commit
5fe70a075b
|
@ -5,6 +5,7 @@ data_config:
|
||||||
num_proc: 1
|
num_proc: 1
|
||||||
|
|
||||||
combine: True
|
combine: True
|
||||||
|
freezeV: True
|
||||||
max_input_length: 512
|
max_input_length: 512
|
||||||
max_output_length: 512
|
max_output_length: 512
|
||||||
|
|
||||||
|
@ -45,4 +46,4 @@ peft_config:
|
||||||
r: 8
|
r: 8
|
||||||
lora_alpha: 32
|
lora_alpha: 32
|
||||||
lora_dropout: 0.1
|
lora_dropout: 0.1
|
||||||
target_modules: ["query_key_value"]
|
target_modules: ["query_key_value"]
|
||||||
|
|
|
@ -5,6 +5,7 @@ data_config:
|
||||||
num_proc: 1
|
num_proc: 1
|
||||||
|
|
||||||
combine: True
|
combine: True
|
||||||
|
freezeV: True
|
||||||
max_input_length: 512
|
max_input_length: 512
|
||||||
max_output_length: 512
|
max_output_length: 512
|
||||||
|
|
||||||
|
|
|
@ -5,6 +5,7 @@ data_config:
|
||||||
num_proc: 1
|
num_proc: 1
|
||||||
|
|
||||||
combine: True
|
combine: True
|
||||||
|
freezeV: True
|
||||||
max_input_length: 512
|
max_input_length: 512
|
||||||
max_output_length: 512
|
max_output_length: 512
|
||||||
|
|
||||||
|
|
|
@ -136,6 +136,7 @@ class FinetuningConfig(object):
|
||||||
max_input_length: int
|
max_input_length: int
|
||||||
max_output_length: int
|
max_output_length: int
|
||||||
combine: bool
|
combine: bool
|
||||||
|
freezeV: bool
|
||||||
|
|
||||||
training_args: Seq2SeqTrainingArguments = dc.field(
|
training_args: Seq2SeqTrainingArguments = dc.field(
|
||||||
default_factory=lambda: Seq2SeqTrainingArguments(output_dir='./output')
|
default_factory=lambda: Seq2SeqTrainingArguments(output_dir='./output')
|
||||||
|
@ -449,6 +450,10 @@ def main(
|
||||||
):
|
):
|
||||||
ft_config = FinetuningConfig.from_file(config_file)
|
ft_config = FinetuningConfig.from_file(config_file)
|
||||||
tokenizer, model = load_tokenizer_and_model(model_dir, peft_config=ft_config.peft_config)
|
tokenizer, model = load_tokenizer_and_model(model_dir, peft_config=ft_config.peft_config)
|
||||||
|
|
||||||
|
if ft_config.freezeV:
|
||||||
|
for param in model.transformer.vision.parameters():
|
||||||
|
param.requires_grad = False
|
||||||
data_manager = DataManager(data_dir, ft_config.data_config)
|
data_manager = DataManager(data_dir, ft_config.data_config)
|
||||||
|
|
||||||
train_dataset = data_manager.get_dataset(
|
train_dataset = data_manager.get_dataset(
|
||||||
|
|
Loading…
Reference in New Issue