Merge pull request #550 from zhipuch/froze

Froze vision layers
This commit is contained in:
Yuxuan.Zhang 2024-09-07 01:00:29 +08:00 committed by GitHub
commit 5fe70a075b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 9 additions and 1 deletions

View File

@ -5,6 +5,7 @@ data_config:
num_proc: 1
combine: True
freezeV: True
max_input_length: 512
max_output_length: 512

View File

@ -5,6 +5,7 @@ data_config:
num_proc: 1
combine: True
freezeV: True
max_input_length: 512
max_output_length: 512

View File

@ -5,6 +5,7 @@ data_config:
num_proc: 1
combine: True
freezeV: True
max_input_length: 512
max_output_length: 512

View File

@ -136,6 +136,7 @@ class FinetuningConfig(object):
max_input_length: int
max_output_length: int
combine: bool
freezeV: bool
training_args: Seq2SeqTrainingArguments = dc.field(
default_factory=lambda: Seq2SeqTrainingArguments(output_dir='./output')
@ -449,6 +450,10 @@ def main(
):
ft_config = FinetuningConfig.from_file(config_file)
tokenizer, model = load_tokenizer_and_model(model_dir, peft_config=ft_config.peft_config)
if ft_config.freezeV:
for param in model.transformer.vision.parameters():
param.requires_grad = False
data_manager = DataManager(data_dir, ft_config.data_config)
train_dataset = data_manager.get_dataset(