fix openai stream function
This commit is contained in:
parent
920425c9fe
commit
5c4bf6201c
|
@ -5,6 +5,8 @@ import uvicorn
|
|||
import gc
|
||||
import json
|
||||
import torch
|
||||
import random
|
||||
import string
|
||||
|
||||
from vllm import SamplingParams, AsyncEngineArgs, AsyncLLMEngine
|
||||
from fastapi import FastAPI, HTTPException, Response
|
||||
|
@ -40,8 +42,13 @@ app.add_middleware(
|
|||
)
|
||||
|
||||
|
||||
def generate_id(prefix: str, k=29) -> str:
|
||||
suffix = ''.join(random.choices(string.ascii_letters + string.digits, k=k))
|
||||
return f"{prefix}{suffix}"
|
||||
|
||||
|
||||
class ModelCard(BaseModel):
|
||||
id: str
|
||||
id: str = ""
|
||||
object: str = "model"
|
||||
created: int = Field(default_factory=lambda: int(time.time()))
|
||||
owned_by: str = "owner"
|
||||
|
@ -60,7 +67,7 @@ class FunctionCall(BaseModel):
|
|||
arguments: str
|
||||
|
||||
|
||||
class FunctionCallResponse(BaseModel):
|
||||
class ChoiceDeltaToolCallFunction(BaseModel):
|
||||
name: Optional[str] = None
|
||||
arguments: Optional[str] = None
|
||||
|
||||
|
@ -72,22 +79,24 @@ class UsageInfo(BaseModel):
|
|||
|
||||
|
||||
class ChatCompletionMessageToolCall(BaseModel):
|
||||
id: str
|
||||
index: Optional[int] = 0
|
||||
id: Optional[str] = None
|
||||
function: FunctionCall
|
||||
type: Literal["function"]
|
||||
type: Optional[Literal["function"]] = 'function'
|
||||
|
||||
|
||||
class ChatMessage(BaseModel):
|
||||
role: Literal["user", "assistant", "system", "tool"]
|
||||
content: Optional[str] = None
|
||||
function_call: Optional[FunctionCallResponse] = None
|
||||
function_call: Optional[ChoiceDeltaToolCallFunction] = None
|
||||
tool_calls: Optional[List[ChatCompletionMessageToolCall]] = None
|
||||
|
||||
|
||||
class DeltaMessage(BaseModel):
|
||||
role: Optional[Literal["user", "assistant", "system"]] = None
|
||||
content: Optional[str] = None
|
||||
function_call: Optional[FunctionCallResponse] = None
|
||||
function_call: Optional[ChoiceDeltaToolCallFunction] = None
|
||||
tool_calls: Optional[List[ChatCompletionMessageToolCall]] = None
|
||||
|
||||
|
||||
class ChatCompletionResponseChoice(BaseModel):
|
||||
|
@ -104,10 +113,11 @@ class ChatCompletionResponseStreamChoice(BaseModel):
|
|||
|
||||
class ChatCompletionResponse(BaseModel):
|
||||
model: str
|
||||
id: str
|
||||
id: Optional[str] = Field(default_factory=lambda: generate_id('chatcmpl-', 29))
|
||||
object: Literal["chat.completion", "chat.completion.chunk"]
|
||||
choices: List[Union[ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice]]
|
||||
created: Optional[int] = Field(default_factory=lambda: int(time.time()))
|
||||
system_fingerprint: Optional[str] = Field(default_factory=lambda: generate_id('fp_', 9))
|
||||
usage: Optional[UsageInfo] = None
|
||||
|
||||
|
||||
|
@ -153,7 +163,8 @@ def process_response(output: str, use_tool: bool = False) -> Union[str, dict]:
|
|||
if is_tool_call and use_tool:
|
||||
content = {
|
||||
"name": function_name,
|
||||
"arguments": json.dumps(arguments_json if isinstance(arguments_json, dict) else arguments, ensure_ascii=False)
|
||||
"arguments": json.dumps(arguments_json if isinstance(arguments_json, dict) else arguments,
|
||||
ensure_ascii=False)
|
||||
}
|
||||
if function_name == "simple_browser":
|
||||
search_pattern = re.compile(r'search\("(.+?)"\s*,\s*recency_days\s*=\s*(\d+)\)')
|
||||
|
@ -172,9 +183,6 @@ def process_response(output: str, use_tool: bool = False) -> Union[str, dict]:
|
|||
return output.strip()
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
async def generate_stream_glm4(params):
|
||||
messages = params["messages"]
|
||||
|
@ -220,6 +228,7 @@ async def generate_stream_glm4(params):
|
|||
"finish_reason": output.outputs[0].finish_reason,
|
||||
}
|
||||
yield ret
|
||||
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
@ -317,7 +326,6 @@ def process_messages(messages, tools=None, tool_choice="none"):
|
|||
return processed_messages
|
||||
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
async def health() -> Response:
|
||||
"""Health check."""
|
||||
|
@ -335,7 +343,6 @@ async def create_chat_completion(request: ChatCompletionRequest):
|
|||
if len(request.messages) < 1 or request.messages[-1].role == "assistant":
|
||||
raise HTTPException(status_code=400, detail="Invalid request")
|
||||
|
||||
|
||||
gen_params = dict(
|
||||
messages=request.messages,
|
||||
temperature=request.temperature,
|
||||
|
@ -364,12 +371,11 @@ async def create_chat_completion(request: ChatCompletionRequest):
|
|||
logger.warning("Failed to parse tool call")
|
||||
|
||||
if isinstance(function_call, dict):
|
||||
function_call = FunctionCallResponse(**function_call)
|
||||
function_call = ChoiceDeltaToolCallFunction(**function_call)
|
||||
generate = parse_output_text(request.model, output, function_call=function_call)
|
||||
return EventSourceResponse(generate, media_type="text/event-stream")
|
||||
else:
|
||||
return EventSourceResponse(predict_stream_generator, media_type="text/event-stream")
|
||||
|
||||
response = ""
|
||||
async for response in generate_stream_glm4(gen_params):
|
||||
pass
|
||||
|
@ -390,7 +396,7 @@ async def create_chat_completion(request: ChatCompletionRequest):
|
|||
|
||||
if isinstance(function_call, dict):
|
||||
finish_reason = "tool_calls"
|
||||
function_call_response = FunctionCallResponse(**function_call)
|
||||
function_call_response = ChoiceDeltaToolCallFunction(**function_call)
|
||||
function_call_instance = FunctionCall(
|
||||
name=function_call_response.name,
|
||||
arguments=function_call_response.arguments
|
||||
|
@ -421,7 +427,6 @@ async def create_chat_completion(request: ChatCompletionRequest):
|
|||
|
||||
return ChatCompletionResponse(
|
||||
model=request.model,
|
||||
id="",
|
||||
choices=[choice_data],
|
||||
object="chat.completion",
|
||||
usage=usage
|
||||
|
@ -433,6 +438,9 @@ async def predict_stream(model_id, gen_params):
|
|||
is_function_call = False
|
||||
has_send_first_chunk = False
|
||||
function_name = None
|
||||
created_time = int(time.time())
|
||||
response_id = generate_id('chatcmpl-', 29)
|
||||
system_fingerprint = generate_id('fp_', 9)
|
||||
async for new_response in generate_stream_glm4(gen_params):
|
||||
decoded_unicode = new_response["text"]
|
||||
delta_text = decoded_unicode[len(output):]
|
||||
|
@ -446,10 +454,16 @@ async def predict_stream(model_id, gen_params):
|
|||
if is_function_call:
|
||||
for char in delta_text:
|
||||
function_call = {"name": function_name, "arguments": char}
|
||||
tool_call = ChatCompletionMessageToolCall(
|
||||
index=0,
|
||||
function=FunctionCall(**function_call),
|
||||
type="function"
|
||||
)
|
||||
message = DeltaMessage(
|
||||
content=None,
|
||||
role="assistant",
|
||||
function_call=function_call
|
||||
role=None,
|
||||
function_call=None,
|
||||
tool_calls=[tool_call]
|
||||
)
|
||||
choice_data = ChatCompletionResponseStreamChoice(
|
||||
index=0,
|
||||
|
@ -458,9 +472,10 @@ async def predict_stream(model_id, gen_params):
|
|||
)
|
||||
chunk = ChatCompletionResponse(
|
||||
model=model_id,
|
||||
id="",
|
||||
id=response_id,
|
||||
choices=[choice_data],
|
||||
created=int(time.time()),
|
||||
created=created_time,
|
||||
system_fingerprint=system_fingerprint,
|
||||
object="chat.completion.chunk"
|
||||
)
|
||||
yield chunk.model_dump_json(exclude_unset=True)
|
||||
|
@ -480,9 +495,10 @@ async def predict_stream(model_id, gen_params):
|
|||
)
|
||||
chunk = ChatCompletionResponse(
|
||||
model=model_id,
|
||||
id="",
|
||||
id=response_id,
|
||||
choices=[choice_data],
|
||||
created=int(time.time()),
|
||||
created=created_time,
|
||||
system_fingerprint=system_fingerprint,
|
||||
object="chat.completion.chunk"
|
||||
)
|
||||
yield chunk.model_dump_json(exclude_unset=True)
|
||||
|
@ -501,20 +517,36 @@ async def predict_stream(model_id, gen_params):
|
|||
)
|
||||
chunk = ChatCompletionResponse(
|
||||
model=model_id,
|
||||
id="",
|
||||
id=response_id,
|
||||
choices=[choice_data],
|
||||
created=int(time.time()),
|
||||
created=created_time,
|
||||
system_fingerprint=system_fingerprint,
|
||||
object="chat.completion.chunk"
|
||||
)
|
||||
yield chunk.model_dump_json(exclude_unset=True)
|
||||
|
||||
if is_function_call:
|
||||
yield json.dumps({"text": output})
|
||||
else:
|
||||
yield '[DONE]'
|
||||
yield ChatCompletionResponse(
|
||||
model=model_id,
|
||||
id=response_id,
|
||||
system_fingerprint=system_fingerprint,
|
||||
choices=[
|
||||
ChatCompletionResponseStreamChoice(
|
||||
index=0,
|
||||
delta=DeltaMessage(
|
||||
content=None,
|
||||
role=None,
|
||||
function_call=None,
|
||||
),
|
||||
finish_reason="tool_calls"
|
||||
)],
|
||||
created=created_time,
|
||||
object="chat.completion.chunk",
|
||||
usage=None
|
||||
).model_dump_json(exclude_unset=True)
|
||||
|
||||
|
||||
async def parse_output_text(model_id: str, value: str, function_call: FunctionCallResponse = None):
|
||||
async def parse_output_text(model_id: str, value: str, function_call: ChoiceDeltaToolCallFunction = None):
|
||||
delta = DeltaMessage(role="assistant", content=value)
|
||||
if function_call is not None:
|
||||
delta.function_call = function_call
|
||||
|
@ -524,9 +556,13 @@ async def parse_output_text(model_id: str, value: str, function_call: FunctionCa
|
|||
delta=delta,
|
||||
finish_reason=None
|
||||
)
|
||||
chunk = ChatCompletionResponse(model=model_id, id="", choices=[choice_data], object="chat.completion.chunk")
|
||||
chunk = ChatCompletionResponse(
|
||||
model=model_id,
|
||||
choices=[choice_data],
|
||||
object="chat.completion.chunk"
|
||||
)
|
||||
yield "{}".format(chunk.model_dump_json(exclude_unset=True))
|
||||
yield '[DONE]'
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, trust_remote_code=True)
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
# use vllm
|
||||
# vllm>=0.4.3
|
||||
# vllm>=0.5.0
|
||||
|
||||
torch>=2.3.0
|
||||
torchvision>=0.18.0
|
||||
|
|
|
@ -166,7 +166,7 @@ OMP_NUM_THREADS=1 torchrun --standalone --nnodes=1 --nproc_per_node=8 finetune_
|
|||
通过以下代码执行 **单机单卡** 运行。
|
||||
|
||||
```shell
|
||||
python finetune_hf.py data/AdvertiseGen/ THUDM/glm-4-9b-chat configs/lora.yaml
|
||||
python finetune.py data/AdvertiseGen/ THUDM/glm-4-9b-chat configs/lora.yaml
|
||||
```
|
||||
|
||||
## 从保存点进行微调
|
||||
|
@ -179,7 +179,7 @@ python finetune_hf.py data/AdvertiseGen/ THUDM/glm-4-9b-chat configs/lora.yam
|
|||
例如,这就是一个从最后一个保存点继续微调的示例代码
|
||||
|
||||
```shell
|
||||
python finetune_hf.py data/AdvertiseGen/ THUDM/glm-4-9b-chat configs/lora.yaml yes
|
||||
python finetune.py data/AdvertiseGen/ THUDM/glm-4-9b-chat configs/lora.yaml yes
|
||||
```
|
||||
|
||||
## 使用微调后的模型
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
jieba>=0.42.1
|
||||
datasets>=2.19.1
|
||||
peft>=0.11.0
|
||||
deepspeed>=0.13.3
|
||||
datasets>2.20.0
|
||||
peft>=0.11.1
|
||||
deepspeed>=0.14.3
|
||||
nltk==3.8.1
|
Loading…
Reference in New Issue