Unify codes in shape inference and conversion (#98)

* Use AffineMap

* Shared AffineMap

* AffineMap for Conv/Pooling

* Create helper files

* Remove changes for Relu

* Remove redundant includes

* Use AffineMap for AveragePool's shape inference

* Add MLIR tests for unknown dimension case

* Extract a method AffineMapIntConstant

* Comment stylist and include path

Co-authored-by: Gheorghe-Teodor Bercea <gt.bercea@gmail.com>
This commit is contained in:
Tung D. Le 2020-05-14 18:31:33 +09:00 committed by GitHub
parent d65a6e72dd
commit 4d8b855c17
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 165 additions and 76 deletions

View File

@ -8,7 +8,6 @@
// //
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
#include "mlir/IR/AffineExpr.h"
#include "src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp" #include "src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp"
using namespace mlir; using namespace mlir;
@ -148,58 +147,30 @@ Value insertAllocAndDeallocForPooling(ConversionPatternRewriter &rewriter,
} }
} }
Value zero, one; // Obtain an affine map to compute the output dimension.
if (ceilMode) { AffineMap dimMap = getConvDimMap(rewriter, ceilMode);
zero = rewriter.create<ConstantOp>(
loc, rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0));
}
one = rewriter.create<ConstantOp>(
loc, rewriter.getIntegerAttr(rewriter.getIntegerType(64), 1));
for (int i = kernelOffset; i < resultShape.size(); ++i) { for (int i = kernelOffset; i < resultShape.size(); ++i) {
if (resultShape[i] < 0) { if (resultShape[i] < 0) {
// dim =
// let numerator = (input + pad - (kernel - 1) * dilation - 1)
// in let denominator = stride
// in
// if (ceilMode)
// ceil(numerator / denominator) + 1
// else
// floor(numerator / denominator) + 1
int spatialIndex = i - kernelOffset; int spatialIndex = i - kernelOffset;
// Prepare arguments for the affine map.
SmallVector<Value, 4> dimArgs;
dimArgs.emplace_back(rewriter.create<DimOp>(loc, inputOperand, i));
dimArgs.emplace_back(emitConstantOp(
rewriter, loc, rewriter.getIndexType(), kernelShape[spatialIndex]));
dimArgs.emplace_back(
emitConstantOp(rewriter, loc, rewriter.getIndexType(),
(pads[spatialIndex] + pads[spatialIndex + kernelRank])));
dimArgs.emplace_back(emitConstantOp(
rewriter, loc, rewriter.getIndexType(), strides[spatialIndex]));
dimArgs.emplace_back(
emitConstantOp(rewriter, loc, rewriter.getIndexType(),
dilations.empty() ? 1 : dilations[spatialIndex]));
// numerator = (input + pad - (kernel - 1) * dilation - 1) // Apply the affine map.
int64_t dilation = dilations.empty() ? 1 : dilations[spatialIndex];
int64_t padKernelDilation =
(pads[spatialIndex] + pads[spatialIndex + kernelRank]) -
(kernelShape[spatialIndex] - 1) * dilation - 1;
auto padKernelDilationVal = emitConstantOp(
rewriter, loc, rewriter.getIntegerType(64), padKernelDilation);
auto inputDim = rewriter.create<DimOp>(loc, inputOperand, i);
auto inputDimVal = rewriter.create<IndexCastOp>(
loc, inputDim, rewriter.getIntegerType(64));
auto numeratorVal =
rewriter.create<AddIOp>(loc, inputDimVal, padKernelDilationVal);
// denominator
auto denominatorVal = emitConstantOp(
rewriter, loc, rewriter.getIntegerType(64), strides[spatialIndex]);
// numerator / denominator
Value dimVal = Value dimVal =
rewriter.create<SignedDivIOp>(loc, numeratorVal, denominatorVal); rewriter.create<AffineApplyOp>(loc, dimMap, ValueRange(dimArgs));
if (ceilMode) { allocOperands.emplace_back(dimVal);
auto remainder =
rewriter.create<SignedRemIOp>(loc, numeratorVal, denominatorVal);
auto isZero =
rewriter.create<CmpIOp>(loc, CmpIPredicate::eq, remainder, zero);
auto dimPlusOne = rewriter.create<AddIOp>(loc, dimVal, one);
dimVal = rewriter.create<SelectOp>(loc, isZero, dimVal, dimPlusOne);
}
dimVal = rewriter.create<AddIOp>(loc, dimVal, one);
allocOperands.emplace_back(
rewriter.create<IndexCastOp>(loc, dimVal, rewriter.getIndexType()));
} }
} }
alloc = rewriter.create<AllocOp>(loc, memRefType, allocOperands); alloc = rewriter.create<AllocOp>(loc, memRefType, allocOperands);

View File

