diff --git a/src/compiler/dialect/onnx/gen_doc.py b/src/compiler/dialect/onnx/gen_doc.py index b63af6e..3e22199 100644 --- a/src/compiler/dialect/onnx/gen_doc.py +++ b/src/compiler/dialect/onnx/gen_doc.py @@ -263,7 +263,7 @@ def collect_types(schema, input) : return allowedTypeStr def gen_schema(schema) : - ShapeInferenceList=['Exp', 'Tanh', 'Sinh', 'Cosh', 'Sigmoid', + ShapeInferenceList=['Exp', 'Tanh', 'Sinh', 'Cosh', 'Sigmoid', 'Relu', 'Add', 'Mul', 'Div', 'Sub', 'And', 'Or', 'Xor', 'MatMul', 'Gemm'] CanonicalList=['Add', 'Identity'] diff --git a/src/compiler/dialect/onnx/onnx_ops.cpp b/src/compiler/dialect/onnx/onnx_ops.cpp index a6299f3..0650a24 100644 --- a/src/compiler/dialect/onnx/onnx_ops.cpp +++ b/src/compiler/dialect/onnx/onnx_ops.cpp @@ -78,6 +78,14 @@ void ONNXSigmoidOp::inferShapes() { getResult()->setType(getOperand()->getType()); } +//===----------------------------------------------------------------------===// +// Relu +/// Infer the output shape of the ONNXReluOp. This method is required by the +/// shape inference interface. +void ONNXReluOp::inferShapes() { + getResult()->setType(getOperand()->getType()); +} + //===----------------------------------------------------------------------===// // Add /// Infer the output shape of the ONNXAddOp. This method is required by the diff --git a/src/compiler/dialect/onnx/onnxop.inc b/src/compiler/dialect/onnx/onnxop.inc index d44f312..6aa42ae 100644 --- a/src/compiler/dialect/onnx/onnxop.inc +++ b/src/compiler/dialect/onnx/onnxop.inc @@ -2185,7 +2185,7 @@ def ONNXReduceSumSquareOp:ONNX_Op<"ReduceSumSquare", } def ONNXReluOp:ONNX_Op<"Relu", - [NoSideEffect]> { + [NoSideEffect, DeclareOpInterfaceMethods]> { let summary = "ONNX Relu operation"; let description = [{ "Relu takes one input data (Tensor) and produces one output data" diff --git a/src/compiler/pass/lower_frontend_to_krnl.cpp b/src/compiler/pass/lower_frontend_to_krnl.cpp index 676ac20..6e28d19 100644 --- a/src/compiler/pass/lower_frontend_to_krnl.cpp +++ b/src/compiler/pass/lower_frontend_to_krnl.cpp @@ -251,6 +251,23 @@ Value* mapToLowerScalarOp(Location loc, return result; } +//===----------------------------------------------------------------------===// +// Scalar unary ops for lowering ONNXReluOp +//===----------------------------------------------------------------------===// +template <> +Value* mapToLowerScalarOp(Location loc, ArrayRef result_types, + ArrayRef operands, ConversionPatternRewriter& rewriter) { + // ONNXReluOp(%X) = SelectOp(CmpFOp(OLT, %X, ConstantOp 0), + // ConstantOp 0, + // %X) + Value* operand = operands[0]; + auto zero = rewriter.create(loc, rewriter.getF32FloatAttr(0.0f)); + auto lessThanZero = + rewriter.create(loc, CmpFPredicate::OLT, operand, zero); + auto result = rewriter.create(loc, lessThanZero, zero, operand); + return result; +} + //===----------------------------------------------------------------------===// // Element-wise n-ary ops lowering to Krnl dialect. //===----------------------------------------------------------------------===// @@ -452,6 +469,7 @@ void FrontendToKrnlLoweringPass::runOnModule() { ONNXElementwiseUnaryOpLowering, ONNXElementwiseUnaryOpLowering, ONNXElementwiseUnaryOpLowering, + ONNXElementwiseUnaryOpLowering, ONNXElementwiseBinaryOpLowering, ONNXElementwiseBinaryOpLowering, ONNXElementwiseBinaryOpLowering, diff --git a/src/compiler/pass/shape_inference_pass.cpp b/src/compiler/pass/shape_inference_pass.cpp index 8ca4de5..138a793 100644 --- a/src/compiler/pass/shape_inference_pass.cpp +++ b/src/compiler/pass/shape_inference_pass.cpp @@ -89,20 +89,21 @@ class ShapeInferencePass : public mlir::FunctionPass { // shaped outputs. All those operation need to implement the inferShape() // method. if (op->getName().getStringRef() != "onnx.Exp" && - op->getName().getStringRef() != "onnx.Tanh" && - op->getName().getStringRef() != "onnx.Sinh" && - op->getName().getStringRef() != "onnx.Cosh" && - op->getName().getStringRef() != "onnx.Sigmoid" && - op->getName().getStringRef() != "onnx.Mul" && - op->getName().getStringRef() != "onnx.Add" && - op->getName().getStringRef() != "onnx.Div" && - op->getName().getStringRef() != "onnx.Sub" && - op->getName().getStringRef() != "onnx.And" && - op->getName().getStringRef() != "onnx.Or" && - op->getName().getStringRef() != "onnx.Xor" && - op->getName().getStringRef() != "onnx.MatMul" && - op->getName().getStringRef() != "onnx.Gemm" && - op->getName().getStringRef() != "onnx.FullGemm") + op->getName().getStringRef() != "onnx.Tanh" && + op->getName().getStringRef() != "onnx.Sinh" && + op->getName().getStringRef() != "onnx.Cosh" && + op->getName().getStringRef() != "onnx.Sigmoid" && + op->getName().getStringRef() != "onnx.Relu" && + op->getName().getStringRef() != "onnx.Mul" && + op->getName().getStringRef() != "onnx.Add" && + op->getName().getStringRef() != "onnx.Div" && + op->getName().getStringRef() != "onnx.Sub" && + op->getName().getStringRef() != "onnx.And" && + op->getName().getStringRef() != "onnx.Or" && + op->getName().getStringRef() != "onnx.Xor" && + op->getName().getStringRef() != "onnx.MatMul" && + op->getName().getStringRef() != "onnx.Gemm" && + op->getName().getStringRef() != "onnx.FullGemm") return false; return llvm::any_of(op->getResultTypes(), [](Type result_type) { return !result_type.isa(); }); @@ -118,4 +119,4 @@ std::unique_ptr mlir::createShapeInferencePass() { } static PassRegistration pass( - "shape-inference", "Shape inference for frontend dialects."); + "shape-inference", "Shape inference for frontend dialects."); diff --git a/test/mlir/onnx/onnx_lowering.mlir b/test/mlir/onnx/onnx_lowering.mlir index 73a6896..7f1df67 100644 --- a/test/mlir/onnx/onnx_lowering.mlir +++ b/test/mlir/onnx/onnx_lowering.mlir @@ -256,4 +256,25 @@ func @test_sigmoid(%arg0 : tensor) -> tensor<*xf32> { // CHECK: [[SIGMOID_RES:%.+]] = divf [[ONE]], [[DIVISOR]] : f32 // CHECK: store [[SIGMOID_RES]], [[RES]][%arg1, %arg2] : memref // CHECK: return [[RES]] : memref -} \ No newline at end of file +} + +func @test_relu(%arg0 : tensor) -> tensor<*xf32> { + %0 = "onnx.Relu"(%arg0) : (tensor) -> tensor<*xf32> + "std.return"(%0) : (tensor<*xf32>) -> () + + // CHECK-LABEL: test_relu + // CHECK: [[DIM_0:%.+]] = dim %arg0, 0 : memref + // CHECK: [[RES:%.+]] = alloc([[DIM_0]]) : memref + // CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2 + // CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops { + // CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1 + // CHECK: } : () -> (!krnl.loop, !krnl.loop) + // CHECK: [[DIM_2:%.+]] = dim %arg0, 0 : memref + // CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg1 = 0 to [[DIM_2]], [[DEF_LOOPS]]#1 -> %arg2 = 0 to 10) { + // CHECK: [[LOAD:%.+]] = load %arg0[%arg1, %arg2] : memref + // CHECK: [[ZERO:%.+]] = constant {{0.+}} : f32 + // CHECK: [[LTZERO:%.+]] = cmpf "olt", [[LOAD]], [[ZERO]] : f32 + // CHECK: [[RELU_RES:%.+]] = select [[LTZERO]], [[ZERO]], [[LOAD]] : f32 + // CHECK: store [[RELU_RES]], [[RES]][%arg1, %arg2] : memref + // CHECK: return [[RES]] : memref +} diff --git a/test/mlir/onnx/onnx_lowering_with_dealloc.mlir b/test/mlir/onnx/onnx_lowering_with_dealloc.mlir index c6bb8ef..6310516 100644 --- a/test/mlir/onnx/onnx_lowering_with_dealloc.mlir +++ b/test/mlir/onnx/onnx_lowering_with_dealloc.mlir @@ -527,4 +527,47 @@ func @test_sigmoid_sigmoid(%arg0 : tensor) -> tensor<*xf32> { // CHECK-NOT: dealloc [[RET_RES]] : memref // CHECK: return [[RET_RES]] : memref -} \ No newline at end of file +} + +func @test_relu_relu(%arg0 : tensor) -> tensor<*xf32> { + %0 = "onnx.Relu"(%arg0) : (tensor) -> tensor<*xf32> + %1 = "onnx.Relu"(%0) : (tensor<*xf32>) -> tensor<*xf32> + "std.return"(%1) : (tensor<*xf32>) -> () + + // CHECK-LABEL: test_relu_relu + /// First Relu + // CHECK: [[DIM_0:%.+]] = dim %arg0, 0 : memref + // CHECK: [[RES:%.+]] = alloc([[DIM_0]]) : memref + // CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2 + // CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops { + // CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1 + // CHECK: } : () -> (!krnl.loop, !krnl.loop) + // CHECK: [[DIM_2:%.+]] = dim %arg0, 0 : memref + // CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg1 = 0 to [[DIM_2]], [[DEF_LOOPS]]#1 -> %arg2 = 0 to 10) { + // CHECK: [[LOAD:%.+]] = load %arg0[%arg1, %arg2] : memref + // CHECK: [[ZERO:%.+]] = constant {{0.+}} : f32 + // CHECK: [[LTZERO:%.+]] = cmpf "olt", [[LOAD]], [[ZERO]] : f32 + // CHECK: [[RELU_RES:%.+]] = select [[LTZERO]], [[ZERO]], [[LOAD]] : f32 + // CHECK: store [[RELU_RES]], [[RES]][%arg1, %arg2] : memref + + /// Second Relu + // CHECK: [[DIM_0:%.+]] = dim [[RES]], 0 : memref + // CHECK: [[RET_RES:%.+]] = alloc([[DIM_0]]) : memref + // CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2 + // CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops { + // CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1 + // CHECK: } : () -> (!krnl.loop, !krnl.loop) + // CHECK: [[DIM_2:%.+]] = dim [[RES]], 0 : memref + // CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg1 = 0 to [[DIM_2]], [[DEF_LOOPS]]#1 -> %arg2 = 0 to 10) { + // CHECK: [[LOAD:%.+]] = load [[RES]][%arg1, %arg2] : memref + // CHECK: [[ZERO:%.+]] = constant {{0.+}} : f32 + // CHECK: [[LTZERO:%.+]] = cmpf "olt", [[LOAD]], [[ZERO]] : f32 + // CHECK: [[RELU_RES:%.+]] = select [[LTZERO]], [[ZERO]], [[LOAD]] : f32 + // CHECK: store [[RELU_RES]], [[RET_RES]][%arg1, %arg2] : memref + + /// Dealloc of first result. + // CHECK: dealloc [[RES]] : memref + // CHECK-NOT: dealloc [[RET_RES]] : memref + + // CHECK: return [[RET_RES]] : memref +}