remove peft in demo

This commit is contained in:
arkohut 2024-06-05 19:37:46 +08:00
parent a263b69376
commit 1dcb491479
2 changed files with 4 additions and 6 deletions

View File

@ -15,7 +15,6 @@ import torch
from threading import Thread
from typing import Union
from pathlib import Path
from peft import AutoPeftModelForCausalLM, PeftModelForCausalLM
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
@ -27,7 +26,7 @@ from transformers import (
TextIteratorStreamer
)
ModelType = Union[PreTrainedModel, PeftModelForCausalLM]
ModelType = Union[PreTrainedModel, AutoModelForCausalLM]
TokenizerType = Union[PreTrainedTokenizer, PreTrainedTokenizerFast]
MODEL_PATH = os.environ.get('MODEL_PATH', 'THUDM/glm-4-9b')
@ -38,7 +37,7 @@ def load_model_and_tokenizer(
) -> tuple[ModelType, TokenizerType]:
model_dir = Path(model_dir).expanduser().resolve()
if (model_dir / 'adapter_config.json').exists():
model = AutoPeftModelForCausalLM.from_pretrained(
model = AutoModelForCausalLM.from_pretrained(
model_dir, trust_remote_code=trust_remote_code, device_map='auto')
tokenizer_dir = model.peft_config['default'].base_model_name_or_path
else:

View File

@ -12,7 +12,6 @@ from threading import Thread
from typing import Union
from pathlib import Path
from peft import AutoPeftModelForCausalLM, PeftModelForCausalLM
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
@ -24,7 +23,7 @@ from transformers import (
TextIteratorStreamer
)
ModelType = Union[PreTrainedModel, PeftModelForCausalLM]
ModelType = Union[PreTrainedModel, AutoModelForCausalLM]
TokenizerType = Union[PreTrainedTokenizer, PreTrainedTokenizerFast]
MODEL_PATH = os.environ.get('MODEL_PATH', 'THUDM/glm-4-9b-chat')
@ -40,7 +39,7 @@ def load_model_and_tokenizer(
) -> tuple[ModelType, TokenizerType]:
model_dir = _resolve_path(model_dir)
if (model_dir / 'adapter_config.json').exists():
model = AutoPeftModelForCausalLM.from_pretrained(
model = AutoModelForCausalLM.from_pretrained(
model_dir, trust_remote_code=trust_remote_code, device_map='auto'
)
tokenizer_dir = model.peft_config['default'].base_model_name_or_path