This commit is contained in:
ljqiff 2024-07-06 17:51:59 +08:00
parent 468a56e91f
commit 8ce406d019
1 changed files with 2 additions and 2 deletions

View File

@ -261,7 +261,7 @@ def process_batch(
for message in conv: for message in conv:
message = process_message(message) message = process_message(message)
loss_mask_val = False if message['role'] in ('system', 'user', 'observation') else True loss_mask_val = False if message['role'] in ('system', 'user', 'observation') else True
new_input_ids = tokenizer.apply_chat_template([message], tokenize=True, return_dict=False)[2:] new_input_ids = tokenizer.apply_chat_template([message], tokenize=True, return_dict=False)[0][2:]
new_loss_masks = [loss_mask_val] * len(new_input_ids) new_loss_masks = [loss_mask_val] * len(new_input_ids)
input_ids += new_input_ids input_ids += new_input_ids
loss_masks += new_loss_masks loss_masks += new_loss_masks
@ -300,7 +300,7 @@ def process_batch_eval(
break break
else: else:
message = process_message(message) message = process_message(message)
new_input_ids = tokenizer.apply_chat_template([message], tokenize=True, return_dict=False)[2:] new_input_ids = tokenizer.apply_chat_template([message], tokenize=True, return_dict=False)[0][2:]
if message['role'] == 'assistant': if message['role'] == 'assistant':
output_prompt, output_ids = ( output_prompt, output_ids = (
new_input_ids[:1], new_input_ids[:1],