75 lines
3.3 KiB
Python
75 lines
3.3 KiB
Python
|
import os
|
||
|
import torch
|
||
|
from typing import Union, Dict, Any
|
||
|
from modelscope.pipelines.builder import PIPELINES
|
||
|
from modelscope.models.builder import MODELS
|
||
|
from modelscope.utils.constant import Tasks
|
||
|
from modelscope.pipelines.base import Pipeline
|
||
|
from modelscope.outputs import OutputKeys
|
||
|
from modelscope.pipelines.nlp.text_generation_pipeline import TextGenerationPipeline
|
||
|
from modelscope.models.base import Model, TorchModel
|
||
|
from modelscope.utils.logger import get_logger
|
||
|
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
|
||
|
from transformers.generation.utils import GenerationConfig
|
||
|
|
||
|
@PIPELINES.register_module(Tasks.text_generation, module_name='Baichuan2-7B-chatbot-pipe')
|
||
|
class Baichuan7BChatTextGenerationPipeline(TextGenerationPipeline):
|
||
|
def __init__(
|
||
|
self,
|
||
|
model: Union[Model, str],
|
||
|
*args,
|
||
|
**kwargs):
|
||
|
self.model = Baichuan7BChatTextGeneration(model) if isinstance(model, str) else model
|
||
|
super().__init__(model=model, **kwargs)
|
||
|
|
||
|
def preprocess(self, inputs, **preprocess_params) -> Dict[str, Any]:
|
||
|
return inputs
|
||
|
|
||
|
def _sanitize_parameters(self, **pipeline_parameters):
|
||
|
return {},pipeline_parameters,{}
|
||
|
|
||
|
# define the forward pass
|
||
|
def forward(self, inputs: Dict, **forward_params) -> Dict[str, Any]:
|
||
|
output = {}
|
||
|
device = self.model.model.device
|
||
|
input_ids = self.model.tokenizer(inputs, return_tensors="pt").input_ids.to(device)
|
||
|
pred = self.model.model.generate(input_ids,**forward_params)
|
||
|
out = self.model.tokenizer.decode(pred.cpu()[0], skip_special_tokens=True)
|
||
|
output['text'] = out
|
||
|
return output
|
||
|
|
||
|
# format the outputs from pipeline
|
||
|
def postprocess(self, input, **kwargs) -> Dict[str, Any]:
|
||
|
return input
|
||
|
|
||
|
|
||
|
@MODELS.register_module(Tasks.text_generation, module_name='Baichuan2-7B-Chat')
|
||
|
class Baichuan7BChatTextGeneration(TorchModel):
|
||
|
def __init__(self, model_dir=None, *args, **kwargs):
|
||
|
super().__init__(model_dir, *args, **kwargs)
|
||
|
self.logger = get_logger()
|
||
|
# loading tokenizer
|
||
|
self.tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True)
|
||
|
self.model = AutoModelForCausalLM.from_pretrained(model_dir, device_map="auto", torch_dtype=torch.float16, trust_remote_code=True)
|
||
|
# self.model = AutoModelForCausalLM.from_pretrained(model_dir, device_map="auto",trust_remote_code=True)
|
||
|
self.model.generation_config = GenerationConfig.from_pretrained(model_dir)
|
||
|
self.model = self.model.eval()
|
||
|
|
||
|
def forward(self,input: Dict, *args, **kwargs) -> Dict[str, Any]:
|
||
|
output = {}
|
||
|
response = self.model.chat(self.tokenizer, input, *args, **kwargs)
|
||
|
history = input.copy()
|
||
|
history.append({'role': 'assistant', 'content': response})
|
||
|
return {OutputKeys.RESPONSE:response, OutputKeys.HISTORY: history}
|
||
|
|
||
|
def quantize(self, bits: int):
|
||
|
self.model = self.model.quantize(bits)
|
||
|
return self
|
||
|
|
||
|
def infer(self, input, **kwargs):
|
||
|
device = self.model.device
|
||
|
input_ids = self.tokenizer(input, return_tensors="pt").input_ids.to(device)
|
||
|
pred = self.model.generate(input_ids,**kwargs)
|
||
|
out = self.tokenizer.decode(pred.cpu()[0], skip_special_tokens=True)
|
||
|
return out
|