remove peft in demo
This commit is contained in:
parent
a263b69376
commit
1dcb491479
|
@ -15,7 +15,6 @@ import torch
|
||||||
from threading import Thread
|
from threading import Thread
|
||||||
from typing import Union
|
from typing import Union
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from peft import AutoPeftModelForCausalLM, PeftModelForCausalLM
|
|
||||||
from transformers import (
|
from transformers import (
|
||||||
AutoModelForCausalLM,
|
AutoModelForCausalLM,
|
||||||
AutoTokenizer,
|
AutoTokenizer,
|
||||||
|
@ -27,7 +26,7 @@ from transformers import (
|
||||||
TextIteratorStreamer
|
TextIteratorStreamer
|
||||||
)
|
)
|
||||||
|
|
||||||
ModelType = Union[PreTrainedModel, PeftModelForCausalLM]
|
ModelType = Union[PreTrainedModel, AutoModelForCausalLM]
|
||||||
TokenizerType = Union[PreTrainedTokenizer, PreTrainedTokenizerFast]
|
TokenizerType = Union[PreTrainedTokenizer, PreTrainedTokenizerFast]
|
||||||
|
|
||||||
MODEL_PATH = os.environ.get('MODEL_PATH', 'THUDM/glm-4-9b')
|
MODEL_PATH = os.environ.get('MODEL_PATH', 'THUDM/glm-4-9b')
|
||||||
|
@ -38,7 +37,7 @@ def load_model_and_tokenizer(
|
||||||
) -> tuple[ModelType, TokenizerType]:
|
) -> tuple[ModelType, TokenizerType]:
|
||||||
model_dir = Path(model_dir).expanduser().resolve()
|
model_dir = Path(model_dir).expanduser().resolve()
|
||||||
if (model_dir / 'adapter_config.json').exists():
|
if (model_dir / 'adapter_config.json').exists():
|
||||||
model = AutoPeftModelForCausalLM.from_pretrained(
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
model_dir, trust_remote_code=trust_remote_code, device_map='auto')
|
model_dir, trust_remote_code=trust_remote_code, device_map='auto')
|
||||||
tokenizer_dir = model.peft_config['default'].base_model_name_or_path
|
tokenizer_dir = model.peft_config['default'].base_model_name_or_path
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -12,7 +12,6 @@ from threading import Thread
|
||||||
|
|
||||||
from typing import Union
|
from typing import Union
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from peft import AutoPeftModelForCausalLM, PeftModelForCausalLM
|
|
||||||
from transformers import (
|
from transformers import (
|
||||||
AutoModelForCausalLM,
|
AutoModelForCausalLM,
|
||||||
AutoTokenizer,
|
AutoTokenizer,
|
||||||
|
@ -24,7 +23,7 @@ from transformers import (
|
||||||
TextIteratorStreamer
|
TextIteratorStreamer
|
||||||
)
|
)
|
||||||
|
|
||||||
ModelType = Union[PreTrainedModel, PeftModelForCausalLM]
|
ModelType = Union[PreTrainedModel, AutoModelForCausalLM]
|
||||||
TokenizerType = Union[PreTrainedTokenizer, PreTrainedTokenizerFast]
|
TokenizerType = Union[PreTrainedTokenizer, PreTrainedTokenizerFast]
|
||||||
|
|
||||||
MODEL_PATH = os.environ.get('MODEL_PATH', 'THUDM/glm-4-9b-chat')
|
MODEL_PATH = os.environ.get('MODEL_PATH', 'THUDM/glm-4-9b-chat')
|
||||||
|
@ -40,7 +39,7 @@ def load_model_and_tokenizer(
|
||||||
) -> tuple[ModelType, TokenizerType]:
|
) -> tuple[ModelType, TokenizerType]:
|
||||||
model_dir = _resolve_path(model_dir)
|
model_dir = _resolve_path(model_dir)
|
||||||
if (model_dir / 'adapter_config.json').exists():
|
if (model_dir / 'adapter_config.json').exists():
|
||||||
model = AutoPeftModelForCausalLM.from_pretrained(
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
model_dir, trust_remote_code=trust_remote_code, device_map='auto'
|
model_dir, trust_remote_code=trust_remote_code, device_map='auto'
|
||||||
)
|
)
|
||||||
tokenizer_dir = model.peft_config['default'].base_model_name_or_path
|
tokenizer_dir = model.peft_config['default'].base_model_name_or_path
|
||||||
|
|
Loading…
Reference in New Issue