fix openai stream function

This commit is contained in:
zR 2024-06-15 20:59:23 +08:00
parent 920425c9fe
commit 5c4bf6201c
4 changed files with 73 additions and 37 deletions

View File

@ -5,6 +5,8 @@ import uvicorn
import gc import gc
import json import json
import torch import torch
import random
import string
from vllm import SamplingParams, AsyncEngineArgs, AsyncLLMEngine from vllm import SamplingParams, AsyncEngineArgs, AsyncLLMEngine
from fastapi import FastAPI, HTTPException, Response from fastapi import FastAPI, HTTPException, Response
@ -40,8 +42,13 @@ app.add_middleware(
) )
def generate_id(prefix: str, k=29) -> str:
suffix = ''.join(random.choices(string.ascii_letters + string.digits, k=k))
return f"{prefix}{suffix}"
class ModelCard(BaseModel): class ModelCard(BaseModel):
id: str id: str = ""
object: str = "model" object: str = "model"
created: int = Field(default_factory=lambda: int(time.time())) created: int = Field(default_factory=lambda: int(time.time()))
owned_by: str = "owner" owned_by: str = "owner"
@ -60,7 +67,7 @@ class FunctionCall(BaseModel):
arguments: str arguments: str
class FunctionCallResponse(BaseModel): class ChoiceDeltaToolCallFunction(BaseModel):
name: Optional[str] = None name: Optional[str] = None
arguments: Optional[str] = None arguments: Optional[str] = None
@ -72,22 +79,24 @@ class UsageInfo(BaseModel):
class ChatCompletionMessageToolCall(BaseModel): class ChatCompletionMessageToolCall(BaseModel):
id: str index: Optional[int] = 0
id: Optional[str] = None
function: FunctionCall function: FunctionCall
type: Literal["function"] type: Optional[Literal["function"]] = 'function'
class ChatMessage(BaseModel): class ChatMessage(BaseModel):
role: Literal["user", "assistant", "system", "tool"] role: Literal["user", "assistant", "system", "tool"]
content: Optional[str] = None content: Optional[str] = None
function_call: Optional[FunctionCallResponse] = None function_call: Optional[ChoiceDeltaToolCallFunction] = None
tool_calls: Optional[List[ChatCompletionMessageToolCall]] = None tool_calls: Optional[List[ChatCompletionMessageToolCall]] = None
class DeltaMessage(BaseModel): class DeltaMessage(BaseModel):
role: Optional[Literal["user", "assistant", "system"]] = None role: Optional[Literal["user", "assistant", "system"]] = None
content: Optional[str] = None content: Optional[str] = None
function_call: Optional[FunctionCallResponse] = None function_call: Optional[ChoiceDeltaToolCallFunction] = None
tool_calls: Optional[List[ChatCompletionMessageToolCall]] = None
class ChatCompletionResponseChoice(BaseModel): class ChatCompletionResponseChoice(BaseModel):
@ -104,10 +113,11 @@ class ChatCompletionResponseStreamChoice(BaseModel):
class ChatCompletionResponse(BaseModel): class ChatCompletionResponse(BaseModel):
model: str model: str
id: str id: Optional[str] = Field(default_factory=lambda: generate_id('chatcmpl-', 29))
object: Literal["chat.completion", "chat.completion.chunk"] object: Literal["chat.completion", "chat.completion.chunk"]
choices: List[Union[ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice]] choices: List[Union[ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice]]
created: Optional[int] = Field(default_factory=lambda: int(time.time())) created: Optional[int] = Field(default_factory=lambda: int(time.time()))
system_fingerprint: Optional[str] = Field(default_factory=lambda: generate_id('fp_', 9))
usage: Optional[UsageInfo] = None usage: Optional[UsageInfo] = None
@ -153,7 +163,8 @@ def process_response(output: str, use_tool: bool = False) -> Union[str, dict]:
if is_tool_call and use_tool: if is_tool_call and use_tool:
content = { content = {
"name": function_name, "name": function_name,
"arguments": json.dumps(arguments_json if isinstance(arguments_json, dict) else arguments, ensure_ascii=False) "arguments": json.dumps(arguments_json if isinstance(arguments_json, dict) else arguments,
ensure_ascii=False)
} }
if function_name == "simple_browser": if function_name == "simple_browser":
search_pattern = re.compile(r'search\("(.+?)"\s*,\s*recency_days\s*=\s*(\d+)\)') search_pattern = re.compile(r'search\("(.+?)"\s*,\s*recency_days\s*=\s*(\d+)\)')
@ -172,9 +183,6 @@ def process_response(output: str, use_tool: bool = False) -> Union[str, dict]:
return output.strip() return output.strip()
@torch.inference_mode() @torch.inference_mode()
async def generate_stream_glm4(params): async def generate_stream_glm4(params):
messages = params["messages"] messages = params["messages"]
@ -220,6 +228,7 @@ async def generate_stream_glm4(params):
"finish_reason": output.outputs[0].finish_reason, "finish_reason": output.outputs[0].finish_reason,
} }
yield ret yield ret
gc.collect() gc.collect()
torch.cuda.empty_cache() torch.cuda.empty_cache()
@ -317,7 +326,6 @@ def process_messages(messages, tools=None, tool_choice="none"):
return processed_messages return processed_messages
@app.get("/health") @app.get("/health")
async def health() -> Response: async def health() -> Response:
"""Health check.""" """Health check."""
@ -335,7 +343,6 @@ async def create_chat_completion(request: ChatCompletionRequest):
if len(request.messages) < 1 or request.messages[-1].role == "assistant": if len(request.messages) < 1 or request.messages[-1].role == "assistant":
raise HTTPException(status_code=400, detail="Invalid request") raise HTTPException(status_code=400, detail="Invalid request")
gen_params = dict( gen_params = dict(
messages=request.messages, messages=request.messages,
temperature=request.temperature, temperature=request.temperature,
@ -364,12 +371,11 @@ async def create_chat_completion(request: ChatCompletionRequest):
logger.warning("Failed to parse tool call") logger.warning("Failed to parse tool call")
if isinstance(function_call, dict): if isinstance(function_call, dict):
function_call = FunctionCallResponse(**function_call) function_call = ChoiceDeltaToolCallFunction(**function_call)
generate = parse_output_text(request.model, output, function_call=function_call) generate = parse_output_text(request.model, output, function_call=function_call)
return EventSourceResponse(generate, media_type="text/event-stream") return EventSourceResponse(generate, media_type="text/event-stream")
else: else:
return EventSourceResponse(predict_stream_generator, media_type="text/event-stream") return EventSourceResponse(predict_stream_generator, media_type="text/event-stream")
response = "" response = ""
async for response in generate_stream_glm4(gen_params): async for response in generate_stream_glm4(gen_params):
pass pass
@ -390,7 +396,7 @@ async def create_chat_completion(request: ChatCompletionRequest):
if isinstance(function_call, dict): if isinstance(function_call, dict):
finish_reason = "tool_calls" finish_reason = "tool_calls"
function_call_response = FunctionCallResponse(**function_call) function_call_response = ChoiceDeltaToolCallFunction(**function_call)
function_call_instance = FunctionCall( function_call_instance = FunctionCall(
name=function_call_response.name, name=function_call_response.name,
arguments=function_call_response.arguments arguments=function_call_response.arguments
@ -421,7 +427,6 @@ async def create_chat_completion(request: ChatCompletionRequest):
return ChatCompletionResponse( return ChatCompletionResponse(
model=request.model, model=request.model,
id="",
choices=[choice_data], choices=[choice_data],
object="chat.completion", object="chat.completion",
usage=usage usage=usage
@ -433,6 +438,9 @@ async def predict_stream(model_id, gen_params):
is_function_call = False is_function_call = False
has_send_first_chunk = False has_send_first_chunk = False
function_name = None function_name = None
created_time = int(time.time())
response_id = generate_id('chatcmpl-', 29)
system_fingerprint = generate_id('fp_', 9)
async for new_response in generate_stream_glm4(gen_params): async for new_response in generate_stream_glm4(gen_params):
decoded_unicode = new_response["text"] decoded_unicode = new_response["text"]
delta_text = decoded_unicode[len(output):] delta_text = decoded_unicode[len(output):]
@ -446,10 +454,16 @@ async def predict_stream(model_id, gen_params):
if is_function_call: if is_function_call:
for char in delta_text: for char in delta_text:
function_call = {"name": function_name, "arguments": char} function_call = {"name": function_name, "arguments": char}
tool_call = ChatCompletionMessageToolCall(
index=0,
function=FunctionCall(**function_call),
type="function"
)
message = DeltaMessage( message = DeltaMessage(
content=None, content=None,
role="assistant", role=None,
function_call=function_call function_call=None,
tool_calls=[tool_call]
) )
choice_data = ChatCompletionResponseStreamChoice( choice_data = ChatCompletionResponseStreamChoice(
index=0, index=0,
@ -458,9 +472,10 @@ async def predict_stream(model_id, gen_params):
) )
chunk = ChatCompletionResponse( chunk = ChatCompletionResponse(
model=model_id, model=model_id,
id="", id=response_id,
choices=[choice_data], choices=[choice_data],
created=int(time.time()), created=created_time,
system_fingerprint=system_fingerprint,
object="chat.completion.chunk" object="chat.completion.chunk"
) )
yield chunk.model_dump_json(exclude_unset=True) yield chunk.model_dump_json(exclude_unset=True)
@ -480,9 +495,10 @@ async def predict_stream(model_id, gen_params):
) )
chunk = ChatCompletionResponse( chunk = ChatCompletionResponse(
model=model_id, model=model_id,
id="", id=response_id,
choices=[choice_data], choices=[choice_data],
created=int(time.time()), created=created_time,
system_fingerprint=system_fingerprint,
object="chat.completion.chunk" object="chat.completion.chunk"
) )
yield chunk.model_dump_json(exclude_unset=True) yield chunk.model_dump_json(exclude_unset=True)
@ -501,20 +517,36 @@ async def predict_stream(model_id, gen_params):
) )
chunk = ChatCompletionResponse( chunk = ChatCompletionResponse(
model=model_id, model=model_id,
id="", id=response_id,
choices=[choice_data], choices=[choice_data],
created=int(time.time()), created=created_time,
system_fingerprint=system_fingerprint,
object="chat.completion.chunk" object="chat.completion.chunk"
) )
yield chunk.model_dump_json(exclude_unset=True) yield chunk.model_dump_json(exclude_unset=True)
if is_function_call: if is_function_call:
yield json.dumps({"text": output}) yield ChatCompletionResponse(
else: model=model_id,
yield '[DONE]' id=response_id,
system_fingerprint=system_fingerprint,
choices=[
ChatCompletionResponseStreamChoice(
index=0,
delta=DeltaMessage(
content=None,
role=None,
function_call=None,
),
finish_reason="tool_calls"
)],
created=created_time,
object="chat.completion.chunk",
usage=None
).model_dump_json(exclude_unset=True)
async def parse_output_text(model_id: str, value: str, function_call: FunctionCallResponse = None): async def parse_output_text(model_id: str, value: str, function_call: ChoiceDeltaToolCallFunction = None):
delta = DeltaMessage(role="assistant", content=value) delta = DeltaMessage(role="assistant", content=value)
if function_call is not None: if function_call is not None:
delta.function_call = function_call delta.function_call = function_call
@ -524,9 +556,13 @@ async def parse_output_text(model_id: str, value: str, function_call: FunctionCa
delta=delta, delta=delta,
finish_reason=None finish_reason=None
) )
chunk = ChatCompletionResponse(model=model_id, id="", choices=[choice_data], object="chat.completion.chunk") chunk = ChatCompletionResponse(
model=model_id,
choices=[choice_data],
object="chat.completion.chunk"
)
yield "{}".format(chunk.model_dump_json(exclude_unset=True)) yield "{}".format(chunk.model_dump_json(exclude_unset=True))
yield '[DONE]'
if __name__ == "__main__": if __name__ == "__main__":
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, trust_remote_code=True) tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, trust_remote_code=True)

