Add support for broadcasting right matrix.

This commit is contained in:
Doru Bercea 2020-01-10 11:22:41 -05:00
parent 170296b7c6
commit d176b84506
1 changed files with 20 additions and 3 deletions

View File

@ -329,15 +329,32 @@ void ONNXMatMulOp::inferShapes() {
// Special case when both arrays are 1-dimensional and according to
// numpy rules the types need to be extended to 1xN and Nx1. Helper sizes
// need to be removed after the multiplication but cannot be removed if all
// remaining sizes are 1.
// sizes are 1.
if (lhsShape[0] != -1 && rhsShape[0] != -1 &&
lhsShape[0] != rhsShape[0])
emitError("Attempt to multiply incompatible matrices.");
dims.emplace_back(1);
} else {
// Special cases for when at least one matrix has more than two dimensions.
if (lhsShape.size() > 2 && rhsShape.size() == 2) {
// (s1 x s2 x... x sKx x M x N) MATMUL (N x P)
// =>
// (s1 x s2 x... x sKx x M x P)
// Check legality of matrix multiplication.
unsigned leftDims = lhsShape.size();
if (lhsShape[leftDims - 2] != -1 && rhsShape[0] != -1 &&
lhsShape[leftDims - 2] != rhsShape[0])
emitError("Attempt to multiply incompatible matrices.");
for (int i = 0; i < leftDims - 1; ++i)
dims.emplace_back(lhsShape[i]);
dims.emplace_back(rhsShape[1]);
} else {
dims.emplace_back(lhsShape.size() == 1 ? 1 : lhsShape[0]);
dims.emplace_back(rhsShape.size() == 1 ? 1 : rhsShape[1]);
}
}
getResult()->setType(RankedTensorType::get(dims, lhsTy.getElementType()));
}