Skip to content

Commit f29b935

Browse files
committed
cuda.core.system: Add MIG-related APIs
1 parent 6ec277e commit f29b935

File tree

3 files changed

+61
-8
lines changed

3 files changed

+61
-8
lines changed

cuda_core/cuda/core/system/_device.pyx

Lines changed: 30 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ include "_fan.pxi"
3232
include "_field_values.pxi"
3333
include "_inforom.pxi"
3434
include "_memory.pxi"
35+
include "_mig.pxi"
3536
include "_pci_info.pxi"
3637
include "_performance.pxi"
3738
include "_repair_status.pxi"
@@ -132,12 +133,19 @@ cdef class Device:
132133
board serial identifier.
133134

134135
In the upstream NVML C++ API, the UUID includes a ``gpu-`` or ``mig-``
135-
prefix. That is not included in ``cuda.core.system``.
136+
prefix. If you a `uuid` without that prefix (for example, to interact
137+
with CUDA), use the `uuid_without_prefix` property.
136138
"""
137-
# NVML UUIDs have a `GPU-` or `MIG-` prefix. We remove that here.
139+
return nvml.device_get_uuid(self._handle)
138140

139-
# TODO: If the user cares about the prefix, we will expose that in the
140-
# future using the MIG-related APIs in NVML.
141+
@property
142+
def uuid_without_prefix(self) -> str:
143+
"""
144+
Retrieves the globally unique immutable UUID associated with this
145+
device, as a 5 part hexadecimal string, that augments the immutable,
146+
board serial identifier.
147+
"""
148+
# NVML UUIDs have a `GPU-` or `MIG-` prefix. We remove that here.
141149
return nvml.device_get_uuid(self._handle)[4:]
142150

143151
@property
@@ -265,7 +273,7 @@ cdef class Device:
265273
# search all the devices for one with a matching UUID.
266274

267275
for cuda_device in CudaDevice.get_all_devices():
268-
if cuda_device.uuid == self.uuid:
276+
if cuda_device.uuid == self.uuid_without_prefix:
269277
return cuda_device
270278

271279
raise RuntimeError("No corresponding CUDA device found for this NVML device.")
@@ -280,6 +288,8 @@ cdef class Device:
280288
int
281289
The number of available devices.
282290
"""
291+
initialize()
292+
283293
return nvml.device_get_count_v2()
284294

285295
@classmethod
@@ -292,6 +302,8 @@ cdef class Device:
292302
Iterator of Device
293303
An iterator over available devices.
294304
"""
305+
initialize()
306+
295307
for device_id in range(nvml.device_get_count_v2()):
296308
yield cls(index=device_id)
297309

@@ -317,6 +329,18 @@ cdef class Device:
317329
"""
318330
return AddressingMode(nvml.device_get_addressing_mode(self._handle).value)
319331

332+
#########################################################################
333+
# MIG (MULTI-INSTANCE GPU) DEVICES
334+
335+
@property
336+
def mig(self) -> MigInfo:
337+
"""
338+
Accessor for MIG (Multi-Instance GPU) information.
339+
340+
For Ampere™ or newer fully supported devices.
341+
"""
342+
return MigInfo(self)
343+
320344
#########################################################################
321345
# AFFINITY
322346

@@ -853,6 +877,7 @@ __all__ = [
853877
"InforomInfo",
854878
"InforomObject",
855879
"MemoryInfo",
880+
"MigInfo",
856881
"PcieUtilCounter",
857882
"PciInfo",
858883
"Pstates",

cuda_core/docs/source/api.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -256,6 +256,7 @@ Types
256256
system.GpuTopologyLevel
257257
system.InforomInfo
258258
system.MemoryInfo
259+
system.MigInfo
259260
system.PciInfo
260261
system.RepairStatus
261262
system.Temperature

cuda_core/tests/system/test_system_device.py

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ def test_to_cuda_device():
5757
cuda_device = device.to_cuda_device()
5858

5959
assert isinstance(cuda_device, CudaDevice)
60-
assert cuda_device.uuid == device.uuid
60+
assert cuda_device.uuid == device.uuid_without_prefix
6161

6262
# Technically, this test will only work with PCI devices, but are there
6363
# non-PCI devices we need to support?
@@ -227,9 +227,9 @@ def test_device_serial():
227227
assert len(serial) > 0
228228

229229

230-
def test_device_uuid():
230+
def test_device_uuid_without_prefix():
231231
for device in system.Device.get_all_devices():
232-
uuid = device.uuid
232+
uuid = device.uuid_without_prefix
233233
assert isinstance(uuid, str)
234234

235235
# Expands to GPU-8hex-4hex-4hex-4hex-12hex, where 8hex means 8 consecutive
@@ -729,3 +729,30 @@ def test_pstates():
729729
assert isinstance(utilization.percentage, int)
730730
assert isinstance(utilization.inc_threshold, int)
731731
assert isinstance(utilization.dec_threshold, int)
732+
733+
734+
@pytest.mark.skipif(helpers.IS_WSL or helpers.IS_WINDOWS, reason="MIG not supported on WSL or Windows")
735+
def test_mig():
736+
for device in system.Device.get_all_devices():
737+
with unsupported_before(device, None):
738+
mig = device.mig
739+
740+
assert isinstance(mig.is_mig_device, bool)
741+
if mig.is_mig_device:
742+
assert isinstance(mig.mode, bool)
743+
assert isinstance(mig.pending_mode, bool)
744+
745+
device_count = mig.get_device_count()
746+
assert isinstance(device_count, int)
747+
assert device_count >= 0
748+
749+
for mig_device in mig.get_all_devices():
750+
assert isinstance(mig_device, system.Device)
751+
752+
753+
def test_uuid():
754+
for device in system.Device.get_all_devices():
755+
uuid = device.uuid
756+
assert isinstance(uuid, str)
757+
assert uuid.startswith(("GPU-", "MIG-"))
758+
assert uuid == device.uuid

0 commit comments

Comments
 (0)