update to support agent chat sft.
This commit is contained in:
parent
23773d94e2
commit
80e1b4cf9b
|
@ -1,5 +1,6 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
import os
|
||||
import json
|
||||
import jieba
|
||||
import dataclasses as dc
|
||||
import functools
|
||||
|
@ -243,6 +244,18 @@ def process_message(message):
|
|||
v is not None}
|
||||
elif 'tools' in message:
|
||||
del message['tools']
|
||||
|
||||
# convert tarin data of agent chat.
|
||||
if message['role'] == 'assistant':
|
||||
content = message['content']
|
||||
if isinstance(content, str) and content.startswith("{") and content.endswith("}"):
|
||||
try:
|
||||
content_ = eval(content)
|
||||
if isinstance(content_, dict) and "name" in content_ and "arguments" in content_:
|
||||
message['content'] = json.dumps(content_["arguments"], ensure_ascii=False)
|
||||
message['metadata'] = content_["name"]
|
||||
except:
|
||||
pass
|
||||
return message
|
||||
|
||||
|
||||
|
@ -270,7 +283,7 @@ def process_batch(
|
|||
for message in conv:
|
||||
message = process_message(message)
|
||||
loss_mask_val = False if message['role'] in ('system', 'user', 'observation') else True
|
||||
new_input_ids = tokenizer.apply_chat_template([message], tokenize=True, return_dict=False)[2:]
|
||||
new_input_ids = tokenizer.apply_chat_template([message], tokenize=True, return_dict=False)[0][2:]
|
||||
input_ids += new_input_ids
|
||||
loss_masks += [loss_mask_val] * len(new_input_ids)
|
||||
|
||||
|
@ -324,7 +337,7 @@ def process_batch_eval(
|
|||
break
|
||||
else:
|
||||
message = process_message(message)
|
||||
new_input_ids = tokenizer.apply_chat_template([message], tokenize=True, return_dict=False)[2:]
|
||||
new_input_ids = tokenizer.apply_chat_template([message], tokenize=True, return_dict=False)[0][2:]
|
||||
if message['role'] == 'assistant':
|
||||
output_prompt, output_ids = (
|
||||
new_input_ids[:1],
|
||||
|
|
Loading…
Reference in New Issue