diff --git a/basic_demo/trans_cli_demo.py b/basic_demo/trans_cli_demo.py index fc3c33c..c07cf23 100644 --- a/basic_demo/trans_cli_demo.py +++ b/basic_demo/trans_cli_demo.py @@ -15,7 +15,6 @@ import torch from threading import Thread from typing import Union from pathlib import Path -from peft import AutoPeftModelForCausalLM, PeftModelForCausalLM from transformers import ( AutoModelForCausalLM, AutoTokenizer, @@ -27,7 +26,7 @@ from transformers import ( TextIteratorStreamer ) -ModelType = Union[PreTrainedModel, PeftModelForCausalLM] +ModelType = Union[PreTrainedModel, AutoModelForCausalLM] TokenizerType = Union[PreTrainedTokenizer, PreTrainedTokenizerFast] MODEL_PATH = os.environ.get('MODEL_PATH', 'THUDM/glm-4-9b') @@ -38,7 +37,7 @@ def load_model_and_tokenizer( ) -> tuple[ModelType, TokenizerType]: model_dir = Path(model_dir).expanduser().resolve() if (model_dir / 'adapter_config.json').exists(): - model = AutoPeftModelForCausalLM.from_pretrained( + model = AutoModelForCausalLM.from_pretrained( model_dir, trust_remote_code=trust_remote_code, device_map='auto') tokenizer_dir = model.peft_config['default'].base_model_name_or_path else: diff --git a/basic_demo/trans_web_demo.py b/basic_demo/trans_web_demo.py index a98af97..b9cb1a5 100644 --- a/basic_demo/trans_web_demo.py +++ b/basic_demo/trans_web_demo.py @@ -12,7 +12,6 @@ from threading import Thread from typing import Union from pathlib import Path -from peft import AutoPeftModelForCausalLM, PeftModelForCausalLM from transformers import ( AutoModelForCausalLM, AutoTokenizer, @@ -24,7 +23,7 @@ from transformers import ( TextIteratorStreamer ) -ModelType = Union[PreTrainedModel, PeftModelForCausalLM] +ModelType = Union[PreTrainedModel, AutoModelForCausalLM] TokenizerType = Union[PreTrainedTokenizer, PreTrainedTokenizerFast] MODEL_PATH = os.environ.get('MODEL_PATH', 'THUDM/glm-4-9b-chat') @@ -40,7 +39,7 @@ def load_model_and_tokenizer( ) -> tuple[ModelType, TokenizerType]: model_dir = _resolve_path(model_dir) if (model_dir / 'adapter_config.json').exists(): - model = AutoPeftModelForCausalLM.from_pretrained( + model = AutoModelForCausalLM.from_pretrained( model_dir, trust_remote_code=trust_remote_code, device_map='auto' ) tokenizer_dir = model.peft_config['default'].base_model_name_or_path