adapt transformers==4.46
This commit is contained in:
parent
0b37cf22e5
commit
af1d4f2f11
|
@ -47,3 +47,4 @@ peft_config:
|
||||||
lora_alpha: 32
|
lora_alpha: 32
|
||||||
lora_dropout: 0.1
|
lora_dropout: 0.1
|
||||||
target_modules: ["query_key_value"]
|
target_modules: ["query_key_value"]
|
||||||
|
#target_modules: ["q_proj", "k_proj", "v_proj"] if model is glm-4-9b-chat-hf
|
||||||
|
|
|
@ -399,7 +399,7 @@ def load_tokenizer_and_model(
|
||||||
model = AutoModelForCausalLM.from_pretrained(
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
model_dir,
|
model_dir,
|
||||||
trust_remote_code=True,
|
trust_remote_code=True,
|
||||||
empty_init=False,
|
empty_init=False, #if transformers>=4.46 and load glm-4-9b-chat-hf, delete this
|
||||||
use_cache=False,
|
use_cache=False,
|
||||||
torch_dtype=torch.bfloat16 # Must use BFloat 16
|
torch_dtype=torch.bfloat16 # Must use BFloat 16
|
||||||
)
|
)
|
||||||
|
@ -409,7 +409,7 @@ def load_tokenizer_and_model(
|
||||||
model = AutoModelForCausalLM.from_pretrained(
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
model_dir,
|
model_dir,
|
||||||
trust_remote_code=True,
|
trust_remote_code=True,
|
||||||
empty_init=False,
|
empty_init=False, #if transformers>=4.46 and load glm-4-9b-chat-hf, delete this
|
||||||
use_cache=False,
|
use_cache=False,
|
||||||
torch_dtype=torch.bfloat16
|
torch_dtype=torch.bfloat16
|
||||||
)
|
)
|
||||||
|
|
Loading…
Reference in New Issue