219 lines
11 KiB
Python
219 lines
11 KiB
Python
import copy
|
||
import warnings
|
||
import logging
|
||
from typing import List, Tuple, Optional, Callable
|
||
|
||
import torch
|
||
from torch import nn
|
||
from transformers.utils import logging
|
||
from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList, GenerationConfig
|
||
|
||
from .modeling_chatglm import ChatGLMForConditionalGeneration, InvalidScoreLogitsProcessor
|
||
from .characterglm_generation_utils import CharacterGLMGenerationUtils, SessionMeta
|
||
|
||
|
||
logger = logging.get_logger(__name__)
|
||
default_generation_config = {
|
||
"do_sample": True,
|
||
"top_k": 100,
|
||
"top_p": 0.9,
|
||
"no_repeat_ngram_size": 0,
|
||
"temperature": 0.9,
|
||
"num_beams": 1,
|
||
"length_penalty": 1.6,
|
||
"repetition_penalty": 1.3,
|
||
"eos_token_id": 13
|
||
}
|
||
|
||
|
||
class CharacterGLMForConditionalGeneration(ChatGLMForConditionalGeneration):
|
||
"""
|
||
CharacterGLM的prompt格式与chatglm有差异。
|
||
CharacterGLMForConditionalGeneration复用了ChatGLMForConditionalGeneration的forward方法,
|
||
重新实现了`build_inputs`和`build_stream_inputs`,
|
||
调整了`chat`和`stream_chat`方法的函数签名,增加session_meta参数,并修改解码参数的默认值。
|
||
"""
|
||
|
||
def build_inputs(self, tokenizer, session_meta: SessionMeta, query: str, history: Optional[List[Tuple[str, str]]] = None):
|
||
character_glm_history = CharacterGLMGenerationUtils.convert_chatglm_history_to_characterglm_history(query, history or [])
|
||
prompt = CharacterGLMGenerationUtils.build_inputs(session_meta, character_glm_history)
|
||
inputs = tokenizer([prompt], return_tensors="pt")
|
||
inputs = inputs.to(self.device)
|
||
return inputs
|
||
|
||
def build_stream_inputs(self, tokenizer, session_meta: SessionMeta, query: str, history: Optional[List[Tuple[str, str]]] = None):
|
||
prompt = "\n[{}]{}\n[{}]".format(
|
||
session_meta['user_name'],
|
||
query.replace('\n', ' '),
|
||
session_meta['bot_name']
|
||
)
|
||
input_ids = tokenizer.encode(prompt, add_special_tokens=False)
|
||
input_ids = input_ids[1:]
|
||
inputs = tokenizer.batch_encode_plus([(input_ids, None)], return_tensors="pt", add_special_tokens=False)
|
||
inputs = inputs.to(self.device)
|
||
return inputs
|
||
|
||
@torch.inference_mode()
|
||
def chat(self, tokenizer, session_meta: SessionMeta, query: str, history: List[Tuple[str, str]] = None, max_length: int = 8192, num_beams=1,
|
||
do_sample=True, top_p=0.9, temperature=0.9, repetition_penalty=1.6, logits_processor=None, **kwargs):
|
||
if history is None:
|
||
history = []
|
||
if logits_processor is None:
|
||
logits_processor = LogitsProcessorList()
|
||
logits_processor.append(InvalidScoreLogitsProcessor())
|
||
gen_kwargs = {"max_length": max_length, "num_beams": num_beams, "do_sample": do_sample, "top_p": top_p,
|
||
"temperature": temperature, "logits_processor": logits_processor, "repetition_penalty": repetition_penalty, **kwargs}
|
||
gen_kwargs.update({k: v for k, v in default_generation_config.items() if k not in gen_kwargs})
|
||
inputs = self.build_inputs(tokenizer, session_meta, query, history=history)
|
||
outputs = self.generate(**inputs, **gen_kwargs)
|
||
outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):]
|
||
response = tokenizer.decode(outputs)
|
||
response = self.process_response(response)
|
||
history = history + [(query, response)]
|
||
return response, history
|
||
|
||
@torch.inference_mode()
|
||
def stream_chat(self, tokenizer, session_meta: SessionMeta, query: str, history: List[Tuple[str, str]] = None, past_key_values=None,
|
||
max_length: int = 8192, do_sample=True, top_p=0.9, temperature=0.9, repetition_penalty=1.0, logits_processor=None,
|
||
return_past_key_values=False, **kwargs):
|
||
if history is None:
|
||
history = []
|
||
if logits_processor is None:
|
||
logits_processor = LogitsProcessorList()
|
||
logits_processor.append(InvalidScoreLogitsProcessor())
|
||
gen_kwargs = {"max_length": max_length, "do_sample": do_sample, "top_p": top_p,
|
||
"temperature": temperature, "logits_processor": logits_processor, "repetition_penalty": repetition_penalty, **kwargs}
|
||
gen_kwargs.update({k: v for k, v in default_generation_config.items() if k not in gen_kwargs})
|
||
gen_kwargs.pop('repetition_penalty', None)
|
||
if past_key_values is None:
|
||
inputs = self.build_inputs(tokenizer, session_meta, query, history=history)
|
||
else:
|
||
inputs = self.build_stream_inputs(tokenizer, session_meta, query, history=history)
|
||
if past_key_values is not None:
|
||
past_length = past_key_values[0][0].shape[0]
|
||
if self.transformer.pre_seq_len is not None:
|
||
past_length -= self.transformer.pre_seq_len
|
||
inputs.position_ids += past_length
|
||
attention_mask = inputs.attention_mask
|
||
attention_mask = torch.cat((attention_mask.new_ones(1, past_length), attention_mask), dim=1)
|
||
inputs['attention_mask'] = attention_mask
|
||
for outputs in self.stream_generate(**inputs, past_key_values=past_key_values,
|
||
return_past_key_values=return_past_key_values, **gen_kwargs):
|
||
if return_past_key_values:
|
||
outputs, past_key_values = outputs
|
||
outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):]
|
||
response = tokenizer.decode(outputs)
|
||
if response and response[-1] != "<EFBFBD>":
|
||
response = self.process_response(response)
|
||
new_history = history + [(query, response)]
|
||
if return_past_key_values:
|
||
yield response, new_history, past_key_values
|
||
else:
|
||
yield response, new_history
|
||
|
||
@torch.inference_mode()
|
||
def stream_generate(
|
||
self,
|
||
input_ids,
|
||
generation_config: Optional[GenerationConfig] = None,
|
||
logits_processor: Optional[LogitsProcessorList] = None,
|
||
stopping_criteria: Optional[StoppingCriteriaList] = None,
|
||
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
|
||
return_past_key_values=False,
|
||
**kwargs,
|
||
):
|
||
batch_size, input_ids_seq_length = input_ids.shape[0], input_ids.shape[-1]
|
||
|
||
if generation_config is None:
|
||
generation_config = self.generation_config
|
||
generation_config = copy.deepcopy(generation_config)
|
||
model_kwargs = generation_config.update(**kwargs)
|
||
model_kwargs["use_cache"] = generation_config.use_cache
|
||
bos_token_id, eos_token_id = generation_config.bos_token_id, generation_config.eos_token_id
|
||
|
||
if isinstance(eos_token_id, int):
|
||
eos_token_id = [eos_token_id]
|
||
|
||
has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None
|
||
if has_default_max_length and generation_config.max_new_tokens is None:
|
||
warnings.warn(
|
||
f"Using `max_length`'s default ({generation_config.max_length}) to control the generation length. "
|
||
"This behaviour is deprecated and will be removed from the config in v5 of Transformers -- we"
|
||
" recommend using `max_new_tokens` to control the maximum length of the generation.",
|
||
UserWarning,
|
||
)
|
||
elif generation_config.max_new_tokens is not None:
|
||
generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length
|
||
if not has_default_max_length:
|
||
logger.warn(
|
||
f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(="
|
||
f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. "
|
||
"Please refer to the documentation for more information. "
|
||
"(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)",
|
||
UserWarning,
|
||
)
|
||
|
||
if input_ids_seq_length >= generation_config.max_length:
|
||
input_ids_string = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids"
|
||
logger.warning(
|
||
f"Input length of {input_ids_string} is {input_ids_seq_length}, but `max_length` is set to"
|
||
f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider"
|
||
" increasing `max_new_tokens`."
|
||
)
|
||
|
||
# 2. Set generation parameters if not already defined
|
||
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
|
||
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
|
||
|
||
logits_processor = self._get_logits_processor(
|
||
generation_config=generation_config,
|
||
input_ids_seq_length=input_ids_seq_length,
|
||
encoder_input_ids=input_ids,
|
||
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
|
||
logits_processor=logits_processor,
|
||
)
|
||
|
||
stopping_criteria = self._get_stopping_criteria(
|
||
generation_config=generation_config, stopping_criteria=stopping_criteria
|
||
)
|
||
logits_warper = self._get_logits_warper(generation_config)
|
||
|
||
unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
|
||
scores = None
|
||
while True:
|
||
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
|
||
# forward pass to get next token
|
||
outputs = self(
|
||
**model_inputs,
|
||
return_dict=True,
|
||
output_attentions=False,
|
||
output_hidden_states=False,
|
||
)
|
||
|
||
next_token_logits = outputs.logits[:, -1, :]
|
||
|
||
# pre-process distribution
|
||
next_token_scores = logits_processor(input_ids, next_token_logits)
|
||
next_token_scores = logits_warper(input_ids, next_token_scores)
|
||
|
||
# sample
|
||
probs = nn.functional.softmax(next_token_scores, dim=-1)
|
||
if generation_config.do_sample:
|
||
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
|
||
else:
|
||
next_tokens = torch.argmax(probs, dim=-1)
|
||
|
||
# update generated ids, model inputs, and length for next step
|
||
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
|
||
model_kwargs = self._update_model_kwargs_for_generation(
|
||
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
|
||
)
|
||
unfinished_sequences = unfinished_sequences.mul((sum(next_tokens != i for i in eos_token_id)).long())
|
||
if return_past_key_values:
|
||
yield input_ids, outputs.past_key_values
|
||
else:
|
||
yield input_ids
|
||
# stop when each sentence is finished, or if we exceed the maximum length
|
||
if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores):
|
||
break
|