add del label
This commit is contained in:
parent
52cd14ce8c
commit
81ba7e087c
|
@ -77,6 +77,7 @@ class Seq2SeqTrainer(_Seq2SeqTrainer):
|
|||
with torch.no_grad():
|
||||
if self.args.predict_with_generate:
|
||||
output_ids = inputs.pop("output_ids", None)
|
||||
del inputs["labels"]
|
||||
loss, generated_tokens, labels = super().prediction_step(
|
||||
model=model,
|
||||
inputs=inputs,
|
||||
|
|
Loading…
Reference in New Issue