Skip to content

Serialize tensors using int8 views #16866

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Apr 19, 2025
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions tests/v1/test_serial_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,15 +114,15 @@ def test_multimodal_kwargs():

total_len = sum(memoryview(x).cast("B").nbytes for x in encoded)

# expected total encoding length, should be 44536, +-20 for minor changes
assert total_len >= 44516 and total_len <= 44556
# expected total encoding length, should be 44559, +-20 for minor changes
assert total_len >= 44539 and total_len <= 44579
decoded: MultiModalKwargs = decoder.decode(encoded).mm[0]
assert all(nested_equal(d[k], decoded[k]) for k in d)


def test_multimodal_items_by_modality():
e1 = MultiModalFieldElem("audio", "a0", torch.zeros(1000,
dtype=torch.int16),
dtype=torch.bfloat16),
MultiModalBatchedField())
e2 = MultiModalFieldElem(
"video",
Expand Down
34 changes: 30 additions & 4 deletions vllm/v1/serial_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def encode_into(self, obj: Any, buf: bytearray) -> Sequence[bytestr]:

def enc_hook(self, obj: Any) -> Any:
if isinstance(obj, torch.Tensor):
return self._encode_ndarray(obj.numpy())
return self._encode_tensor(obj)

# Fall back to pickle for object or void kind ndarrays.
if isinstance(obj, np.ndarray) and obj.dtype.kind not in ('O', 'V'):
Expand Down Expand Up @@ -133,9 +133,26 @@ def _encode_ndarray(
# backing buffers that we've stashed in `aux_buffers`.
return obj.dtype.str, obj.shape, data

def _encode_tensor(
self, obj: torch.Tensor
) -> tuple[str, tuple[int, ...], Union[int, memoryview]]:
assert self.aux_buffers is not None
# this creates a copy of the tensor
obj = obj.contiguous() if not obj.is_contiguous() else obj
# view the tensor as a 1D array of bytes
arr = obj.view([obj.numel()]).view(torch.uint8).numpy()
if obj.nbytes < self.size_threshold:
data = msgpack.Ext(CUSTOM_TYPE_RAW_VIEW, arr.data)
else:
# Otherwise encode index of backing buffer to avoid copy.
data = len(self.aux_buffers)
self.aux_buffers.append(arr.data)
dt = str(obj.dtype)[6:] # remove 'torch.' prefix
return dt, obj.shape, data

def _encode_nested_tensors(self, nt: NestedTensors) -> Any:
if isinstance(nt, torch.Tensor):
return self._encode_ndarray(nt.numpy())
return self._encode_tensor(nt)
if isinstance(nt, (int, float)):
# Although it violates NestedTensors type, MultiModalKwargs
# values are sometimes floats.
Expand Down Expand Up @@ -186,7 +203,7 @@ def dec_hook(self, t: type, obj: Any) -> Any:
if issubclass(t, np.ndarray):
return self._decode_ndarray(obj)
if issubclass(t, torch.Tensor):
return torch.from_numpy(self._decode_ndarray(obj))
return self._decode_tensor(obj)
if issubclass(t, MultiModalKwargs):
if isinstance(obj, list):
return MultiModalKwargs.from_items(
Expand All @@ -205,6 +222,15 @@ def _decode_ndarray(self, arr: Any) -> np.ndarray:
else bytearray(data)
return np.ndarray(buffer=buffer, dtype=np.dtype(dtype), shape=shape)

def _decode_tensor(self, arr: Any) -> torch.Tensor:
dtype, shape, data = arr
# Copy from inline representation, otherwise Torch is unhappy since
# the returned memory is non-writeable.
buffer = self.aux_buffers[data] if isinstance(data, int) \
else bytearray(data)
arr = np.ndarray(buffer=buffer, dtype=np.uint8, shape=[len(buffer)])
return torch.from_numpy(arr).view(getattr(torch, dtype)).view(shape)

def _decode_mm_items(self, obj: list) -> list[MultiModalKwargsItem]:
decoded_items = []
for item in obj:
Expand All @@ -228,7 +254,7 @@ def _decode_nested_tensors(self, obj: Any) -> NestedTensors:
if not isinstance(obj, list):
raise TypeError(f"Unexpected NestedTensors contents: {type(obj)}")
if obj and isinstance(obj[0], str):
return torch.from_numpy(self._decode_ndarray(obj))
return self._decode_tensor(obj)
return [self._decode_nested_tensors(x) for x in obj]

def ext_hook(self, code: int, data: memoryview) -> Any:
Expand Down
Loading