[MLIR] Add support for Reciprocal (#397)
* Added support for Reciprocal * Fixed format
This commit is contained in:
		
							parent
							
								
									3e7b8465e9
								
							
						
					
					
						commit
						7e3f96e642
					
				| 
						 | 
				
			
			@ -266,7 +266,7 @@ def gen_schema(schema) :
 | 
			
		|||
    ShapeInferenceList=['Exp', 'Tanh', 'Sinh', 'Cosh', 'Sigmoid', 'Relu',
 | 
			
		||||
                        'Add', 'Mul', 'Div', 'Sub', 'And', 'Or', 'Xor',
 | 
			
		||||
                        'Sum', 'Max', 'Min', 'MatMul', 'Gemm', 'LeakyRelu',
 | 
			
		||||
                        'Elu', 'Selu', 'HardSigmoid', 'Reshape']
 | 
			
		||||
                        'Elu', 'Selu', 'HardSigmoid', 'Reshape', 'Reciprocal']
 | 
			
		||||
    CanonicalList=['Add', 'Identity']
 | 
			
		||||
    line_indent = '  '
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -114,6 +114,14 @@ void ONNXSeluOp::inferShapes() {
 | 
			
		|||
  getResult()->setType(getOperand()->getType());
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
// Reciprocal
 | 
			
		||||
/// Infer the output shape of the ONNXReciprocalOp. This method is required by
 | 
			
		||||
/// the shape inference interface.
 | 
			
		||||
void ONNXReciprocalOp::inferShapes() {
 | 
			
		||||
  getResult()->setType(getOperand()->getType());
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
// Add
 | 
			
		||||
/// Infer the output shape of the ONNXAddOp. This method is required by the
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -1821,6 +1821,8 @@ def ONNXQLinearConvOp:ONNX_Op<"QLinearConv",
 | 
			
		|||
    "and computes the quantized output. Each scale and zero-point pair must have same shape."
 | 
			
		||||
    "It means they must be either scalars (per tensor) or 1-D tensors (per output channel)."
 | 
			
		||||
    "Each input or output and its related zero point must have same type."
 | 
			
		||||
    "When bias is present it must be quantized using scale = input scale * weight scale and "
 | 
			
		||||
    "zero point as 0."
 | 
			
		||||
  }];
 | 
			
		||||
  let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$x, AnyTypeOf<[AnyMemRef, AnyTensor]>:$x_scale, AnyTypeOf<[AnyMemRef, AnyTensor]>:$x_zero_point, AnyTypeOf<[AnyMemRef, AnyTensor]>:$w, AnyTypeOf<[AnyMemRef, AnyTensor]>:$w_scale, AnyTypeOf<[AnyMemRef, AnyTensor]>:$w_zero_point, AnyTypeOf<[AnyMemRef, AnyTensor]>:$y_scale, AnyTypeOf<[AnyMemRef, AnyTensor]>:$y_zero_point, AnyTypeOf<[AnyMemRef, AnyTensor]>:$B);
 | 
			
		||||
  let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>);
 | 
			
		||||
| 
						 | 
				
			
			@ -2023,7 +2025,7 @@ def ONNXRangeOp:ONNX_Op<"Range",
 | 
			
		|||
}
 | 
			
		||||
 | 
			
		||||
