Added lowering of SignOp (#21)

* Support lowering of SignOp

* Fixed test code for signop of integer input

* Inserted Sigh and Reciprocal in SharingWork.md (Reciprocal is for past commit 7e3f96e)

* Added test for Sign Op

* Fixed minus_one -> minusOne

* Fixed test for signop
This commit is contained in:
Haruki Imai 2020-02-04 23:27:17 +09:00 committed by GitHub
parent 87aa72764f
commit 477227a0ec
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 126 additions and 4 deletions

View File

@ -23,9 +23,11 @@ ONNX operations for which some work is needed.
| Min | Tung | v | v | M | | Min | Tung | v | v | M |
| Mul | Tung | v | v | M | | Mul | Tung | v | v | M |
| Or | Tung | v | v | M | | Or | Tung | v | v | M |
| Reciprocal | Imai | v | v | |
| Relu | Tung | v | v | | | Relu | Tung | v | v | |
| Selu | Tung | v | v | | | Selu | Tung | v | v | |
| Sigmoid | Tung | v | v | | | Sigmoid | Tung | v | v | |
| Sign | Imai | v | v | |
| Sinh | Tung | v | v | | | Sinh | Tung | v | v | |
| Softmax | Tung | v | v | | | Softmax | Tung | v | v | |
| Sub | Tung | v | v | M | | Sub | Tung | v | v | M |

View File

@ -45,7 +45,7 @@ ShapeInferenceList=['Exp', 'Tanh', 'Sinh', 'Cosh', 'Sigmoid', 'Relu',
'Sum', 'Max', 'Min', 'MatMul', 'Gemm', 'LeakyRelu', 'Sum', 'Max', 'Min', 'MatMul', 'Gemm', 'LeakyRelu',
'Elu', 'Selu', 'HardSigmoid', 'Reshape', 'Reciprocal', 'Elu', 'Selu', 'HardSigmoid', 'Reshape', 'Reciprocal',
'Identity', 'Cos', 'Log', 'Transpose', 'Softmax', 'Identity', 'Cos', 'Log', 'Transpose', 'Softmax',
'Softplus', 'Softsign', 'Sqrt', 'Unsqueeze'] 'Softplus', 'Softsign', 'Sqrt', 'Unsqueeze', 'Sign']
CanonicalList=['Add', 'Identity', 'ReduceL1', 'ReduceL2', 'ReduceLogSum', CanonicalList=['Add', 'Identity', 'ReduceL1', 'ReduceL2', 'ReduceLogSum',
'ReduceLogSumExp', 'ReduceSumSquare'] 'ReduceLogSumExp', 'ReduceSumSquare']

View File

@ -190,6 +190,14 @@ void ONNXSqrtOp::inferShapes() {
getResult().setType(getOperand().getType()); getResult().setType(getOperand().getType());
} }
//===----------------------------------------------------------------------===//
// Sign
/// Infer the output shape of the ONNXSignOp. This method is required by
/// the shape inference interface.
void ONNXSignOp::inferShapes() {
getResult().setType(getOperand().getType());
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Add // Add
/// Infer the output shape of the ONNXAddOp. This method is required by the /// Infer the output shape of the ONNXAddOp. This method is required by the

View File

@ -3040,7 +3040,7 @@ def ONNXSigmoidOp:ONNX_Op<"Sigmoid",
} }
def ONNXSignOp:ONNX_Op<"Sign", def ONNXSignOp:ONNX_Op<"Sign",
[NoSideEffect]> { [NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
let summary = "ONNX Sign operation"; let summary = "ONNX Sign operation";
let description = [{ let description = [{
"Calculate the sign of the given input tensor element-wise." "Calculate the sign of the given input tensor element-wise."
@ -3576,4 +3576,3 @@ def ONNXXorOp:ONNX_Op<"Xor",
AnyTypeOf<[AnyMemRef, AnyTensor]>:$B); AnyTypeOf<[AnyMemRef, AnyTensor]>:$B);
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$o_C); let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$o_C);
} }

View File