@ -24,6 +24,7 @@
#include "src/Dialect/Krnl/KrnlHelper.hpp" #include "src/Dialect/Krnl/KrnlHelper.hpp"
#include "src/Dialect/Krnl/KrnlOps.hpp" #include "src/Dialect/Krnl/KrnlOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp"
#include "src/Dialect/ONNX/ONNXOpsHelper.hpp"
#include "src/Pass/Passes.hpp" #include "src/Pass/Passes.hpp"
using namespace mlir; using namespace mlir;

View File

@ -6,7 +6,9 @@ add_public_tablegen_target(OMONNXOpsIncGen)
add_library(OMONNXOps add_library(OMONNXOps
ONNXOps.cpp ONNXOps.cpp
ONNXOps.hpp) ONNXOps.hpp
ONNXOpsHelper.cpp
ONNXOpsHelper.hpp)
target_include_directories(OMONNXOps target_include_directories(OMONNXOps
PRIVATE PRIVATE
${ONNX_MLIR_SRC_ROOT} ${ONNX_MLIR_SRC_ROOT}

View File

@ -21,6 +21,7 @@
#include "llvm/ADT/SmallBitVector.h" #include "llvm/ADT/SmallBitVector.h"
#include "ONNXOps.hpp" #include "ONNXOps.hpp"
#include "ONNXOpsHelper.hpp"
using namespace mlir; using namespace mlir;
using namespace mlir::OpTrait::util; using namespace mlir::OpTrait::util;
@ -48,6 +49,30 @@ static mlir::ONNXConstantOp getONNXConstantOp(Value value) {
return dyn_cast_or_null<mlir::ONNXConstantOp>(value.getDefiningOp()); return dyn_cast_or_null<mlir::ONNXConstantOp>(value.getDefiningOp());
} }
// This method substitutes any uses of dimensions and symbols (e.g.
// dim#0 with dimReplacements[0]) in an affine map, simplifies the modified
// affine map, and returns an integer constant.
int64_t AffineMapIntConstant(Builder &builder, AffineMap map,
ArrayRef<int64_t> dimReplacements, ArrayRef<int64_t> symReplacements,
unsigned numResultDims, unsigned numResultSyms) {
// Prepare affine expressions.
SmallVector<AffineExpr, 4> dimExprs, symExprs;
for (int64_t dim : dimReplacements) {
AffineExpr exp = builder.getAffineConstantExpr(dim);
dimExprs.emplace_back(exp);
}
for (int64_t sym : symReplacements) {
AffineExpr exp = builder.getAffineConstantExpr(sym);
symExprs.emplace_back(exp);
}
// Replace all the affine map's arguments with real values and evaluate the
// map.
AffineMap replacedDimMap = map.replaceDimsAndSymbols(
dimExprs, symExprs, numResultDims, numResultSyms);
AffineMap simplifiedMap = simplifyAffineMap(replacedDimMap);
return simplifiedMap.getSingleConstantResult();
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Get reduction type // Get reduction type
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -267,33 +292,27 @@ static void processConvTypeParams(T *op, Value inputOperand) {
// Compute spatial dimensions given dilations, strides, pads, and ceil mode. // Compute spatial dimensions given dilations, strides, pads, and ceil mode.
// //
static void insertConvSpatialDim(SmallVector<int64_t, 4> *outputDims, static void insertConvSpatialDim(SmallVector<int64_t, 4> *outputDims,
ArrayRef<int64_t> xShape, Optional<ArrayAttr> kernelShape, Builder &builder, ArrayRef<int64_t> xShape, Optional<ArrayAttr> kernelShape,
Optional<ArrayAttr> padsOpt, Optional<ArrayAttr> stridesOpt, Optional<ArrayAttr> padsOpt, Optional<ArrayAttr> stridesOpt,
Optional<ArrayAttr> dilationsOpt = llvm::None, bool ceilMode = false) { Optional<ArrayAttr> dilationsOpt = llvm::None, bool ceilMode = false) {
auto xRank = xShape.size();
auto spatialRank = ArrayAttrSize(kernelShape); auto spatialRank = ArrayAttrSize(kernelShape);
auto spatialOffset = xRank - spatialRank; auto spatialOffset = xShape.size() - spatialRank;
int64_t dilationVal = 1; // Get an affine map to compute the output dimension.
AffineMap dimMap = getConvDimMap(builder, ceilMode);
for (int i = 0; i < spatialRank; ++i) { for (int i = 0; i < spatialRank; ++i) {
int64_t res = -1;
if (xShape[spatialOffset + i] != -1) {
auto inputSize = xShape[spatialOffset + i]; auto inputSize = xShape[spatialOffset + i];
auto sumOfPads =
ArrayAttrIntVal(padsOpt, i) + ArrayAttrIntVal(padsOpt, spatialRank + i);
auto kernelSize = ArrayAttrIntVal(kernelShape, i); auto kernelSize = ArrayAttrIntVal(kernelShape, i);
auto sumOfPads = ArrayAttrIntVal(padsOpt, i) +
ArrayAttrIntVal(padsOpt, spatialRank + i);
auto strideVal = ArrayAttrIntVal(stridesOpt, i);
int64_t dilationVal = 1;
if (dilationsOpt.hasValue()) if (dilationsOpt.hasValue())
dilationVal = ArrayAttrIntVal(dilationsOpt, i); dilationVal = ArrayAttrIntVal(dilationsOpt, i);
auto strideVal = ArrayAttrIntVal(stridesOpt, i); res = AffineMapIntConstant(builder, dimMap, {inputSize},
// Number of useful values: input plus pad - effective size of kernel (see {kernelSize, sumOfPads, strideVal, dilationVal}, 1, 4);
// processConvTypeParams comments to see how this value is derived).
double numerator =
inputSize + sumOfPads - ((kernelSize - 1) * dilationVal + 1);
// Useful number is divided by the strides.
double denominator = strideVal;
int64_t res;
if (ceilMode) {
res = ceil(numerator / denominator) + 1;
} else {
res = floor(numerator / denominator) + 1;
} }
outputDims->emplace_back(res); outputDims->emplace_back(res);
} }
@ -1343,8 +1362,8 @@ bool ONNXConvOp::inferShapes() {
// Insert number of filters being applied (number of output channels). // Insert number of filters being applied (number of output channels).
outputDims.emplace_back(weightShape[0]); outputDims.emplace_back(weightShape[0]);
// Compute and insert spatial dims. // Compute and insert spatial dims.
insertConvSpatialDim( insertConvSpatialDim(&outputDims, builder, xShape, kernelShape, padsOpt,
&outputDims, xShape, kernelShape, padsOpt, stridesOpt, dilationsOpt); stridesOpt, dilationsOpt);
getResult().setType(RankedTensorType::get(outputDims, xTy.getElementType())); getResult().setType(RankedTensorType::get(outputDims, xTy.getElementType()));
return true; return true;
@ -1365,6 +1384,8 @@ bool ONNXAveragePoolOp::inferShapes() {
return false; return false;
} }
auto builder = mlir::Builder(getContext());
// Get shape of input. // Get shape of input.
auto xTy = X().getType().cast<RankedTensorType>(); auto xTy = X().getType().cast<RankedTensorType>();
auto xShape = xTy.getShape(); auto xShape = xTy.getShape();
@ -1390,8 +1411,8 @@ bool ONNXAveragePoolOp::inferShapes() {
outputDims.emplace_back(xShape[0]); outputDims.emplace_back(xShape[0]);
outputDims.emplace_back(xShape[1]); outputDims.emplace_back(xShape[1]);
// Compute and insert spatial dims. // Compute and insert spatial dims.
insertConvSpatialDim(&outputDims, xShape, kernelShape, padsOpt, stridesOpt, insertConvSpatialDim(&outputDims, builder, xShape, kernelShape, padsOpt,
llvm::None, ceilMode); stridesOpt, llvm::None, ceilMode);
getResult().setType(RankedTensorType::get(outputDims, xTy.getElementType())); getResult().setType(RankedTensorType::get(outputDims, xTy.getElementType()));
return true; return true;
@ -1412,6 +1433,8 @@ bool ONNXMaxPoolSingleOutOp::inferShapes() {
return false; return false;
} }
auto builder = mlir::Builder(getContext());
// Get shape of input. // Get shape of input.
auto xTy = X().getType().cast<RankedTensorType>(); auto xTy = X().getType().cast<RankedTensorType>();
auto xShape = xTy.getShape(); auto xShape = xTy.getShape();
@ -1441,8 +1464,8 @@ bool ONNXMaxPoolSingleOutOp::inferShapes() {
outputDims.emplace_back(xShape[0]); outputDims.emplace_back(xShape[0]);
outputDims.emplace_back(xShape[1]); outputDims.emplace_back(xShape[1]);
// Compute and insert spatial dims. // Compute and insert spatial dims.
insertConvSpatialDim(&outputDims, xShape, kernelShape, padsOpt, stridesOpt, insertConvSpatialDim(&outputDims, builder, xShape, kernelShape, padsOpt,
dilationsOpt, ceilMode); stridesOpt, dilationsOpt, ceilMode);
getResult().setType(RankedTensorType::get(outputDims, xTy.getElementType())); getResult().setType(RankedTensorType::get(outputDims, xTy.getElementType()));
return true; return true;

View File

@ -0,0 +1,42 @@
//===------- ONNXOpsHelper.cpp - Helper functions for ONNX dialects -------===//
//
// Copyright 2019 The IBM Research Authors.
//
// =============================================================================
//
// This file contains helper functions for lowering ONNX ops to Krnl Dialect.
//
//===----------------------------------------------------------------------===//
#include "ONNXOpsHelper.hpp"
// Identity affine
using namespace mlir;
AffineMap getIdentityDimMap(Builder &builder) {
return AffineMap::get(1, 0, {builder.getAffineDimExpr(0)});
}
// Pool/conv affine
// dim =
// let numerator = (input + pad - (kernel - 1) * dilation - 1)
// in let denominator = stride
// in
// if (ceilMode)
// ceil(numerator / denominator) + 1
// else
// floor(numerator / denominator) + 1
AffineMap getConvDimMap(Builder &builder, bool ceilMode) {
AffineExpr input = builder.getAffineDimExpr(0);
AffineExpr kernel = builder.getAffineSymbolExpr(0);
AffineExpr pad = builder.getAffineSymbolExpr(1);
AffineExpr stride = builder.getAffineSymbolExpr(2);
AffineExpr dilation = builder.getAffineSymbolExpr(3);
AffineExpr dimExp;
if (ceilMode)
dimExp = (input + pad - (kernel - 1) * dilation - 1).ceilDiv(stride) + 1;
else
dimExp = (input + pad - (kernel - 1) * dilation - 1).floorDiv(stride) + 1;
return AffineMap::get(1, 4, {dimExp});
}

View File

@ -0,0 +1,33 @@
//===------- ONNXOpsHelper.hpp - Helper functions for ONNX dialects -------===//
//
// Copyright 2019 The IBM Research Authors.
//
// =============================================================================
//
// This file contains helper functions for lowering ONNX ops to Krnl Dialect.
//
//===----------------------------------------------------------------------===//
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Builders.h"
using namespace mlir;
// Identity affine map:
// #map = affine_map<(d0)[] -> d0>
AffineMap getIdentityDimMap(Builder &builder);
// Pool/conv affine map:
// #map0 = affine_map<(d0)[s0, s1, s2, s3]
// -> (d0 + s1 - (s0 - 1) * s3 - 1) floordiv s2 + 1>
// In the case of `ceilMode = true`:
// #map0 = affine_map<(d0)[s0, s1, s2, s3]
// -> (d0 + s1 - (s0 - 1) * s3 - 1) ceildiv s2 + 1>
// where:
// - d0: input dim
// - s0: kernel
// - s1: pad
// - s2: stride
// - s3: dilation
AffineMap getConvDimMap(Builder &builder, bool ceilMode);

View File

@ -1727,6 +1727,23 @@ func @test_pool_general_computation(%arg0 : tensor<1x3x32x32xf32>) -> tensor<*xf
// ----- // -----
func @test_pool_unknown_dimensions(%arg0 : tensor<1x3x?x32xf32>) -> tensor<*xf32> {
%0 = "onnx.AveragePool"(%arg0) {auto_pad = "NOTSET", kernel_shape = [2, 2]} : (tensor<1x3x?x32xf32>) -> tensor<*xf32>
"std.return"(%0) : (tensor<*xf32>) -> ()
// CHECK-DAG: #[[AFFINE_MAP:.+]] = affine_map<(d0)[s0, s1, s2, s3] -> ((d0 + s1 - (s0 - 1) * s3 - 1) floordiv s2 + 1)>
// CHECK-LABEL: test_pool_unknown_dimensions
// CHECK: [[DIM:%.+]] = dim %arg0, 2 : memref<1x3x?x32xf32>
// CHECK: [[KERNEL:%.+]] = constant 2 : index
// CHECK: [[PAD:%.+]] = constant 0 : index
// CHECK: [[STRIDE:%.+]] = constant 1 : index
// CHECK: [[DILATION:%.+]] = constant 1 : index
// CHECK: [[AFFINE_APPLY:%.+]] = affine.apply #[[AFFINE_MAP]]([[DIM]]){{.*}}[[KERNEL]], [[PAD]], [[STRIDE]], [[DILATION]]{{.*}}
// CHECK: [[RES:%.+]] = alloc([[AFFINE_APPLY]]) : memref<1x3x?x31xf32>
}
// -----
func @test_averagepool_identity_value(%arg0 : tensor<1x3x32x32xf32>) -> tensor<*xf32> { func @test_averagepool_identity_value(%arg0 : tensor<1x3x32x32xf32>) -> tensor<*xf32> {
%0 = "onnx.AveragePool"(%arg0) {auto_pad = "NOTSET", kernel_shape = [2, 2]} : (tensor<1x3x32x32xf32>) -> tensor<*xf32> %0 = "onnx.AveragePool"(%arg0) {auto_pad = "NOTSET", kernel_shape = [2, 2]} : (tensor<1x3x32x32xf32>) -> tensor<*xf32>
"std.return"(%0) : (tensor<*xf32>) -> () "std.return"(%0) : (tensor<*xf32>) -> ()