This commit is contained in:
zR 2024-12-09 15:54:29 +08:00
commit e14f187090
3 changed files with 15 additions and 12 deletions

View File

@ -22,7 +22,6 @@ from PIL import Image
from io import BytesIO
from pathlib import Path
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
TORCH_TYPE = torch.bfloat16 if torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8 else torch.float16
@asynccontextmanager
@ -283,6 +282,7 @@ def process_history_and_images(messages: List[ChatMessageInput]) -> Tuple[
@torch.inference_mode()
def generate_stream_glm4v(model: AutoModel, tokenizer: AutoTokenizer, params: dict):
uploaded = False
messages = params["messages"]
temperature = float(params.get("temperature", 1.0))
repetition_penalty = float(params.get("repetition_penalty", 1.0))
@ -314,6 +314,7 @@ def generate_stream_glm4v(model: AutoModel, tokenizer: AutoTokenizer, params: di
return_tensors="pt",
return_dict=True
).to(next(model.parameters()).device)
input_echo_len = len(model_inputs["input_ids"][0])
streamer = TextIteratorStreamer(
tokenizer=tokenizer,
@ -328,6 +329,7 @@ def generate_stream_glm4v(model: AutoModel, tokenizer: AutoTokenizer, params: di
"top_p": top_p if temperature > 1e-5 else 0,
"top_k": 1,
'streamer': streamer,
"eos_token_id": [151329, 151336, 151338],
}
if temperature > 1e-5:
gen_kwargs["temperature"] = temperature
@ -354,7 +356,7 @@ def generate_stream_glm4v(model: AutoModel, tokenizer: AutoTokenizer, params: di
},
}
generation_thread.join()
print('\033[91m--generated_text\033[0m', generated_text)
yield {
"text": generated_text,
"usage": {
@ -391,7 +393,7 @@ if __name__ == "__main__":
trust_remote_code=True,
encode_special_tokens=True
)
model.eval().to(DEVICE)
model.eval()
else:
tokenizer = AutoTokenizer.from_pretrained(
MODEL_PATH,
@ -403,6 +405,7 @@ if __name__ == "__main__":
torch_dtype=TORCH_TYPE,
trust_remote_code=True,
device_map="auto",
).eval().to(DEVICE)
).eval()
uvicorn.run(app, host='0.0.0.0', port=8000, workers=1)

View File

@ -374,8 +374,8 @@ def compute_metrics(eval_preds: EvalPrediction, tokenizer):
batched_label_ids[batched_label_ids==-100] = tokenizer.pad_token_id
metrics_dct = {'rouge-1': [], 'rouge-2': [], 'rouge-l': [], 'bleu-4': []}
for pred_ids, label_ids in zip(batched_pred_ids, batched_label_ids):
pred_txt = tokenizer.decode(pred_ids).strip()
label_txt = tokenizer.decode(label_ids).strip()
pred_txt = tokenizer.decode(pred_ids, skip_special_tokens=True).strip()
label_txt = tokenizer.decode(label_ids, skip_special_tokens=True).strip()
pred_tokens = list(jieba.cut(pred_txt))
label_tokens = list(jieba.cut(label_txt))
rouge = Rouge()

View File

@ -420,8 +420,8 @@ def compute_metrics(eval_preds: EvalPrediction, tokenizer):
batched_label_ids[batched_label_ids==-100] = tokenizer.pad_token_id
metrics_dct = {'rouge-1': [], 'rouge-2': [], 'rouge-l': [], 'bleu-4': []}
for pred_ids, label_ids in zip(batched_pred_ids, batched_label_ids):
pred_txt = tokenizer.decode(pred_ids).strip()
label_txt = tokenizer.decode(label_ids).strip()
pred_txt = tokenizer.decode(pred_ids, skip_special_tokens=True).strip()
label_txt = tokenizer.decode(label_ids, skip_special_tokens=True).strip()
pred_tokens = list(jieba.cut(pred_txt))
label_tokens = list(jieba.cut(label_txt))
rouge = Rouge()