From aea6479ad362ad91078099fe55a6bfae4f2cce9d Mon Sep 17 00:00:00 2001 From: "Tung D. Le" Date: Fri, 21 Feb 2020 01:45:40 +0900 Subject: [PATCH] Lower BatchNormalization (test mode) to Krnl dialect (#70) * Add ONNXBatchNormalizationTestModeOp and its shape inference * Lower batchnormalization test mode * re-use scale, bias, mean, and variance * Add MLIR tests * Add e2e tests * fix typos * Fix a bug in MLIR tests * Change type from int to int64_t for indices * Uncomment e2e tests due to segmentation fault * Uncomment e2e tests due to segmentation fault * Revise the code * [Tian] Fix segmentation fault in e2e tests * Re-generate onnx.md to include BatchNormalizationTestModeOp * Reverse an unintentional change * Fix some typos in comments * Use convertToMemRefType from the master branch Co-authored-by: Gheorghe-Teodor Bercea --- doc/Dialects/onnx.md | 36 +++++ doc/gen_doc.py | 1 + src/builder/frontend_dialect_transformer.cpp | 15 ++ src/builder/op_build_table.inc | 2 +- .../onnx_to_krnl/convert_onnx_to_krnl.cpp | 2 + .../rewrite_patterns/nn/normalization.inc | 140 ++++++++++++++++++ src/dialect/onnx/onnx.td | 25 ++++ src/dialect/onnx/onnx_ops.cpp | 49 ++++++ src/pass/shape_inference_pass.cpp | 1 + src/runtime/dyn_memref.cpp | 2 +- test/backend/test.py | 4 + test/mlir/onnx/onnx_lowering.mlir | 60 ++++++++ 12 files changed, 335 insertions(+), 2 deletions(-) create mode 100644 src/conversion/onnx_to_krnl/rewrite_patterns/nn/normalization.inc diff --git a/doc/Dialects/onnx.md b/doc/Dialects/onnx.md index ba5de60..d1da4d6 100644 --- a/doc/Dialects/onnx.md +++ b/doc/Dialects/onnx.md @@ -332,6 +332,42 @@ ONNX BatchNormalization operation 1. `saved_mean`: memref of any type values or tensor of any type values 1. `saved_var`: memref of any type values or tensor of any type values +### onnx.BatchNormalizationTestMode (ONNXBatchNormalizationTestModeOp) +ONNX BatchNormalization operation in test mode + +#### Description: + + +"Carries out batch normalization as described in the paper" +"https://arxiv.org/abs/1502.03167. Depending on the mode it is being run," +"there are multiple cases for the number of outputs, which we list below:" +"" +"Output case #1: Y, mean, var, saved_mean, saved_var (training mode)" +"Output case #2: Y (test mode)" +"" +"For previous (depreciated) non-spatial cases, implementors are suggested" +"to flatten the input shape to (N x C*D1*D2 ..*Dn) before a BatchNormalization Op." +"This operator has **optional** inputs/outputs. See [the doc](IR.md) for more details about the representation of optional arguments. An empty string may be used in the place of an actual argument's name to indicate a missing argument. Trailing optional arguments (those not followed by an argument that is present) may also be simply omitted." + +#### Operands: + +1. `X`: memref of any type values or tensor of any type values +1. `scale`: memref of any type values or tensor of any type values +1. `B`: memref of any type values or tensor of any type values +1. `mean`: memref of any type values or tensor of any type values +1. `var`: memref of any type values or tensor of any type values + +#### Attributes: + +| Attribute | MLIR Type | Description | +| :-------: | :-------: | ----------- | +| `epsilon` | `FloatAttr` | 32-bit float attribute attribute | +| `momentum` | `FloatAttr` | 32-bit float attribute attribute | + +#### Results: + +1. `o_Y`: memref of any type values or tensor of any type values + ### onnx.BitShift (ONNXBitShiftOp) ONNX BitShift operation diff --git a/doc/gen_doc.py b/doc/gen_doc.py index 04d456b..df1337d 100644 --- a/doc/gen_doc.py +++ b/doc/gen_doc.py @@ -36,6 +36,7 @@ special_attr_defaults = dict([ special_op_handler = dict([ ("Conv", "ImportNodeConv"), ("MaxPool", "ImportNodeMaxPool"), + ("BatchNormalization", "ImportNodeBatchNormalization"), ("Gemm", "ImportNodeGemm"), ("Pad", "ImportNodePad"), #("Transpose", "ImportNodeTranspose") diff --git a/src/builder/frontend_dialect_transformer.cpp b/src/builder/frontend_dialect_transformer.cpp index 29b7798..9cadad8 100644 --- a/src/builder/frontend_dialect_transformer.cpp +++ b/src/builder/frontend_dialect_transformer.cpp @@ -434,6 +434,21 @@ private: } } + /*! + * Special handle for BatchNormalization operations. + */ + void ImportNodeBatchNormalization(onnx::NodeProto node, int nIn, int nOut) { + int nOuts = node.output().size(); + if (nOuts == 1) { + // Test mode with one output. + ImportNodeOneOut(node, nIn, + nOuts); + } else { + // Training mode with four trailing optional outputs. Not handled yet. + ImportNodeMultipleOuts(node, nIn, nOuts); + } + } + /*! * Special handle for Gemm operations. */ diff --git a/src/builder/op_build_table.inc b/src/builder/op_build_table.inc index e6d97c5..c0b2ca6 100644 --- a/src/builder/op_build_table.inc +++ b/src/builder/op_build_table.inc @@ -30,7 +30,7 @@ }else if (OpName == "AveragePool") { ImportNodeOneOut(node, 1, 1); }else if (OpName == "BatchNormalization") { - ImportNodeMultipleOuts(node, 5, 5); + ImportNodeBatchNormalization(node, 5, 5); }else if (OpName == "BitShift") { ImportNodeOneOut(node, 2, 1); }else if (OpName == "Cast") { diff --git a/src/conversion/onnx_to_krnl/convert_onnx_to_krnl.cpp b/src/conversion/onnx_to_krnl/convert_onnx_to_krnl.cpp index dec03cf..84d4be8 100644 --- a/src/conversion/onnx_to_krnl/convert_onnx_to_krnl.cpp +++ b/src/conversion/onnx_to_krnl/convert_onnx_to_krnl.cpp @@ -404,6 +404,7 @@ Value mapToLowerScalarOp(Operation *op, ArrayRef result_types, #include "src/conversion/onnx_to_krnl/rewrite_patterns/tensor/unsqueeze.inc" // Neural network #include "src/conversion/onnx_to_krnl/rewrite_patterns/nn/conv.inc" +#include "src/conversion/onnx_to_krnl/rewrite_patterns/nn/normalization.inc" //===----------------------------------------------------------------------===// // EntryPoint Op lowering to Krnl Entry Point. @@ -523,6 +524,7 @@ void FrontendToKrnlLoweringPass::runOnModule() { populateLoweringONNXIdentityOpPattern(patterns, &getContext()); // Neural network populateLoweringONNXConvOpPattern(patterns, &getContext()); + populateLoweringONNXNormalizationOpPattern(patterns, &getContext()); // Entry point patterns.insert(&getContext()); diff --git a/src/conversion/onnx_to_krnl/rewrite_patterns/nn/normalization.inc b/src/conversion/onnx_to_krnl/rewrite_patterns/nn/normalization.inc new file mode 100644 index 0000000..cb98b13 --- /dev/null +++ b/src/conversion/onnx_to_krnl/rewrite_patterns/nn/normalization.inc @@ -0,0 +1,140 @@ +//===----- normalization.inc - Lowering Normalization Ops -----------------===// +// +// Copyright 2019 The IBM Research Authors. +// +// ============================================================================= +// +// This file lowers ONNX Normalization Operators to Krnl dialect. +// +//===----------------------------------------------------------------------===// + +struct ONNXBatchNormalizationTestModeOpLowering : public ConversionPattern { + ONNXBatchNormalizationTestModeOpLowering(MLIRContext *ctx) + : ConversionPattern( + mlir::ONNXBatchNormalizationTestModeOp::getOperationName(), 1, + ctx) {} + PatternMatchResult matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter & rewriter) const final { + // batchnorm{epsilon}(x, scale, bias, mean, variance) = + // scale * (x - mean) / sqrt(variance + epsilon) + bias + auto loc = op->getLoc(); + + auto memRefType = convertToMemRefType(*op->result_type_begin()); + auto epsilonAttr = + FloatAttr::get(memRefType.getElementType(), + llvm::dyn_cast(op) + .epsilon() + .convertToFloat()); + auto epsilon = rewriter.create(loc, epsilonAttr); + + auto operand = operands[0]; + auto scale = operands[1]; + auto bias = operands[2]; + auto mean = operands[3]; + auto variance = operands[4]; + + // Insert an allocation and deallocation for the result of this operation. + Value alloc; + bool insertDealloc = checkInsertDealloc(op); + + if (hasAllConstantDimensions(memRefType)) + alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc); + else + alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc, + {operand}); + + // Operand's dimensions can be in the form of NxCxD1xD2x...xDn or N. + // In case of N, C is assumed to be 1. + // Shapes of scale, bias, mean and variance must be C. + // Computation of BatchNormalization is done as if scale, bias, mean, and + // variance are reshaped to Cx1x1x...x1. + + // rank + int64_t rank = memRefType.getRank(); + + std::vector originalLoops; + std::vector optimizedLoops; + Block *optimizationBlock = + defineLoops(rewriter, loc, originalLoops, optimizedLoops, rank); + + // Create a KrnlIterateOp along C dimension. + // This will be the outer-most loop in order to re-use scale, bias, + // mean and variance. + + SmallVector loopCIVs; + if (rank > 1) { + KrnlIterateOperandPack cPack(rewriter, originalLoops[1], + optimizedLoops[1]); + addDimensionToPack(rewriter, loc, cPack, operand, 1); + auto cIterateOp = rewriter.create(loc, cPack); + Block &cIterationBlock = cIterateOp.bodyRegion().front(); + rewriter.setInsertionPointToStart(&cIterationBlock); + for (auto arg : cIterationBlock.getArguments()) + loopCIVs.emplace_back(arg); + } else { + loopCIVs.emplace_back(rewriter.create(loc, 0)); + } + + auto scaleVal = rewriter.create(loc, scale, loopCIVs); + auto biasVal = rewriter.create(loc, bias, loopCIVs); + auto meanVal = rewriter.create(loc, mean, loopCIVs); + auto varianceVal = rewriter.create(loc, variance, loopCIVs); + + // Create a KrnlIterateOp along the other dimensions. + SmallVector axes; + axes.emplace_back(0); + for (int64_t i = 2; i < rank; ++i) + axes.emplace_back(i); + std::vector packLoops, packOptimizedLoops; + for (int i = 0; i < axes.size(); ++i) { + packLoops.emplace_back(originalLoops[axes[i]]); + packOptimizedLoops.emplace_back(optimizedLoops[axes[i]]); + } + KrnlIterateOperandPack pack(rewriter, packLoops, packOptimizedLoops); + for (int i = 0; i < axes.size(); ++i) { + addDimensionToPack(rewriter, loc, pack, operand, axes[i]); + } + auto iterateOp = rewriter.create(loc, pack); + + // No optimization + rewriter.setInsertionPointToEnd(optimizationBlock); + rewriter.create(loc, originalLoops); + + Block &iterationBlock = iterateOp.bodyRegion().front(); + rewriter.setInsertionPointToStart(&iterationBlock); + + SmallVector loopIVs; + auto args = iterationBlock.getArguments(); + if (args.size() > 1) { + loopIVs.emplace_back(args[0]); + loopIVs.emplace_back(loopCIVs[0]); // Insert C back. + for (int i = 1; i < args.size(); ++i) + loopIVs.emplace_back(args[i]); + } else { + loopIVs.emplace_back(args[0]); + } + + auto xVal = rewriter.create(loc, operand, loopIVs); + // normalize + auto dividend = rewriter.create(loc, xVal, meanVal); + auto adjustedVarianceVal = + rewriter.create(loc, varianceVal, epsilon); + auto divisor = rewriter.create(loc, memRefType.getElementType(), + adjustedVarianceVal); + auto normVal = rewriter.create(loc, dividend, divisor); + // scale and shift + auto scaleNormVal = rewriter.create(loc, scaleVal, normVal); + auto shiftScaleNormVal = + rewriter.create(loc, scaleNormVal, biasVal); + rewriter.create(loc, shiftScaleNormVal, alloc, loopIVs); + + rewriter.replaceOp(op, alloc); + + return matchSuccess(); + } +}; + +void populateLoweringONNXNormalizationOpPattern( + OwningRewritePatternList &patterns, MLIRContext *ctx) { + patterns.insert(ctx); +} diff --git a/src/dialect/onnx/onnx.td b/src/dialect/onnx/onnx.td index c18e620..43d4a10 100644 --- a/src/dialect/onnx/onnx.td +++ b/src/dialect/onnx/onnx.td @@ -146,6 +146,31 @@ def ONNXMaxPoolSingleOutOp: ONNX_Op<"MaxPoolSingleOut", let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$o_Y); } +def ONNXBatchNormalizationTestModeOp: ONNX_Op<"BatchNormalizationTestMode", + [NoSideEffect, DeclareOpInterfaceMethods]> { + let summary = "ONNX BatchNormalization operation in test mode"; + let description = [{ + "Carries out batch normalization as described in the paper" + "https://arxiv.org/abs/1502.03167. Depending on the mode it is being run," + "there are multiple cases for the number of outputs, which we list below:" + "" + "Output case #1: Y, mean, var, saved_mean, saved_var (training mode)" + "Output case #2: Y (test mode)" + "" + "For previous (depreciated) non-spatial cases, implementors are suggested" + "to flatten the input shape to (N x C*D1*D2 ..*Dn) before a BatchNormalization Op." + "This operator has **optional** inputs/outputs. See [the doc](IR.md) for more details about the representation of optional arguments. An empty string may be used in the place of an actual argument's name to indicate a missing argument. Trailing optional arguments (those not followed by an argument that is present) may also be simply omitted." + }]; + let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$X, + AnyTypeOf<[AnyMemRef, AnyTensor]>:$scale, + AnyTypeOf<[AnyMemRef, AnyTensor]>:$B, + AnyTypeOf<[AnyMemRef, AnyTensor]>:$mean, + AnyTypeOf<[AnyMemRef, AnyTensor]>:$var, + DefaultValuedAttr:$epsilon, + DefaultValuedAttr:$momentum); + let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$o_Y); +} + def ONNXPadConstantValueOp : ONNX_Op<"PadConstantValue", [NoSideEffect ]> { let summary = "ONNX Pad operation with constant padding value"; diff --git a/src/dialect/onnx/onnx_ops.cpp b/src/dialect/onnx/onnx_ops.cpp index 4524467..f3cfeef 100644 --- a/src/dialect/onnx/onnx_ops.cpp +++ b/src/dialect/onnx/onnx_ops.cpp @@ -603,6 +603,55 @@ void ONNXGemmNoBiasOp::inferShapes() { getResult().setType(RankedTensorType::get(dims, lhsTy.getElementType())); } +/// BatchNormalizationTestMode +void ONNXBatchNormalizationTestModeOp::inferShapes() { + // Cannot infer shape if no shape exists. + if (!getOperand(0).getType().isa() || + !getOperand(1).getType().isa() || + !getOperand(2).getType().isa() || + !getOperand(3).getType().isa() || + !getOperand(4).getType().isa()) + return; + + auto input = getOperand(0).getType().cast(); + auto scale = getOperand(1).getType().cast(); + auto bias = getOperand(2).getType().cast(); + auto mean = getOperand(3).getType().cast(); + auto variance = getOperand(4).getType().cast(); + + // Check whether the shapes of scale, bias, mean and variance are valid. + // Operand's dimensions can be in the form of NxCxD1xD2x...xDn or N. + // In case of N, C is assumed to be 1. + // Shapes of scale, bias, mean and variance must be C. + int64_t c = -1; + if (input.getShape().size() == 1) { + c = 1; + } else if (input.getShape().size() > 2) { + c = (input.getShape()[1] != -1) ? input.getShape()[1] : -1; + } else { + emitError("Wrong rank for the input."); + } + + if (c != -1) { + auto s = scale.getShape(); + auto b = bias.getShape(); + auto m = mean.getShape(); + auto v = variance.getShape(); + + if ((s.size() != 1) || (s[0] != -1 && s[0] != c)) + emitError("Wrong rank for the scale."); + if ((b.size() != 1) || (b[0] != -1 && b[0] != c)) + emitError("Wrong rank for the bias."); + if ((m.size() != 1) || (m[0] != -1 && m[0] != c)) + emitError("Wrong rank for the mean."); + if ((v.size() != 1) || (v[0] != -1 && v[0] != c)) + emitError("Wrong rank for the variance."); + } + + // The output tensor of the same shape as the input. + getResult().setType(getOperand(0).getType()); +} + // TODO: // Verify that matrix sizes are valid for multiplication and addition. // Take into account the dimensionality of the matrix. diff --git a/src/pass/shape_inference_pass.cpp b/src/pass/shape_inference_pass.cpp index d62069a..7ff0374 100644 --- a/src/pass/shape_inference_pass.cpp +++ b/src/pass/shape_inference_pass.cpp @@ -128,6 +128,7 @@ public: op->getName().getStringRef() != "onnx.Softmax" && op->getName().getStringRef() != "onnx.Sqrt" && op->getName().getStringRef() != "onnx.ConvNoBias" && + op->getName().getStringRef() != "onnx.BatchNormalizationTestMode" && op->getName().getStringRef() != "onnx.Unsqueeze") return false; return llvm::any_of(op->getResultTypes(), [](Type result_type) { diff --git a/src/runtime/dyn_memref.cpp b/src/runtime/dyn_memref.cpp index a20001c..454947b 100644 --- a/src/runtime/dyn_memref.cpp +++ b/src/runtime/dyn_memref.cpp @@ -35,7 +35,7 @@ DynMemRef *getDynMemRef(OrderedDynMemRefDict *tensorDict, int idx) { void setDynMemRef(OrderedDynMemRefDict *tensorDict, int idx, DynMemRef *tensor) { - if (tensorDict->orderedNames.capacity() <= idx) + if (tensorDict->orderedNames.size() <= idx) tensorDict->orderedNames.resize(idx + 1); // The dynamic memref is essentially anonymous, since we are storing it by diff --git a/test/backend/test.py b/test/backend/test.py index 495f1c6..abd27db 100644 --- a/test/backend/test.py +++ b/test/backend/test.py @@ -301,6 +301,10 @@ test_to_enable = [ "test_matmul_3d_cpu", "test_matmul_4d_cpu", + # BatchNormalization (test mode) + "test_batchnorm_epsilon_cpu", + "test_batchnorm_example_cpu", + ] # Extract name of all test cases. diff --git a/test/mlir/onnx/onnx_lowering.mlir b/test/mlir/onnx/onnx_lowering.mlir index 08b2cf1..9da12ac 100644 --- a/test/mlir/onnx/onnx_lowering.mlir +++ b/test/mlir/onnx/onnx_lowering.mlir @@ -1313,3 +1313,63 @@ func @test_conv_no_bias_no_pad_w_strides(%arg0 : tensor<1x9x32x64xf32>, %arg1 : // CHECK: return [[RES]] : memref<1x5x14x29xf32> } + +func @test_batchnorm_testmode_Nd(%arg0: tensor<1x2x1x3xf32>, %arg1: tensor<2xf32>, %arg2: tensor<2xf32>, %arg3: tensor<2xf32>, %arg4: tensor<2xf32>) -> tensor<1x2x1x3xf32> { + %0 = "onnx.BatchNormalizationTestMode"(%arg0, %arg1, %arg2, %arg3, %arg4) : (tensor<1x2x1x3xf32>, tensor<2xf32>, tensor<2xf32>, tensor<2xf32>, tensor<2xf32>) -> tensor<1x2x1x3xf32> + return %0 : tensor<1x2x1x3xf32> + + // CHECK-LABEL: test_batchnorm_testmode_Nd + // CHECK: [[RES:%.+]] = alloc() : memref<1x2x1x3xf32> + // CHECK: [[EPSILON:%.+]] = constant 9.99999974E-6 : f32 + // CHECK: [[DEF_LOOPS:%.+]]:4 = krnl.define_loops 4 + // CHECK: [[OPT_LOOPS:%.+]]:4 = krnl.optimize_loops { + // CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1, [[DEF_LOOPS]]#2, [[DEF_LOOPS]]#3 + // CHECK: } : () -> (!krnl.loop, !krnl.loop, !krnl.loop, !krnl.loop) + // CHECK: krnl.iterate([[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#1 -> %arg5 = 0 to 2) { + // CHECK: [[SCALE:%.+]] = load %arg1[%arg5] : memref<2xf32> + // CHECK: [[BIAS:%.+]] = load %arg2[%arg5] : memref<2xf32> + // CHECK: [[MEAN:%.+]] = load %arg3[%arg5] : memref<2xf32> + // CHECK: [[VARIANCE:%.+]] = load %arg4[%arg5] : memref<2xf32> + // CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#2, [[OPT_LOOPS]]#3) with ([[DEF_LOOPS]]#0 -> %arg6 = 0 to 1, [[DEF_LOOPS]]#2 -> %arg7 = 0 to 1, [[DEF_LOOPS]]#3 -> %arg8 = 0 to 3) { + // CHECK: [[LOADED_VAL:%.+]] = load %arg0[%arg6, %arg5, %arg7, %arg8] : memref<1x2x1x3xf32> + // CHECK: [[DIVIDEND:%.+]] = subf [[LOADED_VAL]], [[MEAN]] : f32 + // CHECK: [[ADJUSTED_VARIANCE:%.+]] = addf [[VARIANCE]], [[EPSILON]] : f32 + // CHECK: [[DIVISOR:%.+]] = "krnl.sqrt"([[ADJUSTED_VARIANCE]]) : (f32) -> f32 + // CHECK: [[NORM:%.+]] = divf [[DIVIDEND]], [[DIVISOR]] : f32 + // CHECK: [[SCALE_NORM:%.+]] = mulf [[SCALE]], [[NORM]] : f32 + // CHECK: [[SHIFT_SCALE_NORM:%.+]] = addf [[SCALE_NORM]], [[BIAS]] : f32 + // CHECK: store [[SHIFT_SCALE_NORM]], [[RES]][%arg6, %arg5, %arg7, %arg8] : memref<1x2x1x3xf32> + // CHECK: } + // CHECK: } + // CHECK: return [[RES]] : memref<1x2x1x3xf32> +} + +func @test_batchnorm_testmode_1d(%arg0: tensor<10xf32>, %arg1: tensor<1xf32>, %arg2: tensor<1xf32>, %arg3: tensor<1xf32>, %arg4: tensor<1xf32>) -> tensor<10xf32> { + %0 = "onnx.BatchNormalizationTestMode"(%arg0, %arg1, %arg2, %arg3, %arg4) : (tensor<10xf32>, tensor<1xf32>, tensor<1xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<10xf32> + return %0 : tensor<10xf32> + + // CHECK-LABEL: test_batchnorm_testmode_1d + // CHECK: [[RES:%.+]] = alloc() : memref<10xf32> + // CHECK: [[EPSILON:%.+]] = constant 9.99999974E-6 : f32 + // CHECK: [[DEF_LOOPS:%.+]] = krnl.define_loops 1 + // CHECK: [[OPT_LOOPS:%.+]] = krnl.optimize_loops { + // CHECK: krnl.return_loops [[DEF_LOOPS]] + // CHECK: } : () -> !krnl.loop + // CHECK: %[[ZERO_INDEX:.+]] = constant 0 : index + // CHECK: [[SCALE:%.+]] = load %arg1[%[[ZERO_INDEX]]] : memref<1xf32> + // CHECK: [[BIAS:%.+]] = load %arg2[%[[ZERO_INDEX]]] : memref<1xf32> + // CHECK: [[MEAN:%.+]] = load %arg3[%[[ZERO_INDEX]]] : memref<1xf32> + // CHECK: [[VARIANCE:%.+]] = load %arg4[%[[ZERO_INDEX]]] : memref<1xf32> + // CHECK: krnl.iterate([[OPT_LOOPS]]) with ([[DEF_LOOPS]] -> %arg5 = 0 to 10) { + // CHECK: [[LOADED_VAL:%.+]] = load %arg0[%arg5] : memref<10xf32> + // CHECK: [[DIVIDEND:%.+]] = subf [[LOADED_VAL]], [[MEAN]] : f32 + // CHECK: [[ADJUSTED_VARIANCE:%.+]] = addf [[VARIANCE]], [[EPSILON]] : f32 + // CHECK: [[DIVISOR:%.+]] = "krnl.sqrt"([[ADJUSTED_VARIANCE]]) : (f32) -> f32 + // CHECK: [[NORM:%.+]] = divf [[DIVIDEND]], [[DIVISOR]] : f32 + // CHECK: [[SCALE_NORM:%.+]] = mulf [[SCALE]], [[NORM]] : f32 + // CHECK: [[SHIFT_SCALE_NORM:%.+]] = addf [[SCALE_NORM]], [[BIAS]] : f32 + // CHECK: store [[SHIFT_SCALE_NORM]], [[RES]][%arg5] : memref<10xf32> + // CHECK: } + // CHECK: return [[RES]] : memref<10xf32> +} +