fix
This commit is contained in:
parent
468a56e91f
commit
8ce406d019
|
@ -261,7 +261,7 @@ def process_batch(
|
|||
for message in conv:
|
||||
message = process_message(message)
|
||||
loss_mask_val = False if message['role'] in ('system', 'user', 'observation') else True
|
||||
new_input_ids = tokenizer.apply_chat_template([message], tokenize=True, return_dict=False)[2:]
|
||||
new_input_ids = tokenizer.apply_chat_template([message], tokenize=True, return_dict=False)[0][2:]
|
||||
new_loss_masks = [loss_mask_val] * len(new_input_ids)
|
||||
input_ids += new_input_ids
|
||||
loss_masks += new_loss_masks
|
||||
|
@ -300,7 +300,7 @@ def process_batch_eval(
|
|||
break
|
||||
else:
|
||||
message = process_message(message)
|
||||
new_input_ids = tokenizer.apply_chat_template([message], tokenize=True, return_dict=False)[2:]
|
||||
new_input_ids = tokenizer.apply_chat_template([message], tokenize=True, return_dict=False)[0][2:]
|
||||
if message['role'] == 'assistant':
|
||||
output_prompt, output_ids = (
|
||||
new_input_ids[:1],
|
||||
|
|
Loading…
Reference in New Issue