Add support for broadcasting left matrix.
This commit is contained in:
		
							parent
							
								
									d176b84506
								
							
						
					
					
						commit
						38bffee619
					
				| 
						 | 
					@ -337,9 +337,9 @@ void ONNXMatMulOp::inferShapes() {
 | 
				
			||||||
  } else {
 | 
					  } else {
 | 
				
			||||||
    // Special cases for when at least one matrix has more than two dimensions.
 | 
					    // Special cases for when at least one matrix has more than two dimensions.
 | 
				
			||||||
    if (lhsShape.size() > 2 && rhsShape.size() == 2) {
 | 
					    if (lhsShape.size() > 2 && rhsShape.size() == 2) {
 | 
				
			||||||
      // (s1 x s2 x... x sKx x M x N) MATMUL (N x P)
 | 
					      // (s1 x s2 x... x sK x M x N) MATMUL (N x P)
 | 
				
			||||||
      // =>
 | 
					      // =>
 | 
				
			||||||
      // (s1 x s2 x... x sKx x M x P)
 | 
					      // (s1 x s2 x... x sK x M x P)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
      // Check legality of matrix multiplication.
 | 
					      // Check legality of matrix multiplication.
 | 
				
			||||||
      unsigned leftDims = lhsShape.size();
 | 
					      unsigned leftDims = lhsShape.size();
 | 
				
			||||||
| 
						 | 
					@ -350,6 +350,21 @@ void ONNXMatMulOp::inferShapes() {
 | 
				
			||||||
      for (int i = 0; i < leftDims - 1; ++i)
 | 
					      for (int i = 0; i < leftDims - 1; ++i)
 | 
				
			||||||
        dims.emplace_back(lhsShape[i]);
 | 
					        dims.emplace_back(lhsShape[i]);
 | 
				
			||||||
      dims.emplace_back(rhsShape[1]);
 | 
					      dims.emplace_back(rhsShape[1]);
 | 
				
			||||||
 | 
					    } else if (lhsShape.size() == 2 && rhsShape.size() > 2) {
 | 
				
			||||||
 | 
					      // (M x N) MATMUL (s1 x s2 x... x sK x N x P)
 | 
				
			||||||
 | 
					      // =>
 | 
				
			||||||
 | 
					      // (s1 x s2 x... x sK x M x P)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					      // Check legality of matrix multiplication.
 | 
				
			||||||
 | 
					      unsigned rightDims = rhsShape.size();
 | 
				
			||||||
 | 
					      if (lhsShape[1] != -1 && rhsShape[rightDims - 2] != -1 &&
 | 
				
			||||||
 | 
					          lhsShape[1] != rhsShape[rightDims - 2])
 | 
				
			||||||
 | 
					        emitError("Attempt to multiply incompatible matrices.");
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					      for (int i = 0; i < rightDims - 2; ++i)
 | 
				
			||||||
 | 
					        dims.emplace_back(rhsShape[i]);
 | 
				
			||||||
 | 
					      dims.emplace_back(lhsShape[0]);
 | 
				
			||||||
 | 
					      dims.emplace_back(rhsShape[rightDims - 1]);
 | 
				
			||||||
    } else {
 | 
					    } else {
 | 
				
			||||||
      dims.emplace_back(lhsShape.size() == 1 ? 1 : lhsShape[0]);
 | 
					      dims.emplace_back(lhsShape.size() == 1 ? 1 : lhsShape[0]);
 | 
				
			||||||
      dims.emplace_back(rhsShape.size() == 1 ? 1 : rhsShape[1]);
 | 
					      dims.emplace_back(rhsShape.size() == 1 ? 1 : rhsShape[1]);
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue