1
1
# SPDX-License-Identifier: Apache-2.0
2
2
3
+ import dataclasses
3
4
import pickle
4
5
from collections .abc import Sequence
5
6
from inspect import isclass
12
13
import zmq
13
14
from msgspec import msgpack
14
15
16
+ from vllm import envs
17
+ from vllm .multimodal .inputs import (BaseMultiModalField ,
18
+ MultiModalBatchedField ,
19
+ MultiModalFieldConfig , MultiModalFieldElem ,
20
+ MultiModalFlatField , MultiModalKwargs ,
21
+ MultiModalKwargsItem ,
22
+ MultiModalSharedField , NestedTensors )
23
+
15
24
CUSTOM_TYPE_PICKLE = 1
16
25
CUSTOM_TYPE_CLOUDPICKLE = 2
17
26
CUSTOM_TYPE_RAW_VIEW = 3
18
27
19
- # TODO calibrate this size
20
- MIN_NOCOPY_BUF_SIZE = 512
28
+ # MultiModalField class serialization type map.
29
+ # These need to list all possible field types and match them
30
+ # to factory methods in `MultiModalFieldConfig`.
31
+ MMF_CLASS_TO_FACTORY : dict [type [BaseMultiModalField ], str ] = {
32
+ MultiModalFlatField : "flat" ,
33
+ MultiModalSharedField : "shared" ,
34
+ MultiModalBatchedField : "batched" ,
35
+ }
21
36
22
37
bytestr = Union [bytes , bytearray , memoryview , zmq .Frame ]
23
38
@@ -27,14 +42,20 @@ class MsgpackEncoder:
27
42
28
43
Note that unlike vanilla `msgspec` Encoders, this interface is generally
29
44
not thread-safe when encoding tensors / numpy arrays.
45
+
46
+ By default, arrays below 256B are serialized inline Larger will get sent
47
+ via dedicated messages. Note that this is a per-tensor limit.
30
48
"""
31
49
32
- def __init__ (self ):
50
+ def __init__ (self , size_threshold : Optional [int ] = None ):
51
+ if size_threshold is None :
52
+ size_threshold = envs .VLLM_MSGPACK_ZERO_COPY_THRESHOLD
33
53
self .encoder = msgpack .Encoder (enc_hook = self .enc_hook )
34
54
# This is used as a local stash of buffers that we can then access from
35
55
# our custom `msgspec` hook, `enc_hook`. We don't have a way to
36
56
# pass custom data to the hook otherwise.
37
57
self .aux_buffers : Optional [list [bytestr ]] = None
58
+ self .size_threshold = size_threshold
38
59
39
60
def encode (self , obj : Any ) -> Sequence [bytestr ]:
40
61
try :
@@ -65,6 +86,25 @@ def enc_hook(self, obj: Any) -> Any:
65
86
if isinstance (obj , np .ndarray ) and obj .dtype .kind not in ('O' , 'V' ):
66
87
return self ._encode_ndarray (obj )
67
88
89
+ if isinstance (obj , MultiModalKwargs ):
90
+ mm : MultiModalKwargs = obj
91
+ if not mm .modalities :
92
+ # just return the main dict if there are no modalities.
93
+ return dict (mm )
94
+
95
+ # ignore the main dict, it will be re-indexed.
96
+ # Encode a list of MultiModalKwargsItems as plain dicts
97
+ # + special handling for .field.
98
+ # Any tensors *not* indexed by modality will be ignored.
99
+ return [[{
100
+ "modality" : elem .modality ,
101
+ "key" : elem .key ,
102
+ "data" : self ._encode_nested_tensors (elem .data ),
103
+ "field" : self ._encode_mm_field (elem .field ),
104
+ } for elem in item .values ()]
105
+ for itemlist in mm ._items_by_modality .values ()
106
+ for item in itemlist ]
107
+
68
108
if isinstance (obj , FunctionType ):
69
109
# `pickle` is generally faster than cloudpickle, but can have
70
110
# problems serializing methods.
@@ -77,8 +117,9 @@ def _encode_ndarray(
77
117
self , obj : np .ndarray
78
118
) -> tuple [str , tuple [int , ...], Union [int , memoryview ]]:
79
119
assert self .aux_buffers is not None
120
+ # If the array is non-contiguous, we need to copy it first
80
121
arr_data = obj .data if obj .data .c_contiguous else obj .tobytes ()
81
- if not obj .shape or obj .nbytes < MIN_NOCOPY_BUF_SIZE :
122
+ if not obj .shape or obj .nbytes < self . size_threshold :
82
123
# Encode small arrays and scalars inline. Using this extension type
83
124
# ensures we can avoid copying when decoding.
84
125
data = msgpack .Ext (CUSTOM_TYPE_RAW_VIEW , arr_data )
@@ -92,6 +133,26 @@ def _encode_ndarray(
92
133
# backing buffers that we've stashed in `aux_buffers`.
93
134
return obj .dtype .str , obj .shape , data
94
135
136
+ def _encode_nested_tensors (self , nt : NestedTensors ) -> Any :
137
+ if isinstance (nt , torch .Tensor ):
138
+ return self ._encode_ndarray (nt .numpy ())
139
+ if isinstance (nt , (int , float )):
140
+ # Although it violates NestedTensors type, MultiModalKwargs
141
+ # values are sometimes floats.
142
+ return nt
143
+ return [self ._encode_nested_tensors (x ) for x in nt ]
144
+
145
+ def _encode_mm_field (self , field : BaseMultiModalField ):
146
+ # Figure out the factory name for the field type.
147
+ name = MMF_CLASS_TO_FACTORY .get (field .__class__ )
148
+ if not name :
149
+ raise TypeError (f"Unsupported field type: { field .__class__ } " )
150
+ # We just need to copy all of the field values in order
151
+ # which will be then used to reconstruct the field.
152
+ field_values = (getattr (field , f .name )
153
+ for f in dataclasses .fields (field ))
154
+ return name , * field_values
155
+
95
156
96
157
class MsgpackDecoder :
97
158
"""Decoder with custom torch tensor and numpy array serialization.
@@ -126,13 +187,50 @@ def dec_hook(self, t: type, obj: Any) -> Any:
126
187
return self ._decode_ndarray (obj )
127
188
if issubclass (t , torch .Tensor ):
128
189
return torch .from_numpy (self ._decode_ndarray (obj ))
190
+ if issubclass (t , MultiModalKwargs ):
191
+ if isinstance (obj , list ):
192
+ return MultiModalKwargs .from_items (
193
+ self ._decode_mm_items (obj ))
194
+ return MultiModalKwargs ({
195
+ k : self ._decode_nested_tensors (v )
196
+ for k , v in obj .items ()
197
+ })
129
198
return obj
130
199
131
200
def _decode_ndarray (self , arr : Any ) -> np .ndarray :
132
201
dtype , shape , data = arr
133
- buffer = self .aux_buffers [data ] if isinstance (data , int ) else data
202
+ # Copy from inline representation, otherwise Torch is unhappy since
203
+ # the returned memory is non-writeable.
204
+ buffer = self .aux_buffers [data ] if isinstance (data , int ) \
205
+ else bytearray (data )
134
206
return np .ndarray (buffer = buffer , dtype = np .dtype (dtype ), shape = shape )
135
207
208
+ def _decode_mm_items (self , obj : list ) -> list [MultiModalKwargsItem ]:
209
+ decoded_items = []
210
+ for item in obj :
211
+ elems = []
212
+ for v in item :
213
+ v ["data" ] = self ._decode_nested_tensors (v ["data" ])
214
+ # Reconstruct the field processor using MultiModalFieldConfig
215
+ factory_meth_name , * field_args = v ["field" ]
216
+ factory_meth = getattr (MultiModalFieldConfig ,
217
+ factory_meth_name )
218
+ v ["field" ] = factory_meth (None , * field_args ).field
219
+ elems .append (MultiModalFieldElem (** v ))
220
+ decoded_items .append (MultiModalKwargsItem .from_elems (elems ))
221
+ return decoded_items
222
+
223
+ def _decode_nested_tensors (self , obj : Any ) -> NestedTensors :
224
+ if isinstance (obj , (int , float )):
225
+ # Although it violates NestedTensors type, MultiModalKwargs
226
+ # values are sometimes floats.
227
+ return obj
228
+ if not isinstance (obj , list ):
229
+ raise TypeError (f"Unexpected NestedTensors contents: { type (obj )} " )
230
+ if obj and isinstance (obj [0 ], str ):
231
+ return torch .from_numpy (self ._decode_ndarray (obj ))
232
+ return [self ._decode_nested_tensors (x ) for x in obj ]
233
+
136
234
def ext_hook (self , code : int , data : memoryview ) -> Any :
137
235
if code == CUSTOM_TYPE_RAW_VIEW :
138
236
return data
0 commit comments