diff --git a/finetune_demo/finetune_vision.py b/finetune_demo/finetune_vision.py index e88ef06..e13f317 100644 --- a/finetune_demo/finetune_vision.py +++ b/finetune_demo/finetune_vision.py @@ -77,6 +77,7 @@ class Seq2SeqTrainer(_Seq2SeqTrainer): with torch.no_grad(): if self.args.predict_with_generate: output_ids = inputs.pop("output_ids", None) + del inputs["labels"] loss, generated_tokens, labels = super().prediction_step( model=model, inputs=inputs,