Merge pull request #15 from tungld/tanh_cos_log

Rewrite tanh using TanhOp, and add support for log, cos
This commit is contained in:
Tian Jin 2020-01-08 15:48:58 -05:00 committed by GitHub
commit caeba371fb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 80 additions and 50 deletions

View File

@ -267,7 +267,7 @@ def gen_schema(schema) :
'Add', 'Mul', 'Div', 'Sub', 'And', 'Or', 'Xor', 'Add', 'Mul', 'Div', 'Sub', 'And', 'Or', 'Xor',
'Sum', 'Max', 'Min', 'MatMul', 'Gemm', 'LeakyRelu', 'Sum', 'Max', 'Min', 'MatMul', 'Gemm', 'LeakyRelu',
'Elu', 'Selu', 'HardSigmoid', 'Reshape', 'Reciprocal', 'Elu', 'Selu', 'HardSigmoid', 'Reshape', 'Reciprocal',
'Identity'] 'Identity', 'Cos', 'Log']
CanonicalList=['Add', 'Identity'] CanonicalList=['Add', 'Identity']
line_indent = ' ' line_indent = ' '

View File

@ -91,6 +91,18 @@ void ONNXCoshOp::inferShapes() {
getResult()->setType(getOperand()->getType()); getResult()->setType(getOperand()->getType());
} }
//===----------------------------------------------------------------------===//
// Cos
/// Infer the output shape of the ONNXCosOp. This method is required by the
/// shape inference interface.
void ONNXCosOp::inferShapes() { getResult()->setType(getOperand()->getType()); }
//===----------------------------------------------------------------------===//
// Log
/// Infer the output shape of the ONNXLogOp. This method is required by the
/// shape inference interface.
void ONNXLogOp::inferShapes() { getResult()->setType(getOperand()->getType()); }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// HardSigmoid // HardSigmoid
/// Infer the output shape of the ONNXHardSigmoidOp. This method is required by /// Infer the output shape of the ONNXHardSigmoidOp. This method is required by

View File

