Lower BatchNormalization (test mode) to Krnl dialect (#70)

* Add ONNXBatchNormalizationTestModeOp and its shape inference

* Lower batchnormalization test mode

* re-use scale, bias, mean, and variance

* Add MLIR tests

* Add e2e tests

* fix typos

* Fix a bug in MLIR tests

* Change type from int to int64_t for indices

* Uncomment e2e tests due to segmentation fault

* Uncomment e2e tests due to segmentation fault

* Revise the code

* [Tian] Fix segmentation fault in e2e tests

* Re-generate onnx.md to include BatchNormalizationTestModeOp

* Reverse an unintentional change

* Fix some typos in comments

* Use convertToMemRefType from the master branch

Co-authored-by: Gheorghe-Teodor Bercea <gt.bercea@gmail.com>
This commit is contained in:
Tung D. Le 2020-02-21 01:45:40 +09:00 committed by GitHub
parent f1d20e368f
commit aea6479ad3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 335 additions and 2 deletions

View File

@ -332,6 +332,42 @@ ONNX BatchNormalization operation
1. `saved_mean`: memref of any type values or tensor of any type values
1. `saved_var`: memref of any type values or tensor of any type values
### onnx.BatchNormalizationTestMode (ONNXBatchNormalizationTestModeOp)
ONNX BatchNormalization operation in test mode
#### Description:
"Carries out batch normalization as described in the paper"
"https://arxiv.org/abs/1502.03167. Depending on the mode it is being run,"
"there are multiple cases for the number of outputs, which we list below:"
""
"Output case #1: Y, mean, var, saved_mean, saved_var (training mode)"
"Output case #2: Y (test mode)"
""
"For previous (depreciated) non-spatial cases, implementors are suggested"
"to flatten the input shape to (N x C*D1*D2 ..*Dn) before a BatchNormalization Op."
"This operator has **optional** inputs/outputs. See [the doc](IR.md) for more details about the representation of optional arguments. An empty string may be used in the place of an actual argument's name to indicate a missing argument. Trailing optional arguments (those not followed by an argument that is present) may also be simply omitted."
#### Operands:
1. `X`: memref of any type values or tensor of any type values
1. `scale`: memref of any type values or tensor of any type values
1. `B`: memref of any type values or tensor of any type values
1. `mean`: memref of any type values or tensor of any type values
1. `var`: memref of any type values or tensor of any type values
#### Attributes:
| Attribute | MLIR Type | Description |
| :-------: | :-------: | ----------- |
| `epsilon` | `FloatAttr` | 32-bit float attribute attribute |
| `momentum` | `FloatAttr` | 32-bit float attribute attribute |
#### Results:
1. `o_Y`: memref of any type values or tensor of any type values
### onnx.BitShift (ONNXBitShiftOp)
ONNX BitShift operation

View File

@ -36,6 +36,7 @@ special_attr_defaults = dict([
special_op_handler = dict([
("Conv", "ImportNodeConv"),
("MaxPool", "ImportNodeMaxPool"),
("BatchNormalization", "ImportNodeBatchNormalization"),
("Gemm", "ImportNodeGemm"),
("Pad", "ImportNodePad"),
#("Transpose", "ImportNodeTranspose")

View File

@ -434,6 +434,21 @@ private:
}
}
/*!
* Special handle for BatchNormalization operations.
*/
void ImportNodeBatchNormalization(onnx::NodeProto node, int nIn, int nOut) {
int nOuts = node.output().size();
if (nOuts == 1) {
// Test mode with one output.
ImportNodeOneOut<mlir::ONNXBatchNormalizationTestModeOp>(node, nIn,
nOuts);
} else {
// Training mode with four trailing optional outputs. Not handled yet.
ImportNodeMultipleOuts<mlir::ONNXBatchNormalizationOp>(node, nIn, nOuts);
}
}
/*!
* Special handle for Gemm operations.
*/

View File

@ -30,7 +30,7 @@
}else if (OpName == "AveragePool") {
ImportNodeOneOut<mlir::ONNXAveragePoolOp>(node, 1, 1);
}else if (OpName == "BatchNormalization") {
ImportNodeMultipleOuts<mlir::ONNXBatchNormalizationOp>(node, 5, 5);
ImportNodeBatchNormalization(node, 5, 5);
}else if (OpName == "BitShift") {
ImportNodeOneOut<mlir::ONNXBitShiftOp>(node, 2, 1);
}else if (OpName == "Cast") {

View File

@ -404,6 +404,7 @@ Value mapToLowerScalarOp(Operation *op, ArrayRef<Type> result_types,
#include "src/conversion/onnx_to_krnl/rewrite_patterns/tensor/unsqueeze.inc"
// Neural network
#include "src/conversion/onnx_to_krnl/rewrite_patterns/nn/conv.inc"
#include "src/conversion/onnx_to_krnl/rewrite_patterns/nn/normalization.inc"
//===----------------------------------------------------------------------===//
// EntryPoint Op lowering to Krnl Entry Point.
@ -523,6 +524,7 @@ void FrontendToKrnlLoweringPass::runOnModule() {
populateLoweringONNXIdentityOpPattern(patterns, &getContext());
// Neural network
populateLoweringONNXConvOpPattern(patterns, &getContext());
populateLoweringONNXNormalizationOpPattern(patterns, &getContext());
// Entry point
patterns.insert<ONNXEntryPointLowering>(&getContext());

View File

@ -0,0 +1,140 @@
//===----- normalization.inc - Lowering Normalization Ops -----------------===//
//
// Copyright 2019 The IBM Research Authors.
//
// =============================================================================
//
// This file lowers ONNX Normalization Operators to Krnl dialect.
//
//===----------------------------------------------------------------------===//
struct ONNXBatchNormalizationTestModeOpLowering : public ConversionPattern {
ONNXBatchNormalizationTestModeOpLowering(MLIRContext *ctx)
: ConversionPattern(
mlir::ONNXBatchNormalizationTestModeOp::getOperationName(), 1,
ctx) {}
PatternMatchResult matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter & rewriter) const final {
// batchnorm{epsilon}(x, scale, bias, mean, variance) =
// scale * (x - mean) / sqrt(variance + epsilon) + bias
auto loc = op->getLoc();
auto memRefType = convertToMemRefType(*op->result_type_begin());
auto epsilonAttr =
FloatAttr::get(memRefType.getElementType(),
llvm::dyn_cast<ONNXBatchNormalizationTestModeOp>(op)
.epsilon()
.convertToFloat());
auto epsilon = rewriter.create<ConstantOp>(loc, epsilonAttr);
auto operand = operands[0];
auto scale = operands[1];
auto bias = operands[2];
auto mean = operands[3];
auto variance = operands[4];
// Insert an allocation and deallocation for the result of this operation.
Value alloc;
bool insertDealloc = checkInsertDealloc(op);
if (hasAllConstantDimensions(memRefType))
alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc);
else
alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc,
{operand});
// Operand's dimensions can be in the form of NxCxD1xD2x...xDn or N.
// In case of N, C is assumed to be 1.
// Shapes of scale, bias, mean and variance must be C.
// Computation of BatchNormalization is done as if scale, bias, mean, and
// variance are reshaped to Cx1x1x...x1.
// rank
int64_t rank = memRefType.getRank();
std::vector<Value> originalLoops;
std::vector<Value> optimizedLoops;
Block *optimizationBlock =
defineLoops(rewriter, loc, originalLoops, optimizedLoops, rank);
// Create a KrnlIterateOp along C dimension.
// This will be the outer-most loop in order to re-use scale, bias,
// mean and variance.
SmallVector<Value, 1> loopCIVs;
if (rank > 1) {
KrnlIterateOperandPack cPack(rewriter, originalLoops[1],
optimizedLoops[1]);
addDimensionToPack(rewriter, loc, cPack, operand, 1);
auto cIterateOp = rewriter.create<KrnlIterateOp>(loc, cPack);
Block &cIterationBlock = cIterateOp.bodyRegion().front();
rewriter.setInsertionPointToStart(&cIterationBlock);
for (auto arg : cIterationBlock.getArguments())
loopCIVs.emplace_back(arg);
} else {
loopCIVs.emplace_back(rewriter.create<ConstantIndexOp>(loc, 0));
}
auto scaleVal = rewriter.create<LoadOp>(loc, scale, loopCIVs);
auto biasVal = rewriter.create<LoadOp>(loc, bias, loopCIVs);
auto meanVal = rewriter.create<LoadOp>(loc, mean, loopCIVs);
auto varianceVal = rewriter.create<LoadOp>(loc, variance, loopCIVs);
// Create a KrnlIterateOp along the other dimensions.
SmallVector<int64_t, 4> axes;
axes.emplace_back(0);
for (int64_t i = 2; i < rank; ++i)
axes.emplace_back(i);
std::vector<Value> packLoops, packOptimizedLoops;
for (int i = 0; i < axes.size(); ++i) {
packLoops.emplace_back(originalLoops[axes[i]]);
packOptimizedLoops.emplace_back(optimizedLoops[axes[i]]);
}
KrnlIterateOperandPack pack(rewriter, packLoops, packOptimizedLoops);
for (int i = 0; i < axes.size(); ++i) {
addDimensionToPack(rewriter, loc, pack, operand, axes[i]);
}
auto iterateOp = rewriter.create<KrnlIterateOp>(loc, pack);
// No optimization
rewriter.setInsertionPointToEnd(optimizationBlock);
rewriter.create<KrnlReturnLoopsOp>(loc, originalLoops);
Block &iterationBlock = iterateOp.bodyRegion().front();
rewriter.setInsertionPointToStart(&iterationBlock);
SmallVector<Value, 4> loopIVs;
auto args = iterationBlock.getArguments();
if (args.size() > 1) {
loopIVs.emplace_back(args[0]);
loopIVs.emplace_back(loopCIVs[0]); // Insert C back.
for (int i = 1; i < args.size(); ++i)
loopIVs.emplace_back(args[i]);
} else {
loopIVs.emplace_back(args[0]);
}
auto xVal = rewriter.create<LoadOp>(loc, operand, loopIVs);
// normalize
auto dividend = rewriter.create<SubFOp>(loc, xVal, meanVal);
auto adjustedVarianceVal =
rewriter.create<AddFOp>(loc, varianceVal, epsilon);
auto divisor = rewriter.create<KrnlSqrtOp>(loc, memRefType.getElementType(),
adjustedVarianceVal);
auto normVal = rewriter.create<DivFOp>(loc, dividend, divisor);
// scale and shift
auto scaleNormVal = rewriter.create<MulFOp>(loc, scaleVal, normVal);
auto shiftScaleNormVal =
rewriter.create<AddFOp>(loc, scaleNormVal, biasVal);
rewriter.create<StoreOp>(loc, shiftScaleNormVal, alloc, loopIVs);
rewriter.replaceOp(op, alloc);
return matchSuccess();
}
};
void populateLoweringONNXNormalizationOpPattern(
OwningRewritePatternList &patterns, MLIRContext *ctx) {
patterns.insert<ONNXBatchNormalizationTestModeOpLowering>(ctx);
}

View File

@ -146,6 +146,31 @@ def ONNXMaxPoolSingleOutOp: ONNX_Op<"MaxPoolSingleOut",
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$o_Y);
}
def ONNXBatchNormalizationTestModeOp: ONNX_Op<"BatchNormalizationTestMode",
[NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
let summary = "ONNX BatchNormalization operation in test mode";
let description = [{
"Carries out batch normalization as described in the paper"
"https://arxiv.org/abs/1502.03167. Depending on the mode it is being run,"
"there are multiple cases for the number of outputs, which we list below:"
""
"Output case #1: Y, mean, var, saved_mean, saved_var (training mode)"
"Output case #2: Y (test mode)"
""
"For previous (depreciated) non-spatial cases, implementors are suggested"
"to flatten the input shape to (N x C*D1*D2 ..*Dn) before a BatchNormalization Op."
"This operator has **optional** inputs/outputs. See [the doc](IR.md) for more details about the representation of optional arguments. An empty string may be used in the place of an actual argument's name to indicate a missing argument. Trailing optional arguments (those not followed by an argument that is present) may also be simply omitted."
}];
let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$X,
AnyTypeOf<[AnyMemRef, AnyTensor]>:$scale,
AnyTypeOf<[AnyMemRef, AnyTensor]>:$B,
AnyTypeOf<[AnyMemRef, AnyTensor]>:$mean,
AnyTypeOf<[AnyMemRef, AnyTensor]>:$var,
DefaultValuedAttr<F32Attr, "1e-05">:$epsilon,
DefaultValuedAttr<F32Attr, "0.9">:$momentum);
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$o_Y);
}
def ONNXPadConstantValueOp : ONNX_Op<"PadConstantValue",
[NoSideEffect ]> {
let summary = "ONNX Pad operation with constant padding value";

View File

@ -603,6 +603,55 @@ void ONNXGemmNoBiasOp::inferShapes() {
getResult().setType(RankedTensorType::get(dims, lhsTy.getElementType()));
}
/// BatchNormalizationTestMode
void ONNXBatchNormalizationTestModeOp::inferShapes() {
// Cannot infer shape if no shape exists.
if (!getOperand(0).getType().isa<RankedTensorType>() ||
!getOperand(1).getType().isa<RankedTensorType>() ||
!getOperand(2).getType().isa<RankedTensorType>() ||
!getOperand(3).getType().isa<RankedTensorType>() ||
!getOperand(4).getType().isa<RankedTensorType>())
return;
auto input = getOperand(0).getType().cast<RankedTensorType>();
auto scale = getOperand(1).getType().cast<RankedTensorType>();
auto bias = getOperand(2).getType().cast<RankedTensorType>();
auto mean = getOperand(3).getType().cast<RankedTensorType>();
auto variance = getOperand(4).getType().cast<RankedTensorType>();
// Check whether the shapes of scale, bias, mean and variance are valid.
// Operand's dimensions can be in the form of NxCxD1xD2x...xDn or N.
// In case of N, C is assumed to be 1.
// Shapes of scale, bias, mean and variance must be C.
int64_t c = -1;
if (input.getShape().size() == 1) {
c = 1;
} else if (input.getShape().size() > 2) {
c = (input.getShape()[1] != -1) ? input.getShape()[1] : -1;
} else {
emitError("Wrong rank for the input.");
}
if (c != -1) {
auto s = scale.getShape();
auto b = bias.getShape();
auto m = mean.getShape();
auto v = variance.getShape();
if ((s.size() != 1) || (s[0] != -1 && s[0] != c))
emitError("Wrong rank for the scale.");
if ((b.size() != 1) || (b[0] != -1 && b[0] != c))
emitError("Wrong rank for the bias.");
if ((m.size() != 1) || (m[0] != -1 && m[0] != c))
emitError("Wrong rank for the mean.");
if ((v.size() != 1) || (v[0] != -1 && v[0] != c))
emitError("Wrong rank for the variance.");
}
// The output tensor of the same shape as the input.
getResult().setType(getOperand(0).getType());
}
// TODO:
// Verify that matrix sizes are valid for multiplication and addition.
// Take into account the dimensionality of the matrix.

View File

@ -128,6 +128,7 @@ public:
op->getName().getStringRef() != "onnx.Softmax" &&
op->getName().getStringRef() != "onnx.Sqrt" &&
op->getName().getStringRef() != "onnx.ConvNoBias" &&
op->getName().getStringRef() != "onnx.BatchNormalizationTestMode" &&
op->getName().getStringRef() != "onnx.Unsqueeze")
return false;
return llvm::any_of(op->getResultTypes(), [](Type result_type) {

View File

@ -35,7 +35,7 @@ DynMemRef *getDynMemRef(OrderedDynMemRefDict *tensorDict, int idx) {
void setDynMemRef(OrderedDynMemRefDict *tensorDict, int idx,
DynMemRef *tensor) {
if (tensorDict->orderedNames.capacity() <= idx)
if (tensorDict->orderedNames.size() <= idx)
tensorDict->orderedNames.resize(idx + 1);
// The dynamic memref is essentially anonymous, since we are storing it by

View File

@ -301,6 +301,10 @@ test_to_enable = [
"test_matmul_3d_cpu",
"test_matmul_4d_cpu",
# BatchNormalization (test mode)
"test_batchnorm_epsilon_cpu",
"test_batchnorm_example_cpu",
]
# Extract name of all test cases.

View File

@ -1313,3 +1313,63 @@ func @test_conv_no_bias_no_pad_w_strides(%arg0 : tensor<1x9x32x64xf32>, %arg1 :
// CHECK: return [[RES]] : memref<1x5x14x29xf32>
}
func @test_batchnorm_testmode_Nd(%arg0: tensor<1x2x1x3xf32>, %arg1: tensor<2xf32>, %arg2: tensor<2xf32>, %arg3: tensor<2xf32>, %arg4: tensor<2xf32>) -> tensor<1x2x1x3xf32> {
%0 = "onnx.BatchNormalizationTestMode"(%arg0, %arg1, %arg2, %arg3, %arg4) : (tensor<1x2x1x3xf32>, tensor<2xf32>, tensor<2xf32>, tensor<2xf32>, tensor<2xf32>) -> tensor<1x2x1x3xf32>
return %0 : tensor<1x2x1x3xf32>
// CHECK-LABEL: test_batchnorm_testmode_Nd
// CHECK: [[RES:%.+]] = alloc() : memref<1x2x1x3xf32>
// CHECK: [[EPSILON:%.+]] = constant 9.99999974E-6 : f32
// CHECK: [[DEF_LOOPS:%.+]]:4 = krnl.define_loops 4
// CHECK: [[OPT_LOOPS:%.+]]:4 = krnl.optimize_loops {
// CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1, [[DEF_LOOPS]]#2, [[DEF_LOOPS]]#3
// CHECK: } : () -> (!krnl.loop, !krnl.loop, !krnl.loop, !krnl.loop)
// CHECK: krnl.iterate([[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#1 -> %arg5 = 0 to 2) {
// CHECK: [[SCALE:%.+]] = load %arg1[%arg5] : memref<2xf32>
// CHECK: [[BIAS:%.+]] = load %arg2[%arg5] : memref<2xf32>
// CHECK: [[MEAN:%.+]] = load %arg3[%arg5] : memref<2xf32>
// CHECK: [[VARIANCE:%.+]] = load %arg4[%arg5] : memref<2xf32>
// CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#2, [[OPT_LOOPS]]#3) with ([[DEF_LOOPS]]#0 -> %arg6 = 0 to 1, [[DEF_LOOPS]]#2 -> %arg7 = 0 to 1, [[DEF_LOOPS]]#3 -> %arg8 = 0 to 3) {
// CHECK: [[LOADED_VAL:%.+]] = load %arg0[%arg6, %arg5, %arg7, %arg8] : memref<1x2x1x3xf32>
// CHECK: [[DIVIDEND:%.+]] = subf [[LOADED_VAL]], [[MEAN]] : f32
// CHECK: [[ADJUSTED_VARIANCE:%.+]] = addf [[VARIANCE]], [[EPSILON]] : f32
// CHECK: [[DIVISOR:%.+]] = "krnl.sqrt"([[ADJUSTED_VARIANCE]]) : (f32) -> f32
// CHECK: [[NORM:%.+]] = divf [[DIVIDEND]], [[DIVISOR]] : f32
// CHECK: [[SCALE_NORM:%.+]] = mulf [[SCALE]], [[NORM]] : f32
// CHECK: [[SHIFT_SCALE_NORM:%.+]] = addf [[SCALE_NORM]], [[BIAS]] : f32
// CHECK: store [[SHIFT_SCALE_NORM]], [[RES]][%arg6, %arg5, %arg7, %arg8] : memref<1x2x1x3xf32>
// CHECK: }
// CHECK: }
// CHECK: return [[RES]] : memref<1x2x1x3xf32>
}
func @test_batchnorm_testmode_1d(%arg0: tensor<10xf32>, %arg1: tensor<1xf32>, %arg2: tensor<1xf32>, %arg3: tensor<1xf32>, %arg4: tensor<1xf32>) -> tensor<10xf32> {
%0 = "onnx.BatchNormalizationTestMode"(%arg0, %arg1, %arg2, %arg3, %arg4) : (tensor<10xf32>, tensor<1xf32>, tensor<1xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<10xf32>
return %0 : tensor<10xf32>
// CHECK-LABEL: test_batchnorm_testmode_1d
// CHECK: [[RES:%.+]] = alloc() : memref<10xf32>
// CHECK: [[EPSILON:%.+]] = constant 9.99999974E-6 : f32
// CHECK: [[DEF_LOOPS:%.+]] = krnl.define_loops 1
// CHECK: [[OPT_LOOPS:%.+]] = krnl.optimize_loops {
// CHECK: krnl.return_loops [[DEF_LOOPS]]
// CHECK: } : () -> !krnl.loop
// CHECK: %[[ZERO_INDEX:.+]] = constant 0 : index
// CHECK: [[SCALE:%.+]] = load %arg1[%[[ZERO_INDEX]]] : memref<1xf32>
// CHECK: [[BIAS:%.+]] = load %arg2[%[[ZERO_INDEX]]] : memref<1xf32>
// CHECK: [[MEAN:%.+]] = load %arg3[%[[ZERO_INDEX]]] : memref<1xf32>
// CHECK: [[VARIANCE:%.+]] = load %arg4[%[[ZERO_INDEX]]] : memref<1xf32>
// CHECK: krnl.iterate([[OPT_LOOPS]]) with ([[DEF_LOOPS]] -> %arg5 = 0 to 10) {
// CHECK: [[LOADED_VAL:%.+]] = load %arg0[%arg5] : memref<10xf32>
// CHECK: [[DIVIDEND:%.+]] = subf [[LOADED_VAL]], [[MEAN]] : f32
// CHECK: [[ADJUSTED_VARIANCE:%.+]] = addf [[VARIANCE]], [[EPSILON]] : f32
// CHECK: [[DIVISOR:%.+]] = "krnl.sqrt"([[ADJUSTED_VARIANCE]]) : (f32) -> f32
// CHECK: [[NORM:%.+]] = divf [[DIVIDEND]], [[DIVISOR]] : f32
// CHECK: [[SCALE_NORM:%.+]] = mulf [[SCALE]], [[NORM]] : f32
// CHECK: [[SHIFT_SCALE_NORM:%.+]] = addf [[SCALE_NORM]], [[BIAS]] : f32
// CHECK: store [[SHIFT_SCALE_NORM]], [[RES]][%arg5] : memref<10xf32>
// CHECK: }
// CHECK: return [[RES]] : memref<10xf32>
}