finetune and vision demo update

This commit is contained in:
zR 2024-06-07 16:53:56 +08:00
parent b562723d59
commit 7fcaeba6cc
5 changed files with 16 additions and 8 deletions

View File

@ -217,7 +217,7 @@ with torch.no_grad():
如果你想更进一步了解 GLM-4-9B 系列开源模型,本开源仓库通过以下内容为开发者提供基础的 GLM-4-9B的使用和开发代码
+ [base](basic_demo/README.md): 在这里包含了
+ [basic_demo](basic_demo/README.md): 在这里包含了
+ 使用 transformers 和 vLLM 后端的交互代码
+ OpenAI API 后端交互代码
+ Batch 推理代码

View File

@ -226,7 +226,7 @@ Note: GLM-4V-9B does not support calling using vLLM method yet.
If you want to learn more about the GLM-4-9B series open source models, this open source repository provides developers
with basic GLM-4-9B usage and development code through the following content
+ [base](basic_demo/README.md): Contains
+ [basic_demo](basic_demo/README.md): Contains
+ Interaction code using transformers and vLLM backend
+ OpenAI API backend interaction code
+ Batch reasoning code

View File

@ -43,8 +43,7 @@ tokenizer = AutoTokenizer.from_pretrained(
model = AutoModel.from_pretrained(
MODEL_PATH,
trust_remote_code=True,
device_map="auto",
torch_dtype=torch.bfloat16).eval()
device_map="auto").eval()
class StopOnTokens(StoppingCriteria):

View File

@ -17,7 +17,7 @@ from transformers import (
AutoTokenizer,
StoppingCriteria,
StoppingCriteriaList,
TextIteratorStreamer, AutoModel
TextIteratorStreamer, AutoModel, BitsAndBytesConfig
)
from PIL import Image
@ -33,8 +33,17 @@ model = AutoModel.from_pretrained(
MODEL_PATH,
trust_remote_code=True,
device_map="auto",
torch_dtype=torch.bfloat16).eval()
torch_dtype=torch.bfloat16
).eval()
## For INT4 inference
# model = AutoModel.from_pretrained(
# MODEL_PATH,
# trust_remote_code=True,
# quantization_config=BitsAndBytesConfig(load_in_4bit=True),
# torch_dtype=torch.bfloat16,
# low_cpu_mem_usage=True
# ).eval()
class StopOnTokens(StoppingCriteria):
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
@ -83,7 +92,7 @@ if __name__ == "__main__":
tokenize=True,
return_tensors="pt",
return_dict=True
).to(model.device)
).to(next(model.parameters()).device)
streamer = TextIteratorStreamer(
tokenizer=tokenizer,
timeout=60,

View File

@ -239,7 +239,7 @@ def process_batch(
loss_masks = [False, False]
for message in conv:
message = process_message(message)
loss_mask_val = False if message['role'] in ('system', 'user') else True
loss_mask_val = False if message['role'] in ('system', 'user', 'observation') else True
new_input_ids = tokenizer.apply_chat_template([message], tokenize=True, return_dict=False)[0][2:]
new_loss_masks = [loss_mask_val] * len(new_input_ids)
input_ids += new_input_ids