fix openai stream function
This commit is contained in:
parent
920425c9fe
commit
5c4bf6201c
|
@ -5,6 +5,8 @@ import uvicorn
|
||||||
import gc
|
import gc
|
||||||
import json
|
import json
|
||||||
import torch
|
import torch
|
||||||
|
import random
|
||||||
|
import string
|
||||||
|
|
||||||
from vllm import SamplingParams, AsyncEngineArgs, AsyncLLMEngine
|
from vllm import SamplingParams, AsyncEngineArgs, AsyncLLMEngine
|
||||||
from fastapi import FastAPI, HTTPException, Response
|
from fastapi import FastAPI, HTTPException, Response
|
||||||
|
@ -40,8 +42,13 @@ app.add_middleware(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def generate_id(prefix: str, k=29) -> str:
|
||||||
|
suffix = ''.join(random.choices(string.ascii_letters + string.digits, k=k))
|
||||||
|
return f"{prefix}{suffix}"
|
||||||
|
|
||||||
|
|
||||||
class ModelCard(BaseModel):
|
class ModelCard(BaseModel):
|
||||||
id: str
|
id: str = ""
|
||||||
object: str = "model"
|
object: str = "model"
|
||||||
created: int = Field(default_factory=lambda: int(time.time()))
|
created: int = Field(default_factory=lambda: int(time.time()))
|
||||||
owned_by: str = "owner"
|
owned_by: str = "owner"
|
||||||
|
@ -60,7 +67,7 @@ class FunctionCall(BaseModel):
|
||||||
arguments: str
|
arguments: str
|
||||||
|
|
||||||
|
|
||||||
class FunctionCallResponse(BaseModel):
|
class ChoiceDeltaToolCallFunction(BaseModel):
|
||||||
name: Optional[str] = None
|
name: Optional[str] = None
|
||||||
arguments: Optional[str] = None
|
arguments: Optional[str] = None
|
||||||
|
|
||||||
|
@ -72,22 +79,24 @@ class UsageInfo(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
class ChatCompletionMessageToolCall(BaseModel):
|
class ChatCompletionMessageToolCall(BaseModel):
|
||||||
id: str
|
index: Optional[int] = 0
|
||||||
|
id: Optional[str] = None
|
||||||
function: FunctionCall
|
function: FunctionCall
|
||||||
type: Literal["function"]
|
type: Optional[Literal["function"]] = 'function'
|
||||||
|
|
||||||
|
|
||||||
class ChatMessage(BaseModel):
|
class ChatMessage(BaseModel):
|
||||||
role: Literal["user", "assistant", "system", "tool"]
|
role: Literal["user", "assistant", "system", "tool"]
|
||||||
content: Optional[str] = None
|
content: Optional[str] = None
|
||||||
function_call: Optional[FunctionCallResponse] = None
|
function_call: Optional[ChoiceDeltaToolCallFunction] = None
|
||||||
tool_calls: Optional[List[ChatCompletionMessageToolCall]] = None
|
tool_calls: Optional[List[ChatCompletionMessageToolCall]] = None
|
||||||
|
|
||||||
|
|
||||||
class DeltaMessage(BaseModel):
|
class DeltaMessage(BaseModel):
|
||||||
role: Optional[Literal["user", "assistant", "system"]] = None
|
role: Optional[Literal["user", "assistant", "system"]] = None
|
||||||
content: Optional[str] = None
|
content: Optional[str] = None
|
||||||
function_call: Optional[FunctionCallResponse] = None
|
function_call: Optional[ChoiceDeltaToolCallFunction] = None
|
||||||
|
tool_calls: Optional[List[ChatCompletionMessageToolCall]] = None
|
||||||
|
|
||||||
|
|
||||||
class ChatCompletionResponseChoice(BaseModel):
|
class ChatCompletionResponseChoice(BaseModel):
|
||||||
|
@ -104,10 +113,11 @@ class ChatCompletionResponseStreamChoice(BaseModel):
|
||||||
|
|
||||||
class ChatCompletionResponse(BaseModel):
|
class ChatCompletionResponse(BaseModel):
|
||||||
model: str
|
model: str
|
||||||
id: str
|
id: Optional[str] = Field(default_factory=lambda: generate_id('chatcmpl-', 29))
|
||||||
object: Literal["chat.completion", "chat.completion.chunk"]
|
object: Literal["chat.completion", "chat.completion.chunk"]
|
||||||
choices: List[Union[ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice]]
|
choices: List[Union[ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice]]
|
||||||
created: Optional[int] = Field(default_factory=lambda: int(time.time()))
|
created: Optional[int] = Field(default_factory=lambda: int(time.time()))
|
||||||
|
system_fingerprint: Optional[str] = Field(default_factory=lambda: generate_id('fp_', 9))
|
||||||
usage: Optional[UsageInfo] = None
|
usage: Optional[UsageInfo] = None
|
||||||
|
|
||||||
|
|
||||||
|
@ -153,7 +163,8 @@ def process_response(output: str, use_tool: bool = False) -> Union[str, dict]:
|
||||||
if is_tool_call and use_tool:
|
if is_tool_call and use_tool:
|
||||||
content = {
|
content = {
|
||||||
"name": function_name,
|
"name": function_name,
|
||||||
"arguments": json.dumps(arguments_json if isinstance(arguments_json, dict) else arguments, ensure_ascii=False)
|
"arguments": json.dumps(arguments_json if isinstance(arguments_json, dict) else arguments,
|
||||||
|
ensure_ascii=False)
|
||||||
}
|
}
|
||||||
if function_name == "simple_browser":
|
if function_name == "simple_browser":
|
||||||
search_pattern = re.compile(r'search\("(.+?)"\s*,\s*recency_days\s*=\s*(\d+)\)')
|
search_pattern = re.compile(r'search\("(.+?)"\s*,\s*recency_days\s*=\s*(\d+)\)')
|
||||||
|
@ -172,9 +183,6 @@ def process_response(output: str, use_tool: bool = False) -> Union[str, dict]:
|
||||||
return output.strip()
|
return output.strip()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
async def generate_stream_glm4(params):
|
async def generate_stream_glm4(params):
|
||||||
messages = params["messages"]
|
messages = params["messages"]
|
||||||
|
@ -220,6 +228,7 @@ async def generate_stream_glm4(params):
|
||||||
"finish_reason": output.outputs[0].finish_reason,
|
"finish_reason": output.outputs[0].finish_reason,
|
||||||
}
|
}
|
||||||
yield ret
|
yield ret
|
||||||
|
|
||||||
gc.collect()
|
gc.collect()
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
@ -317,7 +326,6 @@ def process_messages(messages, tools=None, tool_choice="none"):
|
||||||
return processed_messages
|
return processed_messages
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@app.get("/health")
|
@app.get("/health")
|
||||||
async def health() -> Response:
|
async def health() -> Response:
|
||||||
"""Health check."""
|
"""Health check."""
|
||||||
|
@ -335,7 +343,6 @@ async def create_chat_completion(request: ChatCompletionRequest):
|
||||||
if len(request.messages) < 1 or request.messages[-1].role == "assistant":
|
if len(request.messages) < 1 or request.messages[-1].role == "assistant":
|
||||||
raise HTTPException(status_code=400, detail="Invalid request")
|
raise HTTPException(status_code=400, detail="Invalid request")
|
||||||
|
|
||||||
|
|
||||||
gen_params = dict(
|
gen_params = dict(
|
||||||
messages=request.messages,
|
messages=request.messages,
|
||||||
temperature=request.temperature,
|
temperature=request.temperature,
|
||||||
|
@ -364,12 +371,11 @@ async def create_chat_completion(request: ChatCompletionRequest):
|
||||||
logger.warning("Failed to parse tool call")
|
logger.warning("Failed to parse tool call")
|
||||||
|
|
||||||
if isinstance(function_call, dict):
|
if isinstance(function_call, dict):
|
||||||
function_call = FunctionCallResponse(**function_call)
|
function_call = ChoiceDeltaToolCallFunction(**function_call)
|
||||||
generate = parse_output_text(request.model, output, function_call=function_call)
|
generate = parse_output_text(request.model, output, function_call=function_call)
|
||||||
return EventSourceResponse(generate, media_type="text/event-stream")
|
return EventSourceResponse(generate, media_type="text/event-stream")
|
||||||
else:
|
else:
|
||||||
return EventSourceResponse(predict_stream_generator, media_type="text/event-stream")
|
return EventSourceResponse(predict_stream_generator, media_type="text/event-stream")
|
||||||
|
|
||||||
response = ""
|
response = ""
|
||||||
async for response in generate_stream_glm4(gen_params):
|
async for response in generate_stream_glm4(gen_params):
|
||||||
pass
|
pass
|
||||||
|
@ -390,7 +396,7 @@ async def create_chat_completion(request: ChatCompletionRequest):
|
||||||
|
|
||||||
if isinstance(function_call, dict):
|
if isinstance(function_call, dict):
|
||||||
finish_reason = "tool_calls"
|
finish_reason = "tool_calls"
|
||||||
function_call_response = FunctionCallResponse(**function_call)
|
function_call_response = ChoiceDeltaToolCallFunction(**function_call)
|
||||||
function_call_instance = FunctionCall(
|
function_call_instance = FunctionCall(
|
||||||
name=function_call_response.name,
|
name=function_call_response.name,
|
||||||
arguments=function_call_response.arguments
|
arguments=function_call_response.arguments
|
||||||
|
@ -421,7 +427,6 @@ async def create_chat_completion(request: ChatCompletionRequest):
|
||||||
|
|
||||||
return ChatCompletionResponse(
|
return ChatCompletionResponse(
|
||||||
model=request.model,
|
model=request.model,
|
||||||
id="",
|
|
||||||
choices=[choice_data],
|
choices=[choice_data],
|
||||||
object="chat.completion",
|
object="chat.completion",
|
||||||
usage=usage
|
usage=usage
|
||||||
|
@ -433,6 +438,9 @@ async def predict_stream(model_id, gen_params):
|
||||||
is_function_call = False
|
is_function_call = False
|
||||||
has_send_first_chunk = False
|
has_send_first_chunk = False
|
||||||
function_name = None
|
function_name = None
|
||||||
|
created_time = int(time.time())
|
||||||
|
response_id = generate_id('chatcmpl-', 29)
|
||||||
|
system_fingerprint = generate_id('fp_', 9)
|
||||||
async for new_response in generate_stream_glm4(gen_params):
|
async for new_response in generate_stream_glm4(gen_params):
|
||||||
decoded_unicode = new_response["text"]
|
decoded_unicode = new_response["text"]
|
||||||
delta_text = decoded_unicode[len(output):]
|
delta_text = decoded_unicode[len(output):]
|
||||||
|
@ -446,10 +454,16 @@ async def predict_stream(model_id, gen_params):
|
||||||
if is_function_call:
|
if is_function_call:
|
||||||
for char in delta_text:
|
for char in delta_text:
|
||||||
function_call = {"name": function_name, "arguments": char}
|
function_call = {"name": function_name, "arguments": char}
|
||||||
|
tool_call = ChatCompletionMessageToolCall(
|
||||||
|
index=0,
|
||||||
|
function=FunctionCall(**function_call),
|
||||||
|
type="function"
|
||||||
|
)
|
||||||
message = DeltaMessage(
|
message = DeltaMessage(
|
||||||
content=None,
|
content=None,
|
||||||
role="assistant",
|
role=None,
|
||||||
function_call=function_call
|
function_call=None,
|
||||||
|
tool_calls=[tool_call]
|
||||||
)
|
)
|
||||||
choice_data = ChatCompletionResponseStreamChoice(
|
choice_data = ChatCompletionResponseStreamChoice(
|
||||||
index=0,
|
index=0,
|
||||||
|
@ -458,9 +472,10 @@ async def predict_stream(model_id, gen_params):
|
||||||
)
|
)
|
||||||
chunk = ChatCompletionResponse(
|
chunk = ChatCompletionResponse(
|
||||||
model=model_id,
|
model=model_id,
|
||||||
id="",
|
id=response_id,
|
||||||
choices=[choice_data],
|
choices=[choice_data],
|
||||||
created=int(time.time()),
|
created=created_time,
|
||||||
|
system_fingerprint=system_fingerprint,
|
||||||
object="chat.completion.chunk"
|
object="chat.completion.chunk"
|
||||||
)
|
)
|
||||||
yield chunk.model_dump_json(exclude_unset=True)
|
yield chunk.model_dump_json(exclude_unset=True)
|
||||||
|
@ -480,9 +495,10 @@ async def predict_stream(model_id, gen_params):
|
||||||
)
|
)
|
||||||
chunk = ChatCompletionResponse(
|
chunk = ChatCompletionResponse(
|
||||||
model=model_id,
|
model=model_id,
|
||||||
id="",
|
id=response_id,
|
||||||
choices=[choice_data],
|
choices=[choice_data],
|
||||||
created=int(time.time()),
|
created=created_time,
|
||||||
|
system_fingerprint=system_fingerprint,
|
||||||
object="chat.completion.chunk"
|
object="chat.completion.chunk"
|
||||||
)
|
)
|
||||||
yield chunk.model_dump_json(exclude_unset=True)
|
yield chunk.model_dump_json(exclude_unset=True)
|
||||||
|
@ -501,20 +517,36 @@ async def predict_stream(model_id, gen_params):
|
||||||
)
|
)
|
||||||
chunk = ChatCompletionResponse(
|
chunk = ChatCompletionResponse(
|
||||||
model=model_id,
|
model=model_id,
|
||||||
id="",
|
id=response_id,
|
||||||
choices=[choice_data],
|
choices=[choice_data],
|
||||||
created=int(time.time()),
|
created=created_time,
|
||||||
|
system_fingerprint=system_fingerprint,
|
||||||
object="chat.completion.chunk"
|
object="chat.completion.chunk"
|
||||||
)
|
)
|
||||||
yield chunk.model_dump_json(exclude_unset=True)
|
yield chunk.model_dump_json(exclude_unset=True)
|
||||||
|
|
||||||
if is_function_call:
|
if is_function_call:
|
||||||
yield json.dumps({"text": output})
|
yield ChatCompletionResponse(
|
||||||
else:
|
model=model_id,
|
||||||
yield '[DONE]'
|
id=response_id,
|
||||||
|
system_fingerprint=system_fingerprint,
|
||||||
|
choices=[
|
||||||
|
ChatCompletionResponseStreamChoice(
|
||||||
|
index=0,
|
||||||
|
delta=DeltaMessage(
|
||||||
|
content=None,
|
||||||
|
role=None,
|
||||||
|
function_call=None,
|
||||||
|
),
|
||||||
|
finish_reason="tool_calls"
|
||||||
|
)],
|
||||||
|
created=created_time,
|
||||||
|
object="chat.completion.chunk",
|
||||||
|
usage=None
|
||||||
|
).model_dump_json(exclude_unset=True)
|
||||||
|
|
||||||
|
|
||||||
async def parse_output_text(model_id: str, value: str, function_call: FunctionCallResponse = None):
|
async def parse_output_text(model_id: str, value: str, function_call: ChoiceDeltaToolCallFunction = None):
|
||||||
delta = DeltaMessage(role="assistant", content=value)
|
delta = DeltaMessage(role="assistant", content=value)
|
||||||
if function_call is not None:
|
if function_call is not None:
|
||||||
delta.function_call = function_call
|
delta.function_call = function_call
|
||||||
|
@ -524,9 +556,13 @@ async def parse_output_text(model_id: str, value: str, function_call: FunctionCa
|
||||||
delta=delta,
|
delta=delta,
|
||||||
finish_reason=None
|
finish_reason=None
|
||||||
)
|
)
|
||||||
chunk = ChatCompletionResponse(model=model_id, id="", choices=[choice_data], object="chat.completion.chunk")
|
chunk = ChatCompletionResponse(
|
||||||
|
model=model_id,
|
||||||
|
choices=[choice_data],
|
||||||
|
object="chat.completion.chunk"
|
||||||
|
)
|
||||||
yield "{}".format(chunk.model_dump_json(exclude_unset=True))
|
yield "{}".format(chunk.model_dump_json(exclude_unset=True))
|
||||||
yield '[DONE]'
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, trust_remote_code=True)
|
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, trust_remote_code=True)
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
# use vllm
|
# use vllm
|
||||||
# vllm>=0.4.3
|
# vllm>=0.5.0
|
||||||
|
|
||||||
torch>=2.3.0
|
torch>=2.3.0
|
||||||
torchvision>=0.18.0
|
torchvision>=0.18.0
|
||||||
|
|
|
@ -166,7 +166,7 @@ OMP_NUM_THREADS=1 torchrun --standalone --nnodes=1 --nproc_per_node=8 finetune_
|
||||||
通过以下代码执行 **单机单卡** 运行。
|
通过以下代码执行 **单机单卡** 运行。
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
python finetune_hf.py data/AdvertiseGen/ THUDM/glm-4-9b-chat configs/lora.yaml
|
python finetune.py data/AdvertiseGen/ THUDM/glm-4-9b-chat configs/lora.yaml
|
||||||
```
|
```
|
||||||
|
|
||||||
## 从保存点进行微调
|
## 从保存点进行微调
|
||||||
|
@ -179,7 +179,7 @@ python finetune_hf.py data/AdvertiseGen/ THUDM/glm-4-9b-chat configs/lora.yam
|
||||||
例如,这就是一个从最后一个保存点继续微调的示例代码
|
例如,这就是一个从最后一个保存点继续微调的示例代码
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
python finetune_hf.py data/AdvertiseGen/ THUDM/glm-4-9b-chat configs/lora.yaml yes
|
python finetune.py data/AdvertiseGen/ THUDM/glm-4-9b-chat configs/lora.yaml yes
|
||||||
```
|
```
|
||||||
|
|
||||||
## 使用微调后的模型
|
## 使用微调后的模型
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
jieba>=0.42.1
|
jieba>=0.42.1
|
||||||
datasets>=2.19.1
|
datasets>2.20.0
|
||||||
peft>=0.11.0
|
peft>=0.11.1
|
||||||
deepspeed>=0.13.3
|
deepspeed>=0.14.3
|
||||||
nltk==3.8.1
|
nltk==3.8.1
|
Loading…
Reference in New Issue