update to support agent chat sft.
This commit is contained in:
parent
80e1b4cf9b
commit
c6f0629007
|
@ -283,7 +283,7 @@ def process_batch(
|
|||
for message in conv:
|
||||
message = process_message(message)
|
||||
loss_mask_val = False if message['role'] in ('system', 'user', 'observation') else True
|
||||
new_input_ids = tokenizer.apply_chat_template([message], tokenize=True, return_dict=False)[0][2:]
|
||||
new_input_ids = tokenizer.apply_chat_template([message], tokenize=True, return_dict=False)[2:]
|
||||
input_ids += new_input_ids
|
||||
loss_masks += [loss_mask_val] * len(new_input_ids)
|
||||
|
||||
|
@ -337,7 +337,7 @@ def process_batch_eval(
|
|||
break
|
||||
else:
|
||||
message = process_message(message)
|
||||
new_input_ids = tokenizer.apply_chat_template([message], tokenize=True, return_dict=False)[0][2:]
|
||||
new_input_ids = tokenizer.apply_chat_template([message], tokenize=True, return_dict=False)[2:]
|
||||
if message['role'] == 'assistant':
|
||||
output_prompt, output_ids = (
|
||||
new_input_ids[:1],
|
||||
|
|
Loading…
Reference in New Issue