Skip to content

Commit 0e71603

Browse files
Add type hints to parsers
1 parent 80c817a commit 0e71603

File tree

3 files changed

+134
-90
lines changed

3 files changed

+134
-90
lines changed

schema_salad/codegen_base.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ class LazyInitDef(NamedTuple):
2323

2424
name: str
2525
init: str
26+
instance_type: str | None = None
2627

2728

2829
class CodeGenBase:

schema_salad/python_codegen.py

Lines changed: 132 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,18 @@
1717
from .exceptions import SchemaException
1818
from .schema import shortname
1919

20-
_string_type_def: Final = TypeDef("strtype", "_PrimitiveLoader(str)")
21-
_int_type_def: Final = TypeDef("inttype", "_PrimitiveLoader(int)")
22-
_float_type_def: Final = TypeDef("floattype", "_PrimitiveLoader(float)")
23-
_bool_type_def: Final = TypeDef("booltype", "_PrimitiveLoader(bool)")
24-
_null_type_def: Final = TypeDef("None_type", "_PrimitiveLoader(type(None))")
25-
_any_type_def: Final = TypeDef("Any_type", "_AnyLoader()")
20+
_string_type_def: Final = TypeDef(name="strtype", init="_PrimitiveLoader(str)", instance_type="str")
21+
_int_type_def: Final = TypeDef(name="inttype", init="_PrimitiveLoader(int)", instance_type="int")
22+
_float_type_def: Final = TypeDef(
23+
name="floattype", init="_PrimitiveLoader(float)", instance_type="float"
24+
)
25+
_bool_type_def: Final = TypeDef(
26+
name="booltype", init="_PrimitiveLoader(bool)", instance_type="bool"
27+
)
28+
_null_type_def: Final = TypeDef(
29+
name="None_type", init="_PrimitiveLoader(type(None))", instance_type="None"
30+
)
31+
_any_type_def: Final = TypeDef(name="Any_type", init="_AnyLoader()", instance_type="Any")
2632

