Change variable names to use rank. Add aditional check for scalars.
This commit is contained in:
parent
3f5c543782
commit
b450a763d1
|
@ -350,7 +350,11 @@ void ONNXMatMulOp::inferShapes() {
|
|||
SmallVector<int64_t, 2> dims;
|
||||
auto lhsShape = lhsTy.getShape();
|
||||
auto rhsShape = rhsTy.getShape();
|
||||
if (lhsShape.size() == 1 && rhsShape.size() == 1) {
|
||||
|
||||
if (lhsShape.size() < 1 && rhsShape.size() < 1) {
|
||||
// Multiplication by scalars is not allowed.
|
||||
emitError("Multiplication by scalar arguments not allowed.");
|
||||
} else if (lhsShape.size() == 1 && rhsShape.size() == 1) {
|
||||
// Special case when both arrays are 1-dimensional and according to
|
||||
// numpy rules the types need to be extended to 1xN and Nx1. Helper sizes
|
||||
// need to be removed after the multiplication but cannot be removed if all
|
||||
|
@ -365,12 +369,12 @@ void ONNXMatMulOp::inferShapes() {
|
|||
// (s1 x s2 x... x sK x M x P)
|
||||
|
||||
// Check legality of matrix multiplication.
|
||||
unsigned leftDims = lhsShape.size();
|
||||
if (lhsShape[leftDims - 1] != -1 && rhsShape[0] != -1 &&
|
||||
lhsShape[leftDims - 1] != rhsShape[0])
|
||||
unsigned lhsRank = lhsShape.size();
|
||||
if (lhsShape[lhsRank - 1] != -1 && rhsShape[0] != -1 &&
|
||||
lhsShape[lhsRank - 1] != rhsShape[0])
|
||||
emitError("Attempt to multiply incompatible matrices.");
|
||||
|
||||
for (int i = 0; i < leftDims - 1; ++i)
|
||||
for (int i = 0; i < lhsRank - 1; ++i)
|
||||
dims.emplace_back(lhsShape[i]);
|
||||
dims.emplace_back(rhsShape[1]);
|
||||
} else if (lhsShape.size() == 2 && rhsShape.size() > 2) {
|
||||
|
@ -379,39 +383,39 @@ void ONNXMatMulOp::inferShapes() {
|
|||
// (s1 x s2 x... x sK x M x P)
|
||||
|
||||
// Check legality of matrix multiplication.
|
||||
unsigned rightDims = rhsShape.size();
|
||||
if (lhsShape[1] != -1 && rhsShape[rightDims - 2] != -1 &&
|
||||
lhsShape[1] != rhsShape[rightDims - 2])
|
||||
unsigned rhsRank = rhsShape.size();
|
||||
if (lhsShape[1] != -1 && rhsShape[rhsRank - 2] != -1 &&
|
||||
lhsShape[1] != rhsShape[rhsRank - 2])
|
||||
emitError("Attempt to multiply incompatible matrices.");
|
||||
|
||||
for (int i = 0; i < rightDims - 2; ++i)
|
||||
for (int i = 0; i < rhsRank - 2; ++i)
|
||||
dims.emplace_back(rhsShape[i]);
|
||||
dims.emplace_back(lhsShape[0]);
|
||||
dims.emplace_back(rhsShape[rightDims - 1]);
|
||||
dims.emplace_back(rhsShape[rhsRank - 1]);
|
||||
} else if (lhsShape.size() > 2 && rhsShape.size() > 2) {
|
||||
// (s1 x s2 x... x sK x M x N) MATMUL (t1 x t2 x... x tK x N x P)
|
||||
// =>
|
||||
// (u1 x u2 x... x uK x M x P)
|
||||
|
||||
// Check legality of matrix multiplication.
|
||||
unsigned leftDims = lhsShape.size();
|
||||
unsigned rightDims = rhsShape.size();
|
||||
if (lhsShape[leftDims - 1] != -1 && rhsShape[rightDims - 2] != -1 &&
|
||||
lhsShape[leftDims - 1] != rhsShape[rightDims - 2])
|
||||
unsigned lhsRank = lhsShape.size();
|
||||
unsigned rhsRank = rhsShape.size();
|
||||
if (lhsShape[lhsRank - 1] != -1 && rhsShape[rhsRank - 2] != -1 &&
|
||||
lhsShape[lhsRank - 1] != rhsShape[rhsRank - 2])
|
||||
emitError("Attempt to multiply incompatible matrices.");
|
||||
|
||||
// Check and perform broadcasting for the shapes.
|
||||
SmallVector<int64_t, 2> lhsBcastShape;
|
||||
for (int i = 0; i < leftDims - 2; ++i)
|
||||
for (int i = 0; i < lhsRank - 2; ++i)
|
||||
lhsBcastShape.emplace_back(lhsShape[i]);
|
||||
SmallVector<int64_t, 2> rhsBcastShape;
|
||||
for (int i = 0; i < rightDims - 2; ++i)
|
||||
for (int i = 0; i < rhsRank - 2; ++i)
|
||||
rhsBcastShape.emplace_back(rhsShape[i]);
|
||||
if (!getBroadcastedShape(lhsBcastShape, rhsBcastShape, dims))
|
||||
emitError("Broadcasted dimensions are incompatible.");
|
||||
|
||||
dims.emplace_back(lhsShape[leftDims - 2]);
|
||||
dims.emplace_back(rhsShape[rightDims - 1]);
|
||||
dims.emplace_back(lhsShape[lhsRank - 2]);
|
||||
dims.emplace_back(rhsShape[rhsRank - 1]);
|
||||
} else {
|
||||
// This case covers all remaining combinations of 1 and 2-D matrices.
|
||||
int64_t lhsDim = lhsShape[0];
|
||||
|
|
Loading…
Reference in New Issue