Skip to content

Commit 9aea8e3

Browse files
committed
UCT/IB/MLX5/GDAKI: Grouped checks required ctx to a single method.
1 parent 47d7345 commit 9aea8e3

File tree

1 file changed

+86
-77
lines changed

1 file changed

+86
-77
lines changed

src/uct/ib/mlx5/gdaki/gdaki.c

Lines changed: 86 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -125,53 +125,19 @@ static void uct_rc_gdaki_calc_dev_ep_layout(size_t num_channels, size_t wq_len,
125125
*pgsz_bitmap_p = (max_page_size << 1) - 1;
126126
}
127127

128-
static CUdevice uct_gdaki_push_primary_ctx(int retain_inactive_ctx)
129-
{
130-
CUdevice cuda_dev;
131-
ucs_status_t status;
132-
133-
status = uct_cuda_ctx_primary_push_first_active(&cuda_dev);
134-
if (status == UCS_OK) {
135-
return cuda_dev;
136-
}
137-
138-
if ((status != UCS_ERR_NO_DEVICE) || !retain_inactive_ctx) {
139-
return CU_DEVICE_INVALID;
140-
}
141-
142-
status = UCT_CUDADRV_FUNC_LOG_ERR(cuDeviceGet(&cuda_dev, 0));
143-
if (status != UCS_OK) {
144-
return CU_DEVICE_INVALID;
145-
}
146-
147-
status = uct_cuda_ctx_primary_push(cuda_dev, 1, UCS_LOG_LEVEL_ERROR);
148-
if (status != UCS_OK) {
149-
return CU_DEVICE_INVALID;
150-
}
151-
152-
return cuda_dev;
153-
}
154-
155128
static int uct_gdaki_check_umem_dmabuf(const uct_ib_md_t *md)
156129
{
157-
ucs_status_t status = UCS_ERR_UNSUPPORTED;
130+
ucs_status_t ret = 0;
158131
#if HAVE_DECL_MLX5DV_UMEM_MASK_DMABUF
159132
struct mlx5dv_devx_umem_in umem_in = {};
160133
struct mlx5dv_devx_umem *umem;
161134
uct_cuda_copy_md_dmabuf_t dmabuf;
162135
CUdeviceptr buff;
163-
CUdevice cuda_dev;
164136

165-
cuda_dev = uct_gdaki_push_primary_ctx(md->config.gda_retain_inactive_ctx);
166-
if (cuda_dev == CU_DEVICE_INVALID) {
137+
if (UCT_CUDADRV_FUNC_LOG_ERR(cuMemAlloc(&buff, 1)) != UCS_OK) {
167138
return 0;
168139
}
169140

170-
status = UCT_CUDADRV_FUNC_LOG_ERR(cuMemAlloc(&buff, 1));
171-
if (status != UCS_OK) {
172-
goto out_ctx_pop_and_release;
173-
}
174-
175141
dmabuf = uct_cuda_copy_md_get_dmabuf((void*)buff, 1,
176142
UCS_SYS_DEVICE_ID_UNKNOWN);
177143

@@ -183,20 +149,18 @@ static int uct_gdaki_check_umem_dmabuf(const uct_ib_md_t *md)
183149
umem_in.dmabuf_fd = dmabuf.fd;
184150

185151
umem = mlx5dv_devx_umem_reg_ex(md->dev.ibv_context, &umem_in);
186-
if (umem == NULL) {
187-
status = UCS_ERR_NO_MEMORY;
188-
goto out_free;
152+
if (umem != NULL) {
153+
mlx5dv_devx_umem_dereg(umem);
154+
ret = 1;
155+
} else {
156+
ret = 0;
189157
}
190158

191-
mlx5dv_devx_umem_dereg(umem);
192-
out_free:
193159
ucs_close_fd(&dmabuf.fd);
194-
cuMemFree(buff);
195-
out_ctx_pop_and_release:
196-
uct_cuda_ctx_primary_pop_and_release(cuda_dev);
160+
(void)UCT_CUDADRV_FUNC_LOG_WARN(cuMemFree(buff));
197161
#endif
198162

199-
return status == UCS_OK;
163+
return ret;
200164
}
201165

202166
static int uct_gdaki_is_dmabuf_supported(const uct_ib_md_t *md)
@@ -1159,32 +1123,22 @@ static ucs_status_t uct_gdaki_md_check_uar(uct_ib_mlx5_md_t *md)
11591123
{
11601124
struct mlx5dv_devx_uar *uar;
11611125
ucs_status_t status;
1162-
CUdevice cuda_dev;
11631126
unsigned flags;
11641127

11651128
status = uct_ib_mlx5_devx_alloc_uar(md, 0, &uar);
11661129
if (status != UCS_OK) {
1167-
goto out;
1168-
}
1169-
1170-
cuda_dev = uct_gdaki_push_primary_ctx(
1171-
md->super.config.gda_retain_inactive_ctx);
1172-
if (cuda_dev == CU_DEVICE_INVALID) {
1173-
goto out_free_uar;
1130+
return status;
11741131
}
11751132

11761133
flags = CU_MEMHOSTREGISTER_PORTABLE | CU_MEMHOSTREGISTER_DEVICEMAP |
11771134
CU_MEMHOSTREGISTER_IOMEMORY;
11781135
status = UCT_CUDADRV_FUNC_LOG_DEBUG(
11791136
cuMemHostRegister(uar->reg_addr, UCT_IB_MLX5_BF_REG_SIZE, flags));
11801137
if (status == UCS_OK) {
1181-
UCT_CUDADRV_FUNC_LOG_DEBUG(cuMemHostUnregister(uar->reg_addr));
1138+
UCT_CUDADRV_FUNC_LOG_WARN(cuMemHostUnregister(uar->reg_addr));
11821139
}
11831140

1184-
uct_cuda_ctx_primary_pop_and_release(cuda_dev);
1185-
out_free_uar:
11861141
mlx5dv_devx_free_uar(uar);
1187-
out:
11881142
return status;
11891143
}
11901144

@@ -1382,6 +1336,80 @@ uct_gdaki_dev_matrix_init(const uct_ib_md_t *ib_md, size_t *dmat_length_p)
13821336
return dmat;
13831337
}
13841338

1339+
static CUdevice uct_gdaki_push_primary_ctx(int retain_inactive_ctx)
1340+
{
1341+
CUdevice cuda_dev;
1342+
ucs_status_t status;
1343+
1344+
status = uct_cuda_ctx_primary_push_first_active(&cuda_dev);
1345+
if (status == UCS_OK) {
1346+
return cuda_dev;
1347+
}
1348+
1349+
if ((status != UCS_ERR_NO_DEVICE) || !retain_inactive_ctx) {
1350+
if (status == UCS_ERR_NO_DEVICE) {
1351+
ucs_diag("no active primary CUDA context on any device. Please set "
1352+
"UCX_IB_GDA_RETAIN_INACTIVE_CTX=yes to retain inactive "
1353+
"context.");
1354+
}
1355+
return CU_DEVICE_INVALID;
1356+
}
1357+
1358+
status = UCT_CUDADRV_FUNC_LOG_ERR(cuDeviceGet(&cuda_dev, 0));
1359+
if (status != UCS_OK) {
1360+
return CU_DEVICE_INVALID;
1361+
}
1362+
1363+
status = uct_cuda_ctx_primary_push(cuda_dev, 1, UCS_LOG_LEVEL_ERROR);
1364+
if (status != UCS_OK) {
1365+
return CU_DEVICE_INVALID;
1366+
}
1367+
1368+
return cuda_dev;
1369+
}
1370+
1371+
static int
1372+
uct_gdaki_check_cuda_ctx_dependent_features(uct_ib_mlx5_md_t *ib_mlx5_md)
1373+
{
1374+
uct_ib_md_t *ib_md = &ib_mlx5_md->super;
1375+
CUdevice cuda_dev;
1376+
char dmabuf_str[8];
1377+
int ret;
1378+
1379+
cuda_dev = uct_gdaki_push_primary_ctx(ib_md->config.gda_retain_inactive_ctx);
1380+
if (cuda_dev == CU_DEVICE_INVALID) {
1381+
return 0;
1382+
}
1383+
1384+
if ((ib_md->config.gda_dmabuf_enable != UCS_NO) &&
1385+
uct_gdaki_is_dmabuf_supported(ib_md)) {
1386+
ib_mlx5_md->flags |= UCT_IB_MLX5_MD_FLAG_REG_DMABUF_UMEM;
1387+
ucs_debug("%s: using dmabuf for gda transport",
1388+
uct_ib_device_name(&ib_md->dev));
1389+
} else if ((ib_md->config.gda_dmabuf_enable != UCS_YES) &&
1390+
uct_gdaki_is_peermem_loaded(ib_md)) {
1391+
ucs_debug("%s: using peermem for gda transport",
1392+
uct_ib_device_name(&ib_md->dev));
1393+
} else {
1394+
ucs_config_sprintf_ternary_auto(dmabuf_str, sizeof(dmabuf_str),
1395+
&ib_md->config.gda_dmabuf_enable, NULL);
1396+
ucs_diag("%s: GPU-direct RDMA is not available (GDA_DMABUF_ENABLE=%s)",
1397+
uct_ib_device_name(&ib_md->dev), dmabuf_str);
1398+
ret = 0;
1399+
goto out;
1400+
}
1401+
1402+
if (uct_gdaki_is_uar_supported(ib_mlx5_md)) {
1403+
ret = 1;
1404+
} else {
1405+
ret = 0;
1406+
}
1407+
1408+
out:
1409+
uct_cuda_ctx_primary_pop_and_release(cuda_dev);
1410+
return ret;
1411+
}
1412+
13851413
static ucs_status_t
13861414
uct_gdaki_query_tl_devices(uct_md_h tl_md,
13871415
uct_tl_device_resource_t **tl_devices_p,
@@ -1399,7 +1427,6 @@ uct_gdaki_query_tl_devices(uct_md_h tl_md,
13991427
ucs_sys_device_t dev;
14001428
int i;
14011429
uct_gdaki_dev_matrix_elem_t *ibdesc;
1402-
char dmabuf_str[8];
14031430

14041431
UCS_INIT_ONCE(&dmat_once) {
14051432
dmat = uct_gdaki_dev_matrix_init(ib_md, &dmat_length);
@@ -1410,20 +1437,7 @@ uct_gdaki_query_tl_devices(uct_md_h tl_md,
14101437
goto out;
14111438
}
14121439

1413-
if ((ib_md->config.gda_dmabuf_enable != UCS_NO) &&
1414-
uct_gdaki_is_dmabuf_supported(ib_md)) {
1415-
ib_mlx5_md->flags |= UCT_IB_MLX5_MD_FLAG_REG_DMABUF_UMEM;
1416-
ucs_debug("%s: using dmabuf for gda transport",
1417-
uct_ib_device_name(&ib_md->dev));
1418-
} else if ((ib_md->config.gda_dmabuf_enable != UCS_YES) &&
1419-
uct_gdaki_is_peermem_loaded(ib_md)) {
1420-
ucs_debug("%s: using peermem for gda transport",
1421-
uct_ib_device_name(&ib_md->dev));
1422-
} else {
1423-
ucs_config_sprintf_ternary_auto(dmabuf_str, sizeof(dmabuf_str),
1424-
&ib_md->config.gda_dmabuf_enable, NULL);
1425-
ucs_diag("%s: GPU-direct RDMA is not available (GDA_DMABUF_ENABLE=%s)",
1426-
uct_ib_device_name(&ib_md->dev), dmabuf_str);
1440+
if (!uct_gdaki_check_cuda_ctx_dependent_features(ib_mlx5_md)) {
14271441
status = UCS_ERR_NO_DEVICE;
14281442
goto out;
14291443
}
@@ -1450,11 +1464,6 @@ uct_gdaki_query_tl_devices(uct_md_h tl_md,
14501464
goto out;
14511465
}
14521466

1453-
if (!uct_gdaki_is_uar_supported(ib_mlx5_md)) {
1454-
status = UCS_ERR_NO_DEVICE;
1455-
goto err;
1456-
}
1457-
14581467
num_tl_devices = 0;
14591468
ucs_for_each_bit(i, ibdesc->cuda_map) {
14601469
status = UCT_CUDADRV_FUNC_LOG_ERR(cuDeviceGet(&device, i));

0 commit comments

Comments
 (0)