Add support for broadcasting left matrix.
This commit is contained in:
		
							parent
							
								
									d176b84506
								
							
						
					
					
						commit
						38bffee619
					
				|  | @ -337,9 +337,9 @@ void ONNXMatMulOp::inferShapes() { | |||
|   } else { | ||||
|     // Special cases for when at least one matrix has more than two dimensions.
 | ||||
|     if (lhsShape.size() > 2 && rhsShape.size() == 2) { | ||||
|       // (s1 x s2 x... x sKx x M x N) MATMUL (N x P)
 | ||||
|       // (s1 x s2 x... x sK x M x N) MATMUL (N x P)
 | ||||
|       // =>
 | ||||
|       // (s1 x s2 x... x sKx x M x P)
 | ||||
|       // (s1 x s2 x... x sK x M x P)
 | ||||
| 
 | ||||
|       // Check legality of matrix multiplication.
 | ||||
|       unsigned leftDims = lhsShape.size(); | ||||
|  | @ -350,6 +350,21 @@ void ONNXMatMulOp::inferShapes() { | |||
|       for (int i = 0; i < leftDims - 1; ++i) | ||||
|         dims.emplace_back(lhsShape[i]); | ||||
|       dims.emplace_back(rhsShape[1]); | ||||
|     } else if (lhsShape.size() == 2 && rhsShape.size() > 2) { | ||||
|       // (M x N) MATMUL (s1 x s2 x... x sK x N x P)
 | ||||
|       // =>
 | ||||
|       // (s1 x s2 x... x sK x M x P)
 | ||||
| 
 | ||||
|       // Check legality of matrix multiplication.
 | ||||
|       unsigned rightDims = rhsShape.size(); | ||||
|       if (lhsShape[1] != -1 && rhsShape[rightDims - 2] != -1 && | ||||
|           lhsShape[1] != rhsShape[rightDims - 2]) | ||||
|         emitError("Attempt to multiply incompatible matrices."); | ||||
| 
 | ||||
|       for (int i = 0; i < rightDims - 2; ++i) | ||||
|         dims.emplace_back(rhsShape[i]); | ||||
|       dims.emplace_back(lhsShape[0]); | ||||
|       dims.emplace_back(rhsShape[rightDims - 1]); | ||||
|     } else { | ||||
|       dims.emplace_back(lhsShape.size() == 1 ? 1 : lhsShape[0]); | ||||
|       dims.emplace_back(rhsShape.size() == 1 ? 1 : rhsShape[1]); | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue