Add support for shape broadcast.

This commit is contained in:
Doru Bercea 2020-01-10 12:27:34 -05:00
parent 38bffee619
commit a3995b61e7
1 changed files with 27 additions and 2 deletions

View File

@ -343,8 +343,8 @@ void ONNXMatMulOp::inferShapes() {
// Check legality of matrix multiplication.
unsigned leftDims = lhsShape.size();
if (lhsShape[leftDims - 2] != -1 && rhsShape[0] != -1 &&
lhsShape[leftDims - 2] != rhsShape[0])
if (lhsShape[leftDims - 1] != -1 && rhsShape[0] != -1 &&
lhsShape[leftDims - 1] != rhsShape[0])
emitError("Attempt to multiply incompatible matrices.");
for (int i = 0; i < leftDims - 1; ++i)
@ -365,7 +365,32 @@ void ONNXMatMulOp::inferShapes() {
dims.emplace_back(rhsShape[i]);
dims.emplace_back(lhsShape[0]);
dims.emplace_back(rhsShape[rightDims - 1]);
} else if (lhsShape.size() > 2 && rhsShape.size() > 2) {
// (s1 x s2 x... x sK x M x N) MATMUL (t1 x t2 x... x tK x N x P)
// =>
// (u1 x u2 x... x uK x M x P)
// Check legality of matrix multiplication.
unsigned leftDims = lhsShape.size();
unsigned rightDims = rhsShape.size();
if (lhsShape[leftDims - 1] != -1 && rhsShape[rightDims - 2] != -1 &&
lhsShape[leftDims - 1] != rhsShape[rightDims - 2])
emitError("Attempt to multiply incompatible matrices.");
// Check and perform broadcasting for the shapes.
SmallVector<int64_t, 2> lhsBcastShape;
for (int i = 0; i < leftDims - 2; ++i)
lhsBcastShape.emplace_back(lhsShape[i]);
SmallVector<int64_t, 2> rhsBcastShape;
for (int i = 0; i < rightDims - 2; ++i)
rhsBcastShape.emplace_back(rhsShape[i]);
if (!getBroadcastedShape(lhsBcastShape, rhsBcastShape, dims))
emitError("Broadcasted dimensions are incompatible.");
dims.emplace_back(lhsShape[leftDims - 2]);
dims.emplace_back(rhsShape[rightDims - 1]);
} else {
// This case covers all remaining combinations of 1 and 2-D matrices.
dims.emplace_back(lhsShape.size() == 1 ? 1 : lhsShape[0]);
dims.emplace_back(rhsShape.size() == 1 ? 1 : rhsShape[1]);
}