Transpose optimization (#280)

* Define krnl.permute op.

* Support krnl.permute operation.

* Properly remove loop references.

* Re-push, Github was down.

* Need to debug interpretOp error.

* Fix lowering bug by erasing ops after full krnl IR interpretation is done, and clean up & comment code.

* Introduce permute, unroll operations.

* More debug.

* Remove std::set.

* krnl.terminate fails to be converted.

* Pass all tests, need to add legal ops as well as part of the conversion target.

* Change test format to new permute spec.

* Bug fix for nested iterate op lowering.

* Simplify error reporting.

* Fix compilation error.

* Increase comments coverage.

* Remove unnecessary imports.

* Re-trigger Jenkins

* Add permute/unroll tests.

* Retrigger Jenkins

* transpose fusion and removal

* format

* fix comments

Co-authored-by: Tian Jin <tjingrant@gmail.com>
This commit is contained in:
Alexandre Eichenberger 2020-08-31 10:40:17 -04:00 committed by GitHub
parent 8e3748d4c7
commit 8bfde7de4b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 99 additions and 1 deletions

View File

@ -5456,6 +5456,7 @@ def ONNXTopKOp:ONNX_Op<"TopK",
def ONNXTransposeOp:ONNX_Op<"Transpose", def ONNXTransposeOp:ONNX_Op<"Transpose",
[NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> { [NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
let hasCanonicalizer = 1;
let summary = "ONNX Transpose operation"; let summary = "ONNX Transpose operation";
let description = [{ let description = [{
"Transpose the input tensor similar to numpy.transpose. For example, when" "Transpose the input tensor similar to numpy.transpose. For example, when"

View File

@ -19,6 +19,43 @@ using namespace mlir;
namespace { namespace {
//===----------------------------------------------------------------------===//
// Support for transpose patterns.
//===----------------------------------------------------------------------===//
/// Compute the combined permute pattern from a pair of permute patterns.
ArrayAttr CombinedTransposePattern(PatternRewriter &rewriter,
ArrayAttr &firstPermAttr, ArrayAttr &secondPermAttr) {
// Read first permute vectors.
SmallVector<int64_t, 4> initialPerm;
for (auto firstPermVal : firstPermAttr.getValue())
initialPerm.emplace_back(firstPermVal.cast<IntegerAttr>().getInt());
// Read second permute vector. Use it as an index in the first permute
// vector.
SmallVector<int64_t, 4> resPerm;
for (auto secondPermVal : secondPermAttr.getValue()) {
auto index = secondPermVal.cast<IntegerAttr>().getInt();
resPerm.emplace_back(initialPerm[index]);
}
// Convert to Array of Attributes.
ArrayRef<int64_t> resPermRefs(resPerm);
return rewriter.getI64ArrayAttr(resPermRefs);
}
/// Test if the permute pattern correspond to an identity pattern.
/// Identity patterns are {0, 1, 2, ... , rank -1}.
bool IsIdentityPermuteVector(ArrayAttr &permAttr) {
int64_t currentIndex = 0;
for (auto permVal : permAttr.getValue())
if (permVal.cast<IntegerAttr>().getInt() != currentIndex++)
return false;
return true;
}
//===----------------------------------------------------------------------===//
// Pattern definition.
//===----------------------------------------------------------------------===//
/// Include the patterns defined in the Declarative Rewrite framework. /// Include the patterns defined in the Declarative Rewrite framework.
#include "src/Transform/ONNX/ONNXCombine.inc" #include "src/Transform/ONNX/ONNXCombine.inc"
} // end anonymous namespace } // end anonymous namespace
@ -51,3 +88,10 @@ void ONNXCastOp::getCanonicalizationPatterns(
OwningRewritePatternList &result, MLIRContext *context) { OwningRewritePatternList &result, MLIRContext *context) {
result.insert<CastEliminationPattern>(context); result.insert<CastEliminationPattern>(context);
} }
/// on the ONNXTransposeOp.
void ONNXTransposeOp::getCanonicalizationPatterns(
OwningRewritePatternList &result, MLIRContext *context) {
result.insert<FuseTransposePattern>(context);
result.insert<RemoveIdentityTransposePattern>(context);
}

View File

@ -66,5 +66,27 @@ def CastEliminationPattern : Pat<
(replaceWithValue $arg), (replaceWithValue $arg),
[(HasSameElementType $arg, $type)]>; [(HasSameElementType $arg, $type)]>;
// Combine transposes.
def CreateCombinedTransposedPattern :
NativeCodeCall<"CombinedTransposePattern($_builder, $0, $1)">;
def IsIdentityPermuteAttribute :
Constraint<CPred<"IsIdentityPermuteVector($_self)">,
"has identity permute vector">;
def FuseTransposePattern: Pat<
// Transpose of a transpose.
(ONNXTransposeOp (ONNXTransposeOp $v, $p1), $p2),
// Transpose with combined pattern.
(ONNXTransposeOp $v, (CreateCombinedTransposedPattern $p1, $p2))>;
def RemoveIdentityTransposePattern: Pat<
// Transpose with an identity pattern (e.g. {0, 1, 2, 3}).
(ONNXTransposeOp $v, $p),
// Remove the transpose.
(replaceWithValue $v),
// Check that we have indeed a identity transpose pattern.
[(IsIdentityPermuteAttribute:$p)]>;
#endif // ONNX_COMBINE #endif // ONNX_COMBINE

View File

@ -191,3 +191,34 @@ func @test_conv_batchnormtestmode_fusion(%arg0 : tensor<1x3x224x224xf32>, %arg1
// CHECK: return [[RES]] : tensor<1x64x112x112xf32> // CHECK: return [[RES]] : tensor<1x64x112x112xf32>
} }
// -----
// Check the removal of identity transposes.
// CHECK-LABEL: func @test_transpose_removal(%arg0: tensor<10x11x12x13xf32>) -> tensor<10x11x12x13xf32> {
func @test_transpose_removal(%arg0: tensor<10x11x12x13xf32>) -> tensor<10x11x12x13xf32> {
%0 = "onnx.Transpose"(%arg0) {perm = [0, 1, 2, 3]} : (tensor<10x11x12x13xf32>) -> tensor<10x11x12x13xf32>
// CHECK-NEXT: return %arg0 : tensor<10x11x12x13xf32>
"std.return"(%0) : (tensor<10x11x12x13xf32>) -> ()
}
// -----
// Check the combining of transposes into a simple transpose.
// CHECK-LABEL: func @test_transpose_fusion(%arg0: tensor<10x11x12x13xf32>) -> tensor<11x10x13x12xf32> {
func @test_transpose_fusion(%arg0: tensor<10x11x12x13xf32>) -> tensor<11x10x13x12xf32> {
%0 = "onnx.Transpose"(%arg0) {perm = [3, 2, 1, 0]} : (tensor<10x11x12x13xf32>) -> tensor<13x12x11x10xf32>
%1 = "onnx.Transpose"(%0) {perm = [2, 3, 0, 1]} : (tensor<13x12x11x10xf32>) -> tensor<11x10x13x12xf32>
// CHECK-NEXT: %{{.*}} = "onnx.Transpose"(%arg0) {perm = [1, 0, 3, 2]} : (tensor<10x11x12x13xf32>) -> tensor<11x10x13x12xf32>
"std.return"(%1) : (tensor<11x10x13x12xf32>) -> ()
}
// -----
// Check the combining of transposes into an identity transpose, which in turns is removed.
// CHECK-LABEL: func @test_transpose_fusion_removal(%arg0: tensor<10x11x12x13xf32>) -> tensor<10x11x12x13xf32> {
func @test_transpose_fusion_removal(%arg0: tensor<10x11x12x13xf32>) -> tensor<10x11x12x13xf32> {
%0 = "onnx.Transpose"(%arg0) {perm = [3, 2, 1, 0]} : (tensor<10x11x12x13xf32>) -> tensor<13x12x11x10xf32>
%1 = "onnx.Transpose"(%0) {perm = [3, 2, 1, 0]} : (tensor<13x12x11x10xf32>) -> tensor<10x11x12x13xf32>
// CHECK-NEXT: return %arg0 : tensor<10x11x12x13xf32>
"std.return"(%1) : (tensor<10x11x12x13xf32>) -> ()
}

View File

@ -321,7 +321,7 @@ OpsWithShapeInference=[
] ]
# Operations supporting canonicalization. # Operations supporting canonicalization.
OpsWithCanonicalizer = ['Add', 'Identity', 'Gemm', 'Conv', 'Cast'] OpsWithCanonicalizer = ['Add', 'Identity', 'Gemm', 'Conv', 'Cast', 'Transpose']
# Operations who have operands that, if produced by constant operations, should # Operations who have operands that, if produced by constant operations, should
# be promoted to become an attribute (via attribute promotion). # be promoted to become an attribute (via attribute promotion).