glm4/finetune_demo/inference.py

143 lines
4.6 KiB
Python
Raw Permalink Normal View History

2024-06-05 10:22:16 +08:00
from pathlib import Path
from typing import Annotated, Union
import typer
2024-07-01 17:00:28 +08:00
from peft import PeftModelForCausalLM
2024-06-05 10:22:16 +08:00
from transformers import (
2024-07-01 17:00:28 +08:00
AutoModel,
2024-06-05 10:22:16 +08:00
AutoTokenizer,
)
2024-07-01 17:00:28 +08:00
import torch
2024-06-05 10:22:16 +08:00
app = typer.Typer(pretty_exceptions_show_locals=False)
def load_model_and_tokenizer(
2024-12-09 15:58:10 +08:00
model_dir: Union[str, Path], trust_remote_code: bool = True
2024-07-01 17:00:28 +08:00
):
2024-06-05 10:22:16 +08:00
model_dir = Path(model_dir).expanduser().resolve()
2024-12-09 15:58:10 +08:00
if (model_dir / "adapter_config.json").exists():
2024-07-08 16:51:25 +08:00
import json
2024-12-09 15:58:10 +08:00
with open(model_dir / "adapter_config.json", "r", encoding="utf-8") as file:
2024-07-08 16:51:25 +08:00
config = json.load(file)
2024-07-01 17:00:28 +08:00
model = AutoModel.from_pretrained(
2024-12-09 15:58:10 +08:00
config.get("base_model_name_or_path"),
2024-07-01 17:00:28 +08:00
trust_remote_code=trust_remote_code,
2024-12-09 15:58:10 +08:00
device_map="auto",
torch_dtype=torch.bfloat16,
2024-06-05 10:22:16 +08:00
)
2024-07-08 16:51:25 +08:00
model = PeftModelForCausalLM.from_pretrained(
model=model,
model_id=model_dir,
trust_remote_code=trust_remote_code,
)
2024-12-09 15:58:10 +08:00
tokenizer_dir = model.peft_config["default"].base_model_name_or_path
2024-06-05 10:22:16 +08:00
else:
2024-07-01 17:00:28 +08:00
model = AutoModel.from_pretrained(
model_dir,
trust_remote_code=trust_remote_code,
2024-12-09 15:58:10 +08:00
device_map="auto",
torch_dtype=torch.bfloat16,
2024-06-05 10:22:16 +08:00
)
tokenizer_dir = model_dir
tokenizer = AutoTokenizer.from_pretrained(
2024-07-01 17:00:28 +08:00
tokenizer_dir,
trust_remote_code=trust_remote_code,
encode_special_tokens=True,
2024-12-09 15:58:10 +08:00
use_fast=False,
2024-06-05 10:22:16 +08:00
)
return model, tokenizer
@app.command()
def main(
2024-12-09 15:58:10 +08:00
model_dir: Annotated[str, typer.Argument(help="")],
2024-06-05 10:22:16 +08:00
):
2024-07-01 17:00:28 +08:00
# For GLM-4 Finetune Without Tools
2024-06-05 10:22:16 +08:00
messages = [
{
2024-12-09 15:58:10 +08:00
"role": "user",
"content": "#裙子#夏天",
2024-07-08 16:51:25 +08:00
}
2024-06-05 10:22:16 +08:00
]
2024-07-01 17:00:28 +08:00
2024-07-08 16:51:25 +08:00
# For GLM-4 Finetune With Tools
# messages = [
# {
# "role": "system", "content": "",
# "tools":
# [
# {
# "type": "function",
# "function": {
# "name": "create_calendar_event",
# "description": "Create a new calendar event",
# "parameters": {
# "type": "object",
# "properties": {
# "title": {
# "type": "string",
# "description": "The title of the event"
# },
# "start_time": {
# "type": "string",
# "description": "The start time of the event in the format YYYY-MM-DD HH:MM"
# },
# "end_time": {
# "type": "string",
# "description": "The end time of the event in the format YYYY-MM-DD HH:MM"
# }
# },
# "required": [
# "title",
# "start_time",
# "end_time"
# ]
# }
# }
# }
# ]
#
# },
# {
# "role": "user",
# "content": "Can you help me create a calendar event for my meeting tomorrow? The title is \"Team Meeting\". It starts at 10:00 AM and ends at 11:00 AM."
# },
# ]
2024-07-01 17:00:28 +08:00
# For GLM-4V Finetune
# messages = [
# {
# "role": "user",
# "content": "女孩可能希望观众做什么?",
# "image": Image.open("your Image").convert("RGB")
# }
# ]
2024-06-05 10:22:16 +08:00
model, tokenizer = load_model_and_tokenizer(model_dir)
inputs = tokenizer.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=True,
2024-07-01 17:00:28 +08:00
return_tensors="pt",
2024-12-09 15:58:10 +08:00
return_dict=True,
2024-06-05 10:22:16 +08:00
).to(model.device)
generate_kwargs = {
"max_new_tokens": 1024,
"do_sample": True,
"top_p": 0.8,
"temperature": 0.8,
"repetition_penalty": 1.2,
"eos_token_id": model.config.eos_token_id,
}
2024-07-01 17:00:28 +08:00
outputs = model.generate(**inputs, **generate_kwargs)
2024-12-09 15:58:10 +08:00
response = tokenizer.decode(
outputs[0][len(inputs["input_ids"][0]) :], skip_special_tokens=True
).strip()
2024-06-05 10:22:16 +08:00
print("=========")
print(response)
2024-12-09 15:58:10 +08:00
if __name__ == "__main__":
2024-06-05 10:22:16 +08:00
app()