Change variable names to use rank. Add aditional check for scalars.
This commit is contained in:
		
							parent
							
								
									3f5c543782
								
							
						
					
					
						commit
						b450a763d1
					
				|  | @ -350,7 +350,11 @@ void ONNXMatMulOp::inferShapes() { | ||||||
|   SmallVector<int64_t, 2> dims; |   SmallVector<int64_t, 2> dims; | ||||||
|   auto lhsShape = lhsTy.getShape(); |   auto lhsShape = lhsTy.getShape(); | ||||||
|   auto rhsShape = rhsTy.getShape(); |   auto rhsShape = rhsTy.getShape(); | ||||||
|   if (lhsShape.size() == 1 && rhsShape.size() == 1) { | 
 | ||||||
|  |   if (lhsShape.size() < 1 && rhsShape.size() < 1) { | ||||||
|  |     // Multiplication by scalars is not allowed.
 | ||||||
|  |     emitError("Multiplication by scalar arguments not allowed."); | ||||||
|  |   } else if (lhsShape.size() == 1 && rhsShape.size() == 1) { | ||||||
|     // Special case when both arrays are 1-dimensional and according to
 |     // Special case when both arrays are 1-dimensional and according to
 | ||||||
|     // numpy rules the types need to be extended to 1xN and Nx1. Helper sizes
 |     // numpy rules the types need to be extended to 1xN and Nx1. Helper sizes
 | ||||||
|     // need to be removed after the multiplication but cannot be removed if all
 |     // need to be removed after the multiplication but cannot be removed if all
 | ||||||
|  | @ -365,12 +369,12 @@ void ONNXMatMulOp::inferShapes() { | ||||||
|     // (s1 x s2 x... x sK x M x P)
 |     // (s1 x s2 x... x sK x M x P)
 | ||||||
| 
 | 
 | ||||||
|     // Check legality of matrix multiplication.
 |     // Check legality of matrix multiplication.
 | ||||||
|     unsigned leftDims = lhsShape.size(); |     unsigned lhsRank = lhsShape.size(); | ||||||
|     if (lhsShape[leftDims - 1] != -1 && rhsShape[0] != -1 && |     if (lhsShape[lhsRank - 1] != -1 && rhsShape[0] != -1 && | ||||||
|         lhsShape[leftDims - 1] != rhsShape[0]) |         lhsShape[lhsRank - 1] != rhsShape[0]) | ||||||
|       emitError("Attempt to multiply incompatible matrices."); |       emitError("Attempt to multiply incompatible matrices."); | ||||||
| 
 | 
 | ||||||
|     for (int i = 0; i < leftDims - 1; ++i) |     for (int i = 0; i < lhsRank - 1; ++i) | ||||||
|       dims.emplace_back(lhsShape[i]); |       dims.emplace_back(lhsShape[i]); | ||||||
|     dims.emplace_back(rhsShape[1]); |     dims.emplace_back(rhsShape[1]); | ||||||
|   } else if (lhsShape.size() == 2 && rhsShape.size() > 2) { |   } else if (lhsShape.size() == 2 && rhsShape.size() > 2) { | ||||||
|  | @ -379,39 +383,39 @@ void ONNXMatMulOp::inferShapes() { | ||||||
|     // (s1 x s2 x... x sK x M x P)
 |     // (s1 x s2 x... x sK x M x P)
 | ||||||
| 
 | 
 | ||||||
|     // Check legality of matrix multiplication.
 |     // Check legality of matrix multiplication.
 | ||||||
|     unsigned rightDims = rhsShape.size(); |     unsigned rhsRank = rhsShape.size(); | ||||||
|     if (lhsShape[1] != -1 && rhsShape[rightDims - 2] != -1 && |     if (lhsShape[1] != -1 && rhsShape[rhsRank - 2] != -1 && | ||||||
|         lhsShape[1] != rhsShape[rightDims - 2]) |         lhsShape[1] != rhsShape[rhsRank - 2]) | ||||||
|       emitError("Attempt to multiply incompatible matrices."); |       emitError("Attempt to multiply incompatible matrices."); | ||||||
| 
 | 
 | ||||||
|     for (int i = 0; i < rightDims - 2; ++i) |     for (int i = 0; i < rhsRank - 2; ++i) | ||||||
|       dims.emplace_back(rhsShape[i]); |       dims.emplace_back(rhsShape[i]); | ||||||
|     dims.emplace_back(lhsShape[0]); |     dims.emplace_back(lhsShape[0]); | ||||||
|     dims.emplace_back(rhsShape[rightDims - 1]); |     dims.emplace_back(rhsShape[rhsRank - 1]); | ||||||
|   } else if (lhsShape.size() > 2 && rhsShape.size() > 2) { |   } else if (lhsShape.size() > 2 && rhsShape.size() > 2) { | ||||||
|     // (s1 x s2 x... x sK x M x N) MATMUL (t1 x t2 x... x tK x N x P)
 |     // (s1 x s2 x... x sK x M x N) MATMUL (t1 x t2 x... x tK x N x P)
 | ||||||
|     // =>
 |     // =>
 | ||||||
|     // (u1 x u2 x... x uK x M x P)
 |     // (u1 x u2 x... x uK x M x P)
 | ||||||
| 
 | 
 | ||||||
|     // Check legality of matrix multiplication.
 |     // Check legality of matrix multiplication.
 | ||||||
|     unsigned leftDims = lhsShape.size(); |     unsigned lhsRank = lhsShape.size(); | ||||||
|     unsigned rightDims = rhsShape.size(); |     unsigned rhsRank = rhsShape.size(); | ||||||
|     if (lhsShape[leftDims - 1] != -1 && rhsShape[rightDims - 2] != -1 && |     if (lhsShape[lhsRank - 1] != -1 && rhsShape[rhsRank - 2] != -1 && | ||||||
|         lhsShape[leftDims - 1] != rhsShape[rightDims - 2]) |         lhsShape[lhsRank - 1] != rhsShape[rhsRank - 2]) | ||||||
|       emitError("Attempt to multiply incompatible matrices."); |       emitError("Attempt to multiply incompatible matrices."); | ||||||
| 
 | 
 | ||||||
|     // Check and perform broadcasting for the shapes.
 |     // Check and perform broadcasting for the shapes.
 | ||||||
|     SmallVector<int64_t, 2> lhsBcastShape; |     SmallVector<int64_t, 2> lhsBcastShape; | ||||||
|     for (int i = 0; i < leftDims - 2; ++i) |     for (int i = 0; i < lhsRank - 2; ++i) | ||||||
|       lhsBcastShape.emplace_back(lhsShape[i]); |       lhsBcastShape.emplace_back(lhsShape[i]); | ||||||
|     SmallVector<int64_t, 2> rhsBcastShape; |     SmallVector<int64_t, 2> rhsBcastShape; | ||||||
|     for (int i = 0; i < rightDims - 2; ++i) |     for (int i = 0; i < rhsRank - 2; ++i) | ||||||
|       rhsBcastShape.emplace_back(rhsShape[i]); |       rhsBcastShape.emplace_back(rhsShape[i]); | ||||||
|     if (!getBroadcastedShape(lhsBcastShape, rhsBcastShape, dims)) |     if (!getBroadcastedShape(lhsBcastShape, rhsBcastShape, dims)) | ||||||
|       emitError("Broadcasted dimensions are incompatible."); |       emitError("Broadcasted dimensions are incompatible."); | ||||||
| 
 | 
 | ||||||
|     dims.emplace_back(lhsShape[leftDims - 2]); |     dims.emplace_back(lhsShape[lhsRank - 2]); | ||||||
|     dims.emplace_back(rhsShape[rightDims - 1]); |     dims.emplace_back(rhsShape[rhsRank - 1]); | ||||||
|   } else { |   } else { | ||||||
|     // This case covers all remaining combinations of 1 and 2-D matrices.
 |     // This case covers all remaining combinations of 1 and 2-D matrices.
 | ||||||
|     int64_t lhsDim = lhsShape[0]; |     int64_t lhsDim = lhsShape[0]; | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue