Skip to content

Commit 739f303

Browse files
whybeyoungjhinpan
andcommitted
feat(oai refactor): Replace openai_api with entrypoints/openai (sgl-project#7351)
Co-authored-by: Jin Pan <[email protected]>
1 parent c4b4d10 commit 739f303

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

42 files changed

+673
-2406
lines changed

benchmark/hicache/data_processing.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
get_gen_prefix_cache_path,
2121
)
2222
from sglang.lang.chat_template import get_chat_template, get_chat_template_by_model_path
23-
from sglang.srt.openai_api.protocol import ChatCompletionMessageContentPart
23+
from sglang.srt.entrypoints.openai.protocol import ChatCompletionMessageContentPart
2424
from sglang.utils import encode_video_base64
2525

2626
# type of content fields, can be only prompts or with images/videos

docs/backend/openai_api_embeddings.ipynb

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -64,11 +64,14 @@
6464
"text = \"Once upon a time\"\n",
6565
"\n",
6666
"curl_text = f\"\"\"curl -s http://localhost:{port}/v1/embeddings \\\n",
67+
" -H \"Content-Type: application/json\" \\\n",
6768
" -d '{{\"model\": \"Alibaba-NLP/gte-Qwen2-1.5B-instruct\", \"input\": \"{text}\"}}'\"\"\"\n",
6869
"\n",
69-
"text_embedding = json.loads(subprocess.check_output(curl_text, shell=True))[\"data\"][0][\n",
70-
" \"embedding\"\n",
71-
"]\n",
70+
"result = subprocess.check_output(curl_text, shell=True)\n",
71+
"\n",
72+
"print(result)\n",
73+
"\n",
74+
"text_embedding = json.loads(result)[\"data\"][0][\"embedding\"]\n",
7275
"\n",
7376
"print_highlight(f\"Text embedding (first 10): {text_embedding[:10]}\")"
7477
]
@@ -152,6 +155,7 @@
152155
"input_ids = tokenizer.encode(text)\n",
153156
"\n",
154157
"curl_ids = f\"\"\"curl -s http://localhost:{port}/v1/embeddings \\\n",
158+
" -H \"Content-Type: application/json\" \\\n",
155159
" -d '{{\"model\": \"Alibaba-NLP/gte-Qwen2-1.5B-instruct\", \"input\": {json.dumps(input_ids)}}}'\"\"\"\n",
156160
"\n",
157161
"input_ids_embedding = json.loads(subprocess.check_output(curl_ids, shell=True))[\"data\"][\n",

docs/backend/openai_api_vision.ipynb

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@
6767
"\n",
6868
"curl_command = f\"\"\"\n",
6969
"curl -s http://localhost:{port}/v1/chat/completions \\\\\n",
70+
" -H \"Content-Type: application/json\" \\\\\n",
7071
" -d '{{\n",
7172
" \"model\": \"Qwen/Qwen2.5-VL-7B-Instruct\",\n",
7273
" \"messages\": [\n",

docs/backend/vlm_query.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
"import requests\n",
3737
"from PIL import Image\n",
3838
"\n",
39-
"from sglang.srt.openai_api.protocol import ChatCompletionRequest\n",
39+
"from sglang.srt.entrypoints.openai.protocol import ChatCompletionRequest\n",
4040
"from sglang.srt.conversation import chat_templates\n",
4141
"\n",
4242
"image = Image.open(\n",

python/sglang/srt/code_completion_parser.py

Lines changed: 0 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,7 @@
1515

1616

1717
import dataclasses
18-
import json
1918
import logging
20-
import os
2119
from enum import auto
2220

2321
from sglang.srt.entrypoints.openai.protocol import CompletionRequest
@@ -57,46 +55,6 @@ class CompletionTemplate:
5755
completion_templates: dict[str, CompletionTemplate] = {}
5856

5957

60-
def load_completion_template_for_openai_api(completion_template_arg):
61-
global completion_template_name
62-
63-
logger.info(
64-
f"Use completion template for the OpenAI-compatible API server: {completion_template_arg}"
65-
)
66-
67-
if not completion_template_exists(completion_template_arg):
68-
if not os.path.exists(completion_template_arg):
69-
raise RuntimeError(
70-
f"Completion template {completion_template_arg} is not a built-in template name "
71-
"or a valid completion template file path."
72-
)
73-
74-
assert completion_template_arg.endswith(
75-
".json"
76-
), "unrecognized format of completion template file"
77-
with open(completion_template_arg, "r") as filep:
78-
template = json.load(filep)
79-
try:
80-
fim_position = FimPosition[template["fim_position"]]
81-
except KeyError:
82-
raise ValueError(
83-
f"Unknown fim position: {template['fim_position']}"
84-
) from None
85-
register_completion_template(
86-
CompletionTemplate(
87-
name=template["name"],
88-
fim_begin_token=template["fim_begin_token"],
89-
fim_middle_token=template["fim_middle_token"],
90-
fim_end_token=template["fim_end_token"],
91-
fim_position=fim_position,
92-
),
93-
override=True,
94-
)
95-
completion_template_name = template["name"]
96-
else:
97-
completion_template_name = completion_template_arg
98-
99-
10058
def register_completion_template(template: CompletionTemplate, override: bool = False):
10159
"""Register a new completion template."""
10260
if not override:

