diff --git a/finetune_demo/finetune.py b/finetune_demo/finetune.py index 29ab3c4..6a348bf 100644 --- a/finetune_demo/finetune.py +++ b/finetune_demo/finetune.py @@ -261,7 +261,7 @@ def process_batch( for message in conv: message = process_message(message) loss_mask_val = False if message['role'] in ('system', 'user', 'observation') else True - new_input_ids = tokenizer.apply_chat_template([message], tokenize=True, return_dict=False)[2:] + new_input_ids = tokenizer.apply_chat_template([message], tokenize=True, return_dict=False)[0][2:] new_loss_masks = [loss_mask_val] * len(new_input_ids) input_ids += new_input_ids loss_masks += new_loss_masks @@ -300,7 +300,7 @@ def process_batch_eval( break else: message = process_message(message) - new_input_ids = tokenizer.apply_chat_template([message], tokenize=True, return_dict=False)[2:] + new_input_ids = tokenizer.apply_chat_template([message], tokenize=True, return_dict=False)[0][2:] if message['role'] == 'assistant': output_prompt, output_ids = ( new_input_ids[:1],