diff --git a/src/compiler/dialect/onnx/gen_doc.py b/src/compiler/dialect/onnx/gen_doc.py index 54998a4..8d3e728 100644 --- a/src/compiler/dialect/onnx/gen_doc.py +++ b/src/compiler/dialect/onnx/gen_doc.py @@ -266,7 +266,7 @@ def gen_schema(schema) : ShapeInferenceList=['Exp', 'Tanh', 'Sinh', 'Cosh', 'Sigmoid', 'Relu', 'Add', 'Mul', 'Div', 'Sub', 'And', 'Or', 'Xor', 'Sum', 'Max', 'Min', 'MatMul', 'Gemm', 'LeakyRelu', - 'Elu', 'Selu', 'HardSigmoid', 'Reshape'] + 'Elu', 'Selu', 'HardSigmoid', 'Reshape', 'Reciprocal'] CanonicalList=['Add', 'Identity'] line_indent = ' ' diff --git a/src/compiler/dialect/onnx/onnx_ops.cpp b/src/compiler/dialect/onnx/onnx_ops.cpp index 00d53da..1fb5fea 100644 --- a/src/compiler/dialect/onnx/onnx_ops.cpp +++ b/src/compiler/dialect/onnx/onnx_ops.cpp @@ -114,6 +114,14 @@ void ONNXSeluOp::inferShapes() { getResult()->setType(getOperand()->getType()); } +//===----------------------------------------------------------------------===// +// Reciprocal +/// Infer the output shape of the ONNXReciprocalOp. This method is required by +/// the shape inference interface. +void ONNXReciprocalOp::inferShapes() { + getResult()->setType(getOperand()->getType()); +} + //===----------------------------------------------------------------------===// // Add /// Infer the output shape of the ONNXAddOp. This method is required by the diff --git a/src/compiler/dialect/onnx/onnxop.inc b/src/compiler/dialect/onnx/onnxop.inc index 1cdef19..1ac969e 100644 --- a/src/compiler/dialect/onnx/onnxop.inc +++ b/src/compiler/dialect/onnx/onnxop.inc @@ -1821,6 +1821,8 @@ def ONNXQLinearConvOp:ONNX_Op<"QLinearConv", "and computes the quantized output. Each scale and zero-point pair must have same shape." "It means they must be either scalars (per tensor) or 1-D tensors (per output channel)." "Each input or output and its related zero point must have same type." + "When bias is present it must be quantized using scale = input scale * weight scale and " + "zero point as 0." }]; let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$x, AnyTypeOf<[AnyMemRef, AnyTensor]>:$x_scale, AnyTypeOf<[AnyMemRef, AnyTensor]>:$x_zero_point, AnyTypeOf<[AnyMemRef, AnyTensor]>:$w, AnyTypeOf<[AnyMemRef, AnyTensor]>:$w_scale, AnyTypeOf<[AnyMemRef, AnyTensor]>:$w_zero_point, AnyTypeOf<[AnyMemRef, AnyTensor]>:$y_scale, AnyTypeOf<[AnyMemRef, AnyTensor]>:$y_zero_point, AnyTypeOf<[AnyMemRef, AnyTensor]>:$B); let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>); @@ -2023,7 +2025,7 @@ def ONNXRangeOp:ONNX_Op<"Range", } def ONNXReciprocalOp:ONNX_Op<"Reciprocal", - [NoSideEffect]> { + [NoSideEffect, DeclareOpInterfaceMethods]> { let summary = "ONNX Reciprocal operation"; let description = [{ "Reciprocal takes one input data (Tensor) and produces one output data" diff --git a/src/compiler/pass/lower_frontend_to_krnl.cpp b/src/compiler/pass/lower_frontend_to_krnl.cpp index 30f132d..593bcbb 100644 --- a/src/compiler/pass/lower_frontend_to_krnl.cpp +++ b/src/compiler/pass/lower_frontend_to_krnl.cpp @@ -425,6 +425,22 @@ Value* mapToLowerScalarOp(Operation* op, return result; } +//===----------------------------------------------------------------------===// +// Scalar unary ops for lowering ONNXReciprocalOp +//===----------------------------------------------------------------------===// +template <> +Value* mapToLowerScalarOp(Operation* op, ArrayRef result_types, + ArrayRef operands, ConversionPatternRewriter& rewriter) { + // ONNXReciprocalOp(%X) = DivFOp(ConstantOp 1, %X) + auto loc = op->getLoc(); + Value* operand = operands[0]; + + auto one = rewriter.create(loc, rewriter.getF32FloatAttr(1.0f)); + auto result = rewriter.create(loc, one, operand); + + return result; +} + //===----------------------------------------------------------------------===// // Scalar unary ops for lowering ONNXMaxOp //===----------------------------------------------------------------------===// @@ -815,6 +831,7 @@ void FrontendToKrnlLoweringPass::runOnModule() { ONNXElementwiseUnaryOpLowering, ONNXElementwiseUnaryOpLowering, ONNXElementwiseUnaryOpLowering, + ONNXElementwiseUnaryOpLowering, ONNXElementwiseVariadicOpLowering, ONNXElementwiseVariadicOpLowering, ONNXElementwiseVariadicOpLowering, diff --git a/src/compiler/pass/shape_inference_pass.cpp b/src/compiler/pass/shape_inference_pass.cpp index f44bed9..acddc98 100644 --- a/src/compiler/pass/shape_inference_pass.cpp +++ b/src/compiler/pass/shape_inference_pass.cpp @@ -98,6 +98,7 @@ class ShapeInferencePass : public mlir::FunctionPass { op->getName().getStringRef() != "onnx.Relu" && op->getName().getStringRef() != "onnx.LeakyRelu" && op->getName().getStringRef() != "onnx.Selu" && + op->getName().getStringRef() != "onnx.Reciprocal" && op->getName().getStringRef() != "onnx.Mul" && op->getName().getStringRef() != "onnx.Add" && op->getName().getStringRef() != "onnx.Div" && diff --git a/test/mlir/onnx/onnx_lowering.mlir b/test/mlir/onnx/onnx_lowering.mlir index 17d3609..9cff02c 100644 --- a/test/mlir/onnx/onnx_lowering.mlir +++ b/test/mlir/onnx/onnx_lowering.mlir @@ -475,3 +475,23 @@ func @test_hardsigmoid(%arg0 : tensor) -> tensor<*xf32> { // CHECK: store [[SELECT2]], [[RES]][%arg1, %arg2] : memref // CHECK: return [[RES]] : memref } + +func @test_reciprocal(%arg0 : tensor) -> tensor<*xf32> { + %0 = "onnx.Reciprocal"(%arg0) : (tensor) -> tensor<*xf32> + "std.return"(%0) : (tensor<*xf32>) -> () + + // CHECK-LABEL: test_reciprocal + // CHECK: [[DIM_0:%.+]] = dim %arg0, 0 : memref + // CHECK: [[RES:%.+]] = alloc([[DIM_0]]) : memref + // CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2 + // CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops { + // CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1 + // CHECK: } : () -> (!krnl.loop, !krnl.loop) + // CHECK: [[DIM_2:%.+]] = dim %arg0, 0 : memref + // CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg1 = 0 to [[DIM_2]], [[DEF_LOOPS]]#1 -> %arg2 = 0 to 10) { + // CHECK: [[LOAD:%.+]] = load %arg0[%arg1, %arg2] : memref + // CHECK: [[ONE:%.+]] = constant {{1.+}} : f32 + // CHECK: [[RECIPROCAL_RES:%.+]] = divf [[ONE]], [[LOAD]] : f32 + // CHECK: store [[RECIPROCAL_RES]], [[RES]][%arg1, %arg2] : memref + // CHECK: return [[RES]] : memref +} diff --git a/test/mlir/onnx/onnx_lowering_with_dealloc.mlir b/test/mlir/onnx/onnx_lowering_with_dealloc.mlir index 749829e..cbd4d39 100644 --- a/test/mlir/onnx/onnx_lowering_with_dealloc.mlir +++ b/test/mlir/onnx/onnx_lowering_with_dealloc.mlir @@ -910,3 +910,44 @@ func @test_hardsigmoid_hardsigmoid(%arg0 : tensor) -> tensor<*xf32> { // CHECK: return [[RET_RES]] : memref } + +func @test_reciprocal_reciprocal(%arg0 : tensor) -> tensor<*xf32> { + %0 = "onnx.Reciprocal"(%arg0) : (tensor) -> tensor<*xf32> + %1 = "onnx.Reciprocal"(%0) : (tensor<*xf32>) -> tensor<*xf32> + "std.return"(%1) : (tensor<*xf32>) -> () + + // CHECK-LABEL: test_reciprocal_reciprocal + /// First Reciprocal + // CHECK: [[DIM_0:%.+]] = dim %arg0, 0 : memref + // CHECK: [[RES:%.+]] = alloc([[DIM_0]]) : memref + // CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2 + // CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops { + // CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1 + // CHECK: } : () -> (!krnl.loop, !krnl.loop) + // CHECK: [[DIM_2:%.+]] = dim %arg0, 0 : memref + // CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg1 = 0 to [[DIM_2]], [[DEF_LOOPS]]#1 -> %arg2 = 0 to 10) { + // CHECK: [[LOAD:%.+]] = load %arg0[%arg1, %arg2] : memref + // CHECK: [[ONE:%.+]] = constant {{1.+}} : f32 + // CHECK: [[RECIPROCAL_RES:%.+]] = divf [[ONE]], [[LOAD]] : f32 + // CHECK: store [[RECIPROCAL_RES]], [[RES]][%arg1, %arg2] : memref + + /// Second Reciprocal + // CHECK: [[DIM_0:%.+]] = dim [[RES]], 0 : memref + // CHECK: [[RET_RES:%.+]] = alloc([[DIM_0]]) : memref + // CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2 + // CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops { + // CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1 + // CHECK: } : () -> (!krnl.loop, !krnl.loop) + // CHECK: [[DIM_2:%.+]] = dim [[RES]], 0 : memref + // CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg1 = 0 to [[DIM_2]], [[DEF_LOOPS]]#1 -> %arg2 = 0 to 10) { + // CHECK: [[LOAD:%.+]] = load [[RES]][%arg1, %arg2] : memref + // CHECK: [[ONE:%.+]] = constant {{1.+}} : f32 + // CHECK: [[RECIPROCAL_RES:%.+]] = divf [[ONE]], [[LOAD]] : f32 + // CHECK: store [[RECIPROCAL_RES]], [[RET_RES]][%arg1, %arg2] : memref + + /// Dealloc of first result. + // CHECK: dealloc [[RES]] : memref + // CHECK-NOT: dealloc [[RET_RES]] : memref + + // CHECK: return [[RET_RES]] : memref +}