From 1dcb4914792819e2c765aaac202ff9aeaa5c53c0 Mon Sep 17 00:00:00 2001 From: arkohut <39525455+arkohut@users.noreply.github.com> Date: Wed, 5 Jun 2024 19:37:46 +0800 Subject: [PATCH] remove peft in demo --- basic_demo/trans_cli_demo.py | 5 ++--- basic_demo/trans_web_demo.py | 5 ++--- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/basic_demo/trans_cli_demo.py b/basic_demo/trans_cli_demo.py index fc3c33c..c07cf23 100644 --- a/basic_demo/trans_cli_demo.py +++ b/basic_demo/trans_cli_demo.py @@ -15,7 +15,6 @@ import torch from threading import Thread from typing import Union from pathlib import Path -from peft import AutoPeftModelForCausalLM, PeftModelForCausalLM from transformers import ( AutoModelForCausalLM, AutoTokenizer, @@ -27,7 +26,7 @@ from transformers import ( TextIteratorStreamer ) -ModelType = Union[PreTrainedModel, PeftModelForCausalLM] +ModelType = Union[PreTrainedModel, AutoModelForCausalLM] TokenizerType = Union[PreTrainedTokenizer, PreTrainedTokenizerFast] MODEL_PATH = os.environ.get('MODEL_PATH', 'THUDM/glm-4-9b') @@ -38,7 +37,7 @@ def load_model_and_tokenizer( ) -> tuple[ModelType, TokenizerType]: model_dir = Path(model_dir).expanduser().resolve() if (model_dir / 'adapter_config.json').exists(): - model = AutoPeftModelForCausalLM.from_pretrained( + model = AutoModelForCausalLM.from_pretrained( model_dir, trust_remote_code=trust_remote_code, device_map='auto') tokenizer_dir = model.peft_config['default'].base_model_name_or_path else: diff --git a/basic_demo/trans_web_demo.py b/basic_demo/trans_web_demo.py index a98af97..b9cb1a5 100644 --- a/basic_demo/trans_web_demo.py +++ b/basic_demo/trans_web_demo.py @@ -12,7 +12,6 @@ from threading import Thread from typing import Union from pathlib import Path -from peft import AutoPeftModelForCausalLM, PeftModelForCausalLM from transformers import ( AutoModelForCausalLM, AutoTokenizer, @@ -24,7 +23,7 @@ from transformers import ( TextIteratorStreamer ) -ModelType = Union[PreTrainedModel, PeftModelForCausalLM] +ModelType = Union[PreTrainedModel, AutoModelForCausalLM] TokenizerType = Union[PreTrainedTokenizer, PreTrainedTokenizerFast] MODEL_PATH = os.environ.get('MODEL_PATH', 'THUDM/glm-4-9b-chat') @@ -40,7 +39,7 @@ def load_model_and_tokenizer( ) -> tuple[ModelType, TokenizerType]: model_dir = _resolve_path(model_dir) if (model_dir / 'adapter_config.json').exists(): - model = AutoPeftModelForCausalLM.from_pretrained( + model = AutoModelForCausalLM.from_pretrained( model_dir, trust_remote_code=trust_remote_code, device_map='auto' ) tokenizer_dir = model.peft_config['default'].base_model_name_or_path