66"""
77Tool to check for Cython ABI changes in a given package.
88
9- There are different types of ABI changes, only one of which is covered by this tool:
9+ Cython must be installed in your venv to run this script.
10+
11+ There are different types of ABI changes, some of which are covered by this tool:
1012
1113- cdef function signatures (capsule strings) — covered here
12- - cdef class struct size (tp_basicsize) — not covered
13- - cdef class vtable layout / method reordering — not covered, and this one fails as silent UB rather than an import-time error
14- - Fused specialization ordering — partially covered (reorders manifest as capsule-name deltas, but the mapping is non-obvious)
14+ - cdef class struct size (tp_basicsize) — covered here
15+ - cdef struct / ctypedef struct field layout — covered here (via .pxd parsing)
16+ - cdef class vtable layout / method reordering — not covered, and this one fails
17+ as silent UB rather than an import-time error
18+ - Fused specialization ordering — partially covered (reorders manifest as
19+ capsule-name deltas, but the mapping is non-obvious)
1520
1621The workflow is basically:
1722
2126 package is installed), where `package_name` is the import path to the package,
2227 e.g. `cuda.bindings`:
2328
24- python check_cython_abi.py generate <package_name> <dir>
29+ python check_cython_abi.py generate <package_name> <dir>
2530
26313) Checkout a version with the changes to be tested, and build and install.
2732
28334) Check the ABI against the previously generated files by running:
2934
30- python check_cython_abi.py check <package_name> <dir>
35+ python check_cython_abi.py check <package_name> <dir>
3136"""
3237
3338import ctypes
3439import importlib
3540import json
3641import sys
3742import sysconfig
43+ from io import StringIO
3844from pathlib import Path
3945
46+ from Cython .Compiler import Parsing
47+ from Cython .Compiler .Scanning import FileSourceDescriptor , PyrexScanner
48+ from Cython .Compiler .Symtab import ModuleScope
49+ from Cython .Compiler .TreeFragment import StringParseContext
50+
4051EXT_SUFFIX = sysconfig .get_config_var ("EXT_SUFFIX" )
4152ABI_SUFFIX = ".abi.json"
4253
@@ -66,12 +77,12 @@ def import_from_path(root_package: str, root_dir: Path, path: Path) -> object:
6677
6778
6879def so_path_to_abi_path (so_path : Path , build_dir : Path , abi_dir : Path ) -> Path :
69- abi_name = short_stem (so_path .name ) + ABI_SUFFIX
80+ abi_name = f" { short_stem (so_path .name )} { ABI_SUFFIX } "
7081 return abi_dir / so_path .parent .relative_to (build_dir ) / abi_name
7182
7283
7384def abi_path_to_so_path (abi_path : Path , build_dir : Path , abi_dir : Path ) -> Path :
74- so_name = short_stem (abi_path .name ) + EXT_SUFFIX
85+ so_name = f" { short_stem (abi_path .name )} { EXT_SUFFIX } "
7586 return build_dir / abi_path .parent .relative_to (abi_dir ) / so_name
7687
7788
@@ -80,16 +91,244 @@ def is_cython_module(module: object) -> bool:
8091 return hasattr (module , "__pyx_capi__" )
8192
8293
83- def module_to_json (module : object ) -> dict :
84- """
85- Converts extracts information about a Cython-compiled .so into JSON-serializable information.
94+ ######################################################################################
95+ # STRUCTS
96+
97+
98+ def get_cdef_classes (module : object ) -> dict :
99+ """Extract cdef class (extension type) basicsize from a compiled Cython module."""
100+ result = {}
101+ module_name = module .__name__
102+ for name in sorted (dir (module )):
103+ obj = getattr (module , name , None )
104+ if isinstance (obj , type ) and getattr (obj , "__module__" , None ) == module_name and hasattr (obj , "__basicsize__" ):
105+ result [name ] = {"basicsize" : obj .__basicsize__ }
106+ return result
107+
108+
109+ def _format_base_type_name (bt : object ) -> str :
110+ """Format a Cython base type AST node into a type name string."""
111+ cls = type (bt ).__name__
112+ if cls == "CSimpleBaseTypeNode" :
113+ return bt .name
114+ if cls == "CComplexBaseTypeNode" :
115+ inner = _format_base_type_name (bt .base_type )
116+ return _unwrap_declarator (inner , bt .declarator )[0 ]
117+ return cls
118+
119+
120+ def _unwrap_declarator (type_str : str , decl : object ) -> tuple [str , str ]:
121+ """Unwrap nested Cython declarator nodes to get (type_string, field_name)."""
122+ cls = type (decl ).__name__
123+ if cls == "CNameDeclaratorNode" :
124+ return type_str , decl .name
125+ if cls == "CPtrDeclaratorNode" :
126+ return _unwrap_declarator (f"{ type_str } *" , decl .base )
127+ if cls == "CReferenceDeclaratorNode" :
128+ return _unwrap_declarator (f"{ type_str } &" , decl .base )
129+ if cls == "CArrayDeclaratorNode" :
130+ dim = getattr (decl , "dimension" , None )
131+ size = getattr (dim , "value" , "" ) if dim is not None else ""
132+ return _unwrap_declarator (f"{ type_str } [{ size } ]" , decl .base )
133+ return type_str , ""
134+
135+
136+ def _extract_fields_from_cvardef (node : object ) -> list :
137+ """Extract [type, name] pairs from a CVarDefNode."""
138+ results = []
139+ for d in node .declarators :
140+ type_str , name = _unwrap_declarator (_format_base_type_name (node .base_type ), d )
141+ if name :
142+ results .append ([type_str , name ])
143+ return results
144+
145+
146+ def _collect_cvardef_fields (node : object ) -> list :
147+ """Recursively collect CVarDefNode fields, skipping nested struct/class/func defs."""
148+ fields = []
149+ if type (node ).__name__ == "CVarDefNode" :
150+ fields .extend (_extract_fields_from_cvardef (node ))
151+ skip = ("CStructOrUnionDefNode" , "CClassDefNode" , "CFuncDefNode" )
152+ for attr_name in getattr (node , "child_attrs" , []):
153+ child = getattr (node , attr_name , None )
154+ if child is None :
155+ continue
156+ if isinstance (child , list ):
157+ for item in child :
158+ if hasattr (item , "child_attrs" ) and type (item ).__name__ not in skip :
159+ fields .extend (_collect_cvardef_fields (item ))
160+ elif hasattr (child , "child_attrs" ) and type (child ).__name__ not in skip :
161+ fields .extend (_collect_cvardef_fields (child ))
162+ return fields
163+
164+
165+ def _collect_structs_from_tree (node : object ) -> dict :
166+ """Walk a Cython AST and collect struct/class field definitions."""
167+ result = {}
168+ cls = type (node ).__name__
169+
170+ if cls == "CStructOrUnionDefNode" :
171+ fields = []
172+ for attr in node .attributes :
173+ if type (attr ).__name__ == "CVarDefNode" :
174+ fields .extend (_extract_fields_from_cvardef (attr ))
175+ if fields :
176+ result [node .name ] = {"fields" : fields }
177+
178+ elif cls == "CClassDefNode" :
179+ fields = _collect_cvardef_fields (node .body )
180+ if fields :
181+ result [node .class_name ] = {"fields" : fields }
182+
183+ for attr_name in getattr (node , "child_attrs" , []):
184+ child = getattr (node , attr_name , None )
185+ if child is None :
186+ continue
187+ if isinstance (child , list ):
188+ for item in child :
189+ if hasattr (item , "child_attrs" ):
190+ result .update (_collect_structs_from_tree (item ))
191+ elif hasattr (child , "child_attrs" ):
192+ result .update (_collect_structs_from_tree (child ))
193+
194+ return result
195+
196+
197+ class _PxdParseContext (StringParseContext ):
198+ """Parse context that resolves includes via real paths and ignores unknown cimports."""
199+
200+ def find_module (
201+ self ,
202+ module_name ,
203+ from_module = None , # noqa: ARG002
204+ pos = None , # noqa: ARG002
205+ need_pxd = 1 , # noqa: ARG002
206+ absolute_fallback = True , # noqa: ARG002
207+ relative_import = False , # noqa: ARG002
208+ ):
209+ return ModuleScope (module_name , parent_module = None , context = self )
210+
211+
212+ def parse_pxd_structs (pxd_path : Path ) -> dict :
213+ """Parse struct and cdef class field definitions from a .pxd file.
214+
215+ Uses Cython's own parser (in .pxd mode) for reliable extraction.
216+ cimport lines in the top-level file are stripped since they are
217+ unresolvable without the full compilation context; included files
218+ are handled via a lenient context that returns dummy scopes.
219+
220+ Returns a dict mapping struct/class name to {"fields": [[type, name], ...]}.
86221 """
87- # Sort the dictionary by keys to make diffs in the JSON files smaller
88- pyx_capi = module .__pyx_capi__
222+ text = pxd_path .read_text (encoding = "utf-8" )
223+
224+ # Strip cimport lines (unresolvable without full compilation context)
225+ lines = text .splitlines ()
226+ cleaned = "\n " .join ("" if (" cimport " in ln or ln .lstrip ().startswith ("cimport " )) else ln for ln in lines )
227+
228+ name = pxd_path .stem
229+ context = _PxdParseContext (name , include_directories = [str (pxd_path .parent )])
230+ code_source = FileSourceDescriptor (str (pxd_path ))
231+ scope = context .find_module (name , pos = (code_source , 1 , 0 ), need_pxd = False )
232+
233+ scanner = PyrexScanner (
234+ StringIO (cleaned ),
235+ code_source ,
236+ source_encoding = "UTF-8" ,
237+ scope = scope ,
238+ context = context ,
239+ initial_pos = (code_source , 1 , 0 ),
240+ )
241+ tree = Parsing .p_module (scanner , pxd = 1 , full_module_name = name )
242+ tree .scope = scope
243+
244+ return _collect_structs_from_tree (tree )
245+
246+
247+ def get_structs (module : object ) -> dict :
248+ # Extract cdef class basicsize from compiled module (primary)
249+ structs = get_cdef_classes (module )
250+ so_path = Path (module .__file__ )
251+
252+ # Parse neighboring .pxd file for struct/class field layout (fallback complement)
253+ if so_path is not None :
254+ pxd_path = so_path .parent / f"{ short_stem (so_path .name )} .pxd"
255+ if pxd_path .is_file ():
256+ pxd_structs = parse_pxd_structs (pxd_path )
257+ for name , info in pxd_structs .items ():
258+ if name in structs :
259+ structs [name ].update (info )
260+ else :
261+ structs [name ] = info
262+
263+ return dict (sorted (structs .items ()))
264+
265+
266+ def _report_field_changes (name : str , expected_fields : list , found_fields : list ) -> None :
267+ """Print detailed field-level differences for a struct."""
268+ expected_dict = {f [1 ]: f [0 ] for f in expected_fields }
269+ found_dict = {f [1 ]: f [0 ] for f in found_fields }
270+
271+ for field_name , field_type in expected_dict .items ():
272+ if field_name not in found_dict :
273+ print (f" Struct { name } : removed field '{ field_name } '" )
274+ elif found_dict [field_name ] != field_type :
275+ print (
276+ f" Struct { name } : field '{ field_name } ' type changed from '{ field_type } ' to '{ found_dict [field_name ]} '"
277+ )
278+ for field_name in found_dict :
279+ if field_name not in expected_dict :
280+ print (f" Struct { name } : added field '{ field_name } '" )
281+
282+ expected_common = [f [1 ] for f in expected_fields if f [1 ] in found_dict ]
283+ found_common = [f [1 ] for f in found_fields if f [1 ] in expected_dict ]
284+ if expected_common != found_common :
285+ print (f" Struct { name } : fields were reordered" )
286+
287+
288+ def check_structs (expected : dict , found : dict ) -> tuple [bool , bool ]:
289+ has_errors = False
290+ has_allowed_changes = False
291+
292+ for name , expected_info in expected .items ():
293+ if name not in found :
294+ print (f" Missing struct/class: { name } " )
295+ has_errors = True
296+ continue
297+ found_info = found [name ]
89298
90- return {
91- "functions" : {k : get_capsule_name (pyx_capi [k ]) for k in sorted (pyx_capi .keys ())},
92- }
299+ if "basicsize" in expected_info :
300+ if "basicsize" not in found_info :
301+ print (f" Struct { name } : basicsize no longer available" )
302+ has_errors = True
303+ elif found_info ["basicsize" ] != expected_info ["basicsize" ]:
304+ print (
305+ f" Struct { name } : basicsize changed from { expected_info ['basicsize' ]} to { found_info ['basicsize' ]} "
306+ )
307+ has_errors = True
308+
309+ if "fields" in expected_info :
310+ if "fields" not in found_info :
311+ print (f" Struct { name } : field information no longer available" )
312+ has_errors = True
313+ elif found_info ["fields" ] != expected_info ["fields" ]:
314+ _report_field_changes (name , expected_info ["fields" ], found_info ["fields" ])
315+ has_errors = True
316+
317+ for name in found :
318+ if name not in expected :
319+ print (f" Added struct/class: { name } " )
320+ has_allowed_changes = True
321+
322+ return has_errors , has_allowed_changes
323+
324+
325+ ######################################################################################
326+ # FUNCTIONS
327+
328+
329+ def get_functions (module : object ) -> dict :
330+ pyx_capi = module .__pyx_capi__
331+ return {k : get_capsule_name (pyx_capi [k ]) for k in sorted (pyx_capi .keys ())}
93332
94333
95334def check_functions (expected : dict [str , str ], found : dict [str , str ]) -> tuple [bool , bool ]:
@@ -109,17 +348,29 @@ def check_functions(expected: dict[str, str], found: dict[str, str]) -> tuple[bo
109348 return has_errors , has_allowed_changes
110349
111350
351+ ######################################################################################
352+ # MAIN
353+
354+
112355def compare (expected : dict , found : dict ) -> tuple [bool , bool ]:
113356 has_errors = False
114357 has_allowed_changes = False
115358
116- errors , allowed_changes = check_functions (expected ["functions" ], found ["functions" ])
117- has_errors |= errors
118- has_allowed_changes |= allowed_changes
359+ for func , name in [(check_functions , "functions" ), (check_structs , "structs" )]:
360+ errors , allowed_changes = func (expected [name ], found [name ])
361+ has_errors |= errors
362+ has_allowed_changes |= allowed_changes
119363
120364 return has_errors , has_allowed_changes
121365
122366
367+ def module_to_json (module : object ) -> dict :
368+ """
369+ Extracts information about a Cython-compiled .so into JSON-serializable information.
370+ """
371+ return {"functions" : get_functions (module ), "structs" : get_structs (module )}
372+
373+
123374def check (package : str , abi_dir : Path ) -> bool :
124375 build_dir = get_package_path (package )
125376
@@ -168,7 +419,7 @@ def check(package: str, abi_dir: Path) -> bool:
168419 return False
169420
170421
171- def regenerate (package : str , abi_dir : Path ) -> bool :
422+ def generate (package : str , abi_dir : Path ) -> bool :
172423 if abi_dir .is_dir ():
173424 print (f"ABI directory { abi_dir } already exists. Please remove it before regenerating." )
174425 return True
@@ -199,10 +450,10 @@ def regenerate(package: str, abi_dir: Path) -> bool:
199450
200451 subparsers = parser .add_subparsers ()
201452
202- regen_parser = subparsers .add_parser ("generate" , help = "Regenerate the ABI files" )
203- regen_parser .set_defaults (func = regenerate )
204- regen_parser .add_argument ("package" , help = "Python package to collect data from" )
205- regen_parser .add_argument ("dir" , help = "Output directory to save data to" )
453+ gen_parser = subparsers .add_parser ("generate" , help = "Regenerate the ABI files" )
454+ gen_parser .set_defaults (func = generate )
455+ gen_parser .add_argument ("package" , help = "Python package to collect data from" )
456+ gen_parser .add_argument ("dir" , help = "Output directory to save data to" )
206457
207458 check_parser = subparsers .add_parser ("check" , help = "Check the API against existing ABI files" )
208459 check_parser .set_defaults (func = check )
0 commit comments