1717from .exceptions import SchemaException
1818from .schema import shortname
1919
20- _string_type_def : Final = TypeDef ("strtype" , "_PrimitiveLoader(str)" )
21- _int_type_def : Final = TypeDef ("inttype" , "_PrimitiveLoader(int)" )
22- _float_type_def : Final = TypeDef ("floattype" , "_PrimitiveLoader(float)" )
23- _bool_type_def : Final = TypeDef ("booltype" , "_PrimitiveLoader(bool)" )
24- _null_type_def : Final = TypeDef ("None_type" , "_PrimitiveLoader(type(None))" )
25- _any_type_def : Final = TypeDef ("Any_type" , "_AnyLoader()" )
20+ _string_type_def : Final = TypeDef (name = "strtype" , init = "_PrimitiveLoader(str)" , instance_type = "str" )
21+ _int_type_def : Final = TypeDef (name = "inttype" , init = "_PrimitiveLoader(int)" , instance_type = "int" )
22+ _float_type_def : Final = TypeDef (
23+ name = "floattype" , init = "_PrimitiveLoader(float)" , instance_type = "float"
24+ )
25+ _bool_type_def : Final = TypeDef (
26+ name = "booltype" , init = "_PrimitiveLoader(bool)" , instance_type = "bool"
27+ )
28+ _null_type_def : Final = TypeDef (
29+ name = "None_type" , init = "_PrimitiveLoader(type(None))" , instance_type = "None"
30+ )
31+ _any_type_def : Final = TypeDef (name = "Any_type" , init = "_AnyLoader()" , instance_type = "Any" )
2632
2733prims : Final = {
2834 "http://www.w3.org/2001/XMLSchema#string" : _string_type_def ,
@@ -59,7 +65,7 @@ def fmt(text: str, indent: int) -> str:
5965 black .format_str (
6066 text ,
6167 mode = black .mode .Mode (
62- target_versions = {black .mode .TargetVersion .PY36 }, line_length = 88 - indent
68+ target_versions = {black .mode .TargetVersion .PY310 }, line_length = 88 - indent
6369 ),
6470 ),
6571 " " * indent ,
@@ -145,7 +151,9 @@ def begin_class(
145151 optional_fields : set [str ],
146152 ) -> None :
147153 classname = self .safe_name (classname )
148- self .current_class_is_abstract = abstract
154+ self .current_optional_fields = optional_fields
155+ self .current_fieldtypes : dict [str , TypeDef ] = {}
156+ self .current_class_is_abstract : bool = abstract
149157
150158 if extends :
151159 ext = ", " .join (self .safe_name (e ) for e in extends )
@@ -169,65 +177,17 @@ def begin_class(
169177 self .out .write (" pass\n \n \n " )
170178 return
171179
172- idfield_safe_name : Final = self .safe_name (idfield ) if idfield != "" else None
173- if idfield_safe_name is not None :
174- self .out .write (f" { idfield_safe_name } : str\n \n " )
180+ self . current_idfield : str = self .safe_name (idfield ) if idfield != "" else None
181+ if self . current_idfield is not None :
182+ self .out .write (f" { self . current_idfield } : str\n \n " )
175183
176- required_field_names : Final = [f for f in field_names if f not in optional_fields ]
177- optional_field_names : Final = [f for f in field_names if f in optional_fields ]
178-
179- safe_inits : Final [list [str ]] = [" self," ]
180- safe_inits .extend (
181- [f" { self .safe_name (f )} : Any," for f in required_field_names if f != "class" ]
182- )
183- safe_inits .extend (
184- [
185- f" { self .safe_name (f )} : Any | None = None,"
186- for f in optional_field_names
187- if f != "class"
188- ]
189- )
190- self .out .write (
191- " def __init__(\n "
192- + "\n " .join (safe_inits )
193- + "\n extension_fields: MutableMapping[str, Any] | None = None,"
194- + "\n loadingOptions: LoadingOptions | None = None,"
195- + "\n ) -> None:\n "
196- + """ if extension_fields:
197- self.extension_fields = extension_fields
198- else:
199- self.extension_fields = CommentedMap()
200- if loadingOptions:
201- self.loadingOptions = loadingOptions
202- else:
203- self.loadingOptions = LoadingOptions()
204- """
205- )
206- field_inits = ""
207- for name in field_names :
208- if name == "class" :
209- field_inits += """ self.class_: Final[str] = "{}"
210- """ .format (
211- classname
212- )
213- elif name == idfield_safe_name :
214- field_inits += """ self.{0} = {0} if {0} is not None else "_:" + str(_uuid__.uuid4())
215- """ .format (
216- self .safe_name (name )
217- )
218- else :
219- field_inits += """ self.{0} = {0}
220- """ .format (
221- self .safe_name (name )
222- )
223184 field_eqs : Final = []
224185 field_hashes : Final = []
225186 for name in field_names :
226187 field_eqs .append ("self.{0} == other.{0}" .format (self .safe_name (name )))
227188 field_hashes .append (f"self.{ self .safe_name (name )} " )
228189 field_eq : Final = " and\n " .join (field_eqs )
229190 field_hash : Final = ",\n " .join (field_hashes )
230- self .out .write (field_inits )
231191 self .out .write (
232192 "\n "
233193 + fmt (
@@ -341,7 +301,65 @@ def end_class(self, classname: str, field_names: list[str]) -> None:
341301"""
342302 )
343303
344- self .serializer .write (" return r\n \n " )
304+ self .serializer .write (" return r\n " )
305+
306+ required_field_names : Final = [
307+ f for f in field_names if f not in self .current_optional_fields
308+ ]
309+ optional_field_names : Final = [f for f in field_names if f in self .current_optional_fields ]
310+
311+ idfield_safe_name : Final = (
312+ self .safe_name (self .current_idfield ) if self .current_idfield != None else None
313+ )
314+ safe_inits : Final [list [str ]] = [" self," ]
315+ safe_inits .extend (
316+ [
317+ f" { self .safe_name (f )} : { self .current_fieldtypes [self .safe_name (f )].instance_type } ,"
318+ for f in required_field_names
319+ if f != "class"
320+ ]
321+ )
322+ safe_inits .extend (
323+ [
324+ f" { self .safe_name (f )} : { self .current_fieldtypes [self .safe_name (f )].instance_type } = None,"
325+ for f in optional_field_names
326+ if f != "class"
327+ ]
328+ )
329+ self .serializer .write (
330+ "\n def __init__(\n "
331+ + "\n " .join (safe_inits )
332+ + "\n extension_fields: MutableMapping[str, Any] | None = None,"
333+ + "\n loadingOptions: LoadingOptions | None = None,"
334+ + "\n ) -> None:\n "
335+ + """ if extension_fields:
336+ self.extension_fields = extension_fields
337+ else:
338+ self.extension_fields = CommentedMap()
339+ if loadingOptions:
340+ self.loadingOptions = loadingOptions
341+ else:
342+ self.loadingOptions = LoadingOptions()
343+ """
344+ )
345+ field_inits = ""
346+ for name in field_names :
347+ if name == "class" :
348+ field_inits += """ self.class_: Final[str] = "{}"
349+ """ .format (
350+ classname
351+ )
352+ elif name == idfield_safe_name :
353+ field_inits += """ self.{0} = {0} if {0} is not None else "_:" + str(_uuid__.uuid4())
354+ """ .format (
355+ self .safe_name (name )
356+ )
357+ else :
358+ field_inits += """ self.{0} = {0}
359+ """ .format (
360+ self .safe_name (name )
361+ )
362+ self .serializer .write (f"{ field_inits } \n " )
345363
346364 self .serializer .write (
347365 fmt (
@@ -384,40 +402,46 @@ def type_loader(
384402 """Parse the given type declaration and declare its components."""
385403 match type_declaration :
386404 case MutableSequence ():
387- sub_names1 : Final = list (
388- dict .fromkeys ([self .type_loader (i ).name for i in type_declaration ])
389- )
405+ sub_types1 : Final = [self .type_loader (i ) for i in type_declaration ]
406+ sub_names1 : Final = [t .name for t in sub_types1 ]
390407 return self .declare_type (
391408 TypeDef (
392- "union_of_{}" .format ("_or_" .join (sub_names1 )),
393- "_UnionLoader(({},))" .format (", " .join (sub_names1 )),
409+ name = "union_of_{}" .format ("_or_" .join (sub_names1 )),
410+ init = "_UnionLoader(({},))" .format (", " .join (sub_names1 )),
411+ instance_type = " | " .join ({t .instance_type or "" for t in sub_types1 }),
394412 )
395413 )
396414 case {"type" : "array" | "https://w3id.org/cwl/salad#array" , "items" : items }:
397415 i1 : Final = self .type_loader (items )
398416 return self .declare_type (
399417 TypeDef (
400- f"array_of_{ i1 .name } " ,
401- f"_ArrayLoader({ i1 .name } )" ,
418+ name = f"array_of_{ i1 .name } " ,
419+ init = f"_ArrayLoader({ i1 .name } )" ,
420+ instance_type = f"MutableSequence[{ i1 .instance_type } ]" ,
402421 )
403422 )
404423 case {"type" : "map" | "https://w3id.org/cwl/salad#map" , "values" : values , ** rest }:
405424 i2 : Final = self .type_loader (values )
406425 name = self .safe_name (str (rest ["name" ])) if "name" in rest else None
407426 anon_type = self .declare_type (
408427 TypeDef (
409- f"map_of_{ i2 .name } " ,
410- "_MapLoader({}, {}, {}, {})" .format (
428+ name = f"map_of_{ i2 .name } " ,
429+ init = "_MapLoader({}, {}, {}, {})" .format (
411430 i2 .name ,
412431 f"'{ name } '" , # noqa: B907
413432 f"'{ container } '" if container is not None else None , # noqa: B907
414433 no_link_check ,
415434 ),
435+ instance_type = f"MutableMapping[str, Any]" ,
416436 )
417437 )
418438 if "name" in rest :
419439 return self .declare_type (
420- TypeDef (self .safe_name (str (rest ["name" ])) + "Loader" , anon_type .name )
440+ TypeDef (
441+ name = self .safe_name (str (rest ["name" ])) + "Loader" ,
442+ init = anon_type .name ,
443+ instance_type = anon_type .instance_type ,
444+ )
421445 )
422446 else :
423447 return anon_type
@@ -438,26 +462,30 @@ def type_loader(
438462 docstring = f'\n """\n { formated_doc } \n """'
439463 else :
440464 docstring = ""
465+ sym_names : Final = [schema .avro_field_name (sym ) for sym in symbols ]
466+ sym_literals : Final = [f'"{ s } "' for s in sym_names ]
441467 return self .declare_type (
442468 TypeDef (
443- self .safe_name (name ) + "Loader" ,
444- '_EnumLoader(("{}",), "{}"){}' .format (
445- '", "' .join (schema . avro_field_name ( sym ) for sym in symbols ),
469+ name = self .safe_name (name ) + "Loader" ,
470+ init = '_EnumLoader(("{}",), "{}"){}' .format (
471+ '", "' .join (sym_names ),
446472 self .safe_name (name ),
447473 docstring ,
448474 ),
475+ instance_type = f"Literal[{ ', ' .join (sym_literals )} ]" ,
449476 )
450477 )
451478
452479 case {"type" : "record" | "https://w3id.org/cwl/salad#record" , "name" : name , ** rest }:
453480 return self .declare_type (
454481 TypeDef (
455- self .safe_name (name ) + "Loader" ,
456- "_RecordLoader({}, {}, {})" .format (
482+ name = self .safe_name (name ) + "Loader" ,
483+ init = "_RecordLoader({}, {}, {})" .format (
457484 self .safe_name (name ),
458485 f"'{ container } '" if container is not None else None , # noqa: B907
459486 no_link_check ,
460487 ),
488+ instance_type = self .safe_name (name ),
461489 abstract = bool (rest .get ("abstract" , False )),
462490 )
463491 )
@@ -469,15 +497,23 @@ def type_loader(
469497 }:
470498 # Declare the named loader to handle recursive union definitions
471499 loader_name = self .safe_name (name ) + "Loader"
472- loader_type = TypeDef (loader_name , f"_UnionLoader((), '{ loader_name } ')" )
500+ loader_type = TypeDef (
501+ name = loader_name ,
502+ init = f"_UnionLoader((), '{ loader_name } ')" ,
503+ instance_type = self .safe_name (name ),
504+ )
473505 self .declare_type (loader_type )
474506 # Parse inner types
475- sub_names2 : Final = list (dict .fromkeys ([self .type_loader (i ).name for i in names ]))
507+ sub_types2 : Final = [self .type_loader (i ) for i in names ]
508+ sub_names2 : Final = list (dict .fromkeys (t .name for t in sub_types2 ))
476509 # Register lazy initialization for the loader
477510 self .add_lazy_init (
478511 LazyInitDef (
479- loader_name ,
480- "{}.add_loaders(({},))" .format (loader_name , ", " .join (sub_names2 )),
512+ name = loader_name ,
513+ init = "{}.add_loaders(({},))" .format (loader_name , ", " .join (sub_names2 )),
514+ instance_type = f'{ self .safe_name (name )} : TypeAlias = "'
515+ + " | " .join ({t .instance_type for t in sub_types2 })
516+ + '"' ,
481517 )
482518 )
483519 return loader_type
@@ -490,8 +526,9 @@ def type_loader(
490526 case "Expression" | "https://w3id.org/cwl/cwl#Expression" as decl :
491527 return self .declare_type (
492528 TypeDef (
493- self .safe_name (decl ) + "Loader" ,
494- "_ExpressionLoader(str)" ,
529+ name = self .safe_name (decl ) + "Loader" ,
530+ init = "_ExpressionLoader(str)" ,
531+ instance_type = "str" ,
495532 )
496533 )
497534 case str (decl ):
@@ -545,6 +582,7 @@ def declare_field(
545582 ) -> None :
546583 if self .current_class_is_abstract :
547584 return
585+ self .current_fieldtypes [self .safe_name (name )] = fieldtype
548586
549587 if optional :
550588 self .out .write (f""" { self .safe_name (name )} = None\n """ )
@@ -728,11 +766,12 @@ def uri_loader(
728766 """Construct the TypeDef for the given URI loader."""
729767 return self .declare_type (
730768 TypeDef (
731- f"uri_{ inner .name } _{ scoped_id } _{ vocab_term } _{ ref_scope } _{ no_link_check } " ,
732- f"_URILoader({ inner .name } , { scoped_id } , { vocab_term } , { ref_scope } , { no_link_check } )" ,
769+ name = f"uri_{ inner .name } _{ scoped_id } _{ vocab_term } _{ ref_scope } _{ no_link_check } " ,
770+ init = f"_URILoader({ inner .name } , { scoped_id } , { vocab_term } , { ref_scope } , { no_link_check } )" ,
733771 is_uri = True ,
734772 scoped_id = scoped_id ,
735773 ref_scope = ref_scope ,
774+ instance_type = inner .instance_type ,
736775 )
737776 )
738777
@@ -742,27 +781,30 @@ def idmap_loader(
742781 """Construct the TypeDef for the given mapped ID loader."""
743782 return self .declare_type (
744783 TypeDef (
745- f"idmap_{ self .safe_name (field )} _{ inner .name } " ,
746- f"_IdMapLoader({ inner .name } , '{ map_subject } ', '{ map_predicate } ')" , # noqa: B907
784+ name = f"idmap_{ self .safe_name (field )} _{ inner .name } " ,
785+ init = f"_IdMapLoader({ inner .name } , '{ map_subject } ', '{ map_predicate } ')" , # noqa: B907
786+ instance_type = inner .instance_type ,
747787 )
748788 )
749789
750790 def typedsl_loader (self , inner : TypeDef , ref_scope : int | None ) -> TypeDef :
751791 """Construct the TypeDef for the given DSL loader."""
752792 return self .declare_type (
753793 TypeDef (
754- f"typedsl_{ self .safe_name (inner .name )} _{ ref_scope } " ,
755- f"_TypeDSLLoader({ self .safe_name (inner .name )} , { ref_scope } , " # noqa: B907
794+ name = f"typedsl_{ self .safe_name (inner .name )} _{ ref_scope } " ,
795+ init = f"_TypeDSLLoader({ self .safe_name (inner .name )} , { ref_scope } , " # noqa: B907
756796 f"'{ self .salad_version } ')" , # noqa: B907
797+ instance_type = inner .instance_type ,
757798 )
758799 )
759800
760801 def secondaryfilesdsl_loader (self , inner : TypeDef ) -> TypeDef :
761802 """Construct the TypeDef for secondary files."""
762803 return self .declare_type (
763804 TypeDef (
764- f"secondaryfilesdsl_{ inner .name } " ,
765- f"_UnionLoader((_SecondaryDSLLoader({ inner .name } ), { inner .name } ,))" ,
805+ name = f"secondaryfilesdsl_{ inner .name } " ,
806+ init = f"_UnionLoader((_SecondaryDSLLoader({ inner .name } ), { inner .name } ,))" ,
807+ instance_type = inner .instance_type ,
766808 )
767809 )
768810
@@ -787,6 +829,7 @@ def epilogue(self, root_loader: TypeDef) -> None:
787829 if self .lazy_inits :
788830 for lazy_init in self .lazy_inits .values ():
789831 self .out .write (fmt (f"{ lazy_init .init } \n " , 0 ))
832+ self .out .write (fmt (f"{ lazy_init .instance_type } " , 0 ))
790833 self .out .write ("\n " )
791834
792835 self .out .write (
0 commit comments