finetune and vision demo update
This commit is contained in:
parent
b562723d59
commit
7fcaeba6cc
|
@ -217,7 +217,7 @@ with torch.no_grad():
|
|||
|
||||
如果你想更进一步了解 GLM-4-9B 系列开源模型,本开源仓库通过以下内容为开发者提供基础的 GLM-4-9B的使用和开发代码
|
||||
|
||||
+ [base](basic_demo/README.md): 在这里包含了
|
||||
+ [basic_demo](basic_demo/README.md): 在这里包含了
|
||||
+ 使用 transformers 和 vLLM 后端的交互代码
|
||||
+ OpenAI API 后端交互代码
|
||||
+ Batch 推理代码
|
||||
|
|
|
@ -226,7 +226,7 @@ Note: GLM-4V-9B does not support calling using vLLM method yet.
|
|||
If you want to learn more about the GLM-4-9B series open source models, this open source repository provides developers
|
||||
with basic GLM-4-9B usage and development code through the following content
|
||||
|
||||
+ [base](basic_demo/README.md): Contains
|
||||
+ [basic_demo](basic_demo/README.md): Contains
|
||||
+ Interaction code using transformers and vLLM backend
|
||||
+ OpenAI API backend interaction code
|
||||
+ Batch reasoning code
|
||||
|
|
|
@ -43,8 +43,7 @@ tokenizer = AutoTokenizer.from_pretrained(
|
|||
model = AutoModel.from_pretrained(
|
||||
MODEL_PATH,
|
||||
trust_remote_code=True,
|
||||
device_map="auto",
|
||||
torch_dtype=torch.bfloat16).eval()
|
||||
device_map="auto").eval()
|
||||
|
||||
|
||||
class StopOnTokens(StoppingCriteria):
|
||||
|
|
|
@ -17,7 +17,7 @@ from transformers import (
|
|||
AutoTokenizer,
|
||||
StoppingCriteria,
|
||||
StoppingCriteriaList,
|
||||
TextIteratorStreamer, AutoModel
|
||||
TextIteratorStreamer, AutoModel, BitsAndBytesConfig
|
||||
)
|
||||
|
||||
from PIL import Image
|
||||
|
@ -33,8 +33,17 @@ model = AutoModel.from_pretrained(
|
|||
MODEL_PATH,
|
||||
trust_remote_code=True,
|
||||
device_map="auto",
|
||||
torch_dtype=torch.bfloat16).eval()
|
||||
torch_dtype=torch.bfloat16
|
||||
).eval()
|
||||
|
||||
## For INT4 inference
|
||||
# model = AutoModel.from_pretrained(
|
||||
# MODEL_PATH,
|
||||
# trust_remote_code=True,
|
||||
# quantization_config=BitsAndBytesConfig(load_in_4bit=True),
|
||||
# torch_dtype=torch.bfloat16,
|
||||
# low_cpu_mem_usage=True
|
||||
# ).eval()
|
||||
|
||||
class StopOnTokens(StoppingCriteria):
|
||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
|
||||
|
@ -83,7 +92,7 @@ if __name__ == "__main__":
|
|||
tokenize=True,
|
||||
return_tensors="pt",
|
||||
return_dict=True
|
||||
).to(model.device)
|
||||
).to(next(model.parameters()).device)
|
||||
streamer = TextIteratorStreamer(
|
||||
tokenizer=tokenizer,
|
||||
timeout=60,
|
||||
|
|
|
@ -239,7 +239,7 @@ def process_batch(
|
|||
loss_masks = [False, False]
|
||||
for message in conv:
|
||||
message = process_message(message)
|
||||
loss_mask_val = False if message['role'] in ('system', 'user') else True
|
||||
loss_mask_val = False if message['role'] in ('system', 'user', 'observation') else True
|
||||
new_input_ids = tokenizer.apply_chat_template([message], tokenize=True, return_dict=False)[0][2:]
|
||||
new_loss_masks = [loss_mask_val] * len(new_input_ids)
|
||||
input_ids += new_input_ids
|
||||
|
|
Loading…
Reference in New Issue