2733
prims: Final = {
2834
"http://www.w3.org/2001/XMLSchema#string": _string_type_def,
@@ -59,7 +65,7 @@ def fmt(text: str, indent: int) -> str:
5965
black.format_str(
6066
text,
6167
mode=black.mode.Mode(
62-
target_versions={black.mode.TargetVersion.PY36}, line_length=88 - indent
68+
target_versions={black.mode.TargetVersion.PY310}, line_length=88 - indent
6369
),
6470
),
6571
" " * indent,
@@ -145,7 +151,9 @@ def begin_class(
145151
optional_fields: set[str],
146152
) -> None:
147153
classname = self.safe_name(classname)
148-
self.current_class_is_abstract = abstract
154+
self.current_optional_fields = optional_fields
155+
self.current_fieldtypes: dict[str, TypeDef] = {}
156+
self.current_class_is_abstract: bool = abstract
149157

150158
if extends:
151159
ext = ", ".join(self.safe_name(e) for e in extends)
@@ -169,65 +177,17 @@ def begin_class(
169177
self.out.write(" pass\n\n\n")
170178
return
171179

172-
idfield_safe_name: Final = self.safe_name(idfield) if idfield != "" else None
173-
if idfield_safe_name is not None:
174-
self.out.write(f" {idfield_safe_name}: str\n\n")
180+
self.current_idfield: str = self.safe_name(idfield) if idfield != "" else None
181+
if self.current_idfield is not None:
182+
self.out.write(f" {self.current_idfield}: str\n\n")
175183

176-
required_field_names: Final = [f for f in field_names if f not in optional_fields]
177-
optional_field_names: Final = [f for f in field_names if f in optional_fields]
178-
179-
safe_inits: Final[list[str]] = [" self,"]
180-
safe_inits.extend(
181-
[f" {self.safe_name(f)}: Any," for f in required_field_names if f != "class"]
182-
)
183-
safe_inits.extend(
184-
[
185-
f" {self.safe_name(f)}: Any | None = None,"
186-
for f in optional_field_names
187-
if f != "class"
188-
]
189-
)
190-
self.out.write(
191-
" def __init__(\n"
192-
+ "\n".join(safe_inits)
193-
+ "\n extension_fields: MutableMapping[str, Any] | None = None,"
194-
+ "\n loadingOptions: LoadingOptions | None = None,"
195-
+ "\n ) -> None:\n"
196-
+ """ if extension_fields:
197-
self.extension_fields = extension_fields
198-
else:
199-
self.extension_fields = CommentedMap()
200-
if loadingOptions:
201-
self.loadingOptions = loadingOptions
202-
else:
203-
self.loadingOptions = LoadingOptions()
204-
"""
205-
)
206-
field_inits = ""
207-
for name in field_names:
208-
if name == "class":
209-
field_inits += """ self.class_: Final[str] = "{}"
210-
""".format(
211-
classname
212-
)
213-
elif name == idfield_safe_name:
214-
field_inits += """ self.{0} = {0} if {0} is not None else "_:" + str(_uuid__.uuid4())
215-
""".format(
216-
self.safe_name(name)
217-
)
218-
else:
219-
field_inits += """ self.{0} = {0}
220-
""".format(
221-
self.safe_name(name)
222-
)
223184
field_eqs: Final = []
224185
field_hashes: Final = []
225186
for name in field_names:
226187
field_eqs.append("self.{0} == other.{0}".format(self.safe_name(name)))
227188
field_hashes.append(f"self.{self.safe_name(name)}")
228189
field_eq: Final = " and\n ".join(field_eqs)
229190
field_hash: Final = ",\n ".join(field_hashes)
230-
self.out.write(field_inits)
231191
self.out.write(
232192
"\n"
233193
+ fmt(
@@ -341,7 +301,65 @@ def end_class(self, classname: str, field_names: list[str]) -> None:
341301
"""
342302
)
343303

344-
self.serializer.write(" return r\n\n")
304+
self.serializer.write(" return r\n")
305+
306+
required_field_names: Final = [
307+
f for f in field_names if f not in self.current_optional_fields
308+
]
309+
optional_field_names: Final = [f for f in field_names if f in self.current_optional_fields]
310+
311+
idfield_safe_name: Final = (
312+
self.safe_name(self.current_idfield) if self.current_idfield != None else None
313+
)
314+
safe_inits: Final[list[str]] = [" self,"]
315+
safe_inits.extend(
316+
[
317+
f" {self.safe_name(f)}: {self.current_fieldtypes[self.safe_name(f)].instance_type},"
318+
for f in required_field_names
319+
if f != "class"
320+
]
321+
)
322+
safe_inits.extend(
323+
[
324+
f" {self.safe_name(f)}: {self.current_fieldtypes[self.safe_name(f)].instance_type} = None,"
325+
for f in optional_field_names
326+
if f != "class"
327+
]
328+
)
329+
self.serializer.write(
330+
"\n def __init__(\n"
331+
+ "\n".join(safe_inits)
332+
+ "\n extension_fields: MutableMapping[str, Any] | None = None,"
333+
+ "\n loadingOptions: LoadingOptions | None = None,"
334+
+ "\n ) -> None:\n"
335+
+ """ if extension_fields:
336+
self.extension_fields = extension_fields
337+
else:
338+
self.extension_fields = CommentedMap()
339+
if loadingOptions:
340+
self.loadingOptions = loadingOptions
341+
else:
342+
self.loadingOptions = LoadingOptions()
343+
"""
344+
)
345+
field_inits = ""
346+
for name in field_names:
347+
if name == "class":
348+
field_inits += """ self.class_: Final[str] = "{}"
349+
""".format(
350+
classname
351+
)
352+
elif name == idfield_safe_name:
353+
field_inits += """ self.{0} = {0} if {0} is not None else "_:" + str(_uuid__.uuid4())
354+
""".format(
355+
self.safe_name(name)
356+
)
357+
else:
358+
field_inits += """ self.{0} = {0}
359+
""".format(
360+
self.safe_name(name)
361+
)
362+
self.serializer.write(f"{field_inits}\n")
345363

346364
self.serializer.write(
347365
fmt(
@@ -384,40 +402,46 @@ def type_loader(
384402
"""Parse the given type declaration and declare its components."""
385403
match type_declaration:
386404
case MutableSequence():
387-
sub_names1: Final = list(
388-
dict.fromkeys([self.type_loader(i).name for i in type_declaration])
389-
)
405+
sub_types1: Final = [self.type_loader(i) for i in type_declaration]
406+
sub_names1: Final = [t.name for t in sub_types1]
390407
return self.declare_type(
391408
TypeDef(
392-
"union_of_{}".format("_or_".join(sub_names1)),
393-
"_UnionLoader(({},))".format(", ".join(sub_names1)),
409+
name="union_of_{}".format("_or_".join(sub_names1)),
410+
init="_UnionLoader(({},))".format(", ".join(sub_names1)),
411+
instance_type=" | ".join({t.instance_type or "" for t in sub_types1}),
394412
)
395413
)
396414
case {"type": "array" | "https://w3id.org/cwl/salad#array", "items": items}:
397415
i1: Final = self.type_loader(items)
398416
return self.declare_type(
399417
TypeDef(
400-
f"array_of_{i1.name}",
401-
f"_ArrayLoader({i1.name})",
418+
name=f"array_of_{i1.name}",
419+
init=f"_ArrayLoader({i1.name})",
420+
instance_type=f"MutableSequence[{i1.instance_type}]",
402421
)
403422
)
404423
case {"type": "map" | "https://w3id.org/cwl/salad#map", "values": values, **rest}:
405424
i2: Final = self.type_loader(values)
406425
name = self.safe_name(str(rest["name"])) if "name" in rest else None
407426
anon_type = self.declare_type(
408427
TypeDef(
409-
f"map_of_{i2.name}",
410-
"_MapLoader({}, {}, {}, {})".format(
428+
name=f"map_of_{i2.name}",
429+
init="_MapLoader({}, {}, {}, {})".format(
411430
i2.name,
412431
f"'{name}'", # noqa: B907
413432
f"'{container}'" if container is not None else None, # noqa: B907
414433
no_link_check,
415434
),
435+
instance_type=f"MutableMapping[str, Any]",
416436
)
417437
)
418438
if "name" in rest:
419439
return self.declare_type(
420-
TypeDef(self.safe_name(str(rest["name"])) + "Loader", anon_type.name)
440+
TypeDef(
441+
name=self.safe_name(str(rest["name"])) + "Loader",
442+
init=anon_type.name,
443+
instance_type=anon_type.instance_type,
444+
)
421445
)
422446
else:
423447
return anon_type
@@ -438,26 +462,30 @@ def type_loader(
438462
docstring = f'\n"""\n{formated_doc}\n"""'
439463
else:
440464
docstring = ""
465+
sym_names: Final = [schema.avro_field_name(sym) for sym in symbols]
466+
sym_literals: Final = [f'"{s}"' for s in sym_names]
441467
return self.declare_type(
442468
TypeDef(
443-
self.safe_name(name) + "Loader",
444-
'_EnumLoader(("{}",), "{}"){}'.format(
445-
'", "'.join(schema.avro_field_name(sym) for sym in symbols),
469+
name=self.safe_name(name) + "Loader",
470+
init='_EnumLoader(("{}",), "{}"){}'.format(
471+
'", "'.join(sym_names),
446472
self.safe_name(name),
447473
docstring,
448474
),
475+
instance_type=f"Literal[{', '.join(sym_literals)}]",
449476
)
450477
)
451478

452479
case {"type": "record" | "https://w3id.org/cwl/salad#record", "name": name, **rest}:
453480
return self.declare_type(
454481
TypeDef(
455-
self.safe_name(name) + "Loader",
456-
"_RecordLoader({}, {}, {})".format(
482+
name=self.safe_name(name) + "Loader",
483+
init="_RecordLoader({}, {}, {})".format(
457484
self.safe_name(name),
458485
f"'{container}'" if container is not None else None, # noqa: B907
459486
no_link_check,
460487
),
488+
instance_type=self.safe_name(name),
461489
abstract=bool(rest.get("abstract", False)),
462490
)
463491
)
@@ -469,15 +497,23 @@ def type_loader(
469497
}:
470498
# Declare the named loader to handle recursive union definitions
471499
loader_name = self.safe_name(name) + "Loader"
472-
loader_type = TypeDef(loader_name, f"_UnionLoader((), '{loader_name}')")
500+
loader_type = TypeDef(
501+
name=loader_name,
502+
init=f"_UnionLoader((), '{loader_name}')",
503+
instance_type=self.safe_name(name),
504+
)
473505
self.declare_type(loader_type)
474506
# Parse inner types
475-
sub_names2: Final = list(dict.fromkeys([self.type_loader(i).name for i in names]))
507+
sub_types2: Final = [self.type_loader(i) for i in names]
508+
sub_names2: Final = list(dict.fromkeys(t.name for t in sub_types2))
476509
# Register lazy initialization for the loader
477510
self.add_lazy_init(
478511
LazyInitDef(
479-
loader_name,
480-
"{}.add_loaders(({},))".format(loader_name, ", ".join(sub_names2)),
512+
name=loader_name,
513+
init="{}.add_loaders(({},))".format(loader_name, ", ".join(sub_names2)),
514+
instance_type=f'{self.safe_name(name)}: TypeAlias = "'
515+
+ " | ".join({t.instance_type for t in sub_types2})
516+
+ '"',
481517
)
482518
)
483519
return loader_type
@@ -490,8 +526,9 @@ def type_loader(
490526
case "Expression" | "https://w3id.org/cwl/cwl#Expression" as decl:
491527
return self.declare_type(
492528
TypeDef(
493-
self.safe_name(decl) + "Loader",
494-
"_ExpressionLoader(str)",
529+
name=self.safe_name(decl) + "Loader",
530+
init="_ExpressionLoader(str)",
531+
instance_type="str",
495532
)
496533
)
497534
case str(decl):
@@ -545,6 +582,7 @@ def declare_field(
545582
) -> None:
546583
if self.current_class_is_abstract:
547584
return
585+
self.current_fieldtypes[self.safe_name(name)] = fieldtype
548586

549587
if optional:
550588
self.out.write(f""" {self.safe_name(name)} = None\n""")
@@ -728,11 +766,12 @@ def uri_loader(
728766
"""Construct the TypeDef for the given URI loader."""
729767
return self.declare_type(
730768
TypeDef(
731-
f"uri_{inner.name}_{scoped_id}_{vocab_term}_{ref_scope}_{no_link_check}",
732-
f"_URILoader({inner.name}, {scoped_id}, {vocab_term}, {ref_scope}, {no_link_check})",
769+
name=f"uri_{inner.name}_{scoped_id}_{vocab_term}_{ref_scope}_{no_link_check}",
770+
init=f"_URILoader({inner.name}, {scoped_id}, {vocab_term}, {ref_scope}, {no_link_check})",
733771
is_uri=True,
734772
scoped_id=scoped_id,
735773
ref_scope=ref_scope,
774+
instance_type=inner.instance_type,
736775
)
737776
)
738777

@@ -742,27 +781,30 @@ def idmap_loader(
742781
"""Construct the TypeDef for the given mapped ID loader."""
743782
return self.declare_type(
744783
TypeDef(
745-
f"idmap_{self.safe_name(field)}_{inner.name}",
746-
f"_IdMapLoader({inner.name}, '{map_subject}', '{map_predicate}')", # noqa: B907
784+
name=f"idmap_{self.safe_name(field)}_{inner.name}",
785+
init=f"_IdMapLoader({inner.name}, '{map_subject}', '{map_predicate}')", # noqa: B907
786+
instance_type=inner.instance_type,
747787
)
748788
)
749789

750790
def typedsl_loader(self, inner: TypeDef, ref_scope: int | None) -> TypeDef:
751791
"""Construct the TypeDef for the given DSL loader."""
752792
return self.declare_type(
753793
TypeDef(
754-
f"typedsl_{self.safe_name(inner.name)}_{ref_scope}",
755-
f"_TypeDSLLoader({self.safe_name(inner.name)}, {ref_scope}, " # noqa: B907
794+
name=f"typedsl_{self.safe_name(inner.name)}_{ref_scope}",
795+
init=f"_TypeDSLLoader({self.safe_name(inner.name)}, {ref_scope}, " # noqa: B907
756796
f"'{self.salad_version}')", # noqa: B907
797+
instance_type=inner.instance_type,
757798
)
758799
)
759800

760801
def secondaryfilesdsl_loader(self, inner: TypeDef) -> TypeDef:
761802
"""Construct the TypeDef for secondary files."""
762803
return self.declare_type(
763804
TypeDef(
764-
f"secondaryfilesdsl_{inner.name}",
765-
f"_UnionLoader((_SecondaryDSLLoader({inner.name}), {inner.name},))",
805+
name=f"secondaryfilesdsl_{inner.name}",
806+
init=f"_UnionLoader((_SecondaryDSLLoader({inner.name}), {inner.name},))",
807+
instance_type=inner.instance_type,
766808
)
767809
)
768810

@@ -787,6 +829,7 @@ def epilogue(self, root_loader: TypeDef) -> None:
787829
if self.lazy_inits:
788830
for lazy_init in self.lazy_inits.values():
789831
self.out.write(fmt(f"{lazy_init.init}\n", 0))
832+
self.out.write(fmt(f"{lazy_init.instance_type}", 0))
790833
self.out.write("\n")
791834

792835
self.out.write(

schema_salad/python_codegen_support.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from itertools import chain
1818
from mypy_extensions import trait
1919
from typing import Any, Final, Generic, TypeAlias, TypeVar, cast
20-
from typing import ClassVar # pylint: disable=unused-import # noqa: F401
20+
from typing import ClassVar, Literal # pylint: disable=unused-import # noqa: F401
2121
from urllib.parse import quote, urldefrag, urlparse, urlsplit, urlunsplit
2222
from urllib.request import pathname2url
2323

0 commit comments

Comments
 (0)