update for transformers 4.46
This commit is contained in:
parent
c2c28bc45c
commit
94776fb841
|
@ -4,19 +4,18 @@ Here is an example of using batch request glm-4-9b,
|
||||||
here you need to build the conversation format yourself and then call the batch function to make batch requests.
|
here you need to build the conversation format yourself and then call the batch function to make batch requests.
|
||||||
Please note that in this demo, the memory consumption is significantly higher.
|
Please note that in this demo, the memory consumption is significantly higher.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
Using with glm-4-9b-chat-hf will require `transformers>=4.46.0".
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Optional, Union
|
from typing import Union
|
||||||
from transformers import AutoModel, AutoTokenizer, LogitsProcessorList
|
from transformers import AutoTokenizer, LogitsProcessorList, AutoModelForCausalLM
|
||||||
|
|
||||||
MODEL_PATH = 'THUDM/glm-4-9b-chat'
|
MODEL_PATH = 'THUDM/glm-4-9b-chat-hf'
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(
|
|
||||||
MODEL_PATH,
|
|
||||||
trust_remote_code=True,
|
|
||||||
encode_special_tokens=True)
|
|
||||||
model = AutoModel.from_pretrained(MODEL_PATH, trust_remote_code=True, device_map="auto").eval()
|
|
||||||
|
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(MODEL_PATH, device_map="auto").eval()
|
||||||
|
|
||||||
def process_model_outputs(inputs, outputs, tokenizer):
|
def process_model_outputs(inputs, outputs, tokenizer):
|
||||||
responses = []
|
responses = []
|
||||||
|
@ -36,11 +35,17 @@ def batch(
|
||||||
do_sample: bool = True,
|
do_sample: bool = True,
|
||||||
top_p: float = 0.8,
|
top_p: float = 0.8,
|
||||||
temperature: float = 0.8,
|
temperature: float = 0.8,
|
||||||
logits_processor: Optional[LogitsProcessorList] = LogitsProcessorList(),
|
logits_processor=None,
|
||||||
):
|
):
|
||||||
|
if logits_processor is None:
|
||||||
|
logits_processor = LogitsProcessorList()
|
||||||
messages = [messages] if isinstance(messages, str) else messages
|
messages = [messages] if isinstance(messages, str) else messages
|
||||||
batched_inputs = tokenizer(messages, return_tensors="pt", padding="max_length", truncation=True,
|
batched_inputs = tokenizer(
|
||||||
max_length=max_input_tokens).to(model.device)
|
messages,
|
||||||
|
return_tensors="pt",
|
||||||
|
padding="max_length",
|
||||||
|
truncation=True,
|
||||||
|
max_length=max_input_tokens).to(model.device)
|
||||||
|
|
||||||
gen_kwargs = {
|
gen_kwargs = {
|
||||||
"max_new_tokens": max_new_tokens,
|
"max_new_tokens": max_new_tokens,
|
||||||
|
|
|
@ -10,26 +10,21 @@ Note: The script includes a modification to handle markdown to plain text conver
|
||||||
ensuring that the CLI interface displays formatted text correctly.
|
ensuring that the CLI interface displays formatted text correctly.
|
||||||
|
|
||||||
If you use flash attention, you should install the flash-attn and add attn_implementation="flash_attention_2" in model loading.
|
If you use flash attention, you should install the flash-attn and add attn_implementation="flash_attention_2" in model loading.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
Using with glm-4-9b-chat-hf will require `transformers>=4.46.0".
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
|
||||||
import torch
|
import torch
|
||||||
from threading import Thread
|
from threading import Thread
|
||||||
from transformers import (
|
from transformers import AutoTokenizer, AutoModelForCausalLM, StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer
|
||||||
AutoTokenizer,
|
|
||||||
StoppingCriteria,
|
|
||||||
StoppingCriteriaList,
|
|
||||||
TextIteratorStreamer,
|
|
||||||
GlmForCausalLM
|
|
||||||
)
|
|
||||||
|
|
||||||
MODEL_PATH = os.environ.get('MODEL_PATH', 'THUDM/glm-4-9b-chat')
|
MODEL_PATH = "/share/home/zyx/Models/glm-4-9b-chat-hf"
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
|
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
|
||||||
|
|
||||||
model = GlmForCausalLM.from_pretrained(
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
MODEL_PATH,
|
MODEL_PATH, # attn_implementation="flash_attention_2", # Use Flash Attention
|
||||||
# attn_implementation="flash_attention_2", # Use Flash Attention
|
|
||||||
torch_dtype=torch.bfloat16, # using flash-attn must use bfloat16 or float16
|
torch_dtype=torch.bfloat16, # using flash-attn must use bfloat16 or float16
|
||||||
device_map="auto").eval()
|
device_map="auto").eval()
|
||||||
|
|
||||||
|
|
|
@ -4,22 +4,14 @@ from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStream
|
||||||
import torch
|
import torch
|
||||||
from threading import Thread
|
from threading import Thread
|
||||||
|
|
||||||
MODEL_PATH = 'THUDM/glm-4-9b-chat'
|
MODEL_PATH = 'THUDM/glm-4-9b-chat-hf'
|
||||||
|
|
||||||
|
|
||||||
def stress_test(token_len, n, num_gpu):
|
def stress_test(token_len, n, num_gpu):
|
||||||
device = torch.device(f"cuda:{num_gpu - 1}" if torch.cuda.is_available() and num_gpu > 0 else "cpu")
|
device = torch.device(f"cuda:{num_gpu - 1}" if torch.cuda.is_available() and num_gpu > 0 else "cpu")
|
||||||
tokenizer = AutoTokenizer.from_pretrained(
|
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, paddsing_side="left")
|
||||||
MODEL_PATH,
|
model = AutoModelForCausalLM.from_pretrained(MODEL_PATH, torch_dtype=torch.bfloat16).to(device).eval()
|
||||||
trust_remote_code=True,
|
|
||||||
padding_side="left"
|
|
||||||
)
|
|
||||||
model = AutoModelForCausalLM.from_pretrained(
|
|
||||||
MODEL_PATH,
|
|
||||||
trust_remote_code=True,
|
|
||||||
torch_dtype=torch.bfloat16
|
|
||||||
).to(device).eval()
|
|
||||||
|
|
||||||
# Use INT4 weight infer
|
# Use INT4 weight infer
|
||||||
# model = AutoModelForCausalLM.from_pretrained(
|
# model = AutoModelForCausalLM.from_pretrained(
|
||||||
# MODEL_PATH,
|
# MODEL_PATH,
|
||||||
|
|
Loading…
Reference in New Issue