commit
5fe70a075b
finetune_demo
|
@ -5,6 +5,7 @@ data_config:
|
|||
num_proc: 1
|
||||
|
||||
combine: True
|
||||
freezeV: True
|
||||
max_input_length: 512
|
||||
max_output_length: 512
|
||||
|
||||
|
@ -45,4 +46,4 @@ peft_config:
|
|||
r: 8
|
||||
lora_alpha: 32
|
||||
lora_dropout: 0.1
|
||||
target_modules: ["query_key_value"]
|
||||
target_modules: ["query_key_value"]
|
||||
|
|
|
@ -5,6 +5,7 @@ data_config:
|
|||
num_proc: 1
|
||||
|
||||
combine: True
|
||||
freezeV: True
|
||||
max_input_length: 512
|
||||
max_output_length: 512
|
||||
|
||||
|
|
|
@ -5,6 +5,7 @@ data_config:
|
|||
num_proc: 1
|
||||
|
||||
combine: True
|
||||
freezeV: True
|
||||
max_input_length: 512
|
||||
max_output_length: 512
|
||||
|
||||
|
|
|
@ -136,6 +136,7 @@ class FinetuningConfig(object):
|
|||
max_input_length: int
|
||||
max_output_length: int
|
||||
combine: bool
|
||||
freezeV: bool
|
||||
|
||||
training_args: Seq2SeqTrainingArguments = dc.field(
|
||||
default_factory=lambda: Seq2SeqTrainingArguments(output_dir='./output')
|
||||
|
@ -449,6 +450,10 @@ def main(
|
|||
):
|
||||
ft_config = FinetuningConfig.from_file(config_file)
|
||||
tokenizer, model = load_tokenizer_and_model(model_dir, peft_config=ft_config.peft_config)
|
||||
|
||||
if ft_config.freezeV:
|
||||
for param in model.transformer.vision.parameters():
|
||||
param.requires_grad = False
|
||||
data_manager = DataManager(data_dir, ft_config.data_config)
|
||||
|
||||
train_dataset = data_manager.get_dataset(
|
||||
|
|
Loading…
Reference in New Issue