update to support agent chat sft.
This commit is contained in:
parent
80e1b4cf9b
commit
c6f0629007
|
@ -283,7 +283,7 @@ def process_batch(
|
||||||
for message in conv:
|
for message in conv:
|
||||||
message = process_message(message)
|
message = process_message(message)
|
||||||
loss_mask_val = False if message['role'] in ('system', 'user', 'observation') else True
|
loss_mask_val = False if message['role'] in ('system', 'user', 'observation') else True
|
||||||
new_input_ids = tokenizer.apply_chat_template([message], tokenize=True, return_dict=False)[0][2:]
|
new_input_ids = tokenizer.apply_chat_template([message], tokenize=True, return_dict=False)[2:]
|
||||||
input_ids += new_input_ids
|
input_ids += new_input_ids
|
||||||
loss_masks += [loss_mask_val] * len(new_input_ids)
|
loss_masks += [loss_mask_val] * len(new_input_ids)
|
||||||
|
|
||||||
|
@ -337,7 +337,7 @@ def process_batch_eval(
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
message = process_message(message)
|
message = process_message(message)
|
||||||
new_input_ids = tokenizer.apply_chat_template([message], tokenize=True, return_dict=False)[0][2:]
|
new_input_ids = tokenizer.apply_chat_template([message], tokenize=True, return_dict=False)[2:]
|
||||||
if message['role'] == 'assistant':
|
if message['role'] == 'assistant':
|
||||||
output_prompt, output_ids = (
|
output_prompt, output_ids = (
|
||||||
new_input_ids[:1],
|
new_input_ids[:1],
|
||||||
|
|
Loading…
Reference in New Issue