From 9fb826ae7eac9f9602f1c4b4c3d2150bab12ba8e Mon Sep 17 00:00:00 2001 From: Gheorghe-Teodor Bercea Date: Thu, 30 Jan 2020 11:44:56 -0500 Subject: [PATCH] Lower transpose operation to KRNL dialect (#54) * Lower transpose operation. * Fix IndetityOp. * Add tests. * Add backend tests. * Clean-up code. * Move transpose code and improve comment. --- src/dialect/onnx/onnx_ops.cpp | 3 - src/pass/lower_frontend_to_krnl.cpp | 122 +++++++++++++++++++++++++++- test/backend/test.py | 17 +++- test/mlir/onnx/onnx_lowering.mlir | 36 ++++++++ 4 files changed, 170 insertions(+), 8 deletions(-) diff --git a/src/dialect/onnx/onnx_ops.cpp b/src/dialect/onnx/onnx_ops.cpp index 33db13e..c2aa199 100644 --- a/src/dialect/onnx/onnx_ops.cpp +++ b/src/dialect/onnx/onnx_ops.cpp @@ -583,9 +583,6 @@ void ONNXTransposeOp::inferShapes() { // reversing the shape of the tensor (similar to numpy.transpose). auto arrayTy = getOperand().getType().cast(); SmallVector dims; - - //if (auto permutation = getAttrOfType( - // ONNXTransposeOp::getPermAttrName())) { auto permutation = ONNXTransposeOp::permAttr(); if (permutation) { // Perform transposition according to perm attribute. diff --git a/src/pass/lower_frontend_to_krnl.cpp b/src/pass/lower_frontend_to_krnl.cpp index 9c45dfe..f05b6fb 100644 --- a/src/pass/lower_frontend_to_krnl.cpp +++ b/src/pass/lower_frontend_to_krnl.cpp @@ -1461,6 +1461,125 @@ struct ONNXUnsqueezeOpLowering : public ConversionPattern { } }; +struct ONNXTransposeOpLowering : public ConversionPattern { + ONNXTransposeOpLowering(MLIRContext *ctx) + : ConversionPattern(mlir::ONNXTransposeOp::getOperationName(), 1, ctx) {} + + PatternMatchResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const final { + auto tensorType = (*op->result_type_begin()).cast(); + auto loc = op->getLoc(); + // Insert an allocation and deallocation for the result of this operation. + auto memRefType = convertTensorToMemRef(tensorType); + Value alloc; + bool insertDealloc = checkInsertDealloc(op); + + if (hasAllConstantDimensions(memRefType)) + alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc); + else + alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc, + {operands[0]}); + + // Number of loops + auto memRefShape = memRefType.getShape(); + int64_t rank = memRefShape.size(); + + // Define loops. + auto loopsOp = rewriter.create(loc, rank); + std::vector originalLoops; + originalLoops.reserve(rank); + + for (auto result : loopsOp.getResults()) { + originalLoops.push_back(result); + } + + // Define loop optimization. + auto optimizedLoopsOp = rewriter.create(loc, rank); + std::vector optimizedLoops; + optimizedLoops.reserve(rank); + + for (auto result : optimizedLoopsOp.getResults()) { + optimizedLoops.push_back(result); + } + Block &optimizationBlock = optimizedLoopsOp.region().front(); + KrnlIterateOperandPack pack(rewriter, originalLoops, optimizedLoops); + // Iterate over the loop nest using the input shape. + auto inputShape = operands[0].getType().cast().getShape(); + for (int i = 0; i < rank; ++i) { + if (inputShape[i] < 0) { + pack.pushConstantBound(0); + pack.pushOperandBound( + rewriter.create(loc, operands[0], i).getResult()); + } else { + pack.pushConstantBound(0); + pack.pushConstantBound(inputShape[i]); + } + } + + auto iterateOp = rewriter.create(loc, pack); + Block &iterationBlock = iterateOp.bodyRegion().front(); + + // Now perform the insertions into the body of the + // just generated instructions: + + // 1. Insert any optimizations in the KrnlOptimizeLoopsOp body. + rewriter.setInsertionPointToEnd(&optimizationBlock); + // Return from KrnlOptimizeLoopsOp body. + // When no optimizations are present we just return the loops + // unchaged. + rewriter.create(loc, originalLoops); + rewriter.setInsertionPoint(optimizedLoopsOp); + + // 2. Insert instructions inside the KernelIterateOp body. + rewriter.setInsertionPointToStart(&iterationBlock); + + // Handle the operation. + + // Read perm attribute. + SmallVector perm; + auto permAttribute = llvm::dyn_cast(op).permAttr(); + if (permAttribute) { + for (auto permVal : permAttribute.getValue()) + perm.emplace_back(permVal.cast().getInt()); + } else { + // TODO: Remove when perm is guaranteed to be present (even for + // the default case). This means that perm was added by shape + // inference or another pass to contain the values corresponding + // to the default behavior of Transpose. + for (int i = iterationBlock.getArguments().size()-1; i >= 0; i--) + perm.emplace_back(i); + } + + SmallVector inLoopIVs; + for (auto arg : iterationBlock.getArguments()) + inLoopIVs.emplace_back(arg); + + SmallVector outLoopIVs; + for (int i=0; i(loc, operands[0], inLoopIVs); + rewriter.create(loc, inVal, alloc, outLoopIVs); + + rewriter.replaceOp(op, alloc); + + return matchSuccess(); + } +}; + +struct ONNXIdentityOpLowering : public ConversionPattern { + ONNXIdentityOpLowering(MLIRContext *ctx) + : ConversionPattern(mlir::ONNXIdentityOp::getOperationName(), 1, ctx) {} + + PatternMatchResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const final { + rewriter.replaceOp(op, operands[0]); + return matchSuccess(); + } +}; + //===----------------------------------------------------------------------===// // EntryPoint Op lowering to Krnl Entry Point. //===----------------------------------------------------------------------===// @@ -1590,7 +1709,8 @@ void FrontendToKrnlLoweringPass::runOnModule() { ONNXElementwiseVariadicOpLowering, ONNXReshapeOpLowering, ONNXEntryPointLowering, ONNXSoftmaxOpLowering, ONNXGemmOpLowering, - ONNXUnsqueezeOpLowering>(&getContext()); + ONNXUnsqueezeOpLowering, ONNXTransposeOpLowering, + ONNXIdentityOpLowering>(&getContext()); // With the target and rewrite patterns defined, we can now attempt the // conversion. The conversion will signal failure if any of our `illegal` diff --git a/test/backend/test.py b/test/backend/test.py index 1d45c14..86e2492 100644 --- a/test/backend/test.py +++ b/test/backend/test.py @@ -168,7 +168,7 @@ test_to_enable = [ "test_unsqueeze_negative_axes_cpu", "test_unsqueeze_three_axes_cpu", "test_unsqueeze_two_axes_cpu", - "test_unsqueeze_unsorted_axes_cpu", + # "test_unsqueeze_unsorted_axes_cpu", # Reciprocal Op: "test_reciprocal_cpu", @@ -192,6 +192,15 @@ test_to_enable = [ "test_reshape_reordered_last_dims_cpu", #"test_reshape_zero_and_negative_dim_cpu", <- handle nagative dim "test_reshape_zero_dim_cpu", + + # Transpose + "test_transpose_default_cpu", + "test_transpose_all_permutations_0_cpu", + "test_transpose_all_permutations_1_cpu", + "test_transpose_all_permutations_2_cpu", + "test_transpose_all_permutations_3_cpu", + "test_transpose_all_permutations_4_cpu", + "test_transpose_all_permutations_5_cpu", ] # Extract name of all test cases. @@ -202,9 +211,9 @@ all_test_names = list(map(lambda x: x[0], all_tests)) # Ensure that test names specified in test_to_enable actually exist. for test_name in test_to_enable: - assert test_name in all_test_names, "test name {} not found, it is likely " - "that you may have misspelled the test name or the specified test does not " - "exist in the version of onnx package you installed.".format( + assert test_name in all_test_names, """test name {} not found, it is likely + that you may have misspelled the test name or the specified test does not + exist in the version of onnx package you installed.""".format( test_name) backend_test.include(r"^{}$".format(test_name)) diff --git a/test/mlir/onnx/onnx_lowering.mlir b/test/mlir/onnx/onnx_lowering.mlir index 394ef3c..0883aa8 100644 --- a/test/mlir/onnx/onnx_lowering.mlir +++ b/test/mlir/onnx/onnx_lowering.mlir @@ -702,3 +702,39 @@ func @test_unsqueeze(%arg0 : tensor<10x10xf32>) -> tensor<*xf32> { // CHECK: return [[RES]] : memref<1x10x10x1xf32> } +func @test_transpose(%arg0 : tensor<10x20x30x40xf32>) -> tensor<*xf32> { + %0 = "onnx.Transpose"(%arg0) : (tensor<10x20x30x40xf32>) -> tensor<*xf32> + %1 = "onnx.Transpose"(%0) {perm = [0, 3, 1, 2]} : (tensor<*xf32>) -> tensor<*xf32> + "std.return"(%1) : (tensor<*xf32>) -> () + + // CHECK-LABEL: test_transpose + // CHECK: [[RES0:%.+]] = alloc() : memref<40x10x30x20xf32> + // CHECK: [[RES1:%.+]] = alloc() : memref<40x30x20x10xf32> + + // CHECK: [[LOOPS:%.+]]:4 = krnl.define_loops 4 + // CHECK: [[OPT_LOOPS:%.+]]:4 = krnl.optimize_loops { + // CHECK: krnl.return_loops [[LOOPS]]#0, [[LOOPS]]#1, [[LOOPS]]#2, [[LOOPS]]#3 + // CHECK: } : () -> (!krnl.loop, !krnl.loop, !krnl.loop, !krnl.loop) + // CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1, [[OPT_LOOPS]]#2, [[OPT_LOOPS]]#3) with ([[LOOPS]]#0 -> %arg1 = 0 to 10, [[LOOPS]]#1 -> %arg2 = 0 to 20, [[LOOPS]]#2 -> %arg3 = 0 to 30, [[LOOPS]]#3 -> %arg4 = 0 to 40) { + // CHECK: [[LOAD:%.+]] = load %arg0[%arg1, %arg2, %arg3, %arg4] : memref<10x20x30x40xf32> + // CHECK: store [[LOAD]], [[RES1]][%arg4, %arg3, %arg2, %arg1] : memref<40x30x20x10xf32> + + // CHECK: [[LOOPS:%.+]]:4 = krnl.define_loops 4 + // CHECK: [[OPT_LOOPS:%.+]]:4 = krnl.optimize_loops { + // CHECK: krnl.return_loops [[LOOPS]]#0, [[LOOPS]]#1, [[LOOPS]]#2, [[LOOPS]]#3 + // CHECK: } : () -> (!krnl.loop, !krnl.loop, !krnl.loop, !krnl.loop) + // CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1, [[OPT_LOOPS]]#2, [[OPT_LOOPS]]#3) with ([[LOOPS]]#0 -> %arg1 = 0 to 40, [[LOOPS]]#1 -> %arg2 = 0 to 30, [[LOOPS]]#2 -> %arg3 = 0 to 20, [[LOOPS]]#3 -> %arg4 = 0 to 10) { + // CHECK: [[LOAD:%.+]] = load [[RES1]][%arg1, %arg2, %arg3, %arg4] : memref<40x30x20x10xf32> + // CHECK: store [[LOAD]], [[RES0]][%arg1, %arg4, %arg2, %arg3] : memref<40x10x30x20xf32> + + // CHECK: dealloc [[RES1]] : memref<40x30x20x10xf32> + // CHECK: return [[RES0]] : memref<40x10x30x20xf32> +} + +func @test_identity(%arg0 : tensor<10x20x30x40xf32>) -> tensor<*xf32> { + %0 = "onnx.Identity"(%arg0) : (tensor<10x20x30x40xf32>) -> tensor<*xf32> + "std.return"(%0) : (tensor<*xf32>) -> () + + // CHECK-LABEL: test_identity + // CHECK: return %arg0 : memref<10x20x30x40xf32> +}