Rewrite tanh using TanhOp, add log, cos
This commit is contained in:
parent
322002f509
commit
3d4ad52011
|
@ -267,7 +267,7 @@ def gen_schema(schema) :
|
|||
'Add', 'Mul', 'Div', 'Sub', 'And', 'Or', 'Xor',
|
||||
'Sum', 'Max', 'Min', 'MatMul', 'Gemm', 'LeakyRelu',
|
||||
'Elu', 'Selu', 'HardSigmoid', 'Reshape', 'Reciprocal',
|
||||
'Identity']
|
||||
'Identity', 'Cos', 'Log']
|
||||
CanonicalList=['Add', 'Identity']
|
||||
line_indent = ' '
|
||||
|
||||
|
|
|
@ -91,6 +91,18 @@ void ONNXCoshOp::inferShapes() {
|
|||
getResult()->setType(getOperand()->getType());
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Cos
|
||||
/// Infer the output shape of the ONNXCosOp. This method is required by the
|
||||
/// shape inference interface.
|
||||
void ONNXCosOp::inferShapes() { getResult()->setType(getOperand()->getType()); }
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Log
|
||||
/// Infer the output shape of the ONNXLogOp. This method is required by the
|
||||
/// shape inference interface.
|
||||
void ONNXLogOp::inferShapes() { getResult()->setType(getOperand()->getType()); }
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// HardSigmoid
|
||||
/// Infer the output shape of the ONNXHardSigmoidOp. This method is required by
|
||||
|
|
|
@ -361,7 +361,7 @@ def ONNXConvTransposeOp:ONNX_Op<"ConvTranspose",
|
|||
}
|
||||
|
||||
def ONNXCosOp:ONNX_Op<"Cos",
|
||||
[NoSideEffect]> {
|
||||
[NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
|
||||
let summary = "ONNX Cos operation";
|
||||
let description = [{
|
||||
"Calculates the cosine of the given input tensor, element-wise."
|
||||
|
@ -1216,7 +1216,7 @@ def ONNXLessOp:ONNX_Op<"Less",
|
|||
}
|
||||
|
||||
def ONNXLogOp:ONNX_Op<"Log",
|
||||
[NoSideEffect]> {
|
||||
[NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
|
||||
let summary = "ONNX Log operation";
|
||||
let description = [{
|
||||
"Calculates the natural log of the given input tensor, element-wise."
|
||||
|
|
|
@ -286,6 +286,24 @@ struct ScalarOp<ONNXSumOp> {
|
|||
using IOp = AddIOp;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct ScalarOp<ONNXTanhOp> {
|
||||
using FOp = TanhOp;
|
||||
using IOp = TanhOp; // not use
|
||||
};
|
||||
|
||||
template <>
|
||||
struct ScalarOp<ONNXCosOp> {
|
||||
using FOp = CosOp;
|
||||
using IOp = CosOp; // not use
|
||||
};
|
||||
|
||||
template <>
|
||||
struct ScalarOp<ONNXLogOp> {
|
||||
using FOp = LogOp;
|
||||
using IOp = LogOp; // not use
|
||||
};
|
||||
|
||||
template <typename ElementwiseNaryOp>
|
||||
using ScalarFOp = typename ScalarOp<ElementwiseNaryOp>::FOp;
|
||||
template <typename ElementwiseNaryOp>
|
||||
|
@ -314,28 +332,6 @@ Value mapToLowerScalarOp(Operation *op, ArrayRef<Type> result_types,
|
|||
}
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Scalar unary ops for lowering ONNXTanhOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
template <>
|
||||
Value mapToLowerScalarOp<ONNXTanhOp>(Operation *op, ArrayRef<Type> result_types,
|
||||
ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) {
|
||||
// ONNXTanhOp(%X) = DivFOp(SubFOp(ExpOp(%X), ExpOp(NegFOp(%X))),
|
||||
// AddFOp(ExpOp(%X), ExpOp(NegFOp(%X))))
|
||||
auto loc = op->getLoc();
|
||||
Value operand = operands[0];
|
||||
|
||||
auto zero = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(0.0f));
|
||||
auto neg = rewriter.create<SubFOp>(loc, zero, operand);
|
||||
auto exp = rewriter.create<ExpOp>(loc, operand);
|
||||
auto negExp = rewriter.create<ExpOp>(loc, neg);
|
||||
auto diff = rewriter.create<SubFOp>(loc, exp, negExp);
|
||||
auto sum = rewriter.create<AddFOp>(loc, exp, negExp);
|
||||
auto result = rewriter.create<DivFOp>(loc, diff, sum);
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Scalar unary ops for lowering ONNXSinhOp
|
||||
|
@ -982,6 +978,8 @@ void FrontendToKrnlLoweringPass::runOnModule() {
|
|||
ONNXElementwiseUnaryOpLowering<mlir::ONNXTanhOp>,
|
||||
ONNXElementwiseUnaryOpLowering<mlir::ONNXSinhOp>,
|
||||
ONNXElementwiseUnaryOpLowering<mlir::ONNXCoshOp>,
|
||||
ONNXElementwiseUnaryOpLowering<mlir::ONNXCosOp>,
|
||||
ONNXElementwiseUnaryOpLowering<mlir::ONNXLogOp>,
|
||||
ONNXElementwiseUnaryOpLowering<mlir::ONNXSigmoidOp>,
|
||||
ONNXElementwiseUnaryOpLowering<mlir::ONNXHardSigmoidOp>,
|
||||
ONNXElementwiseUnaryOpLowering<mlir::ONNXEluOp>,
|
||||
|
|
|
@ -92,6 +92,8 @@ public:
|
|||
op->getName().getStringRef() != "onnx.Tanh" &&
|
||||
op->getName().getStringRef() != "onnx.Sinh" &&
|
||||
op->getName().getStringRef() != "onnx.Cosh" &&
|
||||
op->getName().getStringRef() != "onnx.Cos" &&
|
||||
op->getName().getStringRef() != "onnx.Log" &&
|
||||
op->getName().getStringRef() != "onnx.Sigmoid" &&
|
||||
op->getName().getStringRef() != "onnx.HardSigmoid" &&
|
||||
op->getName().getStringRef() != "onnx.Elu" &&
|
||||
|
|
|
@ -159,14 +159,8 @@ func @test_tanh(%arg0 : tensor<?x10xf32>) -> tensor<*xf32> {
|
|||
// CHECK: [[DIM_2:%.+]] = dim %arg0, 0 : memref<?x10xf32>
|
||||
// CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg1 = 0 to [[DIM_2]], [[DEF_LOOPS]]#1 -> %arg2 = 0 to 10) {
|
||||
// CHECK: [[LOAD:%.+]] = load %arg0[%arg1, %arg2] : memref<?x10xf32>
|
||||
// CHECK: [[ZERO:%.+]] = constant {{0.+}} : f32
|
||||
// CHECK: [[NLOAD:%.+]] = subf [[ZERO]], [[LOAD]] : f32
|
||||
// CHECK: [[EXP:%.+]] = exp [[LOAD]] : f32
|
||||
// CHECK: [[NEXP:%.+]] = exp [[NLOAD]] : f32
|
||||
// CHECK: [[DIVIDEND:%.+]] = subf [[EXP]], [[NEXP]] : f32
|
||||
// CHECK: [[DIVISOR:%.+]] = addf [[EXP]], [[NEXP]] : f32
|
||||
// CHECK: [[TANH_RES:%.+]] = divf [[DIVIDEND]], [[DIVISOR]] : f32
|
||||
// CHECK: store [[TANH_RES]], [[RES]][%arg1, %arg2] : memref<?x10xf32>
|
||||
// CHECK: [[TANH:%.+]] = tanh [[LOAD]] : f32
|
||||
// CHECK: store [[TANH]], [[RES]][%arg1, %arg2] : memref<?x10xf32>
|
||||
// CHECK: return [[RES]] : memref<?x10xf32>
|
||||
}
|
||||
|
||||
|
@ -220,6 +214,44 @@ func @test_cosh(%arg0 : tensor<?x10xf32>) -> tensor<*xf32> {
|
|||
// CHECK: return [[RES]] : memref<?x10xf32>
|
||||
}
|
||||
|
||||
func @test_cos(%arg0 : tensor<?x10xf32>) -> tensor<*xf32> {
|
||||
%0 = "onnx.Cos"(%arg0) : (tensor<?x10xf32>) -> tensor<*xf32>
|
||||
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||
|
||||
// CHECK-LABEL: test_cos
|
||||
// CHECK: [[DIM_0:%.+]] = dim %arg0, 0 : memref<?x10xf32>
|
||||
// CHECK: [[RES:%.+]] = alloc([[DIM_0]]) : memref<?x10xf32>
|
||||
// CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2
|
||||
// CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops {
|
||||
// CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1
|
||||
// CHECK: } : () -> (!krnl.loop, !krnl.loop)
|
||||
// CHECK: [[DIM_2:%.+]] = dim %arg0, 0 : memref<?x10xf32>
|
||||
// CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg1 = 0 to [[DIM_2]], [[DEF_LOOPS]]#1 -> %arg2 = 0 to 10) {
|
||||
// CHECK: [[LOAD:%.+]] = load %arg0[%arg1, %arg2] : memref<?x10xf32>
|
||||
// CHECK: [[COS:%.+]] = cos [[LOAD]] : f32
|
||||
// CHECK: store [[COS]], [[RES]][%arg1, %arg2] : memref<?x10xf32>
|
||||
// CHECK: return [[RES]] : memref<?x10xf32>
|
||||
}
|
||||
|
||||
func @test_log(%arg0 : tensor<?x10xf32>) -> tensor<*xf32> {
|
||||
%0 = "onnx.Log"(%arg0) : (tensor<?x10xf32>) -> tensor<*xf32>
|
||||
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||
|
||||
// CHECK-LABEL: test_log
|
||||
// CHECK: [[DIM_0:%.+]] = dim %arg0, 0 : memref<?x10xf32>
|
||||
// CHECK: [[RES:%.+]] = alloc([[DIM_0]]) : memref<?x10xf32>
|
||||
// CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2
|
||||
// CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops {
|
||||
// CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1
|
||||
// CHECK: } : () -> (!krnl.loop, !krnl.loop)
|
||||
// CHECK: [[DIM_2:%.+]] = dim %arg0, 0 : memref<?x10xf32>
|
||||
// CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg1 = 0 to [[DIM_2]], [[DEF_LOOPS]]#1 -> %arg2 = 0 to 10) {
|
||||
// CHECK: [[LOAD:%.+]] = load %arg0[%arg1, %arg2] : memref<?x10xf32>
|
||||
// CHECK: [[LOG:%.+]] = log [[LOAD]] : f32
|
||||
// CHECK: store [[LOG]], [[RES]][%arg1, %arg2] : memref<?x10xf32>
|
||||
// CHECK: return [[RES]] : memref<?x10xf32>
|
||||
}
|
||||
|
||||
func @test_sigmoid(%arg0 : tensor<?x10xf32>) -> tensor<*xf32> {
|
||||
%0 = "onnx.Sigmoid"(%arg0) : (tensor<?x10xf32>) -> tensor<*xf32>
|
||||
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||
|
|
Loading…
Reference in New Issue