Specialize the op lowering logic for element-wise operations (#118)
* Specialize the op lowering logic for elementwise operations * Fix clang-format error. * Update tests for LSTM since LSTM uses element-wise ops Co-authored-by: Tian Jin <tjingrant@gmail.com>
This commit is contained in:
parent
4dd3c809c7
commit
d65a6e72dd
|
@ -505,6 +505,7 @@ struct ONNXElementwiseUnaryOpLowering : public ConversionPattern {
|
||||||
// An element-wise unary operation must have all operands and the result of
|
// An element-wise unary operation must have all operands and the result of
|
||||||
// the same type. This should have been verified by the verifier.
|
// the same type. This should have been verified by the verifier.
|
||||||
auto loc = op->getLoc();
|
auto loc = op->getLoc();
|
||||||
|
auto X = operands[0];
|
||||||
|
|
||||||
// Insert an allocation and deallocation for the result of this operation.
|
// Insert an allocation and deallocation for the result of this operation.
|
||||||
auto memRefType = convertToMemRefType(*op->result_type_begin());
|
auto memRefType = convertToMemRefType(*op->result_type_begin());
|
||||||
|
@ -521,14 +522,16 @@ struct ONNXElementwiseUnaryOpLowering : public ConversionPattern {
|
||||||
if (hasAllConstantDimensions(memRefType))
|
if (hasAllConstantDimensions(memRefType))
|
||||||
alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc);
|
alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc);
|
||||||
else
|
else
|
||||||
alloc = insertAllocAndDealloc(
|
alloc =
|
||||||
memRefType, loc, rewriter, insertDealloc, {operands[0]});
|
insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc, {X});
|
||||||
|
|
||||||
|
SmallVector<Value, 4> loopIVs;
|
||||||
|
if (!hasAllScalarValues(operands)) {
|
||||||
std::vector<Value> originalLoops;
|
std::vector<Value> originalLoops;
|
||||||
KrnlOptimizeLoopsOp optimizedLoopsOp;
|
KrnlOptimizeLoopsOp optimizedLoopsOp;
|
||||||
KrnlIterateOp iterateOp;
|
KrnlIterateOp iterateOp;
|
||||||
emitKrnlLoopsAndIterationForOperand(
|
emitKrnlLoopsAndIterationForOperand(
|
||||||
rewriter, loc, operands[0], originalLoops, optimizedLoopsOp, iterateOp);
|
rewriter, loc, X, originalLoops, optimizedLoopsOp, iterateOp);
|
||||||
Block &optimizationBlock = optimizedLoopsOp.region().front();
|
Block &optimizationBlock = optimizedLoopsOp.region().front();
|
||||||
Block &iterationBlock = iterateOp.bodyRegion().front();
|
Block &iterationBlock = iterateOp.bodyRegion().front();
|
||||||
|
|
||||||
|
@ -543,11 +546,11 @@ struct ONNXElementwiseUnaryOpLowering : public ConversionPattern {
|
||||||
rewriter.setInsertionPointToStart(&iterationBlock);
|
rewriter.setInsertionPointToStart(&iterationBlock);
|
||||||
|
|
||||||
// Handle the operation:
|
// Handle the operation:
|
||||||
SmallVector<Value, 4> loopIVs;
|
|
||||||
for (auto arg : iterationBlock.getArguments())
|
for (auto arg : iterationBlock.getArguments())
|
||||||
loopIVs.push_back(arg);
|
loopIVs.push_back(arg);
|
||||||
|
}
|
||||||
|
|
||||||
auto loadedVal = rewriter.create<LoadOp>(loc, operands[0], loopIVs);
|
auto loadedVal = rewriter.create<LoadOp>(loc, X, loopIVs);
|
||||||
auto loweredOpResult = emitScalarOpFor<ElementwiseUnaryOp>(
|
auto loweredOpResult = emitScalarOpFor<ElementwiseUnaryOp>(
|
||||||
rewriter, loc, op, memRefType.getElementType(), {loadedVal});
|
rewriter, loc, op, memRefType.getElementType(), {loadedVal});
|
||||||
// Store result in the resulting array.
|
// Store result in the resulting array.
|
||||||
|
@ -589,9 +592,12 @@ struct ONNXElementwiseVariadicOpLowering : public ConversionPattern {
|
||||||
alloc = insertAllocAndDealloc(
|
alloc = insertAllocAndDealloc(
|
||||||
memRefType, loc, rewriter, insertDealloc, operands);
|
memRefType, loc, rewriter, insertDealloc, operands);
|
||||||
|
|
||||||
|
SmallVector<Value, 4> loopIVs;
|
||||||
|
std::map<int, std::map<int, Value>> broadcastedDimInfo;
|
||||||
|
if (!hasAllScalarValues(operands)) {
|
||||||
// Get run-time dimension information for unknown dimensions used for
|
// Get run-time dimension information for unknown dimensions used for
|
||||||
// broadcasting.
|
// broadcasting.
|
||||||
std::map<int, std::map<int, Value>> broadcastedDimInfo =
|
broadcastedDimInfo =
|
||||||
getBroadcastedDimInfo(loc, rewriter, memRefType, operands);
|
getBroadcastedDimInfo(loc, rewriter, memRefType, operands);
|
||||||
|
|
||||||
std::vector<Value> originalLoops;
|
std::vector<Value> originalLoops;
|
||||||
|
@ -612,10 +618,9 @@ struct ONNXElementwiseVariadicOpLowering : public ConversionPattern {
|
||||||
rewriter.setInsertionPointToStart(&iterationBlock);
|
rewriter.setInsertionPointToStart(&iterationBlock);
|
||||||
|
|
||||||
// Handle the operation:
|
// Handle the operation:
|
||||||
SmallVector<Value, 4> loopIVs;
|
|
||||||
for (auto arg : iterationBlock.getArguments())
|
for (auto arg : iterationBlock.getArguments())
|
||||||
loopIVs.push_back(arg);
|
loopIVs.push_back(arg);
|
||||||
|
}
|
||||||
// Fold over operands for each of their scalar values
|
// Fold over operands for each of their scalar values
|
||||||
Value accumulated, next;
|
Value accumulated, next;
|
||||||
auto accumulatedLoopIVs = getLoopIVsForBroadcasting(
|
auto accumulatedLoopIVs = getLoopIVsForBroadcasting(
|
||||||
|
|
|
@ -20,6 +20,15 @@ bool hasAllConstantDimensions(MemRefType type) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Check is all operands are scalar values at compile time.
|
||||||
|
bool hasAllScalarValues(ArrayRef<Value> values) {
|
||||||
|
for (Value value : values) {
|
||||||
|
if (value.getType().cast<ShapedType>().getRank() != 0)
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
/// Get the corresponding MemRefType of a given TensorType/MemRefType.
|
/// Get the corresponding MemRefType of a given TensorType/MemRefType.
|
||||||
MemRefType convertToMemRefType(Type type) {
|
MemRefType convertToMemRefType(Type type) {
|
||||||
MemRefType memRefType;
|
MemRefType memRefType;
|
||||||
|
|
|
@ -35,6 +35,9 @@ using namespace mlir;
|
||||||
/// Check is all dimensions are known at compile time.
|
/// Check is all dimensions are known at compile time.
|
||||||
bool hasAllConstantDimensions(MemRefType type);
|
bool hasAllConstantDimensions(MemRefType type);
|
||||||
|
|
||||||
|
/// Check is all operands are scalar values at compile time.
|
||||||
|
bool hasAllScalarValues(ArrayRef<Value> values);
|
||||||
|
|
||||||
/// Get the corresponding MemRefType of a given TensorType/MemRefType.
|
/// Get the corresponding MemRefType of a given TensorType/MemRefType.
|
||||||
MemRefType convertToMemRefType(Type type);
|
MemRefType convertToMemRefType(Type type);
|
||||||
|
|
||||||
|
|
|
@ -18,6 +18,52 @@ func @test_no_argument_2() -> tensor<*xf32> {
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
|
func @test_elementwise_op_with_scalar_values_1(%arg0 : tensor<f32>) -> tensor<*xf32> {
|
||||||
|
%0 = "onnx.Exp"(%arg0) : (tensor<f32>) -> tensor<*xf32>
|
||||||
|
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||||
|
|
||||||
|
// CHECK-LABEL: test_elementwise_op_with_scalar_values_1
|
||||||
|
// CHECK: [[RES:%.+]] = alloc() : memref<f32>
|
||||||
|
// CHECK: [[LOAD:%.+]] = load %arg0[] : memref<f32>
|
||||||
|
// CHECK: [[EXP:%.+]] = exp [[LOAD]] : f32
|
||||||
|
// CHECK: store [[EXP]], [[RES]][] : memref<f32>
|
||||||
|
// CHECK: return [[RES]] : memref<f32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
func @test_elementwise_op_with_scalar_values_2(%arg0 : tensor<f32>, %arg1 : tensor<f32>) -> tensor<*xf32> {
|
||||||
|
%0 = "onnx.Add"(%arg0, %arg1) : (tensor<f32>, tensor<f32>) -> tensor<*xf32>
|
||||||
|
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||||
|
|
||||||
|
// CHECK-LABEL: test_elementwise_op_with_scalar_values_2
|
||||||
|
// CHECK: [[RES:%.+]] = alloc() : memref<f32>
|
||||||
|
// CHECK: [[LOAD1:%.+]] = load %arg0[] : memref<f32>
|
||||||
|
// CHECK: [[LOAD2:%.+]] = load %arg1[] : memref<f32>
|
||||||
|
// CHECK: [[ADD:%.+]] = addf [[LOAD1]], [[LOAD2]] : f32
|
||||||
|
// CHECK: store [[ADD]], [[RES]][] : memref<f32>
|
||||||
|
// CHECK: return [[RES]] : memref<f32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
func @test_elementwise_op_with_scalar_values_3(%arg0 : tensor<f32>, %arg1 : tensor<f32>, %arg2 : tensor<f32>) -> tensor<*xf32> {
|
||||||
|
%0 = "onnx.Sum"(%arg0, %arg1, %arg2) : (tensor<f32>, tensor<f32>, tensor<f32>) -> tensor<*xf32>
|
||||||
|
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||||
|
|
||||||
|
// CHECK-LABEL: test_elementwise_op_with_scalar_values_3
|
||||||
|
// CHECK: [[RES:%.+]] = alloc() : memref<f32>
|
||||||
|
// CHECK: [[LOAD1:%.+]] = load %arg0[] : memref<f32>
|
||||||
|
// CHECK: [[LOAD2:%.+]] = load %arg1[] : memref<f32>
|
||||||
|
// CHECK: [[ADD1:%.+]] = addf [[LOAD1]], [[LOAD2]] : f32
|
||||||
|
// CHECK: [[LOAD3:%.+]] = load %arg2[] : memref<f32>
|
||||||
|
// CHECK: [[ADD2:%.+]] = addf [[ADD1]], [[LOAD3]] : f32
|
||||||
|
// CHECK: store [[ADD2]], [[RES]][] : memref<f32>
|
||||||
|
// CHECK: return [[RES]] : memref<f32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
func @test_add(%arg0 : tensor<10x10xf32>, %arg1 : tensor<10x10xf32>) -> tensor<*xf32> {
|
func @test_add(%arg0 : tensor<10x10xf32>, %arg1 : tensor<10x10xf32>) -> tensor<*xf32> {
|
||||||
%0 = "onnx.Add"(%arg0, %arg1) : (tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<*xf32>
|
%0 = "onnx.Add"(%arg0, %arg1) : (tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<*xf32>
|
||||||
"std.return"(%0) : (tensor<*xf32>) -> ()
|
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||||
|
@ -1854,49 +1900,49 @@ func @test_lstm_general_computation(%arg0: tensor<4x3x2xf32>, %arg1: tensor<1x12
|
||||||
// CHECK: {{.*}} = mulf [[Xt_LOAD]], [[Wi_LOAD]] : f32
|
// CHECK: {{.*}} = mulf [[Xt_LOAD]], [[Wi_LOAD]] : f32
|
||||||
// CHECK: {{.*}} = load [[XtWi_GEMM]][] : memref<f32>
|
// CHECK: {{.*}} = load [[XtWi_GEMM]][] : memref<f32>
|
||||||
// CHECK: {{.*}} = addf {{.*}}, {{.*}} : f32
|
// CHECK: {{.*}} = addf {{.*}}, {{.*}} : f32
|
||||||
// CHECK: store %59, [[XtWi_GEMM]][] : memref<f32>
|
// CHECK: store {{.*}}, [[XtWi_GEMM]][] : memref<f32>
|
||||||
|
|
||||||
// CHECK: [[Ri_LOAD:%.+]] = load %arg2[%c0, [[INPUT_HIDDEN_INDEX]], %arg6] : memref<1x12x3xf32>
|
// CHECK: [[Ri_LOAD:%.+]] = load %arg2[%c0, [[INPUT_HIDDEN_INDEX]], %arg6] : memref<1x12x3xf32>
|
||||||
// CHECK: {{.*}} = mulf [[Ht1_LOAD]], [[Ri_LOAD]] : f32
|
// CHECK: {{.*}} = mulf [[Ht1_LOAD]], [[Ri_LOAD]] : f32
|
||||||
// CHECK: {{.*}} = load [[Ht1Ri_GEMM]][] : memref<f32>
|
// CHECK: {{.*}} = load [[Ht1Ri_GEMM]][] : memref<f32>
|
||||||
// CHECK: {{.*}} = addf {{.*}}, {{.*}} : f32
|
// CHECK: {{.*}} = addf {{.*}}, {{.*}} : f32
|
||||||
// CHECK: store %63, [[Ht1Ri_GEMM]][] : memref<f32>
|
// CHECK: store {{.*}}, [[Ht1Ri_GEMM]][] : memref<f32>
|
||||||
|
|
||||||
// CHECK: [[Wo_LOAD:%.+]] = load %arg1[%c0, [[OUTPUT_HIDDEN_INDEX]], %arg6] : memref<1x12x2xf32>
|
// CHECK: [[Wo_LOAD:%.+]] = load %arg1[%c0, [[OUTPUT_HIDDEN_INDEX]], %arg6] : memref<1x12x2xf32>
|
||||||
// CHECK: {{.*}} = mulf [[Xt_LOAD]], [[Wo_LOAD]] : f32
|
// CHECK: {{.*}} = mulf [[Xt_LOAD]], [[Wo_LOAD]] : f32
|
||||||
// CHECK: {{.*}} = load [[XtWo_GEMM]][] : memref<f32>
|
// CHECK: {{.*}} = load [[XtWo_GEMM]][] : memref<f32>
|
||||||
// CHECK: {{.*}} = addf {{.*}}, {{.*}} : f32
|
// CHECK: {{.*}} = addf {{.*}}, {{.*}} : f32
|
||||||
// CHECK: store %67, [[XtWo_GEMM]][] : memref<f32>
|
// CHECK: store {{.*}}, [[XtWo_GEMM]][] : memref<f32>
|
||||||
|
|
||||||
// CHECK: [[Ro_LOAD:%.+]] = load %arg2[%c0, [[OUTPUT_HIDDEN_INDEX]], %arg6] : memref<1x12x3xf32>
|
// CHECK: [[Ro_LOAD:%.+]] = load %arg2[%c0, [[OUTPUT_HIDDEN_INDEX]], %arg6] : memref<1x12x3xf32>
|
||||||
// CHECK: {{.*}} = mulf [[Ht1_LOAD]], [[Ro_LOAD]] : f32
|
// CHECK: {{.*}} = mulf [[Ht1_LOAD]], [[Ro_LOAD]] : f32
|
||||||
// CHECK: {{.*}} = load [[Ht1Ro_GEMM]][] : memref<f32>
|
// CHECK: {{.*}} = load [[Ht1Ro_GEMM]][] : memref<f32>
|
||||||
// CHECK: {{.*}} = addf {{.*}}, {{.*}} : f32
|
// CHECK: {{.*}} = addf {{.*}}, {{.*}} : f32
|
||||||
// CHECK: store %71, [[Ht1Ro_GEMM]][] : memref<f32>
|
// CHECK: store {{.*}}, [[Ht1Ro_GEMM]][] : memref<f32>
|
||||||
|
|
||||||
// CHECK: [[Wf_LOAD:%.+]] = load %arg1[%c0, [[FORGET_HIDDEN_INDEX]], %arg6] : memref<1x12x2xf32>
|
// CHECK: [[Wf_LOAD:%.+]] = load %arg1[%c0, [[FORGET_HIDDEN_INDEX]], %arg6] : memref<1x12x2xf32>
|
||||||
// CHECK: {{.*}} = mulf [[Xt_LOAD]], [[Wf_LOAD]] : f32
|
// CHECK: {{.*}} = mulf [[Xt_LOAD]], [[Wf_LOAD]] : f32
|
||||||
// CHECK: {{.*}} = load [[XtWf_GEMM]][] : memref<f32>
|
// CHECK: {{.*}} = load [[XtWf_GEMM]][] : memref<f32>
|
||||||
// CHECK: {{.*}} = addf {{.*}}, {{.*}} : f32
|
// CHECK: {{.*}} = addf {{.*}}, {{.*}} : f32
|
||||||
// CHECK: store %75, [[XtWf_GEMM]][] : memref<f32>
|
// CHECK: store {{.*}}, [[XtWf_GEMM]][] : memref<f32>
|
||||||
|
|
||||||
// CHECK: [[Rf_LOAD:%.+]] = load %arg2[%c0, [[FORGET_HIDDEN_INDEX]], %arg6] : memref<1x12x3xf32>
|
// CHECK: [[Rf_LOAD:%.+]] = load %arg2[%c0, [[FORGET_HIDDEN_INDEX]], %arg6] : memref<1x12x3xf32>
|
||||||
// CHECK: {{.*}} = mulf [[Ht1_LOAD]], [[Rf_LOAD]] : f32
|
// CHECK: {{.*}} = mulf [[Ht1_LOAD]], [[Rf_LOAD]] : f32
|
||||||
// CHECK: {{.*}} = load [[Ht1Rf_GEMM]][] : memref<f32>
|
// CHECK: {{.*}} = load [[Ht1Rf_GEMM]][] : memref<f32>
|
||||||
// CHECK: {{.*}} = addf {{.*}}, {{.*}} : f32
|
// CHECK: {{.*}} = addf {{.*}}, {{.*}} : f32
|
||||||
// CHECK: store %79, [[Ht1Rf_GEMM]][] : memref<f32>
|
// CHECK: store {{.*}}, [[Ht1Rf_GEMM]][] : memref<f32>
|
||||||
|
|
||||||
// CHECK: [[Wc_LOAD:%.+]] = load %arg1[%c0, [[CELL_HIDDEN_INDEX]], %arg6] : memref<1x12x2xf32>
|
// CHECK: [[Wc_LOAD:%.+]] = load %arg1[%c0, [[CELL_HIDDEN_INDEX]], %arg6] : memref<1x12x2xf32>
|
||||||
// CHECK: {{.*}} = mulf [[Xt_LOAD]], [[Wc_LOAD]] : f32
|
// CHECK: {{.*}} = mulf [[Xt_LOAD]], [[Wc_LOAD]] : f32
|
||||||
// CHECK: {{.*}} = load [[XtWc_GEMM]][] : memref<f32>
|
// CHECK: {{.*}} = load [[XtWc_GEMM]][] : memref<f32>
|
||||||
// CHECK: {{.*}} = addf {{.*}}, {{.*}} : f32
|
// CHECK: {{.*}} = addf {{.*}}, {{.*}} : f32
|
||||||
// CHECK: store %83, [[XtWc_GEMM]][] : memref<f32>
|
// CHECK: store {{.*}}, [[XtWc_GEMM]][] : memref<f32>
|
||||||
|
|
||||||
// CHECK: [[Rc_LOAD:%.+]] = load %arg2[%c0, [[CELL_HIDDEN_INDEX]], %arg6] : memref<1x12x3xf32>
|
// CHECK: [[Rc_LOAD:%.+]] = load %arg2[%c0, [[CELL_HIDDEN_INDEX]], %arg6] : memref<1x12x3xf32>
|
||||||
// CHECK: {{.*}} = mulf [[Ht1_LOAD]], [[Rc_LOAD]] : f32
|
// CHECK: {{.*}} = mulf [[Ht1_LOAD]], [[Rc_LOAD]] : f32
|
||||||
// CHECK: {{.*}} = load [[Ht1Rc_GEMM]][] : memref<f32>
|
// CHECK: {{.*}} = load [[Ht1Rc_GEMM]][] : memref<f32>
|
||||||
// CHECK: {{.*}} = addf {{.*}}, {{.*}} : f32
|
// CHECK: {{.*}} = addf {{.*}}, {{.*}} : f32
|
||||||
// CHECK: store %87, [[Ht1Rc_GEMM]][] : memref<f32>
|
// CHECK: store {{.*}}, [[Ht1Rc_GEMM]][] : memref<f32>
|
||||||
// CHECK: }
|
// CHECK: }
|
||||||
|
|
||||||
// CHECK: [[XtWi_LOAD:%.+]] = load [[XtWi_GEMM]][] : memref<f32>
|
// CHECK: [[XtWi_LOAD:%.+]] = load [[XtWi_GEMM]][] : memref<f32>
|
||||||
|
@ -1905,11 +1951,6 @@ func @test_lstm_general_computation(%arg0: tensor<4x3x2xf32>, %arg1: tensor<1x12
|
||||||
|
|
||||||
// CHECK: [[SIGMOID_INPUT:%.+]] = alloc() : memref<f32>
|
// CHECK: [[SIGMOID_INPUT:%.+]] = alloc() : memref<f32>
|
||||||
// CHECK: store [[It_OUTPUT]], [[SIGMOID_INPUT]][] : memref<f32>
|
// CHECK: store [[It_OUTPUT]], [[SIGMOID_INPUT]][] : memref<f32>
|
||||||
// CHECK: krnl.define_loops 0
|
|
||||||
// CHECK: krnl.optimize_loops {
|
|
||||||
// CHECK: krnl.return_loops
|
|
||||||
// CHECK: } : () -> ()
|
|
||||||
// CHECK: krnl.iterate() with () {
|
|
||||||
// CHECK: {{.*}} = load [[SIGMOID_INPUT]][] : memref<f32>
|
// CHECK: {{.*}} = load [[SIGMOID_INPUT]][] : memref<f32>
|
||||||
// CHECK: {{.*}} = constant 0.000000e+00 : f32
|
// CHECK: {{.*}} = constant 0.000000e+00 : f32
|
||||||
// CHECK: {{.*}} = constant 1.000000e+00 : f32
|
// CHECK: {{.*}} = constant 1.000000e+00 : f32
|
||||||
|
@ -1918,7 +1959,6 @@ func @test_lstm_general_computation(%arg0: tensor<4x3x2xf32>, %arg1: tensor<1x12
|
||||||
// CHECK: {{.*}} = addf {{.*}}, {{.*}} : f32
|
// CHECK: {{.*}} = addf {{.*}}, {{.*}} : f32
|
||||||
// CHECK: {{.*}} = divf {{.*}}, {{.*}} : f32
|
// CHECK: {{.*}} = divf {{.*}}, {{.*}} : f32
|
||||||
// CHECK: store {{.*}}, [[It]][] : memref<f32>
|
// CHECK: store {{.*}}, [[It]][] : memref<f32>
|
||||||
// CHECK: }
|
|
||||||
// CHECK: [[It_LOAD:%.+]] = load [[It]][] : memref<f32>
|
// CHECK: [[It_LOAD:%.+]] = load [[It]][] : memref<f32>
|
||||||
|
|
||||||
// CHECK: [[XtWf_LOAD:%.+]] = load [[XtWf_GEMM]][] : memref<f32>
|
// CHECK: [[XtWf_LOAD:%.+]] = load [[XtWf_GEMM]][] : memref<f32>
|
||||||
|
@ -1927,11 +1967,6 @@ func @test_lstm_general_computation(%arg0: tensor<4x3x2xf32>, %arg1: tensor<1x12
|
||||||
|
|
||||||
// CHECK: [[SIGMOID_FORGET:%.+]] = alloc() : memref<f32>
|
// CHECK: [[SIGMOID_FORGET:%.+]] = alloc() : memref<f32>
|
||||||
// CHECK: store [[Ft_OUTPUT]], [[SIGMOID_FORGET]][] : memref<f32>
|
// CHECK: store [[Ft_OUTPUT]], [[SIGMOID_FORGET]][] : memref<f32>
|
||||||
// CHECK: krnl.define_loops 0
|
|
||||||
// CHECK: krnl.optimize_loops {
|
|
||||||
// CHECK: krnl.return_loops
|
|
||||||
// CHECK: } : () -> ()
|
|
||||||
// CHECK: krnl.iterate() with () {
|
|
||||||
// CHECK: {{.*}} = load [[SIGMOID_FORGET]][] : memref<f32>
|
// CHECK: {{.*}} = load [[SIGMOID_FORGET]][] : memref<f32>
|
||||||
// CHECK: {{.*}} = constant 0.000000e+00 : f32
|
// CHECK: {{.*}} = constant 0.000000e+00 : f32
|
||||||
// CHECK: {{.*}} = constant 1.000000e+00 : f32
|
// CHECK: {{.*}} = constant 1.000000e+00 : f32
|
||||||
|
@ -1940,7 +1975,6 @@ func @test_lstm_general_computation(%arg0: tensor<4x3x2xf32>, %arg1: tensor<1x12
|
||||||
// CHECK: {{.*}} = addf {{.*}}, {{.*}} : f32
|
// CHECK: {{.*}} = addf {{.*}}, {{.*}} : f32
|
||||||
// CHECK: {{.*}} = divf {{.*}}, {{.*}} : f32
|
// CHECK: {{.*}} = divf {{.*}}, {{.*}} : f32
|
||||||
// CHECK: store {{.*}}, [[Ft]][] : memref<f32>
|
// CHECK: store {{.*}}, [[Ft]][] : memref<f32>
|
||||||
// CHECK: }
|
|
||||||
// CHECK: [[Ft_LOAD:%.+]] = load [[Ft]][] : memref<f32>
|
// CHECK: [[Ft_LOAD:%.+]] = load [[Ft]][] : memref<f32>
|
||||||
|
|
||||||
// CHECK: [[XtWc_LOAD:%.+]] = load [[XtWc_GEMM]][] : memref<f32>
|
// CHECK: [[XtWc_LOAD:%.+]] = load [[XtWc_GEMM]][] : memref<f32>
|
||||||
|
@ -1949,11 +1983,6 @@ func @test_lstm_general_computation(%arg0: tensor<4x3x2xf32>, %arg1: tensor<1x12
|
||||||
|
|
||||||
// CHECK: [[TANH_CELL:%.+]] = alloc() : memref<f32>
|
// CHECK: [[TANH_CELL:%.+]] = alloc() : memref<f32>
|
||||||
// CHECK: store [[ct_OUTPUT]], [[TANH_CELL]][] : memref<f32>
|
// CHECK: store [[ct_OUTPUT]], [[TANH_CELL]][] : memref<f32>
|
||||||
// CHECK: krnl.define_loops 0
|
|
||||||
// CHECK: krnl.optimize_loops {
|
|
||||||
// CHECK: krnl.return_loops
|
|
||||||
// CHECK: } : () -> ()
|
|
||||||
// CHECK: krnl.iterate() with () {
|
|
||||||
// CHECK: {{.*}} = load [[TANH_CELL]][] : memref<f32>
|
// CHECK: {{.*}} = load [[TANH_CELL]][] : memref<f32>
|
||||||
// CHECK: {{.*}} = constant 0.000000e+00 : f32
|
// CHECK: {{.*}} = constant 0.000000e+00 : f32
|
||||||
// CHECK: {{.*}} = subf {{.*}}, {{.*}} : f32
|
// CHECK: {{.*}} = subf {{.*}}, {{.*}} : f32
|
||||||
|
@ -1963,7 +1992,6 @@ func @test_lstm_general_computation(%arg0: tensor<4x3x2xf32>, %arg1: tensor<1x12
|
||||||
// CHECK: {{.*}} = addf {{.*}}, {{.*}} : f32
|
// CHECK: {{.*}} = addf {{.*}}, {{.*}} : f32
|
||||||
// CHECK: {{.*}} = divf {{.*}}, {{.*}} : f32
|
// CHECK: {{.*}} = divf {{.*}}, {{.*}} : f32
|
||||||
// CHECK: store {{.*}}, [[ct]][] : memref<f32>
|
// CHECK: store {{.*}}, [[ct]][] : memref<f32>
|
||||||
// CHECK: }
|
|
||||||
// CHECK: [[ct_LOAD:%.+]] = load [[ct]][] : memref<f32>
|
// CHECK: [[ct_LOAD:%.+]] = load [[ct]][] : memref<f32>
|
||||||
|
|
||||||
// CHECK: [[FtCt1:%.+]] = mulf [[Ft_LOAD]], [[Ct1_LOAD]] : f32
|
// CHECK: [[FtCt1:%.+]] = mulf [[Ft_LOAD]], [[Ct1_LOAD]] : f32
|
||||||
|
@ -1977,11 +2005,6 @@ func @test_lstm_general_computation(%arg0: tensor<4x3x2xf32>, %arg1: tensor<1x12
|
||||||
|
|
||||||
// CHECK: [[SIGMOID_OUTPUT:%.+]] = alloc() : memref<f32>
|
// CHECK: [[SIGMOID_OUTPUT:%.+]] = alloc() : memref<f32>
|
||||||
// CHECK: store [[Ot_OUTPUT]], [[SIGMOID_OUTPUT]][] : memref<f32>
|
// CHECK: store [[Ot_OUTPUT]], [[SIGMOID_OUTPUT]][] : memref<f32>
|
||||||
// CHECK: krnl.define_loops 0
|
|
||||||
// CHECK: krnl.optimize_loops {
|
|
||||||
// CHECK: krnl.return_loops
|
|
||||||
// CHECK: } : () -> ()
|
|
||||||
// CHECK: krnl.iterate() with () {
|
|
||||||
// CHECK: {{.*}} = load [[SIGMOID_OUTPUT]][] : memref<f32>
|
// CHECK: {{.*}} = load [[SIGMOID_OUTPUT]][] : memref<f32>
|
||||||
// CHECK: {{.*}} = constant 0.000000e+00 : f32
|
// CHECK: {{.*}} = constant 0.000000e+00 : f32
|
||||||
// CHECK: {{.*}} = constant 1.000000e+00 : f32
|
// CHECK: {{.*}} = constant 1.000000e+00 : f32
|
||||||
|
@ -1990,16 +2013,10 @@ func @test_lstm_general_computation(%arg0: tensor<4x3x2xf32>, %arg1: tensor<1x12
|
||||||
// CHECK: {{.*}} = addf {{.*}}, {{.*}} : f32
|
// CHECK: {{.*}} = addf {{.*}}, {{.*}} : f32
|
||||||
// CHECK: {{.*}} = divf {{.*}}, {{.*}} : f32
|
// CHECK: {{.*}} = divf {{.*}}, {{.*}} : f32
|
||||||
// CHECK: store {{.*}}, [[Ot]][] : memref<f32>
|
// CHECK: store {{.*}}, [[Ot]][] : memref<f32>
|
||||||
// CHECK: }
|
|
||||||
// CHECK: [[Ot_LOAD:%.+]] = load [[Ot]][] : memref<f32>
|
// CHECK: [[Ot_LOAD:%.+]] = load [[Ot]][] : memref<f32>
|
||||||
|
|
||||||
// CHECK: [[TANH_HIDDEN:%.+]] = alloc() : memref<f32>
|
// CHECK: [[TANH_HIDDEN:%.+]] = alloc() : memref<f32>
|
||||||
// CHECK: store [[Ct]], [[TANH_HIDDEN]][] : memref<f32>
|
// CHECK: store [[Ct]], [[TANH_HIDDEN]][] : memref<f32>
|
||||||
// CHECK: krnl.define_loops 0
|
|
||||||
// CHECK: krnl.optimize_loops {
|
|
||||||
// CHECK: krnl.return_loops
|
|
||||||
// CHECK: } : () -> ()
|
|
||||||
// CHECK: krnl.iterate() with () {
|
|
||||||
// CHECK: {{.*}} = load [[TANH_HIDDEN]][] : memref<f32>
|
// CHECK: {{.*}} = load [[TANH_HIDDEN]][] : memref<f32>
|
||||||
// CHECK: {{.*}} = constant 0.000000e+00 : f32
|
// CHECK: {{.*}} = constant 0.000000e+00 : f32
|
||||||
// CHECK: {{.*}} = subf {{.*}}, {{.*}} : f32
|
// CHECK: {{.*}} = subf {{.*}}, {{.*}} : f32
|
||||||
|
@ -2009,7 +2026,6 @@ func @test_lstm_general_computation(%arg0: tensor<4x3x2xf32>, %arg1: tensor<1x12
|
||||||
// CHECK: {{.*}} = addf {{.*}}, {{.*}} : f32
|
// CHECK: {{.*}} = addf {{.*}}, {{.*}} : f32
|
||||||
// CHECK: {{.*}} = divf {{.*}}, {{.*}} : f32
|
// CHECK: {{.*}} = divf {{.*}}, {{.*}} : f32
|
||||||
// CHECK: store {{.*}}, [[hCt]][] : memref<f32>
|
// CHECK: store {{.*}}, [[hCt]][] : memref<f32>
|
||||||
// CHECK: }
|
|
||||||
// CHECK: [[hCt_LOAD:%.+]] = load [[hCt]][] : memref<f32>
|
// CHECK: [[hCt_LOAD:%.+]] = load [[hCt]][] : memref<f32>
|
||||||
|
|
||||||
// CHECK: [[Ht:%.+]] = mulf [[Ot_LOAD]], [[hCt_LOAD]] : f32
|
// CHECK: [[Ht:%.+]] = mulf [[Ot_LOAD]], [[hCt_LOAD]] : f32
|
||||||
|
|
Loading…
Reference in New Issue