diff --git a/src/dialect/onnx/onnx_ops.cpp b/src/dialect/onnx/onnx_ops.cpp index 4b68fe7..2380f5a 100644 --- a/src/dialect/onnx/onnx_ops.cpp +++ b/src/dialect/onnx/onnx_ops.cpp @@ -318,12 +318,28 @@ void ONNXMatMulOp::inferShapes() { if (!getOperand(0).getType().isa() || !getOperand(1).getType().isa()) return; - auto lhsTy = getOperand(0).getType().cast(); - auto rhsTy = getOperand(1).getType().cast(); + + auto lhsTy = getOperand(0)->getType().cast(); + auto rhsTy = getOperand(1)->getType().cast(); + SmallVector dims; - dims.emplace_back(lhsTy.getShape()[0]); - dims.emplace_back(rhsTy.getShape()[1]); - getResult().setType(RankedTensorType::get(dims, lhsTy.getElementType())); + auto lhsShape = lhsTy.getShape(); + auto rhsShape = rhsTy.getShape(); + if (lhsShape.size() == 1 && rhsShape.size() == 1) { + // Special case when both arrays are 1-dimensional and according to + // numpy rules the types need to be extended to 1xN and Nx1. Helper sizes + // need to be removed after the multiplication but cannot be removed if all + // remaining sizes are 1. + if (lhsShape[0] != -1 && rhsShape[0] != -1 && + lhsShape[0] != rhsShape[0]) + emitError("Attempt to multiply incompatible matrices."); + dims.emplace_back(1); + } else { + dims.emplace_back(lhsShape.size() == 1 ? 1 : lhsShape[0]); + dims.emplace_back(rhsShape.size() == 1 ? 1 : rhsShape[1]); + } + + getResult()->setType(RankedTensorType::get(dims, lhsTy.getElementType())); } // TODO: