@@ -93,6 +93,22 @@ def __str__(self) -> str:
93
93
return self .raw_value
94
94
95
95
96
+ from typing import Generic , TypeVar , Iterable
97
+
98
+ T = TypeVar ("T" )
99
+
100
+
101
+ class NonEmptyList (List [T ], Generic [T ]):
102
+ """
103
+ When using the List[T] annotation we assume that it is allowed to have zero elements in it.
104
+ If you need to indicate, that a List is guaranteed to be non-empty, use this type annotation instead.
105
+ Note that this is also enforced during parsing, i.e., parsing an empty list declared as being non-empty
106
+ will result in an error.
107
+ """
108
+
109
+ pass
110
+
111
+
96
112
class TypeBase :
97
113
def construct (self , args : Dict [str , any ]):
98
114
raise NotImplementedError (self )
@@ -244,7 +260,7 @@ def _resolve(self):
244
260
245
261
def _reflect_list (item_cls , globals , locals , symbol_table : SymbolTable , allow_empty : bool ) -> ClassType :
246
262
item_cls = _unwrap_forward_ref (item_cls , globals , locals )
247
- _reflect (item_cls , globals , locals , symbol_table , False )
263
+ _reflect (item_cls , globals , locals , symbol_table )
248
264
return ListType (UnresolvedType (item_cls ), allow_empty )
249
265
250
266
@@ -261,40 +277,38 @@ def _reflect_class(cls, globals, locals, symbol_table: SymbolTable) -> ClassType
261
277
except KeyError :
262
278
pass
263
279
force_name = field .metadata .get ("force_name" , None )
264
- allow_empty_list = field .metadata .get ("allow_empty" , False )
265
280
field_type = _unwrap_forward_ref (field_type , globals , locals )
266
281
attrs .append (ClassType .Attribute (field .name , UnresolvedType (field_type ), required , force_name ))
267
- _reflect (field_type , globals , locals , symbol_table , allow_empty_list )
282
+ _reflect (field_type , globals , locals , symbol_table )
268
283
269
284
subclasses : List [TypeBase ] = []
270
285
for subclass in _collect_subclasses (cls ):
271
286
subclasses .append (UnresolvedType (subclass ))
272
- _reflect (subclass , globals , locals , symbol_table , False )
287
+ _reflect (subclass , globals , locals , symbol_table )
273
288
return ClassType (cls , attrs , static_attrs , subclasses )
274
289
275
290
276
- def _reflect (
277
- cls : any , globals , locals , symbol_table : SymbolTable , allow_empty_list : bool
278
- ) -> Tuple [TypeBase , SymbolTable ]:
291
+ def _reflect (cls : any , globals , locals , symbol_table : SymbolTable ) -> Tuple [TypeBase , SymbolTable ]:
279
292
key = str (cls )
280
293
try :
281
294
return symbol_table .symbols [key ]
282
295
except KeyError :
283
296
# Avoid infinite recursion if _reflect_unsafe calls itself again
284
297
symbol_table .symbols [key ] = None
285
- result = _reflect_unsafe (cls , globals , locals , symbol_table , allow_empty_list )
298
+ result = _reflect_unsafe (cls , globals , locals , symbol_table )
286
299
symbol_table .symbols [key ] = result
287
300
return result
288
301
289
302
290
- def _reflect_unsafe (
291
- cls : any , globals , locals , symbol_table : SymbolTable , allow_empty_list : bool
292
- ) -> Tuple [TypeBase , SymbolTable ]:
303
+ def _reflect_unsafe (cls : any , globals , locals , symbol_table : SymbolTable ) -> Tuple [TypeBase , SymbolTable ]:
293
304
origin = getattr (cls , "__origin__" , None )
294
305
if origin :
295
306
if origin is list :
296
307
item_type = cls .__args__ [0 ]
297
- return _reflect_list (item_type , globals , locals , symbol_table , allow_empty_list )
308
+ return _reflect_list (item_type , globals , locals , symbol_table , True )
309
+ elif origin is NonEmptyList :
310
+ item_type = cls .__args__ [0 ]
311
+ return _reflect_list (item_type , globals , locals , symbol_table , False )
298
312
else :
299
313
if cls is None :
300
314
return NoneType ()
@@ -327,20 +341,20 @@ def _reflect_unsafe(
327
341
328
342
def reflect (cls : any , globals = {}, locals = {}) -> Tuple [TypeBase , SymbolTable ]:
329
343
symbol_table = SymbolTable ()
330
- type = _reflect (cls , globals , locals , symbol_table , False )
344
+ type = _reflect (cls , globals , locals , symbol_table )
331
345
symbol_table ._resolve ()
332
346
return type , symbol_table
333
347
334
348
335
349
def reflect_function (fn : callable , globals = {}, locals = {}) -> FunctionType :
336
350
symbol_table = SymbolTable ()
337
351
return_type = fn .__annotations__ .get ("return" , None )
338
- r_return_type = _reflect (return_type , globals , locals , symbol_table , False )
352
+ r_return_type = _reflect (return_type , globals , locals , symbol_table )
339
353
args : List [FunctionType .Argument ] = []
340
354
for key , value in fn .__annotations__ .items ():
341
355
if key in ["return" ]:
342
356
continue
343
357
required , arg_type = _unwrap_optional (value )
344
- r_arg_type = _reflect (arg_type , globals , locals , symbol_table , False )
358
+ r_arg_type = _reflect (arg_type , globals , locals , symbol_table )
345
359
args .append (FunctionType .Argument (key , r_arg_type , required ))
346
360
return FunctionType (fn , r_return_type , args )
0 commit comments