From 8bfde7de4bcddc7923d1d922a973cd69c8e4d127 Mon Sep 17 00:00:00 2001 From: Alexandre Eichenberger Date: Mon, 31 Aug 2020 10:40:17 -0400 Subject: [PATCH] Transpose optimization (#280) * Define krnl.permute op. * Support krnl.permute operation. * Properly remove loop references. * Re-push, Github was down. * Need to debug interpretOp error. * Fix lowering bug by erasing ops after full krnl IR interpretation is done, and clean up & comment code. * Introduce permute, unroll operations. * More debug. * Remove std::set. * krnl.terminate fails to be converted. * Pass all tests, need to add legal ops as well as part of the conversion target. * Change test format to new permute spec. * Bug fix for nested iterate op lowering. * Simplify error reporting. * Fix compilation error. * Increase comments coverage. * Remove unnecessary imports. * Re-trigger Jenkins * Add permute/unroll tests. * Retrigger Jenkins * transpose fusion and removal * format * fix comments Co-authored-by: Tian Jin --- src/Dialect/ONNX/ONNXOps.td.inc | 1 + src/Transform/ONNX/Combine.cpp | 44 +++++++++++++++++++++++ src/Transform/ONNX/Combine.td | 22 ++++++++++++ test/mlir/onnx/onnx_canonicalization.mlir | 31 ++++++++++++++++ utils/gen_onnx_mlir.py | 2 +- 5 files changed, 99 insertions(+), 1 deletion(-) diff --git a/src/Dialect/ONNX/ONNXOps.td.inc b/src/Dialect/ONNX/ONNXOps.td.inc index f43e9d8..b39b5b6 100644 --- a/src/Dialect/ONNX/ONNXOps.td.inc +++ b/src/Dialect/ONNX/ONNXOps.td.inc @@ -5456,6 +5456,7 @@ def ONNXTopKOp:ONNX_Op<"TopK", def ONNXTransposeOp:ONNX_Op<"Transpose", [NoSideEffect, DeclareOpInterfaceMethods]> { + let hasCanonicalizer = 1; let summary = "ONNX Transpose operation"; let description = [{ "Transpose the input tensor similar to numpy.transpose. For example, when" diff --git a/src/Transform/ONNX/Combine.cpp b/src/Transform/ONNX/Combine.cpp index 08f8abb..aa12fe8 100644 --- a/src/Transform/ONNX/Combine.cpp +++ b/src/Transform/ONNX/Combine.cpp @@ -19,6 +19,43 @@ using namespace mlir; namespace { +//===----------------------------------------------------------------------===// +// Support for transpose patterns. +//===----------------------------------------------------------------------===// + +/// Compute the combined permute pattern from a pair of permute patterns. +ArrayAttr CombinedTransposePattern(PatternRewriter &rewriter, + ArrayAttr &firstPermAttr, ArrayAttr &secondPermAttr) { + // Read first permute vectors. + SmallVector initialPerm; + for (auto firstPermVal : firstPermAttr.getValue()) + initialPerm.emplace_back(firstPermVal.cast().getInt()); + // Read second permute vector. Use it as an index in the first permute + // vector. + SmallVector resPerm; + for (auto secondPermVal : secondPermAttr.getValue()) { + auto index = secondPermVal.cast().getInt(); + resPerm.emplace_back(initialPerm[index]); + } + // Convert to Array of Attributes. + ArrayRef resPermRefs(resPerm); + return rewriter.getI64ArrayAttr(resPermRefs); +} + +/// Test if the permute pattern correspond to an identity pattern. +/// Identity patterns are {0, 1, 2, ... , rank -1}. +bool IsIdentityPermuteVector(ArrayAttr &permAttr) { + int64_t currentIndex = 0; + for (auto permVal : permAttr.getValue()) + if (permVal.cast().getInt() != currentIndex++) + return false; + return true; +} + +//===----------------------------------------------------------------------===// +// Pattern definition. +//===----------------------------------------------------------------------===// + /// Include the patterns defined in the Declarative Rewrite framework. #include "src/Transform/ONNX/ONNXCombine.inc" } // end anonymous namespace @@ -51,3 +88,10 @@ void ONNXCastOp::getCanonicalizationPatterns( OwningRewritePatternList &result, MLIRContext *context) { result.insert(context); } + +/// on the ONNXTransposeOp. +void ONNXTransposeOp::getCanonicalizationPatterns( + OwningRewritePatternList &result, MLIRContext *context) { + result.insert(context); + result.insert(context); +} diff --git a/src/Transform/ONNX/Combine.td b/src/Transform/ONNX/Combine.td index 04067e0..ef9951f 100644 --- a/src/Transform/ONNX/Combine.td +++ b/src/Transform/ONNX/Combine.td @@ -66,5 +66,27 @@ def CastEliminationPattern : Pat< (replaceWithValue $arg), [(HasSameElementType $arg, $type)]>; +// Combine transposes. +def CreateCombinedTransposedPattern : + NativeCodeCall<"CombinedTransposePattern($_builder, $0, $1)">; + +def IsIdentityPermuteAttribute : + Constraint, + "has identity permute vector">; + +def FuseTransposePattern: Pat< + // Transpose of a transpose. + (ONNXTransposeOp (ONNXTransposeOp $v, $p1), $p2), + // Transpose with combined pattern. + (ONNXTransposeOp $v, (CreateCombinedTransposedPattern $p1, $p2))>; + +def RemoveIdentityTransposePattern: Pat< + // Transpose with an identity pattern (e.g. {0, 1, 2, 3}). + (ONNXTransposeOp $v, $p), + // Remove the transpose. + (replaceWithValue $v), + // Check that we have indeed a identity transpose pattern. + [(IsIdentityPermuteAttribute:$p)]>; + #endif // ONNX_COMBINE diff --git a/test/mlir/onnx/onnx_canonicalization.mlir b/test/mlir/onnx/onnx_canonicalization.mlir index e74a8be..ad689ea 100644 --- a/test/mlir/onnx/onnx_canonicalization.mlir +++ b/test/mlir/onnx/onnx_canonicalization.mlir @@ -191,3 +191,34 @@ func @test_conv_batchnormtestmode_fusion(%arg0 : tensor<1x3x224x224xf32>, %arg1 // CHECK: return [[RES]] : tensor<1x64x112x112xf32> } +// ----- + +// Check the removal of identity transposes. +// CHECK-LABEL: func @test_transpose_removal(%arg0: tensor<10x11x12x13xf32>) -> tensor<10x11x12x13xf32> { +func @test_transpose_removal(%arg0: tensor<10x11x12x13xf32>) -> tensor<10x11x12x13xf32> { + %0 = "onnx.Transpose"(%arg0) {perm = [0, 1, 2, 3]} : (tensor<10x11x12x13xf32>) -> tensor<10x11x12x13xf32> + // CHECK-NEXT: return %arg0 : tensor<10x11x12x13xf32> + "std.return"(%0) : (tensor<10x11x12x13xf32>) -> () +} + +// ----- + +// Check the combining of transposes into a simple transpose. +// CHECK-LABEL: func @test_transpose_fusion(%arg0: tensor<10x11x12x13xf32>) -> tensor<11x10x13x12xf32> { +func @test_transpose_fusion(%arg0: tensor<10x11x12x13xf32>) -> tensor<11x10x13x12xf32> { + %0 = "onnx.Transpose"(%arg0) {perm = [3, 2, 1, 0]} : (tensor<10x11x12x13xf32>) -> tensor<13x12x11x10xf32> + %1 = "onnx.Transpose"(%0) {perm = [2, 3, 0, 1]} : (tensor<13x12x11x10xf32>) -> tensor<11x10x13x12xf32> + // CHECK-NEXT: %{{.*}} = "onnx.Transpose"(%arg0) {perm = [1, 0, 3, 2]} : (tensor<10x11x12x13xf32>) -> tensor<11x10x13x12xf32> + "std.return"(%1) : (tensor<11x10x13x12xf32>) -> () +} + +// ----- + +// Check the combining of transposes into an identity transpose, which in turns is removed. +// CHECK-LABEL: func @test_transpose_fusion_removal(%arg0: tensor<10x11x12x13xf32>) -> tensor<10x11x12x13xf32> { +func @test_transpose_fusion_removal(%arg0: tensor<10x11x12x13xf32>) -> tensor<10x11x12x13xf32> { + %0 = "onnx.Transpose"(%arg0) {perm = [3, 2, 1, 0]} : (tensor<10x11x12x13xf32>) -> tensor<13x12x11x10xf32> + %1 = "onnx.Transpose"(%0) {perm = [3, 2, 1, 0]} : (tensor<13x12x11x10xf32>) -> tensor<10x11x12x13xf32> + // CHECK-NEXT: return %arg0 : tensor<10x11x12x13xf32> + "std.return"(%1) : (tensor<10x11x12x13xf32>) -> () +} diff --git a/utils/gen_onnx_mlir.py b/utils/gen_onnx_mlir.py index bde18af..2528837 100644 --- a/utils/gen_onnx_mlir.py +++ b/utils/gen_onnx_mlir.py @@ -321,7 +321,7 @@ OpsWithShapeInference=[ ] # Operations supporting canonicalization. -OpsWithCanonicalizer = ['Add', 'Identity', 'Gemm', 'Conv', 'Cast'] +OpsWithCanonicalizer = ['Add', 'Identity', 'Gemm', 'Conv', 'Cast', 'Transpose'] # Operations who have operands that, if produced by constant operations, should # be promoted to become an attribute (via attribute promotion).