Skip to content

Commit b295704

Browse files
committed
Check struct and class changes in ABI checking tool
1 parent a18022c commit b295704

File tree

1 file changed

+275
-24
lines changed

1 file changed

+275
-24
lines changed

toolshed/check_cython_abi.py

Lines changed: 275 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,17 @@
66
"""
77
Tool to check for Cython ABI changes in a given package.
88
9-
There are different types of ABI changes, only one of which is covered by this tool:
9+
Cython must be installed in your venv to run this script.
10+
11+
There are different types of ABI changes, some of which are covered by this tool:
1012
1113
- cdef function signatures (capsule strings) — covered here
12-
- cdef class struct size (tp_basicsize) — not covered
13-
- cdef class vtable layout / method reordering — not covered, and this one fails as silent UB rather than an import-time error
14-
- Fused specialization ordering — partially covered (reorders manifest as capsule-name deltas, but the mapping is non-obvious)
14+
- cdef class struct size (tp_basicsize) — covered here
15+
- cdef struct / ctypedef struct field layout — covered here (via .pxd parsing)
16+
- cdef class vtable layout / method reordering — not covered, and this one fails
17+
as silent UB rather than an import-time error
18+
- Fused specialization ordering — partially covered (reorders manifest as
19+
capsule-name deltas, but the mapping is non-obvious)
1520
1621
The workflow is basically:
1722
@@ -21,22 +26,28 @@
2126
package is installed), where `package_name` is the import path to the package,
2227
e.g. `cuda.bindings`:
2328
24-
python check_cython_abi.py generate <package_name> <dir>
29+
python check_cython_abi.py generate <package_name> <dir>
2530
2631
3) Checkout a version with the changes to be tested, and build and install.
2732
2833
4) Check the ABI against the previously generated files by running:
2934
30-
python check_cython_abi.py check <package_name> <dir>
35+
python check_cython_abi.py check <package_name> <dir>
3136
"""
3237

3338
import ctypes
3439
import importlib
3540
import json
3641
import sys
3742
import sysconfig
43+
from io import StringIO
3844
from pathlib import Path
3945

46+
from Cython.Compiler import Parsing
47+
from Cython.Compiler.Scanning import FileSourceDescriptor, PyrexScanner
48+
from Cython.Compiler.Symtab import ModuleScope
49+
from Cython.Compiler.TreeFragment import StringParseContext
50+
4051
EXT_SUFFIX = sysconfig.get_config_var("EXT_SUFFIX")
4152
ABI_SUFFIX = ".abi.json"
4253

@@ -66,12 +77,12 @@ def import_from_path(root_package: str, root_dir: Path, path: Path) -> object:
6677

6778

6879
def so_path_to_abi_path(so_path: Path, build_dir: Path, abi_dir: Path) -> Path:
69-
abi_name = short_stem(so_path.name) + ABI_SUFFIX
80+
abi_name = f"{short_stem(so_path.name)}{ABI_SUFFIX}"
7081
return abi_dir / so_path.parent.relative_to(build_dir) / abi_name
7182

7283

7384
def abi_path_to_so_path(abi_path: Path, build_dir: Path, abi_dir: Path) -> Path:
74-
so_name = short_stem(abi_path.name) + EXT_SUFFIX
85+
so_name = f"{short_stem(abi_path.name)}{EXT_SUFFIX}"
7586
return build_dir / abi_path.parent.relative_to(abi_dir) / so_name
7687

7788

@@ -80,16 +91,244 @@ def is_cython_module(module: object) -> bool:
8091
return hasattr(module, "__pyx_capi__")
8192

8293

83-
def module_to_json(module: object) -> dict:
84-
"""
85-
Converts extracts information about a Cython-compiled .so into JSON-serializable information.
94+
######################################################################################
95+
# STRUCTS
96+
97+
98+
def get_cdef_classes(module: object) -> dict:
99+
"""Extract cdef class (extension type) basicsize from a compiled Cython module."""
100+
result = {}
101+
module_name = module.__name__
102+
for name in sorted(dir(module)):
103+
obj = getattr(module, name, None)
104+
if isinstance(obj, type) and getattr(obj, "__module__", None) == module_name and hasattr(obj, "__basicsize__"):
105+
result[name] = {"basicsize": obj.__basicsize__}
106+
return result
107+
108+
109+
def _format_base_type_name(bt: object) -> str:
110+
"""Format a Cython base type AST node into a type name string."""
111+
cls = type(bt).__name__
112+
if cls == "CSimpleBaseTypeNode":
113+
return bt.name
114+
if cls == "CComplexBaseTypeNode":
115+
inner = _format_base_type_name(bt.base_type)
116+
return _unwrap_declarator(inner, bt.declarator)[0]
117+
return cls
118+
119+
120+
def _unwrap_declarator(type_str: str, decl: object) -> tuple[str, str]:
121+
"""Unwrap nested Cython declarator nodes to get (type_string, field_name)."""
122+
cls = type(decl).__name__
123+
if cls == "CNameDeclaratorNode":
124+
return type_str, decl.name
125+
if cls == "CPtrDeclaratorNode":
126+
return _unwrap_declarator(f"{type_str}*", decl.base)
127+
if cls == "CReferenceDeclaratorNode":
128+
return _unwrap_declarator(f"{type_str}&", decl.base)
129+
if cls == "CArrayDeclaratorNode":
130+
dim = getattr(decl, "dimension", None)
131+
size = getattr(dim, "value", "") if dim is not None else ""
132+
return _unwrap_declarator(f"{type_str}[{size}]", decl.base)
133+
return type_str, ""
134+
135+
136+
def _extract_fields_from_cvardef(node: object) -> list:
137+
"""Extract [type, name] pairs from a CVarDefNode."""
138+
results = []
139+
for d in node.declarators:
140+
type_str, name = _unwrap_declarator(_format_base_type_name(node.base_type), d)
141+
if name:
142+
results.append([type_str, name])
143+
return results
144+
145+
146+
def _collect_cvardef_fields(node: object) -> list:
147+
"""Recursively collect CVarDefNode fields, skipping nested struct/class/func defs."""
148+
fields = []
149+
if type(node).__name__ == "CVarDefNode":
150+
fields.extend(_extract_fields_from_cvardef(node))
151+
skip = ("CStructOrUnionDefNode", "CClassDefNode", "CFuncDefNode")
152+
for attr_name in getattr(node, "child_attrs", []):
153+
child = getattr(node, attr_name, None)
154+
if child is None:
155+
continue
156+
if isinstance(child, list):
157+
for item in child:
158+
if hasattr(item, "child_attrs") and type(item).__name__ not in skip:
159+
fields.extend(_collect_cvardef_fields(item))
160+
elif hasattr(child, "child_attrs") and type(child).__name__ not in skip:
161+
fields.extend(_collect_cvardef_fields(child))
162+
return fields
163+
164+
165+
def _collect_structs_from_tree(node: object) -> dict:
166+
"""Walk a Cython AST and collect struct/class field definitions."""
167+
result = {}
168+
cls = type(node).__name__
169+
170+
if cls == "CStructOrUnionDefNode":
171+
fields = []
172+
for attr in node.attributes:
173+
if type(attr).__name__ == "CVarDefNode":
174+
fields.extend(_extract_fields_from_cvardef(attr))
175+
if fields:
176+
result[node.name] = {"fields": fields}
177+
178+
elif cls == "CClassDefNode":
179+
fields = _collect_cvardef_fields(node.body)
180+
if fields:
181+
result[node.class_name] = {"fields": fields}
182+
183+
for attr_name in getattr(node, "child_attrs", []):
184+
child = getattr(node, attr_name, None)
185+
if child is None:
186+
continue
187+
if isinstance(child, list):
188+
for item in child:
189+
if hasattr(item, "child_attrs"):
190+
result.update(_collect_structs_from_tree(item))
191+
elif hasattr(child, "child_attrs"):
192+
result.update(_collect_structs_from_tree(child))
193+
194+
return result
195+
196+
197+
class _PxdParseContext(StringParseContext):
198+
"""Parse context that resolves includes via real paths and ignores unknown cimports."""
199+
200+
def find_module(
201+
self,
202+
module_name,
203+
from_module=None, # noqa: ARG002
204+
pos=None, # noqa: ARG002
205+
need_pxd=1, # noqa: ARG002
206+
absolute_fallback=True, # noqa: ARG002
207+
relative_import=False, # noqa: ARG002
208+
):
209+
return ModuleScope(module_name, parent_module=None, context=self)
210+
211+
212+
def parse_pxd_structs(pxd_path: Path) -> dict:
213+
"""Parse struct and cdef class field definitions from a .pxd file.
214+
215+
Uses Cython's own parser (in .pxd mode) for reliable extraction.
216+
cimport lines in the top-level file are stripped since they are
217+
unresolvable without the full compilation context; included files
218+
are handled via a lenient context that returns dummy scopes.
219+
220+
Returns a dict mapping struct/class name to {"fields": [[type, name], ...]}.
86221
"""
87-
# Sort the dictionary by keys to make diffs in the JSON files smaller
88-
pyx_capi = module.__pyx_capi__
222+
text = pxd_path.read_text(encoding="utf-8")
223+
224+
# Strip cimport lines (unresolvable without full compilation context)
225+
lines = text.splitlines()
226+
cleaned = "\n".join("" if (" cimport " in ln or ln.lstrip().startswith("cimport ")) else ln for ln in lines)
227+
228+
name = pxd_path.stem
229+
context = _PxdParseContext(name, include_directories=[str(pxd_path.parent)])
230+
code_source = FileSourceDescriptor(str(pxd_path))
231+
scope = context.find_module(name, pos=(code_source, 1, 0), need_pxd=False)
232+
233+
scanner = PyrexScanner(
234+
StringIO(cleaned),
235+
code_source,
236+
source_encoding="UTF-8",
237+
scope=scope,
238+
context=context,
239+
initial_pos=(code_source, 1, 0),
240+
)
241+
tree = Parsing.p_module(scanner, pxd=1, full_module_name=name)
242+
tree.scope = scope
243+
244+
return _collect_structs_from_tree(tree)
245+
246+
247+
def get_structs(module: object) -> dict:
248+
# Extract cdef class basicsize from compiled module (primary)
249+
structs = get_cdef_classes(module)
250+
so_path = Path(module.__file__)
251+
252+
# Parse neighboring .pxd file for struct/class field layout (fallback complement)
253+
if so_path is not None:
254+
pxd_path = so_path.parent / f"{short_stem(so_path.name)}.pxd"
255+
if pxd_path.is_file():
256+
pxd_structs = parse_pxd_structs(pxd_path)
257+
for name, info in pxd_structs.items():
258+
if name in structs:
259+
structs[name].update(info)
260+
else:
261+
structs[name] = info
262+
263+
return dict(sorted(structs.items()))
264+
265+
266+
def _report_field_changes(name: str, expected_fields: list, found_fields: list) -> None:
267+
"""Print detailed field-level differences for a struct."""
268+
expected_dict = {f[1]: f[0] for f in expected_fields}
269+
found_dict = {f[1]: f[0] for f in found_fields}
270+
271+
for field_name, field_type in expected_dict.items():
272+
if field_name not in found_dict:
273+
print(f" Struct {name}: removed field '{field_name}'")
274+
elif found_dict[field_name] != field_type:
275+
print(
276+
f" Struct {name}: field '{field_name}' type changed from '{field_type}' to '{found_dict[field_name]}'"
277+
)
278+
for field_name in found_dict:
279+
if field_name not in expected_dict:
280+
print(f" Struct {name}: added field '{field_name}'")
281+
282+
expected_common = [f[1] for f in expected_fields if f[1] in found_dict]
283+
found_common = [f[1] for f in found_fields if f[1] in expected_dict]
284+
if expected_common != found_common:
285+
print(f" Struct {name}: fields were reordered")
286+
287+
288+
def check_structs(expected: dict, found: dict) -> tuple[bool, bool]:
289+
has_errors = False
290+
has_allowed_changes = False
291+
292+
for name, expected_info in expected.items():
293+
if name not in found:
294+
print(f" Missing struct/class: {name}")
295+
has_errors = True
296+
continue
297+
found_info = found[name]
89298

90-
return {
91-
"functions": {k: get_capsule_name(pyx_capi[k]) for k in sorted(pyx_capi.keys())},
92-
}
299+
if "basicsize" in expected_info:
300+
if "basicsize" not in found_info:
301+
print(f" Struct {name}: basicsize no longer available")
302+
has_errors = True
303+
elif found_info["basicsize"] != expected_info["basicsize"]:
304+
print(
305+
f" Struct {name}: basicsize changed from {expected_info['basicsize']} to {found_info['basicsize']}"
306+
)
307+
has_errors = True
308+
309+
if "fields" in expected_info:
310+
if "fields" not in found_info:
311+
print(f" Struct {name}: field information no longer available")
312+
has_errors = True
313+
elif found_info["fields"] != expected_info["fields"]:
314+
_report_field_changes(name, expected_info["fields"], found_info["fields"])
315+
has_errors = True
316+
317+
for name in found:
318+
if name not in expected:
319+
print(f" Added struct/class: {name}")
320+
has_allowed_changes = True
321+
322+
return has_errors, has_allowed_changes
323+
324+
325+
######################################################################################
326+
# FUNCTIONS
327+
328+
329+
def get_functions(module: object) -> dict:
330+
pyx_capi = module.__pyx_capi__
331+
return {k: get_capsule_name(pyx_capi[k]) for k in sorted(pyx_capi.keys())}
93332

94333

95334
def check_functions(expected: dict[str, str], found: dict[str, str]) -> tuple[bool, bool]:
@@ -109,17 +348,29 @@ def check_functions(expected: dict[str, str], found: dict[str, str]) -> tuple[bo
109348
return has_errors, has_allowed_changes
110349

111350

351+
######################################################################################
352+
# MAIN
353+
354+
112355
def compare(expected: dict, found: dict) -> tuple[bool, bool]:
113356
has_errors = False
114357
has_allowed_changes = False
115358

116-
errors, allowed_changes = check_functions(expected["functions"], found["functions"])
117-
has_errors |= errors
118-
has_allowed_changes |= allowed_changes
359+
for func, name in [(check_functions, "functions"), (check_structs, "structs")]:
360+
errors, allowed_changes = func(expected[name], found[name])
361+
has_errors |= errors
362+
has_allowed_changes |= allowed_changes
119363

120364
return has_errors, has_allowed_changes
121365

122366

367+
def module_to_json(module: object) -> dict:
368+
"""
369+
Extracts information about a Cython-compiled .so into JSON-serializable information.
370+
"""
371+
return {"functions": get_functions(module), "structs": get_structs(module)}
372+
373+
123374
def check(package: str, abi_dir: Path) -> bool:
124375
build_dir = get_package_path(package)
125376

@@ -168,7 +419,7 @@ def check(package: str, abi_dir: Path) -> bool:
168419
return False
169420

170421

171-
def regenerate(package: str, abi_dir: Path) -> bool:
422+
def generate(package: str, abi_dir: Path) -> bool:
172423
if abi_dir.is_dir():
173424
print(f"ABI directory {abi_dir} already exists. Please remove it before regenerating.")
174425
return True
@@ -199,10 +450,10 @@ def regenerate(package: str, abi_dir: Path) -> bool:
199450

200451
subparsers = parser.add_subparsers()
201452

202-
regen_parser = subparsers.add_parser("generate", help="Regenerate the ABI files")
203-
regen_parser.set_defaults(func=regenerate)
204-
regen_parser.add_argument("package", help="Python package to collect data from")
205-
regen_parser.add_argument("dir", help="Output directory to save data to")
453+
gen_parser = subparsers.add_parser("generate", help="Regenerate the ABI files")
454+
gen_parser.set_defaults(func=generate)
455+
gen_parser.add_argument("package", help="Python package to collect data from")
456+
gen_parser.add_argument("dir", help="Output directory to save data to")
206457

207458
check_parser = subparsers.add_parser("check", help="Check the API against existing ABI files")
208459
check_parser.set_defaults(func=check)

0 commit comments

Comments
 (0)