Skip to content

Commit 77757b0

Browse files
committed
Fix addmv handler - ignore input when beta is 0
Summary: According to the pytorch documentation, input should be ignored when beta is 0. !ci_branch_mk2 Reviewers: #framework_ip_review_-_any_oss_or_third-party_code_use_has_been_approved, #pytorch, dariuszs Reviewed By: #framework_ip_review_-_any_oss_or_third-party_code_use_has_been_approved, #pytorch, dariuszs Subscribers: dariuszs JIRA Issues: AFS-357 Differential Revision: https://phabricator.sourcevertex.net/D84820
1 parent 25f77c2 commit 77757b0

2 files changed

Lines changed: 77 additions & 7 deletions

File tree

poptorch/source/popart_canonicalization/BlasOps.cpp

Lines changed: 29 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -112,16 +112,38 @@ torch::jit::Node *baddbmmHandler(torch::jit::Graph *graph,
112112

113113
torch::jit::Node *addmvHandler(torch::jit::Graph *graph,
114114
torch::jit::Node *node) {
115+
auto *input = node->input(0);
115116
auto *mat = node->input(1);
116117
auto *vec = node->input(2);
117-
auto *t0 = createMatmul(graph, {mat, vec})->output();
118-
auto *alpha = node->input(4);
119-
auto *t1 = createMul(graph, {t0, alpha})->output();
120-
auto *input = node->input(0);
121118
auto *beta = node->input(3);
122-
auto *t2 = createMul(graph, {input, beta})->output();
123-
// add(mul(matmul(mat, vec), alpha), mul(input, beta))
124-
return createAdd(graph, {t1, t2});
119+
auto *alpha = node->input(4);
120+
121+
const auto alpha_val = constantToFloat(alpha->node());
122+
const auto beta_val = constantToFloat(beta->node());
123+
124+
if (alpha_val == 0 && beta_val == 0) {
125+
return createConstantFloatLike(graph, input, {0}, {shapeFromTensor(input)});
126+
}
127+
128+
torch::jit::Node *t1 = nullptr;
129+
if (alpha_val != 0) {
130+
auto *t0 = createMatmul(graph, {mat, vec})->output();
131+
t1 = createMul(graph, {t0, alpha});
132+
}
133+
134+
torch::jit::Node *output;
135+
if (beta_val != 0) {
136+
auto *t2 = createMul(graph, {input, beta});
137+
if (t1 != nullptr) {
138+
output = createAdd(graph, {t1->output(), t2->output()});
139+
} else {
140+
output = t2;
141+
}
142+
} else {
143+
output = t1;
144+
}
145+
146+
return output;
125147
}
126148
} // namespace
127149

tests/blas_test.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -221,3 +221,51 @@ def forward(self, x1, x2, x3):
221221
ipu_result = poptorch.inferenceModel(model)(t1, t2, t3)
222222

223223
helpers.assert_allclose(expected=cpu_result, actual=ipu_result)
224+
225+
226+
@pytest.mark.parametrize("input_shape", [(20, 10)])
227+
@pytest.mark.parametrize("beta", [0, .5])
228+
@pytest.mark.parametrize("alpha", [0, 1.5])
229+
@pytest.mark.parametrize("use_out", [True, False])
230+
def test_addmv(input_shape, beta, alpha, use_out):
231+
torch.manual_seed(42)
232+
233+
mat = torch.randn(input_shape)
234+
vec = torch.randn(input_shape[1])
235+
inp = torch.randn(input_shape[0])
236+
237+
if beta == 0:
238+
# NaNs in input should be ignored
239+
inp[0] = float('nan')
240+
if alpha == 0:
241+
# NaNs in vec or mat should be ignored
242+
mat[0, 0] = float('nan')
243+
vec[0] = float('nan')
244+
245+
output = torch.empty(input_shape[0]) if use_out else None
246+
247+
class AddmvModel(torch.nn.Module):
248+
def __init__(self, beta, alpha):
249+
super().__init__()
250+
self.beta = beta
251+
self.alpha = alpha
252+
253+
def forward(self, inp, mat, vec, out=None):
254+
result = torch.addmv(inp,
255+
mat,
256+
vec,
257+
beta=self.beta,
258+
alpha=self.alpha,
259+
out=out)
260+
if self.beta == 0 and self.alpha == 0:
261+
# Avoid empty compute graph
262+
result += torch.zeros_like(inp)
263+
return result
264+
265+
model = AddmvModel(beta, alpha)
266+
cpu_result = model(inp, mat, vec, out=output)
267+
ipu_result = poptorch.inferenceModel(model)(inp, mat, vec, output)
268+
269+
helpers.assert_allclose(expected=cpu_result, actual=ipu_result)
270+
if use_out is True:
271+
helpers.assert_allclose(expected=cpu_result, actual=output)

0 commit comments

Comments
 (0)