Re-add tanh lowering (#75)

* Re-add tanh lowering

* Make the emission deterministic
This commit is contained in:
Tung D. Le 2020-04-09 15:22:36 +09:00 committed by GitHub
parent c9199c9061
commit f4fefcf713
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 49 additions and 9 deletions

View File

@ -66,12 +66,6 @@ struct ScalarOp<ONNXSumOp> {
using IOp = AddIOp; using IOp = AddIOp;
}; };
template <>
struct ScalarOp<ONNXTanhOp> {
using FOp = TanhOp;
using IOp = TanhOp; // not use
};
template <> template <>
struct ScalarOp<ONNXCosOp> { struct ScalarOp<ONNXCosOp> {
using FOp = CosOp; using FOp = CosOp;
@ -138,6 +132,30 @@ Value mapToLowerScalarOp<ONNXCoshOp>(Operation *op, ArrayRef<Type> result_types,
return result; return result;
} }
//===----------------------------------------------------------------------===//
// Scalar unary ops for lowering ONNXTanhOp
//===----------------------------------------------------------------------===//
template <>
Value mapToLowerScalarOp<ONNXTanhOp>(Operation *op, ArrayRef<Type> result_types,
ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) {
// ONNXTanhOp(%X) = DivFOp(SubFOp(ExpOp(%X), ExpOp(NegFOp(%X))),
// AddFOp(ExpOp(%X), ExpOp(NegFOp(%X))))
auto loc = op->getLoc();
Value operand = operands[0];
auto elementType = result_types[0];
auto zero = emitConstantOp(rewriter, loc, elementType, 0);
auto neg = rewriter.create<SubFOp>(loc, zero, operand);
auto exp = rewriter.create<ExpOp>(loc, operand);
auto negExp = rewriter.create<ExpOp>(loc, neg);
auto dividend = rewriter.create<SubFOp>(loc, exp, negExp);
auto divisor = rewriter.create<AddFOp>(loc, exp, negExp);
auto result = rewriter.create<DivFOp>(loc, dividend, divisor);
return result;
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Scalar unary ops for lowering ONNXSigmoidOp // Scalar unary ops for lowering ONNXSigmoidOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//

View File

@ -82,6 +82,10 @@ test_to_enable = [
"test_cosh_cpu", "test_cosh_cpu",
"test_cosh_example_cpu", "test_cosh_example_cpu",
# Tanh:
"test_tanh_cpu",
"test_tanh_example_cpu",
# Div Op: # Div Op:
"test_div_cpu", "test_div_cpu",
"test_div_bcast_cpu", "test_div_bcast_cpu",

View File

@ -159,7 +159,13 @@ func @test_tanh(%arg0 : tensor<?x10xf32>) -> tensor<*xf32> {
// CHECK: [[DIM_2:%.+]] = dim %arg0, 0 : memref<?x10xf32> // CHECK: [[DIM_2:%.+]] = dim %arg0, 0 : memref<?x10xf32>
// CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg1 = 0 to [[DIM_2]], [[DEF_LOOPS]]#1 -> %arg2 = 0 to 10) { // CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg1 = 0 to [[DIM_2]], [[DEF_LOOPS]]#1 -> %arg2 = 0 to 10) {
// CHECK: [[LOAD:%.+]] = load %arg0[%arg1, %arg2] : memref<?x10xf32> // CHECK: [[LOAD:%.+]] = load %arg0[%arg1, %arg2] : memref<?x10xf32>
// CHECK: [[TANH:%.+]] = tanh [[LOAD]] : f32 // CHECK: [[ZERO:%.+]] = constant {{0.+}} : f32
// CHECK: [[NLOAD:%.+]] = subf [[ZERO]], [[LOAD]] : f32
// CHECK: [[EXP:%.+]] = exp [[LOAD]] : f32
// CHECK: [[NEXP:%.+]] = exp [[NLOAD]] : f32
// CHECK: [[DIVIDEND:%.+]] = subf [[EXP]], [[NEXP]] : f32
// CHECK: [[DIVISOR:%.+]] = addf [[EXP]], [[NEXP]] : f32
// CHECK: [[TANH:%.+]] = divf [[DIVIDEND]], [[DIVISOR]] : f32
// CHECK: store [[TANH]], [[RES]][%arg1, %arg2] : memref<?x10xf32> // CHECK: store [[TANH]], [[RES]][%arg1, %arg2] : memref<?x10xf32>
// CHECK: return [[RES]] : memref<?x10xf32> // CHECK: return [[RES]] : memref<?x10xf32>
} }

View File

@ -315,7 +315,13 @@ func @test_tanh_tanh(%arg0 : tensor<?x10xf32>) -> tensor<*xf32> {
// CHECK: [[DIM_2:%.+]] = dim %arg0, 0 : memref<?x10xf32> // CHECK: [[DIM_2:%.+]] = dim %arg0, 0 : memref<?x10xf32>
// CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg1 = 0 to [[DIM_2]], [[DEF_LOOPS]]#1 -> %arg2 = 0 to 10) { // CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg1 = 0 to [[DIM_2]], [[DEF_LOOPS]]#1 -> %arg2 = 0 to 10) {
// CHECK: [[LOAD:%.+]] = load %arg0[%arg1, %arg2] : memref<?x10xf32> // CHECK: [[LOAD:%.+]] = load %arg0[%arg1, %arg2] : memref<?x10xf32>
// CHECK: [[TANH:%.+]] = tanh [[LOAD]] : f32 // CHECK: [[ZERO:%.+]] = constant {{0.+}} : f32
// CHECK: [[NLOAD:%.+]] = subf [[ZERO]], [[LOAD]] : f32
// CHECK: [[EXP:%.+]] = exp [[LOAD]] : f32
// CHECK: [[NEXP:%.+]] = exp [[NLOAD]] : f32
// CHECK: [[DIVIDEND:%.+]] = subf [[EXP]], [[NEXP]] : f32
// CHECK: [[DIVISOR:%.+]] = addf [[EXP]], [[NEXP]] : f32
// CHECK: [[TANH:%.+]] = divf [[DIVIDEND]], [[DIVISOR]] : f32
// CHECK: store [[TANH]], [[RES]][%arg1, %arg2] : memref<?x10xf32> // CHECK: store [[TANH]], [[RES]][%arg1, %arg2] : memref<?x10xf32>
/// Second Tanh /// Second Tanh
@ -328,7 +334,13 @@ func @test_tanh_tanh(%arg0 : tensor<?x10xf32>) -> tensor<*xf32> {
// CHECK: [[DIM_2:%.+]] = dim [[RES]], 0 : memref<?x10xf32> // CHECK: [[DIM_2:%.+]] = dim [[RES]], 0 : memref<?x10xf32>
// CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg1 = 0 to [[DIM_2]], [[DEF_LOOPS]]#1 -> %arg2 = 0 to 10) { // CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg1 = 0 to [[DIM_2]], [[DEF_LOOPS]]#1 -> %arg2 = 0 to 10) {
// CHECK: [[LOAD:%.+]] = load [[RES]][%arg1, %arg2] : memref<?x10xf32> // CHECK: [[LOAD:%.+]] = load [[RES]][%arg1, %arg2] : memref<?x10xf32>
// CHECK: [[TANH_RES:%.+]] = tanh [[LOAD]] : f32 // CHECK: [[ZERO:%.+]] = constant {{0.+}} : f32
// CHECK: [[NLOAD:%.+]] = subf [[ZERO]], [[LOAD]] : f32
// CHECK: [[EXP:%.+]] = exp [[LOAD]] : f32
// CHECK: [[NEXP:%.+]] = exp [[NLOAD]] : f32
// CHECK: [[DIVIDEND:%.+]] = subf [[EXP]], [[NEXP]] : f32
// CHECK: [[DIVISOR:%.+]] = addf [[EXP]], [[NEXP]] : f32
// CHECK: [[TANH_RES:%.+]] = divf [[DIVIDEND]], [[DIVISOR]] : f32
// CHECK: store [[TANH_RES]], [[RET_RES]][%arg1, %arg2] : memref<?x10xf32> // CHECK: store [[TANH_RES]], [[RET_RES]][%arg1, %arg2] : memref<?x10xf32>
/// Dealloc of first result. /// Dealloc of first result.