Skip to content

Commit 98dca8d

Browse files
Deprecate sendnn_decoder in favor of sendnn with warmup_mode (#186)
If `TORCH_SENDNN_LOG` is set to WARNING instead of CRITICAL, there are logs stating > You're using a deprecated backend. Please use sendnn in conjunction with warmup_mode This PR makes the change to sendnn. For backwards compatibility, sendnn_decoder is overwritten to sendnn and a warning is logged. --------- Signed-off-by: Travis Johnson <[email protected]>
1 parent 153fd2a commit 98dca8d

File tree

12 files changed

+129
-105
lines changed

12 files changed

+129
-105
lines changed

docker/Dockerfile.amd64

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ ENV COMPILATION_MODE=offline_decoder \
7474
FLEX_UNLINK_DEVMEM=false \
7575
FLEX_RDMA_MODE_FULL=1 \
7676
TOKENIZERS_PARALLELISM=false \
77-
TORCH_SENDNN_LOG=CRITICAL
77+
TORCH_SENDNN_LOG=WARNING
7878

7979
# Required configuration file
8080
COPY docker/.senlib.json /home/senuser

docs/user_guide/configuration.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@ To run inference on IBM Spyre Accelerators, the backend should be set as:
1111

1212
| Model type | vLLM backend | `VLLM_SPYRE_DYNAMO_BACKEND` configuration | Notes |
1313
| --- | --- | --- | --- |
14-
| Decoder | v0 | sendnn_decoder | V0 support for decoder models is deprecated |
15-
| Decoder | v1 | sendnn_decoder | |
14+
| Decoder | v0 | sendnn | V0 support for decoder models is deprecated |
15+
| Decoder | v1 | sendnn | |
1616
| Embedding | v0 | sendnn | |
1717
| Embedding | v1 | N/A | Embedding models are not yet supported on V1 |
1818

examples/offline_inference_spyre.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@
113113
"0 / 1 : FLEX_RDMA_WORLD_RANK=0\n",
114114
"0 / 1 : FLEX_RDMA_WORLD_SIZE=1\n",
115115
"0 / 1 : Spyre: Enabled (0) (offset=0)\n",
116-
"0 / 1 : Dynamo Backend : sendnn_decoder\n",
116+
"0 / 1 : Dynamo Backend : sendnn\n",
117117
"0 / 1 : CPU Cores : 56 x 2 HW threads\n",
118118
"------------------------------------------------------------\n",
119119
"NOTICE: Adjusting torch._dynamo.config.accumulated_cache_size_limit from 64 to 160 to accommodate prompt size of 64 and decode tokens of 5\n",

tests/e2e/test_spyre_basic.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,59 @@ def test_output(
8787
hf_results=hf_results)
8888

8989

90+
@pytest.mark.parametrize("model", get_spyre_model_list())
91+
@pytest.mark.parametrize("prompts", [[
92+
template.format("Provide a list of instructions "
93+
"for preparing chicken soup."),
94+
]])
95+
@pytest.mark.parametrize(
96+
"warmup_shape", [(64, 20, 4)]) # (prompt_length/new_tokens/batch_size)
97+
@pytest.mark.parametrize("backend", ["sendnn_decoder"])
98+
@pytest.mark.parametrize("vllm_version", VLLM_VERSIONS)
99+
def test_output_sendnn_decoder(
100+
model: str,
101+
prompts: list[str],
102+
warmup_shape: tuple[int, int, int],
103+
backend: str,
104+
vllm_version: str,
105+
) -> None:
106+
'''
107+
Tests the deprecated sendnn_decoder backend, which should fall-back to
108+
sendnn
109+
'''
110+
111+
max_new_tokens = warmup_shape[1]
112+
113+
vllm_sampling_params = SamplingParams(
114+
max_tokens=max_new_tokens,
115+
temperature=0,
116+
logprobs=0, # return logprobs of generated tokens only
117+
ignore_eos=True)
118+
119+
vllm_results = generate_spyre_vllm_output(
120+
model=model,
121+
prompts=prompts,
122+
warmup_shapes=[warmup_shape],
123+
max_model_len=2048,
124+
block_size=2048,
125+
sampling_params=vllm_sampling_params,
126+
tensor_parallel_size=1,
127+
backend=backend,
128+
vllm_version=vllm_version)
129+
130+
hf_results = generate_hf_output(model=model,
131+
prompts=prompts,
132+
max_new_tokens=max_new_tokens)
133+
134+
compare_results(model=model,
135+
prompts=prompts,
136+
warmup_shapes=[warmup_shape],
137+
tensor_parallel_size=1,
138+
backend=backend,
139+
vllm_results=vllm_results,
140+
hf_results=hf_results)
141+
142+
90143
@pytest.mark.parametrize("model", get_spyre_model_list())
91144
@pytest.mark.parametrize("backend", get_spyre_backend_list())
92145
@pytest.mark.parametrize("vllm_version", VLLM_VERSIONS)

tests/e2e/test_spyre_embeddings.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
]])
1919
@pytest.mark.parametrize("warmup_shape",
2020
[(64, 4), (64, 8), (128, 4),
21-
(128, 8)]) # (prompt_length/new_tokens/batch_size)
21+
(128, 8)]) # (prompt_length/batch_size)
2222
@pytest.mark.parametrize("backend", get_spyre_backend_list())
2323
@pytest.mark.parametrize(
2424
"vllm_version",

tests/e2e/test_spyre_online.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,13 +57,13 @@ def test_openai_serving(remote_openai_server, model, warmup_shape, backend,
5757

5858

5959
@pytest.mark.parametrize("model", get_spyre_model_list(quantization="gptq"))
60-
@pytest.mark.parametrize("backend", ["sendnn_decoder"])
60+
@pytest.mark.parametrize("backend", ["sendnn"])
6161
@pytest.mark.parametrize("quantization", ["gptq"])
6262
@pytest.mark.parametrize("warmup_shape", [[(64, 20, 4)]])
6363
@pytest.mark.parametrize("vllm_version", VLLM_VERSIONS)
6464
def test_openai_serving_gptq(remote_openai_server, model, backend,
6565
warmup_shape, vllm_version, quantization):
66-
"""Test online serving a GPTQ model with the sendnn_decoder backend only"""
66+
"""Test online serving a GPTQ model with the sendnn backend only"""
6767

6868
client = remote_openai_server.get_client()
6969
completion = client.completions.create(model=model,

tests/spyre_util.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -324,7 +324,7 @@ def compare_results(model: str, prompts: list[str],
324324
print(f" vLLM: {repr(vllm_result['text']):s}{err_msg}")
325325
print()
326326

327-
assert DISABLE_ASSERTS or backend == 'sendnn_decoder' or\
327+
assert DISABLE_ASSERTS or backend == 'sendnn' or\
328328
hf_result['token_ids'] == vllm_result['token_ids']
329329

330330
if len(hf_result['tokens']) > 0:
@@ -351,13 +351,13 @@ def compare_results(model: str, prompts: list[str],
351351
f"{vllm_logprob:14f} ",
352352
end='')
353353

354-
if backend == 'sendnn_decoder':
354+
if backend == 'sendnn':
355355
rel_tol = ISCLOSE_REL_TOL_SPYRE
356356
else:
357357
rel_tol = ISCLOSE_REL_TOL_CPU
358358

359359
if hf_token_id != vllm_token_id: # different tokens
360-
if backend == 'sendnn_decoder' and math.isclose(
360+
if backend == 'sendnn' and math.isclose(
361361
hf_logprob, vllm_logprob, rel_tol=rel_tol):
362362
# probably still OK
363363
print('DIVERGING')
@@ -477,15 +477,15 @@ def get_spyre_model_dir_path() -> Path:
477477
# get model backends from env or default to all and add pytest markers
478478
def get_spyre_backend_list():
479479
user_backend_list = os.environ.get("VLLM_SPYRE_TEST_BACKEND_LIST",
480-
"eager,inductor,sendnn_decoder,sendnn")
480+
"eager,inductor,sendnn")
481481

482482
backends = []
483483
for backend in user_backend_list.split(","):
484484
backend = backend.strip()
485485
marks = []
486486
if backend == "eager":
487487
marks = [pytest.mark.cpu]
488-
elif backend == "sendnn_decoder":
488+
elif backend == "sendnn":
489489
marks = [pytest.mark.spyre]
490490

491491
backends.append(pytest.param(backend, marks=marks, id=backend))

tests/utils/test_spyre_backend_list.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,11 @@ def test_get_spyre_backend_list(monkeypatch):
88
Ensure we return the backend list correctly
99
'''
1010
with monkeypatch.context() as m:
11-
m.setenv("VLLM_SPYRE_TEST_BACKEND_LIST",
12-
"eager,inductor,sendnn_decoder")
11+
m.setenv("VLLM_SPYRE_TEST_BACKEND_LIST", "eager,inductor,sendnn")
1312
backend_list = get_spyre_backend_list()
1413
assert backend_list[0].values[0] == "eager"
1514
assert backend_list[1].values[0] == "inductor"
16-
assert backend_list[2].values[0] == "sendnn_decoder"
15+
assert backend_list[2].values[0] == "sendnn"
1716

1817
with monkeypatch.context() as m:
1918
m.setenv("VLLM_SPYRE_TEST_BACKEND_LIST", "sendnn")

vllm_spyre/envs.py

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
import os
22
from typing import TYPE_CHECKING, Any, Callable, Optional
33

4+
from vllm.logger import init_logger
5+
46
if TYPE_CHECKING:
5-
VLLM_SPYRE_DYNAMO_BACKEND: str = "sendnn_decoder"
7+
VLLM_SPYRE_DYNAMO_BACKEND: str = "sendnn"
68
VLLM_SPYRE_WARMUP_PROMPT_LENS: Optional[list[int]] = None
79
VLLM_SPYRE_WARMUP_NEW_TOKENS: Optional[list[int]] = None
810
VLLM_SPYRE_WARMUP_BATCH_SIZES: Optional[list[int]] = None
@@ -12,6 +14,19 @@
1214
VLLM_SPYRE_PERF_METRIC_LOGGING_DIR: str = "/tmp"
1315
VLLM_SPYRE_OVERRIDE_SIGNALS_HANDLER: bool = False
1416

17+
logger = init_logger(__name__)
18+
19+
20+
def _backend_backwards_compat() -> str:
21+
val = os.getenv("VLLM_SPYRE_DYNAMO_BACKEND", "sendnn")
22+
if val == "sendnn_decoder":
23+
logger.warning_once(
24+
"Using 'sendnn_decoder' for "
25+
"VLLM_SPYRE_DYNAMO_BACKEND is deprecated. Use 'sendnn' instead")
26+
val = 'sendnn'
27+
return val
28+
29+
1530
# --8<-- [start:env-vars-definition]
1631
environment_variables: dict[str, Callable[[], Any]] = {
1732
# Defines the prompt lengths the Spyre accelerator should be prepared
@@ -41,14 +56,13 @@
4156

4257
# Defines the backend that torch.compile will use when using Spyre
4358
# Available options:
44-
# - "sendnn_decoder": Compile for execution on Spyre hardware for
45-
# decoder models
46-
# - "sendnn": Compile for execution on Spyre hardware for
47-
# encoder models
59+
# - "sendnn": Compile for execution on Spyre hardware
4860
# - "inductor": Compile for execution on CPU (for debug and testing)
4961
# - "eager": Skip compile entirely (for debug and testing)
62+
#
63+
# - "sendnn_decoder": Deprecated in favor of "sendnn"
5064
"VLLM_SPYRE_DYNAMO_BACKEND":
51-
lambda: os.getenv("VLLM_SPYRE_DYNAMO_BACKEND", "sendnn_decoder"),
65+
_backend_backwards_compat,
5266

5367
# If set, use the V1 continuous batching implementation. Otherwise, static
5468
# batching mode will be enabled.

vllm_spyre/model_executor/model_loader/spyre.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
print("WARNING: Disabled: dynamo_tracer")
3030
pass
3131

32-
BACKEND_LIST = ['sendnn_decoder', 'inductor']
32+
BACKEND_LIST = ['sendnn', 'inductor']
3333

3434
logger = init_logger(__name__)
3535

@@ -88,7 +88,7 @@ def forward(
8888
self.model.past_key_value_states = None # type: ignore
8989

9090
extra_kwargs: dict[str, Any] = {}
91-
if envs_spyre.VLLM_SPYRE_DYNAMO_BACKEND != "sendnn_decoder":
91+
if envs_spyre.VLLM_SPYRE_DYNAMO_BACKEND != "sendnn":
9292
# Bug in 2.3.1 fixed in 2.4.1 for SDPA flash
9393
# cpu impl when padding too much
9494
extra_kwargs["attn_algorithm"] = "math"
@@ -153,7 +153,7 @@ def __init__(
153153

154154
self.config: PretrainedConfig = model_config.hf_config
155155
self.dtype = torch.float16 if envs_spyre.VLLM_SPYRE_DYNAMO_BACKEND == \
156-
'sendnn_decoder' else torch.float32
156+
'sendnn' else torch.float32
157157

158158
# Actual FMS model
159159
self.model: nn.Module
@@ -177,7 +177,7 @@ def load_weights(
177177
) -> None:
178178

179179
if model_config.quantization == "gptq":
180-
if envs_spyre.VLLM_SPYRE_DYNAMO_BACKEND == "sendnn_decoder":
180+
if envs_spyre.VLLM_SPYRE_DYNAMO_BACKEND == "sendnn":
181181
from fms_mo.aiu_addons.gptq import ( # noqa: F401
182182
gptq_aiu_adapter, gptq_aiu_linear)
183183
linear_type = "gptq_aiu"
@@ -215,7 +215,7 @@ def load_weights(
215215
revision=model_config.revision)
216216

217217
# we can use fused weights unless running on Spyre
218-
fused_weights = envs_spyre.VLLM_SPYRE_DYNAMO_BACKEND != "sendnn_decoder"
218+
fused_weights = envs_spyre.VLLM_SPYRE_DYNAMO_BACKEND != "sendnn"
219219

220220
self.model = get_model(architecture="hf_configured",
221221
variant=model_config.model,

0 commit comments

Comments
 (0)