Skip to content

Commit cce9a6e

Browse files
reidliu41dbyoung18
authored andcommitted
[Misc] refactor examples (vllm-project#16563)
Signed-off-by: reidliu41 <[email protected]> Co-authored-by: reidliu41 <[email protected]>
1 parent 31e51ad commit cce9a6e

File tree

5 files changed

+110
-71
lines changed

5 files changed

+110
-71
lines changed

examples/offline_inference/disaggregated_prefill.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ def run_decode(prefill_done):
9595
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
9696

9797

98-
if __name__ == "__main__":
98+
def main():
9999
prefill_done = Event()
100100
prefill_process = Process(target=run_prefill, args=(prefill_done, ))
101101
decode_process = Process(target=run_decode, args=(prefill_done, ))
@@ -109,3 +109,7 @@ def run_decode(prefill_done):
109109
# Terminate the prefill node when decode is finished
110110
decode_process.join()
111111
prefill_process.terminate()
112+
113+
114+
if __name__ == "__main__":
115+
main()

examples/offline_inference/disaggregated_prefill_lmcache.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,10 @@
3838
# `naive` indicates using raw bytes of the tensor without any compression
3939
os.environ["LMCACHE_REMOTE_SERDE"] = "naive"
4040

41+
prompts = [
42+
"Hello, how are you?" * 1000,
43+
]
44+
4145

4246
def run_prefill(prefill_done, prompts):
4347
# We use GPU 0 for prefill node.
@@ -106,12 +110,7 @@ def run_lmcache_server(port):
106110
return server_proc
107111

108112

109-
if __name__ == "__main__":
110-
111-
prompts = [
112-
"Hello, how are you?" * 1000,
113-
]
114-
113+
def main():
115114
prefill_done = Event()
116115
prefill_process = Process(target=run_prefill, args=(prefill_done, prompts))
117116
decode_process = Process(target=run_decode, args=(prefill_done, prompts))
@@ -128,3 +127,7 @@ def run_lmcache_server(port):
128127
prefill_process.terminate()
129128
lmcache_server_process.terminate()
130129
lmcache_server_process.wait()
130+
131+
132+
if __name__ == "__main__":
133+
main()

examples/online_serving/cohere_rerank_client.py

Lines changed: 38 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -2,32 +2,46 @@
22
"""
33
Example of using the OpenAI entrypoint's rerank API which is compatible with
44
the Cohere SDK: https://github.com/cohere-ai/cohere-python
5+
Note that `pip install cohere` is needed to run this example.
56
67
run: vllm serve BAAI/bge-reranker-base
78
"""
9+
from typing import Union
10+
811
import cohere
12+
from cohere import Client, ClientV2
13+
14+
model = "BAAI/bge-reranker-base"
15+
16+
query = "What is the capital of France?"
17+
18+
documents = [
19+
"The capital of France is Paris", "Reranking is fun!",
20+
"vLLM is an open-source framework for fast AI serving"
21+
]
22+
23+
24+
def cohere_rerank(client: Union[Client, ClientV2], model: str, query: str,
25+
documents: list[str]) -> dict:
26+
return client.rerank(model=model, query=query, documents=documents)
27+
28+
29+
def main():
30+
# cohere v1 client
31+
cohere_v1 = cohere.Client(base_url="http://localhost:8000",
32+
api_key="sk-fake-key")
33+
rerank_v1_result = cohere_rerank(cohere_v1, model, query, documents)
34+
print("-" * 50)
35+
print("rerank_v1_result:\n", rerank_v1_result)
36+
print("-" * 50)
37+
38+
# or the v2
39+
cohere_v2 = cohere.ClientV2("sk-fake-key",
40+
base_url="http://localhost:8000")
41+
rerank_v2_result = cohere_rerank(cohere_v2, model, query, documents)
42+
print("rerank_v2_result:\n", rerank_v2_result)
43+
print("-" * 50)
44+
945

10-
# cohere v1 client
11-
co = cohere.Client(base_url="http://localhost:8000", api_key="sk-fake-key")
12-
rerank_v1_result = co.rerank(
13-
model="BAAI/bge-reranker-base",
14-
query="What is the capital of France?",
15-
documents=[
16-
"The capital of France is Paris", "Reranking is fun!",
17-
"vLLM is an open-source framework for fast AI serving"
18-
])
19-
20-
print(rerank_v1_result)
21-
22-
# or the v2
23-
co2 = cohere.ClientV2("sk-fake-key", base_url="http://localhost:8000")
24-
25-
v2_rerank_result = co2.rerank(
26-
model="BAAI/bge-reranker-base",
27-
query="What is the capital of France?",
28-
documents=[
29-
"The capital of France is Paris", "Reranking is fun!",
30-
"vLLM is an open-source framework for fast AI serving"
31-
])
32-
33-
print(v2_rerank_result)
46+
if __name__ == "__main__":
47+
main()

examples/online_serving/jinaai_rerank_client.py

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,19 @@
2323
"The capital of France is Paris.", "Horses and cows are both animals"
2424
]
2525
}
26-
response = requests.post(url, headers=headers, json=data)
27-
28-
# Check the response
29-
if response.status_code == 200:
30-
print("Request successful!")
31-
print(json.dumps(response.json(), indent=2))
32-
else:
33-
print(f"Request failed with status code: {response.status_code}")
34-
print(response.text)
26+
27+
28+
def main():
29+
response = requests.post(url, headers=headers, json=data)
30+
31+
# Check the response
32+
if response.status_code == 200:
33+
print("Request successful!")
34+
print(json.dumps(response.json(), indent=2))
35+
else:
36+
print(f"Request failed with status code: {response.status_code}")
37+
print(response.text)
38+
39+
40+
if __name__ == "__main__":
41+
main()
Lines changed: 42 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,38 +1,49 @@
11
# SPDX-License-Identifier: Apache-2.0
2-
2+
"""Example Python client for OpenAI Chat Completion using vLLM API server
3+
NOTE: start a supported chat completion model server with `vllm serve`, e.g.
4+
vllm serve meta-llama/Llama-2-7b-chat-hf
5+
"""
36
from openai import OpenAI
47

58
# Modify OpenAI's API key and API base to use vLLM's API server.
69
openai_api_key = "EMPTY"
710
openai_api_base = "http://localhost:8000/v1"
811

9-
client = OpenAI(
10-
# defaults to os.environ.get("OPENAI_API_KEY")
11-
api_key=openai_api_key,
12-
base_url=openai_api_base,
13-
)
14-
15-
models = client.models.list()
16-
model = models.data[0].id
17-
18-
chat_completion = client.chat.completions.create(
19-
messages=[{
20-
"role": "system",
21-
"content": "You are a helpful assistant."
22-
}, {
23-
"role": "user",
24-
"content": "Who won the world series in 2020?"
25-
}, {
26-
"role":
27-
"assistant",
28-
"content":
29-
"The Los Angeles Dodgers won the World Series in 2020."
30-
}, {
31-
"role": "user",
32-
"content": "Where was it played?"
33-
}],
34-
model=model,
35-
)
36-
37-
print("Chat completion results:")
38-
print(chat_completion)
12+
messages = [{
13+
"role": "system",
14+
"content": "You are a helpful assistant."
15+
}, {
16+
"role": "user",
17+
"content": "Who won the world series in 2020?"
18+
}, {
19+
"role": "assistant",
20+
"content": "The Los Angeles Dodgers won the World Series in 2020."
21+
}, {
22+
"role": "user",
23+
"content": "Where was it played?"
24+
}]
25+
26+
27+
def main():
28+
client = OpenAI(
29+
# defaults to os.environ.get("OPENAI_API_KEY")
30+
api_key=openai_api_key,
31+
base_url=openai_api_base,
32+
)
33+
34+
models = client.models.list()
35+
model = models.data[0].id
36+
37+
chat_completion = client.chat.completions.create(
38+
messages=messages,
39+
model=model,
40+
)
41+
42+
print("-" * 50)
43+
print("Chat completion results:")
44+
print(chat_completion)
45+
print("-" * 50)
46+
47+
48+
if __name__ == "__main__":
49+
main()

0 commit comments

Comments
 (0)