Specialize the op lowering logic for element-wise operations (#118)

* Specialize the op lowering logic for elementwise operations

* Fix clang-format error.

* Update tests for LSTM since LSTM uses element-wise ops

Co-authored-by: Tian Jin <tjingrant@gmail.com>
This commit is contained in:
Tung D. Le 2020-05-14 14:00:15 +09:00 committed by GitHub
parent 4dd3c809c7
commit d65a6e72dd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 162 additions and 129 deletions

View File

@ -505,6 +505,7 @@ struct ONNXElementwiseUnaryOpLowering : public ConversionPattern {
// An element-wise unary operation must have all operands and the result of // An element-wise unary operation must have all operands and the result of
// the same type. This should have been verified by the verifier. // the same type. This should have been verified by the verifier.
auto loc = op->getLoc(); auto loc = op->getLoc();
auto X = operands[0];
// Insert an allocation and deallocation for the result of this operation. // Insert an allocation and deallocation for the result of this operation.
auto memRefType = convertToMemRefType(*op->result_type_begin()); auto memRefType = convertToMemRefType(*op->result_type_begin());
@ -521,14 +522,16 @@ struct ONNXElementwiseUnaryOpLowering : public ConversionPattern {
if (hasAllConstantDimensions(memRefType)) if (hasAllConstantDimensions(memRefType))
alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc); alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc);
else else
alloc = insertAllocAndDealloc( alloc =
memRefType, loc, rewriter, insertDealloc, {operands[0]}); insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc, {X});
SmallVector<Value, 4> loopIVs;
if (!hasAllScalarValues(operands)) {
std::vector<Value> originalLoops; std::vector<Value> originalLoops;
KrnlOptimizeLoopsOp optimizedLoopsOp; KrnlOptimizeLoopsOp optimizedLoopsOp;
KrnlIterateOp iterateOp; KrnlIterateOp iterateOp;
emitKrnlLoopsAndIterationForOperand( emitKrnlLoopsAndIterationForOperand(
rewriter, loc, operands[0], originalLoops, optimizedLoopsOp, iterateOp); rewriter, loc, X, originalLoops, optimizedLoopsOp, iterateOp);
Block &optimizationBlock = optimizedLoopsOp.region().front(); Block &optimizationBlock = optimizedLoopsOp.region().front();
Block &iterationBlock = iterateOp.bodyRegion().front(); Block &iterationBlock = iterateOp.bodyRegion().front();
@ -543,11 +546,11 @@ struct ONNXElementwiseUnaryOpLowering : public ConversionPattern {
rewriter.setInsertionPointToStart(&iterationBlock); rewriter.setInsertionPointToStart(&iterationBlock);
// Handle the operation: // Handle the operation:
SmallVector<Value, 4> loopIVs;
for (auto arg : iterationBlock.getArguments()) for (auto arg : iterationBlock.getArguments())
loopIVs.push_back(arg); loopIVs.push_back(arg);
}
auto loadedVal = rewriter.create<LoadOp>(loc, operands[0], loopIVs); auto loadedVal = rewriter.create<LoadOp>(loc, X, loopIVs);
auto loweredOpResult = emitScalarOpFor<ElementwiseUnaryOp>( auto loweredOpResult = emitScalarOpFor<ElementwiseUnaryOp>(
rewriter, loc, op, memRefType.getElementType(), {loadedVal}); rewriter, loc, op, memRefType.getElementType(), {loadedVal});
// Store result in the resulting array. // Store result in the resulting array.
@ -589,9 +592,12 @@ struct ONNXElementwiseVariadicOpLowering : public ConversionPattern {
alloc = insertAllocAndDealloc( alloc = insertAllocAndDealloc(
memRefType, loc, rewriter, insertDealloc, operands); memRefType, loc, rewriter, insertDealloc, operands);
SmallVector<Value, 4> loopIVs;
std::map<int, std::map<int, Value>> broadcastedDimInfo;
if (!hasAllScalarValues(operands)) {
// Get run-time dimension information for unknown dimensions used for // Get run-time dimension information for unknown dimensions used for
// broadcasting. // broadcasting.
std::map<int, std::map<int, Value>> broadcastedDimInfo = broadcastedDimInfo =
getBroadcastedDimInfo(loc, rewriter, memRefType, operands); getBroadcastedDimInfo(loc, rewriter, memRefType, operands);
std::vector<Value> originalLoops; std::vector<Value> originalLoops;
@ -612,10 +618,9 @@ struct ONNXElementwiseVariadicOpLowering : public ConversionPattern {
rewriter.setInsertionPointToStart(&iterationBlock); rewriter.setInsertionPointToStart(&iterationBlock);
// Handle the operation: // Handle the operation:
SmallVector<Value, 4> loopIVs;
for (auto arg : iterationBlock.getArguments()) for (auto arg : iterationBlock.getArguments())
loopIVs.push_back(arg); loopIVs.push_back(arg);
}
// Fold over operands for each of their scalar values // Fold over operands for each of their scalar values
Value accumulated, next; Value accumulated, next;
auto accumulatedLoopIVs = getLoopIVsForBroadcasting( auto accumulatedLoopIVs = getLoopIVsForBroadcasting(

View File

@ -20,6 +20,15 @@ bool hasAllConstantDimensions(MemRefType type) {
return true; return true;
} }
/// Check is all operands are scalar values at compile time.
bool hasAllScalarValues(ArrayRef<Value> values) {
for (Value value : values) {
if (value.getType().cast<ShapedType>().getRank() != 0)
return false;
}
return true;
}
/// Get the corresponding MemRefType of a given TensorType/MemRefType. /// Get the corresponding MemRefType of a given TensorType/MemRefType.
MemRefType convertToMemRefType(Type type) { MemRefType convertToMemRefType(Type type) {
MemRefType memRefType; MemRefType memRefType;

View File

@ -35,6 +35,9 @@ using namespace mlir;
/// Check is all dimensions are known at compile time. /// Check is all dimensions are known at compile time.
bool hasAllConstantDimensions(MemRefType type); bool hasAllConstantDimensions(MemRefType type);
/// Check is all operands are scalar values at compile time.
bool hasAllScalarValues(ArrayRef<Value> values);
/// Get the corresponding MemRefType of a given TensorType/MemRefType. /// Get the corresponding MemRefType of a given TensorType/MemRefType.
MemRefType convertToMemRefType(Type type); MemRefType convertToMemRefType(Type type);

View File

@ -18,6 +18,52 @@ func @test_no_argument_2() -> tensor<*xf32> {
// ----- // -----
func @test_elementwise_op_with_scalar_values_1(%arg0 : tensor<f32>) -> tensor<*xf32> {
%0 = "onnx.Exp"(%arg0) : (tensor<f32>) -> tensor<*xf32>
"std.return"(%0) : (tensor<*xf32>) -> ()
// CHECK-LABEL: test_elementwise_op_with_scalar_values_1
// CHECK: [[RES:%.+]] = alloc() : memref<f32>
// CHECK: [[LOAD:%.+]] = load %arg0[] : memref<f32>
// CHECK: [[EXP:%.+]] = exp [[LOAD]] : f32
// CHECK: store [[EXP]], [[RES]][] : memref<f32>
// CHECK: return [[RES]] : memref<f32>
}
// -----
func @test_elementwise_op_with_scalar_values_2(%arg0 : tensor<f32>, %arg1 : tensor<f32>) -> tensor<*xf32> {
%0 = "onnx.Add"(%arg0, %arg1) : (tensor<f32>, tensor<f32>) -> tensor<*xf32>
"std.return"(%0) : (tensor<*xf32>) -> ()
// CHECK-LABEL: test_elementwise_op_with_scalar_values_2
// CHECK: [[RES:%.+]] = alloc() : memref<f32>
// CHECK: [[LOAD1:%.+]] = load %arg0[] : memref<f32>
// CHECK: [[LOAD2:%.+]] = load %arg1[] : memref<f32>
// CHECK: [[ADD:%.+]] = addf [[LOAD1]], [[LOAD2]] : f32
// CHECK: store [[ADD]], [[RES]][] : memref<f32>
// CHECK: return [[RES]] : memref<f32>
}
// -----
func @test_elementwise_op_with_scalar_values_3(%arg0 : tensor<f32>, %arg1 : tensor<f32>, %arg2 : tensor<f32>) -> tensor<*xf32> {
%0 = "onnx.Sum"(%arg0, %arg1, %arg2) : (tensor<f32>, tensor<f32>, tensor<f32>) -> tensor<*xf32>
"std.return"(%0) : (tensor<*xf32>) -> ()
// CHECK-LABEL: test_elementwise_op_with_scalar_values_3
// CHECK: [[RES:%.+]] = alloc() : memref<f32>
// CHECK: [[LOAD1:%.+]] = load %arg0[] : memref<f32>
// CHECK: [[LOAD2:%.+]] = load %arg1[] : memref<f32>
// CHECK: [[ADD1:%.+]] = addf [[LOAD1]], [[LOAD2]] : f32
// CHECK: [[LOAD3:%.+]] = load %arg2[] : memref<f32>
// CHECK: [[ADD2:%.+]] = addf [[ADD1]], [[LOAD3]] : f32
// CHECK: store [[ADD2]], [[RES]][] : memref<f32>
// CHECK: return [[RES]] : memref<f32>
}
// -----
func @test_add(%arg0 : tensor<10x10xf32>, %arg1 : tensor<10x10xf32>) -> tensor<*xf32> { func @test_add(%arg0 : tensor<10x10xf32>, %arg1 : tensor<10x10xf32>) -> tensor<*xf32> {
%0 = "onnx.Add"(%arg0, %arg1) : (tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<*xf32> %0 = "onnx.Add"(%arg0, %arg1) : (tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<*xf32>
"std.return"(%0) : (tensor<*xf32>) -> () "std.return"(%0) : (tensor<*xf32>) -> ()
@ -1854,49 +1900,49 @@ func @test_lstm_general_computation(%arg0: tensor<4x3x2xf32>, %arg1: tensor<1x12
// CHECK: {{.*}} = mulf [[Xt_LOAD]], [[Wi_LOAD]] : f32 // CHECK: {{.*}} = mulf [[Xt_LOAD]], [[Wi_LOAD]] : f32
// CHECK: {{.*}} = load [[XtWi_GEMM]][] : memref<f32> // CHECK: {{.*}} = load [[XtWi_GEMM]][] : memref<f32>
// CHECK: {{.*}} = addf {{.*}}, {{.*}} : f32 // CHECK: {{.*}} = addf {{.*}}, {{.*}} : f32
// CHECK: store %59, [[XtWi_GEMM]][] : memref<f32> // CHECK: store {{.*}}, [[XtWi_GEMM]][] : memref<f32>
// CHECK: [[Ri_LOAD:%.+]] = load %arg2[%c0, [[INPUT_HIDDEN_INDEX]], %arg6] : memref<1x12x3xf32> // CHECK: [[Ri_LOAD:%.+]] = load %arg2[%c0, [[INPUT_HIDDEN_INDEX]], %arg6] : memref<1x12x3xf32>
// CHECK: {{.*}} = mulf [[Ht1_LOAD]], [[Ri_LOAD]] : f32 // CHECK: {{.*}} = mulf [[Ht1_LOAD]], [[Ri_LOAD]] : f32
// CHECK: {{.*}} = load [[Ht1Ri_GEMM]][] : memref<f32> // CHECK: {{.*}} = load [[Ht1Ri_GEMM]][] : memref<f32>
// CHECK: {{.*}} = addf {{.*}}, {{.*}} : f32 // CHECK: {{.*}} = addf {{.*}}, {{.*}} : f32
// CHECK: store %63, [[Ht1Ri_GEMM]][] : memref<f32> // CHECK: store {{.*}}, [[Ht1Ri_GEMM]][] : memref<f32>
// CHECK: [[Wo_LOAD:%.+]] = load %arg1[%c0, [[OUTPUT_HIDDEN_INDEX]], %arg6] : memref<1x12x2xf32> // CHECK: [[Wo_LOAD:%.+]] = load %arg1[%c0, [[OUTPUT_HIDDEN_INDEX]], %arg6] : memref<1x12x2xf32>
// CHECK: {{.*}} = mulf [[Xt_LOAD]], [[Wo_LOAD]] : f32 // CHECK: {{.*}} = mulf [[Xt_LOAD]], [[Wo_LOAD]] : f32
// CHECK: {{.*}} = load [[XtWo_GEMM]][] : memref<f32> // CHECK: {{.*}} = load [[XtWo_GEMM]][] : memref<f32>
// CHECK: {{.*}} = addf {{.*}}, {{.*}} : f32 // CHECK: {{.*}} = addf {{.*}}, {{.*}} : f32
// CHECK: store %67, [[XtWo_GEMM]][] : memref<f32> // CHECK: store {{.*}}, [[XtWo_GEMM]][] : memref<f32>
// CHECK: [[Ro_LOAD:%.+]] = load %arg2[%c0, [[OUTPUT_HIDDEN_INDEX]], %arg6] : memref<1x12x3xf32> // CHECK: [[Ro_LOAD:%.+]] = load %arg2[%c0, [[OUTPUT_HIDDEN_INDEX]], %arg6] : memref<1x12x3xf32>
// CHECK: {{.*}} = mulf [[Ht1_LOAD]], [[Ro_LOAD]] : f32 // CHECK: {{.*}} = mulf [[Ht1_LOAD]], [[Ro_LOAD]] : f32
// CHECK: {{.*}} = load [[Ht1Ro_GEMM]][] : memref<f32> // CHECK: {{.*}} = load [[Ht1Ro_GEMM]][] : memref<f32>
// CHECK: {{.*}} = addf {{.*}}, {{.*}} : f32 // CHECK: {{.*}} = addf {{.*}}, {{.*}} : f32
// CHECK: store %71, [[Ht1Ro_GEMM]][] : memref<f32> // CHECK: store {{.*}}, [[Ht1Ro_GEMM]][] : memref<f32>
// CHECK: [[Wf_LOAD:%.+]] = load %arg1[%c0, [[FORGET_HIDDEN_INDEX]], %arg6] : memref<1x12x2xf32> // CHECK: [[Wf_LOAD:%.+]] = load %arg1[%c0, [[FORGET_HIDDEN_INDEX]], %arg6] : memref<1x12x2xf32>
// CHECK: {{.*}} = mulf [[Xt_LOAD]], [[Wf_LOAD]] : f32 // CHECK: {{.*}} = mulf [[Xt_LOAD]], [[Wf_LOAD]] : f32
// CHECK: {{.*}} = load [[XtWf_GEMM]][] : memref<f32> // CHECK: {{.*}} = load [[XtWf_GEMM]][] : memref<f32>
// CHECK: {{.*}} = addf {{.*}}, {{.*}} : f32 // CHECK: {{.*}} = addf {{.*}}, {{.*}} : f32
// CHECK: store %75, [[XtWf_GEMM]][] : memref<f32> // CHECK: store {{.*}}, [[XtWf_GEMM]][] : memref<f32>
// CHECK: [[Rf_LOAD:%.+]] = load %arg2[%c0, [[FORGET_HIDDEN_INDEX]], %arg6] : memref<1x12x3xf32> // CHECK: [[Rf_LOAD:%.+]] = load %arg2[%c0, [[FORGET_HIDDEN_INDEX]], %arg6] : memref<1x12x3xf32>
// CHECK: {{.*}} = mulf [[Ht1_LOAD]], [[Rf_LOAD]] : f32 // CHECK: {{.*}} = mulf [[Ht1_LOAD]], [[Rf_LOAD]] : f32
// CHECK: {{.*}} = load [[Ht1Rf_GEMM]][] : memref<f32> // CHECK: {{.*}} = load [[Ht1Rf_GEMM]][] : memref<f32>
// CHECK: {{.*}} = addf {{.*}}, {{.*}} : f32 // CHECK: {{.*}} = addf {{.*}}, {{.*}} : f32
// CHECK: store %79, [[Ht1Rf_GEMM]][] : memref<f32> // CHECK: store {{.*}}, [[Ht1Rf_GEMM]][] : memref<f32>
// CHECK: [[Wc_LOAD:%.+]] = load %arg1[%c0, [[CELL_HIDDEN_INDEX]], %arg6] : memref<1x12x2xf32> // CHECK: [[Wc_LOAD:%.+]] = load %arg1[%c0, [[CELL_HIDDEN_INDEX]], %arg6] : memref<1x12x2xf32>
// CHECK: {{.*}} = mulf [[Xt_LOAD]], [[Wc_LOAD]] : f32 // CHECK: {{.*}} = mulf [[Xt_LOAD]], [[Wc_LOAD]] : f32
// CHECK: {{.*}} = load [[XtWc_GEMM]][] : memref<f32> // CHECK: {{.*}} = load [[XtWc_GEMM]][] : memref<f32>
// CHECK: {{.*}} = addf {{.*}}, {{.*}} : f32 // CHECK: {{.*}} = addf {{.*}}, {{.*}} : f32
// CHECK: store %83, [[XtWc_GEMM]][] : memref<f32> // CHECK: store {{.*}}, [[XtWc_GEMM]][] : memref<f32>
// CHECK: [[Rc_LOAD:%.+]] = load %arg2[%c0, [[CELL_HIDDEN_INDEX]], %arg6] : memref<1x12x3xf32> // CHECK: [[Rc_LOAD:%.+]] = load %arg2[%c0, [[CELL_HIDDEN_INDEX]], %arg6] : memref<1x12x3xf32>
// CHECK: {{.*}} = mulf [[Ht1_LOAD]], [[Rc_LOAD]] : f32 // CHECK: {{.*}} = mulf [[Ht1_LOAD]], [[Rc_LOAD]] : f32
// CHECK: {{.*}} = load [[Ht1Rc_GEMM]][] : memref<f32> // CHECK: {{.*}} = load [[Ht1Rc_GEMM]][] : memref<f32>
// CHECK: {{.*}} = addf {{.*}}, {{.*}} : f32 // CHECK: {{.*}} = addf {{.*}}, {{.*}} : f32
// CHECK: store %87, [[Ht1Rc_GEMM]][] : memref<f32> // CHECK: store {{.*}}, [[Ht1Rc_GEMM]][] : memref<f32>
// CHECK: } // CHECK: }
// CHECK: [[XtWi_LOAD:%.+]] = load [[XtWi_GEMM]][] : memref<f32> // CHECK: [[XtWi_LOAD:%.+]] = load [[XtWi_GEMM]][] : memref<f32>
@ -1905,11 +1951,6 @@ func @test_lstm_general_computation(%arg0: tensor<4x3x2xf32>, %arg1: tensor<1x12
// CHECK: [[SIGMOID_INPUT:%.+]] = alloc() : memref<f32> // CHECK: [[SIGMOID_INPUT:%.+]] = alloc() : memref<f32>
// CHECK: store [[It_OUTPUT]], [[SIGMOID_INPUT]][] : memref<f32> // CHECK: store [[It_OUTPUT]], [[SIGMOID_INPUT]][] : memref<f32>
// CHECK: krnl.define_loops 0
// CHECK: krnl.optimize_loops {
// CHECK: krnl.return_loops
// CHECK: } : () -> ()
// CHECK: krnl.iterate() with () {
// CHECK: {{.*}} = load [[SIGMOID_INPUT]][] : memref<f32> // CHECK: {{.*}} = load [[SIGMOID_INPUT]][] : memref<f32>
// CHECK: {{.*}} = constant 0.000000e+00 : f32 // CHECK: {{.*}} = constant 0.000000e+00 : f32
// CHECK: {{.*}} = constant 1.000000e+00 : f32 // CHECK: {{.*}} = constant 1.000000e+00 : f32
@ -1918,7 +1959,6 @@ func @test_lstm_general_computation(%arg0: tensor<4x3x2xf32>, %arg1: tensor<1x12
// CHECK: {{.*}} = addf {{.*}}, {{.*}} : f32 // CHECK: {{.*}} = addf {{.*}}, {{.*}} : f32
// CHECK: {{.*}} = divf {{.*}}, {{.*}} : f32 // CHECK: {{.*}} = divf {{.*}}, {{.*}} : f32
// CHECK: store {{.*}}, [[It]][] : memref<f32> // CHECK: store {{.*}}, [[It]][] : memref<f32>
// CHECK: }
// CHECK: [[It_LOAD:%.+]] = load [[It]][] : memref<f32> // CHECK: [[It_LOAD:%.+]] = load [[It]][] : memref<f32>
// CHECK: [[XtWf_LOAD:%.+]] = load [[XtWf_GEMM]][] : memref<f32> // CHECK: [[XtWf_LOAD:%.+]] = load [[XtWf_GEMM]][] : memref<f32>
@ -1927,11 +1967,6 @@ func @test_lstm_general_computation(%arg0: tensor<4x3x2xf32>, %arg1: tensor<1x12
// CHECK: [[SIGMOID_FORGET:%.+]] = alloc() : memref<f32> // CHECK: [[SIGMOID_FORGET:%.+]] = alloc() : memref<f32>
// CHECK: store [[Ft_OUTPUT]], [[SIGMOID_FORGET]][] : memref<f32> // CHECK: store [[Ft_OUTPUT]], [[SIGMOID_FORGET]][] : memref<f32>
// CHECK: krnl.define_loops 0
// CHECK: krnl.optimize_loops {
// CHECK: krnl.return_loops
// CHECK: } : () -> ()
// CHECK: krnl.iterate() with () {
// CHECK: {{.*}} = load [[SIGMOID_FORGET]][] : memref<f32> // CHECK: {{.*}} = load [[SIGMOID_FORGET]][] : memref<f32>
// CHECK: {{.*}} = constant 0.000000e+00 : f32 // CHECK: {{.*}} = constant 0.000000e+00 : f32
// CHECK: {{.*}} = constant 1.000000e+00 : f32 // CHECK: {{.*}} = constant 1.000000e+00 : f32
@ -1940,7 +1975,6 @@ func @test_lstm_general_computation(%arg0: tensor<4x3x2xf32>, %arg1: tensor<1x12
// CHECK: {{.*}} = addf {{.*}}, {{.*}} : f32 // CHECK: {{.*}} = addf {{.*}}, {{.*}} : f32
// CHECK: {{.*}} = divf {{.*}}, {{.*}} : f32 // CHECK: {{.*}} = divf {{.*}}, {{.*}} : f32
// CHECK: store {{.*}}, [[Ft]][] : memref<f32> // CHECK: store {{.*}}, [[Ft]][] : memref<f32>
// CHECK: }
// CHECK: [[Ft_LOAD:%.+]] = load [[Ft]][] : memref<f32> // CHECK: [[Ft_LOAD:%.+]] = load [[Ft]][] : memref<f32>
// CHECK: [[XtWc_LOAD:%.+]] = load [[XtWc_GEMM]][] : memref<f32> // CHECK: [[XtWc_LOAD:%.+]] = load [[XtWc_GEMM]][] : memref<f32>
@ -1949,11 +1983,6 @@ func @test_lstm_general_computation(%arg0: tensor<4x3x2xf32>, %arg1: tensor<1x12
// CHECK: [[TANH_CELL:%.+]] = alloc() : memref<f32> // CHECK: [[TANH_CELL:%.+]] = alloc() : memref<f32>
// CHECK: store [[ct_OUTPUT]], [[TANH_CELL]][] : memref<f32> // CHECK: store [[ct_OUTPUT]], [[TANH_CELL]][] : memref<f32>
// CHECK: krnl.define_loops 0
// CHECK: krnl.optimize_loops {
// CHECK: krnl.return_loops
// CHECK: } : () -> ()
// CHECK: krnl.iterate() with () {
// CHECK: {{.*}} = load [[TANH_CELL]][] : memref<f32> // CHECK: {{.*}} = load [[TANH_CELL]][] : memref<f32>
// CHECK: {{.*}} = constant 0.000000e+00 : f32 // CHECK: {{.*}} = constant 0.000000e+00 : f32
// CHECK: {{.*}} = subf {{.*}}, {{.*}} : f32 // CHECK: {{.*}} = subf {{.*}}, {{.*}} : f32
@ -1963,7 +1992,6 @@ func @test_lstm_general_computation(%arg0: tensor<4x3x2xf32>, %arg1: tensor<1x12
// CHECK: {{.*}} = addf {{.*}}, {{.*}} : f32 // CHECK: {{.*}} = addf {{.*}}, {{.*}} : f32
// CHECK: {{.*}} = divf {{.*}}, {{.*}} : f32 // CHECK: {{.*}} = divf {{.*}}, {{.*}} : f32
// CHECK: store {{.*}}, [[ct]][] : memref<f32> // CHECK: store {{.*}}, [[ct]][] : memref<f32>
// CHECK: }
// CHECK: [[ct_LOAD:%.+]] = load [[ct]][] : memref<f32> // CHECK: [[ct_LOAD:%.+]] = load [[ct]][] : memref<f32>
// CHECK: [[FtCt1:%.+]] = mulf [[Ft_LOAD]], [[Ct1_LOAD]] : f32 // CHECK: [[FtCt1:%.+]] = mulf [[Ft_LOAD]], [[Ct1_LOAD]] : f32
@ -1977,11 +2005,6 @@ func @test_lstm_general_computation(%arg0: tensor<4x3x2xf32>, %arg1: tensor<1x12
// CHECK: [[SIGMOID_OUTPUT:%.+]] = alloc() : memref<f32> // CHECK: [[SIGMOID_OUTPUT:%.+]] = alloc() : memref<f32>
// CHECK: store [[Ot_OUTPUT]], [[SIGMOID_OUTPUT]][] : memref<f32> // CHECK: store [[Ot_OUTPUT]], [[SIGMOID_OUTPUT]][] : memref<f32>
// CHECK: krnl.define_loops 0
// CHECK: krnl.optimize_loops {
// CHECK: krnl.return_loops
// CHECK: } : () -> ()
// CHECK: krnl.iterate() with () {
// CHECK: {{.*}} = load [[SIGMOID_OUTPUT]][] : memref<f32> // CHECK: {{.*}} = load [[SIGMOID_OUTPUT]][] : memref<f32>
// CHECK: {{.*}} = constant 0.000000e+00 : f32 // CHECK: {{.*}} = constant 0.000000e+00 : f32
// CHECK: {{.*}} = constant 1.000000e+00 : f32 // CHECK: {{.*}} = constant 1.000000e+00 : f32
@ -1990,16 +2013,10 @@ func @test_lstm_general_computation(%arg0: tensor<4x3x2xf32>, %arg1: tensor<1x12
// CHECK: {{.*}} = addf {{.*}}, {{.*}} : f32 // CHECK: {{.*}} = addf {{.*}}, {{.*}} : f32
// CHECK: {{.*}} = divf {{.*}}, {{.*}} : f32 // CHECK: {{.*}} = divf {{.*}}, {{.*}} : f32
// CHECK: store {{.*}}, [[Ot]][] : memref<f32> // CHECK: store {{.*}}, [[Ot]][] : memref<f32>
// CHECK: }
// CHECK: [[Ot_LOAD:%.+]] = load [[Ot]][] : memref<f32> // CHECK: [[Ot_LOAD:%.+]] = load [[Ot]][] : memref<f32>
// CHECK: [[TANH_HIDDEN:%.+]] = alloc() : memref<f32> // CHECK: [[TANH_HIDDEN:%.+]] = alloc() : memref<f32>
// CHECK: store [[Ct]], [[TANH_HIDDEN]][] : memref<f32> // CHECK: store [[Ct]], [[TANH_HIDDEN]][] : memref<f32>
// CHECK: krnl.define_loops 0
// CHECK: krnl.optimize_loops {
// CHECK: krnl.return_loops
// CHECK: } : () -> ()
// CHECK: krnl.iterate() with () {
// CHECK: {{.*}} = load [[TANH_HIDDEN]][] : memref<f32> // CHECK: {{.*}} = load [[TANH_HIDDEN]][] : memref<f32>
// CHECK: {{.*}} = constant 0.000000e+00 : f32 // CHECK: {{.*}} = constant 0.000000e+00 : f32
// CHECK: {{.*}} = subf {{.*}}, {{.*}} : f32 // CHECK: {{.*}} = subf {{.*}}, {{.*}} : f32
@ -2009,7 +2026,6 @@ func @test_lstm_general_computation(%arg0: tensor<4x3x2xf32>, %arg1: tensor<1x12
// CHECK: {{.*}} = addf {{.*}}, {{.*}} : f32 // CHECK: {{.*}} = addf {{.*}}, {{.*}} : f32
// CHECK: {{.*}} = divf {{.*}}, {{.*}} : f32 // CHECK: {{.*}} = divf {{.*}}, {{.*}} : f32
// CHECK: store {{.*}}, [[hCt]][] : memref<f32> // CHECK: store {{.*}}, [[hCt]][] : memref<f32>
// CHECK: }
// CHECK: [[hCt_LOAD:%.+]] = load [[hCt]][] : memref<f32> // CHECK: [[hCt_LOAD:%.+]] = load [[hCt]][] : memref<f32>
// CHECK: [[Ht:%.+]] = mulf [[Ot_LOAD]], [[hCt_LOAD]] : f32 // CHECK: [[Ht:%.+]] = mulf [[Ot_LOAD]], [[hCt_LOAD]] : f32