Fix conditions.
This commit is contained in:
		
							parent
							
								
									a3995b61e7
								
							
						
					
					
						commit
						96551ef71e
					
				|  | @ -334,66 +334,63 @@ void ONNXMatMulOp::inferShapes() { | |||
|         lhsShape[0] != rhsShape[0]) | ||||
|       emitError("Attempt to multiply incompatible matrices."); | ||||
|     dims.emplace_back(1); | ||||
|   } else if (lhsShape.size() > 2 && rhsShape.size() == 2) { | ||||
|     // (s1 x s2 x... x sK x M x N) MATMUL (N x P)
 | ||||
|     // =>
 | ||||
|     // (s1 x s2 x... x sK x M x P)
 | ||||
| 
 | ||||
|     // Check legality of matrix multiplication.
 | ||||
|     unsigned leftDims = lhsShape.size(); | ||||
|     if (lhsShape[leftDims - 1] != -1 && rhsShape[0] != -1 && | ||||
|         lhsShape[leftDims - 1] != rhsShape[0]) | ||||
|       emitError("Attempt to multiply incompatible matrices."); | ||||
| 
 | ||||
|     for (int i = 0; i < leftDims - 1; ++i) | ||||
|       dims.emplace_back(lhsShape[i]); | ||||
|     dims.emplace_back(rhsShape[1]); | ||||
|   } else if (lhsShape.size() == 2 && rhsShape.size() > 2) { | ||||
|     // (M x N) MATMUL (s1 x s2 x... x sK x N x P)
 | ||||
|     // =>
 | ||||
|     // (s1 x s2 x... x sK x M x P)
 | ||||
| 
 | ||||
|     // Check legality of matrix multiplication.
 | ||||
|     unsigned rightDims = rhsShape.size(); | ||||
|     if (lhsShape[1] != -1 && rhsShape[rightDims - 2] != -1 && | ||||
|         lhsShape[1] != rhsShape[rightDims - 2]) | ||||
|       emitError("Attempt to multiply incompatible matrices."); | ||||
| 
 | ||||
|     for (int i = 0; i < rightDims - 2; ++i) | ||||
|       dims.emplace_back(rhsShape[i]); | ||||
|     dims.emplace_back(lhsShape[0]); | ||||
|     dims.emplace_back(rhsShape[rightDims - 1]); | ||||
|   } else if (lhsShape.size() > 2 && rhsShape.size() > 2) { | ||||
|     // (s1 x s2 x... x sK x M x N) MATMUL (t1 x t2 x... x tK x N x P)
 | ||||
|     // =>
 | ||||
|     // (u1 x u2 x... x uK x M x P)
 | ||||
| 
 | ||||
|     // Check legality of matrix multiplication.
 | ||||
|     unsigned leftDims = lhsShape.size(); | ||||
|     unsigned rightDims = rhsShape.size(); | ||||
|     if (lhsShape[leftDims - 1] != -1 && rhsShape[rightDims - 2] != -1 && | ||||
|         lhsShape[leftDims - 1] != rhsShape[rightDims - 2]) | ||||
|       emitError("Attempt to multiply incompatible matrices."); | ||||
| 
 | ||||
|     // Check and perform broadcasting for the shapes.
 | ||||
|     SmallVector<int64_t, 2> lhsBcastShape; | ||||
|     for (int i = 0; i < leftDims - 2; ++i) | ||||
|       lhsBcastShape.emplace_back(lhsShape[i]); | ||||
|     SmallVector<int64_t, 2> rhsBcastShape; | ||||
|     for (int i = 0; i < rightDims - 2; ++i) | ||||
|       rhsBcastShape.emplace_back(rhsShape[i]); | ||||
|     if (!getBroadcastedShape(lhsBcastShape, rhsBcastShape, dims)) | ||||
|       emitError("Broadcasted dimensions are incompatible."); | ||||
| 
 | ||||
|     dims.emplace_back(lhsShape[leftDims - 2]); | ||||
|     dims.emplace_back(rhsShape[rightDims - 1]); | ||||
|   } else { | ||||
|     // Special cases for when at least one matrix has more than two dimensions.
 | ||||
|     if (lhsShape.size() > 2 && rhsShape.size() == 2) { | ||||
|       // (s1 x s2 x... x sK x M x N) MATMUL (N x P)
 | ||||
|       // =>
 | ||||
|       // (s1 x s2 x... x sK x M x P)
 | ||||
| 
 | ||||
|       // Check legality of matrix multiplication.
 | ||||
|       unsigned leftDims = lhsShape.size(); | ||||
|       if (lhsShape[leftDims - 1] != -1 && rhsShape[0] != -1 && | ||||
|           lhsShape[leftDims - 1] != rhsShape[0]) | ||||
|         emitError("Attempt to multiply incompatible matrices."); | ||||
| 
 | ||||
|       for (int i = 0; i < leftDims - 1; ++i) | ||||
|         dims.emplace_back(lhsShape[i]); | ||||
|       dims.emplace_back(rhsShape[1]); | ||||
|     } else if (lhsShape.size() == 2 && rhsShape.size() > 2) { | ||||
|       // (M x N) MATMUL (s1 x s2 x... x sK x N x P)
 | ||||
|       // =>
 | ||||
|       // (s1 x s2 x... x sK x M x P)
 | ||||
| 
 | ||||
|       // Check legality of matrix multiplication.
 | ||||
|       unsigned rightDims = rhsShape.size(); | ||||
|       if (lhsShape[1] != -1 && rhsShape[rightDims - 2] != -1 && | ||||
|           lhsShape[1] != rhsShape[rightDims - 2]) | ||||
|         emitError("Attempt to multiply incompatible matrices."); | ||||
| 
 | ||||
|       for (int i = 0; i < rightDims - 2; ++i) | ||||
|         dims.emplace_back(rhsShape[i]); | ||||
|       dims.emplace_back(lhsShape[0]); | ||||
|       dims.emplace_back(rhsShape[rightDims - 1]); | ||||
|     } else if (lhsShape.size() > 2 && rhsShape.size() > 2) { | ||||
|       // (s1 x s2 x... x sK x M x N) MATMUL (t1 x t2 x... x tK x N x P)
 | ||||
|       // =>
 | ||||
|       // (u1 x u2 x... x uK x M x P)
 | ||||
| 
 | ||||
|       // Check legality of matrix multiplication.
 | ||||
|       unsigned leftDims = lhsShape.size(); | ||||
|       unsigned rightDims = rhsShape.size(); | ||||
|       if (lhsShape[leftDims - 1] != -1 && rhsShape[rightDims - 2] != -1 && | ||||
|           lhsShape[leftDims - 1] != rhsShape[rightDims - 2]) | ||||
|         emitError("Attempt to multiply incompatible matrices."); | ||||
| 
 | ||||
|       // Check and perform broadcasting for the shapes.
 | ||||
|       SmallVector<int64_t, 2> lhsBcastShape; | ||||
|       for (int i = 0; i < leftDims - 2; ++i) | ||||
|         lhsBcastShape.emplace_back(lhsShape[i]); | ||||
|       SmallVector<int64_t, 2> rhsBcastShape; | ||||
|       for (int i = 0; i < rightDims - 2; ++i) | ||||
|         rhsBcastShape.emplace_back(rhsShape[i]); | ||||
|       if (!getBroadcastedShape(lhsBcastShape, rhsBcastShape, dims)) | ||||
|         emitError("Broadcasted dimensions are incompatible."); | ||||
| 
 | ||||
|       dims.emplace_back(lhsShape[leftDims - 2]); | ||||
|       dims.emplace_back(rhsShape[rightDims - 1]); | ||||
|     } else { | ||||
|       // This case covers all remaining combinations of 1 and 2-D matrices.
 | ||||
|       dims.emplace_back(lhsShape.size() == 1 ? 1 : lhsShape[0]); | ||||
|       dims.emplace_back(rhsShape.size() == 1 ? 1 : rhsShape[1]); | ||||
|     } | ||||
|     // This case covers all remaining combinations of 1 and 2-D matrices.
 | ||||
|     dims.emplace_back(lhsShape.size() == 1 ? 1 : lhsShape[0]); | ||||
|     dims.emplace_back(rhsShape.size() == 1 ? 1 : rhsShape[1]); | ||||
|   } | ||||
| 
 | ||||
|   getResult()->setType(RankedTensorType::get(dims, lhsTy.getElementType())); | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue