diff --git a/finetune_demo/configs/lora.yaml b/finetune_demo/configs/lora.yaml index 9b32278..1875424 100644 --- a/finetune_demo/configs/lora.yaml +++ b/finetune_demo/configs/lora.yaml @@ -5,6 +5,7 @@ data_config: num_proc: 1 combine: True +freezeV: True max_input_length: 512 max_output_length: 512 @@ -45,4 +46,4 @@ peft_config: r: 8 lora_alpha: 32 lora_dropout: 0.1 - target_modules: ["query_key_value"] \ No newline at end of file + target_modules: ["query_key_value"] diff --git a/finetune_demo/configs/ptuning_v2.yaml b/finetune_demo/configs/ptuning_v2.yaml index 32fd460..3b7f284 100644 --- a/finetune_demo/configs/ptuning_v2.yaml +++ b/finetune_demo/configs/ptuning_v2.yaml @@ -5,6 +5,7 @@ data_config: num_proc: 1 combine: True +freezeV: True max_input_length: 512 max_output_length: 512 diff --git a/finetune_demo/configs/sft.yaml b/finetune_demo/configs/sft.yaml index 54e5a71..080594b 100644 --- a/finetune_demo/configs/sft.yaml +++ b/finetune_demo/configs/sft.yaml @@ -5,6 +5,7 @@ data_config: num_proc: 1 combine: True +freezeV: True max_input_length: 512 max_output_length: 512 diff --git a/finetune_demo/finetune_vision.py b/finetune_demo/finetune_vision.py index 56fec56..b86c91c 100644 --- a/finetune_demo/finetune_vision.py +++ b/finetune_demo/finetune_vision.py @@ -136,6 +136,7 @@ class FinetuningConfig(object): max_input_length: int max_output_length: int combine: bool + freezeV: bool training_args: Seq2SeqTrainingArguments = dc.field( default_factory=lambda: Seq2SeqTrainingArguments(output_dir='./output') @@ -449,6 +450,10 @@ def main( ): ft_config = FinetuningConfig.from_file(config_file) tokenizer, model = load_tokenizer_and_model(model_dir, peft_config=ft_config.peft_config) + + if ft_config.freezeV: + for param in model.transformer.vision.parameters(): + param.requires_grad = False data_manager = DataManager(data_dir, ft_config.data_config) train_dataset = data_manager.get_dataset(