update to support agent chat sft.

This commit is contained in:
wylilong 2024-09-26 20:05:18 +08:00
parent 23773d94e2
commit 80e1b4cf9b
1 changed files with 15 additions and 2 deletions

View File

@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
import os
import json
import jieba
import dataclasses as dc
import functools
@ -243,6 +244,18 @@ def process_message(message):
v is not None}
elif 'tools' in message:
del message['tools']
# convert tarin data of agent chat.
if message['role'] == 'assistant':
content = message['content']
if isinstance(content, str) and content.startswith("{") and content.endswith("}"):
try:
content_ = eval(content)
if isinstance(content_, dict) and "name" in content_ and "arguments" in content_:
message['content'] = json.dumps(content_["arguments"], ensure_ascii=False)
message['metadata'] = content_["name"]
except:
pass
return message
@ -270,7 +283,7 @@ def process_batch(
for message in conv:
message = process_message(message)
loss_mask_val = False if message['role'] in ('system', 'user', 'observation') else True
new_input_ids = tokenizer.apply_chat_template([message], tokenize=True, return_dict=False)[2:]
new_input_ids = tokenizer.apply_chat_template([message], tokenize=True, return_dict=False)[0][2:]
input_ids += new_input_ids
loss_masks += [loss_mask_val] * len(new_input_ids)
@ -324,7 +337,7 @@ def process_batch_eval(
break
else:
message = process_message(message)
new_input_ids = tokenizer.apply_chat_template([message], tokenize=True, return_dict=False)[2:]
new_input_ids = tokenizer.apply_chat_template([message], tokenize=True, return_dict=False)[0][2:]
if message['role'] == 'assistant':
output_prompt, output_ids = (
new_input_ids[:1],