Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions kernels/gemm/bf16_b200/bf16_b200_gemm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ __global__ void kernel(const __grid_constant__ globals<C> g) {
__shared__ uint32_t tmem_addr;
__shared__ clc::handle clc_handle[C::CLC_PIPE_DEPTH];
__shared__ semaphore tmem_provisioned, tmem_finished, schedule_arrived[C::CLC_PIPE_DEPTH], schedule_finished[C::CLC_PIPE_DEPTH];
__shared__ semaphore inputs_arrived[C::LOAD_PIPE_DEPTH], inputs_finished[C::LOAD_PIPE_DEPTH], outputs_arrived[C::NUM_CONSUMERS], outputs_finished[C::MMA_PIPE_DEPTH];
__shared__ semaphore inputs_arrived[C::LOAD_PIPE_DEPTH], inputs_finished[C::LOAD_PIPE_DEPTH], outputs_arrived[C::MMA_PIPE_DEPTH][C::NUM_CONSUMERS], outputs_finished[C::MMA_PIPE_DEPTH];
uint32_t bitfield = 0xFFFF0000; // ***_finished phase bits start as 1s, ***_arrived phase bits start as 0s

if (threadIdx.x == 32) {
Expand All @@ -104,11 +104,11 @@ __global__ void kernel(const __grid_constant__ globals<C> g) {
init_semaphore(inputs_finished[i], 0, C::NUM_CONSUMERS);
}
#pragma unroll
for (int i = 0; i < C::NUM_CONSUMERS; i++) {
init_semaphore(outputs_arrived[i], 0, 1);
}
#pragma unroll
for (int i = 0; i < C::MMA_PIPE_DEPTH; i++) {
#pragma unroll
for (int j = 0; j < C::NUM_CONSUMERS; j++) {
init_semaphore(outputs_arrived[i][j], 0, 1);
}
init_semaphore(outputs_finished[i], 0, C::CLUSTER_SIZE*C::NUM_CONSUMERS);
}
}
Expand Down Expand Up @@ -175,7 +175,7 @@ __global__ void kernel(const __grid_constant__ globals<C> g) {
update_phasebit<0>(bitfield, input_ring);
input_ring=ring_advance<C::LOAD_PIPE_DEPTH>(input_ring);
}
detail::tcgen05::commit<C::CLUSTER_SIZE>(outputs_arrived[warpgroup::warpid()]);
detail::tcgen05::commit<C::CLUSTER_SIZE>(outputs_arrived[task_iter%C::MMA_PIPE_DEPTH][warpgroup::warpid()]);
if (!schedule.success) break;
}
}
Expand Down Expand Up @@ -205,7 +205,7 @@ __global__ void kernel(const __grid_constant__ globals<C> g) {
warpgroup::sync(warpgroup::groupid()+1);
warpgroup::tma::cluster::arrive(schedule_finished[task_iter%C::CLC_PIPE_DEPTH], 0);
if (schedule.success) next_tile_coord = get_swizzled_2d_idx<C::SUPERGROUP_SIZE>(rblks, cblks, schedule.x/C::CLUSTER_SIZE);
wait(outputs_arrived[warpgroup::groupid()], task_iter%2);
wait(outputs_arrived[task_iter%C::MMA_PIPE_DEPTH][warpgroup::groupid()], (task_iter/C::MMA_PIPE_DEPTH)%2);
if constexpr (C::OVERLAP_MMA_EPI) {
rt_bf<C::Mb/8, C::Nb/C::EPI_PIPE_DEPTH> d_reg;
#pragma unroll
Expand Down
14 changes: 7 additions & 7 deletions kernels/gemm/fp8_b200/fp8_b200_gemm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ __global__ void kernel(const __grid_constant__ globals<C> g) {
__shared__ uint32_t tmem_addr;
__shared__ clc::handle clc_handle[C::CLC_PIPE_DEPTH];
__shared__ semaphore tmem_provisioned, tmem_finished, schedule_arrived[C::CLC_PIPE_DEPTH], schedule_finished[C::CLC_PIPE_DEPTH];
__shared__ semaphore inputs_arrived[C::LOAD_PIPE_DEPTH], inputs_finished[C::LOAD_PIPE_DEPTH], outputs_arrived[C::NUM_CONSUMERS], outputs_finished[C::MMA_PIPE_DEPTH];
__shared__ semaphore inputs_arrived[C::LOAD_PIPE_DEPTH], inputs_finished[C::LOAD_PIPE_DEPTH], outputs_arrived[C::MMA_PIPE_DEPTH][C::NUM_CONSUMERS], outputs_finished[C::MMA_PIPE_DEPTH];
uint32_t bitfield = 0xFFFF0000; // ***_finished phase bits start as 1s, ***_arrived phase bits start as 0s

if (threadIdx.x == 32) {
Expand All @@ -104,11 +104,11 @@ __global__ void kernel(const __grid_constant__ globals<C> g) {
init_semaphore(inputs_finished[i], 0, C::NUM_CONSUMERS);
}
#pragma unroll
for (int i = 0; i < C::NUM_CONSUMERS; i++) {
init_semaphore(outputs_arrived[i], 0, 1);
}
#pragma unroll
for (int i = 0; i < C::MMA_PIPE_DEPTH; i++) {
#pragma unroll
for (int j = 0; j < C::NUM_CONSUMERS; j++) {
init_semaphore(outputs_arrived[i][j], 0, 1);
}
init_semaphore(outputs_finished[i], 0, C::CLUSTER_SIZE*C::NUM_CONSUMERS);
}
}
Expand Down Expand Up @@ -175,7 +175,7 @@ __global__ void kernel(const __grid_constant__ globals<C> g) {
update_phasebit<0>(bitfield, input_ring);
input_ring=ring_advance<C::LOAD_PIPE_DEPTH>(input_ring);
}
detail::tcgen05::commit<C::CLUSTER_SIZE>(outputs_arrived[warpgroup::warpid()]);
detail::tcgen05::commit<C::CLUSTER_SIZE>(outputs_arrived[task_iter%C::MMA_PIPE_DEPTH][warpgroup::warpid()]);
if (!schedule.success) break;
}
}
Expand Down Expand Up @@ -205,7 +205,7 @@ __global__ void kernel(const __grid_constant__ globals<C> g) {
warpgroup::sync(warpgroup::groupid()+1);
warpgroup::tma::cluster::arrive(schedule_finished[task_iter%C::CLC_PIPE_DEPTH], 0);
if (schedule.success) next_tile_coord = get_swizzled_2d_idx<C::SUPERGROUP_SIZE>(rblks, cblks, schedule.x/C::CLUSTER_SIZE);
wait(outputs_arrived[warpgroup::groupid()], task_iter%2);
wait(outputs_arrived[task_iter%C::MMA_PIPE_DEPTH][warpgroup::groupid()], (task_iter/C::MMA_PIPE_DEPTH)%2);
if constexpr (C::OVERLAP_MMA_EPI) {
rt_bf<C::Mb/8, C::Nb/C::EPI_PIPE_DEPTH> d_reg;
#pragma unroll
Expand Down
9 changes: 5 additions & 4 deletions kernels/parallel/ag_gemm/ag_gemm_b200.cu
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ __device__ inline void comp_sm(const globals &g) {

__shared__ semaphore inputs_arrived[G::PIPELINE_STAGES],
inputs_finished[G::PIPELINE_STAGES],
outputs_arrived,
outputs_arrived[G::MMA_PIPE_DEPTH],
outputs_finished[G::MMA_PIPE_DEPTH];
int input_ring = 0;
int mma_ring = 0;
Expand All @@ -141,9 +141,9 @@ __device__ inline void comp_sm(const globals &g) {
init_semaphore(inputs_arrived[i], 0, 1);
init_semaphore(inputs_finished[i], 0, 1);
}
init_semaphore(outputs_arrived, 0, 1);
#pragma unroll
for (int i = 0; i < G::MMA_PIPE_DEPTH; i++) {
init_semaphore(outputs_arrived[i], 0, 1);
init_semaphore(outputs_finished[i], 0, C::CLUSTER_SIZE);
}
}
Expand Down Expand Up @@ -200,7 +200,7 @@ __device__ inline void comp_sm(const globals &g) {
else mma2_ABt(d_tt[mma_ring], A_smem[input_ring], B_smem[input_ring], inputs_finished[input_ring]);
input_ring=ring_advance<G::PIPELINE_STAGES>(input_ring);
}
kittens::detail::tcgen05::commit<C::CLUSTER_SIZE>(outputs_arrived);
kittens::detail::tcgen05::commit<C::CLUSTER_SIZE>(outputs_arrived[mma_ring]);
mma_ring=ring_advance<G::MMA_PIPE_DEPTH>(mma_ring);
}
}
Expand Down Expand Up @@ -236,7 +236,8 @@ __device__ inline void comp_sm(const globals &g) {
col_idx = idx_in_shard / num_peer_devices;
}

wait(outputs_arrived, mma_ring);
wait(outputs_arrived[mma_ring], get_phasebit<0>(bitfield, mma_ring));
update_phasebit<0>(bitfield, mma_ring);
rt_bf<G::ROW_BLOCK/8, G::COL_BLOCK/G::EPI_PIPE_DEPTH> C_reg;
#pragma unroll
for(int i = 0; i < G::EPI_PIPE_DEPTH; i++) {
Expand Down
9 changes: 5 additions & 4 deletions kernels/parallel/ag_gemm_fp8/ag_gemm_fp8_b200.cu
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ __device__ inline void comp_sm(const globals &g) {

__shared__ semaphore inputs_arrived[G::PIPELINE_STAGES],
inputs_finished[G::PIPELINE_STAGES],
outputs_arrived,
outputs_arrived[G::MMA_PIPE_DEPTH],
outputs_finished[G::MMA_PIPE_DEPTH];
int input_ring = 0;
int mma_ring = 0;
Expand All @@ -141,9 +141,9 @@ __device__ inline void comp_sm(const globals &g) {
init_semaphore(inputs_arrived[i], 0, 1);
init_semaphore(inputs_finished[i], 0, 1);
}
init_semaphore(outputs_arrived, 0, 1);
#pragma unroll
for (int i = 0; i < G::MMA_PIPE_DEPTH; i++) {
init_semaphore(outputs_arrived[i], 0, 1);
init_semaphore(outputs_finished[i], 0, C::CLUSTER_SIZE);
}
}
Expand Down Expand Up @@ -200,7 +200,7 @@ __device__ inline void comp_sm(const globals &g) {
else mma2_ABt(d_tt[mma_ring], A_smem[input_ring], B_smem[input_ring], inputs_finished[input_ring]);
input_ring=ring_advance<G::PIPELINE_STAGES>(input_ring);
}
kittens::detail::tcgen05::commit<C::CLUSTER_SIZE>(outputs_arrived);
kittens::detail::tcgen05::commit<C::CLUSTER_SIZE>(outputs_arrived[mma_ring]);
mma_ring=ring_advance<G::MMA_PIPE_DEPTH>(mma_ring);
}
}
Expand Down Expand Up @@ -236,7 +236,8 @@ __device__ inline void comp_sm(const globals &g) {
col_idx = idx_in_shard / num_peer_devices;
}

wait(outputs_arrived, mma_ring);
wait(outputs_arrived[mma_ring], get_phasebit<0>(bitfield, mma_ring));
update_phasebit<0>(bitfield, mma_ring);
rt_bf<G::ROW_BLOCK/8, G::COL_BLOCK/G::EPI_PIPE_DEPTH> C_reg;
#pragma unroll
for(int i = 0; i < G::EPI_PIPE_DEPTH; i++) {
Expand Down