remove peft in demo

This commit is contained in:
arkohut 2024-06-05 19:37:46 +08:00
parent a263b69376
commit 1dcb491479
2 changed files with 4 additions and 6 deletions

View File

@ -15,7 +15,6 @@ import torch
from threading import Thread from threading import Thread
from typing import Union from typing import Union
from pathlib import Path from pathlib import Path
from peft import AutoPeftModelForCausalLM, PeftModelForCausalLM
from transformers import ( from transformers import (
AutoModelForCausalLM, AutoModelForCausalLM,
AutoTokenizer, AutoTokenizer,
@ -27,7 +26,7 @@ from transformers import (
TextIteratorStreamer TextIteratorStreamer
) )
ModelType = Union[PreTrainedModel, PeftModelForCausalLM] ModelType = Union[PreTrainedModel, AutoModelForCausalLM]
TokenizerType = Union[PreTrainedTokenizer, PreTrainedTokenizerFast] TokenizerType = Union[PreTrainedTokenizer, PreTrainedTokenizerFast]
MODEL_PATH = os.environ.get('MODEL_PATH', 'THUDM/glm-4-9b') MODEL_PATH = os.environ.get('MODEL_PATH', 'THUDM/glm-4-9b')
@ -38,7 +37,7 @@ def load_model_and_tokenizer(
) -> tuple[ModelType, TokenizerType]: ) -> tuple[ModelType, TokenizerType]:
model_dir = Path(model_dir).expanduser().resolve() model_dir = Path(model_dir).expanduser().resolve()
if (model_dir / 'adapter_config.json').exists(): if (model_dir / 'adapter_config.json').exists():
model = AutoPeftModelForCausalLM.from_pretrained( model = AutoModelForCausalLM.from_pretrained(
model_dir, trust_remote_code=trust_remote_code, device_map='auto') model_dir, trust_remote_code=trust_remote_code, device_map='auto')
tokenizer_dir = model.peft_config['default'].base_model_name_or_path tokenizer_dir = model.peft_config['default'].base_model_name_or_path
else: else:

View File

@ -12,7 +12,6 @@ from threading import Thread
from typing import Union from typing import Union
from pathlib import Path from pathlib import Path
from peft import AutoPeftModelForCausalLM, PeftModelForCausalLM
from transformers import ( from transformers import (
AutoModelForCausalLM, AutoModelForCausalLM,
AutoTokenizer, AutoTokenizer,
@ -24,7 +23,7 @@ from transformers import (
TextIteratorStreamer TextIteratorStreamer
) )
ModelType = Union[PreTrainedModel, PeftModelForCausalLM] ModelType = Union[PreTrainedModel, AutoModelForCausalLM]
TokenizerType = Union[PreTrainedTokenizer, PreTrainedTokenizerFast] TokenizerType = Union[PreTrainedTokenizer, PreTrainedTokenizerFast]
MODEL_PATH = os.environ.get('MODEL_PATH', 'THUDM/glm-4-9b-chat') MODEL_PATH = os.environ.get('MODEL_PATH', 'THUDM/glm-4-9b-chat')
@ -40,7 +39,7 @@ def load_model_and_tokenizer(
) -> tuple[ModelType, TokenizerType]: ) -> tuple[ModelType, TokenizerType]:
model_dir = _resolve_path(model_dir) model_dir = _resolve_path(model_dir)
if (model_dir / 'adapter_config.json').exists(): if (model_dir / 'adapter_config.json').exists():
model = AutoPeftModelForCausalLM.from_pretrained( model = AutoModelForCausalLM.from_pretrained(
model_dir, trust_remote_code=trust_remote_code, device_map='auto' model_dir, trust_remote_code=trust_remote_code, device_map='auto'
) )
tokenizer_dir = model.peft_config['default'].base_model_name_or_path tokenizer_dir = model.peft_config['default'].base_model_name_or_path