Fix conditions.
This commit is contained in:
parent
a3995b61e7
commit
96551ef71e
|
@ -334,9 +334,7 @@ void ONNXMatMulOp::inferShapes() {
|
|||
lhsShape[0] != rhsShape[0])
|
||||
emitError("Attempt to multiply incompatible matrices.");
|
||||
dims.emplace_back(1);
|
||||
} else {
|
||||
// Special cases for when at least one matrix has more than two dimensions.
|
||||
if (lhsShape.size() > 2 && rhsShape.size() == 2) {
|
||||
} else if (lhsShape.size() > 2 && rhsShape.size() == 2) {
|
||||
// (s1 x s2 x... x sK x M x N) MATMUL (N x P)
|
||||
// =>
|
||||
// (s1 x s2 x... x sK x M x P)
|
||||
|
@ -394,7 +392,6 @@ void ONNXMatMulOp::inferShapes() {
|
|||
dims.emplace_back(lhsShape.size() == 1 ? 1 : lhsShape[0]);
|
||||
dims.emplace_back(rhsShape.size() == 1 ? 1 : rhsShape[1]);
|
||||
}
|
||||
}
|
||||
|
||||
getResult()->setType(RankedTensorType::get(dims, lhsTy.getElementType()));
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue