update for transformers 4.46

This commit is contained in:
zR 2024-10-29 01:40:11 +08:00
parent c2c28bc45c
commit 94776fb841
3 changed files with 28 additions and 36 deletions

View File

@ -4,19 +4,18 @@ Here is an example of using batch request glm-4-9b,
here you need to build the conversation format yourself and then call the batch function to make batch requests. here you need to build the conversation format yourself and then call the batch function to make batch requests.
Please note that in this demo, the memory consumption is significantly higher. Please note that in this demo, the memory consumption is significantly higher.
Note:
Using with glm-4-9b-chat-hf will require `transformers>=4.46.0".
""" """
from typing import Optional, Union from typing import Union
from transformers import AutoModel, AutoTokenizer, LogitsProcessorList from transformers import AutoTokenizer, LogitsProcessorList, AutoModelForCausalLM
MODEL_PATH = 'THUDM/glm-4-9b-chat' MODEL_PATH = 'THUDM/glm-4-9b-chat-hf'
tokenizer = AutoTokenizer.from_pretrained(
MODEL_PATH,
trust_remote_code=True,
encode_special_tokens=True)
model = AutoModel.from_pretrained(MODEL_PATH, trust_remote_code=True, device_map="auto").eval()
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
model = AutoModelForCausalLM.from_pretrained(MODEL_PATH, device_map="auto").eval()
def process_model_outputs(inputs, outputs, tokenizer): def process_model_outputs(inputs, outputs, tokenizer):
responses = [] responses = []
@ -36,11 +35,17 @@ def batch(
do_sample: bool = True, do_sample: bool = True,
top_p: float = 0.8, top_p: float = 0.8,
temperature: float = 0.8, temperature: float = 0.8,
logits_processor: Optional[LogitsProcessorList] = LogitsProcessorList(), logits_processor=None,
): ):
if logits_processor is None:
logits_processor = LogitsProcessorList()
messages = [messages] if isinstance(messages, str) else messages messages = [messages] if isinstance(messages, str) else messages
batched_inputs = tokenizer(messages, return_tensors="pt", padding="max_length", truncation=True, batched_inputs = tokenizer(
max_length=max_input_tokens).to(model.device) messages,
return_tensors="pt",
padding="max_length",
truncation=True,
max_length=max_input_tokens).to(model.device)
gen_kwargs = { gen_kwargs = {
"max_new_tokens": max_new_tokens, "max_new_tokens": max_new_tokens,

View File

@ -10,26 +10,21 @@ Note: The script includes a modification to handle markdown to plain text conver
ensuring that the CLI interface displays formatted text correctly. ensuring that the CLI interface displays formatted text correctly.
If you use flash attention, you should install the flash-attn and add attn_implementation="flash_attention_2" in model loading. If you use flash attention, you should install the flash-attn and add attn_implementation="flash_attention_2" in model loading.
Note:
Using with glm-4-9b-chat-hf will require `transformers>=4.46.0".
""" """
import os
import torch import torch
from threading import Thread from threading import Thread
from transformers import ( from transformers import AutoTokenizer, AutoModelForCausalLM, StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer
AutoTokenizer,
StoppingCriteria,
StoppingCriteriaList,
TextIteratorStreamer,
GlmForCausalLM
)
MODEL_PATH = os.environ.get('MODEL_PATH', 'THUDM/glm-4-9b-chat') MODEL_PATH = "/share/home/zyx/Models/glm-4-9b-chat-hf"
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH) tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
model = GlmForCausalLM.from_pretrained( model = AutoModelForCausalLM.from_pretrained(
MODEL_PATH, MODEL_PATH, # attn_implementation="flash_attention_2", # Use Flash Attention
# attn_implementation="flash_attention_2", # Use Flash Attention
torch_dtype=torch.bfloat16, # using flash-attn must use bfloat16 or float16 torch_dtype=torch.bfloat16, # using flash-attn must use bfloat16 or float16
device_map="auto").eval() device_map="auto").eval()

View File

@ -4,22 +4,14 @@ from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStream
import torch import torch
from threading import Thread from threading import Thread
MODEL_PATH = 'THUDM/glm-4-9b-chat' MODEL_PATH = 'THUDM/glm-4-9b-chat-hf'
def stress_test(token_len, n, num_gpu): def stress_test(token_len, n, num_gpu):
device = torch.device(f"cuda:{num_gpu - 1}" if torch.cuda.is_available() and num_gpu > 0 else "cpu") device = torch.device(f"cuda:{num_gpu - 1}" if torch.cuda.is_available() and num_gpu > 0 else "cpu")
tokenizer = AutoTokenizer.from_pretrained( tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, paddsing_side="left")
MODEL_PATH, model = AutoModelForCausalLM.from_pretrained(MODEL_PATH, torch_dtype=torch.bfloat16).to(device).eval()
trust_remote_code=True,
padding_side="left"
)
model = AutoModelForCausalLM.from_pretrained(
MODEL_PATH,
trust_remote_code=True,
torch_dtype=torch.bfloat16
).to(device).eval()
# Use INT4 weight infer # Use INT4 weight infer
# model = AutoModelForCausalLM.from_pretrained( # model = AutoModelForCausalLM.from_pretrained(
# MODEL_PATH, # MODEL_PATH,