Transpose optimization (#280)
* Define krnl.permute op. * Support krnl.permute operation. * Properly remove loop references. * Re-push, Github was down. * Need to debug interpretOp error. * Fix lowering bug by erasing ops after full krnl IR interpretation is done, and clean up & comment code. * Introduce permute, unroll operations. * More debug. * Remove std::set. * krnl.terminate fails to be converted. * Pass all tests, need to add legal ops as well as part of the conversion target. * Change test format to new permute spec. * Bug fix for nested iterate op lowering. * Simplify error reporting. * Fix compilation error. * Increase comments coverage. * Remove unnecessary imports. * Re-trigger Jenkins * Add permute/unroll tests. * Retrigger Jenkins * transpose fusion and removal * format * fix comments Co-authored-by: Tian Jin <tjingrant@gmail.com>
This commit is contained in:
parent
8e3748d4c7
commit
8bfde7de4b
|
@ -5456,6 +5456,7 @@ def ONNXTopKOp:ONNX_Op<"TopK",
|
||||||
|
|
||||||
def ONNXTransposeOp:ONNX_Op<"Transpose",
|
def ONNXTransposeOp:ONNX_Op<"Transpose",
|
||||||
[NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
|
[NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
|
||||||
|
let hasCanonicalizer = 1;
|
||||||
let summary = "ONNX Transpose operation";
|
let summary = "ONNX Transpose operation";
|
||||||
let description = [{
|
let description = [{
|
||||||
"Transpose the input tensor similar to numpy.transpose. For example, when"
|
"Transpose the input tensor similar to numpy.transpose. For example, when"
|
||||||
|
|
|
@ -19,6 +19,43 @@ using namespace mlir;
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// Support for transpose patterns.
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
/// Compute the combined permute pattern from a pair of permute patterns.
|
||||||
|
ArrayAttr CombinedTransposePattern(PatternRewriter &rewriter,
|
||||||
|
ArrayAttr &firstPermAttr, ArrayAttr &secondPermAttr) {
|
||||||
|
// Read first permute vectors.
|
||||||
|
SmallVector<int64_t, 4> initialPerm;
|
||||||
|
for (auto firstPermVal : firstPermAttr.getValue())
|
||||||
|
initialPerm.emplace_back(firstPermVal.cast<IntegerAttr>().getInt());
|
||||||
|
// Read second permute vector. Use it as an index in the first permute
|
||||||
|
// vector.
|
||||||
|
SmallVector<int64_t, 4> resPerm;
|
||||||
|
for (auto secondPermVal : secondPermAttr.getValue()) {
|
||||||
|
auto index = secondPermVal.cast<IntegerAttr>().getInt();
|
||||||
|
resPerm.emplace_back(initialPerm[index]);
|
||||||
|
}
|
||||||
|
// Convert to Array of Attributes.
|
||||||
|
ArrayRef<int64_t> resPermRefs(resPerm);
|
||||||
|
return rewriter.getI64ArrayAttr(resPermRefs);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Test if the permute pattern correspond to an identity pattern.
|
||||||
|
/// Identity patterns are {0, 1, 2, ... , rank -1}.
|
||||||
|
bool IsIdentityPermuteVector(ArrayAttr &permAttr) {
|
||||||
|
int64_t currentIndex = 0;
|
||||||
|
for (auto permVal : permAttr.getValue())
|
||||||
|
if (permVal.cast<IntegerAttr>().getInt() != currentIndex++)
|
||||||
|
return false;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// Pattern definition.
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
/// Include the patterns defined in the Declarative Rewrite framework.
|
/// Include the patterns defined in the Declarative Rewrite framework.
|
||||||
#include "src/Transform/ONNX/ONNXCombine.inc"
|
#include "src/Transform/ONNX/ONNXCombine.inc"
|
||||||
} // end anonymous namespace
|
} // end anonymous namespace
|
||||||
|
@ -51,3 +88,10 @@ void ONNXCastOp::getCanonicalizationPatterns(
|
||||||
OwningRewritePatternList &result, MLIRContext *context) {
|
OwningRewritePatternList &result, MLIRContext *context) {
|
||||||
result.insert<CastEliminationPattern>(context);
|
result.insert<CastEliminationPattern>(context);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// on the ONNXTransposeOp.
|
||||||
|
void ONNXTransposeOp::getCanonicalizationPatterns(
|
||||||
|
OwningRewritePatternList &result, MLIRContext *context) {
|
||||||
|
result.insert<FuseTransposePattern>(context);
|
||||||
|
result.insert<RemoveIdentityTransposePattern>(context);
|
||||||
|
}
|
||||||
|
|
|
@ -66,5 +66,27 @@ def CastEliminationPattern : Pat<
|
||||||
(replaceWithValue $arg),
|
(replaceWithValue $arg),
|
||||||
[(HasSameElementType $arg, $type)]>;
|
[(HasSameElementType $arg, $type)]>;
|
||||||
|
|
||||||
|
// Combine transposes.
|
||||||
|
def CreateCombinedTransposedPattern :
|
||||||
|
NativeCodeCall<"CombinedTransposePattern($_builder, $0, $1)">;
|
||||||
|
|
||||||
|
def IsIdentityPermuteAttribute :
|
||||||
|
Constraint<CPred<"IsIdentityPermuteVector($_self)">,
|
||||||
|
"has identity permute vector">;
|
||||||
|
|
||||||
|
def FuseTransposePattern: Pat<
|
||||||
|
// Transpose of a transpose.
|
||||||
|
(ONNXTransposeOp (ONNXTransposeOp $v, $p1), $p2),
|
||||||
|
// Transpose with combined pattern.
|
||||||
|
(ONNXTransposeOp $v, (CreateCombinedTransposedPattern $p1, $p2))>;
|
||||||
|
|
||||||
|
def RemoveIdentityTransposePattern: Pat<
|
||||||
|
// Transpose with an identity pattern (e.g. {0, 1, 2, 3}).
|
||||||
|
(ONNXTransposeOp $v, $p),
|
||||||
|
// Remove the transpose.
|
||||||
|
(replaceWithValue $v),
|
||||||
|
// Check that we have indeed a identity transpose pattern.
|
||||||
|
[(IsIdentityPermuteAttribute:$p)]>;
|
||||||
|
|
||||||
|
|
||||||
#endif // ONNX_COMBINE
|
#endif // ONNX_COMBINE
|
||||||
|
|
|
@ -191,3 +191,34 @@ func @test_conv_batchnormtestmode_fusion(%arg0 : tensor<1x3x224x224xf32>, %arg1
|
||||||
// CHECK: return [[RES]] : tensor<1x64x112x112xf32>
|
// CHECK: return [[RES]] : tensor<1x64x112x112xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// Check the removal of identity transposes.
|
||||||
|
// CHECK-LABEL: func @test_transpose_removal(%arg0: tensor<10x11x12x13xf32>) -> tensor<10x11x12x13xf32> {
|
||||||
|
func @test_transpose_removal(%arg0: tensor<10x11x12x13xf32>) -> tensor<10x11x12x13xf32> {
|
||||||
|
%0 = "onnx.Transpose"(%arg0) {perm = [0, 1, 2, 3]} : (tensor<10x11x12x13xf32>) -> tensor<10x11x12x13xf32>
|
||||||
|
// CHECK-NEXT: return %arg0 : tensor<10x11x12x13xf32>
|
||||||
|
"std.return"(%0) : (tensor<10x11x12x13xf32>) -> ()
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// Check the combining of transposes into a simple transpose.
|
||||||
|
// CHECK-LABEL: func @test_transpose_fusion(%arg0: tensor<10x11x12x13xf32>) -> tensor<11x10x13x12xf32> {
|
||||||
|
func @test_transpose_fusion(%arg0: tensor<10x11x12x13xf32>) -> tensor<11x10x13x12xf32> {
|
||||||
|
%0 = "onnx.Transpose"(%arg0) {perm = [3, 2, 1, 0]} : (tensor<10x11x12x13xf32>) -> tensor<13x12x11x10xf32>
|
||||||
|
%1 = "onnx.Transpose"(%0) {perm = [2, 3, 0, 1]} : (tensor<13x12x11x10xf32>) -> tensor<11x10x13x12xf32>
|
||||||
|
// CHECK-NEXT: %{{.*}} = "onnx.Transpose"(%arg0) {perm = [1, 0, 3, 2]} : (tensor<10x11x12x13xf32>) -> tensor<11x10x13x12xf32>
|
||||||
|
"std.return"(%1) : (tensor<11x10x13x12xf32>) -> ()
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// Check the combining of transposes into an identity transpose, which in turns is removed.
|
||||||
|
// CHECK-LABEL: func @test_transpose_fusion_removal(%arg0: tensor<10x11x12x13xf32>) -> tensor<10x11x12x13xf32> {
|
||||||
|
func @test_transpose_fusion_removal(%arg0: tensor<10x11x12x13xf32>) -> tensor<10x11x12x13xf32> {
|
||||||
|
%0 = "onnx.Transpose"(%arg0) {perm = [3, 2, 1, 0]} : (tensor<10x11x12x13xf32>) -> tensor<13x12x11x10xf32>
|
||||||
|
%1 = "onnx.Transpose"(%0) {perm = [3, 2, 1, 0]} : (tensor<13x12x11x10xf32>) -> tensor<10x11x12x13xf32>
|
||||||
|
// CHECK-NEXT: return %arg0 : tensor<10x11x12x13xf32>
|
||||||
|
"std.return"(%1) : (tensor<10x11x12x13xf32>) -> ()
|
||||||
|
}
|
||||||
|
|
|
@ -321,7 +321,7 @@ OpsWithShapeInference=[
|
||||||
]
|
]
|
||||||
|
|
||||||
# Operations supporting canonicalization.
|
# Operations supporting canonicalization.
|
||||||
OpsWithCanonicalizer = ['Add', 'Identity', 'Gemm', 'Conv', 'Cast']
|
OpsWithCanonicalizer = ['Add', 'Identity', 'Gemm', 'Conv', 'Cast', 'Transpose']
|
||||||
|
|
||||||
# Operations who have operands that, if produced by constant operations, should
|
# Operations who have operands that, if produced by constant operations, should
|
||||||
# be promoted to become an attribute (via attribute promotion).
|
# be promoted to become an attribute (via attribute promotion).
|
||||||
|
|
Loading…
Reference in New Issue