View File

@ -1,5 +1,5 @@
# use vllm # use vllm
# vllm>=0.4.3 # vllm>=0.5.0
torch>=2.3.0 torch>=2.3.0
torchvision>=0.18.0 torchvision>=0.18.0

View File

@ -166,7 +166,7 @@ OMP_NUM_THREADS=1 torchrun --standalone --nnodes=1 --nproc_per_node=8 finetune_
通过以下代码执行 **单机单卡** 运行。 通过以下代码执行 **单机单卡** 运行。
```shell ```shell
python finetune_hf.py data/AdvertiseGen/ THUDM/glm-4-9b-chat configs/lora.yaml python finetune.py data/AdvertiseGen/ THUDM/glm-4-9b-chat configs/lora.yaml
``` ```
## 从保存点进行微调 ## 从保存点进行微调
@ -179,7 +179,7 @@ python finetune_hf.py data/AdvertiseGen/ THUDM/glm-4-9b-chat configs/lora.yam
例如,这就是一个从最后一个保存点继续微调的示例代码 例如,这就是一个从最后一个保存点继续微调的示例代码
```shell ```shell
python finetune_hf.py data/AdvertiseGen/ THUDM/glm-4-9b-chat configs/lora.yaml yes python finetune.py data/AdvertiseGen/ THUDM/glm-4-9b-chat configs/lora.yaml yes
``` ```
## 使用微调后的模型 ## 使用微调后的模型

View File

@ -1,5 +1,5 @@
jieba>=0.42.1 jieba>=0.42.1
datasets>=2.19.1 datasets>2.20.0
peft>=0.11.0 peft>=0.11.1
deepspeed>=0.13.3 deepspeed>=0.14.3
nltk==3.8.1 nltk==3.8.1