From 1c3176bf9fef3863f7b865269f0a48eba9b65146 Mon Sep 17 00:00:00 2001 From: TUNG LEDUC Date: Fri, 6 Dec 2019 10:08:09 +0900 Subject: [PATCH] [MLIR] Lower ONNX element-wise unary ops: Exp, Tanh, Sinh, Cosh, Sigmoid (#389) * Lower ExpOp * Lower ONNXTanhOp * Lower Exp Tanh, Sinh, and Cosh * Lower ONNX Sigmoid op * Merge * Specialize template lowerScalarOp * Unify ONNXEWUnaryOpLowering and ONNXEWBinaryOpLowering * Support multiple types * Reformat the code * Add test cases * Reformat the code * Change names * Apply clang-format * Update variable names --- src/compiler/dialect/onnx/gen_doc.py | 4 +- src/compiler/dialect/onnx/onnx_ops.cpp | 40 +++ src/compiler/dialect/onnx/onnxop.inc | 10 +- src/compiler/pass/lower_frontend_to_krnl.cpp | 222 +++++++++++++--- src/compiler/pass/shape_inference_pass.cpp | 13 +- test/mlir/onnx/onnx_lowering.mlir | 118 +++++++++ .../mlir/onnx/onnx_lowering_with_dealloc.mlir | 241 ++++++++++++++++++ 7 files changed, 609 insertions(+), 39 deletions(-) diff --git a/src/compiler/dialect/onnx/gen_doc.py b/src/compiler/dialect/onnx/gen_doc.py index fe428fb..b63af6e 100644 --- a/src/compiler/dialect/onnx/gen_doc.py +++ b/src/compiler/dialect/onnx/gen_doc.py @@ -263,7 +263,9 @@ def collect_types(schema, input) : return allowedTypeStr def gen_schema(schema) : - ShapeInferenceList=['Add', 'Mul', 'Div', 'Sub', 'And', 'Or', 'Xor', 'MatMul', 'Gemm'] + ShapeInferenceList=['Exp', 'Tanh', 'Sinh', 'Cosh', 'Sigmoid', + 'Add', 'Mul', 'Div', 'Sub', 'And', 'Or', 'Xor', + 'MatMul', 'Gemm'] CanonicalList=['Add', 'Identity'] line_indent = ' ' diff --git a/src/compiler/dialect/onnx/onnx_ops.cpp b/src/compiler/dialect/onnx/onnx_ops.cpp index bd42ef0..a6299f3 100644 --- a/src/compiler/dialect/onnx/onnx_ops.cpp +++ b/src/compiler/dialect/onnx/onnx_ops.cpp @@ -38,6 +38,46 @@ ONNXOpsDialect::ONNXOpsDialect(mlir::MLIRContext* ctx) //===----------------------------------------------------------------------===// // ONNX Operations +//===----------------------------------------------------------------------===// +// Exp +/// Infer the output shape of the ONNXExpOp. This method is required by the +/// shape inference interface. +void ONNXExpOp::inferShapes() { + getResult()->setType(getOperand()->getType()); +} + +//===----------------------------------------------------------------------===// +// Tanh +/// Infer the output shape of the ONNXTanhOp. This method is required by the +/// shape inference interface. +void ONNXTanhOp::inferShapes() { + getResult()->setType(getOperand()->getType()); +} + +//===----------------------------------------------------------------------===// +// Sinh +/// Infer the output shape of the ONNXSinhOp. This method is required by the +/// shape inference interface. +void ONNXSinhOp::inferShapes() { + getResult()->setType(getOperand()->getType()); +} + +//===----------------------------------------------------------------------===// +// Cosh +/// Infer the output shape of the ONNXCoshOp. This method is required by the +/// shape inference interface. +void ONNXCoshOp::inferShapes() { + getResult()->setType(getOperand()->getType()); +} + +//===----------------------------------------------------------------------===// +// Sigmoid +/// Infer the output shape of the ONNXSigmoidOp. This method is required by the +/// shape inference interface. +void ONNXSigmoidOp::inferShapes() { + getResult()->setType(getOperand()->getType()); +} + //===----------------------------------------------------------------------===// // Add /// Infer the output shape of the ONNXAddOp. This method is required by the diff --git a/src/compiler/dialect/onnx/onnxop.inc b/src/compiler/dialect/onnx/onnxop.inc index 261c52e..d44f312 100644 --- a/src/compiler/dialect/onnx/onnxop.inc +++ b/src/compiler/dialect/onnx/onnxop.inc @@ -371,7 +371,7 @@ def ONNXCosOp:ONNX_Op<"Cos", } def ONNXCoshOp:ONNX_Op<"Cosh", - [NoSideEffect]> { + [NoSideEffect, DeclareOpInterfaceMethods]> { let summary = "ONNX Cosh operation"; let description = [{ "Calculates the hyperbolic cosine of the given input tensor element-wise." @@ -567,7 +567,7 @@ def ONNXErfOp:ONNX_Op<"Erf", } def ONNXExpOp:ONNX_Op<"Exp", - [NoSideEffect]> { + [NoSideEffect, DeclareOpInterfaceMethods]> { let summary = "ONNX Exp operation"; let description = [{ "Calculates the exponential of the given input tensor, element-wise." @@ -2731,7 +2731,7 @@ def ONNXShrinkOp:ONNX_Op<"Shrink", } def ONNXSigmoidOp:ONNX_Op<"Sigmoid", - [NoSideEffect]> { + [NoSideEffect, DeclareOpInterfaceMethods]> { let summary = "ONNX Sigmoid operation"; let description = [{ "Sigmoid takes one input data (Tensor) and produces one output data" @@ -2764,7 +2764,7 @@ def ONNXSinOp:ONNX_Op<"Sin", } def ONNXSinhOp:ONNX_Op<"Sinh", - [NoSideEffect]> { + [NoSideEffect, DeclareOpInterfaceMethods]> { let summary = "ONNX Sinh operation"; let description = [{ "Calculates the hyperbolic sine of the given input tensor element-wise." @@ -2994,7 +2994,7 @@ def ONNXTanOp:ONNX_Op<"Tan", } def ONNXTanhOp:ONNX_Op<"Tanh", - [NoSideEffect]> { + [NoSideEffect, DeclareOpInterfaceMethods]> { let summary = "ONNX Tanh operation"; let description = [{ "Calculates the hyperbolic tangent of the given input tensor element-wise." diff --git a/src/compiler/pass/lower_frontend_to_krnl.cpp b/src/compiler/pass/lower_frontend_to_krnl.cpp index f2f7d39..cc319f5 100644 --- a/src/compiler/pass/lower_frontend_to_krnl.cpp +++ b/src/compiler/pass/lower_frontend_to_krnl.cpp @@ -44,9 +44,8 @@ static MemRefType convertTensorToMemRef(TensorType type) { } /// Insert an allocation and deallocation for the given MemRefType. -static Value* insertAllocAndDealloc( - MemRefType type, Location loc, PatternRewriter& rewriter, - bool insertDealloc, Value *oldMemRef = nullptr) { +static Value* insertAllocAndDealloc(MemRefType type, Location loc, + PatternRewriter& rewriter, bool insertDealloc, Value* oldMemRef = nullptr) { // Put together alloc operands for any dynamic dimensions of the memref. AllocOp alloc; if (oldMemRef) { @@ -77,7 +76,7 @@ static Value* insertAllocAndDealloc( // Determine if current function returns the result value of the // current op being lowered. If it does then dealloc should not be // inserted. -static bool checkInsertDealloc(Operation *currentOp) { +static bool checkInsertDealloc(Operation* currentOp) { auto parentBlock = currentOp->getBlock(); bool insertDealloc = true; @@ -87,7 +86,7 @@ static bool checkInsertDealloc(Operation *currentOp) { // If there is at least one result to investigate. if (currentOp->getNumResults() > 0) { auto result = currentOp->getResult(0); - for(auto operand : op.getOperands()) + for (auto operand : op.getOperands()) if (operand == result) insertDealloc = false; } @@ -98,14 +97,166 @@ static bool checkInsertDealloc(Operation *currentOp) { namespace { -//===----------------------------------------------------------------------===// -// Element-wise binary ops lowering to Krnl dialect. -//===----------------------------------------------------------------------===// -template -struct ONNXEWBinaryOpLowering : public ConversionPattern { - ONNXEWBinaryOpLowering(MLIRContext* ctx) - : ConversionPattern(BinaryOp::getOperationName(), 1, ctx) {} +template +struct ScalarOp; +template <> +struct ScalarOp { + using FOp = AddFOp; + using IOp = AddIOp; +}; + +template <> +struct ScalarOp { + using FOp = MulFOp; + using IOp = MulIOp; +}; + +template <> +struct ScalarOp { + using FOp = DivFOp; + using IOp = DivISOp; +}; + +template <> +struct ScalarOp { + using FOp = SubFOp; + using IOp = SubIOp; +}; + +template <> +struct ScalarOp { + using FOp = AndOp; // not use + using IOp = AndOp; +}; + +template <> +struct ScalarOp { + using FOp = OrOp; // not use + using IOp = OrOp; +}; + +template <> +struct ScalarOp { + using FOp = XOrOp; // not use + using IOp = XOrOp; +}; + +template <> +struct ScalarOp { + using FOp = ExpOp; + using IOp = ExpOp; // not use +}; + +template +using ScalarFOp = typename ScalarOp::FOp; +template +using ScalarIOp = typename ScalarOp::IOp; + +//===----------------------------------------------------------------------===// +// Scalar unary ops for lowering to Krnl dialect. +//===----------------------------------------------------------------------===// +template +Value* mapToLowerScalarOp(Location loc, ArrayRef result_types, + ArrayRef operands, ConversionPatternRewriter& rewriter) { + /* Lower UnaryOp to Ops in the Standard dialect. + */ + + Type element_type = operands.front()->getType(); + if (element_type.isa()) { + return rewriter.create>( + loc, result_types, operands, mlir::None); + } else if (element_type.isa()) { + return rewriter.create>( + loc, result_types, operands, mlir::None); + } else { + return nullptr; + } +} + +//===----------------------------------------------------------------------===// +// Scalar unary ops for lowering ONNXTanhOp +//===----------------------------------------------------------------------===// +template <> +Value* mapToLowerScalarOp(Location loc, ArrayRef result_types, + ArrayRef operands, ConversionPatternRewriter& rewriter) { + // ONNXTanhOp(%X) = DivFOp(SubFOp(ExpOp(%X), ExpOp(NegFOp(%X))), + // AddFOp(ExpOp(%X), ExpOp(NegFOp(%X)))) + Value* operand = operands[0]; + auto zero = rewriter.create(loc, rewriter.getF32FloatAttr(0.0f)); + auto neg = rewriter.create(loc, zero, operand); + auto exp = rewriter.create(loc, operand); + auto negExp = rewriter.create(loc, neg); + auto result = + rewriter.create(loc, rewriter.create(loc, exp, negExp), + rewriter.create(loc, exp, negExp)); + return result; +} + +//===----------------------------------------------------------------------===// +// Scalar unary ops for lowering ONNXSinhOp +//===----------------------------------------------------------------------===// +template <> +Value* mapToLowerScalarOp(Location loc, ArrayRef result_types, + ArrayRef operands, ConversionPatternRewriter& rewriter) { + // ONNXSinhOp(%X) = DivFOp(SubFOp(ExpOp(%X), ExpOp(NegFOp(%X))), + // ConstantOp 2) + Value* operand = operands[0]; + auto zero = rewriter.create(loc, rewriter.getF32FloatAttr(0.0f)); + auto two = rewriter.create(loc, rewriter.getF32FloatAttr(2.0f)); + auto neg = rewriter.create(loc, zero, operand); + auto exp = rewriter.create(loc, operand); + auto negExp = rewriter.create(loc, neg); + auto result = rewriter.create( + loc, rewriter.create(loc, exp, negExp), two); + return result; +} + +//===----------------------------------------------------------------------===// +// Scalar unary ops for lowering ONNXCoshOp +//===----------------------------------------------------------------------===// +template <> +Value* mapToLowerScalarOp(Location loc, ArrayRef result_types, + ArrayRef operands, ConversionPatternRewriter& rewriter) { + // ONNXCoshOp(%X) = DivFOp(AddFOp(ExpOp(%X), ExpOp(NegFOp(%X))), + // ConstantOp 2) + Value* operand = operands[0]; + auto zero = rewriter.create(loc, rewriter.getF32FloatAttr(0.0f)); + auto two = rewriter.create(loc, rewriter.getF32FloatAttr(2.0f)); + auto neg = rewriter.create(loc, zero, operand); + auto exp = rewriter.create(loc, operand); + auto negExp = rewriter.create(loc, neg); + auto result = rewriter.create( + loc, rewriter.create(loc, exp, negExp), two); + return result; +} + +//===----------------------------------------------------------------------===// +// Scalar unary ops for lowering ONNXSigmoidOp +//===----------------------------------------------------------------------===// +template <> +Value* mapToLowerScalarOp(Location loc, + ArrayRef result_types, ArrayRef operands, + ConversionPatternRewriter& rewriter) { + // ONNXSigmoidOp(%X) = DivFOp(ConstantOp 1, + // AddFOp(ConstantOp 1, ExpOp(NegFOp(%X)))) + Value* operand = operands[0]; + auto zero = rewriter.create(loc, rewriter.getF32FloatAttr(0.0f)); + auto one = rewriter.create(loc, rewriter.getF32FloatAttr(1.0f)); + auto neg = rewriter.create(loc, zero, operand); + auto negExp = rewriter.create(loc, neg); + auto result = rewriter.create( + loc, one, rewriter.create(loc, one, negExp)); + return result; +} + +//===----------------------------------------------------------------------===// +// Element-wise n-ary ops lowering to Krnl dialect. +//===----------------------------------------------------------------------===// +template +struct ONNXElementwiseNaryOpLowering : public ConversionPattern { + ONNXElementwiseNaryOpLowering(MLIRContext* ctx) + : ConversionPattern(ElementwiseNaryOp::getOperationName(), 1, ctx) {} PatternMatchResult matchAndRewrite(Operation* op, ArrayRef operands, ConversionPatternRewriter& rewriter) const final { // TODO: Check that the types are valid. @@ -123,12 +274,11 @@ struct ONNXEWBinaryOpLowering : public ConversionPattern { // dimensions with the result at this pre-optimization phase. // TODO: verify that dimensions match. // TODO: can the dimension of the result differ after optimizations? - Value *alloc; + Value* alloc; bool insertDealloc = checkInsertDealloc(op); if (hasAllConstantDimensions(memRefType)) - alloc = insertAllocAndDealloc( - memRefType, loc, rewriter, insertDealloc); + alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc); else alloc = insertAllocAndDealloc( memRefType, loc, rewriter, insertDealloc, operands[0]); @@ -190,11 +340,15 @@ struct ONNXEWBinaryOpLowering : public ConversionPattern { SmallVector loopIVs; for (auto arg : iterationBlock.getArguments()) loopIVs.push_back(arg); - auto loadedFirstVal = rewriter.create(loc, operands[0], loopIVs); - auto loadedSecondVal = rewriter.create(loc, operands[1], loopIVs); - auto loweredOpResult = - rewriter.create(loc, loadedFirstVal, loadedSecondVal); + SmallVector loadedVals; + for (unsigned i = 0; i < numArgs; i++) { + auto loadedVal = rewriter.create(loc, operands[i], loopIVs); + loadedVals.push_back(loadedVal); + } + + auto loweredOpResult = mapToLowerScalarOp( + loc, memRefType.getElementType(), loadedVals, rewriter); // Store result in the resulting array. rewriter.create(loc, loweredOpResult, alloc, loopIVs); @@ -205,6 +359,13 @@ struct ONNXEWBinaryOpLowering : public ConversionPattern { } }; +template +using ONNXElementwiseUnaryOpLowering = + ONNXElementwiseNaryOpLowering; +template +using ONNXElementwiseBinaryOpLowering = + ONNXElementwiseNaryOpLowering; + //===----------------------------------------------------------------------===// // Conversion from Tensor type to the Standard dialect MemRef type. //===----------------------------------------------------------------------===// @@ -285,15 +446,18 @@ void FrontendToKrnlLoweringPass::runOnModule() { patterns, &getContext(), tensor_to_memref_converter); // Frontent operation lowering. - // TODO: Support 1-N mapping (e.g. different types of the lowered op) - patterns.insert, - ONNXEWBinaryOpLowering, - ONNXEWBinaryOpLowering, - ONNXEWBinaryOpLowering, - ONNXEWBinaryOpLowering, - ONNXEWBinaryOpLowering, - ONNXEWBinaryOpLowering> - (&getContext()); + patterns.insert, + ONNXElementwiseUnaryOpLowering, + ONNXElementwiseUnaryOpLowering, + ONNXElementwiseUnaryOpLowering, + ONNXElementwiseUnaryOpLowering, + ONNXElementwiseBinaryOpLowering, + ONNXElementwiseBinaryOpLowering, + ONNXElementwiseBinaryOpLowering, + ONNXElementwiseBinaryOpLowering, + ONNXElementwiseBinaryOpLowering, + ONNXElementwiseBinaryOpLowering, + ONNXElementwiseBinaryOpLowering>(&getContext()); // With the target and rewrite patterns defined, we can now attempt the // conversion. The conversion will signal failure if any of our `illegal` @@ -307,4 +471,4 @@ std::unique_ptr mlir::createLowerToKrnlPass() { } static PassRegistration pass( - "lower-frontend", "Lower frontend ops to Krnl dialect."); + "lower-frontend", "Lower frontend ops to Krnl dialect."); diff --git a/src/compiler/pass/shape_inference_pass.cpp b/src/compiler/pass/shape_inference_pass.cpp index 0bbd9d6..8ca4de5 100644 --- a/src/compiler/pass/shape_inference_pass.cpp +++ b/src/compiler/pass/shape_inference_pass.cpp @@ -88,16 +88,21 @@ class ShapeInferencePass : public mlir::FunctionPass { // All operations which do not return a ranked tensor type have dynamic // shaped outputs. All those operation need to implement the inferShape() // method. - if (op->getName().getStringRef() != "onnx.Add" && + if (op->getName().getStringRef() != "onnx.Exp" && + op->getName().getStringRef() != "onnx.Tanh" && + op->getName().getStringRef() != "onnx.Sinh" && + op->getName().getStringRef() != "onnx.Cosh" && + op->getName().getStringRef() != "onnx.Sigmoid" && op->getName().getStringRef() != "onnx.Mul" && + op->getName().getStringRef() != "onnx.Add" && op->getName().getStringRef() != "onnx.Div" && op->getName().getStringRef() != "onnx.Sub" && op->getName().getStringRef() != "onnx.And" && op->getName().getStringRef() != "onnx.Or" && op->getName().getStringRef() != "onnx.Xor" && - op->getName().getStringRef() != "onnx.MatMul" && - op->getName().getStringRef() != "onnx.Gemm" && - op->getName().getStringRef() != "onnx.FullGemm") + op->getName().getStringRef() != "onnx.MatMul" && + op->getName().getStringRef() != "onnx.Gemm" && + op->getName().getStringRef() != "onnx.FullGemm") return false; return llvm::any_of(op->getResultTypes(), [](Type result_type) { return !result_type.isa(); }); diff --git a/test/mlir/onnx/onnx_lowering.mlir b/test/mlir/onnx/onnx_lowering.mlir index d628a20..73a6896 100644 --- a/test/mlir/onnx/onnx_lowering.mlir +++ b/test/mlir/onnx/onnx_lowering.mlir @@ -139,3 +139,121 @@ func @test_xor(%arg0 : tensor, %arg1 : tensor) -> tensor<*xi // CHECK: store [[XOR]], [[RES]][%arg2, %arg3] : memref // CHECK: return [[RES]] : memref } + +func @test_exp(%arg0 : tensor) -> tensor<*xf32> { + %0 = "onnx.Exp"(%arg0) : (tensor) -> tensor<*xf32> + "std.return"(%0) : (tensor<*xf32>) -> () + + // CHECK-LABEL: test_exp + // CHECK: [[DIM_0:%.+]] = dim %arg0, 0 : memref + // CHECK: [[RES:%.+]] = alloc([[DIM_0]]) : memref + // CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2 + // CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops { + // CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1 + // CHECK: } : () -> (!krnl.loop, !krnl.loop) + // CHECK: [[DIM_2:%.+]] = dim %arg0, 0 : memref + // CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg1 = 0 to [[DIM_2]], [[DEF_LOOPS]]#1 -> %arg2 = 0 to 10) { + // CHECK: [[LOAD:%.+]] = load %arg0[%arg1, %arg2] : memref + // CHECK: [[EXP:%.+]] = exp [[LOAD]] : f32 + // CHECK: store [[EXP]], [[RES]][%arg1, %arg2] : memref + // CHECK: return [[RES]] : memref +} + +func @test_tanh(%arg0 : tensor) -> tensor<*xf32> { + %0 = "onnx.Tanh"(%arg0) : (tensor) -> tensor<*xf32> + "std.return"(%0) : (tensor<*xf32>) -> () + + // CHECK-LABEL: test_tanh + // CHECK: [[DIM_0:%.+]] = dim %arg0, 0 : memref + // CHECK: [[RES:%.+]] = alloc([[DIM_0]]) : memref + // CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2 + // CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops { + // CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1 + // CHECK: } : () -> (!krnl.loop, !krnl.loop) + // CHECK: [[DIM_2:%.+]] = dim %arg0, 0 : memref + // CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg1 = 0 to [[DIM_2]], [[DEF_LOOPS]]#1 -> %arg2 = 0 to 10) { + // CHECK: [[LOAD:%.+]] = load %arg0[%arg1, %arg2] : memref + // CHECK: [[ZERO:%.+]] = constant {{0.+}} : f32 + // CHECK: [[NLOAD:%.+]] = subf [[ZERO]], [[LOAD]] : f32 + // CHECK: [[EXP:%.+]] = exp [[LOAD]] : f32 + // CHECK: [[NEXP:%.+]] = exp [[NLOAD]] : f32 + // CHECK: [[DIVIDEND:%.+]] = subf [[EXP]], [[NEXP]] : f32 + // CHECK: [[DIVISOR:%.+]] = addf [[EXP]], [[NEXP]] : f32 + // CHECK: [[TANH_RES:%.+]] = divf [[DIVIDEND]], [[DIVISOR]] : f32 + // CHECK: store [[TANH_RES]], [[RES]][%arg1, %arg2] : memref + // CHECK: return [[RES]] : memref +} + +func @test_sinh(%arg0 : tensor) -> tensor<*xf32> { + %0 = "onnx.Sinh"(%arg0) : (tensor) -> tensor<*xf32> + "std.return"(%0) : (tensor<*xf32>) -> () + + // CHECK-LABEL: test_sinh + // CHECK: [[DIM_0:%.+]] = dim %arg0, 0 : memref + // CHECK: [[RES:%.+]] = alloc([[DIM_0]]) : memref + // CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2 + // CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops { + // CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1 + // CHECK: } : () -> (!krnl.loop, !krnl.loop) + // CHECK: [[DIM_2:%.+]] = dim %arg0, 0 : memref + // CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg1 = 0 to [[DIM_2]], [[DEF_LOOPS]]#1 -> %arg2 = 0 to 10) { + // CHECK: [[LOAD:%.+]] = load %arg0[%arg1, %arg2] : memref + // CHECK: [[ZERO:%.+]] = constant {{0.+}} : f32 + // CHECK: [[TWO:%.+]] = constant {{2.+}} : f32 + // CHECK: [[NLOAD:%.+]] = subf [[ZERO]], [[LOAD]] : f32 + // CHECK: [[EXP:%.+]] = exp [[LOAD]] : f32 + // CHECK: [[NEXP:%.+]] = exp [[NLOAD]] : f32 + // CHECK: [[DIVIDEND:%.+]] = subf [[EXP]], [[NEXP]] : f32 + // CHECK: [[SINH_RES:%.+]] = divf [[DIVIDEND]], [[TWO]] : f32 + // CHECK: store [[SINH_RES]], [[RES]][%arg1, %arg2] : memref + // CHECK: return [[RES]] : memref +} + +func @test_cosh(%arg0 : tensor) -> tensor<*xf32> { + %0 = "onnx.Cosh"(%arg0) : (tensor) -> tensor<*xf32> + "std.return"(%0) : (tensor<*xf32>) -> () + + // CHECK-LABEL: test_cosh + // CHECK: [[DIM_0:%.+]] = dim %arg0, 0 : memref + // CHECK: [[RES:%.+]] = alloc([[DIM_0]]) : memref + // CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2 + // CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops { + // CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1 + // CHECK: } : () -> (!krnl.loop, !krnl.loop) + // CHECK: [[DIM_2:%.+]] = dim %arg0, 0 : memref + // CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg1 = 0 to [[DIM_2]], [[DEF_LOOPS]]#1 -> %arg2 = 0 to 10) { + // CHECK: [[LOAD:%.+]] = load %arg0[%arg1, %arg2] : memref + // CHECK: [[ZERO:%.+]] = constant {{0.+}} : f32 + // CHECK: [[TWO:%.+]] = constant {{2.+}} : f32 + // CHECK: [[NLOAD:%.+]] = subf [[ZERO]], [[LOAD]] : f32 + // CHECK: [[EXP:%.+]] = exp [[LOAD]] : f32 + // CHECK: [[NEXP:%.+]] = exp [[NLOAD]] : f32 + // CHECK: [[DIVIDEND:%.+]] = addf [[EXP]], [[NEXP]] : f32 + // CHECK: [[COSH_RES:%.+]] = divf [[DIVIDEND]], [[TWO]] : f32 + // CHECK: store [[COSH_RES]], [[RES]][%arg1, %arg2] : memref + // CHECK: return [[RES]] : memref +} + +func @test_sigmoid(%arg0 : tensor) -> tensor<*xf32> { + %0 = "onnx.Sigmoid"(%arg0) : (tensor) -> tensor<*xf32> + "std.return"(%0) : (tensor<*xf32>) -> () + + // CHECK-LABEL: test_sigmoid + // CHECK: [[DIM_0:%.+]] = dim %arg0, 0 : memref + // CHECK: [[RES:%.+]] = alloc([[DIM_0]]) : memref + // CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2 + // CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops { + // CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1 + // CHECK: } : () -> (!krnl.loop, !krnl.loop) + // CHECK: [[DIM_2:%.+]] = dim %arg0, 0 : memref + // CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg1 = 0 to [[DIM_2]], [[DEF_LOOPS]]#1 -> %arg2 = 0 to 10) { + // CHECK: [[LOAD:%.+]] = load %arg0[%arg1, %arg2] : memref + // CHECK: [[ZERO:%.+]] = constant {{0.+}} : f32 + // CHECK: [[ONE:%.+]] = constant {{1.+}} : f32 + // CHECK: [[NLOAD:%.+]] = subf [[ZERO]], [[LOAD]] : f32 + // CHECK: [[NEXP:%.+]] = exp [[NLOAD]] : f32 + // CHECK: [[DIVISOR:%.+]] = addf [[ONE]], [[NEXP]] : f32 + // CHECK: [[SIGMOID_RES:%.+]] = divf [[ONE]], [[DIVISOR]] : f32 + // CHECK: store [[SIGMOID_RES]], [[RES]][%arg1, %arg2] : memref + // CHECK: return [[RES]] : memref +} \ No newline at end of file diff --git a/test/mlir/onnx/onnx_lowering_with_dealloc.mlir b/test/mlir/onnx/onnx_lowering_with_dealloc.mlir index 04dc616..c6bb8ef 100644 --- a/test/mlir/onnx/onnx_lowering_with_dealloc.mlir +++ b/test/mlir/onnx/onnx_lowering_with_dealloc.mlir @@ -287,3 +287,244 @@ func @test_xor_xor(%arg0 : tensor, %arg1 : tensor) -> tensor // CHECK: return [[RET_RES]] : memref } + +func @test_exp_exp(%arg0 : tensor) -> tensor<*xf32> { + %0 = "onnx.Exp"(%arg0) : (tensor) -> tensor<*xf32> + %1 = "onnx.Exp"(%0) : (tensor<*xf32>) -> tensor<*xf32> + "std.return"(%1) : (tensor<*xf32>) -> () + + // CHECK-LABEL: test_exp_exp + /// First Exp + // CHECK: [[DIM_0:%.+]] = dim %arg0, 0 : memref + // CHECK: [[RES:%.+]] = alloc([[DIM_0]]) : memref + // CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2 + // CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops { + // CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1 + // CHECK: } : () -> (!krnl.loop, !krnl.loop) + // CHECK: [[DIM_2:%.+]] = dim %arg0, 0 : memref + // CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg1 = 0 to [[DIM_2]], [[DEF_LOOPS]]#1 -> %arg2 = 0 to 10) { + // CHECK: [[LOAD:%.+]] = load %arg0[%arg1, %arg2] : memref + // CHECK: [[EXP:%.+]] = exp [[LOAD]] : f32 + // CHECK: store [[EXP]], [[RES]][%arg1, %arg2] : memref + + /// Second Exp + // CHECK: [[DIM_0:%.+]] = dim [[RES]], 0 : memref + // CHECK: [[RET_RES:%.+]] = alloc([[DIM_0]]) : memref + // CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2 + // CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops { + // CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1 + // CHECK: } : () -> (!krnl.loop, !krnl.loop) + // CHECK: [[DIM_2:%.+]] = dim [[RES]], 0 : memref + // CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg1 = 0 to [[DIM_2]], [[DEF_LOOPS]]#1 -> %arg2 = 0 to 10) { + // CHECK: [[LOAD:%.+]] = load [[RES]][%arg1, %arg2] : memref + // CHECK: [[EXP:%.+]] = exp [[LOAD]] : f32 + // CHECK: store [[EXP]], [[RET_RES]][%arg1, %arg2] : memref + + /// Dealloc of first result. + // CHECK: dealloc [[RES]] : memref + // CHECK-NOT: dealloc [[RET_RES]] : memref + + // CHECK: return [[RET_RES]] : memref +} + +func @test_tanh_tanh(%arg0 : tensor) -> tensor<*xf32> { + %0 = "onnx.Tanh"(%arg0) : (tensor) -> tensor<*xf32> + %1 = "onnx.Tanh"(%0) : (tensor<*xf32>) -> tensor<*xf32> + "std.return"(%1) : (tensor<*xf32>) -> () + + // CHECK-LABEL: test_tanh_tanh + /// First Tanh + // CHECK: [[DIM_0:%.+]] = dim %arg0, 0 : memref + // CHECK: [[RES:%.+]] = alloc([[DIM_0]]) : memref + // CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2 + // CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops { + // CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1 + // CHECK: } : () -> (!krnl.loop, !krnl.loop) + // CHECK: [[DIM_2:%.+]] = dim %arg0, 0 : memref + // CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg1 = 0 to [[DIM_2]], [[DEF_LOOPS]]#1 -> %arg2 = 0 to 10) { + // CHECK: [[LOAD:%.+]] = load %arg0[%arg1, %arg2] : memref + // CHECK: [[ZERO:%.+]] = constant {{0.+}} : f32 + // CHECK: [[NLOAD:%.+]] = subf [[ZERO]], [[LOAD]] : f32 + // CHECK: [[EXP:%.+]] = exp [[LOAD]] : f32 + // CHECK: [[NEXP:%.+]] = exp [[NLOAD]] : f32 + // CHECK: [[DIVIDEND:%.+]] = subf [[EXP]], [[NEXP]] : f32 + // CHECK: [[DIVISOR:%.+]] = addf [[EXP]], [[NEXP]] : f32 + // CHECK: [[TANH_RES:%.+]] = divf [[DIVIDEND]], [[DIVISOR]] : f32 + // CHECK: store [[TANH_RES]], [[RES]][%arg1, %arg2] : memref + + /// Second Tanh + // CHECK: [[DIM_0:%.+]] = dim [[RES]], 0 : memref + // CHECK: [[RET_RES:%.+]] = alloc([[DIM_0]]) : memref + // CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2 + // CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops { + // CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1 + // CHECK: } : () -> (!krnl.loop, !krnl.loop) + // CHECK: [[DIM_2:%.+]] = dim [[RES]], 0 : memref + // CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg1 = 0 to [[DIM_2]], [[DEF_LOOPS]]#1 -> %arg2 = 0 to 10) { + // CHECK: [[LOAD:%.+]] = load [[RES]][%arg1, %arg2] : memref + // CHECK: [[ZERO:%.+]] = constant {{0.+}} : f32 + // CHECK: [[NLOAD:%.+]] = subf [[ZERO]], [[LOAD]] : f32 + // CHECK: [[EXP:%.+]] = exp [[LOAD]] : f32 + // CHECK: [[NEXP:%.+]] = exp [[NLOAD]] : f32 + // CHECK: [[DIVIDEND:%.+]] = subf [[EXP]], [[NEXP]] : f32 + // CHECK: [[DIVISOR:%.+]] = addf [[EXP]], [[NEXP]] : f32 + // CHECK: [[TANH_RES:%.+]] = divf [[DIVIDEND]], [[DIVISOR]] : f32 + // CHECK: store [[TANH_RES]], [[RET_RES]][%arg1, %arg2] : memref + + /// Dealloc of first result. + // CHECK: dealloc [[RES]] : memref + // CHECK-NOT: dealloc [[RET_RES]] : memref + + // CHECK: return [[RET_RES]] : memref +} + +func @test_sinh_sinh(%arg0 : tensor) -> tensor<*xf32> { + %0 = "onnx.Sinh"(%arg0) : (tensor) -> tensor<*xf32> + %1 = "onnx.Sinh"(%0) : (tensor<*xf32>) -> tensor<*xf32> + "std.return"(%1) : (tensor<*xf32>) -> () + + // CHECK-LABEL: test_sinh_sinh + /// First Sinh + // CHECK: [[DIM_0:%.+]] = dim %arg0, 0 : memref + // CHECK: [[RES:%.+]] = alloc([[DIM_0]]) : memref + // CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2 + // CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops { + // CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1 + // CHECK: } : () -> (!krnl.loop, !krnl.loop) + // CHECK: [[DIM_2:%.+]] = dim %arg0, 0 : memref + // CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg1 = 0 to [[DIM_2]], [[DEF_LOOPS]]#1 -> %arg2 = 0 to 10) { + // CHECK: [[LOAD:%.+]] = load %arg0[%arg1, %arg2] : memref + // CHECK: [[ZERO:%.+]] = constant {{0.+}} : f32 + // CHECK: [[TWO:%.+]] = constant {{2.+}} : f32 + // CHECK: [[NLOAD:%.+]] = subf [[ZERO]], [[LOAD]] : f32 + // CHECK: [[EXP:%.+]] = exp [[LOAD]] : f32 + // CHECK: [[NEXP:%.+]] = exp [[NLOAD]] : f32 + // CHECK: [[DIVIDEND:%.+]] = subf [[EXP]], [[NEXP]] : f32 + // CHECK: [[SINH_RES:%.+]] = divf [[DIVIDEND]], [[TWO]] : f32 + // CHECK: store [[SINH_RES]], [[RES]][%arg1, %arg2] : memref + + /// Second Sinh + // CHECK: [[DIM_0:%.+]] = dim [[RES]], 0 : memref + // CHECK: [[RET_RES:%.+]] = alloc([[DIM_0]]) : memref + // CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2 + // CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops { + // CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1 + // CHECK: } : () -> (!krnl.loop, !krnl.loop) + // CHECK: [[DIM_2:%.+]] = dim [[RES]], 0 : memref + // CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg1 = 0 to [[DIM_2]], [[DEF_LOOPS]]#1 -> %arg2 = 0 to 10) { + // CHECK: [[LOAD:%.+]] = load [[RES]][%arg1, %arg2] : memref + // CHECK: [[ZERO:%.+]] = constant {{0.+}} : f32 + // CHECK: [[TWO:%.+]] = constant {{2.+}} : f32 + // CHECK: [[NLOAD:%.+]] = subf [[ZERO]], [[LOAD]] : f32 + // CHECK: [[EXP:%.+]] = exp [[LOAD]] : f32 + // CHECK: [[NEXP:%.+]] = exp [[NLOAD]] : f32 + // CHECK: [[DIVIDEND:%.+]] = subf [[EXP]], [[NEXP]] : f32 + // CHECK: [[SINH_RES:%.+]] = divf [[DIVIDEND]], [[TWO]] : f32 + // CHECK: store [[SINH_RES]], [[RET_RES]][%arg1, %arg2] : memref + + /// Dealloc of first result. + // CHECK: dealloc [[RES]] : memref + // CHECK-NOT: dealloc [[RET_RES]] : memref + + // CHECK: return [[RET_RES]] : memref +} + +func @test_cosh_cosh(%arg0 : tensor) -> tensor<*xf32> { + %0 = "onnx.Cosh"(%arg0) : (tensor) -> tensor<*xf32> + %1 = "onnx.Cosh"(%0) : (tensor<*xf32>) -> tensor<*xf32> + "std.return"(%1) : (tensor<*xf32>) -> () + + // CHECK-LABEL: test_cosh_cosh + /// First Cosh + // CHECK: [[DIM_0:%.+]] = dim %arg0, 0 : memref + // CHECK: [[RES:%.+]] = alloc([[DIM_0]]) : memref + // CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2 + // CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops { + // CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1 + // CHECK: } : () -> (!krnl.loop, !krnl.loop) + // CHECK: [[DIM_2:%.+]] = dim %arg0, 0 : memref + // CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg1 = 0 to [[DIM_2]], [[DEF_LOOPS]]#1 -> %arg2 = 0 to 10) { + // CHECK: [[LOAD:%.+]] = load %arg0[%arg1, %arg2] : memref + // CHECK: [[ZERO:%.+]] = constant {{0.+}} : f32 + // CHECK: [[TWO:%.+]] = constant {{2.+}} : f32 + // CHECK: [[NLOAD:%.+]] = subf [[ZERO]], [[LOAD]] : f32 + // CHECK: [[EXP:%.+]] = exp [[LOAD]] : f32 + // CHECK: [[NEXP:%.+]] = exp [[NLOAD]] : f32 + // CHECK: [[DIVIDEND:%.+]] = addf [[EXP]], [[NEXP]] : f32 + // CHECK: [[COSH_RES:%.+]] = divf [[DIVIDEND]], [[TWO]] : f32 + // CHECK: store [[COSH_RES]], [[RES]][%arg1, %arg2] : memref + + /// Second Cosh + // CHECK: [[DIM_0:%.+]] = dim [[RES]], 0 : memref + // CHECK: [[RET_RES:%.+]] = alloc([[DIM_0]]) : memref + // CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2 + // CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops { + // CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1 + // CHECK: } : () -> (!krnl.loop, !krnl.loop) + // CHECK: [[DIM_2:%.+]] = dim [[RES]], 0 : memref + // CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg1 = 0 to [[DIM_2]], [[DEF_LOOPS]]#1 -> %arg2 = 0 to 10) { + // CHECK: [[LOAD:%.+]] = load [[RES]][%arg1, %arg2] : memref + // CHECK: [[ZERO:%.+]] = constant {{0.+}} : f32 + // CHECK: [[TWO:%.+]] = constant {{2.+}} : f32 + // CHECK: [[NLOAD:%.+]] = subf [[ZERO]], [[LOAD]] : f32 + // CHECK: [[EXP:%.+]] = exp [[LOAD]] : f32 + // CHECK: [[NEXP:%.+]] = exp [[NLOAD]] : f32 + // CHECK: [[DIVIDEND:%.+]] = addf [[EXP]], [[NEXP]] : f32 + // CHECK: [[COSH_RES:%.+]] = divf [[DIVIDEND]], [[TWO]] : f32 + // CHECK: store [[COSH_RES]], [[RET_RES]][%arg1, %arg2] : memref + + /// Dealloc of first result. + // CHECK: dealloc [[RES]] : memref + // CHECK-NOT: dealloc [[RET_RES]] : memref + + // CHECK: return [[RET_RES]] : memref +} + +func @test_sigmoid_sigmoid(%arg0 : tensor) -> tensor<*xf32> { + %0 = "onnx.Sigmoid"(%arg0) : (tensor) -> tensor<*xf32> + %1 = "onnx.Sigmoid"(%0) : (tensor<*xf32>) -> tensor<*xf32> + "std.return"(%1) : (tensor<*xf32>) -> () + + // CHECK-LABEL: test_sigmoid_sigmoid + /// First Sigmoid + // CHECK: [[DIM_0:%.+]] = dim %arg0, 0 : memref + // CHECK: [[RES:%.+]] = alloc([[DIM_0]]) : memref + // CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2 + // CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops { + // CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1 + // CHECK: } : () -> (!krnl.loop, !krnl.loop) + // CHECK: [[DIM_2:%.+]] = dim %arg0, 0 : memref + // CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg1 = 0 to [[DIM_2]], [[DEF_LOOPS]]#1 -> %arg2 = 0 to 10) { + // CHECK: [[LOAD:%.+]] = load %arg0[%arg1, %arg2] : memref + // CHECK: [[ZERO:%.+]] = constant {{0.+}} : f32 + // CHECK: [[ONE:%.+]] = constant {{1.+}} : f32 + // CHECK: [[NLOAD:%.+]] = subf [[ZERO]], [[LOAD]] : f32 + // CHECK: [[NEXP:%.+]] = exp [[NLOAD]] : f32 + // CHECK: [[DIVISOR:%.+]] = addf [[ONE]], [[NEXP]] : f32 + // CHECK: [[SIGMOID_RES:%.+]] = divf [[ONE]], [[DIVISOR]] : f32 + // CHECK: store [[SIGMOID_RES]], [[RES]][%arg1, %arg2] : memref + + /// Second Sigmoid + // CHECK: [[DIM_0:%.+]] = dim [[RES]], 0 : memref + // CHECK: [[RET_RES:%.+]] = alloc([[DIM_0]]) : memref + // CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2 + // CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops { + // CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1 + // CHECK: } : () -> (!krnl.loop, !krnl.loop) + // CHECK: [[DIM_2:%.+]] = dim [[RES]], 0 : memref + // CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg1 = 0 to [[DIM_2]], [[DEF_LOOPS]]#1 -> %arg2 = 0 to 10) { + // CHECK: [[LOAD:%.+]] = load [[RES]][%arg1, %arg2] : memref + // CHECK: [[ZERO:%.+]] = constant {{0.+}} : f32 + // CHECK: [[ONE:%.+]] = constant {{1.+}} : f32 + // CHECK: [[NLOAD:%.+]] = subf [[ZERO]], [[LOAD]] : f32 + // CHECK: [[NEXP:%.+]] = exp [[NLOAD]] : f32 + // CHECK: [[DIVISOR:%.+]] = addf [[ONE]], [[NEXP]] : f32 + // CHECK: [[SIGMOID_RES:%.+]] = divf [[ONE]], [[DIVISOR]] : f32 + // CHECK: store [[SIGMOID_RES]], [[RET_RES]][%arg1, %arg2] : memref + + /// Dealloc of first result. + // CHECK: dealloc [[RES]] : memref + // CHECK-NOT: dealloc [[RET_RES]] : memref + + // CHECK: return [[RET_RES]] : memref +} \ No newline at end of file