Added lowering of SignOp (#21)
* Support lowering of SignOp
* Fixed test code for signop of integer input
* Inserted Sigh and Reciprocal in SharingWork.md (Reciprocal is for past commit 7e3f96e
)
* Added test for Sign Op
* Fixed minus_one -> minusOne
* Fixed test for signop
This commit is contained in:
parent
87aa72764f
commit
477227a0ec
|
@ -23,9 +23,11 @@ ONNX operations for which some work is needed.
|
||||||
| Min | Tung | v | v | M |
|
| Min | Tung | v | v | M |
|
||||||
| Mul | Tung | v | v | M |
|
| Mul | Tung | v | v | M |
|
||||||
| Or | Tung | v | v | M |
|
| Or | Tung | v | v | M |
|
||||||
|
| Reciprocal | Imai | v | v | |
|
||||||
| Relu | Tung | v | v | |
|
| Relu | Tung | v | v | |
|
||||||
| Selu | Tung | v | v | |
|
| Selu | Tung | v | v | |
|
||||||
| Sigmoid | Tung | v | v | |
|
| Sigmoid | Tung | v | v | |
|
||||||
|
| Sign | Imai | v | v | |
|
||||||
| Sinh | Tung | v | v | |
|
| Sinh | Tung | v | v | |
|
||||||
| Softmax | Tung | v | v | |
|
| Softmax | Tung | v | v | |
|
||||||
| Sub | Tung | v | v | M |
|
| Sub | Tung | v | v | M |
|
||||||
|
|
|
@ -45,7 +45,7 @@ ShapeInferenceList=['Exp', 'Tanh', 'Sinh', 'Cosh', 'Sigmoid', 'Relu',
|
||||||
'Sum', 'Max', 'Min', 'MatMul', 'Gemm', 'LeakyRelu',
|
'Sum', 'Max', 'Min', 'MatMul', 'Gemm', 'LeakyRelu',
|
||||||
'Elu', 'Selu', 'HardSigmoid', 'Reshape', 'Reciprocal',
|
'Elu', 'Selu', 'HardSigmoid', 'Reshape', 'Reciprocal',
|
||||||
'Identity', 'Cos', 'Log', 'Transpose', 'Softmax',
|
'Identity', 'Cos', 'Log', 'Transpose', 'Softmax',
|
||||||
'Softplus', 'Softsign', 'Sqrt', 'Unsqueeze']
|
'Softplus', 'Softsign', 'Sqrt', 'Unsqueeze', 'Sign']
|
||||||
|
|
||||||
CanonicalList=['Add', 'Identity', 'ReduceL1', 'ReduceL2', 'ReduceLogSum',
|
CanonicalList=['Add', 'Identity', 'ReduceL1', 'ReduceL2', 'ReduceLogSum',
|
||||||
'ReduceLogSumExp', 'ReduceSumSquare']
|
'ReduceLogSumExp', 'ReduceSumSquare']
|
||||||
|
|
|
@ -190,6 +190,14 @@ void ONNXSqrtOp::inferShapes() {
|
||||||
getResult().setType(getOperand().getType());
|
getResult().setType(getOperand().getType());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// Sign
|
||||||
|
/// Infer the output shape of the ONNXSignOp. This method is required by
|
||||||
|
/// the shape inference interface.
|
||||||
|
void ONNXSignOp::inferShapes() {
|
||||||
|
getResult().setType(getOperand().getType());
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Add
|
// Add
|
||||||
/// Infer the output shape of the ONNXAddOp. This method is required by the
|
/// Infer the output shape of the ONNXAddOp. This method is required by the
|
||||||
|
|
|
@ -3040,7 +3040,7 @@ def ONNXSigmoidOp:ONNX_Op<"Sigmoid",
|
||||||
}
|
}
|
||||||
|
|
||||||
def ONNXSignOp:ONNX_Op<"Sign",
|
def ONNXSignOp:ONNX_Op<"Sign",
|
||||||
[NoSideEffect]> {
|
[NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
|
||||||
let summary = "ONNX Sign operation";
|
let summary = "ONNX Sign operation";
|
||||||
let description = [{
|
let description = [{
|
||||||
"Calculate the sign of the given input tensor element-wise."
|
"Calculate the sign of the given input tensor element-wise."
|
||||||
|
@ -3576,4 +3576,3 @@ def ONNXXorOp:ONNX_Op<"Xor",
|
||||||
AnyTypeOf<[AnyMemRef, AnyTensor]>:$B);
|
AnyTypeOf<[AnyMemRef, AnyTensor]>:$B);
|
||||||
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$o_C);
|
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$o_C);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -622,6 +622,64 @@ Value mapToLowerScalarOp<ONNXSoftsignOp>(
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// Scalar unary ops for lowering ONNXSignOp
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
template <>
|
||||||
|
Value mapToLowerScalarOp<ONNXSignOp>(Operation *op, ArrayRef<Type> result_types,
|
||||||
|
ArrayRef<Value> operands,
|
||||||
|
ConversionPatternRewriter &rewriter) {
|
||||||
|
|
||||||
|
auto loc = op->getLoc();
|
||||||
|
Value operand = operands[0];
|
||||||
|
Type element_type = operands.front().getType();
|
||||||
|
// TODO: unsigned int should be supported separately?
|
||||||
|
if (element_type.isa<IntegerType>()) {
|
||||||
|
// %Y = SelectOP(CmpIOp(GT, %X, ConstantOp 0),
|
||||||
|
// ConstantOp 1,
|
||||||
|
// COnstantOp -1)
|
||||||
|
// ONNXSignOp(%X) = SelectOP(CmpIOp(EQ, %X, ConstantOp 0),
|
||||||
|
// ConstantOp 0,
|
||||||
|
// %Y)
|
||||||
|
auto zero = rewriter.create<ConstantOp>(loc, rewriter.getI32IntegerAttr(0));
|
||||||
|
auto one = rewriter.create<ConstantOp>(loc, rewriter.getI32IntegerAttr(1));
|
||||||
|
auto minusOne =
|
||||||
|
rewriter.create<ConstantOp>(loc, rewriter.getI32IntegerAttr(-1));
|
||||||
|
auto plusPredicate =
|
||||||
|
rewriter.create<CmpIOp>(loc, CmpIPredicate::sgt, operand, zero);
|
||||||
|
auto plusSelect =
|
||||||
|
rewriter.create<SelectOp>(loc, plusPredicate, one, minusOne);
|
||||||
|
auto zeroPredicate =
|
||||||
|
rewriter.create<CmpIOp>(loc, CmpIPredicate::eq, operand, zero);
|
||||||
|
auto result =
|
||||||
|
rewriter.create<SelectOp>(loc, zeroPredicate, zero, plusSelect);
|
||||||
|
return result;
|
||||||
|
} else if (element_type.isa<FloatType>()) {
|
||||||
|
// %Y = SelectOP(CmpFOp(OGT, %X, ConstantOp 0),
|
||||||
|
// ConstantOp 1,
|
||||||
|
// ConstantOp -1)
|
||||||
|
// ONNXSignOp(%X) = SelectOP(CmpFOp(OEQ, %X, ConstantOp 0),
|
||||||
|
// ConstantOp 0,
|
||||||
|
// %Y)
|
||||||
|
auto zero =
|
||||||
|
rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(0.0f));
|
||||||
|
auto one = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(1.0f));
|
||||||
|
auto minusOne =
|
||||||
|
rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(-1.0f));
|
||||||
|
auto plusPredicate =
|
||||||
|
rewriter.create<CmpFOp>(loc, CmpFPredicate::OGT, operand, zero);
|
||||||
|
auto plusSelect =
|
||||||
|
rewriter.create<SelectOp>(loc, plusPredicate, one, minusOne);
|
||||||
|
auto zeroPredicate =
|
||||||
|
rewriter.create<CmpFOp>(loc, CmpFPredicate::OEQ, operand, zero);
|
||||||
|
auto result =
|
||||||
|
rewriter.create<SelectOp>(loc, zeroPredicate, zero, plusSelect);
|
||||||
|
return result;
|
||||||
|
} else {
|
||||||
|
emitError(loc, "unsupported element type");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Scalar unary ops for lowering ONNXMaxOp
|
// Scalar unary ops for lowering ONNXMaxOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -1697,6 +1755,7 @@ void FrontendToKrnlLoweringPass::runOnModule() {
|
||||||
ONNXElementwiseUnaryOpLowering<mlir::ONNXSoftplusOp>,
|
ONNXElementwiseUnaryOpLowering<mlir::ONNXSoftplusOp>,
|
||||||
ONNXElementwiseUnaryOpLowering<mlir::ONNXSoftsignOp>,
|
ONNXElementwiseUnaryOpLowering<mlir::ONNXSoftsignOp>,
|
||||||
ONNXElementwiseUnaryOpLowering<mlir::ONNXSqrtOp>,
|
ONNXElementwiseUnaryOpLowering<mlir::ONNXSqrtOp>,
|
||||||
|
ONNXElementwiseUnaryOpLowering<mlir::ONNXSignOp>,
|
||||||
ONNXElementwiseVariadicOpLowering<mlir::ONNXAddOp>,
|
ONNXElementwiseVariadicOpLowering<mlir::ONNXAddOp>,
|
||||||
ONNXElementwiseVariadicOpLowering<mlir::ONNXMulOp>,
|
ONNXElementwiseVariadicOpLowering<mlir::ONNXMulOp>,
|
||||||
ONNXElementwiseVariadicOpLowering<mlir::ONNXDivOp>,
|
ONNXElementwiseVariadicOpLowering<mlir::ONNXDivOp>,
|
||||||
|
|
|
@ -103,6 +103,7 @@ public:
|
||||||
op->getName().getStringRef() != "onnx.Reciprocal" &&
|
op->getName().getStringRef() != "onnx.Reciprocal" &&
|
||||||
op->getName().getStringRef() != "onnx.Softplus" &&
|
op->getName().getStringRef() != "onnx.Softplus" &&
|
||||||
op->getName().getStringRef() != "onnx.Softsign" &&
|
op->getName().getStringRef() != "onnx.Softsign" &&
|
||||||
|
op->getName().getStringRef() != "onnx.Sign" &&
|
||||||
op->getName().getStringRef() != "onnx.Mul" &&
|
op->getName().getStringRef() != "onnx.Mul" &&
|
||||||
op->getName().getStringRef() != "onnx.Add" &&
|
op->getName().getStringRef() != "onnx.Add" &&
|
||||||
op->getName().getStringRef() != "onnx.Div" &&
|
op->getName().getStringRef() != "onnx.Div" &&
|
||||||
|
|
|
@ -201,6 +201,9 @@ test_to_enable = [
|
||||||
"test_transpose_all_permutations_3_cpu",
|
"test_transpose_all_permutations_3_cpu",
|
||||||
"test_transpose_all_permutations_4_cpu",
|
"test_transpose_all_permutations_4_cpu",
|
||||||
"test_transpose_all_permutations_5_cpu",
|
"test_transpose_all_permutations_5_cpu",
|
||||||
|
|
||||||
|
# Sign Op:
|
||||||
|
"test_sign_cpu",
|
||||||
]
|
]
|
||||||
|
|
||||||
# Extract name of all test cases.
|
# Extract name of all test cases.
|
||||||
|
|
|
@ -738,3 +738,53 @@ func @test_identity(%arg0 : tensor<10x20x30x40xf32>) -> tensor<*xf32> {
|
||||||
// CHECK-LABEL: test_identity
|
// CHECK-LABEL: test_identity
|
||||||
// CHECK: return %arg0 : memref<10x20x30x40xf32>
|
// CHECK: return %arg0 : memref<10x20x30x40xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func @test_sign_f(%arg0 : tensor<?x10xf32>) -> tensor<*xf32> {
|
||||||
|
%0 = "onnx.Sign"(%arg0) : (tensor<?x10xf32>) -> tensor<*xf32>
|
||||||
|
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||||
|
|
||||||
|
// CHECK-LABEL: test_sign_f
|
||||||
|
// CHECK: [[DIM_0:%.+]] = dim %arg0, 0 : memref<?x10xf32>
|
||||||
|
// CHECK: [[RES:%.+]] = alloc([[DIM_0]]) : memref<?x10xf32>
|
||||||
|
// CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2
|
||||||
|
// CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops {
|
||||||
|
// CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1
|
||||||
|
// CHECK: } : () -> (!krnl.loop, !krnl.loop)
|
||||||
|
// CHECK: [[DIM_2:%.+]] = dim %arg0, 0 : memref<?x10xf32>
|
||||||
|
// CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg1 = 0 to [[DIM_2]], [[DEF_LOOPS]]#1 -> %arg2 = 0 to 10) {
|
||||||
|
// CHECK: [[LOAD:%.+]] = load %arg0[%arg1, %arg2] : memref<?x10xf32>
|
||||||
|
// CHECK: [[ZERO:%.+]] = constant {{0.+}} : f32
|
||||||
|
// CHECK: [[ONE:%.+]] = constant {{1.+}} : f32
|
||||||
|
// CHECK: [[MINUS_ONE:%.+]] = constant {{-1.+}} : f32
|
||||||
|
// CHECK: [[GTZERO:%.+]] = cmpf "ogt", [[LOAD]], [[ZERO]] : f32
|
||||||
|
// CHECK: [[SELECT_PLUS:%.+]] = select [[GTZERO]], [[ONE]], [[MINUS_ONE]] : f32
|
||||||
|
// CHECK: [[EQZERO:%.+]] = cmpf "oeq", [[LOAD]], [[ZERO]] : f32
|
||||||
|
// CHECK: [[SIGN_RES:%.+]] = select [[EQZERO]], [[ZERO]], [[SELECT_PLUS]] : f32
|
||||||
|
// CHECK: store [[SIGN_RES]], [[RES]][%arg1, %arg2] : memref<?x10xf32>
|
||||||
|
// CHECK: return [[RES]] : memref<?x10xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
func @test_sign_i(%arg0 : tensor<?x10xi32>) -> tensor<*xi32> {
|
||||||
|
%0 = "onnx.Sign"(%arg0) : (tensor<?x10xi32>) -> tensor<*xi32>
|
||||||
|
"std.return"(%0) : (tensor<*xi32>) -> ()
|
||||||
|
|
||||||
|
// CHECK-LABEL: test_sign_i
|
||||||
|
// CHECK: [[DIM_0:%.+]] = dim %arg0, 0 : memref<?x10xi32>
|
||||||
|
// CHECK: [[RES:%.+]] = alloc([[DIM_0]]) : memref<?x10xi32>
|
||||||
|
// CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2
|
||||||
|
// CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops {
|
||||||
|
// CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1
|
||||||
|
// CHECK: } : () -> (!krnl.loop, !krnl.loop)
|
||||||
|
// CHECK: [[DIM_2:%.+]] = dim %arg0, 0 : memref<?x10xi32>
|
||||||
|
// CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg1 = 0 to [[DIM_2]], [[DEF_LOOPS]]#1 -> %arg2 = 0 to 10) {
|
||||||
|
// CHECK: [[LOAD:%.+]] = load %arg0[%arg1, %arg2] : memref<?x10xi32>
|
||||||
|
// CHECK: [[ZERO:%.+]] = constant 0 : i32
|
||||||
|
// CHECK: [[ONE:%.+]] = constant 1 : i32
|
||||||
|
// CHECK: [[MINUS_ONE:%.+]] = constant -1 : i32
|
||||||
|
// CHECK: [[GTZERO:%.+]] = cmpi "sgt", [[LOAD]], [[ZERO]] : i32
|
||||||
|
// CHECK: [[SELECT_PLUS:%.+]] = select [[GTZERO]], [[ONE]], [[MINUS_ONE]] : i32
|
||||||
|
// CHECK: [[EQZERO:%.+]] = cmpi "eq", [[LOAD]], [[ZERO]] : i32
|
||||||
|
// CHECK: [[SIGN_RES:%.+]] = select [[EQZERO]], [[ZERO]], [[SELECT_PLUS]] : i32
|
||||||
|
// CHECK: store [[SIGN_RES]], [[RES]][%arg1, %arg2] : memref<?x10xi32>
|
||||||
|
// CHECK: return [[RES]] : memref<?x10xi32>
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue