Lower MaxPoolSingleOutOp to Krnl dialect (#1)
* Lower MaxPoolSingleOutOp to Krnl dialect * Edit comments * Update changes according to the new folder structure * Add MLIR tests * Support ceil_mode * Merge the first two krnl loops into one krnl loop; remove attribute checks * Dynamically allocate memory for the result if the result has unknown dimensions Co-authored-by: Gheorghe-Teodor Bercea <gt.bercea@gmail.com>
This commit is contained in:
parent
e97df0b343
commit
e4c23da4fd
|
@ -85,6 +85,7 @@ add_library(onnf_lower_frontend
|
|||
conversion/onnx_to_krnl/math/softmax.cpp
|
||||
conversion/onnx_to_krnl/nn/conv.cpp
|
||||
conversion/onnx_to_krnl/nn/normalization.cpp
|
||||
conversion/onnx_to_krnl/nn/pooling.cpp
|
||||
conversion/onnx_to_krnl/tensor/identity.cpp
|
||||
conversion/onnx_to_krnl/tensor/reshape.cpp
|
||||
conversion/onnx_to_krnl/tensor/transpose.cpp
|
||||
|
|
|
@ -99,6 +99,7 @@ void FrontendToKrnlLoweringPass::runOnModule() {
|
|||
// Neural network
|
||||
populateLoweringONNXConvOpPattern(patterns, &getContext());
|
||||
populateLoweringONNXNormalizationOpPattern(patterns, &getContext());
|
||||
populateLoweringONNXPoolingOpPattern(patterns, &getContext());
|
||||
// Entry point
|
||||
patterns.insert<ONNXEntryPointLowering>(&getContext());
|
||||
|
||||
|
|
|
@ -0,0 +1,294 @@
|
|||
//===----- pooling.cpp - Lowering Pooling Ops -----------------------------===//
|
||||
//
|
||||
// Copyright 2019 The IBM Research Authors.
|
||||
//
|
||||
// =============================================================================
|
||||
//
|
||||
// This file lowers the ONNX Pooling Operators to Krnl dialect.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "src/conversion/onnx_to_krnl/onnx_to_krnl_common.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
// Identity values
|
||||
template <>
|
||||
float getIdentityValue<float, ONNXMaxPoolSingleOutOp>() {
|
||||
return (float)-std::numeric_limits<float>::infinity();
|
||||
}
|
||||
|
||||
template <>
|
||||
int getIdentityValue<int, ONNXMaxPoolSingleOutOp>() {
|
||||
return std::numeric_limits<int>::min();
|
||||
}
|
||||
|
||||
template <>
|
||||
Value mapToLowerScalarOp<ONNXMaxPoolSingleOutOp>(Operation *op,
|
||||
ArrayRef<Type> result_types, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) {
|
||||
auto loc = op->getLoc();
|
||||
Value lhs = operands[0];
|
||||
Value rhs = operands[1];
|
||||
auto max = rewriter.create<CmpFOp>(loc, CmpFPredicate::OGT, lhs, rhs);
|
||||
auto result = rewriter.create<SelectOp>(loc, max, lhs, rhs);
|
||||
return result;
|
||||
}
|
||||
|
||||
struct ONNXMaxPoolSingleOutOpLowering : public ConversionPattern {
|
||||
ONNXMaxPoolSingleOutOpLowering(MLIRContext *ctx)
|
||||
: ConversionPattern(
|
||||
mlir::ONNXMaxPoolSingleOutOp::getOperationName(), 1, ctx) {}
|
||||
|
||||
PatternMatchResult matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const final {
|
||||
auto loc = op->getLoc();
|
||||
|
||||
// Match
|
||||
ONNXMaxPoolSingleOutOp poolOp = llvm::dyn_cast<ONNXMaxPoolSingleOutOp>(op);
|
||||
|
||||
// Read kernel_shape attribute
|
||||
SmallVector<int, 4> kernelShape;
|
||||
auto kernelShapeAttribute = poolOp.kernel_shapeAttr();
|
||||
for (auto dim : kernelShapeAttribute.getValue())
|
||||
kernelShape.emplace_back(dim.cast<IntegerAttr>().getInt());
|
||||
|
||||
// Read strides attribute
|
||||
SmallVector<int, 4> strides;
|
||||
auto stridesAttribute = poolOp.stridesAttr();
|
||||
for (auto stride : stridesAttribute.getValue())
|
||||
strides.emplace_back(stride.cast<IntegerAttr>().getInt());
|
||||
|
||||
// Read ceil_mode attribute
|
||||
auto ceilMode = poolOp.ceil_mode().getSExtValue();
|
||||
|
||||
// Read pads attribute
|
||||
SmallVector<int, 4> pads;
|
||||
auto padsAttribute = poolOp.padsAttr();
|
||||
for (auto pad : padsAttribute.getValue())
|
||||
pads.emplace_back(pad.cast<IntegerAttr>().getInt());
|
||||
|
||||
// Read dilations attribute
|
||||
SmallVector<int, 4> dilations;
|
||||
auto dilationsAttribute = poolOp.dilationsAttr();
|
||||
for (auto dilation : dilationsAttribute.getValue())
|
||||
dilations.emplace_back(dilation.cast<IntegerAttr>().getInt());
|
||||
|
||||
// Type information about the input and result of this operation.
|
||||
auto &inputOperand = operands[0];
|
||||
auto inputShape = inputOperand.getType().cast<MemRefType>().getShape();
|
||||
auto memRefType = convertToMemRefType(*op->result_type_begin());
|
||||
auto resultShape = memRefType.getShape();
|
||||
auto resultElementType = memRefType.getElementType();
|
||||
|
||||
// Batch indices: N and C dimensions
|
||||
int batchRank = 2;
|
||||
|
||||
// Insert an allocation and deallocation for the result of this operation.
|
||||
Value alloc;
|
||||
bool insertDealloc = checkInsertDealloc(op);
|
||||
|
||||
if (hasAllConstantDimensions(memRefType))
|
||||
alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc);
|
||||
else {
|
||||
// Compute dimensions of the result of this operation.
|
||||
SmallVector<Value, 2> allocOperands;
|
||||
for (int i = 0; i < batchRank; ++i) {
|
||||
if (resultShape[i] < 0) {
|
||||
auto dim = rewriter.create<DimOp>(loc, inputOperand, i);
|
||||
allocOperands.emplace_back(dim);
|
||||
}
|
||||
}
|
||||
|
||||
Value zero, one;
|
||||
if (ceilMode) {
|
||||
zero = rewriter.create<ConstantOp>(
|
||||
loc, rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0));
|
||||
}
|
||||
one = rewriter.create<ConstantOp>(
|
||||
loc, rewriter.getIntegerAttr(rewriter.getIntegerType(64), 1));
|
||||
|
||||
int spatialRank = resultShape.size() - batchRank;
|
||||
for (int i = batchRank; i < resultShape.size(); ++i) {
|
||||
if (resultShape[i] < 0) {
|
||||
// dim =
|
||||
// let numerator = (input + pad - (kernel - 1) * dilation + 1)
|
||||
// in let denomitor = stride
|
||||
// in
|
||||
// if (ceilMode)
|
||||
// ceil(numerator / denominator) + 1
|
||||
// else
|
||||
// floor(numerator / denominator) + 1
|
||||
int spatialIndex = i - batchRank;
|
||||
|
||||
// numerator = (input + pad - (kernel - 1) * dilation + 1)
|
||||
auto inputDim = rewriter.create<DimOp>(loc, inputOperand, i);
|
||||
auto inputVal = rewriter.create<IndexCastOp>(
|
||||
loc, inputDim, rewriter.getIntegerType(64));
|
||||
int64_t padKernelDilation =
|
||||
(pads[spatialIndex] + pads[spatialIndex + spatialRank]) -
|
||||
(kernelShape[spatialIndex] - 1) * dilations[spatialIndex] + 1;
|
||||
auto padKernelDilationVal = rewriter.create<ConstantOp>(
|
||||
loc, rewriter.getIntegerAttr(
|
||||
rewriter.getIntegerType(64), padKernelDilation));
|
||||
auto numeratorVal =
|
||||
rewriter.create<AddIOp>(loc, inputVal, padKernelDilationVal);
|
||||
// denominator
|
||||
auto denominatorVal = rewriter.create<ConstantOp>(
|
||||
loc, rewriter.getIntegerAttr(
|
||||
rewriter.getIntegerType(64), strides[spatialIndex]));
|
||||
|
||||
// numerator / denominator
|
||||
Value dimVal =
|
||||
rewriter.create<SignedDivIOp>(loc, numeratorVal, denominatorVal);
|
||||
|
||||
if (ceilMode) {
|
||||
auto remainder = rewriter.create<SignedRemIOp>(
|
||||
loc, numeratorVal, denominatorVal);
|
||||
auto isZero = rewriter.create<CmpIOp>(
|
||||
loc, CmpIPredicate::eq, remainder, zero);
|
||||
auto dimPlusOne = rewriter.create<AddIOp>(loc, dimVal, one);
|
||||
dimVal = rewriter.create<SelectOp>(loc, isZero, dimVal, dimPlusOne);
|
||||
}
|
||||
|
||||
dimVal = rewriter.create<AddIOp>(loc, dimVal, one);
|
||||
allocOperands.emplace_back(rewriter.create<IndexCastOp>(
|
||||
loc, dimVal, rewriter.getIndexType()));
|
||||
}
|
||||
}
|
||||
alloc = rewriter.create<AllocOp>(loc, memRefType, allocOperands);
|
||||
if (insertDealloc) {
|
||||
auto *parentBlock = alloc.getDefiningOp()->getBlock();
|
||||
auto dealloc = rewriter.create<DeallocOp>(loc, alloc);
|
||||
dealloc.getOperation()->moveBefore(&parentBlock->back());
|
||||
}
|
||||
}
|
||||
|
||||
// R = MaxPool(D)
|
||||
//
|
||||
// The input/output shapes will look like this:
|
||||
//
|
||||
// D (NxCxHxW) -> R (NxCxRHxRW)
|
||||
//
|
||||
// The loop nest will look as follows:
|
||||
//
|
||||
// strides = [s1, s2]
|
||||
//
|
||||
// for n = 0 .. N:
|
||||
// for c = 0 .. C:
|
||||
// for r1 = 0 .. RH:
|
||||
// for r2 = 0 .. RW:
|
||||
// R[n][c][r1][r2] = negative_infinity;
|
||||
// for k1 = 0 .. KH:
|
||||
// for k2 = 0 .. KW:
|
||||
// t = D[n][c][s1 * r1 + k1][s2 * r2 + k2];
|
||||
// R[n][c][r1][r2] = max(R[n][c][r1][r2], t);
|
||||
//
|
||||
// Naming:
|
||||
// n, c, r1, r2: outer loop nest indices
|
||||
// k1, k2: inner loop nest indices
|
||||
//
|
||||
// TODO: handle padding.
|
||||
//
|
||||
|
||||
// 1. Define outer loops and emit empty optimization block.
|
||||
auto nOuterLoops = resultShape.size();
|
||||
BuildKrnlLoop outerLoops(rewriter, loc, nOuterLoops);
|
||||
outerLoops.createDefineOptimizeAndIterateOp(alloc);
|
||||
|
||||
rewriter.setInsertionPointToStart(outerLoops.getIterateBlock());
|
||||
{
|
||||
// 2. Emit the body of the outer loop nest.
|
||||
SmallVector<Value, 4> resultIndices;
|
||||
for (int i = 0; i < nOuterLoops; ++i)
|
||||
resultIndices.emplace_back(outerLoops.getInductionVar(i));
|
||||
|
||||
// 2.1 Emit: R[n][c][r1][r2] = negative_infinity;
|
||||
Value identity;
|
||||
if (resultElementType.isa<FloatType>()) {
|
||||
identity = rewriter.create<ConstantOp>(
|
||||
loc, FloatAttr::get(resultElementType,
|
||||
getIdentityValue<float, ONNXMaxPoolSingleOutOp>()));
|
||||
} else if (resultElementType.isa<IntegerType>()) {
|
||||
identity = rewriter.create<ConstantOp>(
|
||||
loc, IntegerAttr::get(resultElementType,
|
||||
getIdentityValue<int, ONNXMaxPoolSingleOutOp>()));
|
||||
} else {
|
||||
emitError(loc, "unsupported element type");
|
||||
}
|
||||
rewriter.create<StoreOp>(loc, identity, alloc, resultIndices);
|
||||
|
||||
// 2.2 Define inner loops.
|
||||
int nInnerLoops = kernelShape.size();
|
||||
BuildKrnlLoop innerLoops(rewriter, loc, nInnerLoops);
|
||||
innerLoops.createDefineAndOptimizeOp();
|
||||
// for Kx = 0 .. KX
|
||||
for (int i = 0; i < nInnerLoops; ++i)
|
||||
innerLoops.pushBounds(0, kernelShape[i]);
|
||||
|
||||
// 2.3 Emit inner loop nest.
|
||||
innerLoops.createIterateOp();
|
||||
rewriter.setInsertionPointToStart(innerLoops.getIterateBlock());
|
||||
{
|
||||
// 3. Emit inner loop body
|
||||
// t = D[n][c][s1 * r1 + k1][s2 * r2 + k2];
|
||||
// R[n][c][r1][r2] = max(R[n][c][r1][r2], t);
|
||||
|
||||
// 3.1 Prepare indices for accesing the data tensor.
|
||||
SmallVector<Value, 4> dataIndices;
|
||||
// Batch indices: n, c
|
||||
for (int i = 0; i < batchRank; ++i)
|
||||
dataIndices.emplace_back(outerLoops.getInductionVar(i));
|
||||
// Spatial indices: sX * rX + kX
|
||||
for (int i = batchRank; i < nOuterLoops; ++i) {
|
||||
Value spatialIndex = outerLoops.getInductionVar(i);
|
||||
// If strides are present then emit the correct access index.
|
||||
if (stridesAttribute && strides[i - batchRank] > 1) {
|
||||
spatialIndex = rewriter.create<MulIOp>(loc,
|
||||
rewriter.create<ConstantIndexOp>(loc, strides[i - batchRank]),
|
||||
outerLoops.getInductionVar(i));
|
||||
}
|
||||
spatialIndex = rewriter.create<AddIOp>(
|
||||
loc, spatialIndex, innerLoops.getInductionVar(i - batchRank));
|
||||
// If ceil mode is enabled, then the calculated access index may
|
||||
// exceed its dimension. In such a case, we will use the maximum
|
||||
// index, which causes multiple visits to the element of the
|
||||
// maximum index.
|
||||
// TODO: Avoid multiple visits.
|
||||
if (ceilMode) {
|
||||
Value inputIndex;
|
||||
if (inputShape[i] < 0) {
|
||||
Value inputDim = rewriter.create<DimOp>(loc, inputOperand, i);
|
||||
Value one = rewriter.create<ConstantIndexOp>(loc, 1);
|
||||
inputIndex = rewriter.create<SubIOp>(loc, inputDim, one);
|
||||
} else {
|
||||
inputIndex =
|
||||
rewriter.create<ConstantIndexOp>(loc, inputShape[i] - 1);
|
||||
}
|
||||
auto greaterCondition = rewriter.create<CmpIOp>(
|
||||
loc, CmpIPredicate::sgt, spatialIndex, inputIndex);
|
||||
spatialIndex = rewriter.create<SelectOp>(
|
||||
loc, greaterCondition, inputIndex, spatialIndex);
|
||||
}
|
||||
dataIndices.emplace_back(spatialIndex);
|
||||
}
|
||||
|
||||
// 3.2 Do pooling.
|
||||
auto loadData = rewriter.create<LoadOp>(loc, inputOperand, dataIndices);
|
||||
auto loadPartialResult =
|
||||
rewriter.create<LoadOp>(loc, alloc, resultIndices);
|
||||
Value result = mapToLowerScalarOp<ONNXMaxPoolSingleOutOp>(
|
||||
op, resultElementType, {loadPartialResult, loadData}, rewriter);
|
||||
rewriter.create<StoreOp>(loc, result, alloc, resultIndices);
|
||||
}
|
||||
}
|
||||
rewriter.replaceOp(op, alloc);
|
||||
|
||||
return matchSuccess();
|
||||
}
|
||||
};
|
||||
|
||||
void populateLoweringONNXPoolingOpPattern(
|
||||
OwningRewritePatternList &patterns, MLIRContext *ctx) {
|
||||
patterns.insert<ONNXMaxPoolSingleOutOpLowering>(ctx);
|
||||
}
|
|
@ -202,6 +202,9 @@ void populateLoweringONNXConvOpPattern(
|
|||
void populateLoweringONNXNormalizationOpPattern(
|
||||
OwningRewritePatternList &patterns, MLIRContext *ctx);
|
||||
|
||||
void populateLoweringONNXPoolingOpPattern(
|
||||
OwningRewritePatternList &patterns, MLIRContext *ctx);
|
||||
|
||||
// `tensor` directory methods:
|
||||
|
||||
void populateLoweringONNXUnsqueezeOpPattern(
|
||||
|
|
|
@ -305,6 +305,13 @@ test_to_enable = [
|
|||
"test_batchnorm_epsilon_cpu",
|
||||
"test_batchnorm_example_cpu",
|
||||
|
||||
# Pooling
|
||||
"test_maxpool_1d_default_cpu",
|
||||
"test_maxpool_2d_ceil_cpu",
|
||||
"test_maxpool_2d_default_cpu",
|
||||
"test_maxpool_2d_strides_cpu",
|
||||
"test_maxpool_3d_default_cpu",
|
||||
|
||||
]
|
||||
|
||||
# Extract name of all test cases.
|
||||
|
|
|
@ -1344,3 +1344,169 @@ func @test_batchnorm_testmode_1d(%arg0: tensor<10xf32>, %arg1: tensor<1xf32>, %a
|
|||
// CHECK: return [[RES]] : memref<10xf32>
|
||||
}
|
||||
|
||||
func @test_maxpooling_singleout_no_pad(%arg0 : tensor<1x3x32x32xf32>) -> tensor<*xf32> {
|
||||
%0 = "onnx.MaxPoolSingleOut"(%arg0) {auto_pad = "NOTSET", kernel_shape = [2, 2]} : (tensor<1x3x32x32xf32>) -> tensor<*xf32>
|
||||
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||
|
||||
// CHECK-LABEL: test_maxpooling_singleout_no_pad
|
||||
// CHECK: [[RES:%.+]] = alloc() : memref<1x3x31x31xf32>
|
||||
// CHECK: [[DEF_LOOPS_0:%.+]]:4 = krnl.define_loops 4
|
||||
// CHECK: [[OPT_LOOPS_0:%.+]]:4 = krnl.optimize_loops {
|
||||
// CHECK: krnl.return_loops [[DEF_LOOPS_0]]#0, [[DEF_LOOPS_0]]#1, [[DEF_LOOPS_0]]#2, [[DEF_LOOPS_0]]#3
|
||||
// CHECK: } : () -> (!krnl.loop, !krnl.loop, !krnl.loop, !krnl.loop)
|
||||
// CHECK: krnl.iterate([[OPT_LOOPS_0]]#0, [[OPT_LOOPS_0]]#1, [[OPT_LOOPS_0]]#2, [[OPT_LOOPS_0]]#3) with ([[DEF_LOOPS_0]]#0 -> %arg1 = 0 to 1, [[DEF_LOOPS_0]]#1 -> %arg2 = 0 to 3, [[DEF_LOOPS_0]]#2 -> %arg3 = 0 to 31, [[DEF_LOOPS_0]]#3 -> %arg4 = 0 to 31) {
|
||||
// CHECK: [[NEGATIVE_INFINITY:%.+]] = constant 0xFF800000 : f32
|
||||
// CHECK: store [[NEGATIVE_INFINITY]], [[RES]][%arg1, %arg2, %arg3, %arg4] : memref<1x3x31x31xf32>
|
||||
// CHECK: [[DEF_LOOPS_1:%.+]]:2 = krnl.define_loops 2
|
||||
// CHECK: [[OPT_LOOPS_1:%.+]]:2 = krnl.optimize_loops {
|
||||
// CHECK: krnl.return_loops [[DEF_LOOPS_1]]#0, [[DEF_LOOPS_1]]#1
|
||||
// CHECK: } : () -> (!krnl.loop, !krnl.loop)
|
||||
// CHECK: krnl.iterate([[OPT_LOOPS_1]]#0, [[OPT_LOOPS_1]]#1) with ([[DEF_LOOPS_1]]#0 -> %arg5 = 0 to 2, [[DEF_LOOPS_1]]#1 -> %arg6 = 0 to 2) {
|
||||
// CHECK: [[H:%.+]] = addi %arg3, %arg5 : index
|
||||
// CHECK: [[W:%.+]] = addi %arg4, %arg6 : index
|
||||
// CHECK: [[LOAD_X:%.+]] = load %arg0[%arg1, %arg2, [[H]], [[W]]] : memref<1x3x32x32xf32>
|
||||
// CHECK: [[LOAD_Y:%.+]] = load [[RES]][%arg1, %arg2, %arg3, %arg4] : memref<1x3x31x31xf32>
|
||||
// CHECK: [[COMPARE:%.+]] = cmpf "ogt", [[LOAD_Y]], [[LOAD_X]] : f32
|
||||
// CHECK: [[SELECT:%.+]] = select [[COMPARE]], [[LOAD_Y]], [[LOAD_X]] : f32
|
||||
// CHECK: store [[SELECT]], [[RES]][%arg1, %arg2, %arg3, %arg4] : memref<1x3x31x31xf32>
|
||||
// CHECK: }
|
||||
// CHECK: }
|
||||
// CHECK: return [[RES]] : memref<1x3x31x31xf32>
|
||||
}
|
||||
|
||||
func @test_maxpooling_singleout_no_pad_w_strides(%arg0 : tensor<1x3x32x32xf32>) -> tensor<*xf32> {
|
||||
%0 = "onnx.MaxPoolSingleOut"(%arg0) {auto_pad = "NOTSET", kernel_shape = [2, 2], strides = [2, 2]} : (tensor<1x3x32x32xf32>) -> tensor<*xf32>
|
||||
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||
|
||||
// CHECK-LABEL: test_maxpooling_singleout_no_pad_w_strides
|
||||
// CHECK: [[RES:%.+]] = alloc() : memref<1x3x16x16xf32>
|
||||
// CHECK: [[DEF_LOOPS_0:%.+]]:4 = krnl.define_loops 4
|
||||
// CHECK: [[OPT_LOOPS_0:%.+]]:4 = krnl.optimize_loops {
|
||||
// CHECK: krnl.return_loops [[DEF_LOOPS_0]]#0, [[DEF_LOOPS_0]]#1, [[DEF_LOOPS_0]]#2, [[DEF_LOOPS_0]]#3
|
||||
// CHECK: } : () -> (!krnl.loop, !krnl.loop, !krnl.loop, !krnl.loop)
|
||||
// CHECK: krnl.iterate([[OPT_LOOPS_0]]#0, [[OPT_LOOPS_0]]#1, [[OPT_LOOPS_0]]#2, [[OPT_LOOPS_0]]#3) with ([[DEF_LOOPS_0]]#0 -> %arg1 = 0 to 1, [[DEF_LOOPS_0]]#1 -> %arg2 = 0 to 3, [[DEF_LOOPS_0]]#2 -> %arg3 = 0 to 16, [[DEF_LOOPS_0]]#3 -> %arg4 = 0 to 16) {
|
||||
// CHECK: [[NEGATIVE_INFINITY:%.+]] = constant 0xFF800000 : f32
|
||||
// CHECK: store [[NEGATIVE_INFINITY]], [[RES]][%arg1, %arg2, %arg3, %arg4] : memref<1x3x16x16xf32>
|
||||
// CHECK: [[DEF_LOOPS_1:%.+]]:2 = krnl.define_loops 2
|
||||
// CHECK: [[OPT_LOOPS_1:%.+]]:2 = krnl.optimize_loops {
|
||||
// CHECK: krnl.return_loops [[DEF_LOOPS_1]]#0, [[DEF_LOOPS_1]]#1
|
||||
// CHECK: } : () -> (!krnl.loop, !krnl.loop)
|
||||
// CHECK: krnl.iterate([[OPT_LOOPS_1]]#0, [[OPT_LOOPS_1]]#1) with ([[DEF_LOOPS_1]]#0 -> %arg5 = 0 to 2, [[DEF_LOOPS_1]]#1 -> %arg6 = 0 to 2) {
|
||||
// CHECK: [[STRIDE_0:%.+]] = constant 2 : index
|
||||
// CHECK: [[MUL_0:%.+]] = muli [[STRIDE_0]], %arg3 : index
|
||||
// CHECK: [[H:%.+]] = addi [[MUL_0]], %arg5 : index
|
||||
// CHECK: [[STRIDE_1:%.+]] = constant 2 : index
|
||||
// CHECK: [[MUL_1:%.+]] = muli [[STRIDE_1]], %arg4 : index
|
||||
// CHECK: [[W:%.+]] = addi [[MUL_1]], %arg6 : index
|
||||
// CHECK: [[LOAD_X:%.+]] = load %arg0[%arg1, %arg2, [[H]], [[W]]] : memref<1x3x32x32xf32>
|
||||
// CHECK: [[LOAD_Y:%.+]] = load [[RES]][%arg1, %arg2, %arg3, %arg4] : memref<1x3x16x16xf32>
|
||||
// CHECK: [[COMPARE:%.+]] = cmpf "ogt", [[LOAD_Y]], [[LOAD_X]] : f32
|
||||
// CHECK: [[SELECT:%.+]] = select [[COMPARE]], [[LOAD_Y]], [[LOAD_X]] : f32
|
||||
// CHECK: store [[SELECT]], [[RES]][%arg1, %arg2, %arg3, %arg4] : memref<1x3x16x16xf32>
|
||||
// CHECK: }
|
||||
// CHECK: }
|
||||
// CHECK: return [[RES]] : memref<1x3x16x16xf32>
|
||||
}
|
||||
|
||||
func @test_maxpooling_singleout_no_pad_w_strides_w_ceil_mode(%arg0 : tensor<1x3x32x32xf32>) -> tensor<*xf32> {
|
||||
%0 = "onnx.MaxPoolSingleOut"(%arg0) {auto_pad = "NOTSET", kernel_shape = [3, 3], strides = [2, 2], ceil_mode = 1} : (tensor<1x3x32x32xf32>) -> tensor<*xf32>
|
||||
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||
|
||||
// CHECK-LABEL: test_maxpooling_singleout_no_pad_w_strides_w_ceil_mode
|
||||
// CHECK: [[RES:%.+]] = alloc() : memref<1x3x16x16xf32>
|
||||
// CHECK: [[DEF_LOOPS_0:%.+]]:4 = krnl.define_loops 4
|
||||
// CHECK: [[OPT_LOOPS_0:%.+]]:4 = krnl.optimize_loops {
|
||||
// CHECK: krnl.return_loops [[DEF_LOOPS_0]]#0, [[DEF_LOOPS_0]]#1, [[DEF_LOOPS_0]]#2, [[DEF_LOOPS_0]]#3
|
||||
// CHECK: } : () -> (!krnl.loop, !krnl.loop, !krnl.loop, !krnl.loop)
|
||||
// CHECK: krnl.iterate([[OPT_LOOPS_0]]#0, [[OPT_LOOPS_0]]#1, [[OPT_LOOPS_0]]#2, [[OPT_LOOPS_0]]#3) with ([[DEF_LOOPS_0]]#0 -> %arg1 = 0 to 1, [[DEF_LOOPS_0]]#1 -> %arg2 = 0 to 3, [[DEF_LOOPS_0]]#2 -> %arg3 = 0 to 16, [[DEF_LOOPS_0]]#3 -> %arg4 = 0 to 16) {
|
||||
// CHECK: [[NEGATIVE_INFINITY:%.+]] = constant 0xFF800000 : f32
|
||||
// CHECK: store [[NEGATIVE_INFINITY]], [[RES]][%arg1, %arg2, %arg3, %arg4] : memref<1x3x16x16xf32>
|
||||
// CHECK: [[DEF_LOOPS_1:%.+]]:2 = krnl.define_loops 2
|
||||
// CHECK: [[OPT_LOOPS_1:%.+]]:2 = krnl.optimize_loops {
|
||||
// CHECK: krnl.return_loops [[DEF_LOOPS_1]]#0, [[DEF_LOOPS_1]]#1
|
||||
// CHECK: } : () -> (!krnl.loop, !krnl.loop)
|
||||
// CHECK: krnl.iterate([[OPT_LOOPS_1]]#0, [[OPT_LOOPS_1]]#1) with ([[DEF_LOOPS_1]]#0 -> %arg5 = 0 to 3, [[DEF_LOOPS_1]]#1 -> %arg6 = 0 to 3) {
|
||||
// CHECK: [[STRIDE_0:%.+]] = constant 2 : index
|
||||
// CHECK: [[MUL_0:%.+]] = muli [[STRIDE_0]], %arg3 : index
|
||||
// CHECK: [[SPATIAL_H:%.+]] = addi [[MUL_0]], %arg5 : index
|
||||
// CHECK: [[INPUT_INDEX_0:%.+]] = constant 31 : index
|
||||
// CHECK: [[CMP_0:%.+]] = cmpi "sgt", [[SPATIAL_H]], [[INPUT_INDEX_0]] : index
|
||||
// CHECK: [[H:%.+]] = select [[CMP_0]], [[INPUT_INDEX_0]], [[SPATIAL_H]] : index
|
||||
// CHECK: [[STRIDE_1:%.+]] = constant 2 : index
|
||||
// CHECK: [[MUL_1:%.+]] = muli [[STRIDE_1]], %arg4 : index
|
||||
// CHECK: [[SPATIAL_W:%.+]] = addi [[MUL_1]], %arg6 : index
|
||||
// CHECK: [[INPUT_INDEX_1:%.+]] = constant 31 : index
|
||||
// CHECK: [[CMP_1:%.+]] = cmpi "sgt", [[SPATIAL_W]], [[INPUT_INDEX_1]] : index
|
||||
// CHECK: [[W:%.+]] = select [[CMP_1]], [[INPUT_INDEX_1]], [[SPATIAL_W]] : index
|
||||
// CHECK: [[LOAD_X:%.+]] = load %arg0[%arg1, %arg2, [[H]], [[W]]] : memref<1x3x32x32xf32>
|
||||
// CHECK: [[LOAD_Y:%.+]] = load [[RES]][%arg1, %arg2, %arg3, %arg4] : memref<1x3x16x16xf32>
|
||||
// CHECK: [[CMP_2:%.+]] = cmpf "ogt", [[LOAD_Y]], [[LOAD_X]] : f32
|
||||
// CHECK: [[SELECT:%.+]] = select [[CMP_2]], [[LOAD_Y]], [[LOAD_X]] : f32
|
||||
// CHECK: store [[SELECT]], [[RES]][%arg1, %arg2, %arg3, %arg4] : memref<1x3x16x16xf32>
|
||||
// CHECK: }
|
||||
// CHECK: }
|
||||
// CHECK: return [[RES]] : memref<1x3x16x16xf32>
|
||||
}
|
||||
|
||||
func @test_maxpooling_singleout_no_pad_w_strides_w_ceil_mode_w_unknown_dims(%arg0 : tensor<?x3x?x32xf32>) -> tensor<*xf32> {
|
||||
%0 = "onnx.MaxPoolSingleOut"(%arg0) {auto_pad = "NOTSET", kernel_shape = [3, 3], strides = [2, 2], ceil_mode = 1} : (tensor<?x3x?x32xf32>) -> tensor<*xf32>
|
||||
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||
|
||||
// CHECK-LABEL: test_maxpooling_singleout_no_pad_w_strides_w_ceil_mode_w_unknown_dims
|
||||
|
||||
// CHECK: [[DIM_0:%.+]] = dim %arg0, 0 : memref<?x3x?x32xf32>
|
||||
// CHECK: [[ZERO:%.+]] = constant 0 : i64
|
||||
// CHECK: [[ONE:%.+]] = constant 1 : i64
|
||||
// CHECK: [[DIM_1:%.+]] = dim %arg0, 2 : memref<?x3x?x32xf32>
|
||||
// CHECK: [[DIM_1_i64:%.+]] = index_cast [[DIM_1]] : index to i64
|
||||
// CHECK: [[KERNEL_PAD_DILATION:%.+]] = constant -1 : i64
|
||||
// CHECK: [[NUMERATOR:%.+]] = addi [[DIM_1_i64]], [[KERNEL_PAD_DILATION]] : i64
|
||||
// CHECK: [[DENOMINATOR:%.+]] = constant 2 : i64
|
||||
// CHECK: [[DIV:%.+]] = divi_signed [[NUMERATOR]], [[DENOMINATOR]] : i64
|
||||
// CHECK: [[REMAINDER:%.+]] = remi_signed [[NUMERATOR]], [[DENOMINATOR]] : i64
|
||||
// CHECK: [[IS_ZERO:%.+]] = cmpi "eq", [[REMAINDER]], [[ZERO]] : i64
|
||||
// CHECK: [[DIV_PLUS_ONE:%.+]] = addi [[DIV]], [[ONE]] : i64
|
||||
// CHECK: [[SELECT:%.+]] = select [[IS_ZERO]], [[DIV]], [[DIV_PLUS_ONE]] : i64
|
||||
// CHECK: [[SELECT_PLUS_ONE:%.+]] = addi [[SELECT]], [[ONE]] : i64
|
||||
// CHECK: [[DIM_1_FINAL:%.+]] = index_cast [[SELECT_PLUS_ONE]] : i64 to index
|
||||
// CHECK: [[RES:%.+]] = alloc([[DIM_0]], [[DIM_1_FINAL]]) : memref<?x3x?x16xf32>
|
||||
|
||||
// CHECK: [[DEF_LOOPS_0:%.+]]:4 = krnl.define_loops 4
|
||||
// CHECK: [[OPT_LOOPS_0:%.+]]:4 = krnl.optimize_loops {
|
||||
// CHECK: krnl.return_loops [[DEF_LOOPS_0]]#0, [[DEF_LOOPS_0]]#1, [[DEF_LOOPS_0]]#2, [[DEF_LOOPS_0]]#3
|
||||
// CHECK: } : () -> (!krnl.loop, !krnl.loop, !krnl.loop, !krnl.loop)
|
||||
// CHECK: [[DIM_2:%.+]] = dim [[RES]], 0 : memref<?x3x?x16xf32>
|
||||
// CHECK: [[DIM_3:%.+]] = dim [[RES]], 2 : memref<?x3x?x16xf32>
|
||||
// CHECK: krnl.iterate([[OPT_LOOPS_0]]#0, [[OPT_LOOPS_0]]#1, [[OPT_LOOPS_0]]#2, [[OPT_LOOPS_0]]#3) with ([[DEF_LOOPS_0]]#0 -> %arg1 = 0 to [[DIM_2]], [[DEF_LOOPS_0]]#1 -> %arg2 = 0 to 3, [[DEF_LOOPS_0]]#2 -> %arg3 = 0 to [[DIM_3]], [[DEF_LOOPS_0]]#3 -> %arg4 = 0 to 16) {
|
||||
// CHECK: [[NEGATIVE_INFINITY:%.+]] = constant 0xFF800000 : f32
|
||||
// CHECK: store [[NEGATIVE_INFINITY]], [[RES]][%arg1, %arg2, %arg3, %arg4] : memref<?x3x?x16xf32>
|
||||
// CHECK: [[DEF_LOOPS_1:%.+]]:2 = krnl.define_loops 2
|
||||
// CHECK: [[OPT_LOOPS_1:%.+]]:2 = krnl.optimize_loops {
|
||||
// CHECK: krnl.return_loops [[DEF_LOOPS_1]]#0, [[DEF_LOOPS_1]]#1
|
||||
// CHECK: } : () -> (!krnl.loop, !krnl.loop)
|
||||
// CHECK: krnl.iterate([[OPT_LOOPS_1]]#0, [[OPT_LOOPS_1]]#1) with ([[DEF_LOOPS_1]]#0 -> %arg5 = 0 to 3, [[DEF_LOOPS_1]]#1 -> %arg6 = 0 to 3) {
|
||||
// CHECK: [[STRIDE_0:%.+]] = constant 2 : index
|
||||
// CHECK: [[MUL_0:%.+]] = muli [[STRIDE_0]], %arg3 : index
|
||||
// CHECK: [[SPATIAL_H:%.+]] = addi [[MUL_0]], %arg5 : index
|
||||
// CHECK: [[DIM_0_0:%.+]] = dim %arg0, 2 : memref<?x3x?x32xf32>
|
||||
// CHECK: [[ONE_INDEX:%.+]] = constant 1 : index
|
||||
// CHECK: [[INPUT_INDEX_0:%.+]] = subi [[DIM_0_0]], [[ONE_INDEX]] : index
|
||||
// CHECK: [[CMP_0:%.+]] = cmpi "sgt", [[SPATIAL_H]], [[INPUT_INDEX_0]] : index
|
||||
// CHECK: [[H:%.+]] = select [[CMP_0]], [[INPUT_INDEX_0]], [[SPATIAL_H]] : index
|
||||
|
||||
// CHECK: [[STRIDE_1:%.+]] = constant 2 : index
|
||||
// CHECK: [[MUL_1:%.+]] = muli [[STRIDE_1]], %arg4 : index
|
||||
// CHECK: [[SPATIAL_W:%.+]] = addi [[MUL_1]], %arg6 : index
|
||||
// CHECK: [[INPUT_INDEX_1:%.+]] = constant 31 : index
|
||||
// CHECK: [[CMP_1:%.+]] = cmpi "sgt", [[SPATIAL_W]], [[INPUT_INDEX_1]] : index
|
||||
// CHECK: [[W:%.+]] = select [[CMP_1]], [[INPUT_INDEX_1]], [[SPATIAL_W]] : index
|
||||
|
||||
// CHECK: [[LOAD_X:%.+]] = load %arg0[%arg1, %arg2, [[H]], [[W]]] : memref<?x3x?x32xf32>
|
||||
// CHECK: [[LOAD_Y:%.+]] = load [[RES]][%arg1, %arg2, %arg3, %arg4] : memref<?x3x?x16xf32>
|
||||
// CHECK: [[CMP_2:%.+]] = cmpf "ogt", [[LOAD_Y]], [[LOAD_X]] : f32
|
||||
// CHECK: [[SELECT:%.+]] = select [[CMP_2]], [[LOAD_Y]], [[LOAD_X]] : f32
|
||||
// CHECK: store [[SELECT]], [[RES]][%arg1, %arg2, %arg3, %arg4] : memref<?x3x?x16xf32>
|
||||
// CHECK: }
|
||||
// CHECK: }
|
||||
// CHECK: return [[RES]] : memref<?x3x?x16xf32>
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue