Add support for broadcasting left matrix.

This commit is contained in:
Doru Bercea 2020-01-10 11:34:26 -05:00
parent d176b84506
commit 38bffee619
1 changed files with 17 additions and 2 deletions

View File

@ -337,9 +337,9 @@ void ONNXMatMulOp::inferShapes() {
} else { } else {
// Special cases for when at least one matrix has more than two dimensions. // Special cases for when at least one matrix has more than two dimensions.
if (lhsShape.size() > 2 && rhsShape.size() == 2) { if (lhsShape.size() > 2 && rhsShape.size() == 2) {
// (s1 x s2 x... x sKx x M x N) MATMUL (N x P) // (s1 x s2 x... x sK x M x N) MATMUL (N x P)
// => // =>
// (s1 x s2 x... x sKx x M x P) // (s1 x s2 x... x sK x M x P)
// Check legality of matrix multiplication. // Check legality of matrix multiplication.
unsigned leftDims = lhsShape.size(); unsigned leftDims = lhsShape.size();
@ -350,6 +350,21 @@ void ONNXMatMulOp::inferShapes() {
for (int i = 0; i < leftDims - 1; ++i) for (int i = 0; i < leftDims - 1; ++i)
dims.emplace_back(lhsShape[i]); dims.emplace_back(lhsShape[i]);
dims.emplace_back(rhsShape[1]); dims.emplace_back(rhsShape[1]);
} else if (lhsShape.size() == 2 && rhsShape.size() > 2) {
// (M x N) MATMUL (s1 x s2 x... x sK x N x P)
// =>
// (s1 x s2 x... x sK x M x P)
// Check legality of matrix multiplication.
unsigned rightDims = rhsShape.size();
if (lhsShape[1] != -1 && rhsShape[rightDims - 2] != -1 &&
lhsShape[1] != rhsShape[rightDims - 2])
emitError("Attempt to multiply incompatible matrices.");
for (int i = 0; i < rightDims - 2; ++i)
dims.emplace_back(rhsShape[i]);
dims.emplace_back(lhsShape[0]);
dims.emplace_back(rhsShape[rightDims - 1]);
} else { } else {
dims.emplace_back(lhsShape.size() == 1 ? 1 : lhsShape[0]); dims.emplace_back(lhsShape.size() == 1 ? 1 : lhsShape[0]);
dims.emplace_back(rhsShape.size() == 1 ? 1 : rhsShape[1]); dims.emplace_back(rhsShape.size() == 1 ? 1 : rhsShape[1]);