Merge branch 'master' of github.com:clang-ykt/ONNF into shapeinference-pad

Conflicts:
	src/pass/shape_inference_pass.cpp
This commit is contained in:
chentong 2020-02-21 09:30:40 -05:00
commit 2281cc060f
28 changed files with 2865 additions and 70 deletions

View File

@ -18,7 +18,7 @@ jobs:
git submodule update --init --recursive
# Use cached mlir installation if possible.
- restore_cache:
key: V4-LLVM-PROJECT-{{ arch }}
key: V6-LLVM-PROJECT-{{ arch }}
- run:
name: Install MLIR
command: |
@ -29,7 +29,7 @@ jobs:
source ONNF/utils/install-mlir.sh
fi
- save_cache:
key: V4-LLVM-PROJECT-{{ arch }}
key: V6-LLVM-PROJECT-{{ arch }}
paths:
- llvm-project
- run:

View File

@ -58,9 +58,11 @@ find_mlir_lib(MLIRAffineOps)
find_mlir_lib(MLIRAffineToStandard)
find_mlir_lib(MLIRAnalysis)
find_mlir_lib(MLIRDialect)
find_mlir_lib(MLIREDSC)
find_mlir_lib(MLIRExecutionEngine)
find_mlir_lib(MLIRIR)
find_mlir_lib(MLIRLLVMIR)
find_mlir_lib(MLIRLoopAnalysis)
find_mlir_lib(MLIRLoopToStandard)
find_mlir_lib(MLIRLoopOps)
find_mlir_lib(MLIRParser)
@ -71,7 +73,8 @@ find_mlir_lib(MLIRTargetLLVMIR)
find_mlir_lib(MLIRTransforms)
find_mlir_lib(MLIRTransformUtils)
find_mlir_lib(MLIRSupport)
find_mlir_lib(MLIROptMain)
find_mlir_lib(MLIRMlirOptMain)
find_mlir_lib(MLIROptLib)
find_mlir_lib(MLIRTargetLLVMIRModuleTranslation)
find_mlir_lib(MLIRTargetLLVMIR)
find_mlir_lib(MLIRTransformUtils)
@ -117,12 +120,15 @@ set(MLIRLibsOnce
${MLIRAffineToStandard}
${MLIRAnalysis}
${MLIRDialect}
${MLIREDSC}
${MLIRExecutionEngine}
${MLIRIR}
${MLIRLLVMIR}
${MLIRLoopToStandard}
${MLIRLoopOps}
${MLIROptMain}
${MLIRLoopAnalysis}
${MLIRMlirOptMain}
${MLIROptLib}
${MLIRParser}
${MLIRPass}
${MLIRStandardOps}
@ -226,4 +232,4 @@ function(add_onnf_dialect_doc dialect dialect_tablegen_file)
add_dependencies(onnf-doc ${dialect}DocGen)
endfunction()
add_custom_target(onnf-doc)
add_custom_target(onnf-doc)

View File

@ -332,6 +332,42 @@ ONNX BatchNormalization operation
1. `saved_mean`: memref of any type values or tensor of any type values
1. `saved_var`: memref of any type values or tensor of any type values
### onnx.BatchNormalizationTestMode (ONNXBatchNormalizationTestModeOp)
ONNX BatchNormalization operation in test mode
#### Description:
"Carries out batch normalization as described in the paper"
"https://arxiv.org/abs/1502.03167. Depending on the mode it is being run,"
"there are multiple cases for the number of outputs, which we list below:"
""
"Output case #1: Y, mean, var, saved_mean, saved_var (training mode)"
"Output case #2: Y (test mode)"
""
"For previous (depreciated) non-spatial cases, implementors are suggested"
"to flatten the input shape to (N x C*D1*D2 ..*Dn) before a BatchNormalization Op."
"This operator has **optional** inputs/outputs. See [the doc](IR.md) for more details about the representation of optional arguments. An empty string may be used in the place of an actual argument's name to indicate a missing argument. Trailing optional arguments (those not followed by an argument that is present) may also be simply omitted."
#### Operands:
1. `X`: memref of any type values or tensor of any type values
1. `scale`: memref of any type values or tensor of any type values
1. `B`: memref of any type values or tensor of any type values
1. `mean`: memref of any type values or tensor of any type values
1. `var`: memref of any type values or tensor of any type values
#### Attributes:
| Attribute | MLIR Type | Description |
| :-------: | :-------: | ----------- |
| `epsilon` | `FloatAttr` | 32-bit float attribute attribute |
| `momentum` | `FloatAttr` | 32-bit float attribute attribute |
#### Results:
1. `o_Y`: memref of any type values or tensor of any type values
### onnx.BitShift (ONNXBitShiftOp)
ONNX BitShift operation

View File

@ -36,6 +36,7 @@ special_attr_defaults = dict([
special_op_handler = dict([
("Conv", "ImportNodeConv"),
("MaxPool", "ImportNodeMaxPool"),
("BatchNormalization", "ImportNodeBatchNormalization"),
("Gemm", "ImportNodeGemm"),
("Pad", "ImportNodePad"),
#("Transpose", "ImportNodeTranspose")

View File

@ -434,6 +434,21 @@ private:
}
}
/*!
* Special handle for BatchNormalization operations.
*/
void ImportNodeBatchNormalization(onnx::NodeProto node, int nIn, int nOut) {
int nOuts = node.output().size();
if (nOuts == 1) {
// Test mode with one output.
ImportNodeOneOut<mlir::ONNXBatchNormalizationTestModeOp>(node, nIn,
nOuts);
} else {
// Training mode with four trailing optional outputs. Not handled yet.
ImportNodeMultipleOuts<mlir::ONNXBatchNormalizationOp>(node, nIn, nOuts);
}
}
/*!
* Special handle for Gemm operations.
*/

View File

@ -30,7 +30,7 @@
}else if (OpName == "AveragePool") {
ImportNodeOneOut<mlir::ONNXAveragePoolOp>(node, 1, 1);
}else if (OpName == "BatchNormalization") {
ImportNodeMultipleOuts<mlir::ONNXBatchNormalizationOp>(node, 5, 5);
ImportNodeBatchNormalization(node, 5, 5);
}else if (OpName == "BitShift") {
ImportNodeOneOut<mlir::ONNXBitShiftOp>(node, 2, 1);
}else if (OpName == "Cast") {

View File

@ -37,10 +37,18 @@ static bool hasAllConstantDimensions(MemRefType type) {
return true;
}
/// Convert the given TensorType into the corresponding MemRefType.
static MemRefType convertTensorToMemRef(TensorType type) {
assert(type.hasRank() && "expected only ranked shapes");
return MemRefType::get(type.getShape(), type.getElementType());
/// Get the corresponding MemRefType of a given TensorType/MemRefType.
static MemRefType convertToMemRefType(Type type) {
MemRefType memRefType;
auto tensorType = type.dyn_cast<TensorType>();
if (tensorType) {
assert(tensorType.hasRank() && "expected only ranked shapes");
memRefType =
MemRefType::get(tensorType.getShape(), tensorType.getElementType());
} else {
memRefType = type.dyn_cast<MemRefType>();
}
return memRefType;
}
/// Insert an allocation and deallocation for the given MemRefType.
@ -396,6 +404,7 @@ Value mapToLowerScalarOp(Operation *op, ArrayRef<Type> result_types,
#include "src/conversion/onnx_to_krnl/rewrite_patterns/tensor/unsqueeze.inc"
// Neural network
#include "src/conversion/onnx_to_krnl/rewrite_patterns/nn/conv.inc"
#include "src/conversion/onnx_to_krnl/rewrite_patterns/nn/normalization.inc"
//===----------------------------------------------------------------------===//
// EntryPoint Op lowering to Krnl Entry Point.
@ -425,9 +434,13 @@ public:
struct TensorTypeConverter : public TypeConverter {
using TypeConverter::TypeConverter;
LogicalResult convertType(Type t, SmallVectorImpl<Type> &results) override {
if (auto tensor_type = t.dyn_cast<TensorType>()) {
results.push_back(convertTensorToMemRef(tensor_type));
TensorTypeConverter() {
addConversion(convertType);
}
static LogicalResult convertType(Type t, SmallVectorImpl<Type> &results) {
if (auto type = convertToMemRefType(t)) {
results.push_back(type);
return success();
}
@ -511,6 +524,7 @@ void FrontendToKrnlLoweringPass::runOnModule() {
populateLoweringONNXIdentityOpPattern(patterns, &getContext());
// Neural network
populateLoweringONNXConvOpPattern(patterns, &getContext());
populateLoweringONNXNormalizationOpPattern(patterns, &getContext());
// Entry point
patterns.insert<ONNXEntryPointLowering>(&getContext());

View File

@ -476,11 +476,10 @@ struct ONNXElementwiseUnaryOpLowering : public ConversionPattern {
// TODO: Check that the types are valid.
// An element-wise unary operation must have all operands and the result of
// the same type. This should have been verified by the verifier.
auto tensorType = (*op->result_type_begin()).cast<TensorType>();
auto loc = op->getLoc();
// Insert an allocation and deallocation for the result of this operation.
auto memRefType = convertTensorToMemRef(tensorType);
auto memRefType = convertToMemRefType(*op->result_type_begin());
// If the output has a dynamic dimension, pass the operands required for
// each dynamic dimension to the AllocOp. The first operand of the
@ -545,12 +544,11 @@ struct ONNXElementwiseVariadicOpLowering : public ConversionPattern {
// TODO: Check that the types are valid.
// An element-wise variadic operation must have all operands and the result
// of the same type. This should have been verified by the verifier.
auto tensorType = (*op->result_type_begin()).cast<TensorType>();
auto loc = op->getLoc();
auto numArgs = op->getNumOperands();
// Insert an allocation and deallocation for the result of this operation.
auto memRefType = convertTensorToMemRef(tensorType);
auto memRefType = convertToMemRefType(*op->result_type_begin());
Value alloc;
bool insertDealloc = checkInsertDealloc(op);

View File

@ -8,33 +8,34 @@
//
//===----------------------------------------------------------------------===//
template <typename GemmOp>
struct ONNXGemmOpLowering : public ConversionPattern {
ONNXGemmOpLowering(MLIRContext *ctx)
: ConversionPattern(mlir::ONNXGemmOp::getOperationName(), 1, ctx) {}
: ConversionPattern(GemmOp::getOperationName(), 1, ctx) {}
PatternMatchResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
auto tensorType = (*op->result_type_begin()).cast<TensorType>();
auto loc = op->getLoc();
auto has_bias = (operands.size() == 3);
Value A, B, C;
A = operands[0];
B = operands[1];
C = operands[2];
if (has_bias)
C = operands[2];
auto alphaAttr = FloatAttr::get(tensorType.getElementType(),
llvm::dyn_cast<ONNXGemmOp>(op).alpha().convertToFloat());
auto betaAttr = FloatAttr::get(tensorType.getElementType(),
llvm::dyn_cast<ONNXGemmOp>(op).beta().convertToFloat());
auto memRefType = convertToMemRefType(*op->result_type_begin());
auto alphaAttr = FloatAttr::get(memRefType.getElementType(),
llvm::dyn_cast<GemmOp>(op).alpha().convertToFloat());
auto betaAttr = FloatAttr::get(memRefType.getElementType(),
llvm::dyn_cast<GemmOp>(op).beta().convertToFloat());
auto alpha = rewriter.create<ConstantOp>(loc, alphaAttr);
auto beta = rewriter.create<ConstantOp>(loc, betaAttr);
bool isTransA = (llvm::dyn_cast<ONNXGemmOp>(op).transA() != 0);
bool isTransB = (llvm::dyn_cast<ONNXGemmOp>(op).transB() != 0);
// Result type
auto memRefType = convertTensorToMemRef(tensorType);
bool isTransA = (llvm::dyn_cast<GemmOp>(op).transA() != 0);
bool isTransB = (llvm::dyn_cast<GemmOp>(op).transB() != 0);
// Insert an allocation and deallocation for the result of this operation.
Value alloc;
@ -118,14 +119,16 @@ struct ONNXGemmOpLowering : public ConversionPattern {
// GemmOp supports unidirectional broadcasting from C to A*B.
// Hence, it must be enough to get broadcasting information for C only.
std::map<int, Value> broadcastedDimInfo;
auto shape = C.getType().cast<MemRefType>().getShape();
for (int i = 0; i < shape.size(); ++i) {
if (shape[i] < 0) {
auto dim = rewriter.create<DimOp>(loc, C, i).getResult();
auto one = rewriter.create<ConstantIndexOp>(loc, 1);
auto isBroadcasted =
rewriter.create<CmpIOp>(loc, CmpIPredicate::eq, dim, one);
broadcastedDimInfo.insert(std::make_pair(i, isBroadcasted));
if (has_bias) {
auto shape = C.getType().cast<MemRefType>().getShape();
for (int i = 0; i < shape.size(); ++i) {
if (shape[i] < 0) {
auto dim = rewriter.create<DimOp>(loc, C, i).getResult();
auto one = rewriter.create<ConstantIndexOp>(loc, 1);
auto isBroadcasted =
rewriter.create<CmpIOp>(loc, CmpIPredicate::eq, dim, one);
broadcastedDimInfo.insert(std::make_pair(i, isBroadcasted));
}
}
}
@ -157,14 +160,18 @@ struct ONNXGemmOpLowering : public ConversionPattern {
auto matmulIterateOp = rewriter.create<KrnlIterateOp>(loc, reductionPack);
// Compute beta*C, and add up to alpha*A*B (unidirectional broadcasting)
auto loopCIVs = getLoopIVsForBroadcasting(
loc, rewriter, loopMNIVs, C, broadcastedDimInfo);
auto loadedC = rewriter.create<LoadOp>(loc, C, loopCIVs);
auto loadedAB = rewriter.create<LoadOp>(loc, alloc, loopMNIVs);
auto alphaAB = rewriter.create<MulFOp>(loc, alpha, loadedAB);
auto betaC = rewriter.create<MulFOp>(loc, beta, loadedC);
auto Y = rewriter.create<AddFOp>(loc, alphaAB, betaC);
rewriter.create<StoreOp>(loc, Y, alloc, loopMNIVs);
if (has_bias) {
auto loopCIVs = getLoopIVsForBroadcasting(loc, rewriter, loopMNIVs, C,
broadcastedDimInfo);
auto loadedC = rewriter.create<LoadOp>(loc, C, loopCIVs);
auto betaC = rewriter.create<MulFOp>(loc, beta, loadedC);
auto Y = rewriter.create<AddFOp>(loc, alphaAB, betaC);
rewriter.create<StoreOp>(loc, Y, alloc, loopMNIVs);
} else {
rewriter.create<StoreOp>(loc, alphaAB, alloc, loopMNIVs);
}
// Insert instructions to do matrix multiplication: A*B
Block &matmulIterationBlock = matmulIterateOp.bodyRegion().front();
@ -205,5 +212,6 @@ struct ONNXGemmOpLowering : public ConversionPattern {
void populateLoweringONNXGemmOpPattern(
OwningRewritePatternList &patterns, MLIRContext *ctx) {
patterns.insert<ONNXGemmOpLowering>(ctx);
patterns.insert<ONNXGemmOpLowering<ONNXGemmOp>>(ctx);
patterns.insert<ONNXGemmOpLowering<ONNXGemmNoBiasOp>>(ctx);
}

View File

@ -15,7 +15,6 @@ struct ONNXMatMulOpLowering : public ConversionPattern {
PatternMatchResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
auto tensorType = (*op->result_type_begin()).cast<TensorType>();
auto loc = op->getLoc();
Value A = operands[0];
@ -29,7 +28,7 @@ struct ONNXMatMulOpLowering : public ConversionPattern {
// - Both arguments are 1-D
// Result type
auto memRefType = convertTensorToMemRef(tensorType);
auto memRefType = convertToMemRefType(*op->result_type_begin());
auto elementType = memRefType.getElementType();
auto memRefShape = memRefType.getShape();

View File

@ -145,9 +145,9 @@ struct ONNXReductionOpLowering : public ConversionPattern {
auto loc = op->getLoc();
auto memRefInType = operands[0].getType().cast<MemRefType>();
auto memRefInShape = memRefInType.getShape();
auto tensorOutType = (*op->result_type_begin()).cast<TensorType>();
auto memRefOutType = convertToMemRefType(*op->result_type_begin());
int64_t inRank = memRefInType.getRank();
int64_t outRank = tensorOutType.getRank();
int64_t outRank = memRefOutType.getRank();
// Get attributes
ArrayAttr axisAttrs = llvm::dyn_cast<ONNXReductionOp>(op).axesAttr();
@ -171,7 +171,6 @@ struct ONNXReductionOpLowering : public ConversionPattern {
bool isKeepdims = (keepdims == 1) ? true : false;
// Get type information
auto memRefOutType = convertTensorToMemRef(tensorOutType);
auto memRefOutShape = memRefOutType.getShape();
auto elementOutType = memRefOutType.getElementType();
std::map<int64_t, int64_t> outInDimMap =

View File

@ -18,8 +18,8 @@ struct ONNXSoftmaxOpLowering : public ConversionPattern {
// let exp_x = exp(x - max_x) in
// let sum = sum(exp_x) in
// exp_x / sum
auto tensorType = (*op->result_type_begin()).cast<RankedTensorType>();
int64_t rank = tensorType.getRank();
auto memRefType = convertToMemRefType(*op->result_type_begin());
int64_t rank = memRefType.getRank();
int64_t axis = llvm::dyn_cast<ONNXSoftmaxOp>(op).axis().getSExtValue();
axis = axis >= 0 ? axis : rank + axis;
assert(axis >= -rank && axis <= rank - 1);
@ -27,7 +27,6 @@ struct ONNXSoftmaxOpLowering : public ConversionPattern {
auto loc = op->getLoc();
// Insert an allocation and deallocation for the result of this operation.
auto memRefType = convertTensorToMemRef(tensorType);
auto elementType = memRefType.getElementType();
Value alloc;

View File

@ -15,10 +15,9 @@ struct ONNXConvNoBiasOpLowering : public ConversionPattern {
PatternMatchResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
auto tensorType = (*op->result_type_begin()).cast<TensorType>();
auto loc = op->getLoc();
// Insert an allocation and deallocation for the result of this operation.
auto memRefType = convertTensorToMemRef(tensorType);
auto memRefType = convertToMemRefType(*op->result_type_begin());
Value alloc;
bool insertDealloc = checkInsertDealloc(op);
ONNXConvNoBiasOp convOp = llvm::dyn_cast<ONNXConvNoBiasOp>(op);

View File

@ -0,0 +1,140 @@
//===----- normalization.inc - Lowering Normalization Ops -----------------===//
//
// Copyright 2019 The IBM Research Authors.
//
// =============================================================================
//
// This file lowers ONNX Normalization Operators to Krnl dialect.
//
//===----------------------------------------------------------------------===//
struct ONNXBatchNormalizationTestModeOpLowering : public ConversionPattern {
ONNXBatchNormalizationTestModeOpLowering(MLIRContext *ctx)
: ConversionPattern(
mlir::ONNXBatchNormalizationTestModeOp::getOperationName(), 1,
ctx) {}
PatternMatchResult matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter & rewriter) const final {
// batchnorm{epsilon}(x, scale, bias, mean, variance) =
// scale * (x - mean) / sqrt(variance + epsilon) + bias
auto loc = op->getLoc();
auto memRefType = convertToMemRefType(*op->result_type_begin());
auto epsilonAttr =
FloatAttr::get(memRefType.getElementType(),
llvm::dyn_cast<ONNXBatchNormalizationTestModeOp>(op)
.epsilon()
.convertToFloat());
auto epsilon = rewriter.create<ConstantOp>(loc, epsilonAttr);
auto operand = operands[0];
auto scale = operands[1];
auto bias = operands[2];
auto mean = operands[3];
auto variance = operands[4];
// Insert an allocation and deallocation for the result of this operation.
Value alloc;
bool insertDealloc = checkInsertDealloc(op);
if (hasAllConstantDimensions(memRefType))
alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc);
else
alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc,
{operand});
// Operand's dimensions can be in the form of NxCxD1xD2x...xDn or N.
// In case of N, C is assumed to be 1.
// Shapes of scale, bias, mean and variance must be C.
// Computation of BatchNormalization is done as if scale, bias, mean, and
// variance are reshaped to Cx1x1x...x1.
// rank
int64_t rank = memRefType.getRank();
std::vector<Value> originalLoops;
std::vector<Value> optimizedLoops;
Block *optimizationBlock =
defineLoops(rewriter, loc, originalLoops, optimizedLoops, rank);
// Create a KrnlIterateOp along C dimension.
// This will be the outer-most loop in order to re-use scale, bias,
// mean and variance.
SmallVector<Value, 1> loopCIVs;
if (rank > 1) {
KrnlIterateOperandPack cPack(rewriter, originalLoops[1],
optimizedLoops[1]);
addDimensionToPack(rewriter, loc, cPack, operand, 1);
auto cIterateOp = rewriter.create<KrnlIterateOp>(loc, cPack);
Block &cIterationBlock = cIterateOp.bodyRegion().front();
rewriter.setInsertionPointToStart(&cIterationBlock);
for (auto arg : cIterationBlock.getArguments())
loopCIVs.emplace_back(arg);
} else {
loopCIVs.emplace_back(rewriter.create<ConstantIndexOp>(loc, 0));
}
auto scaleVal = rewriter.create<LoadOp>(loc, scale, loopCIVs);
auto biasVal = rewriter.create<LoadOp>(loc, bias, loopCIVs);
auto meanVal = rewriter.create<LoadOp>(loc, mean, loopCIVs);
auto varianceVal = rewriter.create<LoadOp>(loc, variance, loopCIVs);
// Create a KrnlIterateOp along the other dimensions.
SmallVector<int64_t, 4> axes;
axes.emplace_back(0);
for (int64_t i = 2; i < rank; ++i)
axes.emplace_back(i);
std::vector<Value> packLoops, packOptimizedLoops;
for (int i = 0; i < axes.size(); ++i) {
packLoops.emplace_back(originalLoops[axes[i]]);
packOptimizedLoops.emplace_back(optimizedLoops[axes[i]]);
}
KrnlIterateOperandPack pack(rewriter, packLoops, packOptimizedLoops);
for (int i = 0; i < axes.size(); ++i) {
addDimensionToPack(rewriter, loc, pack, operand, axes[i]);
}
auto iterateOp = rewriter.create<KrnlIterateOp>(loc, pack);
// No optimization
rewriter.setInsertionPointToEnd(optimizationBlock);
rewriter.create<KrnlReturnLoopsOp>(loc, originalLoops);
Block &iterationBlock = iterateOp.bodyRegion().front();
rewriter.setInsertionPointToStart(&iterationBlock);
SmallVector<Value, 4> loopIVs;
auto args = iterationBlock.getArguments();
if (args.size() > 1) {
loopIVs.emplace_back(args[0]);
loopIVs.emplace_back(loopCIVs[0]); // Insert C back.
for (int i = 1; i < args.size(); ++i)
loopIVs.emplace_back(args[i]);
} else {
loopIVs.emplace_back(args[0]);
}
auto xVal = rewriter.create<LoadOp>(loc, operand, loopIVs);
// normalize
auto dividend = rewriter.create<SubFOp>(loc, xVal, meanVal);
auto adjustedVarianceVal =
rewriter.create<AddFOp>(loc, varianceVal, epsilon);
auto divisor = rewriter.create<KrnlSqrtOp>(loc, memRefType.getElementType(),
adjustedVarianceVal);
auto normVal = rewriter.create<DivFOp>(loc, dividend, divisor);
// scale and shift
auto scaleNormVal = rewriter.create<MulFOp>(loc, scaleVal, normVal);
auto shiftScaleNormVal =
rewriter.create<AddFOp>(loc, scaleNormVal, biasVal);
rewriter.create<StoreOp>(loc, shiftScaleNormVal, alloc, loopIVs);
rewriter.replaceOp(op, alloc);
return matchSuccess();
}
};
void populateLoweringONNXNormalizationOpPattern(
OwningRewritePatternList &patterns, MLIRContext *ctx) {
patterns.insert<ONNXBatchNormalizationTestModeOpLowering>(ctx);
}

View File

@ -15,12 +15,11 @@ struct ONNXReshapeOpLowering : public ConversionPattern {
PatternMatchResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
auto tensorType = (*op->result_type_begin()).cast<TensorType>();
auto inputShape = operands[0].getType().cast<MemRefType>().getShape();
auto loc = op->getLoc();
// Insert an allocation and deallocation for the result of this operation.
auto memRefType = convertTensorToMemRef(tensorType);
auto memRefType = convertToMemRefType(*op->result_type_begin());
auto memRefShape = memRefType.getShape();
Value alloc;

View File

@ -15,10 +15,9 @@ struct ONNXTransposeOpLowering : public ConversionPattern {
PatternMatchResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
auto tensorType = (*op->result_type_begin()).cast<TensorType>();
auto loc = op->getLoc();
// Insert an allocation and deallocation for the result of this operation.
auto memRefType = convertTensorToMemRef(tensorType);
auto memRefType = convertToMemRefType(*op->result_type_begin());
Value alloc;
bool insertDealloc = checkInsertDealloc(op);

View File

@ -16,8 +16,8 @@ struct ONNXUnsqueezeOpLowering : public ConversionPattern {
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
auto loc = op->getLoc();
auto tensorType = (*op->result_type_begin()).cast<TensorType>();
int outRank = tensorType.getRank();
auto memRefType = convertToMemRefType(*op->result_type_begin());
int outRank = memRefType.getRank();
// Assume that `axes` has been validated by shape inference.
// So, here we just get it.
@ -30,7 +30,6 @@ struct ONNXUnsqueezeOpLowering : public ConversionPattern {
}
// Insert an allocation and deallocation for the result of this operation.
auto memRefType = convertTensorToMemRef(tensorType);
Value alloc;
// Compute size in bytes.

View File

@ -146,6 +146,31 @@ def ONNXMaxPoolSingleOutOp: ONNX_Op<"MaxPoolSingleOut",
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$o_Y);
}
def ONNXBatchNormalizationTestModeOp: ONNX_Op<"BatchNormalizationTestMode",
[NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
let summary = "ONNX BatchNormalization operation in test mode";
let description = [{
"Carries out batch normalization as described in the paper"
"https://arxiv.org/abs/1502.03167. Depending on the mode it is being run,"
"there are multiple cases for the number of outputs, which we list below:"
""
"Output case #1: Y, mean, var, saved_mean, saved_var (training mode)"
"Output case #2: Y (test mode)"
""
"For previous (depreciated) non-spatial cases, implementors are suggested"
"to flatten the input shape to (N x C*D1*D2 ..*Dn) before a BatchNormalization Op."
"This operator has **optional** inputs/outputs. See [the doc](IR.md) for more details about the representation of optional arguments. An empty string may be used in the place of an actual argument's name to indicate a missing argument. Trailing optional arguments (those not followed by an argument that is present) may also be simply omitted."
}];
let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$X,
AnyTypeOf<[AnyMemRef, AnyTensor]>:$scale,
AnyTypeOf<[AnyMemRef, AnyTensor]>:$B,
AnyTypeOf<[AnyMemRef, AnyTensor]>:$mean,
AnyTypeOf<[AnyMemRef, AnyTensor]>:$var,
DefaultValuedAttr<F32Attr, "1e-05">:$epsilon,
DefaultValuedAttr<F32Attr, "0.9">:$momentum);
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$o_Y);
}
def ONNXPadConstantValueOp : ONNX_Op<"PadConstantValue",
[NoSideEffect ]> {
let summary = "ONNX Pad operation with constant padding value";

View File

@ -586,12 +586,72 @@ void ONNXGemmNoBiasOp::inferShapes() {
return;
auto lhsTy = getOperand(0).getType().cast<RankedTensorType>();
auto rhsTy = getOperand(1).getType().cast<RankedTensorType>();
int64_t M, N, K_A, K_B;
M = (transA() == 0) ? lhsTy.getShape()[0] : lhsTy.getShape()[1];
K_A = (transA() == 0) ? lhsTy.getShape()[1] : lhsTy.getShape()[0];
N = (transB() == 0) ? rhsTy.getShape()[1] : rhsTy.getShape()[0];
K_B = (transB() == 0) ? rhsTy.getShape()[0] : rhsTy.getShape()[1];
if ((K_A != -1) and (K_B != -1) and (K_A != K_B)) {
emitError("Tensor shapes mismatched.");
}
SmallVector<int64_t, 2> dims;
dims.emplace_back(lhsTy.getShape()[0]);
dims.emplace_back(rhsTy.getShape()[1]);
dims.emplace_back(M);
dims.emplace_back(N);
getResult().setType(RankedTensorType::get(dims, lhsTy.getElementType()));
}
/// BatchNormalizationTestMode
void ONNXBatchNormalizationTestModeOp::inferShapes() {
// Cannot infer shape if no shape exists.
if (!getOperand(0).getType().isa<RankedTensorType>() ||
!getOperand(1).getType().isa<RankedTensorType>() ||
!getOperand(2).getType().isa<RankedTensorType>() ||
!getOperand(3).getType().isa<RankedTensorType>() ||
!getOperand(4).getType().isa<RankedTensorType>())
return;
auto input = getOperand(0).getType().cast<RankedTensorType>();
auto scale = getOperand(1).getType().cast<RankedTensorType>();
auto bias = getOperand(2).getType().cast<RankedTensorType>();
auto mean = getOperand(3).getType().cast<RankedTensorType>();
auto variance = getOperand(4).getType().cast<RankedTensorType>();
// Check whether the shapes of scale, bias, mean and variance are valid.
// Operand's dimensions can be in the form of NxCxD1xD2x...xDn or N.
// In case of N, C is assumed to be 1.
// Shapes of scale, bias, mean and variance must be C.
int64_t c = -1;
if (input.getShape().size() == 1) {
c = 1;
} else if (input.getShape().size() > 2) {
c = (input.getShape()[1] != -1) ? input.getShape()[1] : -1;
} else {
emitError("Wrong rank for the input.");
}
if (c != -1) {
auto s = scale.getShape();
auto b = bias.getShape();
auto m = mean.getShape();
auto v = variance.getShape();
if ((s.size() != 1) || (s[0] != -1 && s[0] != c))
emitError("Wrong rank for the scale.");
if ((b.size() != 1) || (b[0] != -1 && b[0] != c))
emitError("Wrong rank for the bias.");
if ((m.size() != 1) || (m[0] != -1 && m[0] != c))
emitError("Wrong rank for the mean.");
if ((v.size() != 1) || (v[0] != -1 && v[0] != c))
emitError("Wrong rank for the variance.");
}
// The output tensor of the same shape as the input.
getResult().setType(getOperand(0).getType());
}
// TODO:
// Verify that matrix sizes are valid for multiplication and addition.
// Take into account the dimensionality of the matrix.

View File

@ -24,6 +24,7 @@
#include "mlir/Conversion/LoopToStandard/ConvertLoopToStandard.h"
#include "mlir/ExecutionEngine/ExecutionEngine.h"
#include "mlir/ExecutionEngine/OptUtils.h"
#include "mlir/InitAllDialects.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Module.h"
#include "mlir/Parser.h"
@ -69,6 +70,10 @@ void EmitLLVMBitCode(const mlir::OwningModuleRef &module) {
}
int main(int argc, char *argv[]) {
mlir::registerDialect<mlir::AffineOpsDialect>();
mlir::registerDialect<mlir::LLVM::LLVMDialect>();
mlir::registerDialect<mlir::loop::LoopOpsDialect>();
mlir::registerDialect<mlir::StandardOpsDialect>();
mlir::registerDialect<mlir::ONNXOpsDialect>();
mlir::registerDialect<mlir::KrnlOpsDialect>();

File diff suppressed because it is too large Load Diff

View File

@ -130,6 +130,7 @@ public:
op->getName().getStringRef() != "onnx.ConvNoBias" &&
op->getName().getStringRef() != "onnx.PadConstantPad" &&
op->getName().getStringRef() != "onnx.PadConstantValuePad" &&
op->getName().getStringRef() != "onnx.BatchNormalizationTestMode" &&
op->getName().getStringRef() != "onnx.Unsqueeze")
return false;
return llvm::any_of(op->getResultTypes(), [](Type result_type) {

View File

@ -35,7 +35,7 @@ DynMemRef *getDynMemRef(OrderedDynMemRefDict *tensorDict, int idx) {
void setDynMemRef(OrderedDynMemRefDict *tensorDict, int idx,
DynMemRef *tensor) {
if (tensorDict->orderedNames.capacity() <= idx)
if (tensorDict->orderedNames.size() <= idx)
tensorDict->orderedNames.resize(idx + 1);
// The dynamic memref is essentially anonymous, since we are storing it by

View File

@ -9,6 +9,7 @@
#include <llvm/Support/CommandLine.h>
#include <llvm/Support/InitLLVM.h>
#include <llvm/Support/ToolOutputFile.h>
#include <mlir/InitAllDialects.h>
#include <mlir/Pass/Pass.h>
#include <mlir/Pass/PassManager.h>
#include <mlir/Support/FileUtilities.h>
@ -46,6 +47,10 @@ static llvm::cl::opt<bool> verify_passes(
llvm::cl::init(true));
int main(int argc, char **argv) {
mlir::registerDialect<mlir::AffineOpsDialect>();
mlir::registerDialect<mlir::LLVM::LLVMDialect>();
mlir::registerDialect<mlir::loop::LoopOpsDialect>();
mlir::registerDialect<mlir::StandardOpsDialect>();
llvm::InitLLVM y(argc, argv);
mlir::registerDialect<mlir::ONNXOpsDialect>();

View File

@ -224,7 +224,10 @@ public:
// Based on the static entry point type signature, unpack dynamic memory
// refs to corresponding static memory refs.
auto *staticEntryPointFunc = module.lookupSymbol(staticEntryPointFuncName);
auto wrappedStaticEntryPointFuncName =
"_mlir_ciface_" + staticEntryPointFuncName.lower();
auto *staticEntryPointFunc =
module.lookupSymbol(wrappedStaticEntryPointFuncName);
assert(staticEntryPointFunc &&
isa<LLVM::LLVMFuncOp>(staticEntryPointFunc) &&
"entry point func must exist and be an llvm func op");
@ -268,7 +271,8 @@ public:
// Call static entry point with the memref ptrs created, and get output.
auto outputMemRefs = rewriter.create<LLVM::CallOp>(
loc, staticEntryPointTy.getFunctionResultType(),
rewriter.getSymbolRefAttr(staticEntryPointFuncName), staticInputs);
rewriter.getSymbolRefAttr(wrappedStaticEntryPointFuncName),
staticInputs);
// Create wrapped output.
auto wrappedOutput = callApi(rewriter, loc, apiRegistry,
@ -563,7 +567,9 @@ void KrnlToLLVMLoweringPass::runOnModule() {
OwningRewritePatternList patterns;
populateAffineToStdConversionPatterns(patterns, &getContext());
populateLoopToStdConversionPatterns(patterns, &getContext());
populateStdToLLVMConversionPatterns(typeConverter, patterns);
populateStdToLLVMConversionPatterns(typeConverter, patterns,
/*useAlloca=*/false,
/*emitCWrapper=*/true);
// Lower from the `krnl` dialect i.e. the Reshape operation.
patterns.insert<KrnlMemcpyOpLowering, KrnlEntryPointOpLowering,

View File

@ -98,7 +98,7 @@ test_to_enable = [
"test_gemm_alpha_cpu",
"test_gemm_beta_cpu",
"test_gemm_default_matrix_bias_cpu",
# "test_gemm_default_no_bias_cpu", <- error, need support for optional operands
"test_gemm_default_no_bias_cpu",
"test_gemm_default_scalar_bias_cpu",
"test_gemm_default_single_elem_vector_bias_cpu",
"test_gemm_default_vector_bias_cpu",
@ -301,6 +301,10 @@ test_to_enable = [
"test_matmul_3d_cpu",
"test_matmul_4d_cpu",
# BatchNormalization (test mode)
"test_batchnorm_epsilon_cpu",
"test_batchnorm_example_cpu",
]
# Extract name of all test cases.

View File

@ -5,10 +5,18 @@ func @test_reshape(%arg0 : tensor<?x10xf32>, %arg1 : tensor<4xi32>) -> tensor<*x
"std.return"(%0) : (tensor<*xf32>) -> ()
// CHECK: llvm.func @llvm.memcpy.p0i8.p0i8.i64(!llvm<"i8*">, !llvm<"i8*">, !llvm.i64, !llvm.i1)
// CHECK: [[TMP:%.+]] = llvm.mlir.undef : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
// CHECK: llvm.insertvalue %arg0, %0[0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
// CHECK: llvm.insertvalue %arg1, %1[1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
// CHECK: llvm.insertvalue %arg2, %2[2] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
// CHECK: llvm.insertvalue %arg3, %3[3, 0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
// CHECK: llvm.insertvalue %arg5, %4[4, 0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
// CHECK: llvm.insertvalue %arg4, %5[3, 1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
// CHECK: [[TMP1:%.+]] = llvm.insertvalue %arg6, %6[4, 1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
// CHECK: [[RES:%.+]] = llvm.insertvalue {{.*}}[4, 3] : !llvm<"{ float*, float*, i64, [4 x i64], [4 x i64] }">
// CHECK: [[EXT_VAL_0:%.+]] = llvm.extractvalue [[RES]][1] : !llvm<"{ float*, float*, i64, [4 x i64], [4 x i64] }">
// CHECK: [[DST:%.+]] = llvm.bitcast [[EXT_VAL_0]] : !llvm<"float*"> to !llvm<"i8*">
// CHECK: [[EXT_VAL_1:%.+]] = llvm.extractvalue %0[1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
// CHECK: [[EXT_VAL_1:%.+]] = llvm.extractvalue [[TMP1]][1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
// CHECK: [[SRC:%.+]] = llvm.bitcast [[EXT_VAL_1]] : !llvm<"float*"> to !llvm<"i8*">
// CHECK: [[SIZE:%.+]] = llvm.sext %{{.*}} : !llvm.i64 to !llvm.i64
// CHECK: [[VOLATILE:%.+]] = llvm.mlir.constant(0 : i1) : !llvm.i1

View File

@ -795,12 +795,42 @@ func @test_gemm(%arg0 : tensor<5x10xf32>, %arg1 : tensor<5x10xf32>, %arg2: tenso
// CHECK: [[SUM:%.+]] = addf [[Y]], [[AB]] : f32
// CHECK: store [[SUM]], [[RES]][%arg3, %arg4] : memref<10x10xf32>
// CHECK: }
// CHECK: [[C:%.+]] = load %arg2[%arg4] : memref<10xf32>
// CHECK: [[LOAD_Y:%.+]] = load [[RES]][%arg3, %arg4] : memref<10x10xf32>
// CHECK: [[ALPHA_AB:%.+]] = mulf [[ALPHA]], [[LOAD_Y]] : f32
// CHECK: [[C:%.+]] = load %arg2[%arg4] : memref<10xf32>
// CHECK: [[BETA_C:%.+]] = mulf [[BETA]], [[C]] : f32
// CHECK: [[Y_RES:%.+]] = addf [[ALPHA_AB]], [[BETA_C]] : f32
// CHECK: store [[Y_RES]], [[RES]][%arg3, %arg4] : memref<10x10xf32>
// CHECK: }
// CHECK: return [[RES]] : memref<10x10xf32>
// CHECK: }
}
func @test_gemm_no_bias(%arg0 : tensor<5x10xf32>, %arg1 : tensor<5x10xf32>) -> tensor<*xf32> {
%0 ="onnx.GemmNoBias"(%arg0, %arg1) {alpha = 1.0 : f32, beta = 5.0 : f32, transA = 1, transB = 0} : (tensor<5x10xf32>, tensor<5x10xf32>) -> tensor<*xf32>
"std.return"(%0) : (tensor<*xf32>) -> ()
// CHECK-LABEL: test_gemm_no_bias
// CHECK: [[RES:%.+]] = alloc() : memref<10x10xf32>
// CHECK: [[ALPHA:%.+]] = constant 1.000000e+00 : f32
// CHECK: [[BETA:%.+]] = constant 5.000000e+00 : f32
// CHECK: [[DEF_LOOPS:%.+]]:3 = krnl.define_loops 3
// CHECK: [[OPT_LOOPS:%.+]]:3 = krnl.optimize_loops {
// CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1, [[DEF_LOOPS]]#2
// CHECK: } : () -> (!krnl.loop, !krnl.loop, !krnl.loop)
// CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg2 = 0 to 10, [[DEF_LOOPS]]#1 -> %arg3 = 0 to 10) {
// CHECK: krnl.iterate([[OPT_LOOPS]]#2) with ([[DEF_LOOPS]]#2 -> %arg4 = 0 to 5) {
// CHECK: [[A:%.+]] = load %arg0[%arg4, %arg2] : memref<5x10xf32>
// CHECK: [[B:%.+]] = load %arg1[%arg4, %arg3] : memref<5x10xf32>
// CHECK: [[Y:%.+]] = load [[RES]][%arg2, %arg3] : memref<10x10xf32>
// CHECK: [[AB:%.+]] = mulf [[A]], [[B]] : f32
// CHECK: [[SUM:%.+]] = addf [[Y]], [[AB]] : f32
// CHECK: store [[SUM]], [[RES]][%arg2, %arg3] : memref<10x10xf32>
// CHECK: }
// CHECK: [[LOAD_Y:%.+]] = load [[RES]][%arg2, %arg3] : memref<10x10xf32>
// CHECK: [[ALPHA_AB:%.+]] = mulf [[ALPHA]], [[LOAD_Y]] : f32
// CHECK: store [[ALPHA_AB]], [[RES]][%arg2, %arg3] : memref<10x10xf32>
// CHECK: }
// CHECK: return [[RES]] : memref<10x10xf32>
// CHECK: }
}
@ -1283,3 +1313,63 @@ func @test_conv_no_bias_no_pad_w_strides(%arg0 : tensor<1x9x32x64xf32>, %arg1 :
// CHECK: return [[RES]] : memref<1x5x14x29xf32>
}
func @test_batchnorm_testmode_Nd(%arg0: tensor<1x2x1x3xf32>, %arg1: tensor<2xf32>, %arg2: tensor<2xf32>, %arg3: tensor<2xf32>, %arg4: tensor<2xf32>) -> tensor<1x2x1x3xf32> {
%0 = "onnx.BatchNormalizationTestMode"(%arg0, %arg1, %arg2, %arg3, %arg4) : (tensor<1x2x1x3xf32>, tensor<2xf32>, tensor<2xf32>, tensor<2xf32>, tensor<2xf32>) -> tensor<1x2x1x3xf32>
return %0 : tensor<1x2x1x3xf32>
// CHECK-LABEL: test_batchnorm_testmode_Nd
// CHECK: [[RES:%.+]] = alloc() : memref<1x2x1x3xf32>
// CHECK: [[EPSILON:%.+]] = constant 9.99999974E-6 : f32
// CHECK: [[DEF_LOOPS:%.+]]:4 = krnl.define_loops 4
// CHECK: [[OPT_LOOPS:%.+]]:4 = krnl.optimize_loops {
// CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1, [[DEF_LOOPS]]#2, [[DEF_LOOPS]]#3
// CHECK: } : () -> (!krnl.loop, !krnl.loop, !krnl.loop, !krnl.loop)
// CHECK: krnl.iterate([[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#1 -> %arg5 = 0 to 2) {
// CHECK: [[SCALE:%.+]] = load %arg1[%arg5] : memref<2xf32>
// CHECK: [[BIAS:%.+]] = load %arg2[%arg5] : memref<2xf32>
// CHECK: [[MEAN:%.+]] = load %arg3[%arg5] : memref<2xf32>
// CHECK: [[VARIANCE:%.+]] = load %arg4[%arg5] : memref<2xf32>
// CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#2, [[OPT_LOOPS]]#3) with ([[DEF_LOOPS]]#0 -> %arg6 = 0 to 1, [[DEF_LOOPS]]#2 -> %arg7 = 0 to 1, [[DEF_LOOPS]]#3 -> %arg8 = 0 to 3) {
// CHECK: [[LOADED_VAL:%.+]] = load %arg0[%arg6, %arg5, %arg7, %arg8] : memref<1x2x1x3xf32>
// CHECK: [[DIVIDEND:%.+]] = subf [[LOADED_VAL]], [[MEAN]] : f32
// CHECK: [[ADJUSTED_VARIANCE:%.+]] = addf [[VARIANCE]], [[EPSILON]] : f32
// CHECK: [[DIVISOR:%.+]] = "krnl.sqrt"([[ADJUSTED_VARIANCE]]) : (f32) -> f32
// CHECK: [[NORM:%.+]] = divf [[DIVIDEND]], [[DIVISOR]] : f32
// CHECK: [[SCALE_NORM:%.+]] = mulf [[SCALE]], [[NORM]] : f32
// CHECK: [[SHIFT_SCALE_NORM:%.+]] = addf [[SCALE_NORM]], [[BIAS]] : f32
// CHECK: store [[SHIFT_SCALE_NORM]], [[RES]][%arg6, %arg5, %arg7, %arg8] : memref<1x2x1x3xf32>
// CHECK: }
// CHECK: }
// CHECK: return [[RES]] : memref<1x2x1x3xf32>
}
func @test_batchnorm_testmode_1d(%arg0: tensor<10xf32>, %arg1: tensor<1xf32>, %arg2: tensor<1xf32>, %arg3: tensor<1xf32>, %arg4: tensor<1xf32>) -> tensor<10xf32> {
%0 = "onnx.BatchNormalizationTestMode"(%arg0, %arg1, %arg2, %arg3, %arg4) : (tensor<10xf32>, tensor<1xf32>, tensor<1xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<10xf32>
return %0 : tensor<10xf32>
// CHECK-LABEL: test_batchnorm_testmode_1d
// CHECK: [[RES:%.+]] = alloc() : memref<10xf32>
// CHECK: [[EPSILON:%.+]] = constant 9.99999974E-6 : f32
// CHECK: [[DEF_LOOPS:%.+]] = krnl.define_loops 1
// CHECK: [[OPT_LOOPS:%.+]] = krnl.optimize_loops {
// CHECK: krnl.return_loops [[DEF_LOOPS]]
// CHECK: } : () -> !krnl.loop
// CHECK: %[[ZERO_INDEX:.+]] = constant 0 : index
// CHECK: [[SCALE:%.+]] = load %arg1[%[[ZERO_INDEX]]] : memref<1xf32>
// CHECK: [[BIAS:%.+]] = load %arg2[%[[ZERO_INDEX]]] : memref<1xf32>
// CHECK: [[MEAN:%.+]] = load %arg3[%[[ZERO_INDEX]]] : memref<1xf32>
// CHECK: [[VARIANCE:%.+]] = load %arg4[%[[ZERO_INDEX]]] : memref<1xf32>
// CHECK: krnl.iterate([[OPT_LOOPS]]) with ([[DEF_LOOPS]] -> %arg5 = 0 to 10) {
// CHECK: [[LOADED_VAL:%.+]] = load %arg0[%arg5] : memref<10xf32>
// CHECK: [[DIVIDEND:%.+]] = subf [[LOADED_VAL]], [[MEAN]] : f32
// CHECK: [[ADJUSTED_VARIANCE:%.+]] = addf [[VARIANCE]], [[EPSILON]] : f32
// CHECK: [[DIVISOR:%.+]] = "krnl.sqrt"([[ADJUSTED_VARIANCE]]) : (f32) -> f32
// CHECK: [[NORM:%.+]] = divf [[DIVIDEND]], [[DIVISOR]] : f32
// CHECK: [[SCALE_NORM:%.+]] = mulf [[SCALE]], [[NORM]] : f32
// CHECK: [[SHIFT_SCALE_NORM:%.+]] = addf [[SCALE_NORM]], [[BIAS]] : f32
// CHECK: store [[SHIFT_SCALE_NORM]], [[RES]][%arg5] : memref<10xf32>
// CHECK: }
// CHECK: return [[RES]] : memref<10xf32>
}