diff --git a/src/Conversion/ONNXToKrnl/NN/Pooling.cpp b/src/Conversion/ONNXToKrnl/NN/Pooling.cpp index 3113d0d..970588c 100644 --- a/src/Conversion/ONNXToKrnl/NN/Pooling.cpp +++ b/src/Conversion/ONNXToKrnl/NN/Pooling.cpp @@ -8,7 +8,6 @@ // //===----------------------------------------------------------------------===// -#include "mlir/IR/AffineExpr.h" #include "src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp" using namespace mlir; @@ -148,58 +147,30 @@ Value insertAllocAndDeallocForPooling(ConversionPatternRewriter &rewriter, } } - Value zero, one; - if (ceilMode) { - zero = rewriter.create( - loc, rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0)); - } - one = rewriter.create( - loc, rewriter.getIntegerAttr(rewriter.getIntegerType(64), 1)); - + // Obtain an affine map to compute the output dimension. + AffineMap dimMap = getConvDimMap(rewriter, ceilMode); for (int i = kernelOffset; i < resultShape.size(); ++i) { if (resultShape[i] < 0) { - // dim = - // let numerator = (input + pad - (kernel - 1) * dilation - 1) - // in let denominator = stride - // in - // if (ceilMode) - // ceil(numerator / denominator) + 1 - // else - // floor(numerator / denominator) + 1 int spatialIndex = i - kernelOffset; + // Prepare arguments for the affine map. + SmallVector dimArgs; + dimArgs.emplace_back(rewriter.create(loc, inputOperand, i)); + dimArgs.emplace_back(emitConstantOp( + rewriter, loc, rewriter.getIndexType(), kernelShape[spatialIndex])); + dimArgs.emplace_back( + emitConstantOp(rewriter, loc, rewriter.getIndexType(), + (pads[spatialIndex] + pads[spatialIndex + kernelRank]))); + dimArgs.emplace_back(emitConstantOp( + rewriter, loc, rewriter.getIndexType(), strides[spatialIndex])); + dimArgs.emplace_back( + emitConstantOp(rewriter, loc, rewriter.getIndexType(), + dilations.empty() ? 1 : dilations[spatialIndex])); - // numerator = (input + pad - (kernel - 1) * dilation - 1) - int64_t dilation = dilations.empty() ? 1 : dilations[spatialIndex]; - int64_t padKernelDilation = - (pads[spatialIndex] + pads[spatialIndex + kernelRank]) - - (kernelShape[spatialIndex] - 1) * dilation - 1; - auto padKernelDilationVal = emitConstantOp( - rewriter, loc, rewriter.getIntegerType(64), padKernelDilation); - auto inputDim = rewriter.create(loc, inputOperand, i); - auto inputDimVal = rewriter.create( - loc, inputDim, rewriter.getIntegerType(64)); - auto numeratorVal = - rewriter.create(loc, inputDimVal, padKernelDilationVal); - // denominator - auto denominatorVal = emitConstantOp( - rewriter, loc, rewriter.getIntegerType(64), strides[spatialIndex]); - - // numerator / denominator + // Apply the affine map. Value dimVal = - rewriter.create(loc, numeratorVal, denominatorVal); + rewriter.create(loc, dimMap, ValueRange(dimArgs)); - if (ceilMode) { - auto remainder = - rewriter.create(loc, numeratorVal, denominatorVal); - auto isZero = - rewriter.create(loc, CmpIPredicate::eq, remainder, zero); - auto dimPlusOne = rewriter.create(loc, dimVal, one); - dimVal = rewriter.create(loc, isZero, dimVal, dimPlusOne); - } - - dimVal = rewriter.create(loc, dimVal, one); - allocOperands.emplace_back( - rewriter.create(loc, dimVal, rewriter.getIndexType())); + allocOperands.emplace_back(dimVal); } } alloc = rewriter.create(loc, memRefType, allocOperands); diff --git a/src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp b/src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp index 95907f8..97c019e 100644 --- a/src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp +++ b/src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp @@ -24,6 +24,7 @@ #include "src/Dialect/Krnl/KrnlHelper.hpp" #include "src/Dialect/Krnl/KrnlOps.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp" +#include "src/Dialect/ONNX/ONNXOpsHelper.hpp" #include "src/Pass/Passes.hpp" using namespace mlir; diff --git a/src/Dialect/ONNX/CMakeLists.txt b/src/Dialect/ONNX/CMakeLists.txt index 4f26c0c..20c671d 100644 --- a/src/Dialect/ONNX/CMakeLists.txt +++ b/src/Dialect/ONNX/CMakeLists.txt @@ -6,7 +6,9 @@ add_public_tablegen_target(OMONNXOpsIncGen) add_library(OMONNXOps ONNXOps.cpp - ONNXOps.hpp) + ONNXOps.hpp + ONNXOpsHelper.cpp + ONNXOpsHelper.hpp) target_include_directories(OMONNXOps PRIVATE ${ONNX_MLIR_SRC_ROOT} diff --git a/src/Dialect/ONNX/ONNXOps.cpp b/src/Dialect/ONNX/ONNXOps.cpp index 738c45c..16cd67b 100644 --- a/src/Dialect/ONNX/ONNXOps.cpp +++ b/src/Dialect/ONNX/ONNXOps.cpp @@ -21,6 +21,7 @@ #include "llvm/ADT/SmallBitVector.h" #include "ONNXOps.hpp" +#include "ONNXOpsHelper.hpp" using namespace mlir; using namespace mlir::OpTrait::util; @@ -48,6 +49,30 @@ static mlir::ONNXConstantOp getONNXConstantOp(Value value) { return dyn_cast_or_null(value.getDefiningOp()); } +// This method substitutes any uses of dimensions and symbols (e.g. +// dim#0 with dimReplacements[0]) in an affine map, simplifies the modified +// affine map, and returns an integer constant. +int64_t AffineMapIntConstant(Builder &builder, AffineMap map, + ArrayRef dimReplacements, ArrayRef symReplacements, + unsigned numResultDims, unsigned numResultSyms) { + // Prepare affine expressions. + SmallVector dimExprs, symExprs; + for (int64_t dim : dimReplacements) { + AffineExpr exp = builder.getAffineConstantExpr(dim); + dimExprs.emplace_back(exp); + } + for (int64_t sym : symReplacements) { + AffineExpr exp = builder.getAffineConstantExpr(sym); + symExprs.emplace_back(exp); + } + // Replace all the affine map's arguments with real values and evaluate the + // map. + AffineMap replacedDimMap = map.replaceDimsAndSymbols( + dimExprs, symExprs, numResultDims, numResultSyms); + AffineMap simplifiedMap = simplifyAffineMap(replacedDimMap); + return simplifiedMap.getSingleConstantResult(); +} + //===----------------------------------------------------------------------===// // Get reduction type //===----------------------------------------------------------------------===// @@ -267,33 +292,27 @@ static void processConvTypeParams(T *op, Value inputOperand) { // Compute spatial dimensions given dilations, strides, pads, and ceil mode. // static void insertConvSpatialDim(SmallVector *outputDims, - ArrayRef xShape, Optional kernelShape, + Builder &builder, ArrayRef xShape, Optional kernelShape, Optional padsOpt, Optional stridesOpt, Optional dilationsOpt = llvm::None, bool ceilMode = false) { - auto xRank = xShape.size(); auto spatialRank = ArrayAttrSize(kernelShape); - auto spatialOffset = xRank - spatialRank; + auto spatialOffset = xShape.size() - spatialRank; - int64_t dilationVal = 1; + // Get an affine map to compute the output dimension. + AffineMap dimMap = getConvDimMap(builder, ceilMode); for (int i = 0; i < spatialRank; ++i) { - auto inputSize = xShape[spatialOffset + i]; - auto sumOfPads = - ArrayAttrIntVal(padsOpt, i) + ArrayAttrIntVal(padsOpt, spatialRank + i); - auto kernelSize = ArrayAttrIntVal(kernelShape, i); - if (dilationsOpt.hasValue()) - dilationVal = ArrayAttrIntVal(dilationsOpt, i); - auto strideVal = ArrayAttrIntVal(stridesOpt, i); - // Number of useful values: input plus pad - effective size of kernel (see - // processConvTypeParams comments to see how this value is derived). - double numerator = - inputSize + sumOfPads - ((kernelSize - 1) * dilationVal + 1); - // Useful number is divided by the strides. - double denominator = strideVal; - int64_t res; - if (ceilMode) { - res = ceil(numerator / denominator) + 1; - } else { - res = floor(numerator / denominator) + 1; + int64_t res = -1; + if (xShape[spatialOffset + i] != -1) { + auto inputSize = xShape[spatialOffset + i]; + auto kernelSize = ArrayAttrIntVal(kernelShape, i); + auto sumOfPads = ArrayAttrIntVal(padsOpt, i) + + ArrayAttrIntVal(padsOpt, spatialRank + i); + auto strideVal = ArrayAttrIntVal(stridesOpt, i); + int64_t dilationVal = 1; + if (dilationsOpt.hasValue()) + dilationVal = ArrayAttrIntVal(dilationsOpt, i); + res = AffineMapIntConstant(builder, dimMap, {inputSize}, + {kernelSize, sumOfPads, strideVal, dilationVal}, 1, 4); } outputDims->emplace_back(res); } @@ -1343,8 +1362,8 @@ bool ONNXConvOp::inferShapes() { // Insert number of filters being applied (number of output channels). outputDims.emplace_back(weightShape[0]); // Compute and insert spatial dims. - insertConvSpatialDim( - &outputDims, xShape, kernelShape, padsOpt, stridesOpt, dilationsOpt); + insertConvSpatialDim(&outputDims, builder, xShape, kernelShape, padsOpt, + stridesOpt, dilationsOpt); getResult().setType(RankedTensorType::get(outputDims, xTy.getElementType())); return true; @@ -1365,6 +1384,8 @@ bool ONNXAveragePoolOp::inferShapes() { return false; } + auto builder = mlir::Builder(getContext()); + // Get shape of input. auto xTy = X().getType().cast(); auto xShape = xTy.getShape(); @@ -1390,8 +1411,8 @@ bool ONNXAveragePoolOp::inferShapes() { outputDims.emplace_back(xShape[0]); outputDims.emplace_back(xShape[1]); // Compute and insert spatial dims. - insertConvSpatialDim(&outputDims, xShape, kernelShape, padsOpt, stridesOpt, - llvm::None, ceilMode); + insertConvSpatialDim(&outputDims, builder, xShape, kernelShape, padsOpt, + stridesOpt, llvm::None, ceilMode); getResult().setType(RankedTensorType::get(outputDims, xTy.getElementType())); return true; @@ -1412,6 +1433,8 @@ bool ONNXMaxPoolSingleOutOp::inferShapes() { return false; } + auto builder = mlir::Builder(getContext()); + // Get shape of input. auto xTy = X().getType().cast(); auto xShape = xTy.getShape(); @@ -1441,8 +1464,8 @@ bool ONNXMaxPoolSingleOutOp::inferShapes() { outputDims.emplace_back(xShape[0]); outputDims.emplace_back(xShape[1]); // Compute and insert spatial dims. - insertConvSpatialDim(&outputDims, xShape, kernelShape, padsOpt, stridesOpt, - dilationsOpt, ceilMode); + insertConvSpatialDim(&outputDims, builder, xShape, kernelShape, padsOpt, + stridesOpt, dilationsOpt, ceilMode); getResult().setType(RankedTensorType::get(outputDims, xTy.getElementType())); return true; diff --git a/src/Dialect/ONNX/ONNXOpsHelper.cpp b/src/Dialect/ONNX/ONNXOpsHelper.cpp new file mode 100644 index 0000000..347336a --- /dev/null +++ b/src/Dialect/ONNX/ONNXOpsHelper.cpp @@ -0,0 +1,42 @@ +//===------- ONNXOpsHelper.cpp - Helper functions for ONNX dialects -------===// +// +// Copyright 2019 The IBM Research Authors. +// +// ============================================================================= +// +// This file contains helper functions for lowering ONNX ops to Krnl Dialect. +// +//===----------------------------------------------------------------------===// + +#include "ONNXOpsHelper.hpp" + +// Identity affine +using namespace mlir; +AffineMap getIdentityDimMap(Builder &builder) { + return AffineMap::get(1, 0, {builder.getAffineDimExpr(0)}); +} + +// Pool/conv affine +// dim = +// let numerator = (input + pad - (kernel - 1) * dilation - 1) +// in let denominator = stride +// in +// if (ceilMode) +// ceil(numerator / denominator) + 1 +// else +// floor(numerator / denominator) + 1 +AffineMap getConvDimMap(Builder &builder, bool ceilMode) { + AffineExpr input = builder.getAffineDimExpr(0); + AffineExpr kernel = builder.getAffineSymbolExpr(0); + AffineExpr pad = builder.getAffineSymbolExpr(1); + AffineExpr stride = builder.getAffineSymbolExpr(2); + AffineExpr dilation = builder.getAffineSymbolExpr(3); + + AffineExpr dimExp; + if (ceilMode) + dimExp = (input + pad - (kernel - 1) * dilation - 1).ceilDiv(stride) + 1; + else + dimExp = (input + pad - (kernel - 1) * dilation - 1).floorDiv(stride) + 1; + + return AffineMap::get(1, 4, {dimExp}); +} diff --git a/src/Dialect/ONNX/ONNXOpsHelper.hpp b/src/Dialect/ONNX/ONNXOpsHelper.hpp new file mode 100644 index 0000000..15d4f0e --- /dev/null +++ b/src/Dialect/ONNX/ONNXOpsHelper.hpp @@ -0,0 +1,33 @@ +//===------- ONNXOpsHelper.hpp - Helper functions for ONNX dialects -------===// +// +// Copyright 2019 The IBM Research Authors. +// +// ============================================================================= +// +// This file contains helper functions for lowering ONNX ops to Krnl Dialect. +// +//===----------------------------------------------------------------------===// + +#include "mlir/IR/AffineExpr.h" +#include "mlir/IR/AffineMap.h" +#include "mlir/IR/Builders.h" + +using namespace mlir; + +// Identity affine map: +// #map = affine_map<(d0)[] -> d0> +AffineMap getIdentityDimMap(Builder &builder); + +// Pool/conv affine map: +// #map0 = affine_map<(d0)[s0, s1, s2, s3] +// -> (d0 + s1 - (s0 - 1) * s3 - 1) floordiv s2 + 1> +// In the case of `ceilMode = true`: +// #map0 = affine_map<(d0)[s0, s1, s2, s3] +// -> (d0 + s1 - (s0 - 1) * s3 - 1) ceildiv s2 + 1> +// where: +// - d0: input dim +// - s0: kernel +// - s1: pad +// - s2: stride +// - s3: dilation +AffineMap getConvDimMap(Builder &builder, bool ceilMode); diff --git a/test/mlir/onnx/onnx_lowering.mlir b/test/mlir/onnx/onnx_lowering.mlir index 21993a5..57d71f3 100644 --- a/test/mlir/onnx/onnx_lowering.mlir +++ b/test/mlir/onnx/onnx_lowering.mlir @@ -1727,6 +1727,23 @@ func @test_pool_general_computation(%arg0 : tensor<1x3x32x32xf32>) -> tensor<*xf // ----- +func @test_pool_unknown_dimensions(%arg0 : tensor<1x3x?x32xf32>) -> tensor<*xf32> { + %0 = "onnx.AveragePool"(%arg0) {auto_pad = "NOTSET", kernel_shape = [2, 2]} : (tensor<1x3x?x32xf32>) -> tensor<*xf32> + "std.return"(%0) : (tensor<*xf32>) -> () + + // CHECK-DAG: #[[AFFINE_MAP:.+]] = affine_map<(d0)[s0, s1, s2, s3] -> ((d0 + s1 - (s0 - 1) * s3 - 1) floordiv s2 + 1)> + // CHECK-LABEL: test_pool_unknown_dimensions + // CHECK: [[DIM:%.+]] = dim %arg0, 2 : memref<1x3x?x32xf32> + // CHECK: [[KERNEL:%.+]] = constant 2 : index + // CHECK: [[PAD:%.+]] = constant 0 : index + // CHECK: [[STRIDE:%.+]] = constant 1 : index + // CHECK: [[DILATION:%.+]] = constant 1 : index + // CHECK: [[AFFINE_APPLY:%.+]] = affine.apply #[[AFFINE_MAP]]([[DIM]]){{.*}}[[KERNEL]], [[PAD]], [[STRIDE]], [[DILATION]]{{.*}} + // CHECK: [[RES:%.+]] = alloc([[AFFINE_APPLY]]) : memref<1x3x?x31xf32> +} + +// ----- + func @test_averagepool_identity_value(%arg0 : tensor<1x3x32x32xf32>) -> tensor<*xf32> { %0 = "onnx.AveragePool"(%arg0) {auto_pad = "NOTSET", kernel_shape = [2, 2]} : (tensor<1x3x32x32xf32>) -> tensor<*xf32> "std.return"(%0) : (tensor<*xf32>) -> ()