Skip to content

Commit 6fcc767

Browse files
p88hDarkLight1337njhill
authored andcommitted
[V1][Performance] Implement custom serializaton for MultiModalKwargs [Rebased] (vllm-project#16432)
Signed-off-by: Staszek Pasko <[email protected]> Signed-off-by: Nick Hill <[email protected]> Co-authored-by: Cyrus Leung <[email protected]> Co-authored-by: Nick Hill <[email protected]>
1 parent c4ab734 commit 6fcc767

File tree

3 files changed

+212
-6
lines changed

3 files changed

+212
-6
lines changed

tests/v1/test_serial_utils.py

Lines changed: 98 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,16 @@
11
# SPDX-License-Identifier: Apache-2.0
22
from collections import UserDict
33
from dataclasses import dataclass
4+
from typing import Optional
45

6+
import msgspec
57
import numpy as np
68
import torch
79

10+
from vllm.multimodal.inputs import (MultiModalBatchedField,
11+
MultiModalFieldElem, MultiModalKwargs,
12+
MultiModalKwargsItem,
13+
MultiModalSharedField, NestedTensors)
814
from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder
915

1016

@@ -50,7 +56,7 @@ def test_encode_decode():
5056
large_non_contig_tensor=torch.rand(1024, 512)[:, 10:20],
5157
)
5258

53-
encoder = MsgpackEncoder()
59+
encoder = MsgpackEncoder(size_threshold=256)
5460
decoder = MsgpackDecoder(MyType)
5561

5662
encoded = encoder.encode(obj)
@@ -78,6 +84,97 @@ def test_encode_decode():
7884
assert_equal(decoded2, obj)
7985

8086

87+
class MyRequest(msgspec.Struct):
88+
mm: Optional[list[MultiModalKwargs]]
89+
90+
91+
def test_multimodal_kwargs():
92+
d = {
93+
"foo":
94+
torch.zeros(20000, dtype=torch.float16),
95+
"bar": [torch.zeros(i * 1000, dtype=torch.int8) for i in range(3)],
96+
"baz": [
97+
torch.rand((256), dtype=torch.float16),
98+
[
99+
torch.rand((1, 12), dtype=torch.float32),
100+
torch.rand((3, 5, 7), dtype=torch.float64),
101+
], [torch.rand((4, 4), dtype=torch.float16)]
102+
],
103+
}
104+
105+
# pack mm kwargs into a mock request so that it can be decoded properly
106+
req = MyRequest(mm=[MultiModalKwargs(d)])
107+
108+
encoder = MsgpackEncoder()
109+
decoder = MsgpackDecoder(MyRequest)
110+
111+
encoded = encoder.encode(req)
112+
113+
assert len(encoded) == 6
114+
115+
total_len = sum(memoryview(x).cast("B").nbytes for x in encoded)
116+
117+
# expected total encoding length, should be 44536, +-20 for minor changes
118+
assert total_len >= 44516 and total_len <= 44556
119+
decoded: MultiModalKwargs = decoder.decode(encoded).mm[0]
120+
assert all(nested_equal(d[k], decoded[k]) for k in d)
121+
122+
123+
def test_multimodal_items_by_modality():
124+
e1 = MultiModalFieldElem("audio", "a0", torch.zeros(1000,
125+
dtype=torch.int16),
126+
MultiModalBatchedField())
127+
e2 = MultiModalFieldElem(
128+
"video",
129+
"v0",
130+
[torch.zeros(1000, dtype=torch.int8) for _ in range(4)],
131+
MultiModalBatchedField(),
132+
)
133+
e3 = MultiModalFieldElem("image", "i0", torch.zeros(1000,
134+
dtype=torch.int32),
135+
MultiModalSharedField(4))
136+
e4 = MultiModalFieldElem("image", "i1", torch.zeros(1000,
137+
dtype=torch.int32),
138+
MultiModalBatchedField())
139+
audio = MultiModalKwargsItem.from_elems([e1])
140+
video = MultiModalKwargsItem.from_elems([e2])
141+
image = MultiModalKwargsItem.from_elems([e3, e4])
142+
mm = MultiModalKwargs.from_items([audio, video, image])
143+
144+
# pack mm kwargs into a mock request so that it can be decoded properly
145+
req = MyRequest([mm])
146+
147+
encoder = MsgpackEncoder()
148+
decoder = MsgpackDecoder(MyRequest)
149+
150+
encoded = encoder.encode(req)
151+
152+
assert len(encoded) == 8
153+
154+
total_len = sum(memoryview(x).cast("B").nbytes for x in encoded)
155+
156+
# expected total encoding length, should be 14255, +-20 for minor changes
157+
assert total_len >= 14235 and total_len <= 14275
158+
decoded: MultiModalKwargs = decoder.decode(encoded).mm[0]
159+
160+
# check all modalities were recovered and do some basic sanity checks
161+
assert len(decoded.modalities) == 3
162+
images = decoded.get_items("image")
163+
assert len(images) == 1
164+
assert len(images[0].items()) == 2
165+
assert list(images[0].keys()) == ["i0", "i1"]
166+
167+
# check the tensor contents and layout in the main dict
168+
assert all(nested_equal(mm[k], decoded[k]) for k in mm)
169+
170+
171+
def nested_equal(a: NestedTensors, b: NestedTensors):
172+
if isinstance(a, torch.Tensor):
173+
return torch.equal(a, b)
174+
else:
175+
return all(nested_equal(x, y) for x, y in zip(a, b))
176+
177+
81178
def assert_equal(obj1: MyType, obj2: MyType):
82179
assert torch.equal(obj1.tensor1, obj2.tensor1)
83180
assert obj1.a_string == obj2.a_string

vllm/envs.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,7 @@
107107
VLLM_TPU_BUCKET_PADDING_GAP: int = 0
108108
VLLM_USE_DEEP_GEMM: bool = False
109109
VLLM_XGRAMMAR_CACHE_MB: int = 0
110+
VLLM_MSGPACK_ZERO_COPY_THRESHOLD: int = 256
110111

111112

112113
def get_default_cache_root():
@@ -704,6 +705,16 @@ def maybe_convert_int(value: Optional[str]) -> Optional[int]:
704705
# It can be changed with this variable if needed for some reason.
705706
"VLLM_XGRAMMAR_CACHE_MB":
706707
lambda: int(os.getenv("VLLM_XGRAMMAR_CACHE_MB", "512")),
708+
709+
# Control the threshold for msgspec to use 'zero copy' for
710+
# serialization/deserialization of tensors. Tensors below
711+
# this limit will be encoded into the msgpack buffer, and
712+
# tensors above will instead be sent via a separate message.
713+
# While the sending side still actually copies the tensor
714+
# in all cases, on the receiving side, tensors above this
715+
# limit will actually be zero-copy decoded.
716+
"VLLM_MSGPACK_ZERO_COPY_THRESHOLD":
717+
lambda: int(os.getenv("VLLM_MSGPACK_ZERO_COPY_THRESHOLD", "256")),
707718
}
708719

709720
# end-env-vars-definition

vllm/v1/serial_utils.py

Lines changed: 103 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# SPDX-License-Identifier: Apache-2.0
22

3+
import dataclasses
34
import pickle
45
from collections.abc import Sequence
56
from inspect import isclass
@@ -12,12 +13,26 @@
1213
import zmq
1314
from msgspec import msgpack
1415

16+
from vllm import envs
17+
from vllm.multimodal.inputs import (BaseMultiModalField,
18+
MultiModalBatchedField,
19+
MultiModalFieldConfig, MultiModalFieldElem,
20+
MultiModalFlatField, MultiModalKwargs,
21+
MultiModalKwargsItem,
22+
MultiModalSharedField, NestedTensors)
23+
1524
CUSTOM_TYPE_PICKLE = 1
1625
CUSTOM_TYPE_CLOUDPICKLE = 2
1726
CUSTOM_TYPE_RAW_VIEW = 3
1827

19-
# TODO calibrate this size
20-
MIN_NOCOPY_BUF_SIZE = 512
28+
# MultiModalField class serialization type map.
29+
# These need to list all possible field types and match them
30+
# to factory methods in `MultiModalFieldConfig`.
31+
MMF_CLASS_TO_FACTORY: dict[type[BaseMultiModalField], str] = {
32+
MultiModalFlatField: "flat",
33+
MultiModalSharedField: "shared",
34+
MultiModalBatchedField: "batched",
35+
}
2136

2237
bytestr = Union[bytes, bytearray, memoryview, zmq.Frame]
2338

@@ -27,14 +42,20 @@ class MsgpackEncoder:
2742
2843
Note that unlike vanilla `msgspec` Encoders, this interface is generally
2944
not thread-safe when encoding tensors / numpy arrays.
45+
46+
By default, arrays below 256B are serialized inline Larger will get sent
47+
via dedicated messages. Note that this is a per-tensor limit.
3048
"""
3149

32-
def __init__(self):
50+
def __init__(self, size_threshold: Optional[int] = None):
51+
if size_threshold is None:
52+
size_threshold = envs.VLLM_MSGPACK_ZERO_COPY_THRESHOLD
3353
self.encoder = msgpack.Encoder(enc_hook=self.enc_hook)
3454
# This is used as a local stash of buffers that we can then access from
3555
# our custom `msgspec` hook, `enc_hook`. We don't have a way to
3656
# pass custom data to the hook otherwise.
3757
self.aux_buffers: Optional[list[bytestr]] = None
58+
self.size_threshold = size_threshold
3859

3960
def encode(self, obj: Any) -> Sequence[bytestr]:
4061
try:
@@ -65,6 +86,25 @@ def enc_hook(self, obj: Any) -> Any:
6586
if isinstance(obj, np.ndarray) and obj.dtype.kind not in ('O', 'V'):
6687
return self._encode_ndarray(obj)
6788

89+
if isinstance(obj, MultiModalKwargs):
90+
mm: MultiModalKwargs = obj
91+
if not mm.modalities:
92+
# just return the main dict if there are no modalities.
93+
return dict(mm)
94+
95+
# ignore the main dict, it will be re-indexed.
96+
# Encode a list of MultiModalKwargsItems as plain dicts
97+
# + special handling for .field.
98+
# Any tensors *not* indexed by modality will be ignored.
99+
return [[{
100+
"modality": elem.modality,
101+
"key": elem.key,
102+
"data": self._encode_nested_tensors(elem.data),
103+
"field": self._encode_mm_field(elem.field),
104+
} for elem in item.values()]
105+
for itemlist in mm._items_by_modality.values()
106+
for item in itemlist]
107+
68108
if isinstance(obj, FunctionType):
69109
# `pickle` is generally faster than cloudpickle, but can have
70110
# problems serializing methods.
@@ -77,8 +117,9 @@ def _encode_ndarray(
77117
self, obj: np.ndarray
78118
) -> tuple[str, tuple[int, ...], Union[int, memoryview]]:
79119
assert self.aux_buffers is not None
120+
# If the array is non-contiguous, we need to copy it first
80121
arr_data = obj.data if obj.data.c_contiguous else obj.tobytes()
81-
if not obj.shape or obj.nbytes < MIN_NOCOPY_BUF_SIZE:
122+
if not obj.shape or obj.nbytes < self.size_threshold:
82123
# Encode small arrays and scalars inline. Using this extension type
83124
# ensures we can avoid copying when decoding.
84125
data = msgpack.Ext(CUSTOM_TYPE_RAW_VIEW, arr_data)
@@ -92,6 +133,26 @@ def _encode_ndarray(
92133
# backing buffers that we've stashed in `aux_buffers`.
93134
return obj.dtype.str, obj.shape, data
94135

136+
def _encode_nested_tensors(self, nt: NestedTensors) -> Any:
137+
if isinstance(nt, torch.Tensor):
138+
return self._encode_ndarray(nt.numpy())
139+
if isinstance(nt, (int, float)):
140+
# Although it violates NestedTensors type, MultiModalKwargs
141+
# values are sometimes floats.
142+
return nt
143+
return [self._encode_nested_tensors(x) for x in nt]
144+
145+
def _encode_mm_field(self, field: BaseMultiModalField):
146+
# Figure out the factory name for the field type.
147+
name = MMF_CLASS_TO_FACTORY.get(field.__class__)
148+
if not name:
149+
raise TypeError(f"Unsupported field type: {field.__class__}")
150+
# We just need to copy all of the field values in order
151+
# which will be then used to reconstruct the field.
152+
field_values = (getattr(field, f.name)
153+
for f in dataclasses.fields(field))
154+
return name, *field_values
155+
95156

96157
class MsgpackDecoder:
97158
"""Decoder with custom torch tensor and numpy array serialization.
@@ -126,13 +187,50 @@ def dec_hook(self, t: type, obj: Any) -> Any:
126187
return self._decode_ndarray(obj)
127188
if issubclass(t, torch.Tensor):
128189
return torch.from_numpy(self._decode_ndarray(obj))
190+
if issubclass(t, MultiModalKwargs):
191+
if isinstance(obj, list):
192+
return MultiModalKwargs.from_items(
193+
self._decode_mm_items(obj))
194+
return MultiModalKwargs({
195+
k: self._decode_nested_tensors(v)
196+
for k, v in obj.items()
197+
})
129198
return obj
130199

