CharacterGLM-6B_a1411420922.../modeling_characterglm.py

219 lines
11 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import copy
import warnings
import logging
from typing import List, Tuple, Optional, Callable
import torch
from torch import nn
from transformers.utils import logging
from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList, GenerationConfig
from .modeling_chatglm import ChatGLMForConditionalGeneration, InvalidScoreLogitsProcessor
from .characterglm_generation_utils import CharacterGLMGenerationUtils, SessionMeta
logger = logging.get_logger(__name__)
default_generation_config = {
"do_sample": True,
"top_k": 100,
"top_p": 0.9,
"no_repeat_ngram_size": 0,
"temperature": 0.9,
"num_beams": 1,
"length_penalty": 1.6,
"repetition_penalty": 1.3,
"eos_token_id": 13
}
class CharacterGLMForConditionalGeneration(ChatGLMForConditionalGeneration):
"""
CharacterGLM的prompt格式与chatglm有差异。
CharacterGLMForConditionalGeneration复用了ChatGLMForConditionalGeneration的forward方法
重新实现了`build_inputs`和`build_stream_inputs`,
调整了`chat`和`stream_chat`方法的函数签名增加session_meta参数并修改解码参数的默认值。
"""
def build_inputs(self, tokenizer, session_meta: SessionMeta, query: str, history: Optional[List[Tuple[str, str]]] = None):
character_glm_history = CharacterGLMGenerationUtils.convert_chatglm_history_to_characterglm_history(query, history or [])
prompt = CharacterGLMGenerationUtils.build_inputs(session_meta, character_glm_history)
inputs = tokenizer([prompt], return_tensors="pt")
inputs = inputs.to(self.device)
return inputs
def build_stream_inputs(self, tokenizer, session_meta: SessionMeta, query: str, history: Optional[List[Tuple[str, str]]] = None):
prompt = "\n[{}]{}\n[{}]".format(
session_meta['user_name'],
query.replace('\n', ' '),
session_meta['bot_name']
)
input_ids = tokenizer.encode(prompt, add_special_tokens=False)
input_ids = input_ids[1:]
inputs = tokenizer.batch_encode_plus([(input_ids, None)], return_tensors="pt", add_special_tokens=False)
inputs = inputs.to(self.device)
return inputs
@torch.inference_mode()
def chat(self, tokenizer, session_meta: SessionMeta, query: str, history: List[Tuple[str, str]] = None, max_length: int = 8192, num_beams=1,
do_sample=True, top_p=0.9, temperature=0.9, repetition_penalty=1.6, logits_processor=None, **kwargs):
if history is None:
history = []
if logits_processor is None:
logits_processor = LogitsProcessorList()
logits_processor.append(InvalidScoreLogitsProcessor())
gen_kwargs = {"max_length": max_length, "num_beams": num_beams, "do_sample": do_sample, "top_p": top_p,
"temperature": temperature, "logits_processor": logits_processor, "repetition_penalty": repetition_penalty, **kwargs}
gen_kwargs.update({k: v for k, v in default_generation_config.items() if k not in gen_kwargs})
inputs = self.build_inputs(tokenizer, session_meta, query, history=history)
outputs = self.generate(**inputs, **gen_kwargs)
outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):]
response = tokenizer.decode(outputs)
response = self.process_response(response)
history = history + [(query, response)]
return response, history
@torch.inference_mode()
def stream_chat(self, tokenizer, session_meta: SessionMeta, query: str, history: List[Tuple[str, str]] = None, past_key_values=None,
max_length: int = 8192, do_sample=True, top_p=0.9, temperature=0.9, repetition_penalty=1.0, logits_processor=None,
return_past_key_values=False, **kwargs):
if history is None:
history = []
if logits_processor is None:
logits_processor = LogitsProcessorList()
logits_processor.append(InvalidScoreLogitsProcessor())
gen_kwargs = {"max_length": max_length, "do_sample": do_sample, "top_p": top_p,
"temperature": temperature, "logits_processor": logits_processor, "repetition_penalty": repetition_penalty, **kwargs}
gen_kwargs.update({k: v for k, v in default_generation_config.items() if k not in gen_kwargs})
gen_kwargs.pop('repetition_penalty', None)
if past_key_values is None:
inputs = self.build_inputs(tokenizer, session_meta, query, history=history)
else:
inputs = self.build_stream_inputs(tokenizer, session_meta, query, history=history)
if past_key_values is not None:
past_length = past_key_values[0][0].shape[0]
if self.transformer.pre_seq_len is not None:
past_length -= self.transformer.pre_seq_len
inputs.position_ids += past_length
attention_mask = inputs.attention_mask
attention_mask = torch.cat((attention_mask.new_ones(1, past_length), attention_mask), dim=1)
inputs['attention_mask'] = attention_mask
for outputs in self.stream_generate(**inputs, past_key_values=past_key_values,
return_past_key_values=return_past_key_values, **gen_kwargs):
if return_past_key_values:
outputs, past_key_values = outputs
outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):]
response = tokenizer.decode(outputs)
if response and response[-1] != "<EFBFBD>":
response = self.process_response(response)
new_history = history + [(query, response)]
if return_past_key_values:
yield response, new_history, past_key_values
else:
yield response, new_history
@torch.inference_mode()
def stream_generate(
self,
input_ids,
generation_config: Optional[GenerationConfig] = None,
logits_processor: Optional[LogitsProcessorList] = None,
stopping_criteria: Optional[StoppingCriteriaList] = None,
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
return_past_key_values=False,
**kwargs,
):
batch_size, input_ids_seq_length = input_ids.shape[0], input_ids.shape[-1]
if generation_config is None:
generation_config = self.generation_config
generation_config = copy.deepcopy(generation_config)
model_kwargs = generation_config.update(**kwargs)
model_kwargs["use_cache"] = generation_config.use_cache
bos_token_id, eos_token_id = generation_config.bos_token_id, generation_config.eos_token_id
if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None
if has_default_max_length and generation_config.max_new_tokens is None:
warnings.warn(
f"Using `max_length`'s default ({generation_config.max_length}) to control the generation length. "
"This behaviour is deprecated and will be removed from the config in v5 of Transformers -- we"
" recommend using `max_new_tokens` to control the maximum length of the generation.",
UserWarning,
)
elif generation_config.max_new_tokens is not None:
generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length
if not has_default_max_length:
logger.warn(
f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(="
f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. "
"Please refer to the documentation for more information. "
"(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)",
UserWarning,
)
if input_ids_seq_length >= generation_config.max_length:
input_ids_string = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids"
logger.warning(
f"Input length of {input_ids_string} is {input_ids_seq_length}, but `max_length` is set to"
f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider"
" increasing `max_new_tokens`."
)
# 2. Set generation parameters if not already defined
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
logits_processor = self._get_logits_processor(
generation_config=generation_config,
input_ids_seq_length=input_ids_seq_length,
encoder_input_ids=input_ids,
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
logits_processor=logits_processor,
)
stopping_criteria = self._get_stopping_criteria(
generation_config=generation_config, stopping_criteria=stopping_criteria
)
logits_warper = self._get_logits_warper(generation_config)
unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
scores = None
while True:
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
# forward pass to get next token
outputs = self(
**model_inputs,
return_dict=True,
output_attentions=False,
output_hidden_states=False,
)
next_token_logits = outputs.logits[:, -1, :]
# pre-process distribution
next_token_scores = logits_processor(input_ids, next_token_logits)
next_token_scores = logits_warper(input_ids, next_token_scores)
# sample
probs = nn.functional.softmax(next_token_scores, dim=-1)
if generation_config.do_sample:
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
else:
next_tokens = torch.argmax(probs, dim=-1)
# update generated ids, model inputs, and length for next step
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
model_kwargs = self._update_model_kwargs_for_generation(
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
)
unfinished_sequences = unfinished_sequences.mul((sum(next_tokens != i for i in eos_token_id)).long())
if return_past_key_values:
yield input_ids, outputs.past_key_values
else:
yield input_ids
# stop when each sentence is finished, or if we exceed the maximum length
if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores):
break