python/sglang/srt/conversation.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,17 @@
1111
# See the License for the specific language governing permissions and
1212
# limitations under the License.
1313
# ==============================================================================
14-
"""Conversation chat templates."""
14+
"""Conversation chat templates.
15+
16+
This module provides conversation template definitions, data structures, and utilities
17+
for managing chat templates across different model types in SGLang.
18+
19+
Key components:
20+
- Conversation class: Defines the structure and behavior of chat templates
21+
- SeparatorStyle enum: Different conversation formatting styles
22+
- Template registry: Functions to register and retrieve templates by name or model path
23+
- Built-in templates: Pre-defined templates for popular models
24+
"""
1525

1626
# Adapted from
1727
# https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
@@ -20,7 +30,7 @@
2030
from enum import IntEnum, auto
2131
from typing import Callable, Dict, List, Optional, Tuple, Union
2232

23-
from sglang.srt.openai_api.protocol import ChatCompletionRequest
33+
from sglang.srt.entrypoints.openai.protocol import ChatCompletionRequest
2434
from sglang.srt.utils import read_system_prompt_from_file
2535

2636

@@ -618,7 +628,7 @@ def generate_chat_conv(
618628

619629

620630
# llama2 template
621-
# reference: https://huggingface.co/blog/codellama#conversational-instructions
631+
# reference: https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
622632
# reference: https://github.com/facebookresearch/llama/blob/1a240688810f8036049e8da36b073f63d2ac552c/llama/generation.py#L212
623633
register_conv_template(
624634
Conversation(

python/sglang/srt/entrypoints/engine.py

Lines changed: 15 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,6 @@
3737
import torch
3838
import uvloop
3939

40-
from sglang.srt.code_completion_parser import load_completion_template_for_openai_api
4140
from sglang.srt.entrypoints.EngineBase import EngineBase
4241
from sglang.srt.managers.data_parallel_controller import (
4342
run_data_parallel_controller_process,
@@ -58,11 +57,8 @@
5857
UpdateWeightsFromTensorReqInput,
5958
)
6059
from sglang.srt.managers.scheduler import run_scheduler_process
60+
from sglang.srt.managers.template_manager import TemplateManager
6161
from sglang.srt.managers.tokenizer_manager import TokenizerManager
62-
from sglang.srt.openai_api.adapter import (
63-
guess_chat_template_name_from_model_path,
64-
load_chat_template_for_openai_api,
65-
)
6662
from sglang.srt.server_args import PortArgs, ServerArgs
6763
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
6864
from sglang.srt.utils import (
@@ -123,12 +119,13 @@ def __init__(self, **kwargs):
123119
logger.info(f"{server_args=}")
124120

125121
# Launch subprocesses
126-
tokenizer_manager, scheduler_info = _launch_subprocesses(
122+
tokenizer_manager, template_manager, scheduler_info = _launch_subprocesses(
127123
server_args=server_args,
128124
port_args=port_args,
129125
)
130126
self.server_args = server_args
131127
self.tokenizer_manager = tokenizer_manager
128+
self.template_manager = template_manager
132129
self.scheduler_info = scheduler_info
133130

134131
context = zmq.Context(2)
@@ -647,7 +644,7 @@ def sigquit_handler(signum, frame):
647644

648645
def _launch_subprocesses(
649646
server_args: ServerArgs, port_args: Optional[PortArgs] = None
650-
) -> Tuple[TokenizerManager, Dict]:
647+
) -> Tuple[TokenizerManager, TemplateManager, Dict]:
651648
"""
652649
Launch the TokenizerManager in the main process, the Scheduler in a subprocess, and the DetokenizerManager in another subprocess.
653650
"""
@@ -732,7 +729,7 @@ def _launch_subprocesses(
732729

733730
if os.getenv("SGLANG_BLOCK_NONZERO_RANK_CHILDREN") == "0":
734731
# When using `Engine` as a Python API, we don't want to block here.
735-
return None, None
732+
return None, None, None
736733

737734
launch_dummy_health_check_server(server_args.host, server_args.port)
738735

@@ -741,7 +738,7 @@ def _launch_subprocesses(
741738
logger.error(
742739
f"Scheduler or DataParallelController {proc.pid} terminated with {proc.exitcode}"
743740
)
744-
return None, None
741+
return None, None, None
745742

746743
# Launch detokenizer process
747744
detoken_proc = mp.Process(
@@ -755,15 +752,15 @@ def _launch_subprocesses(
755752

756753
# Launch tokenizer process
757754
tokenizer_manager = TokenizerManager(server_args, port_args)
758-
if server_args.chat_template:
759-
load_chat_template_for_openai_api(
760-
tokenizer_manager, server_args.chat_template, server_args.model_path
761-
)
762-
else:
763-
guess_chat_template_name_from_model_path(server_args.model_path)
764755

765-
if server_args.completion_template:
766-
load_completion_template_for_openai_api(server_args.completion_template)
756+
# Initialize templates
757+
template_manager = TemplateManager()
758+
template_manager.initialize_templates(
759+
tokenizer_manager=tokenizer_manager,
760+
model_path=server_args.model_path,
761+
chat_template=server_args.chat_template,
762+
completion_template=server_args.completion_template,
763+
)
767764

768765
# Wait for the model to finish loading
769766
scheduler_infos = []
@@ -787,4 +784,4 @@ def _launch_subprocesses(
787784
# Assume all schedulers have the same scheduler_info
788785
scheduler_info = scheduler_infos[0]
789786
tokenizer_manager.max_req_input_len = scheduler_info["max_req_input_len"]
790-
return tokenizer_manager, scheduler_info
787+
return tokenizer_manager, template_manager, scheduler_info

0 commit comments

Comments
 (0)