From 477227a0ecd5ab8d092495ce761144c70fdcbeec Mon Sep 17 00:00:00 2001 From: Haruki Imai Date: Tue, 4 Feb 2020 23:27:17 +0900 Subject: [PATCH] Added lowering of SignOp (#21) * Support lowering of SignOp * Fixed test code for signop of integer input * Inserted Sigh and Reciprocal in SharingWork.md (Reciprocal is for past commit 7e3f96e) * Added test for Sign Op * Fixed minus_one -> minusOne * Fixed test for signop --- SharingWork.md | 2 + doc/gen_doc.py | 2 +- src/dialect/onnx/onnx_ops.cpp | 8 ++++ src/dialect/onnx/onnxop.inc | 5 +-- src/pass/lower_frontend_to_krnl.cpp | 59 +++++++++++++++++++++++++++++ src/pass/shape_inference_pass.cpp | 1 + test/backend/test.py | 3 ++ test/mlir/onnx/onnx_lowering.mlir | 50 ++++++++++++++++++++++++ 8 files changed, 126 insertions(+), 4 deletions(-) diff --git a/SharingWork.md b/SharingWork.md index 625db96..f542bf4 100644 --- a/SharingWork.md +++ b/SharingWork.md @@ -23,9 +23,11 @@ ONNX operations for which some work is needed. | Min | Tung | v | v | M | | Mul | Tung | v | v | M | | Or | Tung | v | v | M | +| Reciprocal | Imai | v | v | | | Relu | Tung | v | v | | | Selu | Tung | v | v | | | Sigmoid | Tung | v | v | | +| Sign | Imai | v | v | | | Sinh | Tung | v | v | | | Softmax | Tung | v | v | | | Sub | Tung | v | v | M | diff --git a/doc/gen_doc.py b/doc/gen_doc.py index 428c360..e0c9bcc 100644 --- a/doc/gen_doc.py +++ b/doc/gen_doc.py @@ -45,7 +45,7 @@ ShapeInferenceList=['Exp', 'Tanh', 'Sinh', 'Cosh', 'Sigmoid', 'Relu', 'Sum', 'Max', 'Min', 'MatMul', 'Gemm', 'LeakyRelu', 'Elu', 'Selu', 'HardSigmoid', 'Reshape', 'Reciprocal', 'Identity', 'Cos', 'Log', 'Transpose', 'Softmax', - 'Softplus', 'Softsign', 'Sqrt', 'Unsqueeze'] + 'Softplus', 'Softsign', 'Sqrt', 'Unsqueeze', 'Sign'] CanonicalList=['Add', 'Identity', 'ReduceL1', 'ReduceL2', 'ReduceLogSum', 'ReduceLogSumExp', 'ReduceSumSquare'] diff --git a/src/dialect/onnx/onnx_ops.cpp b/src/dialect/onnx/onnx_ops.cpp index 6a70b01..8d9b52b 100644 --- a/src/dialect/onnx/onnx_ops.cpp +++ b/src/dialect/onnx/onnx_ops.cpp @@ -190,6 +190,14 @@ void ONNXSqrtOp::inferShapes() { getResult().setType(getOperand().getType()); } +//===----------------------------------------------------------------------===// +// Sign +/// Infer the output shape of the ONNXSignOp. This method is required by +/// the shape inference interface. +void ONNXSignOp::inferShapes() { + getResult().setType(getOperand().getType()); +} + //===----------------------------------------------------------------------===// // Add /// Infer the output shape of the ONNXAddOp. This method is required by the diff --git a/src/dialect/onnx/onnxop.inc b/src/dialect/onnx/onnxop.inc index 02ce5e7..360ea1f 100644 --- a/src/dialect/onnx/onnxop.inc +++ b/src/dialect/onnx/onnxop.inc @@ -1620,7 +1620,7 @@ def ONNXMaxPoolOp:ONNX_Op<"MaxPool", " ```" " pad_shape[i] = (output_spatial_shape[i] - 1) * strides_spatial_shape[i] + ((kernel_spatial_shape[i] - 1) * dilations[i] + 1) - input_spatial_shape[i]" " ```" - " The output of each pooling window is maximum number of elements exclude pad." + " The output of each pooling window is maximum number of elements exclude pad. " " " }]; let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$X, @@ -3040,7 +3040,7 @@ def ONNXSigmoidOp:ONNX_Op<"Sigmoid", } def ONNXSignOp:ONNX_Op<"Sign", - [NoSideEffect]> { + [NoSideEffect, DeclareOpInterfaceMethods]> { let summary = "ONNX Sign operation"; let description = [{ "Calculate the sign of the given input tensor element-wise." @@ -3576,4 +3576,3 @@ def ONNXXorOp:ONNX_Op<"Xor", AnyTypeOf<[AnyMemRef, AnyTensor]>:$B); let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$o_C); } - diff --git a/src/pass/lower_frontend_to_krnl.cpp b/src/pass/lower_frontend_to_krnl.cpp index f05b6fb..7f13a0f 100644 --- a/src/pass/lower_frontend_to_krnl.cpp +++ b/src/pass/lower_frontend_to_krnl.cpp @@ -622,6 +622,64 @@ Value mapToLowerScalarOp( return result; } +//===----------------------------------------------------------------------===// +// Scalar unary ops for lowering ONNXSignOp +//===----------------------------------------------------------------------===// +template <> +Value mapToLowerScalarOp(Operation *op, ArrayRef result_types, + ArrayRef operands, + ConversionPatternRewriter &rewriter) { + + auto loc = op->getLoc(); + Value operand = operands[0]; + Type element_type = operands.front().getType(); + // TODO: unsigned int should be supported separately? + if (element_type.isa()) { + // %Y = SelectOP(CmpIOp(GT, %X, ConstantOp 0), + // ConstantOp 1, + // COnstantOp -1) + // ONNXSignOp(%X) = SelectOP(CmpIOp(EQ, %X, ConstantOp 0), + // ConstantOp 0, + // %Y) + auto zero = rewriter.create(loc, rewriter.getI32IntegerAttr(0)); + auto one = rewriter.create(loc, rewriter.getI32IntegerAttr(1)); + auto minusOne = + rewriter.create(loc, rewriter.getI32IntegerAttr(-1)); + auto plusPredicate = + rewriter.create(loc, CmpIPredicate::sgt, operand, zero); + auto plusSelect = + rewriter.create(loc, plusPredicate, one, minusOne); + auto zeroPredicate = + rewriter.create(loc, CmpIPredicate::eq, operand, zero); + auto result = + rewriter.create(loc, zeroPredicate, zero, plusSelect); + return result; + } else if (element_type.isa()) { + // %Y = SelectOP(CmpFOp(OGT, %X, ConstantOp 0), + // ConstantOp 1, + // ConstantOp -1) + // ONNXSignOp(%X) = SelectOP(CmpFOp(OEQ, %X, ConstantOp 0), + // ConstantOp 0, + // %Y) + auto zero = + rewriter.create(loc, rewriter.getF32FloatAttr(0.0f)); + auto one = rewriter.create(loc, rewriter.getF32FloatAttr(1.0f)); + auto minusOne = + rewriter.create(loc, rewriter.getF32FloatAttr(-1.0f)); + auto plusPredicate = + rewriter.create(loc, CmpFPredicate::OGT, operand, zero); + auto plusSelect = + rewriter.create(loc, plusPredicate, one, minusOne); + auto zeroPredicate = + rewriter.create(loc, CmpFPredicate::OEQ, operand, zero); + auto result = + rewriter.create(loc, zeroPredicate, zero, plusSelect); + return result; + } else { + emitError(loc, "unsupported element type"); + } +} + //===----------------------------------------------------------------------===// // Scalar unary ops for lowering ONNXMaxOp //===----------------------------------------------------------------------===// @@ -1697,6 +1755,7 @@ void FrontendToKrnlLoweringPass::runOnModule() { ONNXElementwiseUnaryOpLowering, ONNXElementwiseUnaryOpLowering, ONNXElementwiseUnaryOpLowering, + ONNXElementwiseUnaryOpLowering, ONNXElementwiseVariadicOpLowering, ONNXElementwiseVariadicOpLowering, ONNXElementwiseVariadicOpLowering, diff --git a/src/pass/shape_inference_pass.cpp b/src/pass/shape_inference_pass.cpp index 8af15c9..daf0224 100644 --- a/src/pass/shape_inference_pass.cpp +++ b/src/pass/shape_inference_pass.cpp @@ -103,6 +103,7 @@ public: op->getName().getStringRef() != "onnx.Reciprocal" && op->getName().getStringRef() != "onnx.Softplus" && op->getName().getStringRef() != "onnx.Softsign" && + op->getName().getStringRef() != "onnx.Sign" && op->getName().getStringRef() != "onnx.Mul" && op->getName().getStringRef() != "onnx.Add" && op->getName().getStringRef() != "onnx.Div" && diff --git a/test/backend/test.py b/test/backend/test.py index 86e2492..2387e15 100644 --- a/test/backend/test.py +++ b/test/backend/test.py @@ -201,6 +201,9 @@ test_to_enable = [ "test_transpose_all_permutations_3_cpu", "test_transpose_all_permutations_4_cpu", "test_transpose_all_permutations_5_cpu", + + # Sign Op: + "test_sign_cpu", ] # Extract name of all test cases. diff --git a/test/mlir/onnx/onnx_lowering.mlir b/test/mlir/onnx/onnx_lowering.mlir index 0883aa8..c58dcd1 100644 --- a/test/mlir/onnx/onnx_lowering.mlir +++ b/test/mlir/onnx/onnx_lowering.mlir @@ -738,3 +738,53 @@ func @test_identity(%arg0 : tensor<10x20x30x40xf32>) -> tensor<*xf32> { // CHECK-LABEL: test_identity // CHECK: return %arg0 : memref<10x20x30x40xf32> } + +func @test_sign_f(%arg0 : tensor) -> tensor<*xf32> { + %0 = "onnx.Sign"(%arg0) : (tensor) -> tensor<*xf32> + "std.return"(%0) : (tensor<*xf32>) -> () + + // CHECK-LABEL: test_sign_f + // CHECK: [[DIM_0:%.+]] = dim %arg0, 0 : memref + // CHECK: [[RES:%.+]] = alloc([[DIM_0]]) : memref + // CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2 + // CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops { + // CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1 + // CHECK: } : () -> (!krnl.loop, !krnl.loop) + // CHECK: [[DIM_2:%.+]] = dim %arg0, 0 : memref + // CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg1 = 0 to [[DIM_2]], [[DEF_LOOPS]]#1 -> %arg2 = 0 to 10) { + // CHECK: [[LOAD:%.+]] = load %arg0[%arg1, %arg2] : memref + // CHECK: [[ZERO:%.+]] = constant {{0.+}} : f32 + // CHECK: [[ONE:%.+]] = constant {{1.+}} : f32 + // CHECK: [[MINUS_ONE:%.+]] = constant {{-1.+}} : f32 + // CHECK: [[GTZERO:%.+]] = cmpf "ogt", [[LOAD]], [[ZERO]] : f32 + // CHECK: [[SELECT_PLUS:%.+]] = select [[GTZERO]], [[ONE]], [[MINUS_ONE]] : f32 + // CHECK: [[EQZERO:%.+]] = cmpf "oeq", [[LOAD]], [[ZERO]] : f32 + // CHECK: [[SIGN_RES:%.+]] = select [[EQZERO]], [[ZERO]], [[SELECT_PLUS]] : f32 + // CHECK: store [[SIGN_RES]], [[RES]][%arg1, %arg2] : memref + // CHECK: return [[RES]] : memref +} + +func @test_sign_i(%arg0 : tensor) -> tensor<*xi32> { + %0 = "onnx.Sign"(%arg0) : (tensor) -> tensor<*xi32> + "std.return"(%0) : (tensor<*xi32>) -> () + + // CHECK-LABEL: test_sign_i + // CHECK: [[DIM_0:%.+]] = dim %arg0, 0 : memref + // CHECK: [[RES:%.+]] = alloc([[DIM_0]]) : memref + // CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2 + // CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops { + // CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1 + // CHECK: } : () -> (!krnl.loop, !krnl.loop) + // CHECK: [[DIM_2:%.+]] = dim %arg0, 0 : memref + // CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg1 = 0 to [[DIM_2]], [[DEF_LOOPS]]#1 -> %arg2 = 0 to 10) { + // CHECK: [[LOAD:%.+]] = load %arg0[%arg1, %arg2] : memref + // CHECK: [[ZERO:%.+]] = constant 0 : i32 + // CHECK: [[ONE:%.+]] = constant 1 : i32 + // CHECK: [[MINUS_ONE:%.+]] = constant -1 : i32 + // CHECK: [[GTZERO:%.+]] = cmpi "sgt", [[LOAD]], [[ZERO]] : i32 + // CHECK: [[SELECT_PLUS:%.+]] = select [[GTZERO]], [[ONE]], [[MINUS_ONE]] : i32 + // CHECK: [[EQZERO:%.+]] = cmpi "eq", [[LOAD]], [[ZERO]] : i32 + // CHECK: [[SIGN_RES:%.+]] = select [[EQZERO]], [[ZERO]], [[SELECT_PLUS]] : i32 + // CHECK: store [[SIGN_RES]], [[RES]][%arg1, %arg2] : memref + // CHECK: return [[RES]] : memref +}