Merge pull request #655 from zhipuch/main
correct compute_metrics function
This commit is contained in:
commit
c23abb0c59
|
@ -374,8 +374,8 @@ def compute_metrics(eval_preds: EvalPrediction, tokenizer):
|
||||||
batched_label_ids[batched_label_ids==-100] = tokenizer.pad_token_id
|
batched_label_ids[batched_label_ids==-100] = tokenizer.pad_token_id
|
||||||
metrics_dct = {'rouge-1': [], 'rouge-2': [], 'rouge-l': [], 'bleu-4': []}
|
metrics_dct = {'rouge-1': [], 'rouge-2': [], 'rouge-l': [], 'bleu-4': []}
|
||||||
for pred_ids, label_ids in zip(batched_pred_ids, batched_label_ids):
|
for pred_ids, label_ids in zip(batched_pred_ids, batched_label_ids):
|
||||||
pred_txt = tokenizer.decode(pred_ids).strip()
|
pred_txt = tokenizer.decode(pred_ids, skip_special_tokens=True).strip()
|
||||||
label_txt = tokenizer.decode(label_ids).strip()
|
label_txt = tokenizer.decode(label_ids, skip_special_tokens=True).strip()
|
||||||
pred_tokens = list(jieba.cut(pred_txt))
|
pred_tokens = list(jieba.cut(pred_txt))
|
||||||
label_tokens = list(jieba.cut(label_txt))
|
label_tokens = list(jieba.cut(label_txt))
|
||||||
rouge = Rouge()
|
rouge = Rouge()
|
||||||
|
|
|
@ -420,8 +420,8 @@ def compute_metrics(eval_preds: EvalPrediction, tokenizer):
|
||||||
batched_label_ids[batched_label_ids==-100] = tokenizer.pad_token_id
|
batched_label_ids[batched_label_ids==-100] = tokenizer.pad_token_id
|
||||||
metrics_dct = {'rouge-1': [], 'rouge-2': [], 'rouge-l': [], 'bleu-4': []}
|
metrics_dct = {'rouge-1': [], 'rouge-2': [], 'rouge-l': [], 'bleu-4': []}
|
||||||
for pred_ids, label_ids in zip(batched_pred_ids, batched_label_ids):
|
for pred_ids, label_ids in zip(batched_pred_ids, batched_label_ids):
|
||||||
pred_txt = tokenizer.decode(pred_ids).strip()
|
pred_txt = tokenizer.decode(pred_ids, skip_special_tokens=True).strip()
|
||||||
label_txt = tokenizer.decode(label_ids).strip()
|
label_txt = tokenizer.decode(label_ids, skip_special_tokens=True).strip()
|
||||||
pred_tokens = list(jieba.cut(pred_txt))
|
pred_tokens = list(jieba.cut(pred_txt))
|
||||||
label_tokens = list(jieba.cut(label_txt))
|
label_tokens = list(jieba.cut(label_txt))
|
||||||
rouge = Rouge()
|
rouge = Rouge()
|
||||||
|
|
Loading…
Reference in New Issue