update for transformers 4.46

This commit is contained in:
zR 2024-10-29 01:40:11 +08:00
parent c2c28bc45c
commit 94776fb841
3 changed files with 28 additions and 36 deletions

View File

@ -4,19 +4,18 @@ Here is an example of using batch request glm-4-9b,
here you need to build the conversation format yourself and then call the batch function to make batch requests.
Please note that in this demo, the memory consumption is significantly higher.
Note:
Using with glm-4-9b-chat-hf will require `transformers>=4.46.0".
"""
from typing import Optional, Union
from transformers import AutoModel, AutoTokenizer, LogitsProcessorList
from typing import Union
from transformers import AutoTokenizer, LogitsProcessorList, AutoModelForCausalLM
MODEL_PATH = 'THUDM/glm-4-9b-chat'
tokenizer = AutoTokenizer.from_pretrained(
MODEL_PATH,
trust_remote_code=True,
encode_special_tokens=True)
model = AutoModel.from_pretrained(MODEL_PATH, trust_remote_code=True, device_map="auto").eval()
MODEL_PATH = 'THUDM/glm-4-9b-chat-hf'
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
model = AutoModelForCausalLM.from_pretrained(MODEL_PATH, device_map="auto").eval()
def process_model_outputs(inputs, outputs, tokenizer):
responses = []
@ -36,11 +35,17 @@ def batch(
do_sample: bool = True,
top_p: float = 0.8,
temperature: float = 0.8,
logits_processor: Optional[LogitsProcessorList] = LogitsProcessorList(),
logits_processor=None,
):
if logits_processor is None:
logits_processor = LogitsProcessorList()
messages = [messages] if isinstance(messages, str) else messages
batched_inputs = tokenizer(messages, return_tensors="pt", padding="max_length", truncation=True,
max_length=max_input_tokens).to(model.device)
batched_inputs = tokenizer(
messages,
return_tensors="pt",
padding="max_length",
truncation=True,
max_length=max_input_tokens).to(model.device)
gen_kwargs = {
"max_new_tokens": max_new_tokens,

View File

@ -10,26 +10,21 @@ Note: The script includes a modification to handle markdown to plain text conver
ensuring that the CLI interface displays formatted text correctly.
If you use flash attention, you should install the flash-attn and add attn_implementation="flash_attention_2" in model loading.
Note:
Using with glm-4-9b-chat-hf will require `transformers>=4.46.0".
"""
import os
import torch
from threading import Thread
from transformers import (
AutoTokenizer,
StoppingCriteria,
StoppingCriteriaList,
TextIteratorStreamer,
GlmForCausalLM
)
from transformers import AutoTokenizer, AutoModelForCausalLM, StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer
MODEL_PATH = os.environ.get('MODEL_PATH', 'THUDM/glm-4-9b-chat')
MODEL_PATH = "/share/home/zyx/Models/glm-4-9b-chat-hf"
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
model = GlmForCausalLM.from_pretrained(
MODEL_PATH,
# attn_implementation="flash_attention_2", # Use Flash Attention
model = AutoModelForCausalLM.from_pretrained(
MODEL_PATH, # attn_implementation="flash_attention_2", # Use Flash Attention
torch_dtype=torch.bfloat16, # using flash-attn must use bfloat16 or float16
device_map="auto").eval()

View File

@ -4,21 +4,13 @@ from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStream
import torch
from threading import Thread
MODEL_PATH = 'THUDM/glm-4-9b-chat'
MODEL_PATH = 'THUDM/glm-4-9b-chat-hf'
def stress_test(token_len, n, num_gpu):
device = torch.device(f"cuda:{num_gpu - 1}" if torch.cuda.is_available() and num_gpu > 0 else "cpu")
tokenizer = AutoTokenizer.from_pretrained(
MODEL_PATH,
trust_remote_code=True,
padding_side="left"
)
model = AutoModelForCausalLM.from_pretrained(
MODEL_PATH,
trust_remote_code=True,
torch_dtype=torch.bfloat16
).to(device).eval()
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, paddsing_side="left")
model = AutoModelForCausalLM.from_pretrained(MODEL_PATH, torch_dtype=torch.bfloat16).to(device).eval()
# Use INT4 weight infer
# model = AutoModelForCausalLM.from_pretrained(