2024-06-05 10:22:16 +08:00
from pathlib import Path
from typing import Annotated , Union
import typer
2024-07-01 17:00:28 +08:00
from peft import PeftModelForCausalLM
2024-06-05 10:22:16 +08:00
from transformers import (
2024-07-01 17:00:28 +08:00
AutoModel ,
2024-06-05 10:22:16 +08:00
AutoTokenizer ,
)
2024-07-01 17:00:28 +08:00
import torch
2024-06-05 10:22:16 +08:00
app = typer . Typer ( pretty_exceptions_show_locals = False )
def load_model_and_tokenizer (
2024-12-09 15:58:10 +08:00
model_dir : Union [ str , Path ] , trust_remote_code : bool = True
2024-07-01 17:00:28 +08:00
) :
2024-06-05 10:22:16 +08:00
model_dir = Path ( model_dir ) . expanduser ( ) . resolve ( )
2024-12-09 15:58:10 +08:00
if ( model_dir / " adapter_config.json " ) . exists ( ) :
2024-07-08 16:51:25 +08:00
import json
2024-12-09 15:58:10 +08:00
with open ( model_dir / " adapter_config.json " , " r " , encoding = " utf-8 " ) as file :
2024-07-08 16:51:25 +08:00
config = json . load ( file )
2024-07-01 17:00:28 +08:00
model = AutoModel . from_pretrained (
2024-12-09 15:58:10 +08:00
config . get ( " base_model_name_or_path " ) ,
2024-07-01 17:00:28 +08:00
trust_remote_code = trust_remote_code ,
2024-12-09 15:58:10 +08:00
device_map = " auto " ,
torch_dtype = torch . bfloat16 ,
2024-06-05 10:22:16 +08:00
)
2024-07-08 16:51:25 +08:00
model = PeftModelForCausalLM . from_pretrained (
model = model ,
model_id = model_dir ,
trust_remote_code = trust_remote_code ,
)
2024-12-09 15:58:10 +08:00
tokenizer_dir = model . peft_config [ " default " ] . base_model_name_or_path
2024-06-05 10:22:16 +08:00
else :
2024-07-01 17:00:28 +08:00
model = AutoModel . from_pretrained (
model_dir ,
trust_remote_code = trust_remote_code ,
2024-12-09 15:58:10 +08:00
device_map = " auto " ,
torch_dtype = torch . bfloat16 ,
2024-06-05 10:22:16 +08:00
)
tokenizer_dir = model_dir
tokenizer = AutoTokenizer . from_pretrained (
2024-07-01 17:00:28 +08:00
tokenizer_dir ,
trust_remote_code = trust_remote_code ,
encode_special_tokens = True ,
2024-12-09 15:58:10 +08:00
use_fast = False ,
2024-06-05 10:22:16 +08:00
)
return model , tokenizer
@app.command ( )
def main (
2024-12-09 15:58:10 +08:00
model_dir : Annotated [ str , typer . Argument ( help = " " ) ] ,
2024-06-05 10:22:16 +08:00
) :
2024-07-01 17:00:28 +08:00
# For GLM-4 Finetune Without Tools
2024-06-05 10:22:16 +08:00
messages = [
{
2024-12-09 15:58:10 +08:00
" role " : " user " ,
" content " : " #裙子#夏天 " ,
2024-07-08 16:51:25 +08:00
}
2024-06-05 10:22:16 +08:00
]
2024-07-01 17:00:28 +08:00
2024-07-08 16:51:25 +08:00
# For GLM-4 Finetune With Tools
# messages = [
# {
# "role": "system", "content": "",
# "tools":
# [
# {
# "type": "function",
# "function": {
# "name": "create_calendar_event",
# "description": "Create a new calendar event",
# "parameters": {
# "type": "object",
# "properties": {
# "title": {
# "type": "string",
# "description": "The title of the event"
# },
# "start_time": {
# "type": "string",
# "description": "The start time of the event in the format YYYY-MM-DD HH:MM"
# },
# "end_time": {
# "type": "string",
# "description": "The end time of the event in the format YYYY-MM-DD HH:MM"
# }
# },
# "required": [
# "title",
# "start_time",
# "end_time"
# ]
# }
# }
# }
# ]
#
# },
# {
# "role": "user",
# "content": "Can you help me create a calendar event for my meeting tomorrow? The title is \"Team Meeting\". It starts at 10:00 AM and ends at 11:00 AM."
# },
# ]
2024-07-01 17:00:28 +08:00
# For GLM-4V Finetune
# messages = [
# {
# "role": "user",
# "content": "女孩可能希望观众做什么?",
# "image": Image.open("your Image").convert("RGB")
# }
# ]
2024-06-05 10:22:16 +08:00
model , tokenizer = load_model_and_tokenizer ( model_dir )
inputs = tokenizer . apply_chat_template (
messages ,
add_generation_prompt = True ,
tokenize = True ,
2024-07-01 17:00:28 +08:00
return_tensors = " pt " ,
2024-12-09 15:58:10 +08:00
return_dict = True ,
2024-06-05 10:22:16 +08:00
) . to ( model . device )
generate_kwargs = {
" max_new_tokens " : 1024 ,
" do_sample " : True ,
" top_p " : 0.8 ,
" temperature " : 0.8 ,
" repetition_penalty " : 1.2 ,
" eos_token_id " : model . config . eos_token_id ,
}
2024-07-01 17:00:28 +08:00
outputs = model . generate ( * * inputs , * * generate_kwargs )
2024-12-09 15:58:10 +08:00
response = tokenizer . decode (
outputs [ 0 ] [ len ( inputs [ " input_ids " ] [ 0 ] ) : ] , skip_special_tokens = True
) . strip ( )
2024-06-05 10:22:16 +08:00
print ( " ========= " )
print ( response )
2024-12-09 15:58:10 +08:00
if __name__ == " __main__ " :
2024-06-05 10:22:16 +08:00
app ( )