From 38bffee619874714123f8375c62eb1fc8db72a52 Mon Sep 17 00:00:00 2001 From: Doru Bercea Date: Fri, 10 Jan 2020 11:34:26 -0500 Subject: [PATCH] Add support for broadcasting left matrix. --- src/dialect/onnx/onnx_ops.cpp | 19 +++++++++++++++++-- 1 file changed, 17 insertions(+), 2 deletions(-) diff --git a/src/dialect/onnx/onnx_ops.cpp b/src/dialect/onnx/onnx_ops.cpp index f19242b..a449e05 100644 --- a/src/dialect/onnx/onnx_ops.cpp +++ b/src/dialect/onnx/onnx_ops.cpp @@ -337,9 +337,9 @@ void ONNXMatMulOp::inferShapes() { } else { // Special cases for when at least one matrix has more than two dimensions. if (lhsShape.size() > 2 && rhsShape.size() == 2) { - // (s1 x s2 x... x sKx x M x N) MATMUL (N x P) + // (s1 x s2 x... x sK x M x N) MATMUL (N x P) // => - // (s1 x s2 x... x sKx x M x P) + // (s1 x s2 x... x sK x M x P) // Check legality of matrix multiplication. unsigned leftDims = lhsShape.size(); @@ -350,6 +350,21 @@ void ONNXMatMulOp::inferShapes() { for (int i = 0; i < leftDims - 1; ++i) dims.emplace_back(lhsShape[i]); dims.emplace_back(rhsShape[1]); + } else if (lhsShape.size() == 2 && rhsShape.size() > 2) { + // (M x N) MATMUL (s1 x s2 x... x sK x N x P) + // => + // (s1 x s2 x... x sK x M x P) + + // Check legality of matrix multiplication. + unsigned rightDims = rhsShape.size(); + if (lhsShape[1] != -1 && rhsShape[rightDims - 2] != -1 && + lhsShape[1] != rhsShape[rightDims - 2]) + emitError("Attempt to multiply incompatible matrices."); + + for (int i = 0; i < rightDims - 2; ++i) + dims.emplace_back(rhsShape[i]); + dims.emplace_back(lhsShape[0]); + dims.emplace_back(rhsShape[rightDims - 1]); } else { dims.emplace_back(lhsShape.size() == 1 ? 1 : lhsShape[0]); dims.emplace_back(rhsShape.size() == 1 ? 1 : rhsShape[1]);