Skip to content

[V1][Performance] Implement custom serializaton for MultiModalKwargs [Rebased] #16432

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 31 commits into from
Apr 17, 2025
Merged
Show file tree
Hide file tree
Changes from 30 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
7b6b7ba
Implement efficient serialization of MultiModalKwargs
p88h Apr 10, 2025
4bdd16e
Apply suggestions from code review
p88h Apr 11, 2025
e5931af
Additional fixes after code review
p88h Apr 11, 2025
6641584
Fix some broken bits & reformat
p88h Apr 11, 2025
a94df99
Add custom support for MultiModalFieldConfig, less pickle
p88h Apr 11, 2025
57467e2
Too many stars. Test for other field types.
p88h Apr 11, 2025
d993e42
Set zero-copy threshold to 256MB. Also copy out tensors.
p88h Apr 11, 2025
3401429
Make mypy happy, and also simplify field type restore
p88h Apr 11, 2025
252d8a0
Merge branch 'vllm-project:main' into serialize-multimodal-kwargs
p88h Apr 11, 2025
5902d6e
Merge branch 'main' into serialize-multimodal-kwargs
p88h Apr 12, 2025
57e1922
style fix
p88h Apr 12, 2025
2c0e9a8
Merge branch 'vllm-project:main' into serialize-multimodal-kwargs
p88h Apr 13, 2025
176ba06
Copy memory when sending, zero copy when receiving
p88h Apr 13, 2025
3461ce6
Merge branch 'vllm-project:main' into serialize-multimodal-kwargs
p88h Apr 15, 2025
578aab8
Add threshold env var, re-do field serialization, cleanup
p88h Apr 15, 2025
91a4500
Merge branch 'vllm-project:main' into serialize-multimodal-kwargs
p88h Apr 15, 2025
3d4e380
Merge branch 'vllm-project:main' into serialize-multimodal-kwargs
p88h Apr 15, 2025
c61c87a
Merge branch 'vllm-project:main' into serialize-multimodal-kwargs
p88h Apr 15, 2025
936c95e
remove asdict() which involves object deep copy.
p88h Apr 15, 2025
7cf5492
Bring back zero-copy, plus more review updates
p88h Apr 15, 2025
12c9d8b
Apply suggestions from code review
p88h Apr 15, 2025
8bda83c
fix review edits
p88h Apr 15, 2025
678cba1
revert encode_into changes
p88h Apr 15, 2025
f8d26df
Apply suggestions from code review
p88h Apr 16, 2025
bce2f07
Small fixes
p88h Apr 16, 2025
d7cb694
Merge branch 'vllm-project:main' into serialize-multimodal-kwargs
p88h Apr 16, 2025
7511262
style
p88h Apr 16, 2025
97188e6
Merge branch 'vllm-project:main' into serialize-multimodal-kwargs
p88h Apr 16, 2025
48ab2d9
remove unnecessary comment
p88h Apr 16, 2025
a60333e
Merge branch 'vllm-project:main' into serialize-multimodal-kwargs
p88h Apr 16, 2025
281f0f1
Accommodate floats in NestedTensors
njhill Apr 16, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 98 additions & 1 deletion tests/v1/test_serial_utils.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,16 @@
# SPDX-License-Identifier: Apache-2.0
from collections import UserDict
from dataclasses import dataclass
from typing import Optional

import msgspec
import numpy as np
import torch

from vllm.multimodal.inputs import (MultiModalBatchedField,
MultiModalFieldElem, MultiModalKwargs,
MultiModalKwargsItem,
MultiModalSharedField, NestedTensors)
from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder


Expand Down Expand Up @@ -50,7 +56,7 @@ def test_encode_decode():
large_non_contig_tensor=torch.rand(1024, 512)[:, 10:20],
)

encoder = MsgpackEncoder()
encoder = MsgpackEncoder(size_threshold=256)
decoder = MsgpackDecoder(MyType)

