update for transformers 4.46
This commit is contained in:
parent
c2c28bc45c
commit
94776fb841
|
@ -4,19 +4,18 @@ Here is an example of using batch request glm-4-9b,
|
|||
here you need to build the conversation format yourself and then call the batch function to make batch requests.
|
||||
Please note that in this demo, the memory consumption is significantly higher.
|
||||
|
||||
Note:
|
||||
Using with glm-4-9b-chat-hf will require `transformers>=4.46.0".
|
||||
|
||||
"""
|
||||
|
||||
from typing import Optional, Union
|
||||
from transformers import AutoModel, AutoTokenizer, LogitsProcessorList
|
||||
from typing import Union
|
||||
from transformers import AutoTokenizer, LogitsProcessorList, AutoModelForCausalLM
|
||||
|
||||
MODEL_PATH = 'THUDM/glm-4-9b-chat'
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
MODEL_PATH,
|
||||
trust_remote_code=True,
|
||||
encode_special_tokens=True)
|
||||
model = AutoModel.from_pretrained(MODEL_PATH, trust_remote_code=True, device_map="auto").eval()
|
||||
MODEL_PATH = 'THUDM/glm-4-9b-chat-hf'
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
|
||||
model = AutoModelForCausalLM.from_pretrained(MODEL_PATH, device_map="auto").eval()
|
||||
|
||||
def process_model_outputs(inputs, outputs, tokenizer):
|
||||
responses = []
|
||||
|
@ -36,10 +35,16 @@ def batch(
|
|||
do_sample: bool = True,
|
||||
top_p: float = 0.8,
|
||||
temperature: float = 0.8,
|
||||
logits_processor: Optional[LogitsProcessorList] = LogitsProcessorList(),
|
||||
logits_processor=None,
|
||||
):
|
||||
if logits_processor is None:
|
||||
logits_processor = LogitsProcessorList()
|
||||
messages = [messages] if isinstance(messages, str) else messages
|
||||
batched_inputs = tokenizer(messages, return_tensors="pt", padding="max_length", truncation=True,
|
||||
batched_inputs = tokenizer(
|
||||
messages,
|
||||
return_tensors="pt",
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
max_length=max_input_tokens).to(model.device)
|
||||
|
||||
gen_kwargs = {
|
||||
|
|
|
@ -10,26 +10,21 @@ Note: The script includes a modification to handle markdown to plain text conver
|
|||
ensuring that the CLI interface displays formatted text correctly.
|
||||
|
||||
If you use flash attention, you should install the flash-attn and add attn_implementation="flash_attention_2" in model loading.
|
||||
|
||||
Note:
|
||||
Using with glm-4-9b-chat-hf will require `transformers>=4.46.0".
|
||||
"""
|
||||
|
||||
import os
|
||||
import torch
|
||||
from threading import Thread
|
||||
from transformers import (
|
||||
AutoTokenizer,
|
||||
StoppingCriteria,
|
||||
StoppingCriteriaList,
|
||||
TextIteratorStreamer,
|
||||
GlmForCausalLM
|
||||
)
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM, StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer
|
||||
|
||||
MODEL_PATH = os.environ.get('MODEL_PATH', 'THUDM/glm-4-9b-chat')
|
||||
MODEL_PATH = "/share/home/zyx/Models/glm-4-9b-chat-hf"
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
|
||||
|
||||
model = GlmForCausalLM.from_pretrained(
|
||||
MODEL_PATH,
|
||||
# attn_implementation="flash_attention_2", # Use Flash Attention
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
MODEL_PATH, # attn_implementation="flash_attention_2", # Use Flash Attention
|
||||
torch_dtype=torch.bfloat16, # using flash-attn must use bfloat16 or float16
|
||||
device_map="auto").eval()
|
||||
|
||||
|
|
|
@ -4,21 +4,13 @@ from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStream
|
|||
import torch
|
||||
from threading import Thread
|
||||
|
||||
MODEL_PATH = 'THUDM/glm-4-9b-chat'
|
||||
MODEL_PATH = 'THUDM/glm-4-9b-chat-hf'
|
||||
|
||||
|
||||
def stress_test(token_len, n, num_gpu):
|
||||
device = torch.device(f"cuda:{num_gpu - 1}" if torch.cuda.is_available() and num_gpu > 0 else "cpu")
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
MODEL_PATH,
|
||||
trust_remote_code=True,
|
||||
padding_side="left"
|
||||
)
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
MODEL_PATH,
|
||||
trust_remote_code=True,
|
||||
torch_dtype=torch.bfloat16
|
||||
).to(device).eval()
|
||||
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, paddsing_side="left")
|
||||
model = AutoModelForCausalLM.from_pretrained(MODEL_PATH, torch_dtype=torch.bfloat16).to(device).eval()
|
||||
|
||||
# Use INT4 weight infer
|
||||
# model = AutoModelForCausalLM.from_pretrained(
|
||||
|
|
Loading…
Reference in New Issue