Fix conditions.

This commit is contained in:
Doru Bercea 2020-01-10 12:30:12 -05:00
parent a3995b61e7
commit 96551ef71e
1 changed files with 56 additions and 59 deletions

View File

@ -334,9 +334,7 @@ void ONNXMatMulOp::inferShapes() {
lhsShape[0] != rhsShape[0])
emitError("Attempt to multiply incompatible matrices.");
dims.emplace_back(1);
} else {
// Special cases for when at least one matrix has more than two dimensions.
if (lhsShape.size() > 2 && rhsShape.size() == 2) {
} else if (lhsShape.size() > 2 && rhsShape.size() == 2) {
// (s1 x s2 x... x sK x M x N) MATMUL (N x P)
// =>
// (s1 x s2 x... x sK x M x P)
@ -394,7 +392,6 @@ void ONNXMatMulOp::inferShapes() {
dims.emplace_back(lhsShape.size() == 1 ? 1 : lhsShape[0]);
dims.emplace_back(rhsShape.size() == 1 ? 1 : rhsShape[1]);
}
}
getResult()->setType(RankedTensorType::get(dims, lhsTy.getElementType()));
}