diff --git a/src/Dialect/ONNX/ONNXOps.td.inc b/src/Dialect/ONNX/ONNXOps.td.inc index f43e9d8..b39b5b6 100644 --- a/src/Dialect/ONNX/ONNXOps.td.inc +++ b/src/Dialect/ONNX/ONNXOps.td.inc @@ -5456,6 +5456,7 @@ def ONNXTopKOp:ONNX_Op<"TopK", def ONNXTransposeOp:ONNX_Op<"Transpose", [NoSideEffect, DeclareOpInterfaceMethods]> { + let hasCanonicalizer = 1; let summary = "ONNX Transpose operation"; let description = [{ "Transpose the input tensor similar to numpy.transpose. For example, when" diff --git a/src/Transform/ONNX/Combine.cpp b/src/Transform/ONNX/Combine.cpp index 08f8abb..aa12fe8 100644 --- a/src/Transform/ONNX/Combine.cpp +++ b/src/Transform/ONNX/Combine.cpp @@ -19,6 +19,43 @@ using namespace mlir; namespace { +//===----------------------------------------------------------------------===// +// Support for transpose patterns. +//===----------------------------------------------------------------------===// + +/// Compute the combined permute pattern from a pair of permute patterns. +ArrayAttr CombinedTransposePattern(PatternRewriter &rewriter, + ArrayAttr &firstPermAttr, ArrayAttr &secondPermAttr) { + // Read first permute vectors. + SmallVector initialPerm; + for (auto firstPermVal : firstPermAttr.getValue()) + initialPerm.emplace_back(firstPermVal.cast().getInt()); + // Read second permute vector. Use it as an index in the first permute + // vector. + SmallVector resPerm; + for (auto secondPermVal : secondPermAttr.getValue()) { + auto index = secondPermVal.cast().getInt(); + resPerm.emplace_back(initialPerm[index]); + } + // Convert to Array of Attributes. + ArrayRef resPermRefs(resPerm); + return rewriter.getI64ArrayAttr(resPermRefs); +} + +/// Test if the permute pattern correspond to an identity pattern. +/// Identity patterns are {0, 1, 2, ... , rank -1}. +bool IsIdentityPermuteVector(ArrayAttr &permAttr) { + int64_t currentIndex = 0; + for (auto permVal : permAttr.getValue()) + if (permVal.cast().getInt() != currentIndex++) + return false; + return true; +} + +//===----------------------------------------------------------------------===// +// Pattern definition. +//===----------------------------------------------------------------------===// + /// Include the patterns defined in the Declarative Rewrite framework. #include "src/Transform/ONNX/ONNXCombine.inc" } // end anonymous namespace @@ -51,3 +88,10 @@ void ONNXCastOp::getCanonicalizationPatterns( OwningRewritePatternList &result, MLIRContext *context) { result.insert(context); } + +/// on the ONNXTransposeOp. +void ONNXTransposeOp::getCanonicalizationPatterns( + OwningRewritePatternList &result, MLIRContext *context) { + result.insert(context); + result.insert(context); +} diff --git a/src/Transform/ONNX/Combine.td b/src/Transform/ONNX/Combine.td index 04067e0..ef9951f 100644 --- a/src/Transform/ONNX/Combine.td +++ b/src/Transform/ONNX/Combine.td @@ -66,5 +66,27 @@ def CastEliminationPattern : Pat< (replaceWithValue $arg), [(HasSameElementType $arg, $type)]>; +// Combine transposes. +def CreateCombinedTransposedPattern : + NativeCodeCall<"CombinedTransposePattern($_builder, $0, $1)">; + +def IsIdentityPermuteAttribute : + Constraint, + "has identity permute vector">; + +def FuseTransposePattern: Pat< + // Transpose of a transpose. + (ONNXTransposeOp (ONNXTransposeOp $v, $p1), $p2), + // Transpose with combined pattern. + (ONNXTransposeOp $v, (CreateCombinedTransposedPattern $p1, $p2))>; + +def RemoveIdentityTransposePattern: Pat< + // Transpose with an identity pattern (e.g. {0, 1, 2, 3}). + (ONNXTransposeOp $v, $p), + // Remove the transpose. + (replaceWithValue $v), + // Check that we have indeed a identity transpose pattern. + [(IsIdentityPermuteAttribute:$p)]>; + #endif // ONNX_COMBINE diff --git a/test/mlir/onnx/onnx_canonicalization.mlir b/test/mlir/onnx/onnx_canonicalization.mlir index e74a8be..ad689ea 100644 --- a/test/mlir/onnx/onnx_canonicalization.mlir +++ b/test/mlir/onnx/onnx_canonicalization.mlir @@ -191,3 +191,34 @@ func @test_conv_batchnormtestmode_fusion(%arg0 : tensor<1x3x224x224xf32>, %arg1 // CHECK: return [[RES]] : tensor<1x64x112x112xf32> } +// ----- + +// Check the removal of identity transposes. +// CHECK-LABEL: func @test_transpose_removal(%arg0: tensor<10x11x12x13xf32>) -> tensor<10x11x12x13xf32> { +func @test_transpose_removal(%arg0: tensor<10x11x12x13xf32>) -> tensor<10x11x12x13xf32> { + %0 = "onnx.Transpose"(%arg0) {perm = [0, 1, 2, 3]} : (tensor<10x11x12x13xf32>) -> tensor<10x11x12x13xf32> + // CHECK-NEXT: return %arg0 : tensor<10x11x12x13xf32> + "std.return"(%0) : (tensor<10x11x12x13xf32>) -> () +} + +// ----- + +// Check the combining of transposes into a simple transpose. +// CHECK-LABEL: func @test_transpose_fusion(%arg0: tensor<10x11x12x13xf32>) -> tensor<11x10x13x12xf32> { +func @test_transpose_fusion(%arg0: tensor<10x11x12x13xf32>) -> tensor<11x10x13x12xf32> { + %0 = "onnx.Transpose"(%arg0) {perm = [3, 2, 1, 0]} : (tensor<10x11x12x13xf32>) -> tensor<13x12x11x10xf32> + %1 = "onnx.Transpose"(%0) {perm = [2, 3, 0, 1]} : (tensor<13x12x11x10xf32>) -> tensor<11x10x13x12xf32> + // CHECK-NEXT: %{{.*}} = "onnx.Transpose"(%arg0) {perm = [1, 0, 3, 2]} : (tensor<10x11x12x13xf32>) -> tensor<11x10x13x12xf32> + "std.return"(%1) : (tensor<11x10x13x12xf32>) -> () +} + +// ----- + +// Check the combining of transposes into an identity transpose, which in turns is removed. +// CHECK-LABEL: func @test_transpose_fusion_removal(%arg0: tensor<10x11x12x13xf32>) -> tensor<10x11x12x13xf32> { +func @test_transpose_fusion_removal(%arg0: tensor<10x11x12x13xf32>) -> tensor<10x11x12x13xf32> { + %0 = "onnx.Transpose"(%arg0) {perm = [3, 2, 1, 0]} : (tensor<10x11x12x13xf32>) -> tensor<13x12x11x10xf32> + %1 = "onnx.Transpose"(%0) {perm = [3, 2, 1, 0]} : (tensor<13x12x11x10xf32>) -> tensor<10x11x12x13xf32> + // CHECK-NEXT: return %arg0 : tensor<10x11x12x13xf32> + "std.return"(%1) : (tensor<10x11x12x13xf32>) -> () +} diff --git a/utils/gen_onnx_mlir.py b/utils/gen_onnx_mlir.py index bde18af..2528837 100644 --- a/utils/gen_onnx_mlir.py +++ b/utils/gen_onnx_mlir.py @@ -321,7 +321,7 @@ OpsWithShapeInference=[ ] # Operations supporting canonicalization. -OpsWithCanonicalizer = ['Add', 'Identity', 'Gemm', 'Conv', 'Cast'] +OpsWithCanonicalizer = ['Add', 'Identity', 'Gemm', 'Conv', 'Cast', 'Transpose'] # Operations who have operands that, if produced by constant operations, should # be promoted to become an attribute (via attribute promotion).