11import ast
22from functools import lru_cache
33from pathlib import Path
4- from typing import Set , Union
4+ from typing import FrozenSet , List , Set , Union
55
66from polylith .imports import SYMBOLS , extract_api , list_imports , parse_module
77
8+ PACKAGE_INTERFACE = "__init__.py"
9+ ALL_STATEMENT = "__all__"
10+
811
912def target_names (t : ast .AST ) -> Set [str ]:
1013 if isinstance (t , ast .Name ):
@@ -53,13 +56,25 @@ def extract_public_variables(path: Path) -> Set[str]:
5356
5457
5558def is_the_all_statement (target : ast .expr ) -> bool :
56- return isinstance (target , ast .Name ) and target .id == "__all__"
59+ return isinstance (target , ast .Name ) and target .id == ALL_STATEMENT
5760
5861
5962def is_string_constant (expression : ast .AST ) -> bool :
6063 return isinstance (expression , ast .Constant ) and isinstance (expression .value , str )
6164
6265
66+ def attribute_expr_to_parts (expr : ast .AST ) -> List [str ]:
67+ if isinstance (expr , ast .Name ):
68+ return [expr .id ]
69+
70+ if isinstance (expr , ast .Attribute ):
71+ parent = attribute_expr_to_parts (expr .value )
72+
73+ return [* parent , expr .attr ] if parent else []
74+
75+ return []
76+
77+
6378def find_the_all_variable (statement : ast .stmt ) -> Union [Set [str ], None ]:
6479 if not isinstance (statement , ast .Assign ):
6580 return None
@@ -76,12 +91,69 @@ def find_the_all_variable(statement: ast.stmt) -> Union[Set[str], None]:
7691 return {e .value for e in statement .value .elts if isinstance (e , ast .Constant )}
7792
7893
79- def extract_the_all_variable (path : Path ) -> Set [str ]:
94+ def find_the_all_pointer (statement : ast .stmt ) -> Union [str , None ]:
95+ if not isinstance (statement , ast .Assign ):
96+ return None
97+
98+ if not any (is_the_all_statement (t ) for t in statement .targets ):
99+ return None
100+
101+ parts = attribute_expr_to_parts (statement .value )
102+
103+ if not parts :
104+ return None
105+
106+ * module_path , rest = parts
107+
108+ if rest != ALL_STATEMENT :
109+ return None
110+
111+ return "." .join (module_path )
112+
113+
114+ def resolve_local_module_path (package_dir : Path , module_ref : str ) -> Union [Path , None ]:
115+ parts = tuple (p for p in module_ref .split ("." ) if p )
116+
117+ if not parts :
118+ return None
119+
120+ module_file = package_dir .joinpath (* parts ).with_suffix (".py" )
121+
122+ if module_file .exists ():
123+ return module_file
124+
125+ module_init = package_dir .joinpath (* parts , PACKAGE_INTERFACE )
126+
127+ return module_init if module_init .exists () else None
128+
129+
130+ def _extract_the_all_variable (path : Path , visited : FrozenSet [Path ]) -> Set [str ]:
131+ if path in visited :
132+ return set ()
133+
134+ visited = visited | frozenset ({path })
135+
80136 tree = parse (path )
81137
82- res = [find_the_all_variable (s ) for s in tree .body ]
138+ literals = [find_the_all_variable (s ) for s in tree .body ]
139+ literal = next ((r for r in literals if r is not None ), None )
140+
141+ if literal is not None :
142+ return literal
143+
144+ pointers = (find_the_all_pointer (s ) for s in tree .body )
145+ pointer = next ((p for p in pointers if p is not None ), None )
146+
147+ if not pointer :
148+ return set ()
83149
84- return next ((r for r in res if r is not None ), set ())
150+ resolved = resolve_local_module_path (path .parent , pointer )
151+
152+ return _extract_the_all_variable (resolved , visited ) if resolved else set ()
153+
154+
155+ def extract_the_all_variable (path : Path ) -> Set [str ]:
156+ return _extract_the_all_variable (path , frozenset ())
85157
86158
87159def extract_imported_api (path : Path ) -> Set [str ]:
@@ -98,7 +170,7 @@ def fetch_api_for_path(path: Path) -> Set[str]:
98170
99171
100172def fetch_api (paths : Set [Path ]) -> dict :
101- interface_paths = [Path (p / "__init__.py" ) for p in paths ]
173+ interface_paths = [Path (p / PACKAGE_INTERFACE ) for p in paths ]
102174
103175 interfaces = [p for p in interface_paths if p .exists ()]
104176
0 commit comments