Skip to content

Commit 0291458

Browse files
PR feedback refactor
1. Changed fixture to provide a function that empties the stack of contexts. The function has hidden max_iters bound. If exceeded, a RuntimeError is raised 2. Modified _device_unset_current utility function to return a boolean. True is returned is a context was popped, False if the stack was already empty.
1 parent 7e45902 commit 0291458

File tree

2 files changed

+26
-15
lines changed

2 files changed

+26
-15
lines changed

cuda_core/tests/conftest.py

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -31,29 +31,45 @@ def init_cuda():
3131
device = Device()
3232
device.set_current()
3333
yield
34-
_device_unset_current()
34+
_ = _device_unset_current()
3535

3636

37-
def _device_unset_current():
37+
def _device_unset_current() -> bool:
38+
"""Pop current CUDA context.
39+
40+
Returns True if context was popped, False it the stack was empty.
41+
"""
3842
ctx = handle_return(driver.cuCtxGetCurrent())
3943
if int(ctx) == 0:
4044
# no active context, do nothing
41-
return
45+
return False
4246
handle_return(driver.cuCtxPopCurrent())
4347
if hasattr(_device._tls, "devices"):
4448
del _device._tls.devices
49+
return True
4550

4651

4752
@pytest.fixture(scope="function")
4853
def deinit_cuda():
4954
# TODO: rename this to e.g. deinit_context
5055
yield
51-
_device_unset_current()
56+
_ = _device_unset_current()
5257

5358

5459
@pytest.fixture(scope="function")
55-
def deinit_context_function():
56-
return _device_unset_current
60+
def deinit_all_contexts_function():
61+
def pop_all_contexts():
62+
max_iters = 256
63+
for _ in range(max_iters):
64+
if _device_unset_current():
65+
# context was popped, continue until stack is empty
66+
continue
67+
# no active context, we are ready
68+
break
69+
else:
70+
raise RuntimeError(f"Number of iterations popping current CUDA contexts, exceded {max_iters}")
71+
72+
return pop_all_contexts
5773

5874

5975
# samples relying on cffi could fail as the modules cannot be imported

cuda_core/tests/test_module.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -234,7 +234,7 @@ class ExpectedStruct(ctypes.Structure):
234234

235235

236236
@skipif_testing_with_compute_sanitizer
237-
def test_num_args_error_handling(deinit_context_function, cuda12_prerequisite_check):
237+
def test_num_args_error_handling(deinit_all_contexts_function, cuda12_prerequisite_check):
238238
if not cuda12_prerequisite_check:
239239
pytest.skip("Test requires CUDA 12")
240240
src = "__global__ void foo(int a) { }"
@@ -244,14 +244,9 @@ def test_num_args_error_handling(deinit_context_function, cuda12_prerequisite_ch
244244
name_expressions=("foo",),
245245
)
246246
krn = mod.get_kernel("foo")
247-
# Unset current context using function from conftest
248-
while True:
249-
deinit_context_function()
250-
ctx = handle_return(driver.cuCtxGetCurrent())
251-
if int(ctx) == 0:
252-
# no active context, we are ready
253-
break
254-
# with no context, cuKernelGetParamInfo would report
247+
# empty driver's context stack using function from conftest
248+
deinit_all_contexts_function()
249+
# with no current context, cuKernelGetParamInfo would report
255250
# exception which we expect to handle by raising
256251
with pytest.raises(CUDAError):
257252
# assignment resolves linter error "B018: useless expression"

0 commit comments

Comments
 (0)