From d65a6e72ddfc2c2b231b6127f7a40552ab855166 Mon Sep 17 00:00:00 2001 From: "Tung D. Le" Date: Thu, 14 May 2020 14:00:15 +0900 Subject: [PATCH] Specialize the op lowering logic for element-wise operations (#118) * Specialize the op lowering logic for elementwise operations * Fix clang-format error. * Update tests for LSTM since LSTM uses element-wise ops Co-authored-by: Tian Jin --- .../ONNXToKrnl/Math/Elementwise.cpp | 103 +++++----- .../ONNXToKrnl/ONNXToKrnlCommon.cpp | 9 + .../ONNXToKrnl/ONNXToKrnlCommon.hpp | 3 + test/mlir/onnx/onnx_lowering.mlir | 176 ++++++++++-------- 4 files changed, 162 insertions(+), 129 deletions(-) diff --git a/src/Conversion/ONNXToKrnl/Math/Elementwise.cpp b/src/Conversion/ONNXToKrnl/Math/Elementwise.cpp index 7800349..076bc25 100644 --- a/src/Conversion/ONNXToKrnl/Math/Elementwise.cpp +++ b/src/Conversion/ONNXToKrnl/Math/Elementwise.cpp @@ -505,6 +505,7 @@ struct ONNXElementwiseUnaryOpLowering : public ConversionPattern { // An element-wise unary operation must have all operands and the result of // the same type. This should have been verified by the verifier. auto loc = op->getLoc(); + auto X = operands[0]; // Insert an allocation and deallocation for the result of this operation. auto memRefType = convertToMemRefType(*op->result_type_begin()); @@ -521,33 +522,35 @@ struct ONNXElementwiseUnaryOpLowering : public ConversionPattern { if (hasAllConstantDimensions(memRefType)) alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc); else - alloc = insertAllocAndDealloc( - memRefType, loc, rewriter, insertDealloc, {operands[0]}); + alloc = + insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc, {X}); - std::vector originalLoops; - KrnlOptimizeLoopsOp optimizedLoopsOp; - KrnlIterateOp iterateOp; - emitKrnlLoopsAndIterationForOperand( - rewriter, loc, operands[0], originalLoops, optimizedLoopsOp, iterateOp); - Block &optimizationBlock = optimizedLoopsOp.region().front(); - Block &iterationBlock = iterateOp.bodyRegion().front(); - - // 1. Insert any optimizations in the KrnlOptimizeLoopsOp body. - rewriter.setInsertionPointToEnd(&optimizationBlock); - // Return from KrnlOptimizeLoopsOp body. - // When no optimizations are present we just return the loops - // unchaged. - rewriter.create(loc, originalLoops); - - // 2. Insert instructions inside the KernelIterateOp body. - rewriter.setInsertionPointToStart(&iterationBlock); - - // Handle the operation: SmallVector loopIVs; - for (auto arg : iterationBlock.getArguments()) - loopIVs.push_back(arg); + if (!hasAllScalarValues(operands)) { + std::vector originalLoops; + KrnlOptimizeLoopsOp optimizedLoopsOp; + KrnlIterateOp iterateOp; + emitKrnlLoopsAndIterationForOperand( + rewriter, loc, X, originalLoops, optimizedLoopsOp, iterateOp); + Block &optimizationBlock = optimizedLoopsOp.region().front(); + Block &iterationBlock = iterateOp.bodyRegion().front(); - auto loadedVal = rewriter.create(loc, operands[0], loopIVs); + // 1. Insert any optimizations in the KrnlOptimizeLoopsOp body. + rewriter.setInsertionPointToEnd(&optimizationBlock); + // Return from KrnlOptimizeLoopsOp body. + // When no optimizations are present we just return the loops + // unchaged. + rewriter.create(loc, originalLoops); + + // 2. Insert instructions inside the KernelIterateOp body. + rewriter.setInsertionPointToStart(&iterationBlock); + + // Handle the operation: + for (auto arg : iterationBlock.getArguments()) + loopIVs.push_back(arg); + } + + auto loadedVal = rewriter.create(loc, X, loopIVs); auto loweredOpResult = emitScalarOpFor( rewriter, loc, op, memRefType.getElementType(), {loadedVal}); // Store result in the resulting array. @@ -589,33 +592,35 @@ struct ONNXElementwiseVariadicOpLowering : public ConversionPattern { alloc = insertAllocAndDealloc( memRefType, loc, rewriter, insertDealloc, operands); - // Get run-time dimension information for unknown dimensions used for - // broadcasting. - std::map> broadcastedDimInfo = - getBroadcastedDimInfo(loc, rewriter, memRefType, operands); - - std::vector originalLoops; - KrnlOptimizeLoopsOp optimizedLoopsOp; - KrnlIterateOp iterateOp; - emitKrnlLoopsAndIterationForOperand( - rewriter, loc, alloc, originalLoops, optimizedLoopsOp, iterateOp); - Block &optimizationBlock = optimizedLoopsOp.region().front(); - Block &iterationBlock = iterateOp.bodyRegion().front(); - - // 1. Insert any optimizations in the KrnlOptimizeLoopsOp body. - rewriter.setInsertionPointToEnd(&optimizationBlock); - // Return from KrnlOptimizeLoopsOp body. - // When no optimizations are present we just return the loops unchaged. - rewriter.create(loc, originalLoops); - - // 2. Insert instructions inside the KernelIterateOp body. - rewriter.setInsertionPointToStart(&iterationBlock); - - // Handle the operation: SmallVector loopIVs; - for (auto arg : iterationBlock.getArguments()) - loopIVs.push_back(arg); + std::map> broadcastedDimInfo; + if (!hasAllScalarValues(operands)) { + // Get run-time dimension information for unknown dimensions used for + // broadcasting. + broadcastedDimInfo = + getBroadcastedDimInfo(loc, rewriter, memRefType, operands); + std::vector originalLoops; + KrnlOptimizeLoopsOp optimizedLoopsOp; + KrnlIterateOp iterateOp; + emitKrnlLoopsAndIterationForOperand( + rewriter, loc, alloc, originalLoops, optimizedLoopsOp, iterateOp); + Block &optimizationBlock = optimizedLoopsOp.region().front(); + Block &iterationBlock = iterateOp.bodyRegion().front(); + + // 1. Insert any optimizations in the KrnlOptimizeLoopsOp body. + rewriter.setInsertionPointToEnd(&optimizationBlock); + // Return from KrnlOptimizeLoopsOp body. + // When no optimizations are present we just return the loops unchaged. + rewriter.create(loc, originalLoops); + + // 2. Insert instructions inside the KernelIterateOp body. + rewriter.setInsertionPointToStart(&iterationBlock); + + // Handle the operation: + for (auto arg : iterationBlock.getArguments()) + loopIVs.push_back(arg); + } // Fold over operands for each of their scalar values Value accumulated, next; auto accumulatedLoopIVs = getLoopIVsForBroadcasting( diff --git a/src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.cpp b/src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.cpp index c2f5ef3..6487247 100644 --- a/src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.cpp +++ b/src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.cpp @@ -20,6 +20,15 @@ bool hasAllConstantDimensions(MemRefType type) { return true; } +/// Check is all operands are scalar values at compile time. +bool hasAllScalarValues(ArrayRef values) { + for (Value value : values) { + if (value.getType().cast().getRank() != 0) + return false; + } + return true; +} + /// Get the corresponding MemRefType of a given TensorType/MemRefType. MemRefType convertToMemRefType(Type type) { MemRefType memRefType; diff --git a/src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp b/src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp index 74da9f2..95907f8 100644 --- a/src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp +++ b/src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp @@ -35,6 +35,9 @@ using namespace mlir; /// Check is all dimensions are known at compile time. bool hasAllConstantDimensions(MemRefType type); +/// Check is all operands are scalar values at compile time. +bool hasAllScalarValues(ArrayRef values); + /// Get the corresponding MemRefType of a given TensorType/MemRefType. MemRefType convertToMemRefType(Type type); diff --git a/test/mlir/onnx/onnx_lowering.mlir b/test/mlir/onnx/onnx_lowering.mlir index 89cb551..21993a5 100644 --- a/test/mlir/onnx/onnx_lowering.mlir +++ b/test/mlir/onnx/onnx_lowering.mlir @@ -18,6 +18,52 @@ func @test_no_argument_2() -> tensor<*xf32> { // ----- +func @test_elementwise_op_with_scalar_values_1(%arg0 : tensor) -> tensor<*xf32> { + %0 = "onnx.Exp"(%arg0) : (tensor) -> tensor<*xf32> + "std.return"(%0) : (tensor<*xf32>) -> () + + // CHECK-LABEL: test_elementwise_op_with_scalar_values_1 + // CHECK: [[RES:%.+]] = alloc() : memref + // CHECK: [[LOAD:%.+]] = load %arg0[] : memref + // CHECK: [[EXP:%.+]] = exp [[LOAD]] : f32 + // CHECK: store [[EXP]], [[RES]][] : memref + // CHECK: return [[RES]] : memref +} + +// ----- + +func @test_elementwise_op_with_scalar_values_2(%arg0 : tensor, %arg1 : tensor) -> tensor<*xf32> { + %0 = "onnx.Add"(%arg0, %arg1) : (tensor, tensor) -> tensor<*xf32> + "std.return"(%0) : (tensor<*xf32>) -> () + + // CHECK-LABEL: test_elementwise_op_with_scalar_values_2 + // CHECK: [[RES:%.+]] = alloc() : memref + // CHECK: [[LOAD1:%.+]] = load %arg0[] : memref + // CHECK: [[LOAD2:%.+]] = load %arg1[] : memref + // CHECK: [[ADD:%.+]] = addf [[LOAD1]], [[LOAD2]] : f32 + // CHECK: store [[ADD]], [[RES]][] : memref + // CHECK: return [[RES]] : memref +} + +// ----- + +func @test_elementwise_op_with_scalar_values_3(%arg0 : tensor, %arg1 : tensor, %arg2 : tensor) -> tensor<*xf32> { + %0 = "onnx.Sum"(%arg0, %arg1, %arg2) : (tensor, tensor, tensor) -> tensor<*xf32> + "std.return"(%0) : (tensor<*xf32>) -> () + + // CHECK-LABEL: test_elementwise_op_with_scalar_values_3 + // CHECK: [[RES:%.+]] = alloc() : memref + // CHECK: [[LOAD1:%.+]] = load %arg0[] : memref + // CHECK: [[LOAD2:%.+]] = load %arg1[] : memref + // CHECK: [[ADD1:%.+]] = addf [[LOAD1]], [[LOAD2]] : f32 + // CHECK: [[LOAD3:%.+]] = load %arg2[] : memref + // CHECK: [[ADD2:%.+]] = addf [[ADD1]], [[LOAD3]] : f32 + // CHECK: store [[ADD2]], [[RES]][] : memref + // CHECK: return [[RES]] : memref +} + +// ----- + func @test_add(%arg0 : tensor<10x10xf32>, %arg1 : tensor<10x10xf32>) -> tensor<*xf32> { %0 = "onnx.Add"(%arg0, %arg1) : (tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<*xf32> "std.return"(%0) : (tensor<*xf32>) -> () @@ -1854,49 +1900,49 @@ func @test_lstm_general_computation(%arg0: tensor<4x3x2xf32>, %arg1: tensor<1x12 // CHECK: {{.*}} = mulf [[Xt_LOAD]], [[Wi_LOAD]] : f32 // CHECK: {{.*}} = load [[XtWi_GEMM]][] : memref // CHECK: {{.*}} = addf {{.*}}, {{.*}} : f32 - // CHECK: store %59, [[XtWi_GEMM]][] : memref + // CHECK: store {{.*}}, [[XtWi_GEMM]][] : memref // CHECK: [[Ri_LOAD:%.+]] = load %arg2[%c0, [[INPUT_HIDDEN_INDEX]], %arg6] : memref<1x12x3xf32> // CHECK: {{.*}} = mulf [[Ht1_LOAD]], [[Ri_LOAD]] : f32 // CHECK: {{.*}} = load [[Ht1Ri_GEMM]][] : memref // CHECK: {{.*}} = addf {{.*}}, {{.*}} : f32 - // CHECK: store %63, [[Ht1Ri_GEMM]][] : memref + // CHECK: store {{.*}}, [[Ht1Ri_GEMM]][] : memref // CHECK: [[Wo_LOAD:%.+]] = load %arg1[%c0, [[OUTPUT_HIDDEN_INDEX]], %arg6] : memref<1x12x2xf32> // CHECK: {{.*}} = mulf [[Xt_LOAD]], [[Wo_LOAD]] : f32 // CHECK: {{.*}} = load [[XtWo_GEMM]][] : memref // CHECK: {{.*}} = addf {{.*}}, {{.*}} : f32 - // CHECK: store %67, [[XtWo_GEMM]][] : memref + // CHECK: store {{.*}}, [[XtWo_GEMM]][] : memref // CHECK: [[Ro_LOAD:%.+]] = load %arg2[%c0, [[OUTPUT_HIDDEN_INDEX]], %arg6] : memref<1x12x3xf32> // CHECK: {{.*}} = mulf [[Ht1_LOAD]], [[Ro_LOAD]] : f32 // CHECK: {{.*}} = load [[Ht1Ro_GEMM]][] : memref // CHECK: {{.*}} = addf {{.*}}, {{.*}} : f32 - // CHECK: store %71, [[Ht1Ro_GEMM]][] : memref + // CHECK: store {{.*}}, [[Ht1Ro_GEMM]][] : memref // CHECK: [[Wf_LOAD:%.+]] = load %arg1[%c0, [[FORGET_HIDDEN_INDEX]], %arg6] : memref<1x12x2xf32> // CHECK: {{.*}} = mulf [[Xt_LOAD]], [[Wf_LOAD]] : f32 // CHECK: {{.*}} = load [[XtWf_GEMM]][] : memref // CHECK: {{.*}} = addf {{.*}}, {{.*}} : f32 - // CHECK: store %75, [[XtWf_GEMM]][] : memref + // CHECK: store {{.*}}, [[XtWf_GEMM]][] : memref // CHECK: [[Rf_LOAD:%.+]] = load %arg2[%c0, [[FORGET_HIDDEN_INDEX]], %arg6] : memref<1x12x3xf32> // CHECK: {{.*}} = mulf [[Ht1_LOAD]], [[Rf_LOAD]] : f32 // CHECK: {{.*}} = load [[Ht1Rf_GEMM]][] : memref // CHECK: {{.*}} = addf {{.*}}, {{.*}} : f32 - // CHECK: store %79, [[Ht1Rf_GEMM]][] : memref + // CHECK: store {{.*}}, [[Ht1Rf_GEMM]][] : memref // CHECK: [[Wc_LOAD:%.+]] = load %arg1[%c0, [[CELL_HIDDEN_INDEX]], %arg6] : memref<1x12x2xf32> // CHECK: {{.*}} = mulf [[Xt_LOAD]], [[Wc_LOAD]] : f32 // CHECK: {{.*}} = load [[XtWc_GEMM]][] : memref // CHECK: {{.*}} = addf {{.*}}, {{.*}} : f32 - // CHECK: store %83, [[XtWc_GEMM]][] : memref + // CHECK: store {{.*}}, [[XtWc_GEMM]][] : memref // CHECK: [[Rc_LOAD:%.+]] = load %arg2[%c0, [[CELL_HIDDEN_INDEX]], %arg6] : memref<1x12x3xf32> // CHECK: {{.*}} = mulf [[Ht1_LOAD]], [[Rc_LOAD]] : f32 // CHECK: {{.*}} = load [[Ht1Rc_GEMM]][] : memref // CHECK: {{.*}} = addf {{.*}}, {{.*}} : f32 - // CHECK: store %87, [[Ht1Rc_GEMM]][] : memref + // CHECK: store {{.*}}, [[Ht1Rc_GEMM]][] : memref // CHECK: } // CHECK: [[XtWi_LOAD:%.+]] = load [[XtWi_GEMM]][] : memref @@ -1905,20 +1951,14 @@ func @test_lstm_general_computation(%arg0: tensor<4x3x2xf32>, %arg1: tensor<1x12 // CHECK: [[SIGMOID_INPUT:%.+]] = alloc() : memref // CHECK: store [[It_OUTPUT]], [[SIGMOID_INPUT]][] : memref - // CHECK: krnl.define_loops 0 - // CHECK: krnl.optimize_loops { - // CHECK: krnl.return_loops - // CHECK: } : () -> () - // CHECK: krnl.iterate() with () { - // CHECK: {{.*}} = load [[SIGMOID_INPUT]][] : memref - // CHECK: {{.*}} = constant 0.000000e+00 : f32 - // CHECK: {{.*}} = constant 1.000000e+00 : f32 - // CHECK: {{.*}} = subf {{.*}}, {{.*}}: f32 - // CHECK: {{.*}} = exp {{.*}} : f32 - // CHECK: {{.*}} = addf {{.*}}, {{.*}} : f32 - // CHECK: {{.*}} = divf {{.*}}, {{.*}} : f32 - // CHECK: store {{.*}}, [[It]][] : memref - // CHECK: } + // CHECK: {{.*}} = load [[SIGMOID_INPUT]][] : memref + // CHECK: {{.*}} = constant 0.000000e+00 : f32 + // CHECK: {{.*}} = constant 1.000000e+00 : f32 + // CHECK: {{.*}} = subf {{.*}}, {{.*}}: f32 + // CHECK: {{.*}} = exp {{.*}} : f32 + // CHECK: {{.*}} = addf {{.*}}, {{.*}} : f32 + // CHECK: {{.*}} = divf {{.*}}, {{.*}} : f32 + // CHECK: store {{.*}}, [[It]][] : memref // CHECK: [[It_LOAD:%.+]] = load [[It]][] : memref // CHECK: [[XtWf_LOAD:%.+]] = load [[XtWf_GEMM]][] : memref @@ -1927,20 +1967,14 @@ func @test_lstm_general_computation(%arg0: tensor<4x3x2xf32>, %arg1: tensor<1x12 // CHECK: [[SIGMOID_FORGET:%.+]] = alloc() : memref // CHECK: store [[Ft_OUTPUT]], [[SIGMOID_FORGET]][] : memref - // CHECK: krnl.define_loops 0 - // CHECK: krnl.optimize_loops { - // CHECK: krnl.return_loops - // CHECK: } : () -> () - // CHECK: krnl.iterate() with () { - // CHECK: {{.*}} = load [[SIGMOID_FORGET]][] : memref - // CHECK: {{.*}} = constant 0.000000e+00 : f32 - // CHECK: {{.*}} = constant 1.000000e+00 : f32 - // CHECK: {{.*}} = subf {{.*}}, {{.*}}: f32 - // CHECK: {{.*}} = exp {{.*}} : f32 - // CHECK: {{.*}} = addf {{.*}}, {{.*}} : f32 - // CHECK: {{.*}} = divf {{.*}}, {{.*}} : f32 - // CHECK: store {{.*}}, [[Ft]][] : memref - // CHECK: } + // CHECK: {{.*}} = load [[SIGMOID_FORGET]][] : memref + // CHECK: {{.*}} = constant 0.000000e+00 : f32 + // CHECK: {{.*}} = constant 1.000000e+00 : f32 + // CHECK: {{.*}} = subf {{.*}}, {{.*}}: f32 + // CHECK: {{.*}} = exp {{.*}} : f32 + // CHECK: {{.*}} = addf {{.*}}, {{.*}} : f32 + // CHECK: {{.*}} = divf {{.*}}, {{.*}} : f32 + // CHECK: store {{.*}}, [[Ft]][] : memref // CHECK: [[Ft_LOAD:%.+]] = load [[Ft]][] : memref // CHECK: [[XtWc_LOAD:%.+]] = load [[XtWc_GEMM]][] : memref @@ -1949,21 +1983,15 @@ func @test_lstm_general_computation(%arg0: tensor<4x3x2xf32>, %arg1: tensor<1x12 // CHECK: [[TANH_CELL:%.+]] = alloc() : memref // CHECK: store [[ct_OUTPUT]], [[TANH_CELL]][] : memref - // CHECK: krnl.define_loops 0 - // CHECK: krnl.optimize_loops { - // CHECK: krnl.return_loops - // CHECK: } : () -> () - // CHECK: krnl.iterate() with () { - // CHECK: {{.*}} = load [[TANH_CELL]][] : memref - // CHECK: {{.*}} = constant 0.000000e+00 : f32 - // CHECK: {{.*}} = subf {{.*}}, {{.*}} : f32 - // CHECK: {{.*}} = exp {{.*}} : f32 - // CHECK: {{.*}} = exp {{.*}} : f32 - // CHECK: {{.*}} = subf {{.*}}, {{.*}} : f32 - // CHECK: {{.*}} = addf {{.*}}, {{.*}} : f32 - // CHECK: {{.*}} = divf {{.*}}, {{.*}} : f32 - // CHECK: store {{.*}}, [[ct]][] : memref - // CHECK: } + // CHECK: {{.*}} = load [[TANH_CELL]][] : memref + // CHECK: {{.*}} = constant 0.000000e+00 : f32 + // CHECK: {{.*}} = subf {{.*}}, {{.*}} : f32 + // CHECK: {{.*}} = exp {{.*}} : f32 + // CHECK: {{.*}} = exp {{.*}} : f32 + // CHECK: {{.*}} = subf {{.*}}, {{.*}} : f32 + // CHECK: {{.*}} = addf {{.*}}, {{.*}} : f32 + // CHECK: {{.*}} = divf {{.*}}, {{.*}} : f32 + // CHECK: store {{.*}}, [[ct]][] : memref // CHECK: [[ct_LOAD:%.+]] = load [[ct]][] : memref // CHECK: [[FtCt1:%.+]] = mulf [[Ft_LOAD]], [[Ct1_LOAD]] : f32 @@ -1977,39 +2005,27 @@ func @test_lstm_general_computation(%arg0: tensor<4x3x2xf32>, %arg1: tensor<1x12 // CHECK: [[SIGMOID_OUTPUT:%.+]] = alloc() : memref // CHECK: store [[Ot_OUTPUT]], [[SIGMOID_OUTPUT]][] : memref - // CHECK: krnl.define_loops 0 - // CHECK: krnl.optimize_loops { - // CHECK: krnl.return_loops - // CHECK: } : () -> () - // CHECK: krnl.iterate() with () { - // CHECK: {{.*}} = load [[SIGMOID_OUTPUT]][] : memref - // CHECK: {{.*}} = constant 0.000000e+00 : f32 - // CHECK: {{.*}} = constant 1.000000e+00 : f32 - // CHECK: {{.*}} = subf {{.*}}, {{.*}}: f32 - // CHECK: {{.*}} = exp {{.*}} : f32 - // CHECK: {{.*}} = addf {{.*}}, {{.*}} : f32 - // CHECK: {{.*}} = divf {{.*}}, {{.*}} : f32 - // CHECK: store {{.*}}, [[Ot]][] : memref - // CHECK: } + // CHECK: {{.*}} = load [[SIGMOID_OUTPUT]][] : memref + // CHECK: {{.*}} = constant 0.000000e+00 : f32 + // CHECK: {{.*}} = constant 1.000000e+00 : f32 + // CHECK: {{.*}} = subf {{.*}}, {{.*}}: f32 + // CHECK: {{.*}} = exp {{.*}} : f32 + // CHECK: {{.*}} = addf {{.*}}, {{.*}} : f32 + // CHECK: {{.*}} = divf {{.*}}, {{.*}} : f32 + // CHECK: store {{.*}}, [[Ot]][] : memref // CHECK: [[Ot_LOAD:%.+]] = load [[Ot]][] : memref // CHECK: [[TANH_HIDDEN:%.+]] = alloc() : memref // CHECK: store [[Ct]], [[TANH_HIDDEN]][] : memref - // CHECK: krnl.define_loops 0 - // CHECK: krnl.optimize_loops { - // CHECK: krnl.return_loops - // CHECK: } : () -> () - // CHECK: krnl.iterate() with () { - // CHECK: {{.*}} = load [[TANH_HIDDEN]][] : memref - // CHECK: {{.*}} = constant 0.000000e+00 : f32 - // CHECK: {{.*}} = subf {{.*}}, {{.*}} : f32 - // CHECK: {{.*}} = exp {{.*}} : f32 - // CHECK: {{.*}} = exp {{.*}} : f32 - // CHECK: {{.*}} = subf {{.*}}, {{.*}} : f32 - // CHECK: {{.*}} = addf {{.*}}, {{.*}} : f32 - // CHECK: {{.*}} = divf {{.*}}, {{.*}} : f32 - // CHECK: store {{.*}}, [[hCt]][] : memref - // CHECK: } + // CHECK: {{.*}} = load [[TANH_HIDDEN]][] : memref + // CHECK: {{.*}} = constant 0.000000e+00 : f32 + // CHECK: {{.*}} = subf {{.*}}, {{.*}} : f32 + // CHECK: {{.*}} = exp {{.*}} : f32 + // CHECK: {{.*}} = exp {{.*}} : f32 + // CHECK: {{.*}} = subf {{.*}}, {{.*}} : f32 + // CHECK: {{.*}} = addf {{.*}}, {{.*}} : f32 + // CHECK: {{.*}} = divf {{.*}}, {{.*}} : f32 + // CHECK: store {{.*}}, [[hCt]][] : memref // CHECK: [[hCt_LOAD:%.+]] = load [[hCt]][] : memref // CHECK: [[Ht:%.+]] = mulf [[Ot_LOAD]], [[hCt_LOAD]] : f32