Fix conditions.
This commit is contained in:
parent
a3995b61e7
commit
96551ef71e
|
@ -334,66 +334,63 @@ void ONNXMatMulOp::inferShapes() {
|
||||||
lhsShape[0] != rhsShape[0])
|
lhsShape[0] != rhsShape[0])
|
||||||
emitError("Attempt to multiply incompatible matrices.");
|
emitError("Attempt to multiply incompatible matrices.");
|
||||||
dims.emplace_back(1);
|
dims.emplace_back(1);
|
||||||
|
} else if (lhsShape.size() > 2 && rhsShape.size() == 2) {
|
||||||
|
// (s1 x s2 x... x sK x M x N) MATMUL (N x P)
|
||||||
|
// =>
|
||||||
|
// (s1 x s2 x... x sK x M x P)
|
||||||
|
|
||||||
|
// Check legality of matrix multiplication.
|
||||||
|
unsigned leftDims = lhsShape.size();
|
||||||
|
if (lhsShape[leftDims - 1] != -1 && rhsShape[0] != -1 &&
|
||||||
|
lhsShape[leftDims - 1] != rhsShape[0])
|
||||||
|
emitError("Attempt to multiply incompatible matrices.");
|
||||||
|
|
||||||
|
for (int i = 0; i < leftDims - 1; ++i)
|
||||||
|
dims.emplace_back(lhsShape[i]);
|
||||||
|
dims.emplace_back(rhsShape[1]);
|
||||||
|
} else if (lhsShape.size() == 2 && rhsShape.size() > 2) {
|
||||||
|
// (M x N) MATMUL (s1 x s2 x... x sK x N x P)
|
||||||
|
// =>
|
||||||
|
// (s1 x s2 x... x sK x M x P)
|
||||||
|
|
||||||
|
// Check legality of matrix multiplication.
|
||||||
|
unsigned rightDims = rhsShape.size();
|
||||||
|
if (lhsShape[1] != -1 && rhsShape[rightDims - 2] != -1 &&
|
||||||
|
lhsShape[1] != rhsShape[rightDims - 2])
|
||||||
|
emitError("Attempt to multiply incompatible matrices.");
|
||||||
|
|
||||||
|
for (int i = 0; i < rightDims - 2; ++i)
|
||||||
|
dims.emplace_back(rhsShape[i]);
|
||||||
|
dims.emplace_back(lhsShape[0]);
|
||||||
|
dims.emplace_back(rhsShape[rightDims - 1]);
|
||||||
|
} else if (lhsShape.size() > 2 && rhsShape.size() > 2) {
|
||||||
|
// (s1 x s2 x... x sK x M x N) MATMUL (t1 x t2 x... x tK x N x P)
|
||||||
|
// =>
|
||||||
|
// (u1 x u2 x... x uK x M x P)
|
||||||
|
|
||||||
|
// Check legality of matrix multiplication.
|
||||||
|
unsigned leftDims = lhsShape.size();
|
||||||
|
unsigned rightDims = rhsShape.size();
|
||||||
|
if (lhsShape[leftDims - 1] != -1 && rhsShape[rightDims - 2] != -1 &&
|
||||||
|
lhsShape[leftDims - 1] != rhsShape[rightDims - 2])
|
||||||
|
emitError("Attempt to multiply incompatible matrices.");
|
||||||
|
|
||||||
|
// Check and perform broadcasting for the shapes.
|
||||||
|
SmallVector<int64_t, 2> lhsBcastShape;
|
||||||
|
for (int i = 0; i < leftDims - 2; ++i)
|
||||||
|
lhsBcastShape.emplace_back(lhsShape[i]);
|
||||||
|
SmallVector<int64_t, 2> rhsBcastShape;
|
||||||
|
for (int i = 0; i < rightDims - 2; ++i)
|
||||||
|
rhsBcastShape.emplace_back(rhsShape[i]);
|
||||||
|
if (!getBroadcastedShape(lhsBcastShape, rhsBcastShape, dims))
|
||||||
|
emitError("Broadcasted dimensions are incompatible.");
|
||||||
|
|
||||||
|
dims.emplace_back(lhsShape[leftDims - 2]);
|
||||||
|
dims.emplace_back(rhsShape[rightDims - 1]);
|
||||||
} else {
|
} else {
|
||||||
// Special cases for when at least one matrix has more than two dimensions.
|
// This case covers all remaining combinations of 1 and 2-D matrices.
|
||||||
if (lhsShape.size() > 2 && rhsShape.size() == 2) {
|
dims.emplace_back(lhsShape.size() == 1 ? 1 : lhsShape[0]);
|
||||||
// (s1 x s2 x... x sK x M x N) MATMUL (N x P)
|
dims.emplace_back(rhsShape.size() == 1 ? 1 : rhsShape[1]);
|
||||||
// =>
|
|
||||||
// (s1 x s2 x... x sK x M x P)
|
|
||||||
|
|
||||||
// Check legality of matrix multiplication.
|
|
||||||
unsigned leftDims = lhsShape.size();
|
|
||||||
if (lhsShape[leftDims - 1] != -1 && rhsShape[0] != -1 &&
|
|
||||||
lhsShape[leftDims - 1] != rhsShape[0])
|
|
||||||
emitError("Attempt to multiply incompatible matrices.");
|
|
||||||
|
|
||||||
for (int i = 0; i < leftDims - 1; ++i)
|
|
||||||
dims.emplace_back(lhsShape[i]);
|
|
||||||
dims.emplace_back(rhsShape[1]);
|
|
||||||
} else if (lhsShape.size() == 2 && rhsShape.size() > 2) {
|
|
||||||
// (M x N) MATMUL (s1 x s2 x... x sK x N x P)
|
|
||||||
// =>
|
|
||||||
// (s1 x s2 x... x sK x M x P)
|
|
||||||
|
|
||||||
// Check legality of matrix multiplication.
|
|
||||||
unsigned rightDims = rhsShape.size();
|
|
||||||
if (lhsShape[1] != -1 && rhsShape[rightDims - 2] != -1 &&
|
|
||||||
lhsShape[1] != rhsShape[rightDims - 2])
|
|
||||||
emitError("Attempt to multiply incompatible matrices.");
|
|
||||||
|
|
||||||
for (int i = 0; i < rightDims - 2; ++i)
|
|
||||||
dims.emplace_back(rhsShape[i]);
|
|
||||||
dims.emplace_back(lhsShape[0]);
|
|
||||||
dims.emplace_back(rhsShape[rightDims - 1]);
|
|
||||||
} else if (lhsShape.size() > 2 && rhsShape.size() > 2) {
|
|
||||||
// (s1 x s2 x... x sK x M x N) MATMUL (t1 x t2 x... x tK x N x P)
|
|
||||||
// =>
|
|
||||||
// (u1 x u2 x... x uK x M x P)
|
|
||||||
|
|
||||||
// Check legality of matrix multiplication.
|
|
||||||
unsigned leftDims = lhsShape.size();
|
|
||||||
unsigned rightDims = rhsShape.size();
|
|
||||||
if (lhsShape[leftDims - 1] != -1 && rhsShape[rightDims - 2] != -1 &&
|
|
||||||
lhsShape[leftDims - 1] != rhsShape[rightDims - 2])
|
|
||||||
emitError("Attempt to multiply incompatible matrices.");
|
|
||||||
|
|
||||||
// Check and perform broadcasting for the shapes.
|
|
||||||
SmallVector<int64_t, 2> lhsBcastShape;
|
|
||||||
for (int i = 0; i < leftDims - 2; ++i)
|
|
||||||
lhsBcastShape.emplace_back(lhsShape[i]);
|
|
||||||
SmallVector<int64_t, 2> rhsBcastShape;
|
|
||||||
for (int i = 0; i < rightDims - 2; ++i)
|
|
||||||
rhsBcastShape.emplace_back(rhsShape[i]);
|
|
||||||
if (!getBroadcastedShape(lhsBcastShape, rhsBcastShape, dims))
|
|
||||||
emitError("Broadcasted dimensions are incompatible.");
|
|
||||||
|
|
||||||
dims.emplace_back(lhsShape[leftDims - 2]);
|
|
||||||
dims.emplace_back(rhsShape[rightDims - 1]);
|
|
||||||
} else {
|
|
||||||
// This case covers all remaining combinations of 1 and 2-D matrices.
|
|
||||||
dims.emplace_back(lhsShape.size() == 1 ? 1 : lhsShape[0]);
|
|
||||||
dims.emplace_back(rhsShape.size() == 1 ? 1 : rhsShape[1]);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
getResult()->setType(RankedTensorType::get(dims, lhsTy.getElementType()));
|
getResult()->setType(RankedTensorType::get(dims, lhsTy.getElementType()));
|
||||||
|
|
Loading…
Reference in New Issue