From a3995b61e765ccd04467157becdf39f3e6bfe5db Mon Sep 17 00:00:00 2001 From: Doru Bercea Date: Fri, 10 Jan 2020 12:27:34 -0500 Subject: [PATCH] Add support for shape broadcast. --- src/dialect/onnx/onnx_ops.cpp | 29 +++++++++++++++++++++++++++-- 1 file changed, 27 insertions(+), 2 deletions(-) diff --git a/src/dialect/onnx/onnx_ops.cpp b/src/dialect/onnx/onnx_ops.cpp index a449e05..9166a59 100644 --- a/src/dialect/onnx/onnx_ops.cpp +++ b/src/dialect/onnx/onnx_ops.cpp @@ -343,8 +343,8 @@ void ONNXMatMulOp::inferShapes() { // Check legality of matrix multiplication. unsigned leftDims = lhsShape.size(); - if (lhsShape[leftDims - 2] != -1 && rhsShape[0] != -1 && - lhsShape[leftDims - 2] != rhsShape[0]) + if (lhsShape[leftDims - 1] != -1 && rhsShape[0] != -1 && + lhsShape[leftDims - 1] != rhsShape[0]) emitError("Attempt to multiply incompatible matrices."); for (int i = 0; i < leftDims - 1; ++i) @@ -365,7 +365,32 @@ void ONNXMatMulOp::inferShapes() { dims.emplace_back(rhsShape[i]); dims.emplace_back(lhsShape[0]); dims.emplace_back(rhsShape[rightDims - 1]); + } else if (lhsShape.size() > 2 && rhsShape.size() > 2) { + // (s1 x s2 x... x sK x M x N) MATMUL (t1 x t2 x... x tK x N x P) + // => + // (u1 x u2 x... x uK x M x P) + + // Check legality of matrix multiplication. + unsigned leftDims = lhsShape.size(); + unsigned rightDims = rhsShape.size(); + if (lhsShape[leftDims - 1] != -1 && rhsShape[rightDims - 2] != -1 && + lhsShape[leftDims - 1] != rhsShape[rightDims - 2]) + emitError("Attempt to multiply incompatible matrices."); + + // Check and perform broadcasting for the shapes. + SmallVector lhsBcastShape; + for (int i = 0; i < leftDims - 2; ++i) + lhsBcastShape.emplace_back(lhsShape[i]); + SmallVector rhsBcastShape; + for (int i = 0; i < rightDims - 2; ++i) + rhsBcastShape.emplace_back(rhsShape[i]); + if (!getBroadcastedShape(lhsBcastShape, rhsBcastShape, dims)) + emitError("Broadcasted dimensions are incompatible."); + + dims.emplace_back(lhsShape[leftDims - 2]); + dims.emplace_back(rhsShape[rightDims - 1]); } else { + // This case covers all remaining combinations of 1 and 2-D matrices. dims.emplace_back(lhsShape.size() == 1 ? 1 : lhsShape[0]); dims.emplace_back(rhsShape.size() == 1 ? 1 : rhsShape[1]); }