131200
def _decode_ndarray(self, arr: Any) -> np.ndarray:
132201
dtype, shape, data = arr
133-
buffer = self.aux_buffers[data] if isinstance(data, int) else data
202+
# Copy from inline representation, otherwise Torch is unhappy since
203+
# the returned memory is non-writeable.
204+
buffer = self.aux_buffers[data] if isinstance(data, int) \
205+
else bytearray(data)
134206
return np.ndarray(buffer=buffer, dtype=np.dtype(dtype), shape=shape)
135207

208+
def _decode_mm_items(self, obj: list) -> list[MultiModalKwargsItem]:
209+
decoded_items = []
210+
for item in obj:
211+
elems = []
212+
for v in item:
213+
v["data"] = self._decode_nested_tensors(v["data"])
214+
# Reconstruct the field processor using MultiModalFieldConfig
215+
factory_meth_name, *field_args = v["field"]
216+
factory_meth = getattr(MultiModalFieldConfig,
217+
factory_meth_name)
218+
v["field"] = factory_meth(None, *field_args).field
219+
elems.append(MultiModalFieldElem(**v))
220+
decoded_items.append(MultiModalKwargsItem.from_elems(elems))
221+
return decoded_items
222+
223+
def _decode_nested_tensors(self, obj: Any) -> NestedTensors:
224+
if isinstance(obj, (int, float)):
225+
# Although it violates NestedTensors type, MultiModalKwargs
226+
# values are sometimes floats.
227+
return obj
228+
if not isinstance(obj, list):
229+
raise TypeError(f"Unexpected NestedTensors contents: {type(obj)}")
230+
if obj and isinstance(obj[0], str):
231+
return torch.from_numpy(self._decode_ndarray(obj))
232+
return [self._decode_nested_tensors(x) for x in obj]
233+
136234
def ext_hook(self, code: int, data: memoryview) -> Any:
137235
if code == CUSTOM_TYPE_RAW_VIEW:
138236
return data

0 commit comments

Comments
 (0)