@ -361,7 +361,7 @@ def ONNXConvTransposeOp:ONNX_Op<"ConvTranspose",
} }
def ONNXCosOp:ONNX_Op<"Cos", def ONNXCosOp:ONNX_Op<"Cos",
[NoSideEffect]> { [NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
let summary = "ONNX Cos operation"; let summary = "ONNX Cos operation";
let description = [{ let description = [{
"Calculates the cosine of the given input tensor, element-wise." "Calculates the cosine of the given input tensor, element-wise."
@ -1216,7 +1216,7 @@ def ONNXLessOp:ONNX_Op<"Less",
} }
def ONNXLogOp:ONNX_Op<"Log", def ONNXLogOp:ONNX_Op<"Log",
[NoSideEffect]> { [NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
let summary = "ONNX Log operation"; let summary = "ONNX Log operation";
let description = [{ let description = [{
"Calculates the natural log of the given input tensor, element-wise." "Calculates the natural log of the given input tensor, element-wise."

View File

@ -286,6 +286,24 @@ struct ScalarOp<ONNXSumOp> {
using IOp = AddIOp; using IOp = AddIOp;
}; };
template <>
struct ScalarOp<ONNXTanhOp> {
using FOp = TanhOp;
using IOp = TanhOp; // not use
};
template <>
struct ScalarOp<ONNXCosOp> {
using FOp = CosOp;
using IOp = CosOp; // not use
};
template <>
struct ScalarOp<ONNXLogOp> {
using FOp = LogOp;
using IOp = LogOp; // not use
};
template <typename ElementwiseNaryOp> template <typename ElementwiseNaryOp>
using ScalarFOp = typename ScalarOp<ElementwiseNaryOp>::FOp; using ScalarFOp = typename ScalarOp<ElementwiseNaryOp>::FOp;
template <typename ElementwiseNaryOp> template <typename ElementwiseNaryOp>
@ -314,30 +332,6 @@ Value mapToLowerScalarOp(Operation *op, ArrayRef<Type> result_types,
} }
} }
//===----------------------------------------------------------------------===//
// Scalar unary ops for lowering ONNXTanhOp
//===----------------------------------------------------------------------===//
template <>
Value mapToLowerScalarOp<ONNXTanhOp>(Operation *op, ArrayRef<Type> result_types,
ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) {
// ONNXTanhOp(%X) = DivFOp(SubFOp(ExpOp(%X), ExpOp(NegFOp(%X))),
// AddFOp(ExpOp(%X), ExpOp(NegFOp(%X))))
auto loc = op->getLoc();
Value operand = operands[0];
auto elementType = result_types[0];
auto zero = rewriter.create<ConstantOp>(loc, FloatAttr::get(elementType, 0));
auto neg = rewriter.create<SubFOp>(loc, zero, operand);
auto exp = rewriter.create<ExpOp>(loc, operand);
auto negExp = rewriter.create<ExpOp>(loc, neg);
auto diff = rewriter.create<SubFOp>(loc, exp, negExp);
auto sum = rewriter.create<AddFOp>(loc, exp, negExp);
auto result = rewriter.create<DivFOp>(loc, diff, sum);
return result;
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Scalar unary ops for lowering ONNXSinhOp // Scalar unary ops for lowering ONNXSinhOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -992,6 +986,8 @@ void FrontendToKrnlLoweringPass::runOnModule() {
ONNXElementwiseUnaryOpLowering<mlir::ONNXTanhOp>, ONNXElementwiseUnaryOpLowering<mlir::ONNXTanhOp>,
ONNXElementwiseUnaryOpLowering<mlir::ONNXSinhOp>, ONNXElementwiseUnaryOpLowering<mlir::ONNXSinhOp>,
ONNXElementwiseUnaryOpLowering<mlir::ONNXCoshOp>, ONNXElementwiseUnaryOpLowering<mlir::ONNXCoshOp>,
ONNXElementwiseUnaryOpLowering<mlir::ONNXCosOp>,
ONNXElementwiseUnaryOpLowering<mlir::ONNXLogOp>,
ONNXElementwiseUnaryOpLowering<mlir::ONNXSigmoidOp>, ONNXElementwiseUnaryOpLowering<mlir::ONNXSigmoidOp>,
ONNXElementwiseUnaryOpLowering<mlir::ONNXHardSigmoidOp>, ONNXElementwiseUnaryOpLowering<mlir::ONNXHardSigmoidOp>,
ONNXElementwiseUnaryOpLowering<mlir::ONNXEluOp>, ONNXElementwiseUnaryOpLowering<mlir::ONNXEluOp>,

View File

@ -92,6 +92,8 @@ public:
op->getName().getStringRef() != "onnx.Tanh" && op->getName().getStringRef() != "onnx.Tanh" &&
op->getName().getStringRef() != "onnx.Sinh" && op->getName().getStringRef() != "onnx.Sinh" &&
op->getName().getStringRef() != "onnx.Cosh" && op->getName().getStringRef() != "onnx.Cosh" &&
op->getName().getStringRef() != "onnx.Cos" &&
op->getName().getStringRef() != "onnx.Log" &&
op->getName().getStringRef() != "onnx.Sigmoid" && op->getName().getStringRef() != "onnx.Sigmoid" &&
op->getName().getStringRef() != "onnx.HardSigmoid" && op->getName().getStringRef() != "onnx.HardSigmoid" &&
op->getName().getStringRef() != "onnx.Elu" && op->getName().getStringRef() != "onnx.Elu" &&

View File

@ -159,14 +159,8 @@ func @test_tanh(%arg0 : tensor<?x10xf32>) -> tensor<*xf32> {
// CHECK: [[DIM_2:%.+]] = dim %arg0, 0 : memref<?x10xf32> // CHECK: [[DIM_2:%.+]] = dim %arg0, 0 : memref<?x10xf32>
// CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg1 = 0 to [[DIM_2]], [[DEF_LOOPS]]#1 -> %arg2 = 0 to 10) { // CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg1 = 0 to [[DIM_2]], [[DEF_LOOPS]]#1 -> %arg2 = 0 to 10) {
// CHECK: [[LOAD:%.+]] = load %arg0[%arg1, %arg2] : memref<?x10xf32> // CHECK: [[LOAD:%.+]] = load %arg0[%arg1, %arg2] : memref<?x10xf32>
// CHECK: [[ZERO:%.+]] = constant {{0.+}} : f32 // CHECK: [[TANH:%.+]] = tanh [[LOAD]] : f32
// CHECK: [[NLOAD:%.+]] = subf [[ZERO]], [[LOAD]] : f32 // CHECK: store [[TANH]], [[RES]][%arg1, %arg2] : memref<?x10xf32>
// CHECK: [[EXP:%.+]] = exp [[LOAD]] : f32
// CHECK: [[NEXP:%.+]] = exp [[NLOAD]] : f32
// CHECK: [[DIVIDEND:%.+]] = subf [[EXP]], [[NEXP]] : f32
// CHECK: [[DIVISOR:%.+]] = addf [[EXP]], [[NEXP]] : f32
// CHECK: [[TANH_RES:%.+]] = divf [[DIVIDEND]], [[DIVISOR]] : f32
// CHECK: store [[TANH_RES]], [[RES]][%arg1, %arg2] : memref<?x10xf32>
// CHECK: return [[RES]] : memref<?x10xf32> // CHECK: return [[RES]] : memref<?x10xf32>
} }
@ -220,6 +214,44 @@ func @test_cosh(%arg0 : tensor<?x10xf32>) -> tensor<*xf32> {
// CHECK: return [[RES]] : memref<?x10xf32> // CHECK: return [[RES]] : memref<?x10xf32>
} }
func @test_cos(%arg0 : tensor<?x10xf32>) -> tensor<*xf32> {
%0 = "onnx.Cos"(%arg0) : (tensor<?x10xf32>) -> tensor<*xf32>
"std.return"(%0) : (tensor<*xf32>) -> ()
// CHECK-LABEL: test_cos
// CHECK: [[DIM_0:%.+]] = dim %arg0, 0 : memref<?x10xf32>
// CHECK: [[RES:%.+]] = alloc([[DIM_0]]) : memref<?x10xf32>
// CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2
// CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops {
// CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1
// CHECK: } : () -> (!krnl.loop, !krnl.loop)
// CHECK: [[DIM_2:%.+]] = dim %arg0, 0 : memref<?x10xf32>
// CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg1 = 0 to [[DIM_2]], [[DEF_LOOPS]]#1 -> %arg2 = 0 to 10) {
// CHECK: [[LOAD:%.+]] = load %arg0[%arg1, %arg2] : memref<?x10xf32>
// CHECK: [[COS:%.+]] = cos [[LOAD]] : f32
// CHECK: store [[COS]], [[RES]][%arg1, %arg2] : memref<?x10xf32>
// CHECK: return [[RES]] : memref<?x10xf32>
}
func @test_log(%arg0 : tensor<?x10xf32>) -> tensor<*xf32> {
%0 = "onnx.Log"(%arg0) : (tensor<?x10xf32>) -> tensor<*xf32>
"std.return"(%0) : (tensor<*xf32>) -> ()
// CHECK-LABEL: test_log
// CHECK: [[DIM_0:%.+]] = dim %arg0, 0 : memref<?x10xf32>
// CHECK: [[RES:%.+]] = alloc([[DIM_0]]) : memref<?x10xf32>
// CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2
// CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops {
// CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1
// CHECK: } : () -> (!krnl.loop, !krnl.loop)
// CHECK: [[DIM_2:%.+]] = dim %arg0, 0 : memref<?x10xf32>
// CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg1 = 0 to [[DIM_2]], [[DEF_LOOPS]]#1 -> %arg2 = 0 to 10) {
// CHECK: [[LOAD:%.+]] = load %arg0[%arg1, %arg2] : memref<?x10xf32>
// CHECK: [[LOG:%.+]] = log [[LOAD]] : f32
// CHECK: store [[LOG]], [[RES]][%arg1, %arg2] : memref<?x10xf32>
// CHECK: return [[RES]] : memref<?x10xf32>
}
func @test_sigmoid(%arg0 : tensor<?x10xf32>) -> tensor<*xf32> { func @test_sigmoid(%arg0 : tensor<?x10xf32>) -> tensor<*xf32> {
%0 = "onnx.Sigmoid"(%arg0) : (tensor<?x10xf32>) -> tensor<*xf32> %0 = "onnx.Sigmoid"(%arg0) : (tensor<?x10xf32>) -> tensor<*xf32>
"std.return"(%0) : (tensor<*xf32>) -> () "std.return"(%0) : (tensor<*xf32>) -> ()

View File

@ -315,14 +315,8 @@ func @test_tanh_tanh(%arg0 : tensor<?x10xf32>) -> tensor<*xf32> {
// CHECK: [[DIM_2:%.+]] = dim %arg0, 0 : memref<?x10xf32> // CHECK: [[DIM_2:%.+]] = dim %arg0, 0 : memref<?x10xf32>
// CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg1 = 0 to [[DIM_2]], [[DEF_LOOPS]]#1 -> %arg2 = 0 to 10) { // CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg1 = 0 to [[DIM_2]], [[DEF_LOOPS]]#1 -> %arg2 = 0 to 10) {
// CHECK: [[LOAD:%.+]] = load %arg0[%arg1, %arg2] : memref<?x10xf32> // CHECK: [[LOAD:%.+]] = load %arg0[%arg1, %arg2] : memref<?x10xf32>
// CHECK: [[ZERO:%.+]] = constant {{0.+}} : f32 // CHECK: [[TANH:%.+]] = tanh [[LOAD]] : f32
// CHECK: [[NLOAD:%.+]] = subf [[ZERO]], [[LOAD]] : f32 // CHECK: store [[TANH]], [[RES]][%arg1, %arg2] : memref<?x10xf32>
// CHECK: [[EXP:%.+]] = exp [[LOAD]] : f32
// CHECK: [[NEXP:%.+]] = exp [[NLOAD]] : f32
// CHECK: [[DIVIDEND:%.+]] = subf [[EXP]], [[NEXP]] : f32
// CHECK: [[DIVISOR:%.+]] = addf [[EXP]], [[NEXP]] : f32
// CHECK: [[TANH_RES:%.+]] = divf [[DIVIDEND]], [[DIVISOR]] : f32
// CHECK: store [[TANH_RES]], [[RES]][%arg1, %arg2] : memref<?x10xf32>
/// Second Tanh /// Second Tanh
// CHECK: [[DIM_0:%.+]] = dim [[RES]], 0 : memref<?x10xf32> // CHECK: [[DIM_0:%.+]] = dim [[RES]], 0 : memref<?x10xf32>
@ -334,13 +328,7 @@ func @test_tanh_tanh(%arg0 : tensor<?x10xf32>) -> tensor<*xf32> {
// CHECK: [[DIM_2:%.+]] = dim [[RES]], 0 : memref<?x10xf32> // CHECK: [[DIM_2:%.+]] = dim [[RES]], 0 : memref<?x10xf32>
// CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg1 = 0 to [[DIM_2]], [[DEF_LOOPS]]#1 -> %arg2 = 0 to 10) { // CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg1 = 0 to [[DIM_2]], [[DEF_LOOPS]]#1 -> %arg2 = 0 to 10) {
// CHECK: [[LOAD:%.+]] = load [[RES]][%arg1, %arg2] : memref<?x10xf32> // CHECK: [[LOAD:%.+]] = load [[RES]][%arg1, %arg2] : memref<?x10xf32>
// CHECK: [[ZERO:%.+]] = constant {{0.+}} : f32 // CHECK: [[TANH_RES:%.+]] = tanh [[LOAD]] : f32
// CHECK: [[NLOAD:%.+]] = subf [[ZERO]], [[LOAD]] : f32
// CHECK: [[EXP:%.+]] = exp [[LOAD]] : f32
// CHECK: [[NEXP:%.+]] = exp [[NLOAD]] : f32
// CHECK: [[DIVIDEND:%.+]] = subf [[EXP]], [[NEXP]] : f32
// CHECK: [[DIVISOR:%.+]] = addf [[EXP]], [[NEXP]] : f32
// CHECK: [[TANH_RES:%.+]] = divf [[DIVIDEND]], [[DIVISOR]] : f32
// CHECK: store [[TANH_RES]], [[RET_RES]][%arg1, %arg2] : memref<?x10xf32> // CHECK: store [[TANH_RES]], [[RET_RES]][%arg1, %arg2] : memref<?x10xf32>
/// Dealloc of first result. /// Dealloc of first result.