diff --git a/src/Dialect/ONNX/ONNXOps.cpp b/src/Dialect/ONNX/ONNXOps.cpp index fe06204..1bb3a2c 100644 --- a/src/Dialect/ONNX/ONNXOps.cpp +++ b/src/Dialect/ONNX/ONNXOps.cpp @@ -1537,6 +1537,79 @@ bool ONNXConcatOp::inferShapes() { return true; } +//===----------------------------------------------------------------------===// +// Split + +bool ONNXSplitOp::inferShapes() { + if (!getOperand().getType().cast()) { + emitError("Input tensor not ranked"); + return false; + } + + int numOfResults = getNumResults(); + auto inputType = getOperand().getType().cast(); + auto inputShape = inputType.getShape(); + int64_t inputRank = inputShape.size(); + + // Checking value of axis parameter. + auto axisIndex = axis().getSExtValue(); + if (axisIndex < -inputRank || axisIndex >= inputRank) { + emitError("Split axis value out of bound"); + return false; + } + // Negative axis means values are counted from the opposite side. + if (axisIndex < 0) { + axisIndex = inputRank + axisIndex; + auto builder = mlir::Builder(getContext()); + axisAttr(builder.getI64IntegerAttr(axisIndex)); + } + + // Checking value of split parameter. + auto splitAttribute = split(); + SmallVector splitLengths; + if (splitAttribute.hasValue()) { + if (ArrayAttrSize(splitAttribute) != numOfResults) { + emitError("Split size not equal to the number of results"); + } + for (int i = 0; i < numOfResults; ++i) + splitLengths.emplace_back(ArrayAttrIntVal(splitAttribute, i)); + + } else { + if (inputShape[axisIndex] <= 0) { + emitError("The dimension at the split axis is expected to be known at " + "compile time"); + return false; + } + if (inputShape[axisIndex] % numOfResults != 0) { + emitError("The dimension at the split axis is expected to be divisible " + "by the number of results"); + return false; + } + // If split parameter is not specified, the dimension is split to + // equal-sized parts. + for (int i = 0; i < numOfResults; ++i) + splitLengths.emplace_back(inputShape[axisIndex] / numOfResults); + // Build attribute and store attribute. + auto builder = mlir::Builder(getContext()); + splitAttr(builder.getI64ArrayAttr(llvm::makeArrayRef(splitLengths))); + } + + // Build result types. + for (int i = 0; i < numOfResults; ++i) { + SmallVector resultShape; + for (int j = 0; j < inputRank; ++j) { + if (j == axisIndex) { + resultShape.emplace_back(splitLengths[i]); + } else { + resultShape.emplace_back(inputShape[j]); + } + } + getResults()[i].setType( + RankedTensorType::get(resultShape, inputType.getElementType())); + } + return true; +} + //===----------------------------------------------------------------------===// // TableGen'd op method definitions //===----------------------------------------------------------------------===// diff --git a/src/Dialect/ONNX/ONNXOps.td.inc b/src/Dialect/ONNX/ONNXOps.td.inc index 8533863..3b2dc86 100644 --- a/src/Dialect/ONNX/ONNXOps.td.inc +++ b/src/Dialect/ONNX/ONNXOps.td.inc @@ -3241,7 +3241,7 @@ def ONNXSpaceToDepthOp:ONNX_Op<"SpaceToDepth", } def ONNXSplitOp:ONNX_Op<"Split", - [NoSideEffect]> { + [NoSideEffect, DeclareOpInterfaceMethods]> { let summary = "ONNX Split operation"; let description = [{ "Split a tensor into a list of tensors, along the specified" diff --git a/src/Transform/ONNX/ShapeInferencePass.cpp b/src/Transform/ONNX/ShapeInferencePass.cpp index 36655cd..97f9794 100644 --- a/src/Transform/ONNX/ShapeInferencePass.cpp +++ b/src/Transform/ONNX/ShapeInferencePass.cpp @@ -124,6 +124,7 @@ public: op->getName().getStringRef() != "onnx.Abs" && op->getName().getStringRef() != "onnx.Constant" && op->getName().getStringRef() != "onnx.Concat" && + op->getName().getStringRef() != "onnx.Split" && op->getName().getStringRef() != "onnx.Neg" && op->getName().getStringRef() != "onnx.Unsqueeze") return false; diff --git a/test/mlir/onnx/onnx_shape_inference.mlir b/test/mlir/onnx/onnx_shape_inference.mlir index 33a716f..8bee657 100644 --- a/test/mlir/onnx/onnx_shape_inference.mlir +++ b/test/mlir/onnx/onnx_shape_inference.mlir @@ -610,3 +610,36 @@ func @test_concat_3(%arg0 : tensor<5x1x32xf32>, %arg1 : tensor<5x3x32xf32>, %arg // CHECK: [[RES:%.+]] = "onnx.Concat"(%arg0, %arg1, %arg2) {axis = 1 : i64} : (tensor<5x1x32xf32>, tensor<5x3x32xf32>, tensor<5x5x32xf32>) -> tensor<5x9x32xf32> // CHECK: return [[RES]] : tensor<5x9x32xf32> } + +// ----- + +func @test_split_1(%arg0 : tensor<16x32x64xf32>) -> tensor<*xf32> { + %0, %1 = "onnx.Split"(%arg0) { axis = 1 } : (tensor<16x32x64xf32>) -> (tensor<*xf32>, tensor<*xf32>) + "std.return"(%0) : (tensor<*xf32>) -> () + + // CHECK-LABEL: test_split_1 + // CHECK: [[RES:%.+]]:2 = "onnx.Split"(%arg0) {axis = 1 : i64, split = [16, 16]} : (tensor<16x32x64xf32>) -> (tensor<16x16x64xf32>, tensor<16x16x64xf32>) + // CHECK: return [[RES]]#0 : tensor<16x16x64xf32> +} + +// ----- + +func @test_split_2(%arg0 : tensor<16x32x64xf32>) -> tensor<*xf32> { + %0, %1 = "onnx.Split"(%arg0) { axis = -2 } : (tensor<16x32x64xf32>) -> (tensor<*xf32>, tensor<*xf32>) + "std.return"(%0) : (tensor<*xf32>) -> () + + // CHECK-LABEL: test_split_2 + // CHECK: [[RES:%.+]]:2 = "onnx.Split"(%arg0) {axis = 1 : i64, split = [16, 16]} : (tensor<16x32x64xf32>) -> (tensor<16x16x64xf32>, tensor<16x16x64xf32>) + // CHECK: return [[RES]]#0 : tensor<16x16x64xf32> +} + +// ----- + +func @test_split_3(%arg0 : tensor<16x32x64xf32>) -> tensor<*xf32> { + %0, %1 = "onnx.Split"(%arg0) { axis = 1, split = [2, 30]} : (tensor<16x32x64xf32>) -> (tensor<*xf32>, tensor<*xf32>) + "std.return"(%0) : (tensor<*xf32>) -> () + + // CHECK-LABEL: test_split_3 + // CHECK: [[RES:%.+]]:2 = "onnx.Split"(%arg0) {axis = 1 : i64, split = [2, 30]} : (tensor<16x32x64xf32>) -> (tensor<16x2x64xf32>, tensor<16x30x64xf32>) + // CHECK: return [[RES]]#0 : tensor<16x2x64xf32> +} diff --git a/utils/gen_doc.py b/utils/gen_doc.py index 3af41eb..1f7fe65 100644 --- a/utils/gen_doc.py +++ b/utils/gen_doc.py @@ -63,7 +63,7 @@ OpsWithShapeInference = [ 'LeakyRelu', 'Elu', 'Selu', 'HardSigmoid', 'Reshape', 'Reciprocal', 'Identity', 'Cos', 'Log', 'Transpose', 'Softmax', 'ReduceMax', 'ReduceMin', 'ReduceProd', 'ReduceSum', 'Softplus', 'Softsign', 'Sqrt', 'Unsqueeze', - 'Sign', 'Constant', 'AveragePool', 'Abs', 'Conv', 'Concat', 'Neg' + 'Sign', 'Constant', 'AveragePool', 'Abs', 'Conv', 'Concat', 'Neg', 'Split' ] # Operations supporting canonicalization.