@ -622,6 +622,64 @@ Value mapToLowerScalarOp<ONNXSoftsignOp>(
return result; return result;
} }
//===----------------------------------------------------------------------===//
// Scalar unary ops for lowering ONNXSignOp
//===----------------------------------------------------------------------===//
template <>
Value mapToLowerScalarOp<ONNXSignOp>(Operation *op, ArrayRef<Type> result_types,
ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) {
auto loc = op->getLoc();
Value operand = operands[0];
Type element_type = operands.front().getType();
// TODO: unsigned int should be supported separately?
if (element_type.isa<IntegerType>()) {
// %Y = SelectOP(CmpIOp(GT, %X, ConstantOp 0),
// ConstantOp 1,
// COnstantOp -1)
// ONNXSignOp(%X) = SelectOP(CmpIOp(EQ, %X, ConstantOp 0),
// ConstantOp 0,
// %Y)
auto zero = rewriter.create<ConstantOp>(loc, rewriter.getI32IntegerAttr(0));
auto one = rewriter.create<ConstantOp>(loc, rewriter.getI32IntegerAttr(1));
auto minusOne =
rewriter.create<ConstantOp>(loc, rewriter.getI32IntegerAttr(-1));
auto plusPredicate =
rewriter.create<CmpIOp>(loc, CmpIPredicate::sgt, operand, zero);
auto plusSelect =
rewriter.create<SelectOp>(loc, plusPredicate, one, minusOne);
auto zeroPredicate =
rewriter.create<CmpIOp>(loc, CmpIPredicate::eq, operand, zero);
auto result =
rewriter.create<SelectOp>(loc, zeroPredicate, zero, plusSelect);
return result;
} else if (element_type.isa<FloatType>()) {
// %Y = SelectOP(CmpFOp(OGT, %X, ConstantOp 0),
// ConstantOp 1,
// ConstantOp -1)
// ONNXSignOp(%X) = SelectOP(CmpFOp(OEQ, %X, ConstantOp 0),
// ConstantOp 0,
// %Y)
auto zero =
rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(0.0f));
auto one = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(1.0f));
auto minusOne =
rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(-1.0f));
auto plusPredicate =
rewriter.create<CmpFOp>(loc, CmpFPredicate::OGT, operand, zero);
auto plusSelect =
rewriter.create<SelectOp>(loc, plusPredicate, one, minusOne);
auto zeroPredicate =
rewriter.create<CmpFOp>(loc, CmpFPredicate::OEQ, operand, zero);
auto result =
rewriter.create<SelectOp>(loc, zeroPredicate, zero, plusSelect);
return result;
} else {
emitError(loc, "unsupported element type");
}
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Scalar unary ops for lowering ONNXMaxOp // Scalar unary ops for lowering ONNXMaxOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -1697,6 +1755,7 @@ void FrontendToKrnlLoweringPass::runOnModule() {
ONNXElementwiseUnaryOpLowering<mlir::ONNXSoftplusOp>, ONNXElementwiseUnaryOpLowering<mlir::ONNXSoftplusOp>,
ONNXElementwiseUnaryOpLowering<mlir::ONNXSoftsignOp>, ONNXElementwiseUnaryOpLowering<mlir::ONNXSoftsignOp>,
ONNXElementwiseUnaryOpLowering<mlir::ONNXSqrtOp>, ONNXElementwiseUnaryOpLowering<mlir::ONNXSqrtOp>,
ONNXElementwiseUnaryOpLowering<mlir::ONNXSignOp>,
ONNXElementwiseVariadicOpLowering<mlir::ONNXAddOp>, ONNXElementwiseVariadicOpLowering<mlir::ONNXAddOp>,
ONNXElementwiseVariadicOpLowering<mlir::ONNXMulOp>, ONNXElementwiseVariadicOpLowering<mlir::ONNXMulOp>,
ONNXElementwiseVariadicOpLowering<mlir::ONNXDivOp>, ONNXElementwiseVariadicOpLowering<mlir::ONNXDivOp>,

View File

@ -103,6 +103,7 @@ public:
op->getName().getStringRef() != "onnx.Reciprocal" && op->getName().getStringRef() != "onnx.Reciprocal" &&
op->getName().getStringRef() != "onnx.Softplus" && op->getName().getStringRef() != "onnx.Softplus" &&
op->getName().getStringRef() != "onnx.Softsign" && op->getName().getStringRef() != "onnx.Softsign" &&
op->getName().getStringRef() != "onnx.Sign" &&
op->getName().getStringRef() != "onnx.Mul" && op->getName().getStringRef() != "onnx.Mul" &&
op->getName().getStringRef() != "onnx.Add" && op->getName().getStringRef() != "onnx.Add" &&
op->getName().getStringRef() != "onnx.Div" && op->getName().getStringRef() != "onnx.Div" &&

View File

