diff --git a/src/Dialect/ONNX/ONNXOps.td.inc b/src/Dialect/ONNX/ONNXOps.td.inc index e95e401..56fbeeb 100644 --- a/src/Dialect/ONNX/ONNXOps.td.inc +++ b/src/Dialect/ONNX/ONNXOps.td.inc @@ -4729,6 +4729,7 @@ def ONNXSequenceLengthOp:ONNX_Op<"SequenceLength", def ONNXShapeOp:ONNX_Op<"Shape", [NoSideEffect, DeclareOpInterfaceMethods]> { + let hasCanonicalizer = 1; let summary = "ONNX Shape operation"; let description = [{ "Takes a tensor as input and outputs an 1D int64 tensor containing the shape of the input tensor." @@ -4863,6 +4864,7 @@ def ONNXSinhOp:ONNX_Op<"Sinh", def ONNXSizeOp:ONNX_Op<"Size", [NoSideEffect]> { + let hasCanonicalizer = 1; let summary = "ONNX Size operation"; let description = [{ "Takes a tensor as input and outputs a int64 scalar that equals to the total number of elements of the input tensor." diff --git a/src/Transform/ONNX/Rewrite.cpp b/src/Transform/ONNX/Rewrite.cpp index a5fbd7e..641b705 100644 --- a/src/Transform/ONNX/Rewrite.cpp +++ b/src/Transform/ONNX/Rewrite.cpp @@ -27,6 +27,29 @@ DenseElementsAttr createDenseElementsAttrFromFloatAttr( return mlir::DenseElementsAttr::get(tensorType, llvm::makeArrayRef(values)); } +// Create a DenseElementsAttr based on the shape of type. +DenseElementsAttr createDenseElementsAttrFromShape( + PatternRewriter &rewriter, Value value) { + auto inType = value.getType().cast(); + auto shape = inType.getShape(); + SmallVector dims = {inType.getRank()}; + SmallVector values(shape.begin(), shape.end()); + auto tensorType = + mlir::RankedTensorType::get(dims, rewriter.getIntegerType(64)); + return mlir::DenseElementsAttr::get(tensorType, llvm::makeArrayRef(values)); +} + +// Create a DenseElementsAttr based on the size of type. +DenseElementsAttr createDenseElementsAttrFromSize( + PatternRewriter &rewriter, Value value) { + auto inType = value.getType().cast(); + SmallVector dims(1, 1); + SmallVector values = {inType.getNumElements()}; + auto tensorType = + mlir::RankedTensorType::get(dims, rewriter.getIntegerType(64)); + return mlir::DenseElementsAttr::get(tensorType, llvm::makeArrayRef(values)); +} + // If 'lhs' is not NoneType, return 'lhs - rhs'. // Otherwise, return '-rhs'. Value subtractOrNeg( @@ -128,3 +151,15 @@ void ONNXBatchNormalizationTestModeOp::getCanonicalizationPatterns( OwningRewritePatternList &results, MLIRContext *context) { results.insert(context); } + +/// on the ONNXShapeOp. +void ONNXShapeOp::getCanonicalizationPatterns( + OwningRewritePatternList &results, MLIRContext *context) { + results.insert(context); +} + +/// on the ONNXSizeOp. +void ONNXSizeOp::getCanonicalizationPatterns( + OwningRewritePatternList &results, MLIRContext *context) { + results.insert(context); +} diff --git a/src/Transform/ONNX/Rewrite.td b/src/Transform/ONNX/Rewrite.td index 6d6fb10..e3d9886 100644 --- a/src/Transform/ONNX/Rewrite.td +++ b/src/Transform/ONNX/Rewrite.td @@ -28,6 +28,14 @@ include "src/Dialect/ONNX/ONNXOps.td" def createDenseElementsAttrFromFloatAttr : NativeCodeCall< "createDenseElementsAttrFromFloatAttr($_builder, $0.getType().cast().getElementType(), $1)">; +// Create a DenseElementsAttr from the shape of the type of a value. +def createDenseElementsAttrFromShape : NativeCodeCall< + "createDenseElementsAttrFromShape($_builder, $0)">; + +// Create a DenseElementsAttr from the size of the type of a value. +def createDenseElementsAttrFromSize : NativeCodeCall< + "createDenseElementsAttrFromSize($_builder, $0)">; + // If '$1' is not NoneType, do subtraction '$1 - $2'. // Otherwise, take the negative of '$2'. def subtractOrNeg: NativeCodeCall< @@ -172,4 +180,25 @@ def FuseBatchNormTestModeConvPattern: Pat< $auto_pad, $dilation, $group, $kernel_shape, $pads, $strides) >; +def IsStaticShapeTensor: + Constraint< + CPred< + "$_self.getType().cast<::mlir::ShapedType>().hasStaticShape()">, + "hasStaticShape">; + +def ShapeToConstantPattern: Pat< + (ONNXShapeOp $A), + (ONNXConstantOp + (GetNullAttr), + (createDenseElementsAttrFromShape $A)), + [(IsStaticShapeTensor:$A)] +>; + +def SizeToConstantPattern: Pat< + (ONNXSizeOp $A), + (ONNXConstantOp + (GetNullAttr), + (createDenseElementsAttrFromSize $A)), + [(IsStaticShapeTensor:$A)] +>; #endif // ONNX_REWRITE diff --git a/test/mlir/onnx/onnx_canonicalization.mlir b/test/mlir/onnx/onnx_canonicalization.mlir index ad689ea..aa0298c 100644 --- a/test/mlir/onnx/onnx_canonicalization.mlir +++ b/test/mlir/onnx/onnx_canonicalization.mlir @@ -222,3 +222,49 @@ func @test_transpose_fusion_removal(%arg0: tensor<10x11x12x13xf32>) -> tensor<10 // CHECK-NEXT: return %arg0 : tensor<10x11x12x13xf32> "std.return"(%1) : (tensor<10x11x12x13xf32>) -> () } + +// ----- + +func @test_shape1(%arg0 : tensor<2x4x8x16xf32>) -> tensor<*xi64> { + %0 = "onnx.Shape"(%arg0) : (tensor<2x4x8x16xf32>) -> tensor<*xi64> + return %0 : tensor<*xi64> + + // CHECK-LABEL: @test_shape1 + // CHECK-NEXT: %0 = "onnx.Constant"() {value = dense<[2, 4, 8, 16]> : tensor<4xi64>} : () -> tensor<*xi64> + // CHECK-NEXT: %0 : tensor<*xi64> +} + +// ----- + +func @test_shape2(%arg0 : tensor) -> tensor<*xi64> { + %0 = "onnx.Shape"(%arg0) : (tensor) -> tensor<*xi64> + return %0 : tensor<*xi64> + + // CHECK-LABEL: @test_shape2 + // CHECK-NEXT: %0 = "onnx.Shape"(%arg0) : (tensor) -> tensor<*xi64> + // CHECK-NEXT: return %0 : tensor<*xi64> +} + + +// ----- + +func @test_size1(%arg0 : tensor<2x4x8x16xf32>) -> tensor<*xi64> { + %0 = "onnx.Size"(%arg0) : (tensor<2x4x8x16xf32>) -> tensor<*xi64> + return %0 : tensor<*xi64> + + // CHECK-LABEL: @test_size1 + // CHECK-NEXT: %0 = "onnx.Constant"() {value = dense<1024> : tensor<1xi64>} : () -> tensor<*xi64> + // CHECK-NEXT: %0 : tensor<*xi64> +} + +// ----- + +func @test_size2(%arg0 : tensor<*xf32>) -> tensor<*xi64> { + %0 = "onnx.Size"(%arg0) : (tensor<*xf32>) -> tensor<*xi64> + return %0 : tensor<*xi64> + + // CHECK-LABEL: @test_size2 + // CHECK-NEXT: %0 = "onnx.Size"(%arg0) : (tensor<*xf32>) -> tensor<*xi64> + // CHECK-NEXT: return %0 : tensor<*xi64> +} + diff --git a/utils/gen_onnx_mlir.py b/utils/gen_onnx_mlir.py index 250cdd6..474866a 100644 --- a/utils/gen_onnx_mlir.py +++ b/utils/gen_onnx_mlir.py @@ -321,7 +321,7 @@ OpsWithShapeInference=[ ] # Operations supporting canonicalization. -OpsWithCanonicalizer = ['Add', 'Identity', 'Gemm', 'Conv', 'Cast', 'Transpose', 'Dropout'] +OpsWithCanonicalizer = ['Add', 'Identity', 'Gemm', 'Conv', 'Cast', 'Transpose', 'Dropout', 'Shape', 'Size'] # Operations who have operands that, if produced by constant operations, should # be promoted to become an attribute (via attribute promotion).