Skip to content

Commit fd6ae4c

Browse files
authored
Tensor-parallel: Fix delayed AllReduce on Gemma-4 MoE (#22129)
* Fix delayed AllReduce on Gemma-4 MoE Skip forward past nodes that don't consume the current one, and allow a chain of MULs. * Check for all sources before skipping nodes * Address review comments
1 parent fb19f94 commit fd6ae4c

1 file changed

Lines changed: 38 additions & 4 deletions

File tree

ggml/src/ggml-backend-meta.cpp

Lines changed: 38 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1683,6 +1683,36 @@ static enum ggml_status ggml_backend_meta_graph_compute(ggml_backend_t backend,
16831683

16841684
ggml_tensor * node = cgraph->nodes[id];
16851685
int32_t n_used = ggml_node_get_use_count(cgraph, id);
1686+
1687+
// Skip MIRRORED nodes that don't consume node
1688+
auto skip_unrelated = [&]() {
1689+
while (id + 1 < cgraph->n_nodes) {
1690+
ggml_tensor * next = cgraph->nodes[id+1];
1691+
if (ggml_backend_meta_get_split_state(next, false).axis != GGML_BACKEND_SPLIT_AXIS_MIRRORED) {
1692+
break;
1693+
}
1694+
bool safe = true;
1695+
for (int s = 0; s < GGML_MAX_SRC; s++) {
1696+
if (next->src[s] == nullptr) {
1697+
continue;
1698+
}
1699+
if (next->src[s] == node) {
1700+
safe = false;
1701+
break;
1702+
}
1703+
if (ggml_backend_meta_get_split_state(next->src[s], false).axis != GGML_BACKEND_SPLIT_AXIS_MIRRORED) {
1704+
safe = false;
1705+
break;
1706+
}
1707+
}
1708+
if (!safe) {
1709+
break;
1710+
}
1711+
id++;
1712+
}
1713+
};
1714+
1715+
skip_unrelated();
16861716
if (id + 1 >= cgraph->n_nodes) {
16871717
return idr;
16881718
}
@@ -1697,17 +1727,21 @@ static enum ggml_status ggml_backend_meta_graph_compute(ggml_backend_t backend,
16971727
n_used = ggml_node_get_use_count(cgraph, id);
16981728
}
16991729
}
1700-
if (id + 1 >= cgraph->n_nodes) {
1701-
return idr;
1702-
}
1703-
{
1730+
// Chain of MULs with MIRRORED src[1]
1731+
while (true) {
1732+
skip_unrelated();
1733+
if (id + 1 >= cgraph->n_nodes) {
1734+
return idr;
1735+
}
17041736
ggml_tensor * next = cgraph->nodes[id+1];
17051737
if (next->op == GGML_OP_MUL && next->src[0] == node &&
17061738
ggml_backend_meta_get_split_state(next->src[1], false).axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED) {
17071739
node = next;
17081740
id++;
17091741
idr = id;
17101742
n_used = ggml_node_get_use_count(cgraph, id);
1743+
} else {
1744+
break;
17111745
}
17121746
}
17131747

0 commit comments

Comments
 (0)