From 5357fc1421333391522fe694612bacd3e00da953 Mon Sep 17 00:00:00 2001 From: "Tung D. Le" Date: Thu, 27 Feb 2020 02:03:24 +0900 Subject: [PATCH] Use SqrtOp in Standard dialect (#108) Co-authored-by: Gheorghe-Teodor Bercea --- .../onnx_to_krnl/math/elementwise.cpp | 4 +- .../onnx_to_krnl/nn/normalization.cpp | 3 +- src/dialect/krnl/krnl_ops.td | 14 ---- src/transform/lower_krnl.cpp | 1 - src/transform/lower_to_llvm.cpp | 65 +------------------ test/mlir/krnl/sqrt.mlir | 24 ------- test/mlir/onnx/onnx_lowering.mlir | 6 +- 7 files changed, 8 insertions(+), 109 deletions(-) delete mode 100644 test/mlir/krnl/sqrt.mlir diff --git a/src/conversion/onnx_to_krnl/math/elementwise.cpp b/src/conversion/onnx_to_krnl/math/elementwise.cpp index b397281..55d4cda 100644 --- a/src/conversion/onnx_to_krnl/math/elementwise.cpp +++ b/src/conversion/onnx_to_krnl/math/elementwise.cpp @@ -86,8 +86,8 @@ struct ScalarOp { template <> struct ScalarOp { - using FOp = KrnlSqrtOp; - using IOp = KrnlSqrtOp; // not use + using FOp = SqrtOp; + using IOp = SqrtOp; // not use }; //===----------------------------------------------------------------------===// diff --git a/src/conversion/onnx_to_krnl/nn/normalization.cpp b/src/conversion/onnx_to_krnl/nn/normalization.cpp index d151f0a..24cfe41 100644 --- a/src/conversion/onnx_to_krnl/nn/normalization.cpp +++ b/src/conversion/onnx_to_krnl/nn/normalization.cpp @@ -123,8 +123,7 @@ struct ONNXBatchNormalizationTestModeOpLowering : public ConversionPattern { auto dividend = rewriter.create(loc, xVal, meanVal); auto adjustedVarianceVal = rewriter.create(loc, varianceVal, epsilon); - auto divisor = rewriter.create(loc, memRefType.getElementType(), - adjustedVarianceVal); + auto divisor = rewriter.create(loc, adjustedVarianceVal); auto normVal = rewriter.create(loc, dividend, divisor); // scale and shift auto scaleNormVal = rewriter.create(loc, scaleVal, normVal); diff --git a/src/dialect/krnl/krnl_ops.td b/src/dialect/krnl/krnl_ops.td index 1649146..bb894f7 100644 --- a/src/dialect/krnl/krnl_ops.td +++ b/src/dialect/krnl/krnl_ops.td @@ -190,17 +190,3 @@ def KrnlMemcpyOp : Op { let parser = ?; let printer = ?; } - -def KrnlSqrtOp : Op { - let summary = "Krnl sqrt operation"; - let description = [{ - "The `sqrt` computes the square root value. It takes one operand and returns - one result with the same type." - }]; - - let arguments = (ins FloatLike:$operand); - let results = (outs FloatLike); - - let parser = ?; - let printer = ?; -} diff --git a/src/transform/lower_krnl.cpp b/src/transform/lower_krnl.cpp index 193e1bf..73d0e42 100644 --- a/src/transform/lower_krnl.cpp +++ b/src/transform/lower_krnl.cpp @@ -144,7 +144,6 @@ void KrnlToAffineLoweringPass::runOnFunction() { target.addIllegalDialect(); target.addLegalOp(); target.addLegalOp(); - target.addLegalOp(); OwningRewritePatternList patterns; patterns.insert operands, - ConversionPatternRewriter &rewriter) const override { - OperandAdaptor adaptor(operands); - LLVM::LLVMType operandType = - adaptor.operand().getType().dyn_cast_or_null(); - - if (!operandType) - return matchFailure(); - - std::string functionName; - if (operandType.isFloatTy()) - functionName = "llvm.sqrt.f32"; - else if (operandType.isDoubleTy()) - functionName = "llvm.sqrt.f64"; - else - assert(false && "Unsupported operand type."); - - // Get a symbol reference to the sqrt function, inserting it if necessary. - ModuleOp parentModule = op->getParentOfType(); - auto sqrtRef = - getOrInsertSqrt(rewriter, parentModule, functionName, operandType); - - // Sqrt call - rewriter.replaceOpWithNewOp(op, operandType, sqrtRef, - adaptor.operand()); - - return matchSuccess(); - } - -private: - /// Return a symbol reference to the sqrt function, inserting it into the - /// module if necessary. - static FlatSymbolRefAttr getOrInsertSqrt(PatternRewriter &rewriter, - ModuleOp module, std::string fnName, - LLVM::LLVMType operandType) { - auto *context = module.getContext(); - if (module.lookupSymbol(fnName)) - return SymbolRefAttr::get(fnName, context); - // Create a function declaration for sqrt, the signature is: - // * `float (float)` - auto llvmFnType = - LLVM::LLVMType::getFunctionTy(operandType, operandType, false); - - // Insert the sqrt function into the body of the parent module. - PatternRewriter::InsertionGuard insertGuard(rewriter); - rewriter.setInsertionPointToStart(module.getBody()); - rewriter.create(module.getLoc(), fnName, llvmFnType); - return SymbolRefAttr::get(fnName, context); - } -}; } // end namespace //===----------------------------------------------------------------------===// @@ -572,8 +511,8 @@ void KrnlToLLVMLoweringPass::runOnModule() { /*emitCWrapper=*/true); // Lower from the `krnl` dialect i.e. the Reshape operation. - patterns.insert(&getContext()); + patterns.insert( + &getContext()); // We want to completely lower to LLVM, so we use a `FullConversion`. This // ensures that only legal operations will remain after the conversion. diff --git a/test/mlir/krnl/sqrt.mlir b/test/mlir/krnl/sqrt.mlir deleted file mode 100644 index bc91bc4..0000000 --- a/test/mlir/krnl/sqrt.mlir +++ /dev/null @@ -1,24 +0,0 @@ -// RUN: onnf-opt --shape-inference --lower-all-llvm %s -split-input-file | FileCheck %s -module { - func @test_sqrt_32(%arg0 : f32) -> f32 { - %0 = "krnl.sqrt"(%arg0) : (f32) -> f32 - "std.return"(%0) : (f32) -> () - - // CHECK: llvm.func @llvm.sqrt.f32(!llvm.float) -> !llvm.float - // CHECK-NEXT: llvm.func @test_sqrt_32(%arg0: !llvm.float) -> !llvm.float { - // CHECK-NEXT: [[RES:%.+]] = llvm.call @llvm.sqrt.f32(%arg0) : (!llvm.float) -> !llvm.float - // CHECK-NEXT: llvm.return [[RES]] : !llvm.float - } -} - -module{ - func @test_sqrt_64(%arg0 : f64) -> f64 { - %0 = "krnl.sqrt"(%arg0) : (f64) -> f64 - "std.return"(%0) : (f64) -> () - - // CHECK: llvm.func @llvm.sqrt.f64(!llvm.double) -> !llvm.double - // CHECK-NEXT: llvm.func @test_sqrt_64(%arg0: !llvm.double) -> !llvm.double { - // CHECK-NEXT: [[RES:%.+]] = llvm.call @llvm.sqrt.f64(%arg0) : (!llvm.double) -> !llvm.double - // CHECK-NEXT: llvm.return [[RES]] : !llvm.double - } -} diff --git a/test/mlir/onnx/onnx_lowering.mlir b/test/mlir/onnx/onnx_lowering.mlir index c35536d..e0dfa19 100644 --- a/test/mlir/onnx/onnx_lowering.mlir +++ b/test/mlir/onnx/onnx_lowering.mlir @@ -820,7 +820,7 @@ func @test_sqrt(%arg0 : tensor) -> tensor<*xf32> { // CHECK: [[DIM_2:%.+]] = dim %arg0, 0 : memref // CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg1 = 0 to [[DIM_2]], [[DEF_LOOPS]]#1 -> %arg2 = 0 to 10) { // CHECK: [[LOAD:%.+]] = load %arg0[%arg1, %arg2] : memref - // CHECK: [[SQRT:%.+]] = "krnl.sqrt"([[LOAD]]) : (f32) -> f32 + // CHECK: [[SQRT:%.+]] = sqrt [[LOAD]] : f32 // CHECK: store [[SQRT]], [[RES]][%arg1, %arg2] : memref // CHECK: return [[RES]] : memref } @@ -1305,7 +1305,7 @@ func @test_batchnorm_testmode_Nd(%arg0: tensor<1x2x1x3xf32>, %arg1: tensor<2xf32 // CHECK: [[LOADED_VAL:%.+]] = load %arg0[%arg6, %arg5, %arg7, %arg8] : memref<1x2x1x3xf32> // CHECK: [[DIVIDEND:%.+]] = subf [[LOADED_VAL]], [[MEAN]] : f32 // CHECK: [[ADJUSTED_VARIANCE:%.+]] = addf [[VARIANCE]], [[EPSILON]] : f32 - // CHECK: [[DIVISOR:%.+]] = "krnl.sqrt"([[ADJUSTED_VARIANCE]]) : (f32) -> f32 + // CHECK: [[DIVISOR:%.+]] = sqrt [[ADJUSTED_VARIANCE]] : f32 // CHECK: [[NORM:%.+]] = divf [[DIVIDEND]], [[DIVISOR]] : f32 // CHECK: [[SCALE_NORM:%.+]] = mulf [[SCALE]], [[NORM]] : f32 // CHECK: [[SHIFT_SCALE_NORM:%.+]] = addf [[SCALE_NORM]], [[BIAS]] : f32 @@ -1335,7 +1335,7 @@ func @test_batchnorm_testmode_1d(%arg0: tensor<10xf32>, %arg1: tensor<1xf32>, %a // CHECK: [[LOADED_VAL:%.+]] = load %arg0[%arg5] : memref<10xf32> // CHECK: [[DIVIDEND:%.+]] = subf [[LOADED_VAL]], [[MEAN]] : f32 // CHECK: [[ADJUSTED_VARIANCE:%.+]] = addf [[VARIANCE]], [[EPSILON]] : f32 - // CHECK: [[DIVISOR:%.+]] = "krnl.sqrt"([[ADJUSTED_VARIANCE]]) : (f32) -> f32 + // CHECK: [[DIVISOR:%.+]] = sqrt [[ADJUSTED_VARIANCE]] : f32 // CHECK: [[NORM:%.+]] = divf [[DIVIDEND]], [[DIVISOR]] : f32 // CHECK: [[SCALE_NORM:%.+]] = mulf [[SCALE]], [[NORM]] : f32 // CHECK: [[SHIFT_SCALE_NORM:%.+]] = addf [[SCALE_NORM]], [[BIAS]] : f32