2020#include <uct/ib/mlx5/rc/rc_mlx5.h>
2121#include <uct/cuda/cuda_copy/cuda_copy_md.h>
2222#include <uct/cuda/base/cuda_util.h>
23+ #include <uct/cuda/base/cuda_ctx.h>
2324
2425#include "gpunetio/common/doca_gpunetio_verbs_def.h"
2526
@@ -124,6 +125,33 @@ static void uct_rc_gdaki_calc_dev_ep_layout(size_t num_channels, size_t wq_len,
124125 * pgsz_bitmap_p = (max_page_size << 1 ) - 1 ;
125126}
126127
128+ static CUdevice uct_gdaki_push_primary_ctx (int retain_inactive_ctx )
129+ {
130+ CUdevice cuda_dev ;
131+ ucs_status_t status ;
132+
133+ status = uct_cuda_ctx_primary_push_first_active (& cuda_dev );
134+ if (status == UCS_OK ) {
135+ return cuda_dev ;
136+ }
137+
138+ if ((status != UCS_ERR_NO_DEVICE ) || !retain_inactive_ctx ) {
139+ return CU_DEVICE_INVALID ;
140+ }
141+
142+ status = UCT_CUDADRV_FUNC_LOG_ERR (cuDeviceGet (& cuda_dev , 0 ));
143+ if (status != UCS_OK ) {
144+ return CU_DEVICE_INVALID ;
145+ }
146+
147+ status = uct_cuda_ctx_primary_push (cuda_dev , 1 , UCS_LOG_LEVEL_ERROR );
148+ if (status != UCS_OK ) {
149+ return CU_DEVICE_INVALID ;
150+ }
151+
152+ return cuda_dev ;
153+ }
154+
127155static int uct_gdaki_check_umem_dmabuf (const uct_ib_md_t * md )
128156{
129157 ucs_status_t status = UCS_ERR_UNSUPPORTED ;
@@ -132,21 +160,16 @@ static int uct_gdaki_check_umem_dmabuf(const uct_ib_md_t *md)
132160 struct mlx5dv_devx_umem * umem ;
133161 uct_cuda_copy_md_dmabuf_t dmabuf ;
134162 CUdeviceptr buff ;
135- CUcontext cuda_ctx ;
163+ CUdevice cuda_dev ;
136164
137- status = UCT_CUDADRV_FUNC_LOG_ERR ( cuDevicePrimaryCtxRetain ( & cuda_ctx , 0 ) );
138- if (status != UCS_OK ) {
165+ cuda_dev = uct_gdaki_push_primary_ctx ( md -> config . gda_retain_inactive_ctx );
166+ if (cuda_dev == CU_DEVICE_INVALID ) {
139167 return 0 ;
140168 }
141169
142- status = UCT_CUDADRV_FUNC_LOG_ERR (cuCtxPushCurrent (cuda_ctx ));
143- if (status != UCS_OK ) {
144- goto out_ctx_release ;
145- }
146-
147170 status = UCT_CUDADRV_FUNC_LOG_ERR (cuMemAlloc (& buff , 1 ));
148171 if (status != UCS_OK ) {
149- goto out_ctx_pop ;
172+ goto out_ctx_pop_and_release ;
150173 }
151174
152175 dmabuf = uct_cuda_copy_md_get_dmabuf ((void * )buff , 1 ,
@@ -169,10 +192,8 @@ static int uct_gdaki_check_umem_dmabuf(const uct_ib_md_t *md)
169192out_free :
170193 ucs_close_fd (& dmabuf .fd );
171194 cuMemFree (buff );
172- out_ctx_pop :
173- UCT_CUDADRV_FUNC_LOG_WARN (cuCtxPopCurrent (NULL ));
174- out_ctx_release :
175- UCT_CUDADRV_FUNC_LOG_WARN (cuDevicePrimaryCtxRelease (0 ));
195+ out_ctx_pop_and_release :
196+ uct_cuda_ctx_primary_pop_and_release (cuda_dev );
176197#endif
177198
178199 return status == UCS_OK ;
@@ -1134,30 +1155,24 @@ static UCS_CLASS_DEFINE_NEW_FUNC(uct_rc_gdaki_iface_t, uct_iface_t, uct_md_h,
11341155
11351156static UCS_CLASS_DEFINE_DELETE_FUNC (uct_rc_gdaki_iface_t , uct_iface_t ) ;
11361157
1137- static ucs_status_t
1138- uct_gdaki_md_check_uar (uct_ib_mlx5_md_t * md , CUdevice cuda_dev )
1158+ static ucs_status_t uct_gdaki_md_check_uar (uct_ib_mlx5_md_t * md )
11391159{
11401160 struct mlx5dv_devx_uar * uar ;
11411161 ucs_status_t status ;
1142- CUcontext cuda_ctx ;
1162+ CUdevice cuda_dev ;
11431163 unsigned flags ;
11441164
11451165 status = uct_ib_mlx5_devx_alloc_uar (md , 0 , & uar );
11461166 if (status != UCS_OK ) {
11471167 goto out ;
11481168 }
11491169
1150- status = UCT_CUDADRV_FUNC_LOG_ERR (
1151- cuDevicePrimaryCtxRetain ( & cuda_ctx , cuda_dev ) );
1152- if (status != UCS_OK ) {
1170+ cuda_dev = uct_gdaki_push_primary_ctx (
1171+ md -> super . config . gda_retain_inactive_ctx );
1172+ if (cuda_dev == CU_DEVICE_INVALID ) {
11531173 goto out_free_uar ;
11541174 }
11551175
1156- status = UCT_CUDADRV_FUNC_LOG_ERR (cuCtxPushCurrent (cuda_ctx ));
1157- if (status != UCS_OK ) {
1158- goto out_ctx_release ;
1159- }
1160-
11611176 flags = CU_MEMHOSTREGISTER_PORTABLE | CU_MEMHOSTREGISTER_DEVICEMAP |
11621177 CU_MEMHOSTREGISTER_IOMEMORY ;
11631178 status = UCT_CUDADRV_FUNC_LOG_DEBUG (
@@ -1166,9 +1181,7 @@ uct_gdaki_md_check_uar(uct_ib_mlx5_md_t *md, CUdevice cuda_dev)
11661181 UCT_CUDADRV_FUNC_LOG_DEBUG (cuMemHostUnregister (uar -> reg_addr ));
11671182 }
11681183
1169- UCT_CUDADRV_FUNC_LOG_WARN (cuCtxPopCurrent (NULL ));
1170- out_ctx_release :
1171- UCT_CUDADRV_FUNC_LOG_WARN (cuDevicePrimaryCtxRelease (cuda_dev ));
1184+ uct_cuda_ctx_primary_pop_and_release (cuda_dev );
11721185out_free_uar :
11731186 mlx5dv_devx_free_uar (uar );
11741187out :
@@ -1196,7 +1209,7 @@ static int uct_gdaki_is_peermem_loaded(const uct_ib_md_t *md)
11961209 return peermem_loaded ;
11971210}
11981211
1199- static int uct_gdaki_is_uar_supported (uct_ib_mlx5_md_t * md , CUdevice cu_device )
1212+ static int uct_gdaki_is_uar_supported (uct_ib_mlx5_md_t * md )
12001213{
12011214 /**
12021215 * Save the result of UAR support in a global flag to avoid the overhead of
@@ -1209,7 +1222,7 @@ static int uct_gdaki_is_uar_supported(uct_ib_mlx5_md_t *md, CUdevice cu_device)
12091222 return uar_supported ;
12101223 }
12111224
1212- uar_supported = (uct_gdaki_md_check_uar (md , cu_device ) == UCS_OK );
1225+ uar_supported = (uct_gdaki_md_check_uar (md ) == UCS_OK );
12131226 if (uar_supported == 0 ) {
12141227 ucs_diag ("GDAKI not supported, please add NVreg_RegistryDwords="
12151228 "\"PeerMappingOverride=1;\" option for nvidia kernel driver" );
@@ -1437,18 +1450,18 @@ uct_gdaki_query_tl_devices(uct_md_h tl_md,
14371450 goto out ;
14381451 }
14391452
1453+ if (!uct_gdaki_is_uar_supported (ib_mlx5_md )) {
1454+ status = UCS_ERR_NO_DEVICE ;
1455+ goto err ;
1456+ }
1457+
14401458 num_tl_devices = 0 ;
14411459 ucs_for_each_bit (i , ibdesc -> cuda_map ) {
14421460 status = UCT_CUDADRV_FUNC_LOG_ERR (cuDeviceGet (& device , i ));
14431461 if (status != UCS_OK ) {
14441462 goto err ;
14451463 }
14461464
1447- if (!uct_gdaki_is_uar_supported (ib_mlx5_md , device )) {
1448- status = UCS_ERR_NO_DEVICE ;
1449- goto err ;
1450- }
1451-
14521465 dev = uct_cuda_get_sys_dev (device );
14531466
14541467 snprintf (tl_devices [num_tl_devices ].name ,
0 commit comments