Add support for broadcasting right matrix.
This commit is contained in:
		
							parent
							
								
									170296b7c6
								
							
						
					
					
						commit
						d176b84506
					
				|  | @ -329,15 +329,32 @@ void ONNXMatMulOp::inferShapes() { | ||||||
|     // Special case when both arrays are 1-dimensional and according to
 |     // Special case when both arrays are 1-dimensional and according to
 | ||||||
|     // numpy rules the types need to be extended to 1xN and Nx1. Helper sizes
 |     // numpy rules the types need to be extended to 1xN and Nx1. Helper sizes
 | ||||||
|     // need to be removed after the multiplication but cannot be removed if all
 |     // need to be removed after the multiplication but cannot be removed if all
 | ||||||
|     // remaining sizes are 1.
 |     // sizes are 1.
 | ||||||
|     if (lhsShape[0] != -1 && rhsShape[0] != -1 && |     if (lhsShape[0] != -1 && rhsShape[0] != -1 && | ||||||
|         lhsShape[0] != rhsShape[0]) |         lhsShape[0] != rhsShape[0]) | ||||||
|       emitError("Attempt to multiply incompatible matrices."); |       emitError("Attempt to multiply incompatible matrices."); | ||||||
|     dims.emplace_back(1); |     dims.emplace_back(1); | ||||||
|  |   } else { | ||||||
|  |     // Special cases for when at least one matrix has more than two dimensions.
 | ||||||
|  |     if (lhsShape.size() > 2 && rhsShape.size() == 2) { | ||||||
|  |       // (s1 x s2 x... x sKx x M x N) MATMUL (N x P)
 | ||||||
|  |       // =>
 | ||||||
|  |       // (s1 x s2 x... x sKx x M x P)
 | ||||||
|  | 
 | ||||||
|  |       // Check legality of matrix multiplication.
 | ||||||
|  |       unsigned leftDims = lhsShape.size(); | ||||||
|  |       if (lhsShape[leftDims - 2] != -1 && rhsShape[0] != -1 && | ||||||
|  |           lhsShape[leftDims - 2] != rhsShape[0]) | ||||||
|  |         emitError("Attempt to multiply incompatible matrices."); | ||||||
|  | 
 | ||||||
|  |       for (int i = 0; i < leftDims - 1; ++i) | ||||||
|  |         dims.emplace_back(lhsShape[i]); | ||||||
|  |       dims.emplace_back(rhsShape[1]); | ||||||
|     } else { |     } else { | ||||||
|       dims.emplace_back(lhsShape.size() == 1 ? 1 : lhsShape[0]); |       dims.emplace_back(lhsShape.size() == 1 ? 1 : lhsShape[0]); | ||||||
|       dims.emplace_back(rhsShape.size() == 1 ? 1 : rhsShape[1]); |       dims.emplace_back(rhsShape.size() == 1 ? 1 : rhsShape[1]); | ||||||
|     } |     } | ||||||
|  |   } | ||||||
| 
 | 
 | ||||||
|   getResult()->setType(RankedTensorType::get(dims, lhsTy.getElementType())); |   getResult()->setType(RankedTensorType::get(dims, lhsTy.getElementType())); | ||||||
| } | } | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue