@@ -75,7 +75,7 @@ def enc_hook(self, obj: Any) -> Any:
75
75
# ignore the main dict, it will be re-indexed.
76
76
# pass a list of MultiModalKwargsItem, then see below
77
77
# Any tensors *not* indexed by modality will be ignored.
78
- return mm ._items_by_modality .values ()
78
+ return list ( mm ._items_by_modality .values () )
79
79
# just return the main dict if there are no modalities
80
80
return dict (mm )
81
81
@@ -150,10 +150,12 @@ def dec_hook(self, t: type, obj: Any) -> Any:
150
150
return torch .from_numpy (self ._decode_ndarray (obj ))
151
151
if issubclass (t , MultiModalKwargs ):
152
152
if isinstance (obj , list ):
153
- return MultiModalKwargs .from_items (self ._decode_mm_items (obj ))
154
- return MultiModalKwargs (
155
- {k : self ._decode_nested (v )
156
- for k in obj .items ()})
153
+ return MultiModalKwargs .from_items (
154
+ self ._decode_mm_items (obj ))
155
+ return MultiModalKwargs ({
156
+ k : self ._decode_nested_tensors (v )
157
+ for k , v in obj .items ()
158
+ })
157
159
return obj
158
160
159
161
def _decode_ndarray (self , arr : Any ) -> np .ndarray :
@@ -166,7 +168,7 @@ def _decode_mm_items(self, obj: list) -> list[MultiModalKwargsItem]:
166
168
for item in chain .from_iterable (obj ):
167
169
elems = []
168
170
for v in item .values ():
169
- v ['data' ] = self ._decode_nested (v ['data' ])
171
+ v ['data' ] = self ._decode_nested_tensors (v ['data' ])
170
172
v ['field' ] = pickle .loads (v ['field' ])
171
173
elems .append (MultiModalFieldElem (** v ))
172
174
all .append (MultiModalKwargsItem .from_elems (elems ))
0 commit comments