update to support agent chat sft.

This commit is contained in:
wylilong 2024-09-27 18:00:16 +08:00
parent 80e1b4cf9b
commit c6f0629007
1 changed files with 2 additions and 2 deletions

View File

@ -283,7 +283,7 @@ def process_batch(
for message in conv:
message = process_message(message)
loss_mask_val = False if message['role'] in ('system', 'user', 'observation') else True
new_input_ids = tokenizer.apply_chat_template([message], tokenize=True, return_dict=False)[0][2:]
new_input_ids = tokenizer.apply_chat_template([message], tokenize=True, return_dict=False)[2:]
input_ids += new_input_ids
loss_masks += [loss_mask_val] * len(new_input_ids)
@ -337,7 +337,7 @@ def process_batch_eval(
break
else:
message = process_message(message)
new_input_ids = tokenizer.apply_chat_template([message], tokenize=True, return_dict=False)[0][2:]
new_input_ids = tokenizer.apply_chat_template([message], tokenize=True, return_dict=False)[2:]
if message['role'] == 'assistant':
output_prompt, output_ids = (
new_input_ids[:1],