@ -201,6 +201,9 @@ test_to_enable = [
"test_transpose_all_permutations_3_cpu", "test_transpose_all_permutations_3_cpu",
"test_transpose_all_permutations_4_cpu", "test_transpose_all_permutations_4_cpu",
"test_transpose_all_permutations_5_cpu", "test_transpose_all_permutations_5_cpu",
# Sign Op:
"test_sign_cpu",
] ]
# Extract name of all test cases. # Extract name of all test cases.

View File

@ -738,3 +738,53 @@ func @test_identity(%arg0 : tensor<10x20x30x40xf32>) -> tensor<*xf32> {
// CHECK-LABEL: test_identity // CHECK-LABEL: test_identity
// CHECK: return %arg0 : memref<10x20x30x40xf32> // CHECK: return %arg0 : memref<10x20x30x40xf32>
} }
func @test_sign_f(%arg0 : tensor<?x10xf32>) -> tensor<*xf32> {
%0 = "onnx.Sign"(%arg0) : (tensor<?x10xf32>) -> tensor<*xf32>
"std.return"(%0) : (tensor<*xf32>) -> ()
// CHECK-LABEL: test_sign_f
// CHECK: [[DIM_0:%.+]] = dim %arg0, 0 : memref<?x10xf32>
// CHECK: [[RES:%.+]] = alloc([[DIM_0]]) : memref<?x10xf32>
// CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2
// CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops {
// CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1
// CHECK: } : () -> (!krnl.loop, !krnl.loop)
// CHECK: [[DIM_2:%.+]] = dim %arg0, 0 : memref<?x10xf32>
// CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg1 = 0 to [[DIM_2]], [[DEF_LOOPS]]#1 -> %arg2 = 0 to 10) {
// CHECK: [[LOAD:%.+]] = load %arg0[%arg1, %arg2] : memref<?x10xf32>
// CHECK: [[ZERO:%.+]] = constant {{0.+}} : f32
// CHECK: [[ONE:%.+]] = constant {{1.+}} : f32
// CHECK: [[MINUS_ONE:%.+]] = constant {{-1.+}} : f32
// CHECK: [[GTZERO:%.+]] = cmpf "ogt", [[LOAD]], [[ZERO]] : f32
// CHECK: [[SELECT_PLUS:%.+]] = select [[GTZERO]], [[ONE]], [[MINUS_ONE]] : f32
// CHECK: [[EQZERO:%.+]] = cmpf "oeq", [[LOAD]], [[ZERO]] : f32
// CHECK: [[SIGN_RES:%.+]] = select [[EQZERO]], [[ZERO]], [[SELECT_PLUS]] : f32
// CHECK: store [[SIGN_RES]], [[RES]][%arg1, %arg2] : memref<?x10xf32>
// CHECK: return [[RES]] : memref<?x10xf32>
}
func @test_sign_i(%arg0 : tensor<?x10xi32>) -> tensor<*xi32> {
%0 = "onnx.Sign"(%arg0) : (tensor<?x10xi32>) -> tensor<*xi32>
"std.return"(%0) : (tensor<*xi32>) -> ()
// CHECK-LABEL: test_sign_i
// CHECK: [[DIM_0:%.+]] = dim %arg0, 0 : memref<?x10xi32>
// CHECK: [[RES:%.+]] = alloc([[DIM_0]]) : memref<?x10xi32>
// CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2
// CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops {
// CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1
// CHECK: } : () -> (!krnl.loop, !krnl.loop)
// CHECK: [[DIM_2:%.+]] = dim %arg0, 0 : memref<?x10xi32>
// CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg1 = 0 to [[DIM_2]], [[DEF_LOOPS]]#1 -> %arg2 = 0 to 10) {
// CHECK: [[LOAD:%.+]] = load %arg0[%arg1, %arg2] : memref<?x10xi32>
// CHECK: [[ZERO:%.+]] = constant 0 : i32
// CHECK: [[ONE:%.+]] = constant 1 : i32
// CHECK: [[MINUS_ONE:%.+]] = constant -1 : i32
// CHECK: [[GTZERO:%.+]] = cmpi "sgt", [[LOAD]], [[ZERO]] : i32
// CHECK: [[SELECT_PLUS:%.+]] = select [[GTZERO]], [[ONE]], [[MINUS_ONE]] : i32
// CHECK: [[EQZERO:%.+]] = cmpi "eq", [[LOAD]], [[ZERO]] : i32
// CHECK: [[SIGN_RES:%.+]] = select [[EQZERO]], [[ZERO]], [[SELECT_PLUS]] : i32
// CHECK: store [[SIGN_RES]], [[RES]][%arg1, %arg2] : memref<?x10xi32>
// CHECK: return [[RES]] : memref<?x10xi32>
}