Change variable names to use rank. Add aditional check for scalars.

This commit is contained in:
Doru Bercea 2020-01-27 12:08:23 -05:00
parent 3f5c543782
commit b450a763d1
1 changed files with 22 additions and 18 deletions

View File

@ -350,7 +350,11 @@ void ONNXMatMulOp::inferShapes() {
SmallVector<int64_t, 2> dims; SmallVector<int64_t, 2> dims;
auto lhsShape = lhsTy.getShape(); auto lhsShape = lhsTy.getShape();
auto rhsShape = rhsTy.getShape(); auto rhsShape = rhsTy.getShape();
if (lhsShape.size() == 1 && rhsShape.size() == 1) {
if (lhsShape.size() < 1 && rhsShape.size() < 1) {
// Multiplication by scalars is not allowed.
emitError("Multiplication by scalar arguments not allowed.");
} else if (lhsShape.size() == 1 && rhsShape.size() == 1) {
// Special case when both arrays are 1-dimensional and according to // Special case when both arrays are 1-dimensional and according to
// numpy rules the types need to be extended to 1xN and Nx1. Helper sizes // numpy rules the types need to be extended to 1xN and Nx1. Helper sizes
// need to be removed after the multiplication but cannot be removed if all // need to be removed after the multiplication but cannot be removed if all
@ -365,12 +369,12 @@ void ONNXMatMulOp::inferShapes() {
// (s1 x s2 x... x sK x M x P) // (s1 x s2 x... x sK x M x P)
// Check legality of matrix multiplication. // Check legality of matrix multiplication.
unsigned leftDims = lhsShape.size(); unsigned lhsRank = lhsShape.size();
if (lhsShape[leftDims - 1] != -1 && rhsShape[0] != -1 && if (lhsShape[lhsRank - 1] != -1 && rhsShape[0] != -1 &&
lhsShape[leftDims - 1] != rhsShape[0]) lhsShape[lhsRank - 1] != rhsShape[0])
emitError("Attempt to multiply incompatible matrices."); emitError("Attempt to multiply incompatible matrices.");
for (int i = 0; i < leftDims - 1; ++i) for (int i = 0; i < lhsRank - 1; ++i)
dims.emplace_back(lhsShape[i]); dims.emplace_back(lhsShape[i]);
dims.emplace_back(rhsShape[1]); dims.emplace_back(rhsShape[1]);
} else if (lhsShape.size() == 2 && rhsShape.size() > 2) { } else if (lhsShape.size() == 2 && rhsShape.size() > 2) {
@ -379,39 +383,39 @@ void ONNXMatMulOp::inferShapes() {
// (s1 x s2 x... x sK x M x P) // (s1 x s2 x... x sK x M x P)
// Check legality of matrix multiplication. // Check legality of matrix multiplication.
unsigned rightDims = rhsShape.size(); unsigned rhsRank = rhsShape.size();
if (lhsShape[1] != -1 && rhsShape[rightDims - 2] != -1 && if (lhsShape[1] != -1 && rhsShape[rhsRank - 2] != -1 &&
lhsShape[1] != rhsShape[rightDims - 2]) lhsShape[1] != rhsShape[rhsRank - 2])
emitError("Attempt to multiply incompatible matrices."); emitError("Attempt to multiply incompatible matrices.");
for (int i = 0; i < rightDims - 2; ++i) for (int i = 0; i < rhsRank - 2; ++i)
dims.emplace_back(rhsShape[i]); dims.emplace_back(rhsShape[i]);
dims.emplace_back(lhsShape[0]); dims.emplace_back(lhsShape[0]);
dims.emplace_back(rhsShape[rightDims - 1]); dims.emplace_back(rhsShape[rhsRank - 1]);
} else if (lhsShape.size() > 2 && rhsShape.size() > 2) { } else if (lhsShape.size() > 2 && rhsShape.size() > 2) {
// (s1 x s2 x... x sK x M x N) MATMUL (t1 x t2 x... x tK x N x P) // (s1 x s2 x... x sK x M x N) MATMUL (t1 x t2 x... x tK x N x P)
// => // =>
// (u1 x u2 x... x uK x M x P) // (u1 x u2 x... x uK x M x P)
// Check legality of matrix multiplication. // Check legality of matrix multiplication.
unsigned leftDims = lhsShape.size(); unsigned lhsRank = lhsShape.size();
unsigned rightDims = rhsShape.size(); unsigned rhsRank = rhsShape.size();
if (lhsShape[leftDims - 1] != -1 && rhsShape[rightDims - 2] != -1 && if (lhsShape[lhsRank - 1] != -1 && rhsShape[rhsRank - 2] != -1 &&
lhsShape[leftDims - 1] != rhsShape[rightDims - 2]) lhsShape[lhsRank - 1] != rhsShape[rhsRank - 2])
emitError("Attempt to multiply incompatible matrices."); emitError("Attempt to multiply incompatible matrices.");
// Check and perform broadcasting for the shapes. // Check and perform broadcasting for the shapes.
SmallVector<int64_t, 2> lhsBcastShape; SmallVector<int64_t, 2> lhsBcastShape;
for (int i = 0; i < leftDims - 2; ++i) for (int i = 0; i < lhsRank - 2; ++i)
lhsBcastShape.emplace_back(lhsShape[i]); lhsBcastShape.emplace_back(lhsShape[i]);
SmallVector<int64_t, 2> rhsBcastShape; SmallVector<int64_t, 2> rhsBcastShape;
for (int i = 0; i < rightDims - 2; ++i) for (int i = 0; i < rhsRank - 2; ++i)
rhsBcastShape.emplace_back(rhsShape[i]); rhsBcastShape.emplace_back(rhsShape[i]);
if (!getBroadcastedShape(lhsBcastShape, rhsBcastShape, dims)) if (!getBroadcastedShape(lhsBcastShape, rhsBcastShape, dims))
emitError("Broadcasted dimensions are incompatible."); emitError("Broadcasted dimensions are incompatible.");
dims.emplace_back(lhsShape[leftDims - 2]); dims.emplace_back(lhsShape[lhsRank - 2]);
dims.emplace_back(rhsShape[rightDims - 1]); dims.emplace_back(rhsShape[rhsRank - 1]);
} else { } else {
// This case covers all remaining combinations of 1 and 2-D matrices. // This case covers all remaining combinations of 1 and 2-D matrices.
int64_t lhsDim = lhsShape[0]; int64_t lhsDim = lhsShape[0];