Re-add tanh lowering (#75)
* Re-add tanh lowering * Make the emission deterministic
This commit is contained in:
parent
c9199c9061
commit
f4fefcf713
|
@ -66,12 +66,6 @@ struct ScalarOp<ONNXSumOp> {
|
||||||
using IOp = AddIOp;
|
using IOp = AddIOp;
|
||||||
};
|
};
|
||||||
|
|
||||||
template <>
|
|
||||||
struct ScalarOp<ONNXTanhOp> {
|
|
||||||
using FOp = TanhOp;
|
|
||||||
using IOp = TanhOp; // not use
|
|
||||||
};
|
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
struct ScalarOp<ONNXCosOp> {
|
struct ScalarOp<ONNXCosOp> {
|
||||||
using FOp = CosOp;
|
using FOp = CosOp;
|
||||||
|
@ -138,6 +132,30 @@ Value mapToLowerScalarOp<ONNXCoshOp>(Operation *op, ArrayRef<Type> result_types,
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// Scalar unary ops for lowering ONNXTanhOp
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
template <>
|
||||||
|
Value mapToLowerScalarOp<ONNXTanhOp>(Operation *op, ArrayRef<Type> result_types,
|
||||||
|
ArrayRef<Value> operands,
|
||||||
|
ConversionPatternRewriter &rewriter) {
|
||||||
|
// ONNXTanhOp(%X) = DivFOp(SubFOp(ExpOp(%X), ExpOp(NegFOp(%X))),
|
||||||
|
// AddFOp(ExpOp(%X), ExpOp(NegFOp(%X))))
|
||||||
|
auto loc = op->getLoc();
|
||||||
|
Value operand = operands[0];
|
||||||
|
auto elementType = result_types[0];
|
||||||
|
|
||||||
|
auto zero = emitConstantOp(rewriter, loc, elementType, 0);
|
||||||
|
auto neg = rewriter.create<SubFOp>(loc, zero, operand);
|
||||||
|
auto exp = rewriter.create<ExpOp>(loc, operand);
|
||||||
|
auto negExp = rewriter.create<ExpOp>(loc, neg);
|
||||||
|
auto dividend = rewriter.create<SubFOp>(loc, exp, negExp);
|
||||||
|
auto divisor = rewriter.create<AddFOp>(loc, exp, negExp);
|
||||||
|
auto result = rewriter.create<DivFOp>(loc, dividend, divisor);
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Scalar unary ops for lowering ONNXSigmoidOp
|
// Scalar unary ops for lowering ONNXSigmoidOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
|
@ -82,6 +82,10 @@ test_to_enable = [
|
||||||
"test_cosh_cpu",
|
"test_cosh_cpu",
|
||||||
"test_cosh_example_cpu",
|
"test_cosh_example_cpu",
|
||||||
|
|
||||||
|
# Tanh:
|
||||||
|
"test_tanh_cpu",
|
||||||
|
"test_tanh_example_cpu",
|
||||||
|
|
||||||
# Div Op:
|
# Div Op:
|
||||||
"test_div_cpu",
|
"test_div_cpu",
|
||||||
"test_div_bcast_cpu",
|
"test_div_bcast_cpu",
|
||||||
|
|
|
@ -159,7 +159,13 @@ func @test_tanh(%arg0 : tensor<?x10xf32>) -> tensor<*xf32> {
|
||||||
// CHECK: [[DIM_2:%.+]] = dim %arg0, 0 : memref<?x10xf32>
|
// CHECK: [[DIM_2:%.+]] = dim %arg0, 0 : memref<?x10xf32>
|
||||||
// CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg1 = 0 to [[DIM_2]], [[DEF_LOOPS]]#1 -> %arg2 = 0 to 10) {
|
// CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg1 = 0 to [[DIM_2]], [[DEF_LOOPS]]#1 -> %arg2 = 0 to 10) {
|
||||||
// CHECK: [[LOAD:%.+]] = load %arg0[%arg1, %arg2] : memref<?x10xf32>
|
// CHECK: [[LOAD:%.+]] = load %arg0[%arg1, %arg2] : memref<?x10xf32>
|
||||||
// CHECK: [[TANH:%.+]] = tanh [[LOAD]] : f32
|
// CHECK: [[ZERO:%.+]] = constant {{0.+}} : f32
|
||||||
|
// CHECK: [[NLOAD:%.+]] = subf [[ZERO]], [[LOAD]] : f32
|
||||||
|
// CHECK: [[EXP:%.+]] = exp [[LOAD]] : f32
|
||||||
|
// CHECK: [[NEXP:%.+]] = exp [[NLOAD]] : f32
|
||||||
|
// CHECK: [[DIVIDEND:%.+]] = subf [[EXP]], [[NEXP]] : f32
|
||||||
|
// CHECK: [[DIVISOR:%.+]] = addf [[EXP]], [[NEXP]] : f32
|
||||||
|
// CHECK: [[TANH:%.+]] = divf [[DIVIDEND]], [[DIVISOR]] : f32
|
||||||
// CHECK: store [[TANH]], [[RES]][%arg1, %arg2] : memref<?x10xf32>
|
// CHECK: store [[TANH]], [[RES]][%arg1, %arg2] : memref<?x10xf32>
|
||||||
// CHECK: return [[RES]] : memref<?x10xf32>
|
// CHECK: return [[RES]] : memref<?x10xf32>
|
||||||
}
|
}
|
||||||
|
|
|
@ -315,7 +315,13 @@ func @test_tanh_tanh(%arg0 : tensor<?x10xf32>) -> tensor<*xf32> {
|
||||||
// CHECK: [[DIM_2:%.+]] = dim %arg0, 0 : memref<?x10xf32>
|
// CHECK: [[DIM_2:%.+]] = dim %arg0, 0 : memref<?x10xf32>
|
||||||
// CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg1 = 0 to [[DIM_2]], [[DEF_LOOPS]]#1 -> %arg2 = 0 to 10) {
|
// CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg1 = 0 to [[DIM_2]], [[DEF_LOOPS]]#1 -> %arg2 = 0 to 10) {
|
||||||
// CHECK: [[LOAD:%.+]] = load %arg0[%arg1, %arg2] : memref<?x10xf32>
|
// CHECK: [[LOAD:%.+]] = load %arg0[%arg1, %arg2] : memref<?x10xf32>
|
||||||
// CHECK: [[TANH:%.+]] = tanh [[LOAD]] : f32
|
// CHECK: [[ZERO:%.+]] = constant {{0.+}} : f32
|
||||||
|
// CHECK: [[NLOAD:%.+]] = subf [[ZERO]], [[LOAD]] : f32
|
||||||
|
// CHECK: [[EXP:%.+]] = exp [[LOAD]] : f32
|
||||||
|
// CHECK: [[NEXP:%.+]] = exp [[NLOAD]] : f32
|
||||||
|
// CHECK: [[DIVIDEND:%.+]] = subf [[EXP]], [[NEXP]] : f32
|
||||||
|
// CHECK: [[DIVISOR:%.+]] = addf [[EXP]], [[NEXP]] : f32
|
||||||
|
// CHECK: [[TANH:%.+]] = divf [[DIVIDEND]], [[DIVISOR]] : f32
|
||||||
// CHECK: store [[TANH]], [[RES]][%arg1, %arg2] : memref<?x10xf32>
|
// CHECK: store [[TANH]], [[RES]][%arg1, %arg2] : memref<?x10xf32>
|
||||||
|
|
||||||
/// Second Tanh
|
/// Second Tanh
|
||||||
|
@ -328,7 +334,13 @@ func @test_tanh_tanh(%arg0 : tensor<?x10xf32>) -> tensor<*xf32> {
|
||||||
// CHECK: [[DIM_2:%.+]] = dim [[RES]], 0 : memref<?x10xf32>
|
// CHECK: [[DIM_2:%.+]] = dim [[RES]], 0 : memref<?x10xf32>
|
||||||
// CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg1 = 0 to [[DIM_2]], [[DEF_LOOPS]]#1 -> %arg2 = 0 to 10) {
|
// CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg1 = 0 to [[DIM_2]], [[DEF_LOOPS]]#1 -> %arg2 = 0 to 10) {
|
||||||
// CHECK: [[LOAD:%.+]] = load [[RES]][%arg1, %arg2] : memref<?x10xf32>
|
// CHECK: [[LOAD:%.+]] = load [[RES]][%arg1, %arg2] : memref<?x10xf32>
|
||||||
// CHECK: [[TANH_RES:%.+]] = tanh [[LOAD]] : f32
|
// CHECK: [[ZERO:%.+]] = constant {{0.+}} : f32
|
||||||
|
// CHECK: [[NLOAD:%.+]] = subf [[ZERO]], [[LOAD]] : f32
|
||||||
|
// CHECK: [[EXP:%.+]] = exp [[LOAD]] : f32
|
||||||
|
// CHECK: [[NEXP:%.+]] = exp [[NLOAD]] : f32
|
||||||
|
// CHECK: [[DIVIDEND:%.+]] = subf [[EXP]], [[NEXP]] : f32
|
||||||
|
// CHECK: [[DIVISOR:%.+]] = addf [[EXP]], [[NEXP]] : f32
|
||||||
|
// CHECK: [[TANH_RES:%.+]] = divf [[DIVIDEND]], [[DIVISOR]] : f32
|
||||||
// CHECK: store [[TANH_RES]], [[RET_RES]][%arg1, %arg2] : memref<?x10xf32>
|
// CHECK: store [[TANH_RES]], [[RET_RES]][%arg1, %arg2] : memref<?x10xf32>
|
||||||
|
|
||||||
/// Dealloc of first result.
|
/// Dealloc of first result.
|
||||||
|
|
Loading…
Reference in New Issue