Lower BatchNormalization (test mode) to Krnl dialect (#70)
* Add ONNXBatchNormalizationTestModeOp and its shape inference * Lower batchnormalization test mode * re-use scale, bias, mean, and variance * Add MLIR tests * Add e2e tests * fix typos * Fix a bug in MLIR tests * Change type from int to int64_t for indices * Uncomment e2e tests due to segmentation fault * Uncomment e2e tests due to segmentation fault * Revise the code * [Tian] Fix segmentation fault in e2e tests * Re-generate onnx.md to include BatchNormalizationTestModeOp * Reverse an unintentional change * Fix some typos in comments * Use convertToMemRefType from the master branch Co-authored-by: Gheorghe-Teodor Bercea <gt.bercea@gmail.com>
This commit is contained in:
parent
f1d20e368f
commit
aea6479ad3
|
@ -332,6 +332,42 @@ ONNX BatchNormalization operation
|
|||
1. `saved_mean`: memref of any type values or tensor of any type values
|
||||
1. `saved_var`: memref of any type values or tensor of any type values
|
||||
|
||||
### onnx.BatchNormalizationTestMode (ONNXBatchNormalizationTestModeOp)
|
||||
ONNX BatchNormalization operation in test mode
|
||||
|
||||
#### Description:
|
||||
|
||||
|
||||
"Carries out batch normalization as described in the paper"
|
||||
"https://arxiv.org/abs/1502.03167. Depending on the mode it is being run,"
|
||||
"there are multiple cases for the number of outputs, which we list below:"
|
||||
""
|
||||
"Output case #1: Y, mean, var, saved_mean, saved_var (training mode)"
|
||||
"Output case #2: Y (test mode)"
|
||||
""
|
||||
"For previous (depreciated) non-spatial cases, implementors are suggested"
|
||||
"to flatten the input shape to (N x C*D1*D2 ..*Dn) before a BatchNormalization Op."
|
||||
"This operator has **optional** inputs/outputs. See [the doc](IR.md) for more details about the representation of optional arguments. An empty string may be used in the place of an actual argument's name to indicate a missing argument. Trailing optional arguments (those not followed by an argument that is present) may also be simply omitted."
|
||||
|
||||
#### Operands:
|
||||
|
||||
1. `X`: memref of any type values or tensor of any type values
|
||||
1. `scale`: memref of any type values or tensor of any type values
|
||||
1. `B`: memref of any type values or tensor of any type values
|
||||
1. `mean`: memref of any type values or tensor of any type values
|
||||
1. `var`: memref of any type values or tensor of any type values
|
||||
|
||||
#### Attributes:
|
||||
|
||||
| Attribute | MLIR Type | Description |
|
||||
| :-------: | :-------: | ----------- |
|
||||
| `epsilon` | `FloatAttr` | 32-bit float attribute attribute |
|
||||
| `momentum` | `FloatAttr` | 32-bit float attribute attribute |
|
||||
|
||||
#### Results:
|
||||
|
||||
1. `o_Y`: memref of any type values or tensor of any type values
|
||||
|
||||
### onnx.BitShift (ONNXBitShiftOp)
|
||||
ONNX BitShift operation
|
||||
|
||||
|
|
|
@ -36,6 +36,7 @@ special_attr_defaults = dict([
|
|||
special_op_handler = dict([
|
||||
("Conv", "ImportNodeConv"),
|
||||
("MaxPool", "ImportNodeMaxPool"),
|
||||
("BatchNormalization", "ImportNodeBatchNormalization"),
|
||||
("Gemm", "ImportNodeGemm"),
|
||||
("Pad", "ImportNodePad"),
|
||||
#("Transpose", "ImportNodeTranspose")
|
||||
|
|
|
@ -434,6 +434,21 @@ private:
|
|||
}
|
||||
}
|
||||
|
||||
/*!
|
||||
* Special handle for BatchNormalization operations.
|
||||
*/
|
||||
void ImportNodeBatchNormalization(onnx::NodeProto node, int nIn, int nOut) {
|
||||
int nOuts = node.output().size();
|
||||
if (nOuts == 1) {
|
||||
// Test mode with one output.
|
||||
ImportNodeOneOut<mlir::ONNXBatchNormalizationTestModeOp>(node, nIn,
|
||||
nOuts);
|
||||
} else {
|
||||
// Training mode with four trailing optional outputs. Not handled yet.
|
||||
ImportNodeMultipleOuts<mlir::ONNXBatchNormalizationOp>(node, nIn, nOuts);
|
||||
}
|
||||
}
|
||||
|
||||
/*!
|
||||
* Special handle for Gemm operations.
|
||||
*/
|
||||
|
|
|
@ -30,7 +30,7 @@
|
|||
}else if (OpName == "AveragePool") {
|
||||
ImportNodeOneOut<mlir::ONNXAveragePoolOp>(node, 1, 1);
|
||||
}else if (OpName == "BatchNormalization") {
|
||||
ImportNodeMultipleOuts<mlir::ONNXBatchNormalizationOp>(node, 5, 5);
|
||||
ImportNodeBatchNormalization(node, 5, 5);
|
||||
}else if (OpName == "BitShift") {
|
||||
ImportNodeOneOut<mlir::ONNXBitShiftOp>(node, 2, 1);
|
||||
}else if (OpName == "Cast") {
|
||||
|
|
|
@ -404,6 +404,7 @@ Value mapToLowerScalarOp(Operation *op, ArrayRef<Type> result_types,
|
|||
#include "src/conversion/onnx_to_krnl/rewrite_patterns/tensor/unsqueeze.inc"
|
||||
// Neural network
|
||||
#include "src/conversion/onnx_to_krnl/rewrite_patterns/nn/conv.inc"
|
||||
#include "src/conversion/onnx_to_krnl/rewrite_patterns/nn/normalization.inc"
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// EntryPoint Op lowering to Krnl Entry Point.
|
||||
|
@ -523,6 +524,7 @@ void FrontendToKrnlLoweringPass::runOnModule() {
|
|||
populateLoweringONNXIdentityOpPattern(patterns, &getContext());
|
||||
// Neural network
|
||||
populateLoweringONNXConvOpPattern(patterns, &getContext());
|
||||
populateLoweringONNXNormalizationOpPattern(patterns, &getContext());
|
||||
// Entry point
|
||||
patterns.insert<ONNXEntryPointLowering>(&getContext());
|
||||
|
||||
|
|
|
@ -0,0 +1,140 @@
|
|||
//===----- normalization.inc - Lowering Normalization Ops -----------------===//
|
||||
//
|
||||
// Copyright 2019 The IBM Research Authors.
|
||||
//
|
||||
// =============================================================================
|
||||
//
|
||||
// This file lowers ONNX Normalization Operators to Krnl dialect.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
struct ONNXBatchNormalizationTestModeOpLowering : public ConversionPattern {
|
||||
ONNXBatchNormalizationTestModeOpLowering(MLIRContext *ctx)
|
||||
: ConversionPattern(
|
||||
mlir::ONNXBatchNormalizationTestModeOp::getOperationName(), 1,
|
||||
ctx) {}
|
||||
PatternMatchResult matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter & rewriter) const final {
|
||||
// batchnorm{epsilon}(x, scale, bias, mean, variance) =
|
||||
// scale * (x - mean) / sqrt(variance + epsilon) + bias
|
||||
auto loc = op->getLoc();
|
||||
|
||||
auto memRefType = convertToMemRefType(*op->result_type_begin());
|
||||
auto epsilonAttr =
|
||||
FloatAttr::get(memRefType.getElementType(),
|
||||
llvm::dyn_cast<ONNXBatchNormalizationTestModeOp>(op)
|
||||
.epsilon()
|
||||
.convertToFloat());
|
||||
auto epsilon = rewriter.create<ConstantOp>(loc, epsilonAttr);
|
||||
|
||||
auto operand = operands[0];
|
||||
auto scale = operands[1];
|
||||
auto bias = operands[2];
|
||||
auto mean = operands[3];
|
||||
auto variance = operands[4];
|
||||
|
||||
// Insert an allocation and deallocation for the result of this operation.
|
||||
Value alloc;
|
||||
bool insertDealloc = checkInsertDealloc(op);
|
||||
|
||||
if (hasAllConstantDimensions(memRefType))
|
||||
alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc);
|
||||
else
|
||||
alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc,
|
||||
{operand});
|
||||
|
||||
// Operand's dimensions can be in the form of NxCxD1xD2x...xDn or N.
|
||||
// In case of N, C is assumed to be 1.
|
||||
// Shapes of scale, bias, mean and variance must be C.
|
||||
// Computation of BatchNormalization is done as if scale, bias, mean, and
|
||||
// variance are reshaped to Cx1x1x...x1.
|
||||
|
||||
// rank
|
||||
int64_t rank = memRefType.getRank();
|
||||
|
||||
std::vector<Value> originalLoops;
|
||||
std::vector<Value> optimizedLoops;
|
||||
Block *optimizationBlock =
|
||||
defineLoops(rewriter, loc, originalLoops, optimizedLoops, rank);
|
||||
|
||||
// Create a KrnlIterateOp along C dimension.
|
||||
// This will be the outer-most loop in order to re-use scale, bias,
|
||||
// mean and variance.
|
||||
|
||||
SmallVector<Value, 1> loopCIVs;
|
||||
if (rank > 1) {
|
||||
KrnlIterateOperandPack cPack(rewriter, originalLoops[1],
|
||||
optimizedLoops[1]);
|
||||
addDimensionToPack(rewriter, loc, cPack, operand, 1);
|
||||
auto cIterateOp = rewriter.create<KrnlIterateOp>(loc, cPack);
|
||||
Block &cIterationBlock = cIterateOp.bodyRegion().front();
|
||||
rewriter.setInsertionPointToStart(&cIterationBlock);
|
||||
for (auto arg : cIterationBlock.getArguments())
|
||||
loopCIVs.emplace_back(arg);
|
||||
} else {
|
||||
loopCIVs.emplace_back(rewriter.create<ConstantIndexOp>(loc, 0));
|
||||
}
|
||||
|
||||
auto scaleVal = rewriter.create<LoadOp>(loc, scale, loopCIVs);
|
||||
auto biasVal = rewriter.create<LoadOp>(loc, bias, loopCIVs);
|
||||
auto meanVal = rewriter.create<LoadOp>(loc, mean, loopCIVs);
|
||||
auto varianceVal = rewriter.create<LoadOp>(loc, variance, loopCIVs);
|
||||
|
||||
// Create a KrnlIterateOp along the other dimensions.
|
||||
SmallVector<int64_t, 4> axes;
|
||||
axes.emplace_back(0);
|
||||
for (int64_t i = 2; i < rank; ++i)
|
||||
axes.emplace_back(i);
|
||||
std::vector<Value> packLoops, packOptimizedLoops;
|
||||
for (int i = 0; i < axes.size(); ++i) {
|
||||
packLoops.emplace_back(originalLoops[axes[i]]);
|
||||
packOptimizedLoops.emplace_back(optimizedLoops[axes[i]]);
|
||||
}
|
||||
KrnlIterateOperandPack pack(rewriter, packLoops, packOptimizedLoops);
|
||||
for (int i = 0; i < axes.size(); ++i) {
|
||||
addDimensionToPack(rewriter, loc, pack, operand, axes[i]);
|
||||
}
|
||||
auto iterateOp = rewriter.create<KrnlIterateOp>(loc, pack);
|
||||
|
||||
// No optimization
|
||||
rewriter.setInsertionPointToEnd(optimizationBlock);
|
||||
rewriter.create<KrnlReturnLoopsOp>(loc, originalLoops);
|
||||
|
||||
Block &iterationBlock = iterateOp.bodyRegion().front();
|
||||
rewriter.setInsertionPointToStart(&iterationBlock);
|
||||
|
||||
SmallVector<Value, 4> loopIVs;
|
||||
auto args = iterationBlock.getArguments();
|
||||
if (args.size() > 1) {
|
||||
loopIVs.emplace_back(args[0]);
|
||||
loopIVs.emplace_back(loopCIVs[0]); // Insert C back.
|
||||
for (int i = 1; i < args.size(); ++i)
|
||||
loopIVs.emplace_back(args[i]);
|
||||
} else {
|
||||
loopIVs.emplace_back(args[0]);
|
||||
}
|
||||
|
||||
auto xVal = rewriter.create<LoadOp>(loc, operand, loopIVs);
|
||||
// normalize
|
||||
auto dividend = rewriter.create<SubFOp>(loc, xVal, meanVal);
|
||||
auto adjustedVarianceVal =
|
||||
rewriter.create<AddFOp>(loc, varianceVal, epsilon);
|
||||
auto divisor = rewriter.create<KrnlSqrtOp>(loc, memRefType.getElementType(),
|
||||
adjustedVarianceVal);
|
||||
auto normVal = rewriter.create<DivFOp>(loc, dividend, divisor);
|
||||
// scale and shift
|
||||
auto scaleNormVal = rewriter.create<MulFOp>(loc, scaleVal, normVal);
|
||||
auto shiftScaleNormVal =
|
||||
rewriter.create<AddFOp>(loc, scaleNormVal, biasVal);
|
||||
rewriter.create<StoreOp>(loc, shiftScaleNormVal, alloc, loopIVs);
|
||||
|
||||
rewriter.replaceOp(op, alloc);
|
||||
|
||||
return matchSuccess();
|
||||
}
|
||||
};
|
||||
|
||||
void populateLoweringONNXNormalizationOpPattern(
|
||||
OwningRewritePatternList &patterns, MLIRContext *ctx) {
|
||||
patterns.insert<ONNXBatchNormalizationTestModeOpLowering>(ctx);
|
||||
}
|
|
@ -146,6 +146,31 @@ def ONNXMaxPoolSingleOutOp: ONNX_Op<"MaxPoolSingleOut",
|
|||
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$o_Y);
|
||||
}
|
||||
|
||||
def ONNXBatchNormalizationTestModeOp: ONNX_Op<"BatchNormalizationTestMode",
|
||||
[NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
|
||||
let summary = "ONNX BatchNormalization operation in test mode";
|
||||
let description = [{
|
||||
"Carries out batch normalization as described in the paper"
|
||||
"https://arxiv.org/abs/1502.03167. Depending on the mode it is being run,"
|
||||
"there are multiple cases for the number of outputs, which we list below:"
|
||||
""
|
||||
"Output case #1: Y, mean, var, saved_mean, saved_var (training mode)"
|
||||
"Output case #2: Y (test mode)"
|
||||
""
|
||||
"For previous (depreciated) non-spatial cases, implementors are suggested"
|
||||
"to flatten the input shape to (N x C*D1*D2 ..*Dn) before a BatchNormalization Op."
|
||||
"This operator has **optional** inputs/outputs. See [the doc](IR.md) for more details about the representation of optional arguments. An empty string may be used in the place of an actual argument's name to indicate a missing argument. Trailing optional arguments (those not followed by an argument that is present) may also be simply omitted."
|
||||
}];
|
||||
let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$X,
|
||||
AnyTypeOf<[AnyMemRef, AnyTensor]>:$scale,
|
||||
AnyTypeOf<[AnyMemRef, AnyTensor]>:$B,
|
||||
AnyTypeOf<[AnyMemRef, AnyTensor]>:$mean,
|
||||
AnyTypeOf<[AnyMemRef, AnyTensor]>:$var,
|
||||
DefaultValuedAttr<F32Attr, "1e-05">:$epsilon,
|
||||
DefaultValuedAttr<F32Attr, "0.9">:$momentum);
|
||||
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$o_Y);
|
||||
}
|
||||
|
||||
def ONNXPadConstantValueOp : ONNX_Op<"PadConstantValue",
|
||||
[NoSideEffect ]> {
|
||||
let summary = "ONNX Pad operation with constant padding value";
|
||||
|
|
|
@ -603,6 +603,55 @@ void ONNXGemmNoBiasOp::inferShapes() {
|
|||
getResult().setType(RankedTensorType::get(dims, lhsTy.getElementType()));
|
||||
}
|
||||
|
||||
/// BatchNormalizationTestMode
|
||||
void ONNXBatchNormalizationTestModeOp::inferShapes() {
|
||||
// Cannot infer shape if no shape exists.
|
||||
if (!getOperand(0).getType().isa<RankedTensorType>() ||
|
||||
!getOperand(1).getType().isa<RankedTensorType>() ||
|
||||
!getOperand(2).getType().isa<RankedTensorType>() ||
|
||||
!getOperand(3).getType().isa<RankedTensorType>() ||
|
||||
!getOperand(4).getType().isa<RankedTensorType>())
|
||||
return;
|
||||
|
||||
auto input = getOperand(0).getType().cast<RankedTensorType>();
|
||||
auto scale = getOperand(1).getType().cast<RankedTensorType>();
|
||||
auto bias = getOperand(2).getType().cast<RankedTensorType>();
|
||||
auto mean = getOperand(3).getType().cast<RankedTensorType>();
|
||||
auto variance = getOperand(4).getType().cast<RankedTensorType>();
|
||||
|
||||
// Check whether the shapes of scale, bias, mean and variance are valid.
|
||||
// Operand's dimensions can be in the form of NxCxD1xD2x...xDn or N.
|
||||
// In case of N, C is assumed to be 1.
|
||||
// Shapes of scale, bias, mean and variance must be C.
|
||||
int64_t c = -1;
|
||||
if (input.getShape().size() == 1) {
|
||||
c = 1;
|
||||
} else if (input.getShape().size() > 2) {
|
||||
c = (input.getShape()[1] != -1) ? input.getShape()[1] : -1;
|
||||
} else {
|
||||
emitError("Wrong rank for the input.");
|
||||
}
|
||||
|
||||
if (c != -1) {
|
||||
auto s = scale.getShape();
|
||||
auto b = bias.getShape();
|
||||
auto m = mean.getShape();
|
||||
auto v = variance.getShape();
|
||||
|
||||
if ((s.size() != 1) || (s[0] != -1 && s[0] != c))
|
||||
emitError("Wrong rank for the scale.");
|
||||
if ((b.size() != 1) || (b[0] != -1 && b[0] != c))
|
||||
emitError("Wrong rank for the bias.");
|
||||
if ((m.size() != 1) || (m[0] != -1 && m[0] != c))
|
||||
emitError("Wrong rank for the mean.");
|
||||
if ((v.size() != 1) || (v[0] != -1 && v[0] != c))
|
||||
emitError("Wrong rank for the variance.");
|
||||
}
|
||||
|
||||
// The output tensor of the same shape as the input.
|
||||
getResult().setType(getOperand(0).getType());
|
||||
}
|
||||
|
||||
// TODO:
|
||||
// Verify that matrix sizes are valid for multiplication and addition.
|
||||
// Take into account the dimensionality of the matrix.
|
||||
|
|
|
@ -128,6 +128,7 @@ public:
|
|||
op->getName().getStringRef() != "onnx.Softmax" &&
|
||||
op->getName().getStringRef() != "onnx.Sqrt" &&
|
||||
op->getName().getStringRef() != "onnx.ConvNoBias" &&
|
||||
op->getName().getStringRef() != "onnx.BatchNormalizationTestMode" &&
|
||||
op->getName().getStringRef() != "onnx.Unsqueeze")
|
||||
return false;
|
||||
return llvm::any_of(op->getResultTypes(), [](Type result_type) {
|
||||
|
|
|
@ -35,7 +35,7 @@ DynMemRef *getDynMemRef(OrderedDynMemRefDict *tensorDict, int idx) {
|
|||
|
||||
void setDynMemRef(OrderedDynMemRefDict *tensorDict, int idx,
|
||||
DynMemRef *tensor) {
|
||||
if (tensorDict->orderedNames.capacity() <= idx)
|
||||
if (tensorDict->orderedNames.size() <= idx)
|
||||
tensorDict->orderedNames.resize(idx + 1);
|
||||
|
||||
// The dynamic memref is essentially anonymous, since we are storing it by
|
||||
|
|
|
@ -301,6 +301,10 @@ test_to_enable = [
|
|||
"test_matmul_3d_cpu",
|
||||
"test_matmul_4d_cpu",
|
||||
|
||||
# BatchNormalization (test mode)
|
||||
"test_batchnorm_epsilon_cpu",
|
||||
"test_batchnorm_example_cpu",
|
||||
|
||||
]
|
||||
|
||||
# Extract name of all test cases.
|
||||
|
|
|
@ -1313,3 +1313,63 @@ func @test_conv_no_bias_no_pad_w_strides(%arg0 : tensor<1x9x32x64xf32>, %arg1 :
|
|||
|
||||
// CHECK: return [[RES]] : memref<1x5x14x29xf32>
|
||||
}
|
||||
|
||||
func @test_batchnorm_testmode_Nd(%arg0: tensor<1x2x1x3xf32>, %arg1: tensor<2xf32>, %arg2: tensor<2xf32>, %arg3: tensor<2xf32>, %arg4: tensor<2xf32>) -> tensor<1x2x1x3xf32> {
|
||||
%0 = "onnx.BatchNormalizationTestMode"(%arg0, %arg1, %arg2, %arg3, %arg4) : (tensor<1x2x1x3xf32>, tensor<2xf32>, tensor<2xf32>, tensor<2xf32>, tensor<2xf32>) -> tensor<1x2x1x3xf32>
|
||||
return %0 : tensor<1x2x1x3xf32>
|
||||
|
||||
// CHECK-LABEL: test_batchnorm_testmode_Nd
|
||||
// CHECK: [[RES:%.+]] = alloc() : memref<1x2x1x3xf32>
|
||||
// CHECK: [[EPSILON:%.+]] = constant 9.99999974E-6 : f32
|
||||
// CHECK: [[DEF_LOOPS:%.+]]:4 = krnl.define_loops 4
|
||||
// CHECK: [[OPT_LOOPS:%.+]]:4 = krnl.optimize_loops {
|
||||
// CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1, [[DEF_LOOPS]]#2, [[DEF_LOOPS]]#3
|
||||
// CHECK: } : () -> (!krnl.loop, !krnl.loop, !krnl.loop, !krnl.loop)
|
||||
// CHECK: krnl.iterate([[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#1 -> %arg5 = 0 to 2) {
|
||||
// CHECK: [[SCALE:%.+]] = load %arg1[%arg5] : memref<2xf32>
|
||||
// CHECK: [[BIAS:%.+]] = load %arg2[%arg5] : memref<2xf32>
|
||||
// CHECK: [[MEAN:%.+]] = load %arg3[%arg5] : memref<2xf32>
|
||||
// CHECK: [[VARIANCE:%.+]] = load %arg4[%arg5] : memref<2xf32>
|
||||
// CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#2, [[OPT_LOOPS]]#3) with ([[DEF_LOOPS]]#0 -> %arg6 = 0 to 1, [[DEF_LOOPS]]#2 -> %arg7 = 0 to 1, [[DEF_LOOPS]]#3 -> %arg8 = 0 to 3) {
|
||||
// CHECK: [[LOADED_VAL:%.+]] = load %arg0[%arg6, %arg5, %arg7, %arg8] : memref<1x2x1x3xf32>
|
||||
// CHECK: [[DIVIDEND:%.+]] = subf [[LOADED_VAL]], [[MEAN]] : f32
|
||||
// CHECK: [[ADJUSTED_VARIANCE:%.+]] = addf [[VARIANCE]], [[EPSILON]] : f32
|
||||
// CHECK: [[DIVISOR:%.+]] = "krnl.sqrt"([[ADJUSTED_VARIANCE]]) : (f32) -> f32
|
||||
// CHECK: [[NORM:%.+]] = divf [[DIVIDEND]], [[DIVISOR]] : f32
|
||||
// CHECK: [[SCALE_NORM:%.+]] = mulf [[SCALE]], [[NORM]] : f32
|
||||
// CHECK: [[SHIFT_SCALE_NORM:%.+]] = addf [[SCALE_NORM]], [[BIAS]] : f32
|
||||
// CHECK: store [[SHIFT_SCALE_NORM]], [[RES]][%arg6, %arg5, %arg7, %arg8] : memref<1x2x1x3xf32>
|
||||
// CHECK: }
|
||||
// CHECK: }
|
||||
// CHECK: return [[RES]] : memref<1x2x1x3xf32>
|
||||
}
|
||||
|
||||
func @test_batchnorm_testmode_1d(%arg0: tensor<10xf32>, %arg1: tensor<1xf32>, %arg2: tensor<1xf32>, %arg3: tensor<1xf32>, %arg4: tensor<1xf32>) -> tensor<10xf32> {
|
||||
%0 = "onnx.BatchNormalizationTestMode"(%arg0, %arg1, %arg2, %arg3, %arg4) : (tensor<10xf32>, tensor<1xf32>, tensor<1xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<10xf32>
|
||||
return %0 : tensor<10xf32>
|
||||
|
||||
// CHECK-LABEL: test_batchnorm_testmode_1d
|
||||
// CHECK: [[RES:%.+]] = alloc() : memref<10xf32>
|
||||
// CHECK: [[EPSILON:%.+]] = constant 9.99999974E-6 : f32
|
||||
// CHECK: [[DEF_LOOPS:%.+]] = krnl.define_loops 1
|
||||
// CHECK: [[OPT_LOOPS:%.+]] = krnl.optimize_loops {
|
||||
// CHECK: krnl.return_loops [[DEF_LOOPS]]
|
||||
// CHECK: } : () -> !krnl.loop
|
||||
// CHECK: %[[ZERO_INDEX:.+]] = constant 0 : index
|
||||
// CHECK: [[SCALE:%.+]] = load %arg1[%[[ZERO_INDEX]]] : memref<1xf32>
|
||||
// CHECK: [[BIAS:%.+]] = load %arg2[%[[ZERO_INDEX]]] : memref<1xf32>
|
||||
// CHECK: [[MEAN:%.+]] = load %arg3[%[[ZERO_INDEX]]] : memref<1xf32>
|
||||
// CHECK: [[VARIANCE:%.+]] = load %arg4[%[[ZERO_INDEX]]] : memref<1xf32>
|
||||
// CHECK: krnl.iterate([[OPT_LOOPS]]) with ([[DEF_LOOPS]] -> %arg5 = 0 to 10) {
|
||||
// CHECK: [[LOADED_VAL:%.+]] = load %arg0[%arg5] : memref<10xf32>
|
||||
// CHECK: [[DIVIDEND:%.+]] = subf [[LOADED_VAL]], [[MEAN]] : f32
|
||||
// CHECK: [[ADJUSTED_VARIANCE:%.+]] = addf [[VARIANCE]], [[EPSILON]] : f32
|
||||
// CHECK: [[DIVISOR:%.+]] = "krnl.sqrt"([[ADJUSTED_VARIANCE]]) : (f32) -> f32
|
||||
// CHECK: [[NORM:%.+]] = divf [[DIVIDEND]], [[DIVISOR]] : f32
|
||||
// CHECK: [[SCALE_NORM:%.+]] = mulf [[SCALE]], [[NORM]] : f32
|
||||
// CHECK: [[SHIFT_SCALE_NORM:%.+]] = addf [[SCALE_NORM]], [[BIAS]] : f32
|
||||
// CHECK: store [[SHIFT_SCALE_NORM]], [[RES]][%arg5] : memref<10xf32>
|
||||
// CHECK: }
|
||||
// CHECK: return [[RES]] : memref<10xf32>
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue