From 3d4ad520115a0038f60ff80e7d290146c604948d Mon Sep 17 00:00:00 2001 From: "Tung D. Le" Date: Wed, 8 Jan 2020 12:11:21 +0900 Subject: [PATCH] Rewrite tanh using TanhOp, add log, cos --- src/dialect/onnx/gen_doc.py | 2 +- src/dialect/onnx/onnx_ops.cpp | 12 ++++++++ src/dialect/onnx/onnxop.inc | 4 +-- src/pass/lower_frontend_to_krnl.cpp | 42 ++++++++++++------------- src/pass/shape_inference_pass.cpp | 2 ++ test/mlir/onnx/onnx_lowering.mlir | 48 ++++++++++++++++++++++++----- 6 files changed, 77 insertions(+), 33 deletions(-) diff --git a/src/dialect/onnx/gen_doc.py b/src/dialect/onnx/gen_doc.py index 709055e..da8ec59 100644 --- a/src/dialect/onnx/gen_doc.py +++ b/src/dialect/onnx/gen_doc.py @@ -267,7 +267,7 @@ def gen_schema(schema) : 'Add', 'Mul', 'Div', 'Sub', 'And', 'Or', 'Xor', 'Sum', 'Max', 'Min', 'MatMul', 'Gemm', 'LeakyRelu', 'Elu', 'Selu', 'HardSigmoid', 'Reshape', 'Reciprocal', - 'Identity'] + 'Identity', 'Cos', 'Log'] CanonicalList=['Add', 'Identity'] line_indent = ' ' diff --git a/src/dialect/onnx/onnx_ops.cpp b/src/dialect/onnx/onnx_ops.cpp index 4cb920c..902cbf6 100644 --- a/src/dialect/onnx/onnx_ops.cpp +++ b/src/dialect/onnx/onnx_ops.cpp @@ -91,6 +91,18 @@ void ONNXCoshOp::inferShapes() { getResult()->setType(getOperand()->getType()); } +//===----------------------------------------------------------------------===// +// Cos +/// Infer the output shape of the ONNXCosOp. This method is required by the +/// shape inference interface. +void ONNXCosOp::inferShapes() { getResult()->setType(getOperand()->getType()); } + +//===----------------------------------------------------------------------===// +// Log +/// Infer the output shape of the ONNXLogOp. This method is required by the +/// shape inference interface. +void ONNXLogOp::inferShapes() { getResult()->setType(getOperand()->getType()); } + //===----------------------------------------------------------------------===// // HardSigmoid /// Infer the output shape of the ONNXHardSigmoidOp. This method is required by diff --git a/src/dialect/onnx/onnxop.inc b/src/dialect/onnx/onnxop.inc index b0434ba..e3e6a94 100644 --- a/src/dialect/onnx/onnxop.inc +++ b/src/dialect/onnx/onnxop.inc @@ -361,7 +361,7 @@ def ONNXConvTransposeOp:ONNX_Op<"ConvTranspose", } def ONNXCosOp:ONNX_Op<"Cos", - [NoSideEffect]> { + [NoSideEffect, DeclareOpInterfaceMethods]> { let summary = "ONNX Cos operation"; let description = [{ "Calculates the cosine of the given input tensor, element-wise." @@ -1216,7 +1216,7 @@ def ONNXLessOp:ONNX_Op<"Less", } def ONNXLogOp:ONNX_Op<"Log", - [NoSideEffect]> { + [NoSideEffect, DeclareOpInterfaceMethods]> { let summary = "ONNX Log operation"; let description = [{ "Calculates the natural log of the given input tensor, element-wise." diff --git a/src/pass/lower_frontend_to_krnl.cpp b/src/pass/lower_frontend_to_krnl.cpp index d38954b..29697c9 100644 --- a/src/pass/lower_frontend_to_krnl.cpp +++ b/src/pass/lower_frontend_to_krnl.cpp @@ -286,6 +286,24 @@ struct ScalarOp { using IOp = AddIOp; }; +template <> +struct ScalarOp { + using FOp = TanhOp; + using IOp = TanhOp; // not use +}; + +template <> +struct ScalarOp { + using FOp = CosOp; + using IOp = CosOp; // not use +}; + +template <> +struct ScalarOp { + using FOp = LogOp; + using IOp = LogOp; // not use +}; + template using ScalarFOp = typename ScalarOp::FOp; template @@ -314,28 +332,6 @@ Value mapToLowerScalarOp(Operation *op, ArrayRef result_types, } } -//===----------------------------------------------------------------------===// -// Scalar unary ops for lowering ONNXTanhOp -//===----------------------------------------------------------------------===// -template <> -Value mapToLowerScalarOp(Operation *op, ArrayRef result_types, - ArrayRef operands, - ConversionPatternRewriter &rewriter) { - // ONNXTanhOp(%X) = DivFOp(SubFOp(ExpOp(%X), ExpOp(NegFOp(%X))), - // AddFOp(ExpOp(%X), ExpOp(NegFOp(%X)))) - auto loc = op->getLoc(); - Value operand = operands[0]; - - auto zero = rewriter.create(loc, rewriter.getF32FloatAttr(0.0f)); - auto neg = rewriter.create(loc, zero, operand); - auto exp = rewriter.create(loc, operand); - auto negExp = rewriter.create(loc, neg); - auto diff = rewriter.create(loc, exp, negExp); - auto sum = rewriter.create(loc, exp, negExp); - auto result = rewriter.create(loc, diff, sum); - - return result; -} //===----------------------------------------------------------------------===// // Scalar unary ops for lowering ONNXSinhOp @@ -982,6 +978,8 @@ void FrontendToKrnlLoweringPass::runOnModule() { ONNXElementwiseUnaryOpLowering, ONNXElementwiseUnaryOpLowering, ONNXElementwiseUnaryOpLowering, + ONNXElementwiseUnaryOpLowering, + ONNXElementwiseUnaryOpLowering, ONNXElementwiseUnaryOpLowering, ONNXElementwiseUnaryOpLowering, ONNXElementwiseUnaryOpLowering, diff --git a/src/pass/shape_inference_pass.cpp b/src/pass/shape_inference_pass.cpp index d80e042..f54feb4 100644 --- a/src/pass/shape_inference_pass.cpp +++ b/src/pass/shape_inference_pass.cpp @@ -92,6 +92,8 @@ public: op->getName().getStringRef() != "onnx.Tanh" && op->getName().getStringRef() != "onnx.Sinh" && op->getName().getStringRef() != "onnx.Cosh" && + op->getName().getStringRef() != "onnx.Cos" && + op->getName().getStringRef() != "onnx.Log" && op->getName().getStringRef() != "onnx.Sigmoid" && op->getName().getStringRef() != "onnx.HardSigmoid" && op->getName().getStringRef() != "onnx.Elu" && diff --git a/test/mlir/onnx/onnx_lowering.mlir b/test/mlir/onnx/onnx_lowering.mlir index 92d4a0f..123e6a1 100644 --- a/test/mlir/onnx/onnx_lowering.mlir +++ b/test/mlir/onnx/onnx_lowering.mlir @@ -159,14 +159,8 @@ func @test_tanh(%arg0 : tensor) -> tensor<*xf32> { // CHECK: [[DIM_2:%.+]] = dim %arg0, 0 : memref // CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg1 = 0 to [[DIM_2]], [[DEF_LOOPS]]#1 -> %arg2 = 0 to 10) { // CHECK: [[LOAD:%.+]] = load %arg0[%arg1, %arg2] : memref - // CHECK: [[ZERO:%.+]] = constant {{0.+}} : f32 - // CHECK: [[NLOAD:%.+]] = subf [[ZERO]], [[LOAD]] : f32 - // CHECK: [[EXP:%.+]] = exp [[LOAD]] : f32 - // CHECK: [[NEXP:%.+]] = exp [[NLOAD]] : f32 - // CHECK: [[DIVIDEND:%.+]] = subf [[EXP]], [[NEXP]] : f32 - // CHECK: [[DIVISOR:%.+]] = addf [[EXP]], [[NEXP]] : f32 - // CHECK: [[TANH_RES:%.+]] = divf [[DIVIDEND]], [[DIVISOR]] : f32 - // CHECK: store [[TANH_RES]], [[RES]][%arg1, %arg2] : memref + // CHECK: [[TANH:%.+]] = tanh [[LOAD]] : f32 + // CHECK: store [[TANH]], [[RES]][%arg1, %arg2] : memref // CHECK: return [[RES]] : memref } @@ -220,6 +214,44 @@ func @test_cosh(%arg0 : tensor) -> tensor<*xf32> { // CHECK: return [[RES]] : memref } +func @test_cos(%arg0 : tensor) -> tensor<*xf32> { + %0 = "onnx.Cos"(%arg0) : (tensor) -> tensor<*xf32> + "std.return"(%0) : (tensor<*xf32>) -> () + + // CHECK-LABEL: test_cos + // CHECK: [[DIM_0:%.+]] = dim %arg0, 0 : memref + // CHECK: [[RES:%.+]] = alloc([[DIM_0]]) : memref + // CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2 + // CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops { + // CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1 + // CHECK: } : () -> (!krnl.loop, !krnl.loop) + // CHECK: [[DIM_2:%.+]] = dim %arg0, 0 : memref + // CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg1 = 0 to [[DIM_2]], [[DEF_LOOPS]]#1 -> %arg2 = 0 to 10) { + // CHECK: [[LOAD:%.+]] = load %arg0[%arg1, %arg2] : memref + // CHECK: [[COS:%.+]] = cos [[LOAD]] : f32 + // CHECK: store [[COS]], [[RES]][%arg1, %arg2] : memref + // CHECK: return [[RES]] : memref +} + +func @test_log(%arg0 : tensor) -> tensor<*xf32> { + %0 = "onnx.Log"(%arg0) : (tensor) -> tensor<*xf32> + "std.return"(%0) : (tensor<*xf32>) -> () + + // CHECK-LABEL: test_log + // CHECK: [[DIM_0:%.+]] = dim %arg0, 0 : memref + // CHECK: [[RES:%.+]] = alloc([[DIM_0]]) : memref + // CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2 + // CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops { + // CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1 + // CHECK: } : () -> (!krnl.loop, !krnl.loop) + // CHECK: [[DIM_2:%.+]] = dim %arg0, 0 : memref + // CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg1 = 0 to [[DIM_2]], [[DEF_LOOPS]]#1 -> %arg2 = 0 to 10) { + // CHECK: [[LOAD:%.+]] = load %arg0[%arg1, %arg2] : memref + // CHECK: [[LOG:%.+]] = log [[LOAD]] : f32 + // CHECK: store [[LOG]], [[RES]][%arg1, %arg2] : memref + // CHECK: return [[RES]] : memref +} + func @test_sigmoid(%arg0 : tensor) -> tensor<*xf32> { %0 = "onnx.Sigmoid"(%arg0) : (tensor) -> tensor<*xf32> "std.return"(%0) : (tensor<*xf32>) -> ()