onnx-mlir/src/Dialect/ONNX/ONNXOps.cpp

1329 lines
48 KiB
C++
Raw Normal View History

//===------------------ ONNXOps.cpp - ONNX Operations ---------------------===//
//
// Copyright 2019-2020 The IBM Research Authors.
//
// =============================================================================
//
// This file provides definition of ONNX dialect operations.
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Traits.h"
#include "mlir/IR/Block.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/Function.h"
#include "mlir/IR/IntegerSet.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/Module.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/PatternMatch.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/SmallBitVector.h"
#include "ONNXOps.hpp"
using namespace mlir;
using namespace mlir::OpTrait::util;
//===----------------------------------------------------------------------===//
// ONNX Helper functions
//===----------------------------------------------------------------------===//
static size_t ArrayAttrSize(ArrayAttr a) { return a.size(); }
static size_t ArrayAttrSize(Optional<ArrayAttr> a) {
return a.getValue().size();
}
static int64_t ArrayAttrIntVal(ArrayAttr a, int i) {
return (a.getValue()[i]).cast<IntegerAttr>().getInt();
}
static int64_t ArrayAttrIntVal(Optional<ArrayAttr> a, int i) {
return (a.getValue().getValue()[i]).cast<IntegerAttr>().getInt();
}
// Returns the ConstantOp which defines an MLIR Value or null.
static mlir::ONNXConstantOp getONNXConstantOp(Value value) {
return dyn_cast_or_null<mlir::ONNXConstantOp>(value.getDefiningOp());
}
//===----------------------------------------------------------------------===//
// Get reduction type
//===----------------------------------------------------------------------===//
RankedTensorType getReductionOutputType(
RankedTensorType operandTy, Optional<ArrayAttr> axesAttrs, APInt keepdims) {
int64_t rank = operandTy.getRank();
SmallVector<int64_t, 4> axes;
if (axesAttrs != llvm::None) {
for (auto axisAttr : axesAttrs.getValue()) {
int64_t axis = axisAttr.cast<IntegerAttr>().getInt();
axis = axis >= 0 ? axis : (rank + axis);
assert(axis >= -rank && axis <= rank - 1);
if (std::find(axes.begin(), axes.end(), axis) == axes.end())
axes.emplace_back(axis);
}
} else {
for (decltype(rank) i = 0; i < rank; ++i) {
axes.emplace_back(i);
}
}
// Mark reduction axes.
SmallVector<bool, 4> isReductionAxis;
for (decltype(rank) i = 0; i < rank; ++i) {
if (std::find(axes.begin(), axes.end(), i) != axes.end())
isReductionAxis.emplace_back(true);
else
isReductionAxis.emplace_back(false);
}
// KeepDims
bool isKeepdims = (keepdims == 1) ? true : false;
SmallVector<int64_t, 4> dims;
for (decltype(rank) i = 0; i < rank; ++i) {
if (isReductionAxis[i]) {
if (isKeepdims)
dims.emplace_back(1); // reduction dimension
} else {
dims.emplace_back(operandTy.getShape()[i]);
}
}
return RankedTensorType::get(dims, operandTy.getElementType());
}
//===----------------------------------------------------------------------===//
// Support function that computes default values for dilations.
//
template <class T>
static void processConvDilationParam(T *op, Optional<ArrayAttr> kernelShape) {
auto builder = mlir::Builder(op->getContext());
auto kernelRank = ArrayAttrSize(kernelShape);
auto dilationsOpt = op->dilations();
if (dilationsOpt.hasValue()) {
if (ArrayAttrSize(dilationsOpt) != kernelRank)
op->emitError("dialation rank is not the same as the spatial rank");
// Test values to be greater than 0.
for (int i = 0; i < kernelRank; ++i) {
if (ArrayAttrIntVal(dilationsOpt, i) < 1)
op->emitError("dialation value must be nonzero positive");
}
} else {
// Default dilatation is needed, all dimensions init with 1.
SmallVector<int64_t, 4> defaultVals(kernelRank, 1);
// Convert to ArrayRef, then build attribute, then store attribute.
ArrayRef<int64_t> defaultRefs(defaultVals);
op->dilationsAttr(builder.getI64ArrayAttr(defaultRefs));
}
}
//===----------------------------------------------------------------------===//
// Support function that computes default values for strides.
//
template <class T>
static void processConvStrideParam(T *op, Optional<ArrayAttr> kernelShape) {
auto builder = mlir::Builder(op->getContext());
auto kernelRank = ArrayAttrSize(kernelShape);
auto stridesOpt = op->strides();
if (stridesOpt.hasValue()) {
if (ArrayAttrSize(stridesOpt) != kernelRank)
op->emitError("strides rank is not the same as the spatial rank");
// Check values to be greater than 0.
for (int i = 0; i < kernelRank; ++i) {
if (ArrayAttrIntVal(stridesOpt, i) < 1)
op->emitError("strides value must be nonzero positive");
}
} else {
// Default stride is needed, all dimensions init with 1.
SmallVector<int64_t, 4> defaultVals(kernelRank, 1);
// Convert to ArrayRef, then build attribute, then store attribute.
ArrayRef<int64_t> defaultRefs(defaultVals);
op->stridesAttr(builder.getI64ArrayAttr(defaultRefs));
}
}
//===----------------------------------------------------------------------===//
// Support function that computes default values for pads.
//
template <class T>
static void processConvPadParam(T *op,
ArrayRef<int64_t> inputShape, Optional<ArrayAttr> kernelShape,
Optional<ArrayAttr> stridesOpt,
Optional<ArrayAttr> dilationsOpt = llvm::None) {
auto builder = mlir::Builder(op->getContext());
auto inputRank = inputShape.size();
auto kernelRank = ArrayAttrSize(kernelShape);
auto kernelOffset = inputRank - kernelRank;
// Try to find padding, getting auto_pad attribute first.
auto autoPad = op->auto_pad();
// And then investigate the various different cases. Prefill pad values with
// zeros, the most common case.
SmallVector<int64_t, 4> actualPads(2 * kernelRank, 0);
bool updatedPad = false;
if (autoPad == "NOTSET") {
auto padsOpt = op->pads();
if (padsOpt.hasValue()) {
// Only option where pads are not updated. Pads consists of two entries
// for each spatial axis.
if (ArrayAttrSize(padsOpt) != 2 * kernelRank)
op->emitError("pads rank is not twice the spatial rank");
// Check values, pads cannot be negative.
for (int i = 0; i < 2 * kernelRank; ++i) {
if (ArrayAttrIntVal(padsOpt, i) < 0)
op->emitError("pads value must be nonnegative");
}
} else {
// We have notset with no pads, they are assumed to be all zero.
updatedPad = true;
}
} else if (autoPad == "SAME_UPPER" || autoPad == "SAME_LOWER") {
// Reload dialtion and strides as they may have gotten default values.
updatedPad = true;
int64_t dilationVal = 1;
for (int i = 0; i < kernelRank; ++i) {
auto inputSize = inputShape[kernelOffset + i];
auto kernelSize = ArrayAttrIntVal(kernelShape, i);
if (dilationsOpt.hasValue())
dilationVal = ArrayAttrIntVal(dilationsOpt, i);
auto strideVal = ArrayAttrIntVal(stridesOpt, i);
// Output size is input size divided by stride. When stride is 1, then
// input and output are the same size, which is the usual case. When
// stride is greater than 1, take the ceil to be sure to have each input
// value used, as padding will be used to fill the gaps.
int64_t outputSize = ceil((1.0 * inputSize) / (1.0 * strideVal));
// Forumla is from ONNX MaxPool, and can be explained as follows. Pads is
// the difference between the needed values for the computations, minus
// the input values. The needed values for the computation is the
// effective side of the kernel plus the number of times we jump to the
// next kernel. Number of time we jump is (outputSize - 1). That number is
// multiplied with the size of the jump, namely strideVal. Now for the
// effective kernel size. It is the kernelSize + the number of times we
// have dilation holes time the dialtion. The number of dialtion holes is
// (kernelSize -1). Thus the effective size is "kernelSize +
// (kernelSize-1)*dialation". This simplifies to "(kernelSize
// -1)*dialation + 1".
auto sumOfPad = (outputSize - 1) * strideVal +
((kernelSize - 1) * dilationVal + 1) - inputSize;
// Pad values are assumed equal on both size, at half the total value.
actualPads[i] = actualPads[kernelRank + i] = sumOfPad / 2;
// But if the total pad value is odd, we add 1 to begining or end
// depending on autoPad value.
if (sumOfPad % 2 != 0) {
if (autoPad == "SAME_UPPER") {
actualPads[kernelRank + i] += 1;
} else {
actualPads[i] += 1;
}
}
}
} else if (autoPad == "VALID") {
// No pad, default value was set to zero, we are all set.
updatedPad = true;
} else {
op->emitError("auto_pad of unknown / unsupported value");
}
// Set pads values in attributes, if it is needed.
if (updatedPad) {
ArrayRef<int64_t> defaultRefs(actualPads);
op->padsAttr(builder.getI64ArrayAttr(defaultRefs));
}
// In all cases now, the acutal pad values are found in the pads attribute.
op->auto_padAttr(builder.getStringAttr("NOTSET"));
}
//===----------------------------------------------------------------------===//
// Support function that computes default values for dilations, strides, and
// pads.
template <class T>
static void processConvTypeParams(T *op, Value inputOperand) {
auto builder = mlir::Builder(op->getContext());
// 1) Get shape of input.
auto inputShape = inputOperand.getType().cast<RankedTensorType>().getShape();
auto inputRank = inputShape.size();
// 2) Get kernel_shape attribute.
auto kernelShape = op->kernel_shape();
// Dilation.
processConvDilationParam<T>(op, kernelShape);
auto dilationsOpt = op->dilations();
// Strides.
processConvStrideParam<T>(op, kernelShape);
auto stridesOpt = op->strides();
// Pads.
processConvPadParam<T>(op, inputShape, kernelShape, stridesOpt, dilationsOpt);
}
//===----------------------------------------------------------------------===//
// Compute spatial dimensions given dilations, strides, pads, and ceil mode.
//
static void insertConvSpatialDim(SmallVector<int64_t, 4> *outputDims,
ArrayRef<int64_t> xShape, Optional<ArrayAttr> kernelShape,
Optional<ArrayAttr> padsOpt, Optional<ArrayAttr> stridesOpt,
Optional<ArrayAttr> dilationsOpt = llvm::None, bool ceilMode = false) {
auto xRank = xShape.size();
auto spatialRank = ArrayAttrSize(kernelShape);
auto spatialOffset = xRank - spatialRank;
int64_t dilationVal = 1;
for (int i = 0; i < spatialRank; ++i) {
auto inputSize = xShape[spatialOffset + i];
auto sumOfPads =
ArrayAttrIntVal(padsOpt, i) + ArrayAttrIntVal(padsOpt, spatialRank + i);
auto kernelSize = ArrayAttrIntVal(kernelShape, i);
if (dilationsOpt.hasValue())
dilationVal = ArrayAttrIntVal(dilationsOpt, i);
auto strideVal = ArrayAttrIntVal(stridesOpt, i);
// Number of useful values: input plus pad - effective size of kernel (see
// processConvTypeParams comments to see how this value is derived).
double numerator =
inputSize + sumOfPads - ((kernelSize - 1) * dilationVal + 1);
// Useful number is divided by the strides.
double denominator = strideVal;
int64_t res;
if (ceilMode) {
res = ceil(numerator / denominator) + 1;
} else {
res = floor(numerator / denominator) + 1;
}
outputDims->emplace_back(res);
}
}
//===----------------------------------------------------------------------===//
// ONNXOpsDialect
//===----------------------------------------------------------------------===//
/// Dialect creation, the instance will be owned by the context. This is the
/// point of registration of custom types and operations for the dialect.
ONNXOpsDialect::ONNXOpsDialect(mlir::MLIRContext *ctx)
: mlir::Dialect(getDialectNamespace(), ctx) {
addOperations<
#define GET_OP_LIST
2019-12-23 13:13:52 +08:00
#include "src/onnx.cpp.inc"
>();
}
void ONNXEntryPointOp::build(mlir::Builder *builder,
mlir::OperationState &state, mlir::FuncOp function, int numInputs,
int numOutputs) {
state.addAttribute(ONNXEntryPointOp::getEntryPointFuncAttrName(),
builder->getSymbolRefAttr(function));
state.addAttribute(ONNXEntryPointOp::getNumInputsAttrName(),
builder->getI32IntegerAttr(numInputs));
state.addAttribute(ONNXEntryPointOp::getNumOutputsAttrName(),
builder->getI32IntegerAttr(numOutputs));
}
ONNXEntryPointOp ONNXEntryPointOp::create(mlir::Location location,
mlir::FuncOp &func, int numInputs, int numOutputs) {
mlir::OperationState state(location, "onnx.EntryPoint");
Builder builder(location->getContext());
mlir::ONNXEntryPointOp::build(&builder, state, func, numInputs, numOutputs);
Operation *op = mlir::Operation::create(state);
auto onnxEntryOp = llvm::cast<mlir::ONNXEntryPointOp>(op);
return onnxEntryOp;
}
//===----------------------------------------------------------------------===//
// ONNX Operations
//===----------------------------------------------------------------------===//
// Exp
/// Infer the output shape of the ONNXExpOp. This method is required by the
/// shape inference interface.
2020-01-14 01:21:29 +08:00
void ONNXExpOp::inferShapes() { getResult().setType(getOperand().getType()); }
//===----------------------------------------------------------------------===//
// Tanh
/// Infer the output shape of the ONNXTanhOp. This method is required by the
/// shape inference interface.
void ONNXTanhOp::inferShapes() { getResult().setType(getOperand().getType()); }
//===----------------------------------------------------------------------===//
// Sinh
/// Infer the output shape of the ONNXSinhOp. This method is required by the
/// shape inference interface.
void ONNXSinhOp::inferShapes() { getResult().setType(getOperand().getType()); }
//===----------------------------------------------------------------------===//
// Cosh
/// Infer the output shape of the ONNXCoshOp. This method is required by the
/// shape inference interface.
void ONNXCoshOp::inferShapes() { getResult().setType(getOperand().getType()); }
//===----------------------------------------------------------------------===//
// Cos
/// Infer the output shape of the ONNXCosOp. This method is required by the
/// shape inference interface.
2020-01-14 01:21:29 +08:00
void ONNXCosOp::inferShapes() { getResult().setType(getOperand().getType()); }
//===----------------------------------------------------------------------===//
// Log
/// Infer the output shape of the ONNXLogOp. This method is required by the
/// shape inference interface.
2020-01-14 01:21:29 +08:00
void ONNXLogOp::inferShapes() { getResult().setType(getOperand().getType()); }
//===----------------------------------------------------------------------===//
// HardSigmoid
/// Infer the output shape of the ONNXHardSigmoidOp. This method is required by
/// the shape inference interface.
void ONNXHardSigmoidOp::inferShapes() {
2020-01-14 01:21:29 +08:00
getResult().setType(getOperand().getType());
}
//===----------------------------------------------------------------------===//
// Sigmoid
/// Infer the output shape of the ONNXSigmoidOp. This method is required by the
/// shape inference interface.
void ONNXSigmoidOp::inferShapes() {
2020-01-14 01:21:29 +08:00
getResult().setType(getOperand().getType());
}
//===----------------------------------------------------------------------===//
// Elu
/// Infer the output shape of the ONNXEluOp. This method is required by the
/// shape inference interface.
2020-01-14 01:21:29 +08:00
void ONNXEluOp::inferShapes() { getResult().setType(getOperand().getType()); }
//===----------------------------------------------------------------------===//
// Relu
/// Infer the output shape of the ONNXReluOp. This method is required by the
/// shape inference interface.
void ONNXReluOp::inferShapes() { getResult().setType(getOperand().getType()); }
//===----------------------------------------------------------------------===//
// LeakyRelu
/// Infer the output shape of the ONNXLeakyReluOp. This method is required by
/// the shape inference interface.
void ONNXLeakyReluOp::inferShapes() {
2020-01-14 01:21:29 +08:00
getResult().setType(getOperand().getType());
}
//===----------------------------------------------------------------------===//
// Selu
/// Infer the output shape of the ONNXSeluOp. This method is required by
/// the shape inference interface.
void ONNXSeluOp::inferShapes() { getResult().setType(getOperand().getType()); }
//===----------------------------------------------------------------------===//
// Reciprocal
/// Infer the output shape of the ONNXReciprocalOp. This method is required by
/// the shape inference interface.
void ONNXReciprocalOp::inferShapes() {
2020-01-14 01:21:29 +08:00
getResult().setType(getOperand().getType());
}
//===----------------------------------------------------------------------===//
// Softmax
/// Infer the output shape of the ONNXSoftmaxOp. This method is required by
/// the shape inference interface.
void ONNXSoftmaxOp::inferShapes() {
getResult().setType(getOperand().getType());
}
//===----------------------------------------------------------------------===//
// Softplus
/// Infer the output shape of the ONNXSoftplusOp. This method is required by
/// the shape inference interface.
void ONNXSoftplusOp::inferShapes() {
getResult().setType(getOperand().getType());
}
//===----------------------------------------------------------------------===//
// Softsign
/// Infer the output shape of the ONNXSoftsignOp. This method is required by
/// the shape inference interface.
void ONNXSoftsignOp::inferShapes() {
getResult().setType(getOperand().getType());
}
//===----------------------------------------------------------------------===//
// Sqrt
/// Infer the output shape of the ONNXSqrtOp. This method is required by
/// the shape inference interface.
void ONNXSqrtOp::inferShapes() { getResult().setType(getOperand().getType()); }
//===----------------------------------------------------------------------===//
// Sign
/// Infer the output shape of the ONNXSignOp. This method is required by
/// the shape inference interface.
void ONNXSignOp::inferShapes() { getResult().setType(getOperand().getType()); }
//===----------------------------------------------------------------------===//
// Abs
/// Infer the output shape of the ONNXAbsOp. This method is required by the
/// shape inference interface.
void ONNXAbsOp::inferShapes() { getResult().setType(getOperand().getType()); }
//===----------------------------------------------------------------------===//
// Add
/// Infer the output shape of the ONNXAddOp. This method is required by the
/// shape inference interface.
void ONNXAddOp::inferShapes() {
2020-01-14 01:21:29 +08:00
if (!getOperand(0).getType().isa<RankedTensorType>() ||
!getOperand(1).getType().isa<RankedTensorType>())
return;
2020-01-14 01:21:29 +08:00
auto lhsTy = getOperand(0).getType().cast<RankedTensorType>();
auto rhsTy = getOperand(1).getType().cast<RankedTensorType>();
getResult().setType(getBroadcastedType(lhsTy, rhsTy));
}
//===----------------------------------------------------------------------===//
// Mul
/// Infer the output shape of the ONNXMulOp. This method is required by the
/// shape inference interface.
void ONNXMulOp::inferShapes() {
2020-01-14 01:21:29 +08:00
if (!getOperand(0).getType().isa<RankedTensorType>() ||
!getOperand(1).getType().isa<RankedTensorType>())
return;
2020-01-14 01:21:29 +08:00
auto lhsTy = getOperand(0).getType().cast<RankedTensorType>();
auto rhsTy = getOperand(1).getType().cast<RankedTensorType>();
getResult().setType(getBroadcastedType(lhsTy, rhsTy));
}
//===----------------------------------------------------------------------===//
// Div
/// Infer the output shape of the ONNXDivOp. This method is required by the
/// shape inference interface.
void ONNXDivOp::inferShapes() {
2020-01-14 01:21:29 +08:00
if (!getOperand(0).getType().isa<RankedTensorType>() ||
!getOperand(1).getType().isa<RankedTensorType>())
return;
2020-01-14 01:21:29 +08:00
auto lhsTy = getOperand(0).getType().cast<RankedTensorType>();
auto rhsTy = getOperand(1).getType().cast<RankedTensorType>();
getResult().setType(getBroadcastedType(lhsTy, rhsTy));
}
//===----------------------------------------------------------------------===//
// Sub
/// Infer the output shape of the ONNXSubOp. This method is required by the
/// shape inference interface.
void ONNXSubOp::inferShapes() {
2020-01-14 01:21:29 +08:00
if (!getOperand(0).getType().isa<RankedTensorType>() ||
!getOperand(1).getType().isa<RankedTensorType>())
return;
2020-01-14 01:21:29 +08:00
auto lhsTy = getOperand(0).getType().cast<RankedTensorType>();
auto rhsTy = getOperand(1).getType().cast<RankedTensorType>();
getResult().setType(getBroadcastedType(lhsTy, rhsTy));
}
//===----------------------------------------------------------------------===//
// And
/// Infer the output shape of the ONNXAndOp. This method is required by the
/// shape inference interface.
void ONNXAndOp::inferShapes() {
2020-01-14 01:21:29 +08:00
if (!getOperand(0).getType().isa<RankedTensorType>() ||
!getOperand(1).getType().isa<RankedTensorType>())
return;
2020-01-14 01:21:29 +08:00
auto lhsTy = getOperand(0).getType().cast<RankedTensorType>();
auto rhsTy = getOperand(1).getType().cast<RankedTensorType>();
getResult().setType(getBroadcastedType(lhsTy, rhsTy));
}
//===----------------------------------------------------------------------===//
// Or
/// Infer the output shape of the ONNXOrOp. This method is required by the
/// shape inference interface.
void ONNXOrOp::inferShapes() {
2020-01-14 01:21:29 +08:00
if (!getOperand(0).getType().isa<RankedTensorType>() ||
!getOperand(1).getType().isa<RankedTensorType>())
return;
2020-01-14 01:21:29 +08:00
auto lhsTy = getOperand(0).getType().cast<RankedTensorType>();
auto rhsTy = getOperand(1).getType().cast<RankedTensorType>();
getResult().setType(getBroadcastedType(lhsTy, rhsTy));
}
//===----------------------------------------------------------------------===//
// Xor
/// Infer the output shape of the ONNXXorOp. This method is required by the
/// shape inference interface.
void ONNXXorOp::inferShapes() {
2020-01-14 01:21:29 +08:00
if (!getOperand(0).getType().isa<RankedTensorType>() ||
!getOperand(1).getType().isa<RankedTensorType>())
return;
2020-01-14 01:21:29 +08:00
auto lhsTy = getOperand(0).getType().cast<RankedTensorType>();
auto rhsTy = getOperand(1).getType().cast<RankedTensorType>();
getResult().setType(getBroadcastedType(lhsTy, rhsTy));
}
//===----------------------------------------------------------------------===//
//===----------------------------------------------------------------------===//
// Sum
/// Infer the output shape of the ONNXSumOp. This method is required by the
/// shape inference interface.
void ONNXSumOp::inferShapes() {
for (int i = 0; i < getNumOperands(); ++i) {
2020-01-14 01:21:29 +08:00
if (!getOperand(i).getType().cast<RankedTensorType>())
return;
}
2020-01-14 01:21:29 +08:00
Type resultTy = getOperand(0).getType().cast<RankedTensorType>();
for (int i = 1; i < getNumOperands(); ++i) {
2020-01-14 01:21:29 +08:00
Type nextTy = getOperand(i).getType().cast<RankedTensorType>();
resultTy = getBroadcastedType(resultTy, nextTy);
}
2020-01-14 01:21:29 +08:00
getResult().setType(resultTy);
}
//===----------------------------------------------------------------------===//
// Max
/// Infer the output shape of the ONNXMaxOp. This method is required by the
/// shape inference interface.
void ONNXMaxOp::inferShapes() {
for (int i = 0; i < getNumOperands(); ++i) {
2020-01-14 01:21:29 +08:00
if (!getOperand(i).getType().cast<RankedTensorType>())
return;
}
2020-01-14 01:21:29 +08:00
Type resultTy = getOperand(0).getType().cast<RankedTensorType>();
for (int i = 1; i < getNumOperands(); ++i) {
2020-01-14 01:21:29 +08:00
Type nextTy = getOperand(i).getType().cast<RankedTensorType>();
resultTy = getBroadcastedType(resultTy, nextTy);
}
2020-01-14 01:21:29 +08:00
getResult().setType(resultTy);
}
//===----------------------------------------------------------------------===//
// Min
/// Infer the output shape of the ONNXMinOp. This method is required by the
/// shape inference interface.
void ONNXMinOp::inferShapes() {
for (int i = 0; i < getNumOperands(); ++i) {
2020-01-14 01:21:29 +08:00
if (!getOperand(i).getType().cast<RankedTensorType>())
return;
}
2020-01-14 01:21:29 +08:00
Type resultTy = getOperand(0).getType().cast<RankedTensorType>();
for (int i = 1; i < getNumOperands(); ++i) {
2020-01-14 01:21:29 +08:00
Type nextTy = getOperand(i).getType().cast<RankedTensorType>();
resultTy = getBroadcastedType(resultTy, nextTy);
}
2020-01-14 01:21:29 +08:00
getResult().setType(resultTy);
}
//===----------------------------------------------------------------------===//
// Identity
/// Infer the output shape of the ONNXIdentityOp. This method is required by the
/// shape inference interface.
void ONNXIdentityOp::inferShapes() {
2020-01-14 01:21:29 +08:00
getResult().setType(getOperand().getType());
}
//===----------------------------------------------------------------------===//
// MatMul
void ONNXMatMulOp::inferShapes() {
// Cannot infer shape if no shape exists.
if (!A().getType().isa<RankedTensorType>() ||
!B().getType().isa<RankedTensorType>())
return;
auto lhsTy = A().getType().cast<RankedTensorType>();
auto rhsTy = B().getType().cast<RankedTensorType>();
SmallVector<int64_t, 2> dims;
auto lhsShape = lhsTy.getShape();
auto rhsShape = rhsTy.getShape();
if (lhsShape.size() < 1 && rhsShape.size() < 1) {
// Multiplication by scalars is not allowed.
emitError("Multiplication by scalar arguments not allowed");
} else if (lhsShape.size() == 1 && rhsShape.size() == 1) {
// Special case when both arrays are 1-dimensional and according to
// numpy rules the types need to be extended to 1xN and Nx1. Helper sizes
// need to be removed after the multiplication but cannot be removed if all
// sizes are 1.
if (lhsShape[0] != -1 && rhsShape[0] != -1 && lhsShape[0] != rhsShape[0])
emitError("Attempt to multiply incompatible matrices");
dims.emplace_back(1);
2020-01-29 23:35:05 +08:00
} else if (lhsShape.size() == 1 && rhsShape.size() >= 2) {
// If the first argument is 1-D, it is promoted to a matrix by prepending a
// 1 to its dimensions. After matrix multiplication the prepended 1 is
// removed.
//
// N MATMUL (s1 x s2 x... x sK x N x P)
// =>
// (s1 x s2 x... x sK x P)
// Check legality of matrix multiplication.
unsigned rhsRank = rhsShape.size();
if (lhsShape[0] != -1 && rhsShape[rhsRank - 2] != -1 &&
lhsShape[0] != rhsShape[rhsRank - 2])
emitError("Attempt to multiply incompatible matrices");
2020-01-29 23:35:05 +08:00
for (decltype(rhsRank) i = 0; i < rhsRank - 2; ++i)
2020-01-29 23:35:05 +08:00
dims.emplace_back(rhsShape[i]);
dims.emplace_back(rhsShape[rhsRank - 1]);
} else if (lhsShape.size() >= 2 && rhsShape.size() == 1) {
// If the second argument is 1-D, it is promoted to a matrix by appending a
// 1 to its dimensions. After matrix multiplication the appended 1 is
// removed.
//
// (s1 x s2 x... x sK x M x N) MATMUL N
// =>
// (s1 x s2 x... x sK x M)
// Check legality of matrix multiplication.
unsigned lhsRank = lhsShape.size();
if (lhsShape[lhsRank - 1] != -1 && rhsShape[0] != -1 &&
lhsShape[lhsRank - 1] != rhsShape[0])
emitError("Attempt to multiply incompatible matrices");
2020-01-29 23:35:05 +08:00
for (decltype(lhsRank) i = 0; i < lhsRank - 2; ++i)
2020-01-29 23:35:05 +08:00
dims.emplace_back(lhsShape[i]);
dims.emplace_back(lhsShape[lhsRank - 2]);
2020-01-11 01:30:12 +08:00
} else if (lhsShape.size() > 2 && rhsShape.size() == 2) {
// (s1 x s2 x... x sK x M x N) MATMUL (N x P)
// =>
// (s1 x s2 x... x sK x M x P)
// Check legality of matrix multiplication.
unsigned lhsRank = lhsShape.size();
if (lhsShape[lhsRank - 1] != -1 && rhsShape[0] != -1 &&
lhsShape[lhsRank - 1] != rhsShape[0])
emitError("Attempt to multiply incompatible matrices");
2020-01-11 01:30:12 +08:00
for (decltype(lhsRank) i = 0; i < lhsRank - 1; ++i)
2020-01-11 01:30:12 +08:00
dims.emplace_back(lhsShape[i]);
dims.emplace_back(rhsShape[1]);
} else if (lhsShape.size() == 2 && rhsShape.size() > 2) {
// (M x N) MATMUL (s1 x s2 x... x sK x N x P)
// =>
// (s1 x s2 x... x sK x M x P)
// Check legality of matrix multiplication.
unsigned rhsRank = rhsShape.size();
if (lhsShape[1] != -1 && rhsShape[rhsRank - 2] != -1 &&
lhsShape[1] != rhsShape[rhsRank - 2])
emitError("Attempt to multiply incompatible matrices");
2020-01-11 01:30:12 +08:00
for (decltype(rhsRank) i = 0; i < rhsRank - 2; ++i)
2020-01-11 01:30:12 +08:00
dims.emplace_back(rhsShape[i]);
dims.emplace_back(lhsShape[0]);
dims.emplace_back(rhsShape[rhsRank - 1]);
2020-01-11 01:30:12 +08:00
} else if (lhsShape.size() > 2 && rhsShape.size() > 2) {
// (s1 x s2 x... x sK x M x N) MATMUL (t1 x t2 x... x tK x N x P)
// =>
// (u1 x u2 x... x uK x M x P)
// Check legality of matrix multiplication.
unsigned lhsRank = lhsShape.size();
unsigned rhsRank = rhsShape.size();
if (lhsShape[lhsRank - 1] != -1 && rhsShape[rhsRank - 2] != -1 &&
lhsShape[lhsRank - 1] != rhsShape[rhsRank - 2])
emitError("Attempt to multiply incompatible matrices");
2020-01-11 01:30:12 +08:00
// Check and perform broadcasting for the shapes.
SmallVector<int64_t, 2> lhsBcastShape;
for (decltype(lhsRank) i = 0; i < lhsRank - 2; ++i)
2020-01-11 01:30:12 +08:00
lhsBcastShape.emplace_back(lhsShape[i]);
SmallVector<int64_t, 2> rhsBcastShape;
for (decltype(rhsRank) i = 0; i < rhsRank - 2; ++i)
2020-01-11 01:30:12 +08:00
rhsBcastShape.emplace_back(rhsShape[i]);
if (!getBroadcastedShape(lhsBcastShape, rhsBcastShape, dims))
emitError("Broadcasted dimensions are incompatible");
2020-01-11 01:30:12 +08:00
dims.emplace_back(lhsShape[lhsRank - 2]);
dims.emplace_back(rhsShape[rhsRank - 1]);
} else {
2020-01-11 01:30:12 +08:00
// This case covers all remaining combinations of 1 and 2-D matrices.
int64_t lhsDim = lhsShape[0];
int64_t rhsDim = rhsShape[0];
if (lhsShape.size() > 1) {
lhsDim = lhsShape[1];
dims.emplace_back(lhsShape[0]);
}
// Check legality of matrix multiplication.
if (lhsDim != -1 && rhsDim != -1 && lhsDim != rhsDim)
emitError("Attempt to multiply incompatible matrices");
if (rhsShape.size() > 1)
dims.emplace_back(rhsShape[1]);
}
2020-01-23 05:09:19 +08:00
getResult().setType(RankedTensorType::get(dims, lhsTy.getElementType()));
}
//===----------------------------------------------------------------------===//
// Gemm
void ONNXGemmOp::inferShapes() {
bool hasBias = !C().getType().isa<NoneType>();
// Cannot infer shape if no shape exists.
if (!A().getType().isa<RankedTensorType>() ||
!B().getType().isa<RankedTensorType>() ||
(hasBias && !C().getType().isa<RankedTensorType>()))
return;
auto lhsTy = A().getType().cast<RankedTensorType>();
auto rhsTy = B().getType().cast<RankedTensorType>();
int64_t M, N, K_A, K_B;
M = (transA() == 0) ? lhsTy.getShape()[0] : lhsTy.getShape()[1];
K_A = (transA() == 0) ? lhsTy.getShape()[1] : lhsTy.getShape()[0];
N = (transB() == 0) ? rhsTy.getShape()[1] : rhsTy.getShape()[0];
K_B = (transB() == 0) ? rhsTy.getShape()[0] : rhsTy.getShape()[1];
if ((K_A != -1) and (K_B != -1) and (K_A != K_B)) {
emitError("Tensor shapes mismatched");
}
if (hasBias) {
// Check whether bias is unidirectional broadcasting or not.
auto biasTy = C().getType().cast<RankedTensorType>();
auto shape = biasTy.getShape();
int rank = shape.size();
if ((rank > 2) ||
(rank >= 1 && shape[rank - 1] != -1 && N != -1 &&
N != shape[rank - 1] && shape[rank - 1] != 1) ||
(rank == 2 && shape[rank - 2] != -1 && M != -1 &&
M != shape[rank - 2] && shape[rank - 2] != 1)) {
emitError("Bias shape mismatched");
}
}
SmallVector<int64_t, 2> dims;
dims.emplace_back(M);
dims.emplace_back(N);
2020-01-14 01:21:29 +08:00
getResult().setType(RankedTensorType::get(dims, lhsTy.getElementType()));
}
/// BatchNormalizationTestMode
void ONNXBatchNormalizationTestModeOp::inferShapes() {
// Cannot infer shape if no shape exists.
if (!X().getType().isa<RankedTensorType>() ||
!scale().getType().isa<RankedTensorType>() ||
!B().getType().isa<RankedTensorType>() ||
!mean().getType().isa<RankedTensorType>() ||
!var().getType().isa<RankedTensorType>())
return;
auto inputTensorTy = X().getType().cast<RankedTensorType>();
auto scaleTensorTy = scale().getType().cast<RankedTensorType>();
auto biasTensorTy = B().getType().cast<RankedTensorType>();
auto meanTensorTy = mean().getType().cast<RankedTensorType>();
auto varianceTensorTy = var().getType().cast<RankedTensorType>();
// Check whether the shapes of scale, bias, mean and variance are valid.
// Operand's dimensions can be in the form of NxCxD1xD2x...xDn or N.
// In case of N, C is assumed to be 1.
// Shapes of scale, bias, mean and variance must be C.
int64_t c = -1;
if (inputTensorTy.getShape().size() == 1) {
c = 1;
} else if (inputTensorTy.getShape().size() > 2) {
c = (inputTensorTy.getShape()[1] != -1) ? inputTensorTy.getShape()[1] : -1;
} else {
emitError("Wrong rank for the input");
}
if (c != -1) {
auto s = scaleTensorTy.getShape();
auto b = biasTensorTy.getShape();
auto m = meanTensorTy.getShape();
auto v = varianceTensorTy.getShape();
if ((s.size() != 1) || (s[0] != -1 && s[0] != c))
emitError("Wrong rank for the scale");
if ((b.size() != 1) || (b[0] != -1 && b[0] != c))
emitError("Wrong rank for the bias");
if ((m.size() != 1) || (m[0] != -1 && m[0] != c))
emitError("Wrong rank for the mean");
if ((v.size() != 1) || (v[0] != -1 && v[0] != c))
emitError("Wrong rank for the variance");
}
// The output tensor of the same shape as the input.
getResult().setType(X().getType());
}
// TODO:
// Verify that matrix sizes are valid for multiplication and addition.
// Take into account the dimensionality of the matrix.
//===----------------------------------------------------------------------===//
// Reshape
void ONNXReshapeOp::inferShapes() {
// Cannot infer shape if no shape tensor is specified.
if (!shape().getType().isa<RankedTensorType>())
emitError("Shape tensor not ranked");
auto inputTensorTy = data().getType().cast<RankedTensorType>();
auto shapeTensorTy = shape().getType().cast<RankedTensorType>();
// Only rank 1 shape tensors are supported.
if (shapeTensorTy.getShape().size() != 1)
emitError("Shape tensor must have rank one");
int64_t outputRank = shapeTensorTy.getShape()[0];
// Shape tensor must have constant shape.
if (outputRank < 0)
emitError("Shape tensor must have constant shape");
// Compute total number of elements.
int64_t totalInputSize = 1;
for(auto inputDim : inputTensorTy.getShape())
totalInputSize *= inputDim;
// Check if second argument of ReshapeOp is a constant.
auto constantOp = getONNXConstantOp(shape());
SmallVector<int64_t, 2> dims(outputRank, -1);
if (constantOp) {
DenseElementsAttr valueAttribute =
constantOp.valueAttr().dyn_cast<DenseElementsAttr>();
if (!valueAttribute)
emitError("DenseElementsAttr expected");
// Get dims from valueAttribute.
auto valueIt = valueAttribute.getValues<IntegerAttr>().begin();
for (int i=0; i<outputRank; ++i)
dims[i] = (*valueIt++).cast<IntegerAttr>().getInt();
if (valueIt != valueAttribute.getValues<IntegerAttr>().end())
emitError("Constant value must have same rank as output");
int64_t numberOfDynamicInputs = 0;
int64_t totalKnownDimsSize = 1;
int64_t dynamicValueIndex = -1;
for (int i=0; i<outputRank; ++i) {
// Set output dimension.
if (dims[i] == 0)
dims[i] = inputTensorTy.getShape()[i];
if (dims[i] < 0) {
numberOfDynamicInputs++;
dynamicValueIndex = i;
} else {
totalKnownDimsSize *= dims[i];
}
}
// If the number of dynamic inputs is 1 then deduce the missing value
// based on the total input size. The total input size must be greater
// than 0 i.e. all constant dimensions.
// TODO: Support dynamic input dimensons.
if (numberOfDynamicInputs == 1 && totalKnownDimsSize > 0 &&
totalInputSize > 0)
dims[dynamicValueIndex] = totalInputSize / totalKnownDimsSize;
}
2020-01-14 01:21:29 +08:00
getResult().setType(
RankedTensorType::get(dims, inputTensorTy.getElementType()));
}
//===----------------------------------------------------------------------===//
// Transpose
void ONNXTransposeOp::inferShapes() {
// Cannot infer shape if no shape exists.
if (!data().getType().isa<RankedTensorType>())
2020-01-22 09:39:11 +08:00
return;
// Naive transposition which handles the default case of
// reversing the shape of the tensor (similar to numpy.transpose).
auto arrayTy = data().getType().cast<RankedTensorType>();
2020-01-14 07:08:19 +08:00
SmallVector<int64_t, 2> dims;
auto permutation = ONNXTransposeOp::permAttr();
if (permutation) {
2020-01-14 07:08:19 +08:00
// Perform transposition according to perm attribute.
for (auto perm : permutation.getValue())
dims.emplace_back(arrayTy.getShape()[perm.cast<IntegerAttr>().getInt()]);
2020-01-14 07:08:19 +08:00
} else {
// Default
for (auto dim : llvm::reverse(arrayTy.getShape()))
dims.emplace_back(dim);
2020-01-14 07:08:19 +08:00
}
2020-01-14 01:21:29 +08:00
getResult().setType(RankedTensorType::get(dims, arrayTy.getElementType()));
}
//===----------------------------------------------------------------------===//
// ReduceMax
void ONNXReduceMaxOp::inferShapes() {
if (!getOperand().getType().isa<RankedTensorType>()) {
emitError("Shape tensor not ranked");
return;
}
auto operandTy = getOperand().getType().cast<RankedTensorType>();
getResult().setType(getReductionOutputType(operandTy, axes(), keepdims()));
}
//===----------------------------------------------------------------------===//
// ReduceMin
void ONNXReduceMinOp::inferShapes() {
if (!getOperand().getType().isa<RankedTensorType>()) {
emitError("Shape tensor not ranked");
return;
}
auto operandTy = getOperand().getType().cast<RankedTensorType>();
getResult().setType(getReductionOutputType(operandTy, axes(), keepdims()));
}
//===----------------------------------------------------------------------===//
// ReduceProd
void ONNXReduceProdOp::inferShapes() {
if (!getOperand().getType().isa<RankedTensorType>()) {
emitError("Shape tensor not ranked");
return;
}
auto operandTy = getOperand().getType().cast<RankedTensorType>();
getResult().setType(getReductionOutputType(operandTy, axes(), keepdims()));
}
//===----------------------------------------------------------------------===//
// ReduceSum
void ONNXReduceSumOp::inferShapes() {
if (!getOperand().getType().isa<RankedTensorType>()) {
emitError("Shape tensor not ranked");
return;
}
auto operandTy = getOperand().getType().cast<RankedTensorType>();
getResult().setType(getReductionOutputType(operandTy, axes(), keepdims()));
}
//===----------------------------------------------------------------------===//
// Conv
// For this operation, we define the attributes once in the original Conv
// operation class. There is no need to redefine the attribute names for the
// other classes based on Conv.
// Conv attributes output:
// - auto_pad set to NOTSET;
// - dilations, strides: set to 1 if not defined by user;
// - kernelShape: inferred from weight matrix if not defined by user;
// - pads: set to proper value, 0 if not defined by user.
void ONNXConvNoBiasOp::inferShapes() {
// Generic shape for data input X and weight tensor W:
// X: (N x C x D1 x D2 ... x Dn)
// W: (M x C/group x k1 x k2 x ... x kn)
// Cannot infer shape if no shape exists.
if (!X().getType().isa<RankedTensorType>() ||
!W().getType().isa<RankedTensorType>())
return;
2020-01-22 09:39:11 +08:00
auto xTy = X().getType().cast<RankedTensorType>();
auto xShape = xTy.getShape();
auto weightTy = W().getType().cast<RankedTensorType>();
auto weightShape = weightTy.getShape();
auto builder = mlir::Builder(this->getContext());
// Lowest supported convolution is a one dimensional convolution.
if (xShape.size() < 3)
emitError("Data input shape must be at least (NxCxD1)");
2020-01-21 07:50:21 +08:00
// Check that shape of weight and data have same length.
if (xShape.size() != weightShape.size())
emitError("Weight size not compatible with data size");
// Group is a required attribute and should have default value of 1.
int64_t group = ONNXConvNoBiasOp::group().getSExtValue();
// Check if the attribute actually exists. If it does not then add it.
if (!groupAttr())
groupAttr(builder.getI64IntegerAttr(group));
// Check that the X.shape[1] == (W.shape[1] * group) == C condition holds.
if (xShape[1] != -1 && weightShape[1] != -1 &&
xShape[1] != (weightShape[1] * group))
emitError("Channel dimension mismatch");
2020-01-23 04:05:56 +08:00
// Note: the value of the group attribut only impacts the way the
// computation is carried out and not the actual output size.
2020-01-21 07:50:21 +08:00
// Number of spatial dimensions.
auto spatialOffset = 2;
int32_t spatialRank = xShape.size() - spatialOffset;
2020-01-21 07:50:21 +08:00
// Use kernel_shape attribute if present otherwise use size from weight
// argument.
auto kernelShape = kernel_shape();
if (kernelShape.hasValue()) {
if (ArrayAttrSize(kernelShape) != spatialRank)
emitError("kernel_shape length incompatible with spatial dimensions");
// Have the right number of values, check them.
for (int i = 0; i < spatialRank; ++i)
if (ArrayAttrIntVal(kernelShape, i) < 1)
emitError("bad kernel_shape value");
2020-01-21 07:50:21 +08:00
} else {
// Deduce shape from weight input.
SmallVector<int64_t, 2> defaultVals;
for (int i = 0; i < spatialRank; ++i)
defaultVals.emplace_back(weightShape[spatialOffset + i]);
// Convert to ArrayRef, then build attribute, then store attribute.
ArrayRef<int64_t> defaultRefs(defaultVals);
auto builder = mlir::Builder(getContext());
kernel_shapeAttr(builder.getI64ArrayAttr(defaultRefs));
kernelShape = kernel_shape();
2020-01-21 07:50:21 +08:00
}
// Process strides, dilations, and pads.
processConvTypeParams<>(this, X());
auto dilationsOpt = dilations();
auto stridesOpt = strides();
auto padsOpt = pads();
// First two output dimensions consist of the number of batches and the
// number of kernels being applied.
SmallVector<int64_t, 4> outputDims;
// Insert batch size.
outputDims.emplace_back(xShape[0]);
// Insert number of filters being applied (number of output channels).
outputDims.emplace_back(weightShape[0]);
// Compute and insert spatial dims.
insertConvSpatialDim(
&outputDims, xShape, kernelShape, padsOpt, stridesOpt, dilationsOpt);
getResult().setType(RankedTensorType::get(outputDims, xTy.getElementType()));
}
//===----------------------------------------------------------------------===//
// AveragePool
// Infer shape attributes output:
// - auto_pad set to NOTSET;
// - strides: set to 1 if not defined by user;
// - pads: set to proper value, 0 if not defined by user.
void ONNXAveragePoolOp::inferShapes() {
// Cannot infer shape if no shape exists.
if (!X().getType().isa<RankedTensorType>())
return;
// Get shape of input.
auto xTy = X().getType().cast<RankedTensorType>();
auto xShape = xTy.getShape();
// Kernel shape.
auto kernelShape = kernel_shape();
if (!kernelShape)
emitError(
"kernel_shape is a mandatory attribute for which there is no default");
// Ceil mode.
auto ceilMode = ceil_mode().getSExtValue();
// Process strides and pads.
processConvStrideParam<ONNXAveragePoolOp>(this, kernelShape);
auto stridesOpt = strides();
processConvPadParam<ONNXAveragePoolOp>(
this, xShape, kernelShape, stridesOpt, llvm::None);
auto padsOpt = pads();
SmallVector<int64_t, 4> outputDims;
// Insert batch size.
outputDims.emplace_back(xShape[0]);
outputDims.emplace_back(xShape[1]);
// Compute and insert spatial dims.
insertConvSpatialDim(&outputDims, xShape, kernelShape, padsOpt, stridesOpt,
llvm::None, ceilMode);
getResult().setType(RankedTensorType::get(outputDims, xTy.getElementType()));
}
//===----------------------------------------------------------------------===//
// MaxPoolSingleOut
// Infer shape attributes output:
// - auto_pad set to NOTSET;
// - dilations, strides: set to 1 if not defined by user;
// - pads: set to proper value, 0 if not defined by user.
void ONNXMaxPoolSingleOutOp::inferShapes() {
// Cannot infer shape if no shape exists.
if (!X().getType().isa<RankedTensorType>())
return;
// Get shape of input.
auto xTy = X().getType().cast<RankedTensorType>();
auto xShape = xTy.getShape();
// Kernel shape.
auto kernelShape = kernel_shape();
if (!kernelShape)
emitError(
"kernel_shape is a mandatory attribute for which there is no default");
// Storage order.
auto storageOrder = storage_order().getSExtValue();
if (storageOrder != 0)
emitError("column major storage order not supported at this time");
// Process strides, dilations, and pads.
processConvTypeParams<>(this, X());
auto dilationsOpt = dilations();
auto stridesOpt = strides();
auto padsOpt = pads();
// Ceil mode.
auto ceilMode = ceil_mode().getSExtValue();
SmallVector<int64_t, 4> outputDims;
// Insert batch size.
outputDims.emplace_back(xShape[0]);
outputDims.emplace_back(xShape[1]);
// Compute and insert spatial dims.
insertConvSpatialDim(&outputDims, xShape, kernelShape, padsOpt, stridesOpt,
dilationsOpt, ceilMode);
getResult().setType(RankedTensorType::get(outputDims, xTy.getElementType()));
}
//===----------------------------------------------------------------------===//
static Type padShapeInferenceHelper(Value data, ArrayAttr padsOpt) {
// Cannot infer shape if no shape exists.
if (!data.getType().isa<RankedTensorType>())
return (Type)NULL;
auto dataTy = data.getType().cast<RankedTensorType>();
auto dataShape = dataTy.getShape();
auto dataRank = dataShape.size();
SmallVector<int64_t, 4> outputShape(dataShape.begin(), dataShape.end());
if (padsOpt) {
auto padsArray = padsOpt.getValue();
// Pads consists of two values for each axis of data.
// The two values specify the number of elements padded before and after
// respectively.
for (int i = 0; i < dataRank; ++i) {
int64_t p1 = (padsArray[i]).cast<IntegerAttr>().getInt();
int64_t p2 = (padsArray[i + dataRank]).cast<IntegerAttr>().getInt();
// Have to non-negative constant
if (p1 < 0 || p2 < 0)
return (Type)NULL;
if (outputShape[i] != -1)
outputShape[i] += p1 + p2;
}
return (RankedTensorType::get(outputShape, dataTy.getElementType()));
} else {
return (Type)NULL;
}
}
// PadConstantPad
void ONNXPadConstantPadOp::inferShapes() {
auto outputType = padShapeInferenceHelper(data(), pads());
if (outputType) {
getResult().setType(outputType);
}
return;
}
//===----------------------------------------------------------------------===//
// PadConstantValuePad
void ONNXPadConstantValuePadOp::inferShapes() {
auto outputType = padShapeInferenceHelper(data(), pads());
if (outputType) {
getResult().setType(outputType);
}
return;
}
void ONNXPadConstantValuePadOp::build(Builder *builder, OperationState &state,
Value data, ArrayAttr pads, FloatAttr constant_value, StringAttr mode) {
Type outputType = padShapeInferenceHelper(data, pads);
if (!outputType) {
auto elementType = data.getType().cast<TensorType>().getElementType();
outputType = UnrankedTensorType::get(elementType);
}
build(builder, state, outputType, data, pads, constant_value, mode);
}
//===----------------------------------------------------------------------===//
// Unsqueeze
void ONNXUnsqueezeOp::inferShapes() {
if (!data().getType().isa<RankedTensorType>())
return;
auto operandTy = data().getType().cast<RankedTensorType>();
int inRank = operandTy.getRank();
ArrayAttr axisAttrs = axesAttr();
SmallVector<int, 4> axes;
int outRank = 0;
if (axisAttrs) {
outRank = inRank + axisAttrs.getValue().size();
for (auto axisAttr : axisAttrs.getValue()) {
int axis = axisAttr.cast<IntegerAttr>().getInt();
axis = axis >= 0 ? axis : (outRank + axis);
// Valid range
assert(axis >= -outRank && axis <= outRank - 1);
if (std::find(axes.begin(), axes.end(), axis) == axes.end())
axes.emplace_back(axis);
else
emitError("Duplicated axes");
}
} else {
emitError("Axes attribute is required");
}
SmallVector<int64_t, 4> dims;
for (int i = 0, j = 0; i < outRank || j < inRank; ++i) {
if (std::find(axes.begin(), axes.end(), i) != axes.end()) {
dims.emplace_back(1);
} else {
dims.emplace_back(operandTy.getShape()[j++]);
}
}
getResult().setType(RankedTensorType::get(dims, operandTy.getElementType()));
}
//===----------------------------------------------------------------------===//
// Constant
void ONNXConstantOp::inferShapes() {
if ((sparse_value().hasValue() && value().hasValue()) ||
(!sparse_value().hasValue() && !value().hasValue()))
emitError("Require exactly one of the two attributes, either value or "
"sparse_value");
ElementsAttr valAttr;
if (sparse_value().hasValue())
valAttr = sparse_valueAttr().cast<SparseElementsAttr>();
else
valAttr = valueAttr().cast<DenseElementsAttr>();
getResult().setType(valAttr.getType());
}
//===----------------------------------------------------------------------===//
// TableGen'd op method definitions
//===----------------------------------------------------------------------===//
#define GET_OP_CLASSES
2019-12-23 13:13:52 +08:00
#include "src/onnx.cpp.inc"