Change variable names to use rank. Add aditional check for scalars.
This commit is contained in:
		
							parent
							
								
									3f5c543782
								
							
						
					
					
						commit
						b450a763d1
					
				|  | @ -350,7 +350,11 @@ void ONNXMatMulOp::inferShapes() { | |||
|   SmallVector<int64_t, 2> dims; | ||||
|   auto lhsShape = lhsTy.getShape(); | ||||
|   auto rhsShape = rhsTy.getShape(); | ||||
|   if (lhsShape.size() == 1 && rhsShape.size() == 1) { | ||||
| 
 | ||||
|   if (lhsShape.size() < 1 && rhsShape.size() < 1) { | ||||
|     // Multiplication by scalars is not allowed.
 | ||||
|     emitError("Multiplication by scalar arguments not allowed."); | ||||
|   } else if (lhsShape.size() == 1 && rhsShape.size() == 1) { | ||||
|     // Special case when both arrays are 1-dimensional and according to
 | ||||
|     // numpy rules the types need to be extended to 1xN and Nx1. Helper sizes
 | ||||
|     // need to be removed after the multiplication but cannot be removed if all
 | ||||
|  | @ -365,12 +369,12 @@ void ONNXMatMulOp::inferShapes() { | |||
|     // (s1 x s2 x... x sK x M x P)
 | ||||
| 
 | ||||
|     // Check legality of matrix multiplication.
 | ||||
|     unsigned leftDims = lhsShape.size(); | ||||
|     if (lhsShape[leftDims - 1] != -1 && rhsShape[0] != -1 && | ||||
|         lhsShape[leftDims - 1] != rhsShape[0]) | ||||
|     unsigned lhsRank = lhsShape.size(); | ||||
|     if (lhsShape[lhsRank - 1] != -1 && rhsShape[0] != -1 && | ||||
|         lhsShape[lhsRank - 1] != rhsShape[0]) | ||||
|       emitError("Attempt to multiply incompatible matrices."); | ||||
| 
 | ||||
|     for (int i = 0; i < leftDims - 1; ++i) | ||||
|     for (int i = 0; i < lhsRank - 1; ++i) | ||||
|       dims.emplace_back(lhsShape[i]); | ||||
|     dims.emplace_back(rhsShape[1]); | ||||
|   } else if (lhsShape.size() == 2 && rhsShape.size() > 2) { | ||||
|  | @ -379,39 +383,39 @@ void ONNXMatMulOp::inferShapes() { | |||
|     // (s1 x s2 x... x sK x M x P)
 | ||||
| 
 | ||||
|     // Check legality of matrix multiplication.
 | ||||
|     unsigned rightDims = rhsShape.size(); | ||||
|     if (lhsShape[1] != -1 && rhsShape[rightDims - 2] != -1 && | ||||
|         lhsShape[1] != rhsShape[rightDims - 2]) | ||||
|     unsigned rhsRank = rhsShape.size(); | ||||
|     if (lhsShape[1] != -1 && rhsShape[rhsRank - 2] != -1 && | ||||
|         lhsShape[1] != rhsShape[rhsRank - 2]) | ||||
|       emitError("Attempt to multiply incompatible matrices."); | ||||
| 
 | ||||
|     for (int i = 0; i < rightDims - 2; ++i) | ||||
|     for (int i = 0; i < rhsRank - 2; ++i) | ||||
|       dims.emplace_back(rhsShape[i]); | ||||
|     dims.emplace_back(lhsShape[0]); | ||||
|     dims.emplace_back(rhsShape[rightDims - 1]); | ||||
|     dims.emplace_back(rhsShape[rhsRank - 1]); | ||||
|   } else if (lhsShape.size() > 2 && rhsShape.size() > 2) { | ||||
|     // (s1 x s2 x... x sK x M x N) MATMUL (t1 x t2 x... x tK x N x P)
 | ||||
|     // =>
 | ||||
|     // (u1 x u2 x... x uK x M x P)
 | ||||
| 
 | ||||
|     // Check legality of matrix multiplication.
 | ||||
|     unsigned leftDims = lhsShape.size(); | ||||
|     unsigned rightDims = rhsShape.size(); | ||||
|     if (lhsShape[leftDims - 1] != -1 && rhsShape[rightDims - 2] != -1 && | ||||
|         lhsShape[leftDims - 1] != rhsShape[rightDims - 2]) | ||||
|     unsigned lhsRank = lhsShape.size(); | ||||
|     unsigned rhsRank = rhsShape.size(); | ||||
|     if (lhsShape[lhsRank - 1] != -1 && rhsShape[rhsRank - 2] != -1 && | ||||
|         lhsShape[lhsRank - 1] != rhsShape[rhsRank - 2]) | ||||
|       emitError("Attempt to multiply incompatible matrices."); | ||||
| 
 | ||||
|     // Check and perform broadcasting for the shapes.
 | ||||
|     SmallVector<int64_t, 2> lhsBcastShape; | ||||
|     for (int i = 0; i < leftDims - 2; ++i) | ||||
|     for (int i = 0; i < lhsRank - 2; ++i) | ||||
|       lhsBcastShape.emplace_back(lhsShape[i]); | ||||
|     SmallVector<int64_t, 2> rhsBcastShape; | ||||
|     for (int i = 0; i < rightDims - 2; ++i) | ||||
|     for (int i = 0; i < rhsRank - 2; ++i) | ||||
|       rhsBcastShape.emplace_back(rhsShape[i]); | ||||
|     if (!getBroadcastedShape(lhsBcastShape, rhsBcastShape, dims)) | ||||
|       emitError("Broadcasted dimensions are incompatible."); | ||||
| 
 | ||||
|     dims.emplace_back(lhsShape[leftDims - 2]); | ||||
|     dims.emplace_back(rhsShape[rightDims - 1]); | ||||
|     dims.emplace_back(lhsShape[lhsRank - 2]); | ||||
|     dims.emplace_back(rhsShape[rhsRank - 1]); | ||||
|   } else { | ||||
|     // This case covers all remaining combinations of 1 and 2-D matrices.
 | ||||
|     int64_t lhsDim = lhsShape[0]; | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue