Rewrite tanh using TanhOp, add log, cos
This commit is contained in:
parent
322002f509
commit
3d4ad52011
|
@ -267,7 +267,7 @@ def gen_schema(schema) :
|
||||||
'Add', 'Mul', 'Div', 'Sub', 'And', 'Or', 'Xor',
|
'Add', 'Mul', 'Div', 'Sub', 'And', 'Or', 'Xor',
|
||||||
'Sum', 'Max', 'Min', 'MatMul', 'Gemm', 'LeakyRelu',
|
'Sum', 'Max', 'Min', 'MatMul', 'Gemm', 'LeakyRelu',
|
||||||
'Elu', 'Selu', 'HardSigmoid', 'Reshape', 'Reciprocal',
|
'Elu', 'Selu', 'HardSigmoid', 'Reshape', 'Reciprocal',
|
||||||
'Identity']
|
'Identity', 'Cos', 'Log']
|
||||||
CanonicalList=['Add', 'Identity']
|
CanonicalList=['Add', 'Identity']
|
||||||
line_indent = ' '
|
line_indent = ' '
|
||||||
|
|
||||||
|
|
|
@ -91,6 +91,18 @@ void ONNXCoshOp::inferShapes() {
|
||||||
getResult()->setType(getOperand()->getType());
|
getResult()->setType(getOperand()->getType());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// Cos
|
||||||
|
/// Infer the output shape of the ONNXCosOp. This method is required by the
|
||||||
|
/// shape inference interface.
|
||||||
|
void ONNXCosOp::inferShapes() { getResult()->setType(getOperand()->getType()); }
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// Log
|
||||||
|
/// Infer the output shape of the ONNXLogOp. This method is required by the
|
||||||
|
/// shape inference interface.
|
||||||
|
void ONNXLogOp::inferShapes() { getResult()->setType(getOperand()->getType()); }
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// HardSigmoid
|
// HardSigmoid
|
||||||
/// Infer the output shape of the ONNXHardSigmoidOp. This method is required by
|
/// Infer the output shape of the ONNXHardSigmoidOp. This method is required by
|
||||||
|
|
|
@ -361,7 +361,7 @@ def ONNXConvTransposeOp:ONNX_Op<"ConvTranspose",
|
||||||
}
|
}
|
||||||
|
|
||||||
def ONNXCosOp:ONNX_Op<"Cos",
|
def ONNXCosOp:ONNX_Op<"Cos",
|
||||||
[NoSideEffect]> {
|
[NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
|
||||||
let summary = "ONNX Cos operation";
|
let summary = "ONNX Cos operation";
|
||||||
let description = [{
|
let description = [{
|
||||||
"Calculates the cosine of the given input tensor, element-wise."
|
"Calculates the cosine of the given input tensor, element-wise."
|
||||||
|
@ -1216,7 +1216,7 @@ def ONNXLessOp:ONNX_Op<"Less",
|
||||||
}
|
}
|
||||||
|
|
||||||
def ONNXLogOp:ONNX_Op<"Log",
|
def ONNXLogOp:ONNX_Op<"Log",
|
||||||
[NoSideEffect]> {
|
[NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
|
||||||
let summary = "ONNX Log operation";
|
let summary = "ONNX Log operation";
|
||||||
let description = [{
|
let description = [{
|
||||||
"Calculates the natural log of the given input tensor, element-wise."
|
"Calculates the natural log of the given input tensor, element-wise."
|
||||||
|
|
|
@ -286,6 +286,24 @@ struct ScalarOp<ONNXSumOp> {
|
||||||
using IOp = AddIOp;
|
using IOp = AddIOp;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct ScalarOp<ONNXTanhOp> {
|
||||||
|
using FOp = TanhOp;
|
||||||
|
using IOp = TanhOp; // not use
|
||||||
|
};
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct ScalarOp<ONNXCosOp> {
|
||||||
|
using FOp = CosOp;
|
||||||
|
using IOp = CosOp; // not use
|
||||||
|
};
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct ScalarOp<ONNXLogOp> {
|
||||||
|
using FOp = LogOp;
|
||||||
|
using IOp = LogOp; // not use
|
||||||
|
};
|
||||||
|
|
||||||
template <typename ElementwiseNaryOp>
|
template <typename ElementwiseNaryOp>
|
||||||
using ScalarFOp = typename ScalarOp<ElementwiseNaryOp>::FOp;
|
using ScalarFOp = typename ScalarOp<ElementwiseNaryOp>::FOp;
|
||||||
template <typename ElementwiseNaryOp>
|
template <typename ElementwiseNaryOp>
|
||||||
|
@ -314,28 +332,6 @@ Value mapToLowerScalarOp(Operation *op, ArrayRef<Type> result_types,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
// Scalar unary ops for lowering ONNXTanhOp
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
template <>
|
|
||||||
Value mapToLowerScalarOp<ONNXTanhOp>(Operation *op, ArrayRef<Type> result_types,
|
|
||||||
ArrayRef<Value> operands,
|
|
||||||
ConversionPatternRewriter &rewriter) {
|
|
||||||
// ONNXTanhOp(%X) = DivFOp(SubFOp(ExpOp(%X), ExpOp(NegFOp(%X))),
|
|
||||||
// AddFOp(ExpOp(%X), ExpOp(NegFOp(%X))))
|
|
||||||
auto loc = op->getLoc();
|
|
||||||
Value operand = operands[0];
|
|
||||||
|
|
||||||
auto zero = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(0.0f));
|
|
||||||
auto neg = rewriter.create<SubFOp>(loc, zero, operand);
|
|
||||||
auto exp = rewriter.create<ExpOp>(loc, operand);
|
|
||||||
auto negExp = rewriter.create<ExpOp>(loc, neg);
|
|
||||||
auto diff = rewriter.create<SubFOp>(loc, exp, negExp);
|
|
||||||
auto sum = rewriter.create<AddFOp>(loc, exp, negExp);
|
|
||||||
auto result = rewriter.create<DivFOp>(loc, diff, sum);
|
|
||||||
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Scalar unary ops for lowering ONNXSinhOp
|
// Scalar unary ops for lowering ONNXSinhOp
|
||||||
|
@ -982,6 +978,8 @@ void FrontendToKrnlLoweringPass::runOnModule() {
|
||||||
ONNXElementwiseUnaryOpLowering<mlir::ONNXTanhOp>,
|
ONNXElementwiseUnaryOpLowering<mlir::ONNXTanhOp>,
|
||||||
ONNXElementwiseUnaryOpLowering<mlir::ONNXSinhOp>,
|
ONNXElementwiseUnaryOpLowering<mlir::ONNXSinhOp>,
|
||||||
ONNXElementwiseUnaryOpLowering<mlir::ONNXCoshOp>,
|
ONNXElementwiseUnaryOpLowering<mlir::ONNXCoshOp>,
|
||||||
|
ONNXElementwiseUnaryOpLowering<mlir::ONNXCosOp>,
|
||||||
|
ONNXElementwiseUnaryOpLowering<mlir::ONNXLogOp>,
|
||||||
ONNXElementwiseUnaryOpLowering<mlir::ONNXSigmoidOp>,
|
ONNXElementwiseUnaryOpLowering<mlir::ONNXSigmoidOp>,
|
||||||
ONNXElementwiseUnaryOpLowering<mlir::ONNXHardSigmoidOp>,
|
ONNXElementwiseUnaryOpLowering<mlir::ONNXHardSigmoidOp>,
|
||||||
ONNXElementwiseUnaryOpLowering<mlir::ONNXEluOp>,
|
ONNXElementwiseUnaryOpLowering<mlir::ONNXEluOp>,
|
||||||
|
|
|
@ -92,6 +92,8 @@ public:
|
||||||
op->getName().getStringRef() != "onnx.Tanh" &&
|
op->getName().getStringRef() != "onnx.Tanh" &&
|
||||||
op->getName().getStringRef() != "onnx.Sinh" &&
|
op->getName().getStringRef() != "onnx.Sinh" &&
|
||||||
op->getName().getStringRef() != "onnx.Cosh" &&
|
op->getName().getStringRef() != "onnx.Cosh" &&
|
||||||
|
op->getName().getStringRef() != "onnx.Cos" &&
|
||||||
|
op->getName().getStringRef() != "onnx.Log" &&
|
||||||
op->getName().getStringRef() != "onnx.Sigmoid" &&
|
op->getName().getStringRef() != "onnx.Sigmoid" &&
|
||||||
op->getName().getStringRef() != "onnx.HardSigmoid" &&
|
op->getName().getStringRef() != "onnx.HardSigmoid" &&
|
||||||
op->getName().getStringRef() != "onnx.Elu" &&
|
op->getName().getStringRef() != "onnx.Elu" &&
|
||||||
|
|
|
@ -159,14 +159,8 @@ func @test_tanh(%arg0 : tensor<?x10xf32>) -> tensor<*xf32> {
|
||||||
// CHECK: [[DIM_2:%.+]] = dim %arg0, 0 : memref<?x10xf32>
|
// CHECK: [[DIM_2:%.+]] = dim %arg0, 0 : memref<?x10xf32>
|
||||||
// CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg1 = 0 to [[DIM_2]], [[DEF_LOOPS]]#1 -> %arg2 = 0 to 10) {
|
// CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg1 = 0 to [[DIM_2]], [[DEF_LOOPS]]#1 -> %arg2 = 0 to 10) {
|
||||||
// CHECK: [[LOAD:%.+]] = load %arg0[%arg1, %arg2] : memref<?x10xf32>
|
// CHECK: [[LOAD:%.+]] = load %arg0[%arg1, %arg2] : memref<?x10xf32>
|
||||||
// CHECK: [[ZERO:%.+]] = constant {{0.+}} : f32
|
// CHECK: [[TANH:%.+]] = tanh [[LOAD]] : f32
|
||||||
// CHECK: [[NLOAD:%.+]] = subf [[ZERO]], [[LOAD]] : f32
|
// CHECK: store [[TANH]], [[RES]][%arg1, %arg2] : memref<?x10xf32>
|
||||||
// CHECK: [[EXP:%.+]] = exp [[LOAD]] : f32
|
|
||||||
// CHECK: [[NEXP:%.+]] = exp [[NLOAD]] : f32
|
|
||||||
// CHECK: [[DIVIDEND:%.+]] = subf [[EXP]], [[NEXP]] : f32
|
|
||||||
// CHECK: [[DIVISOR:%.+]] = addf [[EXP]], [[NEXP]] : f32
|
|
||||||
// CHECK: [[TANH_RES:%.+]] = divf [[DIVIDEND]], [[DIVISOR]] : f32
|
|
||||||
// CHECK: store [[TANH_RES]], [[RES]][%arg1, %arg2] : memref<?x10xf32>
|
|
||||||
// CHECK: return [[RES]] : memref<?x10xf32>
|
// CHECK: return [[RES]] : memref<?x10xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -220,6 +214,44 @@ func @test_cosh(%arg0 : tensor<?x10xf32>) -> tensor<*xf32> {
|
||||||
// CHECK: return [[RES]] : memref<?x10xf32>
|
// CHECK: return [[RES]] : memref<?x10xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func @test_cos(%arg0 : tensor<?x10xf32>) -> tensor<*xf32> {
|
||||||
|
%0 = "onnx.Cos"(%arg0) : (tensor<?x10xf32>) -> tensor<*xf32>
|
||||||
|
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||||
|
|
||||||
|
// CHECK-LABEL: test_cos
|
||||||
|
// CHECK: [[DIM_0:%.+]] = dim %arg0, 0 : memref<?x10xf32>
|
||||||
|
// CHECK: [[RES:%.+]] = alloc([[DIM_0]]) : memref<?x10xf32>
|
||||||
|
// CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2
|
||||||
|
// CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops {
|
||||||
|
// CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1
|
||||||
|
// CHECK: } : () -> (!krnl.loop, !krnl.loop)
|
||||||
|
// CHECK: [[DIM_2:%.+]] = dim %arg0, 0 : memref<?x10xf32>
|
||||||
|
// CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg1 = 0 to [[DIM_2]], [[DEF_LOOPS]]#1 -> %arg2 = 0 to 10) {
|
||||||
|
// CHECK: [[LOAD:%.+]] = load %arg0[%arg1, %arg2] : memref<?x10xf32>
|
||||||
|
// CHECK: [[COS:%.+]] = cos [[LOAD]] : f32
|
||||||
|
// CHECK: store [[COS]], [[RES]][%arg1, %arg2] : memref<?x10xf32>
|
||||||
|
// CHECK: return [[RES]] : memref<?x10xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
func @test_log(%arg0 : tensor<?x10xf32>) -> tensor<*xf32> {
|
||||||
|
%0 = "onnx.Log"(%arg0) : (tensor<?x10xf32>) -> tensor<*xf32>
|
||||||
|
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||||
|
|
||||||
|
// CHECK-LABEL: test_log
|
||||||
|
// CHECK: [[DIM_0:%.+]] = dim %arg0, 0 : memref<?x10xf32>
|
||||||
|
// CHECK: [[RES:%.+]] = alloc([[DIM_0]]) : memref<?x10xf32>
|
||||||
|
// CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2
|
||||||
|
// CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops {
|
||||||
|
// CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1
|
||||||
|
// CHECK: } : () -> (!krnl.loop, !krnl.loop)
|
||||||
|
// CHECK: [[DIM_2:%.+]] = dim %arg0, 0 : memref<?x10xf32>
|
||||||
|
// CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg1 = 0 to [[DIM_2]], [[DEF_LOOPS]]#1 -> %arg2 = 0 to 10) {
|
||||||
|
// CHECK: [[LOAD:%.+]] = load %arg0[%arg1, %arg2] : memref<?x10xf32>
|
||||||
|
// CHECK: [[LOG:%.+]] = log [[LOAD]] : f32
|
||||||
|
// CHECK: store [[LOG]], [[RES]][%arg1, %arg2] : memref<?x10xf32>
|
||||||
|
// CHECK: return [[RES]] : memref<?x10xf32>
|
||||||
|
}
|
||||||
|
|
||||||
func @test_sigmoid(%arg0 : tensor<?x10xf32>) -> tensor<*xf32> {
|
func @test_sigmoid(%arg0 : tensor<?x10xf32>) -> tensor<*xf32> {
|
||||||
%0 = "onnx.Sigmoid"(%arg0) : (tensor<?x10xf32>) -> tensor<*xf32>
|
%0 = "onnx.Sigmoid"(%arg0) : (tensor<?x10xf32>) -> tensor<*xf32>
|
||||||
"std.return"(%0) : (tensor<*xf32>) -> ()
|
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||||
|
|
Loading…
Reference in New Issue