add del label
This commit is contained in:
parent
52cd14ce8c
commit
81ba7e087c
|
@ -77,6 +77,7 @@ class Seq2SeqTrainer(_Seq2SeqTrainer):
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
if self.args.predict_with_generate:
|
if self.args.predict_with_generate:
|
||||||
output_ids = inputs.pop("output_ids", None)
|
output_ids = inputs.pop("output_ids", None)
|
||||||
|
del inputs["labels"]
|
||||||
loss, generated_tokens, labels = super().prediction_step(
|
loss, generated_tokens, labels = super().prediction_step(
|
||||||
model=model,
|
model=model,
|
||||||
inputs=inputs,
|
inputs=inputs,
|
||||||
|
|
Loading…
Reference in New Issue