Fix conditions.
This commit is contained in:
		
							parent
							
								
									a3995b61e7
								
							
						
					
					
						commit
						96551ef71e
					
				|  | @ -334,9 +334,7 @@ void ONNXMatMulOp::inferShapes() { | ||||||
|         lhsShape[0] != rhsShape[0]) |         lhsShape[0] != rhsShape[0]) | ||||||
|       emitError("Attempt to multiply incompatible matrices."); |       emitError("Attempt to multiply incompatible matrices."); | ||||||
|     dims.emplace_back(1); |     dims.emplace_back(1); | ||||||
|   } else { |   } else if (lhsShape.size() > 2 && rhsShape.size() == 2) { | ||||||
|     // Special cases for when at least one matrix has more than two dimensions.
 |  | ||||||
|     if (lhsShape.size() > 2 && rhsShape.size() == 2) { |  | ||||||
|     // (s1 x s2 x... x sK x M x N) MATMUL (N x P)
 |     // (s1 x s2 x... x sK x M x N) MATMUL (N x P)
 | ||||||
|     // =>
 |     // =>
 | ||||||
|     // (s1 x s2 x... x sK x M x P)
 |     // (s1 x s2 x... x sK x M x P)
 | ||||||
|  | @ -394,7 +392,6 @@ void ONNXMatMulOp::inferShapes() { | ||||||
|     dims.emplace_back(lhsShape.size() == 1 ? 1 : lhsShape[0]); |     dims.emplace_back(lhsShape.size() == 1 ? 1 : lhsShape[0]); | ||||||
|     dims.emplace_back(rhsShape.size() == 1 ? 1 : rhsShape[1]); |     dims.emplace_back(rhsShape.size() == 1 ? 1 : rhsShape[1]); | ||||||
|   } |   } | ||||||
|   } |  | ||||||
| 
 | 
 | ||||||
|   getResult()->setType(RankedTensorType::get(dims, lhsTy.getElementType())); |   getResult()->setType(RankedTensorType::get(dims, lhsTy.getElementType())); | ||||||
| } | } | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue