Add support for broadcasting left matrix.
This commit is contained in:
parent
d176b84506
commit
38bffee619
|
@ -337,9 +337,9 @@ void ONNXMatMulOp::inferShapes() {
|
|||
} else {
|
||||
// Special cases for when at least one matrix has more than two dimensions.
|
||||
if (lhsShape.size() > 2 && rhsShape.size() == 2) {
|
||||
// (s1 x s2 x... x sKx x M x N) MATMUL (N x P)
|
||||
// (s1 x s2 x... x sK x M x N) MATMUL (N x P)
|
||||
// =>
|
||||
// (s1 x s2 x... x sKx x M x P)
|
||||
// (s1 x s2 x... x sK x M x P)
|
||||
|
||||
// Check legality of matrix multiplication.
|
||||
unsigned leftDims = lhsShape.size();
|
||||
|
@ -350,6 +350,21 @@ void ONNXMatMulOp::inferShapes() {
|
|||
for (int i = 0; i < leftDims - 1; ++i)
|
||||
dims.emplace_back(lhsShape[i]);
|
||||
dims.emplace_back(rhsShape[1]);
|
||||
} else if (lhsShape.size() == 2 && rhsShape.size() > 2) {
|
||||
// (M x N) MATMUL (s1 x s2 x... x sK x N x P)
|
||||
// =>
|
||||
// (s1 x s2 x... x sK x M x P)
|
||||
|
||||
// Check legality of matrix multiplication.
|
||||
unsigned rightDims = rhsShape.size();
|
||||
if (lhsShape[1] != -1 && rhsShape[rightDims - 2] != -1 &&
|
||||
lhsShape[1] != rhsShape[rightDims - 2])
|
||||
emitError("Attempt to multiply incompatible matrices.");
|
||||
|
||||
for (int i = 0; i < rightDims - 2; ++i)
|
||||
dims.emplace_back(rhsShape[i]);
|
||||
dims.emplace_back(lhsShape[0]);
|
||||
dims.emplace_back(rhsShape[rightDims - 1]);
|
||||
} else {
|
||||
dims.emplace_back(lhsShape.size() == 1 ? 1 : lhsShape[0]);
|
||||
dims.emplace_back(rhsShape.size() == 1 ? 1 : rhsShape[1]);
|
||||
|
|
Loading…
Reference in New Issue