fix openai stream function

This commit is contained in:
zR 2024-06-15 20:59:23 +08:00
parent 920425c9fe
commit 5c4bf6201c
4 changed files with 73 additions and 37 deletions

View File

@ -5,6 +5,8 @@ import uvicorn
import gc
import json
import torch
import random
import string
from vllm import SamplingParams, AsyncEngineArgs, AsyncLLMEngine
from fastapi import FastAPI, HTTPException, Response
@ -40,8 +42,13 @@ app.add_middleware(
)
def generate_id(prefix: str, k=29) -> str:
suffix = ''.join(random.choices(string.ascii_letters + string.digits, k=k))
return f"{prefix}{suffix}"
class ModelCard(BaseModel):
id: str
id: str = ""
object: str = "model"
created: int = Field(default_factory=lambda: int(time.time()))
owned_by: str = "owner"
@ -60,7 +67,7 @@ class FunctionCall(BaseModel):
arguments: str
class FunctionCallResponse(BaseModel):
class ChoiceDeltaToolCallFunction(BaseModel):
name: Optional[str] = None
arguments: Optional[str] = None
@ -72,22 +79,24 @@ class UsageInfo(BaseModel):
class ChatCompletionMessageToolCall(BaseModel):
id: str
index: Optional[int] = 0
id: Optional[str] = None
function: FunctionCall
type: Literal["function"]
type: Optional[Literal["function"]] = 'function'
class ChatMessage(BaseModel):
role: Literal["user", "assistant", "system", "tool"]
content: Optional[str] = None
function_call: Optional[FunctionCallResponse] = None
function_call: Optional[ChoiceDeltaToolCallFunction] = None
tool_calls: Optional[List[ChatCompletionMessageToolCall]] = None
class DeltaMessage(BaseModel):
role: Optional[Literal["user", "assistant", "system"]] = None
content: Optional[str] = None
function_call: Optional[FunctionCallResponse] = None
function_call: Optional[ChoiceDeltaToolCallFunction] = None
tool_calls: Optional[List[ChatCompletionMessageToolCall]] = None
class ChatCompletionResponseChoice(BaseModel):
@ -104,10 +113,11 @@ class ChatCompletionResponseStreamChoice(BaseModel):
class ChatCompletionResponse(BaseModel):
model: str
id: str
id: Optional[str] = Field(default_factory=lambda: generate_id('chatcmpl-', 29))
object: Literal["chat.completion", "chat.completion.chunk"]
choices: List[Union[ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice]]
created: Optional[int] = Field(default_factory=lambda: int(time.time()))
system_fingerprint: Optional[str] = Field(default_factory=lambda: generate_id('fp_', 9))
usage: Optional[UsageInfo] = None
@ -153,7 +163,8 @@ def process_response(output: str, use_tool: bool = False) -> Union[str, dict]:
if is_tool_call and use_tool:
content = {
"name": function_name,
"arguments": json.dumps(arguments_json if isinstance(arguments_json, dict) else arguments, ensure_ascii=False)
"arguments": json.dumps(arguments_json if isinstance(arguments_json, dict) else arguments,
ensure_ascii=False)
}
if function_name == "simple_browser":
search_pattern = re.compile(r'search\("(.+?)"\s*,\s*recency_days\s*=\s*(\d+)\)')
@ -172,9 +183,6 @@ def process_response(output: str, use_tool: bool = False) -> Union[str, dict]:
return output.strip()
@torch.inference_mode()
async def generate_stream_glm4(params):
messages = params["messages"]
@ -220,6 +228,7 @@ async def generate_stream_glm4(params):
"finish_reason": output.outputs[0].finish_reason,
}
yield ret
gc.collect()
torch.cuda.empty_cache()
@ -317,7 +326,6 @@ def process_messages(messages, tools=None, tool_choice="none"):
return processed_messages
@app.get("/health")
async def health() -> Response:
"""Health check."""
@ -335,7 +343,6 @@ async def create_chat_completion(request: ChatCompletionRequest):
if len(request.messages) < 1 or request.messages[-1].role == "assistant":
raise HTTPException(status_code=400, detail="Invalid request")
gen_params = dict(
messages=request.messages,
temperature=request.temperature,
@ -364,12 +371,11 @@ async def create_chat_completion(request: ChatCompletionRequest):
logger.warning("Failed to parse tool call")
if isinstance(function_call, dict):
function_call = FunctionCallResponse(**function_call)
function_call = ChoiceDeltaToolCallFunction(**function_call)
generate = parse_output_text(request.model, output, function_call=function_call)
return EventSourceResponse(generate, media_type="text/event-stream")
else:
return EventSourceResponse(predict_stream_generator, media_type="text/event-stream")
response = ""
async for response in generate_stream_glm4(gen_params):
pass
@ -390,7 +396,7 @@ async def create_chat_completion(request: ChatCompletionRequest):
if isinstance(function_call, dict):
finish_reason = "tool_calls"
function_call_response = FunctionCallResponse(**function_call)
function_call_response = ChoiceDeltaToolCallFunction(**function_call)
function_call_instance = FunctionCall(
name=function_call_response.name,
arguments=function_call_response.arguments
@ -421,7 +427,6 @@ async def create_chat_completion(request: ChatCompletionRequest):
return ChatCompletionResponse(
model=request.model,
id="",
choices=[choice_data],
object="chat.completion",
usage=usage
@ -433,6 +438,9 @@ async def predict_stream(model_id, gen_params):
is_function_call = False
has_send_first_chunk = False
function_name = None
created_time = int(time.time())
response_id = generate_id('chatcmpl-', 29)
system_fingerprint = generate_id('fp_', 9)
async for new_response in generate_stream_glm4(gen_params):
decoded_unicode = new_response["text"]
delta_text = decoded_unicode[len(output):]
@ -446,10 +454,16 @@ async def predict_stream(model_id, gen_params):
if is_function_call:
for char in delta_text:
function_call = {"name": function_name, "arguments": char}
tool_call = ChatCompletionMessageToolCall(
index=0,
function=FunctionCall(**function_call),
type="function"
)
message = DeltaMessage(
content=None,
role="assistant",
function_call=function_call
role=None,
function_call=None,
tool_calls=[tool_call]
)
choice_data = ChatCompletionResponseStreamChoice(
index=0,
@ -458,9 +472,10 @@ async def predict_stream(model_id, gen_params):
)
chunk = ChatCompletionResponse(
model=model_id,
id="",
id=response_id,
choices=[choice_data],
created=int(time.time()),
created=created_time,
system_fingerprint=system_fingerprint,
object="chat.completion.chunk"
)
yield chunk.model_dump_json(exclude_unset=True)
@ -480,9 +495,10 @@ async def predict_stream(model_id, gen_params):
)
chunk = ChatCompletionResponse(
model=model_id,
id="",
id=response_id,
choices=[choice_data],
created=int(time.time()),
created=created_time,
system_fingerprint=system_fingerprint,
object="chat.completion.chunk"
)
yield chunk.model_dump_json(exclude_unset=True)
@ -501,20 +517,36 @@ async def predict_stream(model_id, gen_params):
)
chunk = ChatCompletionResponse(
model=model_id,
id="",
id=response_id,
choices=[choice_data],
created=int(time.time()),
created=created_time,
system_fingerprint=system_fingerprint,
object="chat.completion.chunk"
)
yield chunk.model_dump_json(exclude_unset=True)
if is_function_call:
yield json.dumps({"text": output})
else:
yield '[DONE]'
yield ChatCompletionResponse(
model=model_id,
id=response_id,
system_fingerprint=system_fingerprint,
choices=[
ChatCompletionResponseStreamChoice(
index=0,
delta=DeltaMessage(
content=None,
role=None,
function_call=None,
),
finish_reason="tool_calls"
)],
created=created_time,
object="chat.completion.chunk",
usage=None
).model_dump_json(exclude_unset=True)
async def parse_output_text(model_id: str, value: str, function_call: FunctionCallResponse = None):
async def parse_output_text(model_id: str, value: str, function_call: ChoiceDeltaToolCallFunction = None):
delta = DeltaMessage(role="assistant", content=value)
if function_call is not None:
delta.function_call = function_call
@ -524,9 +556,13 @@ async def parse_output_text(model_id: str, value: str, function_call: FunctionCa
delta=delta,
finish_reason=None
)
chunk = ChatCompletionResponse(model=model_id, id="", choices=[choice_data], object="chat.completion.chunk")
chunk = ChatCompletionResponse(
model=model_id,
choices=[choice_data],
object="chat.completion.chunk"
)
yield "{}".format(chunk.model_dump_json(exclude_unset=True))
yield '[DONE]'
if __name__ == "__main__":
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, trust_remote_code=True)

View File

@ -1,5 +1,5 @@
# use vllm
# vllm>=0.4.3
# vllm>=0.5.0
torch>=2.3.0
torchvision>=0.18.0

View File

@ -166,7 +166,7 @@ OMP_NUM_THREADS=1 torchrun --standalone --nnodes=1 --nproc_per_node=8 finetune_
通过以下代码执行 **单机单卡** 运行。
```shell
python finetune_hf.py data/AdvertiseGen/ THUDM/glm-4-9b-chat configs/lora.yaml
python finetune.py data/AdvertiseGen/ THUDM/glm-4-9b-chat configs/lora.yaml
```
## 从保存点进行微调
@ -179,7 +179,7 @@ python finetune_hf.py data/AdvertiseGen/ THUDM/glm-4-9b-chat configs/lora.yam
例如,这就是一个从最后一个保存点继续微调的示例代码
```shell
python finetune_hf.py data/AdvertiseGen/ THUDM/glm-4-9b-chat configs/lora.yaml yes
python finetune.py data/AdvertiseGen/ THUDM/glm-4-9b-chat configs/lora.yaml yes
```
## 使用微调后的模型

View File

@ -1,5 +1,5 @@
jieba>=0.42.1
datasets>=2.19.1
peft>=0.11.0
deepspeed>=0.13.3
datasets>2.20.0
peft>=0.11.1
deepspeed>=0.14.3
nltk==3.8.1