def ONNXReciprocalOp:ONNX_Op<"Reciprocal", 
 | 
			
		||||
    [NoSideEffect]> {
 | 
			
		||||
    [NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
 | 
			
		||||
  let summary = "ONNX Reciprocal operation";
 | 
			
		||||
  let description = [{
 | 
			
		||||
    "Reciprocal takes one input data (Tensor<T>) and produces one output data"
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -425,6 +425,22 @@ Value* mapToLowerScalarOp<ONNXSeluOp>(Operation* op,
 | 
			
		|||
  return result;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
// Scalar unary ops for lowering ONNXReciprocalOp
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
template <>
 | 
			
		||||
Value* mapToLowerScalarOp<ONNXReciprocalOp>(Operation* op, ArrayRef<Type> result_types,
 | 
			
		||||
    ArrayRef<Value*> operands, ConversionPatternRewriter& rewriter) {
 | 
			
		||||
  // ONNXReciprocalOp(%X) = DivFOp(ConstantOp 1, %X)
 | 
			
		||||
  auto loc = op->getLoc();
 | 
			
		||||
  Value* operand = operands[0];
 | 
			
		||||
 | 
			
		||||
  auto one = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(1.0f));
 | 
			
		||||
  auto result = rewriter.create<DivFOp>(loc, one, operand);
 | 
			
		||||
 | 
			
		||||
  return result;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
// Scalar unary ops for lowering ONNXMaxOp
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
| 
						 | 
				
			
			@ -815,6 +831,7 @@ void FrontendToKrnlLoweringPass::runOnModule() {
 | 
			
		|||
      ONNXElementwiseUnaryOpLowering<mlir::ONNXReluOp>,
 | 
			
		||||
      ONNXElementwiseUnaryOpLowering<mlir::ONNXLeakyReluOp>,
 | 
			
		||||
      ONNXElementwiseUnaryOpLowering<mlir::ONNXSeluOp>,
 | 
			
		||||
      ONNXElementwiseUnaryOpLowering<mlir::ONNXReciprocalOp>,
 | 
			
		||||
      ONNXElementwiseVariadicOpLowering<mlir::ONNXAddOp>,
 | 
			
		||||
      ONNXElementwiseVariadicOpLowering<mlir::ONNXMulOp>,
 | 
			
		||||
      ONNXElementwiseVariadicOpLowering<mlir::ONNXDivOp>,
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -98,6 +98,7 @@ class ShapeInferencePass : public mlir::FunctionPass<ShapeInferencePass> {
 | 
			
		|||
        op->getName().getStringRef() != "onnx.Relu" &&
 | 
			
		||||
        op->getName().getStringRef() != "onnx.LeakyRelu" &&
 | 
			
		||||
        op->getName().getStringRef() != "onnx.Selu" &&
 | 
			
		||||
        op->getName().getStringRef() != "onnx.Reciprocal" &&
 | 
			
		||||
        op->getName().getStringRef() != "onnx.Mul" &&
 | 
			
		||||
        op->getName().getStringRef() != "onnx.Add" &&
 | 
			
		||||
        op->getName().getStringRef() != "onnx.Div" &&
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -475,3 +475,23 @@ func @test_hardsigmoid(%arg0 : tensor<?x10xf32>) -> tensor<*xf32> {
 | 
			
		|||
  // CHECK: store [[SELECT2]], [[RES]][%arg1, %arg2] : memref<?x10xf32>
 | 
			
		||||
  // CHECK: return [[RES]] : memref<?x10xf32>
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func @test_reciprocal(%arg0 : tensor<?x10xf32>) -> tensor<*xf32> {
 | 
			
		||||
  %0 = "onnx.Reciprocal"(%arg0) : (tensor<?x10xf32>) -> tensor<*xf32>
 | 
			
		||||
  "std.return"(%0) : (tensor<*xf32>) -> ()
 | 
			
		||||
 | 
			
		||||
  // CHECK-LABEL: test_reciprocal
 | 
			
		||||
  // CHECK: [[DIM_0:%.+]] = dim %arg0, 0 : memref<?x10xf32>
 | 
			
		||||
  // CHECK: [[RES:%.+]] = alloc([[DIM_0]]) : memref<?x10xf32>
 | 
			
		||||
  // CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2
 | 
			
		||||
  // CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops  {
 | 
			
		||||
  // CHECK:   krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1
 | 
			
		||||
  // CHECK: } : () -> (!krnl.loop, !krnl.loop)
 | 
			
		||||
  // CHECK: [[DIM_2:%.+]] = dim %arg0, 0 : memref<?x10xf32>
 | 
			
		||||
  // CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg1 = 0 to [[DIM_2]], [[DEF_LOOPS]]#1 -> %arg2 = 0 to 10) {
 | 
			
		||||
  // CHECK: [[LOAD:%.+]] = load %arg0[%arg1, %arg2] : memref<?x10xf32>
 | 
			
		||||
  // CHECK: [[ONE:%.+]] = constant {{1.+}} : f32
 | 
			
		||||
  // CHECK: [[RECIPROCAL_RES:%.+]] = divf [[ONE]], [[LOAD]] : f32
 | 
			
		||||
  // CHECK: store [[RECIPROCAL_RES]], [[RES]][%arg1, %arg2] : memref<?x10xf32>
 | 
			
		||||
  // CHECK: return [[RES]] : memref<?x10xf32>
 | 
			
		||||
}
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -910,3 +910,44 @@ func @test_hardsigmoid_hardsigmoid(%arg0 : tensor<?x10xf32>) -> tensor<*xf32> {
 | 
			
		|||
 | 
			
		||||
  // CHECK: return [[RET_RES]] : memref<?x10xf32>
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func @test_reciprocal_reciprocal(%arg0 : tensor<?x10xf32>) -> tensor<*xf32> {
 | 
			
		||||
  %0 = "onnx.Reciprocal"(%arg0) : (tensor<?x10xf32>) -> tensor<*xf32>
 | 
			
		||||
  %1 = "onnx.Reciprocal"(%0) : (tensor<*xf32>) -> tensor<*xf32>
 | 
			
		||||
  "std.return"(%1) : (tensor<*xf32>) -> ()
 | 
			
		||||
 | 
			
		||||
  // CHECK-LABEL: test_reciprocal_reciprocal
 | 
			
		||||
  /// First Reciprocal
 | 
			
		||||
  // CHECK: [[DIM_0:%.+]] = dim %arg0, 0 : memref<?x10xf32>
 | 
			
		||||
  // CHECK: [[RES:%.+]] = alloc([[DIM_0]]) : memref<?x10xf32>
 | 
			
		||||
  // CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2
 | 
			
		||||
  // CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops  {
 | 
			
		||||
  // CHECK:   krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1
 | 
			
		||||
  // CHECK: } : () -> (!krnl.loop, !krnl.loop)
 | 
			
		||||
  // CHECK: [[DIM_2:%.+]] = dim %arg0, 0 : memref<?x10xf32>
 | 
			
		||||
  // CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg1 = 0 to [[DIM_2]], [[DEF_LOOPS]]#1 -> %arg2 = 0 to 10) {
 | 
			
		||||
  // CHECK: [[LOAD:%.+]] = load %arg0[%arg1, %arg2] : memref<?x10xf32>
 | 
			
		||||
  // CHECK: [[ONE:%.+]] = constant {{1.+}} : f32
 | 
			
		||||
  // CHECK: [[RECIPROCAL_RES:%.+]] = divf [[ONE]], [[LOAD]] : f32
 | 
			
		||||
  // CHECK: store [[RECIPROCAL_RES]], [[RES]][%arg1, %arg2] : memref<?x10xf32>
 | 
			
		||||
 | 
			
		||||
  /// Second Reciprocal
 | 
			
		||||
  // CHECK: [[DIM_0:%.+]] = dim [[RES]], 0 : memref<?x10xf32>
 | 
			
		||||
  // CHECK: [[RET_RES:%.+]] = alloc([[DIM_0]]) : memref<?x10xf32>
 | 
			
		||||
  // CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2
 | 
			
		||||
  // CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops  {
 | 
			
		||||
  // CHECK:   krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1
 | 
			
		||||
  // CHECK: } : () -> (!krnl.loop, !krnl.loop)
 | 
			
		||||
  // CHECK: [[DIM_2:%.+]] = dim [[RES]], 0 : memref<?x10xf32>
 | 
			
		||||
  // CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg1 = 0 to [[DIM_2]], [[DEF_LOOPS]]#1 -> %arg2 = 0 to 10) {
 | 
			
		||||
  // CHECK: [[LOAD:%.+]] = load [[RES]][%arg1, %arg2] : memref<?x10xf32>
 | 
			
		||||
  // CHECK: [[ONE:%.+]] = constant {{1.+}} : f32
 | 
			
		||||
  // CHECK: [[RECIPROCAL_RES:%.+]] = divf [[ONE]], [[LOAD]] : f32
 | 
			
		||||
  // CHECK: store [[RECIPROCAL_RES]], [[RET_RES]][%arg1, %arg2] : memref<?x10xf32>
 | 
			
		||||
 | 
			
		||||
  /// Dealloc of first result.
 | 
			
		||||
  // CHECK: dealloc [[RES]] : memref<?x10xf32>
 | 
			
		||||
  // CHECK-NOT: dealloc [[RET_RES]] : memref<?x10xf32>
 | 
			
		||||
 | 
			
		||||
  // CHECK: return [[RET_RES]] : memref<?x10xf32>
 | 
			
		||||
}
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue