glm4/finetune_demo/inference.py

110 lines
3.8 KiB
Python
Raw Normal View History

2024-06-05 10:22:16 +08:00
from pathlib import Path
from typing import Annotated, Union
import typer
from peft import AutoPeftModelForCausalLM, PeftModelForCausalLM
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
PreTrainedModel,
PreTrainedTokenizer,
PreTrainedTokenizerFast
)
ModelType = Union[PreTrainedModel, PeftModelForCausalLM]
TokenizerType = Union[PreTrainedTokenizer, PreTrainedTokenizerFast]
app = typer.Typer(pretty_exceptions_show_locals=False)
def load_model_and_tokenizer(
model_dir: Union[str, Path], trust_remote_code: bool = True
) -> tuple[ModelType, TokenizerType]:
model_dir = Path(model_dir).expanduser().resolve()
if (model_dir / 'adapter_config.json').exists():
model = AutoPeftModelForCausalLM.from_pretrained(
model_dir, trust_remote_code=trust_remote_code, device_map='auto'
)
tokenizer_dir = model.peft_config['default'].base_model_name_or_path
else:
model = AutoModelForCausalLM.from_pretrained(
model_dir, trust_remote_code=trust_remote_code, device_map='auto'
)
tokenizer_dir = model_dir
tokenizer = AutoTokenizer.from_pretrained(
tokenizer_dir, trust_remote_code=trust_remote_code, encode_special_tokens=True, use_fast=False
)
return model, tokenizer
@app.command()
def main(
model_dir: Annotated[str, typer.Argument(help='')],
):
messages = [
{
"role": "system", "content": "",
"tools":
[
{
"type": "function",
"function": {
"name": "create_calendar_event",
"description": "Create a new calendar event",
"parameters": {
"type": "object",
"properties": {
"title": {
"type": "string",
"description": "The title of the event"
},
"start_time": {
"type": "string",
"description": "The start time of the event in the format YYYY-MM-DD HH:MM"
},
"end_time": {
"type": "string",
"description": "The end time of the event in the format YYYY-MM-DD HH:MM"
}
},
"required": [
"title",
"start_time",
"end_time"
]
}
}
}
]
},
{
"role": "user",
"content": "Can you help me create a calendar event for my meeting tomorrow? The title is \"Team Meeting\". It starts at 10:00 AM and ends at 11:00 AM."
},
]
model, tokenizer = load_model_and_tokenizer(model_dir)
inputs = tokenizer.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=True,
return_tensors="pt"
).to(model.device)
generate_kwargs = {
"input_ids": inputs,
"max_new_tokens": 1024,
"do_sample": True,
"top_p": 0.8,
"temperature": 0.8,
"repetition_penalty": 1.2,
"eos_token_id": model.config.eos_token_id,
}
outputs = model.generate(**generate_kwargs)
response = tokenizer.decode(outputs[0][len(inputs[0]):], skip_special_tokens=True).strip()
print("=========")
print(response)
if __name__ == '__main__':
app()