Support Softplus and Softsign operations (#17)
* Support Softplus and Softsign operations * Add the default shape inference for the transposition operation. * Fix conflict with master * Fix conflict with master branch * Add test for softplus and softsign in test/backend/test.py * Re-enable Reciprocal tests. Co-authored-by: Gheorghe-Teodor Bercea <gt.bercea@gmail.com> Co-authored-by: Tian Jin <tjingrant@gmail.com>
This commit is contained in:
parent
0ee7380edd
commit
383a5c31ac
|
@ -267,7 +267,8 @@ def gen_schema(schema) :
|
||||||
'Add', 'Mul', 'Div', 'Sub', 'And', 'Or', 'Xor',
|
'Add', 'Mul', 'Div', 'Sub', 'And', 'Or', 'Xor',
|
||||||
'Sum', 'Max', 'Min', 'MatMul', 'Gemm', 'LeakyRelu',
|
'Sum', 'Max', 'Min', 'MatMul', 'Gemm', 'LeakyRelu',
|
||||||
'Elu', 'Selu', 'HardSigmoid', 'Reshape', 'Reciprocal',
|
'Elu', 'Selu', 'HardSigmoid', 'Reshape', 'Reciprocal',
|
||||||
'Identity', 'Cos', 'Log', 'Transpose', 'Softmax']
|
'Identity', 'Cos', 'Log', 'Transpose', 'Softmax',
|
||||||
|
'Softplus', 'Softsign']
|
||||||
CanonicalList=['Add', 'Identity']
|
CanonicalList=['Add', 'Identity']
|
||||||
line_indent = ' '
|
line_indent = ' '
|
||||||
|
|
||||||
|
|
|
@ -166,6 +166,22 @@ void ONNXSoftmaxOp::inferShapes() {
|
||||||
getResult().setType(getOperand().getType());
|
getResult().setType(getOperand().getType());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// Softplus
|
||||||
|
/// Infer the output shape of the ONNXSoftplusOp. This method is required by
|
||||||
|
/// the shape inference interface.
|
||||||
|
void ONNXSoftplusOp::inferShapes() {
|
||||||
|
getResult().setType(getOperand().getType());
|
||||||
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// Softsign
|
||||||
|
/// Infer the output shape of the ONNXSoftsignOp. This method is required by
|
||||||
|
/// the shape inference interface.
|
||||||
|
void ONNXSoftsignOp::inferShapes() {
|
||||||
|
getResult().setType(getOperand().getType());
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Add
|
// Add
|
||||||
/// Infer the output shape of the ONNXAddOp. This method is required by the
|
/// Infer the output shape of the ONNXAddOp. This method is required by the
|
||||||
|
|
|
@ -2863,7 +2863,7 @@ def ONNXSoftmaxOp:ONNX_Op<"Softmax",
|
||||||
}
|
}
|
||||||
|
|
||||||
def ONNXSoftplusOp:ONNX_Op<"Softplus",
|
def ONNXSoftplusOp:ONNX_Op<"Softplus",
|
||||||
[NoSideEffect]> {
|
[NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
|
||||||
let summary = "ONNX Softplus operation";
|
let summary = "ONNX Softplus operation";
|
||||||
let description = [{
|
let description = [{
|
||||||
"Softplus takes one input data (Tensor<T>) and produces one output data"
|
"Softplus takes one input data (Tensor<T>) and produces one output data"
|
||||||
|
@ -2875,7 +2875,7 @@ def ONNXSoftplusOp:ONNX_Op<"Softplus",
|
||||||
}
|
}
|
||||||
|
|
||||||
def ONNXSoftsignOp:ONNX_Op<"Softsign",
|
def ONNXSoftsignOp:ONNX_Op<"Softsign",
|
||||||
[NoSideEffect]> {
|
[NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
|
||||||
let summary = "ONNX Softsign operation";
|
let summary = "ONNX Softsign operation";
|
||||||
let description = [{
|
let description = [{
|
||||||
"Calculates the softsign (x/(1+|x|)) of the given input tensor element-wise."
|
"Calculates the softsign (x/(1+|x|)) of the given input tensor element-wise."
|
||||||
|
|
|
@ -570,6 +570,46 @@ Value mapToLowerScalarOp<ONNXReciprocalOp>(
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// Scalar unary ops for lowering ONNXSoftplusOp
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
template <>
|
||||||
|
Value mapToLowerScalarOp<ONNXSoftplusOp>(
|
||||||
|
Operation *op, ArrayRef<Type> result_types, ArrayRef<Value> operands,
|
||||||
|
ConversionPatternRewriter &rewriter) {
|
||||||
|
// ONNXSoftplusOp(%X) = LogOp(AddFOp(ExpOp(%X), ConstantOp 1))
|
||||||
|
auto loc = op->getLoc();
|
||||||
|
Value operand = operands[0];
|
||||||
|
auto elementType = result_types[0];
|
||||||
|
|
||||||
|
auto exp = rewriter.create<ExpOp>(loc, operand);
|
||||||
|
auto one = rewriter.create<ConstantOp>(loc, FloatAttr::get(elementType, 1));
|
||||||
|
auto add = rewriter.create<AddFOp>(loc, exp, one);
|
||||||
|
auto result = rewriter.create<LogOp>(loc, add);
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// Scalar unary ops for lowering ONNXSoftsignOp
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
template <>
|
||||||
|
Value mapToLowerScalarOp<ONNXSoftsignOp>(
|
||||||
|
Operation *op, ArrayRef<Type> result_types, ArrayRef<Value> operands,
|
||||||
|
ConversionPatternRewriter &rewriter) {
|
||||||
|
// ONNXSoftsignOp(%X) = DivFOp(ConstantOp 1, %X)
|
||||||
|
auto loc = op->getLoc();
|
||||||
|
Value operand = operands[0];
|
||||||
|
auto elementType = result_types[0];
|
||||||
|
|
||||||
|
auto abs = rewriter.create<AbsFOp>(loc, operand);
|
||||||
|
auto one = rewriter.create<ConstantOp>(loc, FloatAttr::get(elementType, 1));
|
||||||
|
auto add = rewriter.create<AddFOp>(loc, abs, one);
|
||||||
|
auto result = rewriter.create<DivFOp>(loc, operand, add);
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Scalar unary ops for lowering ONNXMaxOp
|
// Scalar unary ops for lowering ONNXMaxOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -1214,6 +1254,8 @@ void FrontendToKrnlLoweringPass::runOnModule() {
|
||||||
ONNXElementwiseUnaryOpLowering<mlir::ONNXLeakyReluOp>,
|
ONNXElementwiseUnaryOpLowering<mlir::ONNXLeakyReluOp>,
|
||||||
ONNXElementwiseUnaryOpLowering<mlir::ONNXSeluOp>,
|
ONNXElementwiseUnaryOpLowering<mlir::ONNXSeluOp>,
|
||||||
ONNXElementwiseUnaryOpLowering<mlir::ONNXReciprocalOp>,
|
ONNXElementwiseUnaryOpLowering<mlir::ONNXReciprocalOp>,
|
||||||
|
ONNXElementwiseUnaryOpLowering<mlir::ONNXSoftplusOp>,
|
||||||
|
ONNXElementwiseUnaryOpLowering<mlir::ONNXSoftsignOp>,
|
||||||
ONNXElementwiseVariadicOpLowering<mlir::ONNXAddOp>,
|
ONNXElementwiseVariadicOpLowering<mlir::ONNXAddOp>,
|
||||||
ONNXElementwiseVariadicOpLowering<mlir::ONNXMulOp>,
|
ONNXElementwiseVariadicOpLowering<mlir::ONNXMulOp>,
|
||||||
ONNXElementwiseVariadicOpLowering<mlir::ONNXDivOp>,
|
ONNXElementwiseVariadicOpLowering<mlir::ONNXDivOp>,
|
||||||
|
|
|
@ -101,6 +101,8 @@ public:
|
||||||
op->getName().getStringRef() != "onnx.LeakyRelu" &&
|
op->getName().getStringRef() != "onnx.LeakyRelu" &&
|
||||||
op->getName().getStringRef() != "onnx.Selu" &&
|
op->getName().getStringRef() != "onnx.Selu" &&
|
||||||
op->getName().getStringRef() != "onnx.Reciprocal" &&
|
op->getName().getStringRef() != "onnx.Reciprocal" &&
|
||||||
|
op->getName().getStringRef() != "onnx.Softplus" &&
|
||||||
|
op->getName().getStringRef() != "onnx.Softsign" &&
|
||||||
op->getName().getStringRef() != "onnx.Mul" &&
|
op->getName().getStringRef() != "onnx.Mul" &&
|
||||||
op->getName().getStringRef() != "onnx.Add" &&
|
op->getName().getStringRef() != "onnx.Add" &&
|
||||||
op->getName().getStringRef() != "onnx.Div" &&
|
op->getName().getStringRef() != "onnx.Div" &&
|
||||||
|
|
|
@ -146,6 +146,14 @@ test_to_enable = [
|
||||||
# Reciprocal Op:
|
# Reciprocal Op:
|
||||||
"test_reciprocal_cpu",
|
"test_reciprocal_cpu",
|
||||||
"test_reciprocal_example_cpu",
|
"test_reciprocal_example_cpu",
|
||||||
|
|
||||||
|
# SoftplusOp:
|
||||||
|
"test_softplus_cpu",
|
||||||
|
"test_softplus_example_cpu",
|
||||||
|
|
||||||
|
# SoftsignOp:
|
||||||
|
"test_softsign_cpu",
|
||||||
|
"test_softsign_example_cpu",
|
||||||
]
|
]
|
||||||
|
|
||||||
# Extract name of all test cases.
|
# Extract name of all test cases.
|
||||||
|
|
|
@ -508,6 +508,50 @@ func @test_reciprocal(%arg0 : tensor<?x10xf32>) -> tensor<*xf32> {
|
||||||
// CHECK: return [[RES]] : memref<?x10xf32>
|
// CHECK: return [[RES]] : memref<?x10xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func @test_softplus(%arg0 : tensor<?x10xf32>) -> tensor<*xf32> {
|
||||||
|
%0 = "onnx.Softplus"(%arg0) : (tensor<?x10xf32>) -> tensor<*xf32>
|
||||||
|
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||||
|
|
||||||
|
// CHECK-LABEL: test_softplus
|
||||||
|
// CHECK: [[DIM_0:%.+]] = dim %arg0, 0 : memref<?x10xf32>
|
||||||
|
// CHECK: [[RES:%.+]] = alloc([[DIM_0]]) : memref<?x10xf32>
|
||||||
|
// CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2
|
||||||
|
// CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops {
|
||||||
|
// CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1
|
||||||
|
// CHECK: } : () -> (!krnl.loop, !krnl.loop)
|
||||||
|
// CHECK: [[DIM_2:%.+]] = dim %arg0, 0 : memref<?x10xf32>
|
||||||
|
// CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg1 = 0 to [[DIM_2]], [[DEF_LOOPS]]#1 -> %arg2 = 0 to 10) {
|
||||||
|
// CHECK: [[LOAD:%.+]] = load %arg0[%arg1, %arg2] : memref<?x10xf32>
|
||||||
|
// CHECK: [[EXP:%.+]] = exp [[LOAD]] : f32
|
||||||
|
// CHECK: [[ONE:%.+]] = constant {{1.+}} : f32
|
||||||
|
// CHECK: [[ADD:%.+]] = addf [[EXP]], [[ONE]] : f32
|
||||||
|
// CHECK: [[SOFTPLUS_RES:%.+]] = log [[ADD]] : f32
|
||||||
|
// CHECK: store [[SOFTPLUS_RES]], [[RES]][%arg1, %arg2] : memref<?x10xf32>
|
||||||
|
// CHECK: return [[RES]] : memref<?x10xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
func @test_softsign(%arg0 : tensor<?x10xf32>) -> tensor<*xf32> {
|
||||||
|
%0 = "onnx.Softsign"(%arg0) : (tensor<?x10xf32>) -> tensor<*xf32>
|
||||||
|
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||||
|
|
||||||
|
// CHECK-LABEL: test_softsign
|
||||||
|
// CHECK: [[DIM_0:%.+]] = dim %arg0, 0 : memref<?x10xf32>
|
||||||
|
// CHECK: [[RES:%.+]] = alloc([[DIM_0]]) : memref<?x10xf32>
|
||||||
|
// CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2
|
||||||
|
// CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops {
|
||||||
|
// CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1
|
||||||
|
// CHECK: } : () -> (!krnl.loop, !krnl.loop)
|
||||||
|
// CHECK: [[DIM_2:%.+]] = dim %arg0, 0 : memref<?x10xf32>
|
||||||
|
// CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg1 = 0 to [[DIM_2]], [[DEF_LOOPS]]#1 -> %arg2 = 0 to 10) {
|
||||||
|
// CHECK: [[LOAD:%.+]] = load %arg0[%arg1, %arg2] : memref<?x10xf32>
|
||||||
|
// CHECK: [[ABS:%.+]] = absf [[LOAD]] : f32
|
||||||
|
// CHECK: [[ONE:%.+]] = constant {{1.+}} : f32
|
||||||
|
// CHECK: [[ADD:%.+]] = addf [[ABS]], [[ONE]] : f32
|
||||||
|
// CHECK: [[SOFTSIGN_RES:%.+]] = divf [[LOAD]], [[ADD]] : f32
|
||||||
|
// CHECK: store [[SOFTSIGN_RES]], [[RES]][%arg1, %arg2] : memref<?x10xf32>
|
||||||
|
// CHECK: return [[RES]] : memref<?x10xf32>
|
||||||
|
}
|
||||||
|
|
||||||
func @test_add_with_broadcasting(%arg0 : tensor<?xf32>, %arg1 : tensor<?x10xf32>) -> tensor<*xf32> {
|
func @test_add_with_broadcasting(%arg0 : tensor<?xf32>, %arg1 : tensor<?x10xf32>) -> tensor<*xf32> {
|
||||||
%0 = "onnx.Add"(%arg0, %arg1) : (tensor<?xf32>, tensor<?x10xf32>) -> tensor<*xf32>
|
%0 = "onnx.Add"(%arg0, %arg1) : (tensor<?xf32>, tensor<?x10xf32>) -> tensor<*xf32>
|
||||||
"std.return"(%0) : (tensor<*xf32>) -> ()
|
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||||
|
|
Loading…
Reference in New Issue