Merge branch 'master' of github.com:clang-ykt/ONNF into shapeinference-pad
Conflicts: src/pass/shape_inference_pass.cpp
This commit is contained in:
commit
2281cc060f
|
@ -18,7 +18,7 @@ jobs:
|
|||
git submodule update --init --recursive
|
||||
# Use cached mlir installation if possible.
|
||||
- restore_cache:
|
||||
key: V4-LLVM-PROJECT-{{ arch }}
|
||||
key: V6-LLVM-PROJECT-{{ arch }}
|
||||
- run:
|
||||
name: Install MLIR
|
||||
command: |
|
||||
|
@ -29,7 +29,7 @@ jobs:
|
|||
source ONNF/utils/install-mlir.sh
|
||||
fi
|
||||
- save_cache:
|
||||
key: V4-LLVM-PROJECT-{{ arch }}
|
||||
key: V6-LLVM-PROJECT-{{ arch }}
|
||||
paths:
|
||||
- llvm-project
|
||||
- run:
|
||||
|
|
10
MLIR.cmake
10
MLIR.cmake
|
@ -58,9 +58,11 @@ find_mlir_lib(MLIRAffineOps)
|
|||
find_mlir_lib(MLIRAffineToStandard)
|
||||
find_mlir_lib(MLIRAnalysis)
|
||||
find_mlir_lib(MLIRDialect)
|
||||
find_mlir_lib(MLIREDSC)
|
||||
find_mlir_lib(MLIRExecutionEngine)
|
||||
find_mlir_lib(MLIRIR)
|
||||
find_mlir_lib(MLIRLLVMIR)
|
||||
find_mlir_lib(MLIRLoopAnalysis)
|
||||
find_mlir_lib(MLIRLoopToStandard)
|
||||
find_mlir_lib(MLIRLoopOps)
|
||||
find_mlir_lib(MLIRParser)
|
||||
|
@ -71,7 +73,8 @@ find_mlir_lib(MLIRTargetLLVMIR)
|
|||
find_mlir_lib(MLIRTransforms)
|
||||
find_mlir_lib(MLIRTransformUtils)
|
||||
find_mlir_lib(MLIRSupport)
|
||||
find_mlir_lib(MLIROptMain)
|
||||
find_mlir_lib(MLIRMlirOptMain)
|
||||
find_mlir_lib(MLIROptLib)
|
||||
find_mlir_lib(MLIRTargetLLVMIRModuleTranslation)
|
||||
find_mlir_lib(MLIRTargetLLVMIR)
|
||||
find_mlir_lib(MLIRTransformUtils)
|
||||
|
@ -117,12 +120,15 @@ set(MLIRLibsOnce
|
|||
${MLIRAffineToStandard}
|
||||
${MLIRAnalysis}
|
||||
${MLIRDialect}
|
||||
${MLIREDSC}
|
||||
${MLIRExecutionEngine}
|
||||
${MLIRIR}
|
||||
${MLIRLLVMIR}
|
||||
${MLIRLoopToStandard}
|
||||
${MLIRLoopOps}
|
||||
${MLIROptMain}
|
||||
${MLIRLoopAnalysis}
|
||||
${MLIRMlirOptMain}
|
||||
${MLIROptLib}
|
||||
${MLIRParser}
|
||||
${MLIRPass}
|
||||
${MLIRStandardOps}
|
||||
|
|
|
@ -332,6 +332,42 @@ ONNX BatchNormalization operation
|
|||
1. `saved_mean`: memref of any type values or tensor of any type values
|
||||
1. `saved_var`: memref of any type values or tensor of any type values
|
||||
|
||||
### onnx.BatchNormalizationTestMode (ONNXBatchNormalizationTestModeOp)
|
||||
ONNX BatchNormalization operation in test mode
|
||||
|
||||
#### Description:
|
||||
|
||||
|
||||
"Carries out batch normalization as described in the paper"
|
||||
"https://arxiv.org/abs/1502.03167. Depending on the mode it is being run,"
|
||||
"there are multiple cases for the number of outputs, which we list below:"
|
||||
""
|
||||
"Output case #1: Y, mean, var, saved_mean, saved_var (training mode)"
|
||||
"Output case #2: Y (test mode)"
|
||||
""
|
||||
"For previous (depreciated) non-spatial cases, implementors are suggested"
|
||||
"to flatten the input shape to (N x C*D1*D2 ..*Dn) before a BatchNormalization Op."
|
||||
"This operator has **optional** inputs/outputs. See [the doc](IR.md) for more details about the representation of optional arguments. An empty string may be used in the place of an actual argument's name to indicate a missing argument. Trailing optional arguments (those not followed by an argument that is present) may also be simply omitted."
|
||||
|
||||
#### Operands:
|
||||
|
||||
1. `X`: memref of any type values or tensor of any type values
|
||||
1. `scale`: memref of any type values or tensor of any type values
|
||||
1. `B`: memref of any type values or tensor of any type values
|
||||
1. `mean`: memref of any type values or tensor of any type values
|
||||
1. `var`: memref of any type values or tensor of any type values
|
||||
|
||||
#### Attributes:
|
||||
|
||||
| Attribute | MLIR Type | Description |
|
||||
| :-------: | :-------: | ----------- |
|
||||
| `epsilon` | `FloatAttr` | 32-bit float attribute attribute |
|
||||
| `momentum` | `FloatAttr` | 32-bit float attribute attribute |
|
||||
|
||||
#### Results:
|
||||
|
||||
1. `o_Y`: memref of any type values or tensor of any type values
|
||||
|
||||
### onnx.BitShift (ONNXBitShiftOp)
|
||||
ONNX BitShift operation
|
||||
|
||||
|
|
|
@ -36,6 +36,7 @@ special_attr_defaults = dict([
|
|||
special_op_handler = dict([
|
||||
("Conv", "ImportNodeConv"),
|
||||
("MaxPool", "ImportNodeMaxPool"),
|
||||
("BatchNormalization", "ImportNodeBatchNormalization"),
|
||||
("Gemm", "ImportNodeGemm"),
|
||||
("Pad", "ImportNodePad"),
|
||||
#("Transpose", "ImportNodeTranspose")
|
||||
|
|
|
@ -434,6 +434,21 @@ private:
|
|||
}
|
||||
}
|
||||
|
||||
/*!
|
||||
* Special handle for BatchNormalization operations.
|
||||
*/
|
||||
void ImportNodeBatchNormalization(onnx::NodeProto node, int nIn, int nOut) {
|
||||
int nOuts = node.output().size();
|
||||
if (nOuts == 1) {
|
||||
// Test mode with one output.
|
||||
ImportNodeOneOut<mlir::ONNXBatchNormalizationTestModeOp>(node, nIn,
|
||||
nOuts);
|
||||
} else {
|
||||
// Training mode with four trailing optional outputs. Not handled yet.
|
||||
ImportNodeMultipleOuts<mlir::ONNXBatchNormalizationOp>(node, nIn, nOuts);
|
||||
}
|
||||
}
|
||||
|
||||
/*!
|
||||
* Special handle for Gemm operations.
|
||||
*/
|
||||
|
|
|
@ -30,7 +30,7 @@
|
|||
}else if (OpName == "AveragePool") {
|
||||
ImportNodeOneOut<mlir::ONNXAveragePoolOp>(node, 1, 1);
|
||||
}else if (OpName == "BatchNormalization") {
|
||||
ImportNodeMultipleOuts<mlir::ONNXBatchNormalizationOp>(node, 5, 5);
|
||||
ImportNodeBatchNormalization(node, 5, 5);
|
||||
}else if (OpName == "BitShift") {
|
||||
ImportNodeOneOut<mlir::ONNXBitShiftOp>(node, 2, 1);
|
||||
}else if (OpName == "Cast") {
|
||||
|
|
|
@ -37,10 +37,18 @@ static bool hasAllConstantDimensions(MemRefType type) {
|
|||
return true;
|
||||
}
|
||||
|
||||
/// Convert the given TensorType into the corresponding MemRefType.
|
||||
static MemRefType convertTensorToMemRef(TensorType type) {
|
||||
assert(type.hasRank() && "expected only ranked shapes");
|
||||
return MemRefType::get(type.getShape(), type.getElementType());
|
||||
/// Get the corresponding MemRefType of a given TensorType/MemRefType.
|
||||
static MemRefType convertToMemRefType(Type type) {
|
||||
MemRefType memRefType;
|
||||
auto tensorType = type.dyn_cast<TensorType>();
|
||||
if (tensorType) {
|
||||
assert(tensorType.hasRank() && "expected only ranked shapes");
|
||||
memRefType =
|
||||
MemRefType::get(tensorType.getShape(), tensorType.getElementType());
|
||||
} else {
|
||||
memRefType = type.dyn_cast<MemRefType>();
|
||||
}
|
||||
return memRefType;
|
||||
}
|
||||
|
||||
/// Insert an allocation and deallocation for the given MemRefType.
|
||||
|
@ -396,6 +404,7 @@ Value mapToLowerScalarOp(Operation *op, ArrayRef<Type> result_types,
|
|||
#include "src/conversion/onnx_to_krnl/rewrite_patterns/tensor/unsqueeze.inc"
|
||||
// Neural network
|
||||
#include "src/conversion/onnx_to_krnl/rewrite_patterns/nn/conv.inc"
|
||||
#include "src/conversion/onnx_to_krnl/rewrite_patterns/nn/normalization.inc"
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// EntryPoint Op lowering to Krnl Entry Point.
|
||||
|
@ -425,9 +434,13 @@ public:
|
|||
struct TensorTypeConverter : public TypeConverter {
|
||||
using TypeConverter::TypeConverter;
|
||||
|
||||
LogicalResult convertType(Type t, SmallVectorImpl<Type> &results) override {
|
||||
if (auto tensor_type = t.dyn_cast<TensorType>()) {
|
||||
results.push_back(convertTensorToMemRef(tensor_type));
|
||||
TensorTypeConverter() {
|
||||
addConversion(convertType);
|
||||
}
|
||||
|
||||
static LogicalResult convertType(Type t, SmallVectorImpl<Type> &results) {
|
||||
if (auto type = convertToMemRefType(t)) {
|
||||
results.push_back(type);
|
||||
return success();
|
||||
}
|
||||
|
||||
|
@ -511,6 +524,7 @@ void FrontendToKrnlLoweringPass::runOnModule() {
|
|||
populateLoweringONNXIdentityOpPattern(patterns, &getContext());
|
||||
// Neural network
|
||||
populateLoweringONNXConvOpPattern(patterns, &getContext());
|
||||
populateLoweringONNXNormalizationOpPattern(patterns, &getContext());
|
||||
// Entry point
|
||||
patterns.insert<ONNXEntryPointLowering>(&getContext());
|
||||
|
||||
|
|
|
@ -476,11 +476,10 @@ struct ONNXElementwiseUnaryOpLowering : public ConversionPattern {
|
|||
// TODO: Check that the types are valid.
|
||||
// An element-wise unary operation must have all operands and the result of
|
||||
// the same type. This should have been verified by the verifier.
|
||||
auto tensorType = (*op->result_type_begin()).cast<TensorType>();
|
||||
auto loc = op->getLoc();
|
||||
|
||||
// Insert an allocation and deallocation for the result of this operation.
|
||||
auto memRefType = convertTensorToMemRef(tensorType);
|
||||
auto memRefType = convertToMemRefType(*op->result_type_begin());
|
||||
|
||||
// If the output has a dynamic dimension, pass the operands required for
|
||||
// each dynamic dimension to the AllocOp. The first operand of the
|
||||
|
@ -545,12 +544,11 @@ struct ONNXElementwiseVariadicOpLowering : public ConversionPattern {
|
|||
// TODO: Check that the types are valid.
|
||||
// An element-wise variadic operation must have all operands and the result
|
||||
// of the same type. This should have been verified by the verifier.
|
||||
auto tensorType = (*op->result_type_begin()).cast<TensorType>();
|
||||
auto loc = op->getLoc();
|
||||
auto numArgs = op->getNumOperands();
|
||||
|
||||
// Insert an allocation and deallocation for the result of this operation.
|
||||
auto memRefType = convertTensorToMemRef(tensorType);
|
||||
auto memRefType = convertToMemRefType(*op->result_type_begin());
|
||||
|
||||
Value alloc;
|
||||
bool insertDealloc = checkInsertDealloc(op);
|
||||
|
|
|
@ -8,33 +8,34 @@
|
|||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
template <typename GemmOp>
|
||||
struct ONNXGemmOpLowering : public ConversionPattern {
|
||||
ONNXGemmOpLowering(MLIRContext *ctx)
|
||||
: ConversionPattern(mlir::ONNXGemmOp::getOperationName(), 1, ctx) {}
|
||||
: ConversionPattern(GemmOp::getOperationName(), 1, ctx) {}
|
||||
|
||||
PatternMatchResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const final {
|
||||
auto tensorType = (*op->result_type_begin()).cast<TensorType>();
|
||||
auto loc = op->getLoc();
|
||||
auto has_bias = (operands.size() == 3);
|
||||
|
||||
Value A, B, C;
|
||||
A = operands[0];
|
||||
B = operands[1];
|
||||
if (has_bias)
|
||||
C = operands[2];
|
||||
|
||||
auto alphaAttr = FloatAttr::get(tensorType.getElementType(),
|
||||
llvm::dyn_cast<ONNXGemmOp>(op).alpha().convertToFloat());
|
||||
auto betaAttr = FloatAttr::get(tensorType.getElementType(),
|
||||
llvm::dyn_cast<ONNXGemmOp>(op).beta().convertToFloat());
|
||||
auto memRefType = convertToMemRefType(*op->result_type_begin());
|
||||
|
||||
auto alphaAttr = FloatAttr::get(memRefType.getElementType(),
|
||||
llvm::dyn_cast<GemmOp>(op).alpha().convertToFloat());
|
||||
auto betaAttr = FloatAttr::get(memRefType.getElementType(),
|
||||
llvm::dyn_cast<GemmOp>(op).beta().convertToFloat());
|
||||
auto alpha = rewriter.create<ConstantOp>(loc, alphaAttr);
|
||||
auto beta = rewriter.create<ConstantOp>(loc, betaAttr);
|
||||
|
||||
bool isTransA = (llvm::dyn_cast<ONNXGemmOp>(op).transA() != 0);
|
||||
bool isTransB = (llvm::dyn_cast<ONNXGemmOp>(op).transB() != 0);
|
||||
|
||||
// Result type
|
||||
auto memRefType = convertTensorToMemRef(tensorType);
|
||||
bool isTransA = (llvm::dyn_cast<GemmOp>(op).transA() != 0);
|
||||
bool isTransB = (llvm::dyn_cast<GemmOp>(op).transB() != 0);
|
||||
|
||||
// Insert an allocation and deallocation for the result of this operation.
|
||||
Value alloc;
|
||||
|
@ -118,6 +119,7 @@ struct ONNXGemmOpLowering : public ConversionPattern {
|
|||
// GemmOp supports unidirectional broadcasting from C to A*B.
|
||||
// Hence, it must be enough to get broadcasting information for C only.
|
||||
std::map<int, Value> broadcastedDimInfo;
|
||||
if (has_bias) {
|
||||
auto shape = C.getType().cast<MemRefType>().getShape();
|
||||
for (int i = 0; i < shape.size(); ++i) {
|
||||
if (shape[i] < 0) {
|
||||
|
@ -128,6 +130,7 @@ struct ONNXGemmOpLowering : public ConversionPattern {
|
|||
broadcastedDimInfo.insert(std::make_pair(i, isBroadcasted));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
auto outerIterateOp = rewriter.create<KrnlIterateOp>(loc, outerPack);
|
||||
|
||||
|
@ -157,14 +160,18 @@ struct ONNXGemmOpLowering : public ConversionPattern {
|
|||
auto matmulIterateOp = rewriter.create<KrnlIterateOp>(loc, reductionPack);
|
||||
|
||||
// Compute beta*C, and add up to alpha*A*B (unidirectional broadcasting)
|
||||
auto loopCIVs = getLoopIVsForBroadcasting(
|
||||
loc, rewriter, loopMNIVs, C, broadcastedDimInfo);
|
||||
auto loadedC = rewriter.create<LoadOp>(loc, C, loopCIVs);
|
||||
auto loadedAB = rewriter.create<LoadOp>(loc, alloc, loopMNIVs);
|
||||
auto alphaAB = rewriter.create<MulFOp>(loc, alpha, loadedAB);
|
||||
if (has_bias) {
|
||||
auto loopCIVs = getLoopIVsForBroadcasting(loc, rewriter, loopMNIVs, C,
|
||||
broadcastedDimInfo);
|
||||
auto loadedC = rewriter.create<LoadOp>(loc, C, loopCIVs);
|
||||
auto betaC = rewriter.create<MulFOp>(loc, beta, loadedC);
|
||||
auto Y = rewriter.create<AddFOp>(loc, alphaAB, betaC);
|
||||
rewriter.create<StoreOp>(loc, Y, alloc, loopMNIVs);
|
||||
} else {
|
||||
rewriter.create<StoreOp>(loc, alphaAB, alloc, loopMNIVs);
|
||||
}
|
||||
|
||||
// Insert instructions to do matrix multiplication: A*B
|
||||
Block &matmulIterationBlock = matmulIterateOp.bodyRegion().front();
|
||||
|
@ -205,5 +212,6 @@ struct ONNXGemmOpLowering : public ConversionPattern {
|
|||
|
||||
void populateLoweringONNXGemmOpPattern(
|
||||
OwningRewritePatternList &patterns, MLIRContext *ctx) {
|
||||
patterns.insert<ONNXGemmOpLowering>(ctx);
|
||||
patterns.insert<ONNXGemmOpLowering<ONNXGemmOp>>(ctx);
|
||||
patterns.insert<ONNXGemmOpLowering<ONNXGemmNoBiasOp>>(ctx);
|
||||
}
|
||||
|
|
|
@ -15,7 +15,6 @@ struct ONNXMatMulOpLowering : public ConversionPattern {
|
|||
PatternMatchResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const final {
|
||||
auto tensorType = (*op->result_type_begin()).cast<TensorType>();
|
||||
auto loc = op->getLoc();
|
||||
|
||||
Value A = operands[0];
|
||||
|
@ -29,7 +28,7 @@ struct ONNXMatMulOpLowering : public ConversionPattern {
|
|||
// - Both arguments are 1-D
|
||||
|
||||
// Result type
|
||||
auto memRefType = convertTensorToMemRef(tensorType);
|
||||
auto memRefType = convertToMemRefType(*op->result_type_begin());
|
||||
auto elementType = memRefType.getElementType();
|
||||
auto memRefShape = memRefType.getShape();
|
||||
|
||||
|
|
|
@ -145,9 +145,9 @@ struct ONNXReductionOpLowering : public ConversionPattern {
|
|||
auto loc = op->getLoc();
|
||||
auto memRefInType = operands[0].getType().cast<MemRefType>();
|
||||
auto memRefInShape = memRefInType.getShape();
|
||||
auto tensorOutType = (*op->result_type_begin()).cast<TensorType>();
|
||||
auto memRefOutType = convertToMemRefType(*op->result_type_begin());
|
||||
int64_t inRank = memRefInType.getRank();
|
||||
int64_t outRank = tensorOutType.getRank();
|
||||
int64_t outRank = memRefOutType.getRank();
|
||||
|
||||
// Get attributes
|
||||
ArrayAttr axisAttrs = llvm::dyn_cast<ONNXReductionOp>(op).axesAttr();
|
||||
|
@ -171,7 +171,6 @@ struct ONNXReductionOpLowering : public ConversionPattern {
|
|||
bool isKeepdims = (keepdims == 1) ? true : false;
|
||||
|
||||
// Get type information
|
||||
auto memRefOutType = convertTensorToMemRef(tensorOutType);
|
||||
auto memRefOutShape = memRefOutType.getShape();
|
||||
auto elementOutType = memRefOutType.getElementType();
|
||||
std::map<int64_t, int64_t> outInDimMap =
|
||||
|
|
|
@ -18,8 +18,8 @@ struct ONNXSoftmaxOpLowering : public ConversionPattern {
|
|||
// let exp_x = exp(x - max_x) in
|
||||
// let sum = sum(exp_x) in
|
||||
// exp_x / sum
|
||||
auto tensorType = (*op->result_type_begin()).cast<RankedTensorType>();
|
||||
int64_t rank = tensorType.getRank();
|
||||
auto memRefType = convertToMemRefType(*op->result_type_begin());
|
||||
int64_t rank = memRefType.getRank();
|
||||
int64_t axis = llvm::dyn_cast<ONNXSoftmaxOp>(op).axis().getSExtValue();
|
||||
axis = axis >= 0 ? axis : rank + axis;
|
||||
assert(axis >= -rank && axis <= rank - 1);
|
||||
|
@ -27,7 +27,6 @@ struct ONNXSoftmaxOpLowering : public ConversionPattern {
|
|||
auto loc = op->getLoc();
|
||||
|
||||
// Insert an allocation and deallocation for the result of this operation.
|
||||
auto memRefType = convertTensorToMemRef(tensorType);
|
||||
auto elementType = memRefType.getElementType();
|
||||
|
||||
Value alloc;
|
||||
|
|
|
@ -15,10 +15,9 @@ struct ONNXConvNoBiasOpLowering : public ConversionPattern {
|
|||
PatternMatchResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const final {
|
||||
auto tensorType = (*op->result_type_begin()).cast<TensorType>();
|
||||
auto loc = op->getLoc();
|
||||
// Insert an allocation and deallocation for the result of this operation.
|
||||
auto memRefType = convertTensorToMemRef(tensorType);
|
||||
auto memRefType = convertToMemRefType(*op->result_type_begin());
|
||||
Value alloc;
|
||||
bool insertDealloc = checkInsertDealloc(op);
|
||||
ONNXConvNoBiasOp convOp = llvm::dyn_cast<ONNXConvNoBiasOp>(op);
|
||||
|
|
|
@ -0,0 +1,140 @@
|
|||
//===----- normalization.inc - Lowering Normalization Ops -----------------===//
|
||||
//
|
||||
// Copyright 2019 The IBM Research Authors.
|
||||
//
|
||||
// =============================================================================
|
||||
//
|
||||
// This file lowers ONNX Normalization Operators to Krnl dialect.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
struct ONNXBatchNormalizationTestModeOpLowering : public ConversionPattern {
|
||||
ONNXBatchNormalizationTestModeOpLowering(MLIRContext *ctx)
|
||||
: ConversionPattern(
|
||||
mlir::ONNXBatchNormalizationTestModeOp::getOperationName(), 1,
|
||||
ctx) {}
|
||||
PatternMatchResult matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter & rewriter) const final {
|
||||
// batchnorm{epsilon}(x, scale, bias, mean, variance) =
|
||||
// scale * (x - mean) / sqrt(variance + epsilon) + bias
|
||||
auto loc = op->getLoc();
|
||||
|
||||
auto memRefType = convertToMemRefType(*op->result_type_begin());
|
||||
auto epsilonAttr =
|
||||
FloatAttr::get(memRefType.getElementType(),
|
||||
llvm::dyn_cast<ONNXBatchNormalizationTestModeOp>(op)
|
||||
.epsilon()
|
||||
.convertToFloat());
|
||||
auto epsilon = rewriter.create<ConstantOp>(loc, epsilonAttr);
|
||||
|
||||
auto operand = operands[0];
|
||||
auto scale = operands[1];
|
||||
auto bias = operands[2];
|
||||
auto mean = operands[3];
|
||||
auto variance = operands[4];
|
||||
|
||||
// Insert an allocation and deallocation for the result of this operation.
|
||||
Value alloc;
|
||||
bool insertDealloc = checkInsertDealloc(op);
|
||||
|
||||
if (hasAllConstantDimensions(memRefType))
|
||||
alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc);
|
||||
else
|
||||
alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc,
|
||||
{operand});
|
||||
|
||||
// Operand's dimensions can be in the form of NxCxD1xD2x...xDn or N.
|
||||
// In case of N, C is assumed to be 1.
|
||||
// Shapes of scale, bias, mean and variance must be C.
|
||||
// Computation of BatchNormalization is done as if scale, bias, mean, and
|
||||
// variance are reshaped to Cx1x1x...x1.
|
||||
|
||||
// rank
|
||||
int64_t rank = memRefType.getRank();
|
||||
|
||||
std::vector<Value> originalLoops;
|
||||
std::vector<Value> optimizedLoops;
|
||||
Block *optimizationBlock =
|
||||
defineLoops(rewriter, loc, originalLoops, optimizedLoops, rank);
|
||||
|
||||
// Create a KrnlIterateOp along C dimension.
|
||||
// This will be the outer-most loop in order to re-use scale, bias,
|
||||
// mean and variance.
|
||||
|
||||
SmallVector<Value, 1> loopCIVs;
|
||||
if (rank > 1) {
|
||||
KrnlIterateOperandPack cPack(rewriter, originalLoops[1],
|
||||
optimizedLoops[1]);
|
||||
addDimensionToPack(rewriter, loc, cPack, operand, 1);
|
||||
auto cIterateOp = rewriter.create<KrnlIterateOp>(loc, cPack);
|
||||
Block &cIterationBlock = cIterateOp.bodyRegion().front();
|
||||
rewriter.setInsertionPointToStart(&cIterationBlock);
|
||||
for (auto arg : cIterationBlock.getArguments())
|
||||
loopCIVs.emplace_back(arg);
|
||||
} else {
|
||||
loopCIVs.emplace_back(rewriter.create<ConstantIndexOp>(loc, 0));
|
||||
}
|
||||
|
||||
auto scaleVal = rewriter.create<LoadOp>(loc, scale, loopCIVs);
|
||||
auto biasVal = rewriter.create<LoadOp>(loc, bias, loopCIVs);
|
||||
auto meanVal = rewriter.create<LoadOp>(loc, mean, loopCIVs);
|
||||
auto varianceVal = rewriter.create<LoadOp>(loc, variance, loopCIVs);
|
||||
|
||||
// Create a KrnlIterateOp along the other dimensions.
|
||||
SmallVector<int64_t, 4> axes;
|
||||
axes.emplace_back(0);
|
||||
for (int64_t i = 2; i < rank; ++i)
|
||||
axes.emplace_back(i);
|
||||
std::vector<Value> packLoops, packOptimizedLoops;
|
||||
for (int i = 0; i < axes.size(); ++i) {
|
||||
packLoops.emplace_back(originalLoops[axes[i]]);
|
||||
packOptimizedLoops.emplace_back(optimizedLoops[axes[i]]);
|
||||
}
|
||||
KrnlIterateOperandPack pack(rewriter, packLoops, packOptimizedLoops);
|
||||
for (int i = 0; i < axes.size(); ++i) {
|
||||
addDimensionToPack(rewriter, loc, pack, operand, axes[i]);
|
||||
}
|
||||
auto iterateOp = rewriter.create<KrnlIterateOp>(loc, pack);
|
||||
|
||||
// No optimization
|
||||
rewriter.setInsertionPointToEnd(optimizationBlock);
|
||||
rewriter.create<KrnlReturnLoopsOp>(loc, originalLoops);
|
||||
|
||||
Block &iterationBlock = iterateOp.bodyRegion().front();
|
||||
rewriter.setInsertionPointToStart(&iterationBlock);
|
||||
|
||||
SmallVector<Value, 4> loopIVs;
|
||||
auto args = iterationBlock.getArguments();
|
||||
if (args.size() > 1) {
|
||||
loopIVs.emplace_back(args[0]);
|
||||
loopIVs.emplace_back(loopCIVs[0]); // Insert C back.
|
||||
for (int i = 1; i < args.size(); ++i)
|
||||
loopIVs.emplace_back(args[i]);
|
||||
} else {
|
||||
loopIVs.emplace_back(args[0]);
|
||||
}
|
||||
|
||||
auto xVal = rewriter.create<LoadOp>(loc, operand, loopIVs);
|
||||
// normalize
|
||||
auto dividend = rewriter.create<SubFOp>(loc, xVal, meanVal);
|
||||
auto adjustedVarianceVal =
|
||||
rewriter.create<AddFOp>(loc, varianceVal, epsilon);
|
||||
auto divisor = rewriter.create<KrnlSqrtOp>(loc, memRefType.getElementType(),
|
||||
adjustedVarianceVal);
|
||||
auto normVal = rewriter.create<DivFOp>(loc, dividend, divisor);
|
||||
// scale and shift
|
||||
auto scaleNormVal = rewriter.create<MulFOp>(loc, scaleVal, normVal);
|
||||
auto shiftScaleNormVal =
|
||||
rewriter.create<AddFOp>(loc, scaleNormVal, biasVal);
|
||||
rewriter.create<StoreOp>(loc, shiftScaleNormVal, alloc, loopIVs);
|
||||
|
||||
rewriter.replaceOp(op, alloc);
|
||||
|
||||
return matchSuccess();
|
||||
}
|
||||
};
|
||||
|
||||
void populateLoweringONNXNormalizationOpPattern(
|
||||
OwningRewritePatternList &patterns, MLIRContext *ctx) {
|
||||
patterns.insert<ONNXBatchNormalizationTestModeOpLowering>(ctx);
|
||||
}
|
|
@ -15,12 +15,11 @@ struct ONNXReshapeOpLowering : public ConversionPattern {
|
|||
PatternMatchResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const final {
|
||||
auto tensorType = (*op->result_type_begin()).cast<TensorType>();
|
||||
auto inputShape = operands[0].getType().cast<MemRefType>().getShape();
|
||||
auto loc = op->getLoc();
|
||||
|
||||
// Insert an allocation and deallocation for the result of this operation.
|
||||
auto memRefType = convertTensorToMemRef(tensorType);
|
||||
auto memRefType = convertToMemRefType(*op->result_type_begin());
|
||||
auto memRefShape = memRefType.getShape();
|
||||
Value alloc;
|
||||
|
||||
|
|
|
@ -15,10 +15,9 @@ struct ONNXTransposeOpLowering : public ConversionPattern {
|
|||
PatternMatchResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const final {
|
||||
auto tensorType = (*op->result_type_begin()).cast<TensorType>();
|
||||
auto loc = op->getLoc();
|
||||
// Insert an allocation and deallocation for the result of this operation.
|
||||
auto memRefType = convertTensorToMemRef(tensorType);
|
||||
auto memRefType = convertToMemRefType(*op->result_type_begin());
|
||||
Value alloc;
|
||||
bool insertDealloc = checkInsertDealloc(op);
|
||||
|
||||
|
|
|
@ -16,8 +16,8 @@ struct ONNXUnsqueezeOpLowering : public ConversionPattern {
|
|||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const final {
|
||||
auto loc = op->getLoc();
|
||||
auto tensorType = (*op->result_type_begin()).cast<TensorType>();
|
||||
int outRank = tensorType.getRank();
|
||||
auto memRefType = convertToMemRefType(*op->result_type_begin());
|
||||
int outRank = memRefType.getRank();
|
||||
|
||||
// Assume that `axes` has been validated by shape inference.
|
||||
// So, here we just get it.
|
||||
|
@ -30,7 +30,6 @@ struct ONNXUnsqueezeOpLowering : public ConversionPattern {
|
|||
}
|
||||
|
||||
// Insert an allocation and deallocation for the result of this operation.
|
||||
auto memRefType = convertTensorToMemRef(tensorType);
|
||||
Value alloc;
|
||||
|
||||
// Compute size in bytes.
|
||||
|
|
|
@ -146,6 +146,31 @@ def ONNXMaxPoolSingleOutOp: ONNX_Op<"MaxPoolSingleOut",
|
|||
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$o_Y);
|
||||
}
|
||||
|
||||
def ONNXBatchNormalizationTestModeOp: ONNX_Op<"BatchNormalizationTestMode",
|
||||
[NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
|
||||
let summary = "ONNX BatchNormalization operation in test mode";
|
||||
let description = [{
|
||||
"Carries out batch normalization as described in the paper"
|
||||
"https://arxiv.org/abs/1502.03167. Depending on the mode it is being run,"
|
||||
"there are multiple cases for the number of outputs, which we list below:"
|
||||
""
|
||||
"Output case #1: Y, mean, var, saved_mean, saved_var (training mode)"
|
||||
"Output case #2: Y (test mode)"
|
||||
""
|
||||
"For previous (depreciated) non-spatial cases, implementors are suggested"
|
||||
"to flatten the input shape to (N x C*D1*D2 ..*Dn) before a BatchNormalization Op."
|
||||
"This operator has **optional** inputs/outputs. See [the doc](IR.md) for more details about the representation of optional arguments. An empty string may be used in the place of an actual argument's name to indicate a missing argument. Trailing optional arguments (those not followed by an argument that is present) may also be simply omitted."
|
||||
}];
|
||||
let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$X,
|
||||
AnyTypeOf<[AnyMemRef, AnyTensor]>:$scale,
|
||||
AnyTypeOf<[AnyMemRef, AnyTensor]>:$B,
|
||||
AnyTypeOf<[AnyMemRef, AnyTensor]>:$mean,
|
||||
AnyTypeOf<[AnyMemRef, AnyTensor]>:$var,
|
||||
DefaultValuedAttr<F32Attr, "1e-05">:$epsilon,
|
||||
DefaultValuedAttr<F32Attr, "0.9">:$momentum);
|
||||
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$o_Y);
|
||||
}
|
||||
|
||||
def ONNXPadConstantValueOp : ONNX_Op<"PadConstantValue",
|
||||
[NoSideEffect ]> {
|
||||
let summary = "ONNX Pad operation with constant padding value";
|
||||
|
|
|
@ -586,12 +586,72 @@ void ONNXGemmNoBiasOp::inferShapes() {
|
|||
return;
|
||||
auto lhsTy = getOperand(0).getType().cast<RankedTensorType>();
|
||||
auto rhsTy = getOperand(1).getType().cast<RankedTensorType>();
|
||||
|
||||
int64_t M, N, K_A, K_B;
|
||||
M = (transA() == 0) ? lhsTy.getShape()[0] : lhsTy.getShape()[1];
|
||||
K_A = (transA() == 0) ? lhsTy.getShape()[1] : lhsTy.getShape()[0];
|
||||
N = (transB() == 0) ? rhsTy.getShape()[1] : rhsTy.getShape()[0];
|
||||
K_B = (transB() == 0) ? rhsTy.getShape()[0] : rhsTy.getShape()[1];
|
||||
|
||||
if ((K_A != -1) and (K_B != -1) and (K_A != K_B)) {
|
||||
emitError("Tensor shapes mismatched.");
|
||||
}
|
||||
|
||||
SmallVector<int64_t, 2> dims;
|
||||
dims.emplace_back(lhsTy.getShape()[0]);
|
||||
dims.emplace_back(rhsTy.getShape()[1]);
|
||||
dims.emplace_back(M);
|
||||
dims.emplace_back(N);
|
||||
getResult().setType(RankedTensorType::get(dims, lhsTy.getElementType()));
|
||||
}
|
||||
|
||||
/// BatchNormalizationTestMode
|
||||
void ONNXBatchNormalizationTestModeOp::inferShapes() {
|
||||
// Cannot infer shape if no shape exists.
|
||||
if (!getOperand(0).getType().isa<RankedTensorType>() ||
|
||||
!getOperand(1).getType().isa<RankedTensorType>() ||
|
||||
!getOperand(2).getType().isa<RankedTensorType>() ||
|
||||
!getOperand(3).getType().isa<RankedTensorType>() ||
|
||||
!getOperand(4).getType().isa<RankedTensorType>())
|
||||
return;
|
||||
|
||||
auto input = getOperand(0).getType().cast<RankedTensorType>();
|
||||
auto scale = getOperand(1).getType().cast<RankedTensorType>();
|
||||
auto bias = getOperand(2).getType().cast<RankedTensorType>();
|
||||
auto mean = getOperand(3).getType().cast<RankedTensorType>();
|
||||
auto variance = getOperand(4).getType().cast<RankedTensorType>();
|
||||
|
||||
// Check whether the shapes of scale, bias, mean and variance are valid.
|
||||
// Operand's dimensions can be in the form of NxCxD1xD2x...xDn or N.
|
||||
// In case of N, C is assumed to be 1.
|
||||
// Shapes of scale, bias, mean and variance must be C.
|
||||
int64_t c = -1;
|
||||
if (input.getShape().size() == 1) {
|
||||
c = 1;
|
||||
} else if (input.getShape().size() > 2) {
|
||||
c = (input.getShape()[1] != -1) ? input.getShape()[1] : -1;
|
||||
} else {
|
||||
emitError("Wrong rank for the input.");
|
||||
}
|
||||
|
||||
if (c != -1) {
|
||||
auto s = scale.getShape();
|
||||
auto b = bias.getShape();
|
||||
auto m = mean.getShape();
|
||||
auto v = variance.getShape();
|
||||
|
||||
if ((s.size() != 1) || (s[0] != -1 && s[0] != c))
|
||||
emitError("Wrong rank for the scale.");
|
||||
if ((b.size() != 1) || (b[0] != -1 && b[0] != c))
|
||||
emitError("Wrong rank for the bias.");
|
||||
if ((m.size() != 1) || (m[0] != -1 && m[0] != c))
|
||||
emitError("Wrong rank for the mean.");
|
||||
if ((v.size() != 1) || (v[0] != -1 && v[0] != c))
|
||||
emitError("Wrong rank for the variance.");
|
||||
}
|
||||
|
||||
// The output tensor of the same shape as the input.
|
||||
getResult().setType(getOperand(0).getType());
|
||||
}
|
||||
|
||||
// TODO:
|
||||
// Verify that matrix sizes are valid for multiplication and addition.
|
||||
// Take into account the dimensionality of the matrix.
|
||||
|
|
|
@ -24,6 +24,7 @@
|
|||
#include "mlir/Conversion/LoopToStandard/ConvertLoopToStandard.h"
|
||||
#include "mlir/ExecutionEngine/ExecutionEngine.h"
|
||||
#include "mlir/ExecutionEngine/OptUtils.h"
|
||||
#include "mlir/InitAllDialects.h"
|
||||
#include "mlir/IR/MLIRContext.h"
|
||||
#include "mlir/IR/Module.h"
|
||||
#include "mlir/Parser.h"
|
||||
|
@ -69,6 +70,10 @@ void EmitLLVMBitCode(const mlir::OwningModuleRef &module) {
|
|||
}
|
||||
|
||||
int main(int argc, char *argv[]) {
|
||||
mlir::registerDialect<mlir::AffineOpsDialect>();
|
||||
mlir::registerDialect<mlir::LLVM::LLVMDialect>();
|
||||
mlir::registerDialect<mlir::loop::LoopOpsDialect>();
|
||||
mlir::registerDialect<mlir::StandardOpsDialect>();
|
||||
mlir::registerDialect<mlir::ONNXOpsDialect>();
|
||||
mlir::registerDialect<mlir::KrnlOpsDialect>();
|
||||
|
||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -130,6 +130,7 @@ public:
|
|||
op->getName().getStringRef() != "onnx.ConvNoBias" &&
|
||||
op->getName().getStringRef() != "onnx.PadConstantPad" &&
|
||||
op->getName().getStringRef() != "onnx.PadConstantValuePad" &&
|
||||
op->getName().getStringRef() != "onnx.BatchNormalizationTestMode" &&
|
||||
op->getName().getStringRef() != "onnx.Unsqueeze")
|
||||
return false;
|
||||
return llvm::any_of(op->getResultTypes(), [](Type result_type) {
|
||||
|
|
|
@ -35,7 +35,7 @@ DynMemRef *getDynMemRef(OrderedDynMemRefDict *tensorDict, int idx) {
|
|||
|
||||
void setDynMemRef(OrderedDynMemRefDict *tensorDict, int idx,
|
||||
DynMemRef *tensor) {
|
||||
if (tensorDict->orderedNames.capacity() <= idx)
|
||||
if (tensorDict->orderedNames.size() <= idx)
|
||||
tensorDict->orderedNames.resize(idx + 1);
|
||||
|
||||
// The dynamic memref is essentially anonymous, since we are storing it by
|
||||
|
|
|
@ -9,6 +9,7 @@
|
|||
#include <llvm/Support/CommandLine.h>
|
||||
#include <llvm/Support/InitLLVM.h>
|
||||
#include <llvm/Support/ToolOutputFile.h>
|
||||
#include <mlir/InitAllDialects.h>
|
||||
#include <mlir/Pass/Pass.h>
|
||||
#include <mlir/Pass/PassManager.h>
|
||||
#include <mlir/Support/FileUtilities.h>
|
||||
|
@ -46,6 +47,10 @@ static llvm::cl::opt<bool> verify_passes(
|
|||
llvm::cl::init(true));
|
||||
|
||||
int main(int argc, char **argv) {
|
||||
mlir::registerDialect<mlir::AffineOpsDialect>();
|
||||
mlir::registerDialect<mlir::LLVM::LLVMDialect>();
|
||||
mlir::registerDialect<mlir::loop::LoopOpsDialect>();
|
||||
mlir::registerDialect<mlir::StandardOpsDialect>();
|
||||
llvm::InitLLVM y(argc, argv);
|
||||
|
||||
mlir::registerDialect<mlir::ONNXOpsDialect>();
|
||||
|
|
|
@ -224,7 +224,10 @@ public:
|
|||
|
||||
// Based on the static entry point type signature, unpack dynamic memory
|
||||
// refs to corresponding static memory refs.
|
||||
auto *staticEntryPointFunc = module.lookupSymbol(staticEntryPointFuncName);
|
||||
auto wrappedStaticEntryPointFuncName =
|
||||
"_mlir_ciface_" + staticEntryPointFuncName.lower();
|
||||
auto *staticEntryPointFunc =
|
||||
module.lookupSymbol(wrappedStaticEntryPointFuncName);
|
||||
assert(staticEntryPointFunc &&
|
||||
isa<LLVM::LLVMFuncOp>(staticEntryPointFunc) &&
|
||||
"entry point func must exist and be an llvm func op");
|
||||
|
@ -268,7 +271,8 @@ public:
|
|||
// Call static entry point with the memref ptrs created, and get output.
|
||||
auto outputMemRefs = rewriter.create<LLVM::CallOp>(
|
||||
loc, staticEntryPointTy.getFunctionResultType(),
|
||||
rewriter.getSymbolRefAttr(staticEntryPointFuncName), staticInputs);
|
||||
rewriter.getSymbolRefAttr(wrappedStaticEntryPointFuncName),
|
||||
staticInputs);
|
||||
|
||||
// Create wrapped output.
|
||||
auto wrappedOutput = callApi(rewriter, loc, apiRegistry,
|
||||
|
@ -563,7 +567,9 @@ void KrnlToLLVMLoweringPass::runOnModule() {
|
|||
OwningRewritePatternList patterns;
|
||||
populateAffineToStdConversionPatterns(patterns, &getContext());
|
||||
populateLoopToStdConversionPatterns(patterns, &getContext());
|
||||
populateStdToLLVMConversionPatterns(typeConverter, patterns);
|
||||
populateStdToLLVMConversionPatterns(typeConverter, patterns,
|
||||
/*useAlloca=*/false,
|
||||
/*emitCWrapper=*/true);
|
||||
|
||||
// Lower from the `krnl` dialect i.e. the Reshape operation.
|
||||
patterns.insert<KrnlMemcpyOpLowering, KrnlEntryPointOpLowering,
|
||||
|
|
|
@ -98,7 +98,7 @@ test_to_enable = [
|
|||
"test_gemm_alpha_cpu",
|
||||
"test_gemm_beta_cpu",
|
||||
"test_gemm_default_matrix_bias_cpu",
|
||||
# "test_gemm_default_no_bias_cpu", <- error, need support for optional operands
|
||||
"test_gemm_default_no_bias_cpu",
|
||||
"test_gemm_default_scalar_bias_cpu",
|
||||
"test_gemm_default_single_elem_vector_bias_cpu",
|
||||
"test_gemm_default_vector_bias_cpu",
|
||||
|
@ -301,6 +301,10 @@ test_to_enable = [
|
|||
"test_matmul_3d_cpu",
|
||||
"test_matmul_4d_cpu",
|
||||
|
||||
# BatchNormalization (test mode)
|
||||
"test_batchnorm_epsilon_cpu",
|
||||
"test_batchnorm_example_cpu",
|
||||
|
||||
]
|
||||
|
||||
# Extract name of all test cases.
|
||||
|
|
|
@ -5,10 +5,18 @@ func @test_reshape(%arg0 : tensor<?x10xf32>, %arg1 : tensor<4xi32>) -> tensor<*x
|
|||
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||
|
||||
// CHECK: llvm.func @llvm.memcpy.p0i8.p0i8.i64(!llvm<"i8*">, !llvm<"i8*">, !llvm.i64, !llvm.i1)
|
||||
// CHECK: [[TMP:%.+]] = llvm.mlir.undef : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
|
||||
// CHECK: llvm.insertvalue %arg0, %0[0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
|
||||
// CHECK: llvm.insertvalue %arg1, %1[1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
|
||||
// CHECK: llvm.insertvalue %arg2, %2[2] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
|
||||
// CHECK: llvm.insertvalue %arg3, %3[3, 0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
|
||||
// CHECK: llvm.insertvalue %arg5, %4[4, 0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
|
||||
// CHECK: llvm.insertvalue %arg4, %5[3, 1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
|
||||
// CHECK: [[TMP1:%.+]] = llvm.insertvalue %arg6, %6[4, 1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
|
||||
// CHECK: [[RES:%.+]] = llvm.insertvalue {{.*}}[4, 3] : !llvm<"{ float*, float*, i64, [4 x i64], [4 x i64] }">
|
||||
// CHECK: [[EXT_VAL_0:%.+]] = llvm.extractvalue [[RES]][1] : !llvm<"{ float*, float*, i64, [4 x i64], [4 x i64] }">
|
||||
// CHECK: [[DST:%.+]] = llvm.bitcast [[EXT_VAL_0]] : !llvm<"float*"> to !llvm<"i8*">
|
||||
// CHECK: [[EXT_VAL_1:%.+]] = llvm.extractvalue %0[1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
|
||||
// CHECK: [[EXT_VAL_1:%.+]] = llvm.extractvalue [[TMP1]][1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
|
||||
// CHECK: [[SRC:%.+]] = llvm.bitcast [[EXT_VAL_1]] : !llvm<"float*"> to !llvm<"i8*">
|
||||
// CHECK: [[SIZE:%.+]] = llvm.sext %{{.*}} : !llvm.i64 to !llvm.i64
|
||||
// CHECK: [[VOLATILE:%.+]] = llvm.mlir.constant(0 : i1) : !llvm.i1
|
||||
|
|
|
@ -795,12 +795,42 @@ func @test_gemm(%arg0 : tensor<5x10xf32>, %arg1 : tensor<5x10xf32>, %arg2: tenso
|
|||
// CHECK: [[SUM:%.+]] = addf [[Y]], [[AB]] : f32
|
||||
// CHECK: store [[SUM]], [[RES]][%arg3, %arg4] : memref<10x10xf32>
|
||||
// CHECK: }
|
||||
// CHECK: [[C:%.+]] = load %arg2[%arg4] : memref<10xf32>
|
||||
// CHECK: [[LOAD_Y:%.+]] = load [[RES]][%arg3, %arg4] : memref<10x10xf32>
|
||||
// CHECK: [[ALPHA_AB:%.+]] = mulf [[ALPHA]], [[LOAD_Y]] : f32
|
||||
// CHECK: [[C:%.+]] = load %arg2[%arg4] : memref<10xf32>
|
||||
// CHECK: [[BETA_C:%.+]] = mulf [[BETA]], [[C]] : f32
|
||||
// CHECK: [[Y_RES:%.+]] = addf [[ALPHA_AB]], [[BETA_C]] : f32
|
||||
// CHECK: store [[Y_RES]], [[RES]][%arg3, %arg4] : memref<10x10xf32>
|
||||
// CHECK: }
|
||||
// CHECK: return [[RES]] : memref<10x10xf32>
|
||||
// CHECK: }
|
||||
}
|
||||
|
||||
func @test_gemm_no_bias(%arg0 : tensor<5x10xf32>, %arg1 : tensor<5x10xf32>) -> tensor<*xf32> {
|
||||
%0 ="onnx.GemmNoBias"(%arg0, %arg1) {alpha = 1.0 : f32, beta = 5.0 : f32, transA = 1, transB = 0} : (tensor<5x10xf32>, tensor<5x10xf32>) -> tensor<*xf32>
|
||||
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||
|
||||
// CHECK-LABEL: test_gemm_no_bias
|
||||
// CHECK: [[RES:%.+]] = alloc() : memref<10x10xf32>
|
||||
// CHECK: [[ALPHA:%.+]] = constant 1.000000e+00 : f32
|
||||
// CHECK: [[BETA:%.+]] = constant 5.000000e+00 : f32
|
||||
// CHECK: [[DEF_LOOPS:%.+]]:3 = krnl.define_loops 3
|
||||
// CHECK: [[OPT_LOOPS:%.+]]:3 = krnl.optimize_loops {
|
||||
// CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1, [[DEF_LOOPS]]#2
|
||||
// CHECK: } : () -> (!krnl.loop, !krnl.loop, !krnl.loop)
|
||||
// CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg2 = 0 to 10, [[DEF_LOOPS]]#1 -> %arg3 = 0 to 10) {
|
||||
// CHECK: krnl.iterate([[OPT_LOOPS]]#2) with ([[DEF_LOOPS]]#2 -> %arg4 = 0 to 5) {
|
||||
// CHECK: [[A:%.+]] = load %arg0[%arg4, %arg2] : memref<5x10xf32>
|
||||
// CHECK: [[B:%.+]] = load %arg1[%arg4, %arg3] : memref<5x10xf32>
|
||||
// CHECK: [[Y:%.+]] = load [[RES]][%arg2, %arg3] : memref<10x10xf32>
|
||||
// CHECK: [[AB:%.+]] = mulf [[A]], [[B]] : f32
|
||||
// CHECK: [[SUM:%.+]] = addf [[Y]], [[AB]] : f32
|
||||
// CHECK: store [[SUM]], [[RES]][%arg2, %arg3] : memref<10x10xf32>
|
||||
// CHECK: }
|
||||
// CHECK: [[LOAD_Y:%.+]] = load [[RES]][%arg2, %arg3] : memref<10x10xf32>
|
||||
// CHECK: [[ALPHA_AB:%.+]] = mulf [[ALPHA]], [[LOAD_Y]] : f32
|
||||
// CHECK: store [[ALPHA_AB]], [[RES]][%arg2, %arg3] : memref<10x10xf32>
|
||||
// CHECK: }
|
||||
// CHECK: return [[RES]] : memref<10x10xf32>
|
||||
// CHECK: }
|
||||
}
|
||||
|
@ -1283,3 +1313,63 @@ func @test_conv_no_bias_no_pad_w_strides(%arg0 : tensor<1x9x32x64xf32>, %arg1 :
|
|||
|
||||
// CHECK: return [[RES]] : memref<1x5x14x29xf32>
|
||||
}
|
||||
|
||||
func @test_batchnorm_testmode_Nd(%arg0: tensor<1x2x1x3xf32>, %arg1: tensor<2xf32>, %arg2: tensor<2xf32>, %arg3: tensor<2xf32>, %arg4: tensor<2xf32>) -> tensor<1x2x1x3xf32> {
|
||||
%0 = "onnx.BatchNormalizationTestMode"(%arg0, %arg1, %arg2, %arg3, %arg4) : (tensor<1x2x1x3xf32>, tensor<2xf32>, tensor<2xf32>, tensor<2xf32>, tensor<2xf32>) -> tensor<1x2x1x3xf32>
|
||||
return %0 : tensor<1x2x1x3xf32>
|
||||
|
||||
// CHECK-LABEL: test_batchnorm_testmode_Nd
|
||||
// CHECK: [[RES:%.+]] = alloc() : memref<1x2x1x3xf32>
|
||||
// CHECK: [[EPSILON:%.+]] = constant 9.99999974E-6 : f32
|
||||
// CHECK: [[DEF_LOOPS:%.+]]:4 = krnl.define_loops 4
|
||||
// CHECK: [[OPT_LOOPS:%.+]]:4 = krnl.optimize_loops {
|
||||
// CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1, [[DEF_LOOPS]]#2, [[DEF_LOOPS]]#3
|
||||
// CHECK: } : () -> (!krnl.loop, !krnl.loop, !krnl.loop, !krnl.loop)
|
||||
// CHECK: krnl.iterate([[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#1 -> %arg5 = 0 to 2) {
|
||||
// CHECK: [[SCALE:%.+]] = load %arg1[%arg5] : memref<2xf32>
|
||||
// CHECK: [[BIAS:%.+]] = load %arg2[%arg5] : memref<2xf32>
|
||||
// CHECK: [[MEAN:%.+]] = load %arg3[%arg5] : memref<2xf32>
|
||||
// CHECK: [[VARIANCE:%.+]] = load %arg4[%arg5] : memref<2xf32>
|
||||
// CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#2, [[OPT_LOOPS]]#3) with ([[DEF_LOOPS]]#0 -> %arg6 = 0 to 1, [[DEF_LOOPS]]#2 -> %arg7 = 0 to 1, [[DEF_LOOPS]]#3 -> %arg8 = 0 to 3) {
|
||||
// CHECK: [[LOADED_VAL:%.+]] = load %arg0[%arg6, %arg5, %arg7, %arg8] : memref<1x2x1x3xf32>
|
||||
// CHECK: [[DIVIDEND:%.+]] = subf [[LOADED_VAL]], [[MEAN]] : f32
|
||||
// CHECK: [[ADJUSTED_VARIANCE:%.+]] = addf [[VARIANCE]], [[EPSILON]] : f32
|
||||
// CHECK: [[DIVISOR:%.+]] = "krnl.sqrt"([[ADJUSTED_VARIANCE]]) : (f32) -> f32
|
||||
// CHECK: [[NORM:%.+]] = divf [[DIVIDEND]], [[DIVISOR]] : f32
|
||||
// CHECK: [[SCALE_NORM:%.+]] = mulf [[SCALE]], [[NORM]] : f32
|
||||
// CHECK: [[SHIFT_SCALE_NORM:%.+]] = addf [[SCALE_NORM]], [[BIAS]] : f32
|
||||
// CHECK: store [[SHIFT_SCALE_NORM]], [[RES]][%arg6, %arg5, %arg7, %arg8] : memref<1x2x1x3xf32>
|
||||
// CHECK: }
|
||||
// CHECK: }
|
||||
// CHECK: return [[RES]] : memref<1x2x1x3xf32>
|
||||
}
|
||||
|
||||
func @test_batchnorm_testmode_1d(%arg0: tensor<10xf32>, %arg1: tensor<1xf32>, %arg2: tensor<1xf32>, %arg3: tensor<1xf32>, %arg4: tensor<1xf32>) -> tensor<10xf32> {
|
||||
%0 = "onnx.BatchNormalizationTestMode"(%arg0, %arg1, %arg2, %arg3, %arg4) : (tensor<10xf32>, tensor<1xf32>, tensor<1xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<10xf32>
|
||||
return %0 : tensor<10xf32>
|
||||
|
||||
// CHECK-LABEL: test_batchnorm_testmode_1d
|
||||
// CHECK: [[RES:%.+]] = alloc() : memref<10xf32>
|
||||
// CHECK: [[EPSILON:%.+]] = constant 9.99999974E-6 : f32
|
||||
// CHECK: [[DEF_LOOPS:%.+]] = krnl.define_loops 1
|
||||
// CHECK: [[OPT_LOOPS:%.+]] = krnl.optimize_loops {
|
||||
// CHECK: krnl.return_loops [[DEF_LOOPS]]
|
||||
// CHECK: } : () -> !krnl.loop
|
||||
// CHECK: %[[ZERO_INDEX:.+]] = constant 0 : index
|
||||
// CHECK: [[SCALE:%.+]] = load %arg1[%[[ZERO_INDEX]]] : memref<1xf32>
|
||||
// CHECK: [[BIAS:%.+]] = load %arg2[%[[ZERO_INDEX]]] : memref<1xf32>
|
||||
// CHECK: [[MEAN:%.+]] = load %arg3[%[[ZERO_INDEX]]] : memref<1xf32>
|
||||
// CHECK: [[VARIANCE:%.+]] = load %arg4[%[[ZERO_INDEX]]] : memref<1xf32>
|
||||
// CHECK: krnl.iterate([[OPT_LOOPS]]) with ([[DEF_LOOPS]] -> %arg5 = 0 to 10) {
|
||||
// CHECK: [[LOADED_VAL:%.+]] = load %arg0[%arg5] : memref<10xf32>
|
||||
// CHECK: [[DIVIDEND:%.+]] = subf [[LOADED_VAL]], [[MEAN]] : f32
|
||||
// CHECK: [[ADJUSTED_VARIANCE:%.+]] = addf [[VARIANCE]], [[EPSILON]] : f32
|
||||
// CHECK: [[DIVISOR:%.+]] = "krnl.sqrt"([[ADJUSTED_VARIANCE]]) : (f32) -> f32
|
||||
// CHECK: [[NORM:%.+]] = divf [[DIVIDEND]], [[DIVISOR]] : f32
|
||||
// CHECK: [[SCALE_NORM:%.+]] = mulf [[SCALE]], [[NORM]] : f32
|
||||
// CHECK: [[SHIFT_SCALE_NORM:%.+]] = addf [[SCALE_NORM]], [[BIAS]] : f32
|
||||
// CHECK: store [[SHIFT_SCALE_NORM]], [[RES]][%arg5] : memref<10xf32>
|
||||
// CHECK: }
|
||||
// CHECK: return [[RES]] : memref<10xf32>
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue