diff --git a/doc/gen_doc.py b/doc/gen_doc.py index 61e06b4..2d13937 100644 --- a/doc/gen_doc.py +++ b/doc/gen_doc.py @@ -46,7 +46,7 @@ OpsWithShapeInference = [ 'LeakyRelu', 'Elu', 'Selu', 'HardSigmoid', 'Reshape', 'Reciprocal', 'Identity', 'Cos', 'Log', 'Transpose', 'Softmax', 'ReduceMax', 'ReduceMin', 'ReduceProd', 'ReduceSum', 'Softplus', 'Softsign', 'Sqrt', 'Unsqueeze', - 'Sign', 'Constant', 'AveragePool', 'Abs', 'Conv' + 'Sign', 'Constant', 'AveragePool', 'Abs', 'Conv', 'Concat' ] # Operations supporting canonicalization. diff --git a/src/Dialect/ONNX/ONNXOps.cpp b/src/Dialect/ONNX/ONNXOps.cpp index c52b6a8..cb8c8d3 100644 --- a/src/Dialect/ONNX/ONNXOps.cpp +++ b/src/Dialect/ONNX/ONNXOps.cpp @@ -1463,6 +1463,65 @@ bool ONNXConstantOp::inferShapes() { return true; } +//===----------------------------------------------------------------------===// +// Concat + +bool ONNXConcatOp::inferShapes() { + int inputNum = getNumOperands(); + for (int i = 0; i < inputNum; ++i) { + if (!getOperand(i).getType().cast()) { + emitError("Input tensor(s) not ranked"); + return false; + } + } + // Checking value of axis parameter. + auto commonType = getOperand(0).getType().cast(); + auto commonShape = commonType.getShape(); + auto commonRank = commonShape.size(); + auto axisIndex = axis().getSExtValue(); + if (!(axisIndex >= 0 && axisIndex < commonRank)) { + emitError("Concat axis value out of bound"); + return false; + } + // Initial cummlative size is that of the first operand. + int cummulativeAxisSize = commonShape[axisIndex]; + + // Compute the cummlative size with all of the other ones, and make sure that + // the other sizes are all alike. + for (int i = 1; i < inputNum; ++i) { + auto currShape = + getOperand(i).getType().cast().getShape(); + if (currShape.size() != commonRank) { + emitError("Concat input must all have the same rank"); + return false; + } + for (int j = 0; j < commonRank; ++j) { + if (j == axisIndex) { + // Check that the value is positive. + if (currShape[j] <= 0) { + emitError("Concat axis being concatenated is expected to be known at " + "compile time for now"); + return false; + } + } else if (currShape[j] != commonShape[j]) { + emitError("Concat input dimensions must be all identical, except for " + "dimension on the axis of the concatenation"); + return false; + } + } + cummulativeAxisSize += currShape[axisIndex]; + } + + // Set output size and type + SmallVector outputDims; + for (int j = 0; j < commonRank; ++j) + outputDims.emplace_back( + j == axisIndex ? cummulativeAxisSize : commonShape[j]); + getResult().setType( + RankedTensorType::get(outputDims, commonType.getElementType())); + return true; +} + //===----------------------------------------------------------------------===// // TableGen'd op method definitions //===----------------------------------------------------------------------===// diff --git a/src/Dialect/ONNX/ONNXOps.td.inc b/src/Dialect/ONNX/ONNXOps.td.inc index cbf2615..20cbfaf 100644 --- a/src/Dialect/ONNX/ONNXOps.td.inc +++ b/src/Dialect/ONNX/ONNXOps.td.inc @@ -314,7 +314,7 @@ def ONNXCompressOp:ONNX_Op<"Compress", } def ONNXConcatOp:ONNX_Op<"Concat", - [NoSideEffect]> { + [NoSideEffect, DeclareOpInterfaceMethods]> { let summary = "ONNX Concat operation"; let description = [{ "Concatenate a list of tensors into a single tensor. All input tensors must have the same shape, except for the dimension size of the axis to concatenate on." diff --git a/src/Transform/ONNX/ShapeInferencePass.cpp b/src/Transform/ONNX/ShapeInferencePass.cpp index b4385c2..2188628 100644 --- a/src/Transform/ONNX/ShapeInferencePass.cpp +++ b/src/Transform/ONNX/ShapeInferencePass.cpp @@ -123,6 +123,7 @@ public: op->getName().getStringRef() != "onnx.BatchNormalizationTestMode" && op->getName().getStringRef() != "onnx.Abs" && op->getName().getStringRef() != "onnx.Constant" && + op->getName().getStringRef() != "onnx.Concat" && op->getName().getStringRef() != "onnx.Unsqueeze") return false; return llvm::any_of(op->getResultTypes(), [](Type result_type) { diff --git a/test/mlir/onnx/onnx_shape_inference.mlir b/test/mlir/onnx/onnx_shape_inference.mlir index 02cf415..d44895c 100644 --- a/test/mlir/onnx/onnx_shape_inference.mlir +++ b/test/mlir/onnx/onnx_shape_inference.mlir @@ -487,3 +487,25 @@ func @test_reshape_3(%arg0 : tensor<5x5x1x32xf32>) -> tensor<*xf32> { // CHECK: [[RES:%.+]] = "onnx.Reshape"(%arg0, %0) : (tensor<5x5x1x32xf32>, tensor<3xi32>) -> tensor<80x5x2xf32> // CHECK: return [[RES]] : tensor<80x5x2xf32> } + +//===----------------------------------------------------------------------===// +/// Test the reshape op inference when concat are present. +//===----------------------------------------------------------------------===// + +func @test_concat_1(%arg0 : tensor<5x5x1x32xf32>, %arg1 : tensor<5x5x3x32xf32>, %arg2 : tensor<5x5x5x32xf32>) -> tensor<*xf32> { + %1 = "onnx.Concat"(%arg0, %arg1, %arg2) { axis = 2 } : (tensor<5x5x1x32xf32>, tensor<5x5x3x32xf32>, tensor<5x5x5x32xf32>) -> tensor<*xf32> + "std.return"(%1) : (tensor<*xf32>) -> () + + // CHECK-LABEL: test_concat_1 + // CHECK: [[RES:%.+]] = "onnx.Concat"(%arg0, %arg1, %arg2) {axis = 2 : i64} : (tensor<5x5x1x32xf32>, tensor<5x5x3x32xf32>, tensor<5x5x5x32xf32>) -> tensor<5x5x9x32xf32> + // CHECK: return [[RES]] : tensor<5x5x9x32xf32> +} + +func @test_concat_2(%arg0 : tensor<5x1x32xf32>, %arg1 : tensor<5x3x32xf32>, %arg2 : tensor<5x5x32xf32>) -> tensor<*xf32> { + %1 = "onnx.Concat"(%arg0, %arg1, %arg2) { axis = 1 } : (tensor<5x1x32xf32>, tensor<5x3x32xf32>, tensor<5x5x32xf32>) -> tensor<*xf32> + "std.return"(%1) : (tensor<*xf32>) -> () + + // CHECK-LABEL: test_concat_2 + // CHECK: [[RES:%.+]] = "onnx.Concat"(%arg0, %arg1, %arg2) {axis = 1 : i64} : (tensor<5x1x32xf32>, tensor<5x3x32xf32>, tensor<5x5x32xf32>) -> tensor<5x9x32xf32> + // CHECK: return [[RES]] : tensor<5x9x32xf32> +}