encoded = encoder.encode(obj)
Expand Down Expand Up @@ -78,6 +84,97 @@ def test_encode_decode():
assert_equal(decoded2, obj)


class MyRequest(msgspec.Struct):
mm: Optional[list[MultiModalKwargs]]


def test_multimodal_kwargs():
d = {
"foo":
torch.zeros(20000, dtype=torch.float16),
"bar": [torch.zeros(i * 1000, dtype=torch.int8) for i in range(3)],
"baz": [
torch.rand((256), dtype=torch.float16),
[
torch.rand((1, 12), dtype=torch.float32),
torch.rand((3, 5, 7), dtype=torch.float64),
], [torch.rand((4, 4), dtype=torch.float16)]
],
}

# pack mm kwargs into a mock request so that it can be decoded properly
req = MyRequest(mm=[MultiModalKwargs(d)])

encoder = MsgpackEncoder()
decoder = MsgpackDecoder(MyRequest)

encoded = encoder.encode(req)

assert len(encoded) == 6

total_len = sum(memoryview(x).cast("B").nbytes for x in encoded)

# expected total encoding length, should be 44536, +-20 for minor changes
assert total_len >= 44516 and total_len <= 44556
decoded: MultiModalKwargs = decoder.decode(encoded).mm[0]
assert all(nested_equal(d[k], decoded[k]) for k in d)


def test_multimodal_items_by_modality():
e1 = MultiModalFieldElem("audio", "a0", torch.zeros(1000,
dtype=torch.int16),
MultiModalBatchedField())
e2 = MultiModalFieldElem(
"video",
"v0",
[torch.zeros(1000, dtype=torch.int8) for _ in range(4)],
MultiModalBatchedField(),
)
e3 = MultiModalFieldElem("image", "i0", torch.zeros(1000,
dtype=torch.int32),
MultiModalSharedField(4))
e4 = MultiModalFieldElem("image", "i1", torch.zeros(1000,
dtype=torch.int32),
MultiModalBatchedField())
audio = MultiModalKwargsItem.from_elems([e1])
video = MultiModalKwargsItem.from_elems([e2])
image = MultiModalKwargsItem.from_elems([e3, e4])
mm = MultiModalKwargs.from_items([audio, video, image])

# pack mm kwargs into a mock request so that it can be decoded properly
req = MyRequest([mm])

encoder = MsgpackEncoder()
decoder = MsgpackDecoder(MyRequest)

encoded = encoder.encode(req)

assert len(encoded) == 8

total_len = sum(memoryview(x).cast("B").nbytes for x in encoded)

# expected total encoding length, should be 14255, +-20 for minor changes
assert total_len >= 14235 and total_len <= 14275
decoded: MultiModalKwargs = decoder.decode(encoded).mm[0]

# check all modalities were recovered and do some basic sanity checks
assert len(decoded.modalities) == 3
images = decoded.get_items("image")
assert len(images) == 1
assert len(images[0].items()) == 2
assert list(images[0].keys()) == ["i0", "i1"]

# check the tensor contents and layout in the main dict
assert all(nested_equal(mm[k], decoded[k]) for k in mm)


def nested_equal(a: NestedTensors, b: NestedTensors):
if isinstance(a, torch.Tensor):
return torch.equal(a, b)
else:
return all(nested_equal(x, y) for x, y in zip(a, b))


def assert_equal(obj1: MyType, obj2: MyType):
assert torch.equal(obj1.tensor1, obj2.tensor1)
assert obj1.a_string == obj2.a_string
Expand Down
11 changes: 11 additions & 0 deletions vllm/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@
VLLM_TPU_BUCKET_PADDING_GAP: int = 0
VLLM_USE_DEEP_GEMM: bool = False
VLLM_XGRAMMAR_CACHE_MB: int = 0
VLLM_MSGPACK_ZERO_COPY_THRESHOLD: int = 256


def get_default_cache_root():
Expand Down Expand Up @@ -704,6 +705,16 @@ def maybe_convert_int(value: Optional[str]) -> Optional[int]:
# It can be changed with this variable if needed for some reason.
"VLLM_XGRAMMAR_CACHE_MB":
lambda: int(os.getenv("VLLM_XGRAMMAR_CACHE_MB", "512")),

# Control the threshold for msgspec to use 'zero copy' for
# serialization/deserialization of tensors. Tensors below
# this limit will be encoded into the msgpack buffer, and
# tensors above will instead be sent via a separate message.
# While the sending side still actually copies the tensor
# in all cases, on the receiving side, tensors above this
# limit will actually be zero-copy decoded.
"VLLM_MSGPACK_ZERO_COPY_THRESHOLD":
lambda: int(os.getenv("VLLM_MSGPACK_ZERO_COPY_THRESHOLD", "256")),
}

# end-env-vars-definition
Expand Down
100 changes: 95 additions & 5 deletions vllm/v1/serial_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# SPDX-License-Identifier: Apache-2.0

import dataclasses
import pickle
from collections.abc import Sequence
from inspect import isclass
Expand All @@ -12,12 +13,26 @@
import zmq
from msgspec import msgpack

from vllm import envs
from vllm.multimodal.inputs import (BaseMultiModalField,
MultiModalBatchedField,
MultiModalFieldConfig, MultiModalFieldElem,
MultiModalFlatField, MultiModalKwargs,
MultiModalKwargsItem,
MultiModalSharedField, NestedTensors)

CUSTOM_TYPE_PICKLE = 1
CUSTOM_TYPE_CLOUDPICKLE = 2
CUSTOM_TYPE_RAW_VIEW = 3

# TODO calibrate this size
MIN_NOCOPY_BUF_SIZE = 512
# MultiModealField class serialization type map.
# These need to list all possible field types and match them
# to factory methods in `MultiModalFieldConfig`.
MMF_CLASS_TO_FACTORY = {
MultiModalFlatField: "flat",
MultiModalSharedField: "shared",
MultiModalBatchedField: "batched",
}

bytestr = Union[bytes, bytearray, memoryview, zmq.Frame]

Expand All @@ -27,14 +42,20 @@ class MsgpackEncoder:

Note that unlike vanilla `msgspec` Encoders, this interface is generally
not thread-safe when encoding tensors / numpy arrays.

By default, arrays below 256B are serialized inline Larger will get sent
via dedicated messages. Note that this is a per-tensor limit.
"""

def __init__(self):
def __init__(self, size_threshold: Optional[int] = None):
if size_threshold is None:
size_threshold = envs.VLLM_MSGPACK_ZERO_COPY_THRESHOLD
self.encoder = msgpack.Encoder(enc_hook=self.enc_hook)
# This is used as a local stash of buffers that we can then access from
# our custom `msgspec` hook, `enc_hook`. We don't have a way to
# pass custom data to the hook otherwise.
self.aux_buffers: Optional[list[bytestr]] = None
self.size_threshold = size_threshold

def encode(self, obj: Any) -> Sequence[bytestr]:
try:
Expand Down Expand Up @@ -65,6 +86,25 @@ def enc_hook(self, obj: Any) -> Any:
if isinstance(obj, np.ndarray) and obj.dtype.kind not in ('O', 'V'):
return self._encode_ndarray(obj)

if isinstance(obj, MultiModalKwargs):
mm: MultiModalKwargs = obj
if not mm.modalities:
# just return the main dict if there are no modalities.
return dict(mm)

# ignore the main dict, it will be re-indexed.
# Encode a list of MultiModalKwargsItems as plain dicts
# + special handling for .field.
# Any tensors *not* indexed by modality will be ignored.
return [[{
"modality": elem.modality,
"key": elem.key,
"data": self._encode_nested_tensors(elem.data),
"field": self._encode_mm_field(elem.field),
} for elem in item.values()]
for itemlist in mm._items_by_modality.values()
for item in itemlist]

if isinstance(obj, FunctionType):
# `pickle` is generally faster than cloudpickle, but can have
# problems serializing methods.
Expand All @@ -77,8 +117,9 @@ def _encode_ndarray(
self, obj: np.ndarray
) -> tuple[str, tuple[int, ...], Union[int, memoryview]]:
assert self.aux_buffers is not None
# If the array is non-contiguous, we need to copy it first
arr_data = obj.data if obj.data.c_contiguous else obj.tobytes()
if not obj.shape or obj.nbytes < MIN_NOCOPY_BUF_SIZE:
if not obj.shape or obj.nbytes < self.size_threshold:
# Encode small arrays and scalars inline. Using this extension type
# ensures we can avoid copying when decoding.
data = msgpack.Ext(CUSTOM_TYPE_RAW_VIEW, arr_data)
Expand All @@ -92,6 +133,22 @@ def _encode_ndarray(
# backing buffers that we've stashed in `aux_buffers`.
return obj.dtype.str, obj.shape, data

def _encode_nested_tensors(self, obj: Any) -> NestedTensors:
if isinstance(obj, torch.Tensor):
return self._encode_ndarray(obj.numpy())
return [self._encode_nested_tensors(x) for x in obj]

def _encode_mm_field(self, field: BaseMultiModalField):
# Figure out the factory name for the field type.
name = MMF_CLASS_TO_FACTORY.get(field.__class__)
if not name:
raise TypeError(f"Unsupported field type: {field.__class__}")
# We just need to copy all of the field values in order
# which will be then used to reconstruct the field.
field_values = (getattr(field, f.name)
for f in dataclasses.fields(field))
return (name, *field_values)


class MsgpackDecoder:
"""Decoder with custom torch tensor and numpy array serialization.
Expand Down Expand Up @@ -126,13 +183,46 @@ def dec_hook(self, t: type, obj: Any) -> Any:
return self._decode_ndarray(obj)
if issubclass(t, torch.Tensor):
return torch.from_numpy(self._decode_ndarray(obj))
if issubclass(t, MultiModalKwargs):
if isinstance(obj, list):
return MultiModalKwargs.from_items(
self._decode_mm_items(obj))
return MultiModalKwargs({
k: self._decode_nested_tensors(v)
for k, v in obj.items()
})
return obj

def _decode_ndarray(self, arr: Any) -> np.ndarray:
dtype, shape, data = arr
buffer = self.aux_buffers[data] if isinstance(data, int) else data
# Copy from inline representation, otherwise Torch is unhappy since
# the returned memory is non-writeable.
buffer = self.aux_buffers[data] if isinstance(data, int) \
else bytearray(data)
return np.ndarray(buffer=buffer, dtype=np.dtype(dtype), shape=shape)

def _decode_mm_items(self, obj: list) -> list[MultiModalKwargsItem]:
all = []
for item in obj:
elems = []
for v in item:
v["data"] = self._decode_nested_tensors(v["data"])
# Reconstruct the field processor using MultiModalFieldConfig
factory_meth_name, *field_args = v["field"]
factory_meth = getattr(MultiModalFieldConfig,
factory_meth_name)
v["field"] = factory_meth(None, *field_args).field
elems.append(MultiModalFieldElem(**v))
all.append(MultiModalKwargsItem.from_elems(elems))
return all

def _decode_nested_tensors(self, obj: Any) -> NestedTensors:
if not isinstance(obj, list):
raise TypeError(f"Unexpected NestedTensors contents: {type(obj)}")
if obj and isinstance(obj[0], str):
return torch.from_numpy(self._decode_ndarray(obj))
return [self._decode_nested_tensors(x) for x in obj]

def ext_hook(self, code: int, data: memoryview) -> Any:
if code == CUSTOM_TYPE_RAW_VIEW:
return data
Expand Down