2020-03-19 16:48:09 +08:00
|
|
|
//===------------------ ONNXOps.cpp - ONNX Operations ---------------------===//
|
2019-11-02 05:09:48 +08:00
|
|
|
//
|
2020-03-19 16:48:09 +08:00
|
|
|
// Copyright 2019-2020 The IBM Research Authors.
|
2019-11-02 05:09:48 +08:00
|
|
|
//
|
|
|
|
// =============================================================================
|
|
|
|
//
|
2020-03-19 16:48:09 +08:00
|
|
|
// This file provides definition of ONNX dialect operations.
|
2019-11-02 05:09:48 +08:00
|
|
|
//
|
|
|
|
//===----------------------------------------------------------------------===//
|
2020-03-19 16:48:09 +08:00
|
|
|
|
2019-12-20 00:28:06 +08:00
|
|
|
#include "mlir/Dialect/Traits.h"
|
2019-11-02 05:09:48 +08:00
|
|
|
#include "mlir/IR/Block.h"
|
|
|
|
#include "mlir/IR/Builders.h"
|
|
|
|
#include "mlir/IR/Function.h"
|
|
|
|
#include "mlir/IR/IntegerSet.h"
|
|
|
|
#include "mlir/IR/Matchers.h"
|
2020-01-21 03:46:54 +08:00
|
|
|
#include "mlir/IR/Module.h"
|
2019-11-02 05:09:48 +08:00
|
|
|
#include "mlir/IR/OpImplementation.h"
|
|
|
|
#include "mlir/IR/PatternMatch.h"
|
2019-12-17 07:45:39 +08:00
|
|
|
#include "llvm/ADT/SetVector.h"
|
|
|
|
#include "llvm/ADT/SmallBitVector.h"
|
2019-11-02 05:09:48 +08:00
|
|
|
|
2020-03-19 16:48:09 +08:00
|
|
|
#include "ONNXOps.hpp"
|
2019-11-02 05:09:48 +08:00
|
|
|
|
|
|
|
using namespace mlir;
|
2019-12-20 00:28:06 +08:00
|
|
|
using namespace mlir::OpTrait::util;
|
2019-11-02 05:09:48 +08:00
|
|
|
|
2020-02-26 03:33:48 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// ONNX Helper functions
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
static size_t ArrayAttrSize(ArrayAttr a) { return a.size(); }
|
|
|
|
|
|
|
|
static size_t ArrayAttrSize(Optional<ArrayAttr> a) {
|
|
|
|
return a.getValue().size();
|
|
|
|
}
|
|
|
|
|
|
|
|
static int64_t ArrayAttrIntVal(ArrayAttr a, int i) {
|
|
|
|
return (a.getValue()[i]).cast<IntegerAttr>().getInt();
|
|
|
|
}
|
|
|
|
|
|
|
|
static int64_t ArrayAttrIntVal(Optional<ArrayAttr> a, int i) {
|
|
|
|
return (a.getValue().getValue()[i]).cast<IntegerAttr>().getInt();
|
|
|
|
}
|
|
|
|
|
2020-03-14 05:18:46 +08:00
|
|
|
// Returns the ConstantOp which defines an MLIR Value or null.
|
|
|
|
static mlir::ONNXConstantOp getONNXConstantOp(Value value) {
|
|
|
|
return dyn_cast_or_null<mlir::ONNXConstantOp>(value.getDefiningOp());
|
|
|
|
}
|
|
|
|
|
2020-02-10 21:38:19 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// Get reduction type
|
|
|
|
//===----------------------------------------------------------------------===//
|
2020-02-26 03:33:48 +08:00
|
|
|
RankedTensorType getReductionOutputType(
|
|
|
|
RankedTensorType operandTy, Optional<ArrayAttr> axesAttrs, APInt keepdims) {
|
2020-02-10 21:38:19 +08:00
|
|
|
int64_t rank = operandTy.getRank();
|
|
|
|
|
|
|
|
SmallVector<int64_t, 4> axes;
|
|
|
|
if (axesAttrs != llvm::None) {
|
|
|
|
for (auto axisAttr : axesAttrs.getValue()) {
|
|
|
|
int64_t axis = axisAttr.cast<IntegerAttr>().getInt();
|
|
|
|
axis = axis >= 0 ? axis : (rank + axis);
|
|
|
|
assert(axis >= -rank && axis <= rank - 1);
|
|
|
|
if (std::find(axes.begin(), axes.end(), axis) == axes.end())
|
|
|
|
axes.emplace_back(axis);
|
|
|
|
}
|
|
|
|
} else {
|
|
|
|
for (decltype(rank) i = 0; i < rank; ++i) {
|
|
|
|
axes.emplace_back(i);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
// Mark reduction axes.
|
|
|
|
SmallVector<bool, 4> isReductionAxis;
|
|
|
|
for (decltype(rank) i = 0; i < rank; ++i) {
|
|
|
|
if (std::find(axes.begin(), axes.end(), i) != axes.end())
|
|
|
|
isReductionAxis.emplace_back(true);
|
|
|
|
else
|
|
|
|
isReductionAxis.emplace_back(false);
|
|
|
|
}
|
|
|
|
|
|
|
|
// KeepDims
|
|
|
|
bool isKeepdims = (keepdims == 1) ? true : false;
|
|
|
|
|
|
|
|
SmallVector<int64_t, 4> dims;
|
|
|
|
for (decltype(rank) i = 0; i < rank; ++i) {
|
|
|
|
if (isReductionAxis[i]) {
|
|
|
|
if (isKeepdims)
|
|
|
|
dims.emplace_back(1); // reduction dimension
|
|
|
|
} else {
|
|
|
|
dims.emplace_back(operandTy.getShape()[i]);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
return RankedTensorType::get(dims, operandTy.getElementType());
|
|
|
|
}
|
|
|
|
|
2020-03-13 21:59:16 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// Support function that computes default values for dilations.
|
|
|
|
//
|
|
|
|
template <class T>
|
|
|
|
static void processConvDilationParam(T *op, Optional<ArrayAttr> kernelShape) {
|
|
|
|
auto builder = mlir::Builder(op->getContext());
|
|
|
|
auto kernelRank = ArrayAttrSize(kernelShape);
|
|
|
|
|
|
|
|
auto dilationsOpt = op->dilations();
|
|
|
|
if (dilationsOpt.hasValue()) {
|
|
|
|
if (ArrayAttrSize(dilationsOpt) != kernelRank)
|
|
|
|
op->emitError("dialation rank is not the same as the spatial rank");
|
|
|
|
// Test values to be greater than 0.
|
|
|
|
for (int i = 0; i < kernelRank; ++i) {
|
|
|
|
if (ArrayAttrIntVal(dilationsOpt, i) < 1)
|
|
|
|
op->emitError("dialation value must be nonzero positive");
|
|
|
|
}
|
|
|
|
} else {
|
|
|
|
// Default dilatation is needed, all dimensions init with 1.
|
|
|
|
SmallVector<int64_t, 4> defaultVals(kernelRank, 1);
|
|
|
|
// Convert to ArrayRef, then build attribute, then store attribute.
|
|
|
|
ArrayRef<int64_t> defaultRefs(defaultVals);
|
|
|
|
op->dilationsAttr(builder.getI64ArrayAttr(defaultRefs));
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// Support function that computes default values for strides.
|
|
|
|
//
|
|
|
|
template <class T>
|
|
|
|
static void processConvStrideParam(T *op, Optional<ArrayAttr> kernelShape) {
|
|
|
|
auto builder = mlir::Builder(op->getContext());
|
|
|
|
auto kernelRank = ArrayAttrSize(kernelShape);
|
|
|
|
|
|
|
|
auto stridesOpt = op->strides();
|
|
|
|
if (stridesOpt.hasValue()) {
|
|
|
|
if (ArrayAttrSize(stridesOpt) != kernelRank)
|
|
|
|
op->emitError("strides rank is not the same as the spatial rank");
|
|
|
|
// Check values to be greater than 0.
|
|
|
|
for (int i = 0; i < kernelRank; ++i) {
|
|
|
|
if (ArrayAttrIntVal(stridesOpt, i) < 1)
|
|
|
|
op->emitError("strides value must be nonzero positive");
|
|
|
|
}
|
|
|
|
} else {
|
|
|
|
// Default stride is needed, all dimensions init with 1.
|
|
|
|
SmallVector<int64_t, 4> defaultVals(kernelRank, 1);
|
|
|
|
// Convert to ArrayRef, then build attribute, then store attribute.
|
|
|
|
ArrayRef<int64_t> defaultRefs(defaultVals);
|
|
|
|
op->stridesAttr(builder.getI64ArrayAttr(defaultRefs));
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// Support function that computes default values for pads.
|
|
|
|
//
|
|
|
|
template <class T>
|
|
|
|
static void processConvPadParam(T *op,
|
|
|
|
ArrayRef<int64_t> inputShape, Optional<ArrayAttr> kernelShape,
|
|
|
|
Optional<ArrayAttr> stridesOpt,
|
|
|
|
Optional<ArrayAttr> dilationsOpt = llvm::None) {
|
|
|
|
auto builder = mlir::Builder(op->getContext());
|
|
|
|
|
|
|
|
auto inputRank = inputShape.size();
|
|
|
|
auto kernelRank = ArrayAttrSize(kernelShape);
|
|
|
|
auto kernelOffset = inputRank - kernelRank;
|
|
|
|
|
|
|
|
// Try to find padding, getting auto_pad attribute first.
|
|
|
|
auto autoPad = op->auto_pad();
|
|
|
|
// And then investigate the various different cases. Prefill pad values with
|
|
|
|
// zeros, the most common case.
|
|
|
|
SmallVector<int64_t, 4> actualPads(2 * kernelRank, 0);
|
|
|
|
bool updatedPad = false;
|
|
|
|
if (autoPad == "NOTSET") {
|
|
|
|
auto padsOpt = op->pads();
|
|
|
|
if (padsOpt.hasValue()) {
|
|
|
|
// Only option where pads are not updated. Pads consists of two entries
|
|
|
|
// for each spatial axis.
|
|
|
|
if (ArrayAttrSize(padsOpt) != 2 * kernelRank)
|
|
|
|
op->emitError("pads rank is not twice the spatial rank");
|
|
|
|
// Check values, pads cannot be negative.
|
|
|
|
for (int i = 0; i < 2 * kernelRank; ++i) {
|
|
|
|
if (ArrayAttrIntVal(padsOpt, i) < 0)
|
|
|
|
op->emitError("pads value must be nonnegative");
|
|
|
|
}
|
|
|
|
} else {
|
|
|
|
// We have notset with no pads, they are assumed to be all zero.
|
|
|
|
updatedPad = true;
|
|
|
|
}
|
|
|
|
} else if (autoPad == "SAME_UPPER" || autoPad == "SAME_LOWER") {
|
|
|
|
// Reload dialtion and strides as they may have gotten default values.
|
|
|
|
updatedPad = true;
|
|
|
|
int64_t dilationVal = 1;
|
|
|
|
for (int i = 0; i < kernelRank; ++i) {
|
|
|
|
auto inputSize = inputShape[kernelOffset + i];
|
|
|
|
auto kernelSize = ArrayAttrIntVal(kernelShape, i);
|
|
|
|
if (dilationsOpt.hasValue())
|
|
|
|
dilationVal = ArrayAttrIntVal(dilationsOpt, i);
|
|
|
|
auto strideVal = ArrayAttrIntVal(stridesOpt, i);
|
|
|
|
// Output size is input size divided by stride. When stride is 1, then
|
|
|
|
// input and output are the same size, which is the usual case. When
|
|
|
|
// stride is greater than 1, take the ceil to be sure to have each input
|
|
|
|
// value used, as padding will be used to fill the gaps.
|
|
|
|
int64_t outputSize = ceil((1.0 * inputSize) / (1.0 * strideVal));
|
|
|
|
// Forumla is from ONNX MaxPool, and can be explained as follows. Pads is
|
|
|
|
// the difference between the needed values for the computations, minus
|
|
|
|
// the input values. The needed values for the computation is the
|
|
|
|
// effective side of the kernel plus the number of times we jump to the
|
|
|
|
// next kernel. Number of time we jump is (outputSize - 1). That number is
|
|
|
|
// multiplied with the size of the jump, namely strideVal. Now for the
|
|
|
|
// effective kernel size. It is the kernelSize + the number of times we
|
|
|
|
// have dilation holes time the dialtion. The number of dialtion holes is
|
|
|
|
// (kernelSize -1). Thus the effective size is "kernelSize +
|
|
|
|
// (kernelSize-1)*dialation". This simplifies to "(kernelSize
|
|
|
|
// -1)*dialation + 1".
|
|
|
|
auto sumOfPad = (outputSize - 1) * strideVal +
|
|
|
|
((kernelSize - 1) * dilationVal + 1) - inputSize;
|
|
|
|
// Pad values are assumed equal on both size, at half the total value.
|
|
|
|
actualPads[i] = actualPads[kernelRank + i] = sumOfPad / 2;
|
|
|
|
// But if the total pad value is odd, we add 1 to begining or end
|
|
|
|
// depending on autoPad value.
|
|
|
|
if (sumOfPad % 2 != 0) {
|
|
|
|
if (autoPad == "SAME_UPPER") {
|
|
|
|
actualPads[kernelRank + i] += 1;
|
|
|
|
} else {
|
|
|
|
actualPads[i] += 1;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
} else if (autoPad == "VALID") {
|
|
|
|
// No pad, default value was set to zero, we are all set.
|
|
|
|
updatedPad = true;
|
|
|
|
} else {
|
|
|
|
op->emitError("auto_pad of unknown / unsupported value");
|
|
|
|
}
|
|
|
|
// Set pads values in attributes, if it is needed.
|
|
|
|
if (updatedPad) {
|
|
|
|
ArrayRef<int64_t> defaultRefs(actualPads);
|
|
|
|
op->padsAttr(builder.getI64ArrayAttr(defaultRefs));
|
|
|
|
}
|
|
|
|
// In all cases now, the acutal pad values are found in the pads attribute.
|
|
|
|
op->auto_padAttr(builder.getStringAttr("NOTSET"));
|
|
|
|
}
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// Support function that computes default values for dilations, strides, and
|
|
|
|
// pads.
|
|
|
|
template <class T>
|
|
|
|
static void processConvTypeParams(T *op, Value inputOperand) {
|
|
|
|
auto builder = mlir::Builder(op->getContext());
|
|
|
|
|
|
|
|
// 1) Get shape of input.
|
|
|
|
auto inputShape = inputOperand.getType().cast<RankedTensorType>().getShape();
|
|
|
|
auto inputRank = inputShape.size();
|
|
|
|
|
|
|
|
// 2) Get kernel_shape attribute.
|
|
|
|
auto kernelShape = op->kernel_shape();
|
|
|
|
|
|
|
|
// Dilation.
|
|
|
|
processConvDilationParam<T>(op, kernelShape);
|
|
|
|
auto dilationsOpt = op->dilations();
|
|
|
|
|
|
|
|
// Strides.
|
|
|
|
processConvStrideParam<T>(op, kernelShape);
|
|
|
|
auto stridesOpt = op->strides();
|
|
|
|
|
|
|
|
// Pads.
|
|
|
|
processConvPadParam<T>(op, inputShape, kernelShape, stridesOpt, dilationsOpt);
|
|
|
|
}
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// Compute spatial dimensions given dilations, strides, pads, and ceil mode.
|
|
|
|
//
|
|
|
|
static void insertConvSpatialDim(SmallVector<int64_t, 4> *outputDims,
|
|
|
|
ArrayRef<int64_t> xShape, Optional<ArrayAttr> kernelShape,
|
|
|
|
Optional<ArrayAttr> padsOpt, Optional<ArrayAttr> stridesOpt,
|
|
|
|
Optional<ArrayAttr> dilationsOpt = llvm::None, bool ceilMode = false) {
|
|
|
|
auto xRank = xShape.size();
|
|
|
|
auto spatialRank = ArrayAttrSize(kernelShape);
|
|
|
|
auto spatialOffset = xRank - spatialRank;
|
|
|
|
|
|
|
|
int64_t dilationVal = 1;
|
|
|
|
for (int i = 0; i < spatialRank; ++i) {
|
|
|
|
auto inputSize = xShape[spatialOffset + i];
|
|
|
|
auto sumOfPads =
|
|
|
|
ArrayAttrIntVal(padsOpt, i) + ArrayAttrIntVal(padsOpt, spatialRank + i);
|
|
|
|
auto kernelSize = ArrayAttrIntVal(kernelShape, i);
|
|
|
|
if (dilationsOpt.hasValue())
|
|
|
|
dilationVal = ArrayAttrIntVal(dilationsOpt, i);
|
|
|
|
auto strideVal = ArrayAttrIntVal(stridesOpt, i);
|
|
|
|
// Number of useful values: input plus pad - effective size of kernel (see
|
|
|
|
// processConvTypeParams comments to see how this value is derived).
|
|
|
|
double numerator =
|
|
|
|
inputSize + sumOfPads - ((kernelSize - 1) * dilationVal + 1);
|
|
|
|
// Useful number is divided by the strides.
|
|
|
|
double denominator = strideVal;
|
|
|
|
int64_t res;
|
|
|
|
if (ceilMode) {
|
|
|
|
res = ceil(numerator / denominator) + 1;
|
|
|
|
} else {
|
|
|
|
res = floor(numerator / denominator) + 1;
|
|
|
|
}
|
|
|
|
outputDims->emplace_back(res);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2019-11-02 05:09:48 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// ONNXOpsDialect
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
/// Dialect creation, the instance will be owned by the context. This is the
|
|
|
|
/// point of registration of custom types and operations for the dialect.
|
2019-12-17 07:45:39 +08:00
|
|
|
ONNXOpsDialect::ONNXOpsDialect(mlir::MLIRContext *ctx)
|
2019-11-02 05:09:48 +08:00
|
|
|
: mlir::Dialect(getDialectNamespace(), ctx) {
|
|
|
|
addOperations<
|
|
|
|
#define GET_OP_LIST
|
2019-12-23 13:13:52 +08:00
|
|
|
#include "src/onnx.cpp.inc"
|
2019-11-02 05:09:48 +08:00
|
|
|
>();
|
|
|
|
}
|
|
|
|
|
2019-12-22 13:25:02 +08:00
|
|
|
void ONNXEntryPointOp::build(mlir::Builder *builder,
|
2020-02-26 03:33:48 +08:00
|
|
|
mlir::OperationState &state, mlir::FuncOp function, int numInputs,
|
|
|
|
int numOutputs) {
|
2019-12-22 13:25:02 +08:00
|
|
|
state.addAttribute(ONNXEntryPointOp::getEntryPointFuncAttrName(),
|
2020-02-26 03:33:48 +08:00
|
|
|
builder->getSymbolRefAttr(function));
|
2019-12-22 13:25:02 +08:00
|
|
|
state.addAttribute(ONNXEntryPointOp::getNumInputsAttrName(),
|
2020-02-26 03:33:48 +08:00
|
|
|
builder->getI32IntegerAttr(numInputs));
|
2019-12-22 13:25:02 +08:00
|
|
|
state.addAttribute(ONNXEntryPointOp::getNumOutputsAttrName(),
|
2020-02-26 03:33:48 +08:00
|
|
|
builder->getI32IntegerAttr(numOutputs));
|
2019-12-22 13:25:02 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
ONNXEntryPointOp ONNXEntryPointOp::create(mlir::Location location,
|
2020-02-26 03:33:48 +08:00
|
|
|
mlir::FuncOp &func, int numInputs, int numOutputs) {
|
2019-12-22 13:25:02 +08:00
|
|
|
mlir::OperationState state(location, "onnx.EntryPoint");
|
|
|
|
Builder builder(location->getContext());
|
|
|
|
mlir::ONNXEntryPointOp::build(&builder, state, func, numInputs, numOutputs);
|
|
|
|
Operation *op = mlir::Operation::create(state);
|
|
|
|
auto onnxEntryOp = llvm::cast<mlir::ONNXEntryPointOp>(op);
|
|
|
|
return onnxEntryOp;
|
|
|
|
}
|
|
|
|
|
2019-11-02 05:09:48 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// ONNX Operations
|
2019-12-06 09:08:09 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// Exp
|
|
|
|
/// Infer the output shape of the ONNXExpOp. This method is required by the
|
|
|
|
/// shape inference interface.
|
2020-01-14 01:21:29 +08:00
|
|
|
void ONNXExpOp::inferShapes() { getResult().setType(getOperand().getType()); }
|
2019-12-06 09:08:09 +08:00
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// Tanh
|
|
|
|
/// Infer the output shape of the ONNXTanhOp. This method is required by the
|
|
|
|
/// shape inference interface.
|
2020-02-24 23:46:48 +08:00
|
|
|
void ONNXTanhOp::inferShapes() { getResult().setType(getOperand().getType()); }
|
2019-12-06 09:08:09 +08:00
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// Sinh
|
|
|
|
/// Infer the output shape of the ONNXSinhOp. This method is required by the
|
|
|
|
/// shape inference interface.
|
2020-02-24 23:46:48 +08:00
|
|
|
void ONNXSinhOp::inferShapes() { getResult().setType(getOperand().getType()); }
|
2019-12-06 09:08:09 +08:00
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// Cosh
|
|
|
|
/// Infer the output shape of the ONNXCoshOp. This method is required by the
|
|
|
|
/// shape inference interface.
|
2020-02-24 23:46:48 +08:00
|
|
|
void ONNXCoshOp::inferShapes() { getResult().setType(getOperand().getType()); }
|
2019-12-06 09:08:09 +08:00
|
|
|
|
2020-01-08 11:11:21 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// Cos
|
|
|
|
/// Infer the output shape of the ONNXCosOp. This method is required by the
|
|
|
|
/// shape inference interface.
|
2020-01-14 01:21:29 +08:00
|
|
|
void ONNXCosOp::inferShapes() { getResult().setType(getOperand().getType()); }
|
2020-01-08 11:11:21 +08:00
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// Log
|
|
|
|
/// Infer the output shape of the ONNXLogOp. This method is required by the
|
|
|
|
/// shape inference interface.
|
2020-01-14 01:21:29 +08:00
|
|
|
void ONNXLogOp::inferShapes() { getResult().setType(getOperand().getType()); }
|
2020-01-08 11:11:21 +08:00
|
|
|
|
[MLIR] Add support for Max, Min, Sum, Elu, Selu, LeakyRelu, HardSigmoid (#395)
* Lower ONNXSumOp
* Add inferShapes() and test cases
* Load the first operand to the result
* Update SharingWork.md
* Update SharingWork.md
* Update SharingWork.md
* Add support for Max, Min
* Pass operation instead of location to mapToLowerScalarOp
* Add support for Elu, Selu, LeakyRelu, HardSigmoid
* Add test cases
* Update SharingWork.md
* Rewrite the part of lowering variadic ops and use it for binary ops
* Use two diffenrent templates for Unary and Variadic Ops
* Revise the code
2019-12-12 10:49:50 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// HardSigmoid
|
|
|
|
/// Infer the output shape of the ONNXHardSigmoidOp. This method is required by
|
|
|
|
/// the shape inference interface.
|
|
|
|
void ONNXHardSigmoidOp::inferShapes() {
|
2020-01-14 01:21:29 +08:00
|
|
|
getResult().setType(getOperand().getType());
|
[MLIR] Add support for Max, Min, Sum, Elu, Selu, LeakyRelu, HardSigmoid (#395)
* Lower ONNXSumOp
* Add inferShapes() and test cases
* Load the first operand to the result
* Update SharingWork.md
* Update SharingWork.md
* Update SharingWork.md
* Add support for Max, Min
* Pass operation instead of location to mapToLowerScalarOp
* Add support for Elu, Selu, LeakyRelu, HardSigmoid
* Add test cases
* Update SharingWork.md
* Rewrite the part of lowering variadic ops and use it for binary ops
* Use two diffenrent templates for Unary and Variadic Ops
* Revise the code
2019-12-12 10:49:50 +08:00
|
|
|
}
|
|
|
|
|
2019-12-06 09:08:09 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// Sigmoid
|
|
|
|
/// Infer the output shape of the ONNXSigmoidOp. This method is required by the
|
|
|
|
/// shape inference interface.
|
|
|
|
void ONNXSigmoidOp::inferShapes() {
|
2020-01-14 01:21:29 +08:00
|
|
|
getResult().setType(getOperand().getType());
|
2019-12-06 09:08:09 +08:00
|
|
|
}
|
|
|
|
|
[MLIR] Add support for Max, Min, Sum, Elu, Selu, LeakyRelu, HardSigmoid (#395)
* Lower ONNXSumOp
* Add inferShapes() and test cases
* Load the first operand to the result
* Update SharingWork.md
* Update SharingWork.md
* Update SharingWork.md
* Add support for Max, Min
* Pass operation instead of location to mapToLowerScalarOp
* Add support for Elu, Selu, LeakyRelu, HardSigmoid
* Add test cases
* Update SharingWork.md
* Rewrite the part of lowering variadic ops and use it for binary ops
* Use two diffenrent templates for Unary and Variadic Ops
* Revise the code
2019-12-12 10:49:50 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// Elu
|
|
|
|
/// Infer the output shape of the ONNXEluOp. This method is required by the
|
|
|
|
/// shape inference interface.
|
2020-01-14 01:21:29 +08:00
|
|
|
void ONNXEluOp::inferShapes() { getResult().setType(getOperand().getType()); }
|
[MLIR] Add support for Max, Min, Sum, Elu, Selu, LeakyRelu, HardSigmoid (#395)
* Lower ONNXSumOp
* Add inferShapes() and test cases
* Load the first operand to the result
* Update SharingWork.md
* Update SharingWork.md
* Update SharingWork.md
* Add support for Max, Min
* Pass operation instead of location to mapToLowerScalarOp
* Add support for Elu, Selu, LeakyRelu, HardSigmoid
* Add test cases
* Update SharingWork.md
* Rewrite the part of lowering variadic ops and use it for binary ops
* Use two diffenrent templates for Unary and Variadic Ops
* Revise the code
2019-12-12 10:49:50 +08:00
|
|
|
|
2019-12-06 13:31:17 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// Relu
|
|
|
|
/// Infer the output shape of the ONNXReluOp. This method is required by the
|
|
|
|
/// shape inference interface.
|
2020-02-24 23:46:48 +08:00
|
|
|
void ONNXReluOp::inferShapes() { getResult().setType(getOperand().getType()); }
|
2019-12-06 13:31:17 +08:00
|
|
|
|
[MLIR] Add support for Max, Min, Sum, Elu, Selu, LeakyRelu, HardSigmoid (#395)
* Lower ONNXSumOp
* Add inferShapes() and test cases
* Load the first operand to the result
* Update SharingWork.md
* Update SharingWork.md
* Update SharingWork.md
* Add support for Max, Min
* Pass operation instead of location to mapToLowerScalarOp
* Add support for Elu, Selu, LeakyRelu, HardSigmoid
* Add test cases
* Update SharingWork.md
* Rewrite the part of lowering variadic ops and use it for binary ops
* Use two diffenrent templates for Unary and Variadic Ops
* Revise the code
2019-12-12 10:49:50 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// LeakyRelu
|
|
|
|
/// Infer the output shape of the ONNXLeakyReluOp. This method is required by
|
|
|
|
/// the shape inference interface.
|
|
|
|
void ONNXLeakyReluOp::inferShapes() {
|
2020-01-14 01:21:29 +08:00
|
|
|
getResult().setType(getOperand().getType());
|
[MLIR] Add support for Max, Min, Sum, Elu, Selu, LeakyRelu, HardSigmoid (#395)
* Lower ONNXSumOp
* Add inferShapes() and test cases
* Load the first operand to the result
* Update SharingWork.md
* Update SharingWork.md
* Update SharingWork.md
* Add support for Max, Min
* Pass operation instead of location to mapToLowerScalarOp
* Add support for Elu, Selu, LeakyRelu, HardSigmoid
* Add test cases
* Update SharingWork.md
* Rewrite the part of lowering variadic ops and use it for binary ops
* Use two diffenrent templates for Unary and Variadic Ops
* Revise the code
2019-12-12 10:49:50 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// Selu
|
|
|
|
/// Infer the output shape of the ONNXSeluOp. This method is required by
|
|
|
|
/// the shape inference interface.
|
2020-02-24 23:46:48 +08:00
|
|
|
void ONNXSeluOp::inferShapes() { getResult().setType(getOperand().getType()); }
|
2019-12-16 14:23:33 +08:00
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// Reciprocal
|
|
|
|
/// Infer the output shape of the ONNXReciprocalOp. This method is required by
|
|
|
|
/// the shape inference interface.
|
|
|
|
void ONNXReciprocalOp::inferShapes() {
|
2020-01-14 01:21:29 +08:00
|
|
|
getResult().setType(getOperand().getType());
|
[MLIR] Add support for Max, Min, Sum, Elu, Selu, LeakyRelu, HardSigmoid (#395)
* Lower ONNXSumOp
* Add inferShapes() and test cases
* Load the first operand to the result
* Update SharingWork.md
* Update SharingWork.md
* Update SharingWork.md
* Add support for Max, Min
* Pass operation instead of location to mapToLowerScalarOp
* Add support for Elu, Selu, LeakyRelu, HardSigmoid
* Add test cases
* Update SharingWork.md
* Rewrite the part of lowering variadic ops and use it for binary ops
* Use two diffenrent templates for Unary and Variadic Ops
* Revise the code
2019-12-12 10:49:50 +08:00
|
|
|
}
|
|
|
|
|
2020-01-21 10:57:32 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// Softmax
|
|
|
|
/// Infer the output shape of the ONNXSoftmaxOp. This method is required by
|
|
|
|
/// the shape inference interface.
|
|
|
|
void ONNXSoftmaxOp::inferShapes() {
|
|
|
|
getResult().setType(getOperand().getType());
|
|
|
|
}
|
|
|
|
|
2020-01-24 12:18:38 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// Softplus
|
|
|
|
/// Infer the output shape of the ONNXSoftplusOp. This method is required by
|
|
|
|
/// the shape inference interface.
|
|
|
|
void ONNXSoftplusOp::inferShapes() {
|
|
|
|
getResult().setType(getOperand().getType());
|
|
|
|
}
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// Softsign
|
|
|
|
/// Infer the output shape of the ONNXSoftsignOp. This method is required by
|
|
|
|
/// the shape inference interface.
|
|
|
|
void ONNXSoftsignOp::inferShapes() {
|
|
|
|
getResult().setType(getOperand().getType());
|
|
|
|
}
|
|
|
|
|
2020-01-29 00:10:47 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// Sqrt
|
|
|
|
/// Infer the output shape of the ONNXSqrtOp. This method is required by
|
|
|
|
/// the shape inference interface.
|
2020-02-24 23:46:48 +08:00
|
|
|
void ONNXSqrtOp::inferShapes() { getResult().setType(getOperand().getType()); }
|
2020-01-29 00:10:47 +08:00
|
|
|
|
2020-02-04 22:27:17 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// Sign
|
|
|
|
/// Infer the output shape of the ONNXSignOp. This method is required by
|
|
|
|
/// the shape inference interface.
|
2020-02-24 23:46:48 +08:00
|
|
|
void ONNXSignOp::inferShapes() { getResult().setType(getOperand().getType()); }
|
2020-02-04 22:27:17 +08:00
|
|
|
|
2020-03-17 23:12:45 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// Abs
|
|
|
|
/// Infer the output shape of the ONNXAbsOp. This method is required by the
|
|
|
|
/// shape inference interface.
|
|
|
|
void ONNXAbsOp::inferShapes() { getResult().setType(getOperand().getType()); }
|
|
|
|
|
2019-11-02 05:09:48 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
2019-11-13 02:37:46 +08:00
|
|
|
// Add
|
2019-11-19 10:08:21 +08:00
|
|
|
/// Infer the output shape of the ONNXAddOp. This method is required by the
|
|
|
|
/// shape inference interface.
|
2019-11-08 00:42:40 +08:00
|
|
|
void ONNXAddOp::inferShapes() {
|
2020-01-14 01:21:29 +08:00
|
|
|
if (!getOperand(0).getType().isa<RankedTensorType>() ||
|
|
|
|
!getOperand(1).getType().isa<RankedTensorType>())
|
2019-12-20 00:28:06 +08:00
|
|
|
return;
|
2020-01-14 01:21:29 +08:00
|
|
|
auto lhsTy = getOperand(0).getType().cast<RankedTensorType>();
|
|
|
|
auto rhsTy = getOperand(1).getType().cast<RankedTensorType>();
|
|
|
|
getResult().setType(getBroadcastedType(lhsTy, rhsTy));
|
2019-11-08 00:42:40 +08:00
|
|
|
}
|
|
|
|
|
[MLIR] Lower ONNX element-wise binary ops: Mul, Div, Sub, And, Or, Xor (#388)
* Lower ONNX element-wise binary ops: Mul, Div, Sub, And, Or, Xor
* Edit gen_doc.py to avoid changes about AnyTypeOf<[AnyMemRef, AnyTensor]>
* Miss a space
* Add tests
* Shorten ONNXElementWiseBinaryOpLowering into ONNXEWBinaryOpLowering
* Move lowering patterns into runOnModule()
* Redundant space
2019-12-04 00:17:21 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// Mul
|
|
|
|
/// Infer the output shape of the ONNXMulOp. This method is required by the
|
|
|
|
/// shape inference interface.
|
|
|
|
void ONNXMulOp::inferShapes() {
|
2020-01-14 01:21:29 +08:00
|
|
|
if (!getOperand(0).getType().isa<RankedTensorType>() ||
|
|
|
|
!getOperand(1).getType().isa<RankedTensorType>())
|
2019-12-20 00:28:06 +08:00
|
|
|
return;
|
2020-01-14 01:21:29 +08:00
|
|
|
auto lhsTy = getOperand(0).getType().cast<RankedTensorType>();
|
|
|
|
auto rhsTy = getOperand(1).getType().cast<RankedTensorType>();
|
|
|
|
getResult().setType(getBroadcastedType(lhsTy, rhsTy));
|
[MLIR] Lower ONNX element-wise binary ops: Mul, Div, Sub, And, Or, Xor (#388)
* Lower ONNX element-wise binary ops: Mul, Div, Sub, And, Or, Xor
* Edit gen_doc.py to avoid changes about AnyTypeOf<[AnyMemRef, AnyTensor]>
* Miss a space
* Add tests
* Shorten ONNXElementWiseBinaryOpLowering into ONNXEWBinaryOpLowering
* Move lowering patterns into runOnModule()
* Redundant space
2019-12-04 00:17:21 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// Div
|
|
|
|
/// Infer the output shape of the ONNXDivOp. This method is required by the
|
|
|
|
/// shape inference interface.
|
|
|
|
void ONNXDivOp::inferShapes() {
|
2020-01-14 01:21:29 +08:00
|
|
|
if (!getOperand(0).getType().isa<RankedTensorType>() ||
|
|
|
|
!getOperand(1).getType().isa<RankedTensorType>())
|
2019-12-20 00:28:06 +08:00
|
|
|
return;
|
2020-01-14 01:21:29 +08:00
|
|
|
auto lhsTy = getOperand(0).getType().cast<RankedTensorType>();
|
|
|
|
auto rhsTy = getOperand(1).getType().cast<RankedTensorType>();
|
|
|
|
getResult().setType(getBroadcastedType(lhsTy, rhsTy));
|
[MLIR] Lower ONNX element-wise binary ops: Mul, Div, Sub, And, Or, Xor (#388)
* Lower ONNX element-wise binary ops: Mul, Div, Sub, And, Or, Xor
* Edit gen_doc.py to avoid changes about AnyTypeOf<[AnyMemRef, AnyTensor]>
* Miss a space
* Add tests
* Shorten ONNXElementWiseBinaryOpLowering into ONNXEWBinaryOpLowering
* Move lowering patterns into runOnModule()
* Redundant space
2019-12-04 00:17:21 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// Sub
|
|
|
|
/// Infer the output shape of the ONNXSubOp. This method is required by the
|
|
|
|
/// shape inference interface.
|
|
|
|
void ONNXSubOp::inferShapes() {
|
2020-01-14 01:21:29 +08:00
|
|
|
if (!getOperand(0).getType().isa<RankedTensorType>() ||
|
|
|
|
!getOperand(1).getType().isa<RankedTensorType>())
|
2019-12-20 00:28:06 +08:00
|
|
|
return;
|
2020-01-14 01:21:29 +08:00
|
|
|
auto lhsTy = getOperand(0).getType().cast<RankedTensorType>();
|
|
|
|
auto rhsTy = getOperand(1).getType().cast<RankedTensorType>();
|
|
|
|
getResult().setType(getBroadcastedType(lhsTy, rhsTy));
|
[MLIR] Lower ONNX element-wise binary ops: Mul, Div, Sub, And, Or, Xor (#388)
* Lower ONNX element-wise binary ops: Mul, Div, Sub, And, Or, Xor
* Edit gen_doc.py to avoid changes about AnyTypeOf<[AnyMemRef, AnyTensor]>
* Miss a space
* Add tests
* Shorten ONNXElementWiseBinaryOpLowering into ONNXEWBinaryOpLowering
* Move lowering patterns into runOnModule()
* Redundant space
2019-12-04 00:17:21 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// And
|
|
|
|
/// Infer the output shape of the ONNXAndOp. This method is required by the
|
|
|
|
/// shape inference interface.
|
|
|
|
void ONNXAndOp::inferShapes() {
|
2020-01-14 01:21:29 +08:00
|
|
|
if (!getOperand(0).getType().isa<RankedTensorType>() ||
|
|
|
|
!getOperand(1).getType().isa<RankedTensorType>())
|
2019-12-20 00:28:06 +08:00
|
|
|
return;
|
2020-01-14 01:21:29 +08:00
|
|
|
auto lhsTy = getOperand(0).getType().cast<RankedTensorType>();
|
|
|
|
auto rhsTy = getOperand(1).getType().cast<RankedTensorType>();
|
|
|
|
getResult().setType(getBroadcastedType(lhsTy, rhsTy));
|
[MLIR] Lower ONNX element-wise binary ops: Mul, Div, Sub, And, Or, Xor (#388)
* Lower ONNX element-wise binary ops: Mul, Div, Sub, And, Or, Xor
* Edit gen_doc.py to avoid changes about AnyTypeOf<[AnyMemRef, AnyTensor]>
* Miss a space
* Add tests
* Shorten ONNXElementWiseBinaryOpLowering into ONNXEWBinaryOpLowering
* Move lowering patterns into runOnModule()
* Redundant space
2019-12-04 00:17:21 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// Or
|
|
|
|
/// Infer the output shape of the ONNXOrOp. This method is required by the
|
|
|
|
/// shape inference interface.
|
2019-12-20 00:28:06 +08:00
|
|
|
void ONNXOrOp::inferShapes() {
|
2020-01-14 01:21:29 +08:00
|
|
|
if (!getOperand(0).getType().isa<RankedTensorType>() ||
|
|
|
|
!getOperand(1).getType().isa<RankedTensorType>())
|
2019-12-20 00:28:06 +08:00
|
|
|
return;
|
2020-01-14 01:21:29 +08:00
|
|
|
auto lhsTy = getOperand(0).getType().cast<RankedTensorType>();
|
|
|
|
auto rhsTy = getOperand(1).getType().cast<RankedTensorType>();
|
|
|
|
getResult().setType(getBroadcastedType(lhsTy, rhsTy));
|
2019-12-20 00:28:06 +08:00
|
|
|
}
|
[MLIR] Lower ONNX element-wise binary ops: Mul, Div, Sub, And, Or, Xor (#388)
* Lower ONNX element-wise binary ops: Mul, Div, Sub, And, Or, Xor
* Edit gen_doc.py to avoid changes about AnyTypeOf<[AnyMemRef, AnyTensor]>
* Miss a space
* Add tests
* Shorten ONNXElementWiseBinaryOpLowering into ONNXEWBinaryOpLowering
* Move lowering patterns into runOnModule()
* Redundant space
2019-12-04 00:17:21 +08:00
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// Xor
|
|
|
|
/// Infer the output shape of the ONNXXorOp. This method is required by the
|
|
|
|
/// shape inference interface.
|
|
|
|
void ONNXXorOp::inferShapes() {
|
2020-01-14 01:21:29 +08:00
|
|
|
if (!getOperand(0).getType().isa<RankedTensorType>() ||
|
|
|
|
!getOperand(1).getType().isa<RankedTensorType>())
|
2019-12-20 00:28:06 +08:00
|
|
|
return;
|
2020-01-14 01:21:29 +08:00
|
|
|
auto lhsTy = getOperand(0).getType().cast<RankedTensorType>();
|
|
|
|
auto rhsTy = getOperand(1).getType().cast<RankedTensorType>();
|
|
|
|
getResult().setType(getBroadcastedType(lhsTy, rhsTy));
|
[MLIR] Lower ONNX element-wise binary ops: Mul, Div, Sub, And, Or, Xor (#388)
* Lower ONNX element-wise binary ops: Mul, Div, Sub, And, Or, Xor
* Edit gen_doc.py to avoid changes about AnyTypeOf<[AnyMemRef, AnyTensor]>
* Miss a space
* Add tests
* Shorten ONNXElementWiseBinaryOpLowering into ONNXEWBinaryOpLowering
* Move lowering patterns into runOnModule()
* Redundant space
2019-12-04 00:17:21 +08:00
|
|
|
}
|
|
|
|
|
2019-11-13 02:37:46 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
[MLIR] Add support for Max, Min, Sum, Elu, Selu, LeakyRelu, HardSigmoid (#395)
* Lower ONNXSumOp
* Add inferShapes() and test cases
* Load the first operand to the result
* Update SharingWork.md
* Update SharingWork.md
* Update SharingWork.md
* Add support for Max, Min
* Pass operation instead of location to mapToLowerScalarOp
* Add support for Elu, Selu, LeakyRelu, HardSigmoid
* Add test cases
* Update SharingWork.md
* Rewrite the part of lowering variadic ops and use it for binary ops
* Use two diffenrent templates for Unary and Variadic Ops
* Revise the code
2019-12-12 10:49:50 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// Sum
|
|
|
|
/// Infer the output shape of the ONNXSumOp. This method is required by the
|
|
|
|
/// shape inference interface.
|
|
|
|
void ONNXSumOp::inferShapes() {
|
2019-12-20 00:28:06 +08:00
|
|
|
for (int i = 0; i < getNumOperands(); ++i) {
|
2020-01-14 01:21:29 +08:00
|
|
|
if (!getOperand(i).getType().cast<RankedTensorType>())
|
2019-12-20 00:28:06 +08:00
|
|
|
return;
|
|
|
|
}
|
2020-01-14 01:21:29 +08:00
|
|
|
Type resultTy = getOperand(0).getType().cast<RankedTensorType>();
|
2019-12-20 00:28:06 +08:00
|
|
|
for (int i = 1; i < getNumOperands(); ++i) {
|
2020-01-14 01:21:29 +08:00
|
|
|
Type nextTy = getOperand(i).getType().cast<RankedTensorType>();
|
2019-12-20 00:28:06 +08:00
|
|
|
resultTy = getBroadcastedType(resultTy, nextTy);
|
|
|
|
}
|
2020-01-14 01:21:29 +08:00
|
|
|
getResult().setType(resultTy);
|
[MLIR] Add support for Max, Min, Sum, Elu, Selu, LeakyRelu, HardSigmoid (#395)
* Lower ONNXSumOp
* Add inferShapes() and test cases
* Load the first operand to the result
* Update SharingWork.md
* Update SharingWork.md
* Update SharingWork.md
* Add support for Max, Min
* Pass operation instead of location to mapToLowerScalarOp
* Add support for Elu, Selu, LeakyRelu, HardSigmoid
* Add test cases
* Update SharingWork.md
* Rewrite the part of lowering variadic ops and use it for binary ops
* Use two diffenrent templates for Unary and Variadic Ops
* Revise the code
2019-12-12 10:49:50 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// Max
|
|
|
|
/// Infer the output shape of the ONNXMaxOp. This method is required by the
|
|
|
|
/// shape inference interface.
|
|
|
|
void ONNXMaxOp::inferShapes() {
|
2019-12-20 00:28:06 +08:00
|
|
|
for (int i = 0; i < getNumOperands(); ++i) {
|
2020-01-14 01:21:29 +08:00
|
|
|
if (!getOperand(i).getType().cast<RankedTensorType>())
|
2019-12-20 00:28:06 +08:00
|
|
|
return;
|
|
|
|
}
|
2020-01-14 01:21:29 +08:00
|
|
|
Type resultTy = getOperand(0).getType().cast<RankedTensorType>();
|
2019-12-20 00:28:06 +08:00
|
|
|
for (int i = 1; i < getNumOperands(); ++i) {
|
2020-01-14 01:21:29 +08:00
|
|
|
Type nextTy = getOperand(i).getType().cast<RankedTensorType>();
|
2019-12-20 00:28:06 +08:00
|
|
|
resultTy = getBroadcastedType(resultTy, nextTy);
|
|
|
|
}
|
2020-01-14 01:21:29 +08:00
|
|
|
getResult().setType(resultTy);
|
[MLIR] Add support for Max, Min, Sum, Elu, Selu, LeakyRelu, HardSigmoid (#395)
* Lower ONNXSumOp
* Add inferShapes() and test cases
* Load the first operand to the result
* Update SharingWork.md
* Update SharingWork.md
* Update SharingWork.md
* Add support for Max, Min
* Pass operation instead of location to mapToLowerScalarOp
* Add support for Elu, Selu, LeakyRelu, HardSigmoid
* Add test cases
* Update SharingWork.md
* Rewrite the part of lowering variadic ops and use it for binary ops
* Use two diffenrent templates for Unary and Variadic Ops
* Revise the code
2019-12-12 10:49:50 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// Min
|
|
|
|
/// Infer the output shape of the ONNXMinOp. This method is required by the
|
|
|
|
/// shape inference interface.
|
|
|
|
void ONNXMinOp::inferShapes() {
|
2019-12-20 00:28:06 +08:00
|
|
|
for (int i = 0; i < getNumOperands(); ++i) {
|
2020-01-14 01:21:29 +08:00
|
|
|
if (!getOperand(i).getType().cast<RankedTensorType>())
|
2019-12-20 00:28:06 +08:00
|
|
|
return;
|
|
|
|
}
|
2020-01-14 01:21:29 +08:00
|
|
|
Type resultTy = getOperand(0).getType().cast<RankedTensorType>();
|
2019-12-20 00:28:06 +08:00
|
|
|
for (int i = 1; i < getNumOperands(); ++i) {
|
2020-01-14 01:21:29 +08:00
|
|
|
Type nextTy = getOperand(i).getType().cast<RankedTensorType>();
|
2019-12-20 00:28:06 +08:00
|
|
|
resultTy = getBroadcastedType(resultTy, nextTy);
|
|
|
|
}
|
2020-01-14 01:21:29 +08:00
|
|
|
getResult().setType(resultTy);
|
[MLIR] Add support for Max, Min, Sum, Elu, Selu, LeakyRelu, HardSigmoid (#395)
* Lower ONNXSumOp
* Add inferShapes() and test cases
* Load the first operand to the result
* Update SharingWork.md
* Update SharingWork.md
* Update SharingWork.md
* Add support for Max, Min
* Pass operation instead of location to mapToLowerScalarOp
* Add support for Elu, Selu, LeakyRelu, HardSigmoid
* Add test cases
* Update SharingWork.md
* Rewrite the part of lowering variadic ops and use it for binary ops
* Use two diffenrent templates for Unary and Variadic Ops
* Revise the code
2019-12-12 10:49:50 +08:00
|
|
|
}
|
|
|
|
|
2019-12-17 07:45:39 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// Identity
|
|
|
|
/// Infer the output shape of the ONNXIdentityOp. This method is required by the
|
|
|
|
/// shape inference interface.
|
|
|
|
void ONNXIdentityOp::inferShapes() {
|
2020-01-14 01:21:29 +08:00
|
|
|
getResult().setType(getOperand().getType());
|
2019-12-17 07:45:39 +08:00
|
|
|
}
|
|
|
|
|
[MLIR] Add support for Max, Min, Sum, Elu, Selu, LeakyRelu, HardSigmoid (#395)
* Lower ONNXSumOp
* Add inferShapes() and test cases
* Load the first operand to the result
* Update SharingWork.md
* Update SharingWork.md
* Update SharingWork.md
* Add support for Max, Min
* Pass operation instead of location to mapToLowerScalarOp
* Add support for Elu, Selu, LeakyRelu, HardSigmoid
* Add test cases
* Update SharingWork.md
* Rewrite the part of lowering variadic ops and use it for binary ops
* Use two diffenrent templates for Unary and Variadic Ops
* Revise the code
2019-12-12 10:49:50 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2019-11-13 02:37:46 +08:00
|
|
|
// MatMul
|
|
|
|
|
|
|
|
void ONNXMatMulOp::inferShapes() {
|
2019-11-16 02:10:41 +08:00
|
|
|
// Cannot infer shape if no shape exists.
|
2020-02-26 04:46:11 +08:00
|
|
|
if (!A().getType().isa<RankedTensorType>() ||
|
|
|
|
!B().getType().isa<RankedTensorType>())
|
2019-11-16 02:10:41 +08:00
|
|
|
return;
|
2020-01-10 04:30:57 +08:00
|
|
|
|
2020-02-26 04:46:11 +08:00
|
|
|
auto lhsTy = A().getType().cast<RankedTensorType>();
|
|
|
|
auto rhsTy = B().getType().cast<RankedTensorType>();
|
2020-01-10 04:30:57 +08:00
|
|
|
|
2019-11-16 02:10:41 +08:00
|
|
|
SmallVector<int64_t, 2> dims;
|
2020-01-10 04:30:57 +08:00
|
|
|
auto lhsShape = lhsTy.getShape();
|
|
|
|
auto rhsShape = rhsTy.getShape();
|
2020-01-28 01:08:23 +08:00
|
|
|
|
|
|
|
if (lhsShape.size() < 1 && rhsShape.size() < 1) {
|
|
|
|
// Multiplication by scalars is not allowed.
|
2020-02-26 04:46:11 +08:00
|
|
|
emitError("Multiplication by scalar arguments not allowed");
|
2020-01-28 01:08:23 +08:00
|
|
|
} else if (lhsShape.size() == 1 && rhsShape.size() == 1) {
|
2020-01-10 04:30:57 +08:00
|
|
|
// Special case when both arrays are 1-dimensional and according to
|
|
|
|
// numpy rules the types need to be extended to 1xN and Nx1. Helper sizes
|
|
|
|
// need to be removed after the multiplication but cannot be removed if all
|
2020-01-11 00:22:41 +08:00
|
|
|
// sizes are 1.
|
2020-02-24 23:46:48 +08:00
|
|
|
if (lhsShape[0] != -1 && rhsShape[0] != -1 && lhsShape[0] != rhsShape[0])
|
2020-02-26 04:46:11 +08:00
|
|
|
emitError("Attempt to multiply incompatible matrices");
|
2020-01-10 04:30:57 +08:00
|
|
|
dims.emplace_back(1);
|
2020-01-29 23:35:05 +08:00
|
|
|
} else if (lhsShape.size() == 1 && rhsShape.size() >= 2) {
|
|
|
|
// If the first argument is 1-D, it is promoted to a matrix by prepending a
|
|
|
|
// 1 to its dimensions. After matrix multiplication the prepended 1 is
|
|
|
|
// removed.
|
|
|
|
//
|
|
|
|
// N MATMUL (s1 x s2 x... x sK x N x P)
|
|
|
|
// =>
|
|
|
|
// (s1 x s2 x... x sK x P)
|
|
|
|
|
|
|
|
// Check legality of matrix multiplication.
|
|
|
|
unsigned rhsRank = rhsShape.size();
|
|
|
|
if (lhsShape[0] != -1 && rhsShape[rhsRank - 2] != -1 &&
|
|
|
|
lhsShape[0] != rhsShape[rhsRank - 2])
|
2020-02-26 04:46:11 +08:00
|
|
|
emitError("Attempt to multiply incompatible matrices");
|
2020-01-29 23:35:05 +08:00
|
|
|
|
2020-02-14 23:43:17 +08:00
|
|
|
for (decltype(rhsRank) i = 0; i < rhsRank - 2; ++i)
|
2020-01-29 23:35:05 +08:00
|
|
|
dims.emplace_back(rhsShape[i]);
|
|
|
|
dims.emplace_back(rhsShape[rhsRank - 1]);
|
|
|
|
} else if (lhsShape.size() >= 2 && rhsShape.size() == 1) {
|
|
|
|
// If the second argument is 1-D, it is promoted to a matrix by appending a
|
|
|
|
// 1 to its dimensions. After matrix multiplication the appended 1 is
|
|
|
|
// removed.
|
|
|
|
//
|
|
|
|
// (s1 x s2 x... x sK x M x N) MATMUL N
|
|
|
|
// =>
|
|
|
|
// (s1 x s2 x... x sK x M)
|
|
|
|
|
|
|
|
// Check legality of matrix multiplication.
|
|
|
|
unsigned lhsRank = lhsShape.size();
|
|
|
|
if (lhsShape[lhsRank - 1] != -1 && rhsShape[0] != -1 &&
|
|
|
|
lhsShape[lhsRank - 1] != rhsShape[0])
|
2020-02-26 04:46:11 +08:00
|
|
|
emitError("Attempt to multiply incompatible matrices");
|
2020-01-29 23:35:05 +08:00
|
|
|
|
2020-02-14 23:43:17 +08:00
|
|
|
for (decltype(lhsRank) i = 0; i < lhsRank - 2; ++i)
|
2020-01-29 23:35:05 +08:00
|
|
|
dims.emplace_back(lhsShape[i]);
|
|
|
|
dims.emplace_back(lhsShape[lhsRank - 2]);
|
2020-01-11 01:30:12 +08:00
|
|
|
} else if (lhsShape.size() > 2 && rhsShape.size() == 2) {
|
|
|
|
// (s1 x s2 x... x sK x M x N) MATMUL (N x P)
|
|
|
|
// =>
|
|
|
|
// (s1 x s2 x... x sK x M x P)
|
|
|
|
|
|
|
|
// Check legality of matrix multiplication.
|
2020-01-28 01:08:23 +08:00
|
|
|
unsigned lhsRank = lhsShape.size();
|
|
|
|
if (lhsShape[lhsRank - 1] != -1 && rhsShape[0] != -1 &&
|
|
|
|
lhsShape[lhsRank - 1] != rhsShape[0])
|
2020-02-26 04:46:11 +08:00
|
|
|
emitError("Attempt to multiply incompatible matrices");
|
2020-01-11 01:30:12 +08:00
|
|
|
|
2020-02-14 23:43:17 +08:00
|
|
|
for (decltype(lhsRank) i = 0; i < lhsRank - 1; ++i)
|
2020-01-11 01:30:12 +08:00
|
|
|
dims.emplace_back(lhsShape[i]);
|
|
|
|
dims.emplace_back(rhsShape[1]);
|
|
|
|
} else if (lhsShape.size() == 2 && rhsShape.size() > 2) {
|
|
|
|
// (M x N) MATMUL (s1 x s2 x... x sK x N x P)
|
|
|
|
// =>
|
|
|
|
// (s1 x s2 x... x sK x M x P)
|
|
|
|
|
|
|
|
// Check legality of matrix multiplication.
|
2020-01-28 01:08:23 +08:00
|
|
|
unsigned rhsRank = rhsShape.size();
|
|
|
|
if (lhsShape[1] != -1 && rhsShape[rhsRank - 2] != -1 &&
|
|
|
|
lhsShape[1] != rhsShape[rhsRank - 2])
|
2020-02-26 04:46:11 +08:00
|
|
|
emitError("Attempt to multiply incompatible matrices");
|
2020-01-11 01:30:12 +08:00
|
|
|
|
2020-02-14 23:43:17 +08:00
|
|
|
for (decltype(rhsRank) i = 0; i < rhsRank - 2; ++i)
|
2020-01-11 01:30:12 +08:00
|
|
|
dims.emplace_back(rhsShape[i]);
|
|
|
|
dims.emplace_back(lhsShape[0]);
|
2020-01-28 01:08:23 +08:00
|
|
|
dims.emplace_back(rhsShape[rhsRank - 1]);
|
2020-01-11 01:30:12 +08:00
|
|
|
} else if (lhsShape.size() > 2 && rhsShape.size() > 2) {
|
|
|
|
// (s1 x s2 x... x sK x M x N) MATMUL (t1 x t2 x... x tK x N x P)
|
|
|
|
// =>
|
|
|
|
// (u1 x u2 x... x uK x M x P)
|
|
|
|
|
|
|
|
// Check legality of matrix multiplication.
|
2020-01-28 01:08:23 +08:00
|
|
|
unsigned lhsRank = lhsShape.size();
|
|
|
|
unsigned rhsRank = rhsShape.size();
|
|
|
|
if (lhsShape[lhsRank - 1] != -1 && rhsShape[rhsRank - 2] != -1 &&
|
|
|
|
lhsShape[lhsRank - 1] != rhsShape[rhsRank - 2])
|
2020-02-26 04:46:11 +08:00
|
|
|
emitError("Attempt to multiply incompatible matrices");
|
2020-01-11 01:30:12 +08:00
|
|
|
|
|
|
|
// Check and perform broadcasting for the shapes.
|
|
|
|
SmallVector<int64_t, 2> lhsBcastShape;
|
2020-02-14 23:43:17 +08:00
|
|
|
for (decltype(lhsRank) i = 0; i < lhsRank - 2; ++i)
|
2020-01-11 01:30:12 +08:00
|
|
|
lhsBcastShape.emplace_back(lhsShape[i]);
|
|
|
|
SmallVector<int64_t, 2> rhsBcastShape;
|
2020-02-14 23:43:17 +08:00
|
|
|
for (decltype(rhsRank) i = 0; i < rhsRank - 2; ++i)
|
2020-01-11 01:30:12 +08:00
|
|
|
rhsBcastShape.emplace_back(rhsShape[i]);
|
|
|
|
if (!getBroadcastedShape(lhsBcastShape, rhsBcastShape, dims))
|
2020-02-26 04:46:11 +08:00
|
|
|
emitError("Broadcasted dimensions are incompatible");
|
2020-01-11 01:30:12 +08:00
|
|
|
|
2020-01-28 01:08:23 +08:00
|
|
|
dims.emplace_back(lhsShape[lhsRank - 2]);
|
|
|
|
dims.emplace_back(rhsShape[rhsRank - 1]);
|
2020-01-10 04:30:57 +08:00
|
|
|
} else {
|
2020-01-11 01:30:12 +08:00
|
|
|
// This case covers all remaining combinations of 1 and 2-D matrices.
|
2020-01-11 04:26:29 +08:00
|
|
|
int64_t lhsDim = lhsShape[0];
|
|
|
|
int64_t rhsDim = rhsShape[0];
|
|
|
|
if (lhsShape.size() > 1) {
|
|
|
|
lhsDim = lhsShape[1];
|
2020-01-11 04:16:45 +08:00
|
|
|
dims.emplace_back(lhsShape[0]);
|
2020-01-11 04:26:29 +08:00
|
|
|
}
|
2020-01-11 04:16:45 +08:00
|
|
|
|
2020-01-11 04:26:29 +08:00
|
|
|
// Check legality of matrix multiplication.
|
|
|
|
if (lhsDim != -1 && rhsDim != -1 && lhsDim != rhsDim)
|
2020-02-26 04:46:11 +08:00
|
|
|
emitError("Attempt to multiply incompatible matrices");
|
2020-01-11 04:26:29 +08:00
|
|
|
|
|
|
|
if (rhsShape.size() > 1)
|
2020-01-11 04:16:45 +08:00
|
|
|
dims.emplace_back(rhsShape[1]);
|
2020-01-10 04:30:57 +08:00
|
|
|
}
|
|
|
|
|
2020-01-23 05:09:19 +08:00
|
|
|
getResult().setType(RankedTensorType::get(dims, lhsTy.getElementType()));
|
2019-11-13 02:37:46 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
// Gemm
|
|
|
|
|
|
|
|
void ONNXGemmOp::inferShapes() {
|
2020-02-26 04:46:11 +08:00
|
|
|
bool hasBias = !C().getType().isa<NoneType>();
|
2019-11-16 02:10:41 +08:00
|
|
|
// Cannot infer shape if no shape exists.
|
2020-02-26 04:46:11 +08:00
|
|
|
if (!A().getType().isa<RankedTensorType>() ||
|
|
|
|
!B().getType().isa<RankedTensorType>() ||
|
|
|
|
(hasBias && !C().getType().isa<RankedTensorType>()))
|
2019-11-16 02:10:41 +08:00
|
|
|
return;
|
2020-02-26 04:46:11 +08:00
|
|
|
auto lhsTy = A().getType().cast<RankedTensorType>();
|
|
|
|
auto rhsTy = B().getType().cast<RankedTensorType>();
|
2020-01-30 00:11:49 +08:00
|
|
|
|
|
|
|
int64_t M, N, K_A, K_B;
|
|
|
|
M = (transA() == 0) ? lhsTy.getShape()[0] : lhsTy.getShape()[1];
|
|
|
|
K_A = (transA() == 0) ? lhsTy.getShape()[1] : lhsTy.getShape()[0];
|
|
|
|
N = (transB() == 0) ? rhsTy.getShape()[1] : rhsTy.getShape()[0];
|
|
|
|
K_B = (transB() == 0) ? rhsTy.getShape()[0] : rhsTy.getShape()[1];
|
|
|
|
|
|
|
|
if ((K_A != -1) and (K_B != -1) and (K_A != K_B)) {
|
2020-02-26 04:46:11 +08:00
|
|
|
emitError("Tensor shapes mismatched");
|
2020-01-30 00:11:49 +08:00
|
|
|
}
|
|
|
|
|
2020-02-24 23:46:48 +08:00
|
|
|
if (hasBias) {
|
|
|
|
// Check whether bias is unidirectional broadcasting or not.
|
2020-02-26 04:46:11 +08:00
|
|
|
auto biasTy = C().getType().cast<RankedTensorType>();
|
2020-02-24 23:46:48 +08:00
|
|
|
auto shape = biasTy.getShape();
|
|
|
|
int rank = shape.size();
|
|
|
|
if ((rank > 2) ||
|
|
|
|
(rank >= 1 && shape[rank - 1] != -1 && N != -1 &&
|
2020-02-26 03:33:48 +08:00
|
|
|
N != shape[rank - 1] && shape[rank - 1] != 1) ||
|
2020-02-24 23:46:48 +08:00
|
|
|
(rank == 2 && shape[rank - 2] != -1 && M != -1 &&
|
2020-02-26 03:33:48 +08:00
|
|
|
M != shape[rank - 2] && shape[rank - 2] != 1)) {
|
2020-02-26 04:46:11 +08:00
|
|
|
emitError("Bias shape mismatched");
|
2020-02-24 23:46:48 +08:00
|
|
|
}
|
2020-02-20 23:55:24 +08:00
|
|
|
}
|
|
|
|
|
2019-11-16 02:10:41 +08:00
|
|
|
SmallVector<int64_t, 2> dims;
|
2020-02-20 23:55:24 +08:00
|
|
|
dims.emplace_back(M);
|
|
|
|
dims.emplace_back(N);
|
2020-01-14 01:21:29 +08:00
|
|
|
getResult().setType(RankedTensorType::get(dims, lhsTy.getElementType()));
|
2019-11-13 02:37:46 +08:00
|
|
|
}
|
|
|
|
|
2020-02-21 00:45:40 +08:00
|
|
|
/// BatchNormalizationTestMode
|
|
|
|
void ONNXBatchNormalizationTestModeOp::inferShapes() {
|
|
|
|
// Cannot infer shape if no shape exists.
|
2020-02-26 04:46:11 +08:00
|
|
|
if (!X().getType().isa<RankedTensorType>() ||
|
|
|
|
!scale().getType().isa<RankedTensorType>() ||
|
|
|
|
!B().getType().isa<RankedTensorType>() ||
|
|
|
|
!mean().getType().isa<RankedTensorType>() ||
|
|
|
|
!var().getType().isa<RankedTensorType>())
|
2020-02-21 00:45:40 +08:00
|
|
|
return;
|
|
|
|
|
2020-02-26 04:46:11 +08:00
|
|
|
auto inputTensorTy = X().getType().cast<RankedTensorType>();
|
|
|
|
auto scaleTensorTy = scale().getType().cast<RankedTensorType>();
|
|
|
|
auto biasTensorTy = B().getType().cast<RankedTensorType>();
|
|
|
|
auto meanTensorTy = mean().getType().cast<RankedTensorType>();
|
|
|
|
auto varianceTensorTy = var().getType().cast<RankedTensorType>();
|
2020-02-21 00:45:40 +08:00
|
|
|
|
|
|
|
// Check whether the shapes of scale, bias, mean and variance are valid.
|
|
|
|
// Operand's dimensions can be in the form of NxCxD1xD2x...xDn or N.
|
|
|
|
// In case of N, C is assumed to be 1.
|
|
|
|
// Shapes of scale, bias, mean and variance must be C.
|
|
|
|
int64_t c = -1;
|
2020-02-26 04:46:11 +08:00
|
|
|
if (inputTensorTy.getShape().size() == 1) {
|
2020-02-21 00:45:40 +08:00
|
|
|
c = 1;
|
2020-02-26 04:46:11 +08:00
|
|
|
} else if (inputTensorTy.getShape().size() > 2) {
|
|
|
|
c = (inputTensorTy.getShape()[1] != -1) ? inputTensorTy.getShape()[1] : -1;
|
2020-02-21 00:45:40 +08:00
|
|
|
} else {
|
2020-02-26 04:46:11 +08:00
|
|
|
emitError("Wrong rank for the input");
|
2020-02-21 00:45:40 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
if (c != -1) {
|
2020-02-26 04:46:11 +08:00
|
|
|
auto s = scaleTensorTy.getShape();
|
|
|
|
auto b = biasTensorTy.getShape();
|
|
|
|
auto m = meanTensorTy.getShape();
|
|
|
|
auto v = varianceTensorTy.getShape();
|
2020-02-21 00:45:40 +08:00
|
|
|
|
|
|
|
if ((s.size() != 1) || (s[0] != -1 && s[0] != c))
|
2020-02-26 04:46:11 +08:00
|
|
|
emitError("Wrong rank for the scale");
|
2020-02-21 00:45:40 +08:00
|
|
|
if ((b.size() != 1) || (b[0] != -1 && b[0] != c))
|
2020-02-26 04:46:11 +08:00
|
|
|
emitError("Wrong rank for the bias");
|
2020-02-21 00:45:40 +08:00
|
|
|
if ((m.size() != 1) || (m[0] != -1 && m[0] != c))
|
2020-02-26 04:46:11 +08:00
|
|
|
emitError("Wrong rank for the mean");
|
2020-02-21 00:45:40 +08:00
|
|
|
if ((v.size() != 1) || (v[0] != -1 && v[0] != c))
|
2020-02-26 04:46:11 +08:00
|
|
|
emitError("Wrong rank for the variance");
|
2020-02-21 00:45:40 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
// The output tensor of the same shape as the input.
|
2020-02-26 04:46:11 +08:00
|
|
|
getResult().setType(X().getType());
|
2020-02-21 00:45:40 +08:00
|
|
|
}
|
|
|
|
|
2019-11-13 02:37:46 +08:00
|
|
|
// TODO:
|
|
|
|
// Verify that matrix sizes are valid for multiplication and addition.
|
|
|
|
// Take into account the dimensionality of the matrix.
|
|
|
|
|
2019-12-14 04:28:56 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
// Reshape
|
|
|
|
|
|
|
|
void ONNXReshapeOp::inferShapes() {
|
|
|
|
// Cannot infer shape if no shape tensor is specified.
|
2020-02-26 04:46:11 +08:00
|
|
|
if (!shape().getType().isa<RankedTensorType>())
|
|
|
|
emitError("Shape tensor not ranked");
|
2019-12-14 04:28:56 +08:00
|
|
|
|
2020-02-26 04:46:11 +08:00
|
|
|
auto inputTensorTy = data().getType().cast<RankedTensorType>();
|
|
|
|
auto shapeTensorTy = shape().getType().cast<RankedTensorType>();
|
2019-12-14 04:28:56 +08:00
|
|
|
|
|
|
|
// Only rank 1 shape tensors are supported.
|
|
|
|
if (shapeTensorTy.getShape().size() != 1)
|
2020-02-26 04:46:11 +08:00
|
|
|
emitError("Shape tensor must have rank one");
|
2019-12-14 04:28:56 +08:00
|
|
|
|
|
|
|
int64_t outputRank = shapeTensorTy.getShape()[0];
|
|
|
|
|
|
|
|
// Shape tensor must have constant shape.
|
|
|
|
if (outputRank < 0)
|
2020-02-26 04:46:11 +08:00
|
|
|
emitError("Shape tensor must have constant shape");
|
2019-12-14 04:28:56 +08:00
|
|
|
|
2020-03-14 05:18:46 +08:00
|
|
|
// Compute total number of elements.
|
|
|
|
int64_t totalInputSize = 1;
|
|
|
|
for(auto inputDim : inputTensorTy.getShape())
|
|
|
|
totalInputSize *= inputDim;
|
|
|
|
|
2020-03-11 02:46:35 +08:00
|
|
|
// Check if second argument of ReshapeOp is a constant.
|
2020-03-14 05:18:46 +08:00
|
|
|
auto constantOp = getONNXConstantOp(shape());
|
2020-03-11 02:46:35 +08:00
|
|
|
|
|
|
|
SmallVector<int64_t, 2> dims(outputRank, -1);
|
|
|
|
if (constantOp) {
|
2020-03-16 23:17:28 +08:00
|
|
|
DenseElementsAttr valueAttribute =
|
|
|
|
constantOp.valueAttr().dyn_cast<DenseElementsAttr>();
|
|
|
|
|
2020-03-11 02:46:35 +08:00
|
|
|
if (!valueAttribute)
|
2020-03-16 23:17:28 +08:00
|
|
|
emitError("DenseElementsAttr expected");
|
|
|
|
|
|
|
|
// Get dims from valueAttribute.
|
|
|
|
auto valueIt = valueAttribute.getValues<IntegerAttr>().begin();
|
|
|
|
for (int i=0; i<outputRank; ++i)
|
|
|
|
dims[i] = (*valueIt++).cast<IntegerAttr>().getInt();
|
2020-03-11 02:46:35 +08:00
|
|
|
|
2020-03-16 23:17:28 +08:00
|
|
|
if (valueIt != valueAttribute.getValues<IntegerAttr>().end())
|
2020-03-11 02:46:35 +08:00
|
|
|
emitError("Constant value must have same rank as output");
|
|
|
|
|
2020-03-14 05:18:46 +08:00
|
|
|
int64_t numberOfDynamicInputs = 0;
|
|
|
|
int64_t totalKnownDimsSize = 1;
|
|
|
|
int64_t dynamicValueIndex = -1;
|
|
|
|
for (int i=0; i<outputRank; ++i) {
|
|
|
|
// Set output dimension.
|
|
|
|
if (dims[i] == 0)
|
|
|
|
dims[i] = inputTensorTy.getShape()[i];
|
|
|
|
|
|
|
|
if (dims[i] < 0) {
|
|
|
|
numberOfDynamicInputs++;
|
|
|
|
dynamicValueIndex = i;
|
|
|
|
} else {
|
|
|
|
totalKnownDimsSize *= dims[i];
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
// If the number of dynamic inputs is 1 then deduce the missing value
|
|
|
|
// based on the total input size. The total input size must be greater
|
|
|
|
// than 0 i.e. all constant dimensions.
|
|
|
|
// TODO: Support dynamic input dimensons.
|
|
|
|
if (numberOfDynamicInputs == 1 && totalKnownDimsSize > 0 &&
|
|
|
|
totalInputSize > 0)
|
|
|
|
dims[dynamicValueIndex] = totalInputSize / totalKnownDimsSize;
|
2020-03-11 02:46:35 +08:00
|
|
|
}
|
2019-12-14 04:28:56 +08:00
|
|
|
|
2020-01-14 01:21:29 +08:00
|
|
|
getResult().setType(
|
2019-12-14 04:28:56 +08:00
|
|
|
RankedTensorType::get(dims, inputTensorTy.getElementType()));
|
|
|
|
}
|
|
|
|
|
2020-01-08 03:48:01 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
// Transpose
|
|
|
|
|
|
|
|
void ONNXTransposeOp::inferShapes() {
|
|
|
|
// Cannot infer shape if no shape exists.
|
2020-02-26 04:46:11 +08:00
|
|
|
if (!data().getType().isa<RankedTensorType>())
|
2020-01-22 09:39:11 +08:00
|
|
|
return;
|
2020-01-08 03:48:01 +08:00
|
|
|
|
|
|
|
// Naive transposition which handles the default case of
|
|
|
|
// reversing the shape of the tensor (similar to numpy.transpose).
|
2020-02-26 04:46:11 +08:00
|
|
|
auto arrayTy = data().getType().cast<RankedTensorType>();
|
2020-01-14 07:08:19 +08:00
|
|
|
SmallVector<int64_t, 2> dims;
|
2020-01-27 23:09:14 +08:00
|
|
|
auto permutation = ONNXTransposeOp::permAttr();
|
|
|
|
if (permutation) {
|
2020-01-14 07:08:19 +08:00
|
|
|
// Perform transposition according to perm attribute.
|
2020-01-21 03:46:54 +08:00
|
|
|
for (auto perm : permutation.getValue())
|
|
|
|
dims.emplace_back(arrayTy.getShape()[perm.cast<IntegerAttr>().getInt()]);
|
2020-01-14 07:08:19 +08:00
|
|
|
} else {
|
|
|
|
// Default
|
2020-01-21 03:46:54 +08:00
|
|
|
for (auto dim : llvm::reverse(arrayTy.getShape()))
|
|
|
|
dims.emplace_back(dim);
|
2020-01-14 07:08:19 +08:00
|
|
|
}
|
|
|
|
|
2020-01-14 01:21:29 +08:00
|
|
|
getResult().setType(RankedTensorType::get(dims, arrayTy.getElementType()));
|
2020-01-08 03:48:01 +08:00
|
|
|
}
|
|
|
|
|
2020-01-21 00:16:27 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2020-02-10 21:38:19 +08:00
|
|
|
// ReduceMax
|
|
|
|
|
|
|
|
void ONNXReduceMaxOp::inferShapes() {
|
|
|
|
if (!getOperand().getType().isa<RankedTensorType>()) {
|
2020-02-26 04:46:11 +08:00
|
|
|
emitError("Shape tensor not ranked");
|
2020-02-10 21:38:19 +08:00
|
|
|
return;
|
|
|
|
}
|
|
|
|
|
|
|
|
auto operandTy = getOperand().getType().cast<RankedTensorType>();
|
|
|
|
getResult().setType(getReductionOutputType(operandTy, axes(), keepdims()));
|
|
|
|
}
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
// ReduceMin
|
|
|
|
|
|
|
|
void ONNXReduceMinOp::inferShapes() {
|
|
|
|
if (!getOperand().getType().isa<RankedTensorType>()) {
|
2020-02-26 04:46:11 +08:00
|
|
|
emitError("Shape tensor not ranked");
|
2020-02-10 21:38:19 +08:00
|
|
|
return;
|
|
|
|
}
|
|
|
|
|
|
|
|
auto operandTy = getOperand().getType().cast<RankedTensorType>();
|
|
|
|
getResult().setType(getReductionOutputType(operandTy, axes(), keepdims()));
|
|
|
|
}
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
// ReduceProd
|
|
|
|
|
|
|
|
void ONNXReduceProdOp::inferShapes() {
|
|
|
|
if (!getOperand().getType().isa<RankedTensorType>()) {
|
2020-02-26 04:46:11 +08:00
|
|
|
emitError("Shape tensor not ranked");
|
2020-02-10 21:38:19 +08:00
|
|
|
return;
|
|
|
|
}
|
|
|
|
|
|
|
|
auto operandTy = getOperand().getType().cast<RankedTensorType>();
|
|
|
|
getResult().setType(getReductionOutputType(operandTy, axes(), keepdims()));
|
|
|
|
}
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
// ReduceSum
|
|
|
|
|
|
|
|
void ONNXReduceSumOp::inferShapes() {
|
|
|
|
if (!getOperand().getType().isa<RankedTensorType>()) {
|
2020-02-26 04:46:11 +08:00
|
|
|
emitError("Shape tensor not ranked");
|
2020-02-10 21:38:19 +08:00
|
|
|
return;
|
|
|
|
}
|
|
|
|
|
|
|
|
auto operandTy = getOperand().getType().cast<RankedTensorType>();
|
|
|
|
getResult().setType(getReductionOutputType(operandTy, axes(), keepdims()));
|
|
|
|
}
|
|
|
|
|
2020-03-14 05:18:46 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2020-01-21 00:16:27 +08:00
|
|
|
// Conv
|
|
|
|
|
|
|
|
// For this operation, we define the attributes once in the original Conv
|
|
|
|
// operation class. There is no need to redefine the attribute names for the
|
|
|
|
// other classes based on Conv.
|
2020-03-12 06:36:02 +08:00
|
|
|
// Conv attributes output:
|
|
|
|
// - auto_pad set to NOTSET;
|
|
|
|
// - dilations, strides: set to 1 if not defined by user;
|
|
|
|
// - kernelShape: inferred from weight matrix if not defined by user;
|
|
|
|
// - pads: set to proper value, 0 if not defined by user.
|
|
|
|
|
2020-01-21 00:16:27 +08:00
|
|
|
void ONNXConvNoBiasOp::inferShapes() {
|
|
|
|
// Generic shape for data input X and weight tensor W:
|
|
|
|
// X: (N x C x D1 x D2 ... x Dn)
|
|
|
|
// W: (M x C/group x k1 x k2 x ... x kn)
|
|
|
|
|
|
|
|
// Cannot infer shape if no shape exists.
|
2020-02-26 04:46:11 +08:00
|
|
|
if (!X().getType().isa<RankedTensorType>() ||
|
|
|
|
!W().getType().isa<RankedTensorType>())
|
2020-01-21 00:16:27 +08:00
|
|
|
return;
|
2020-01-22 09:39:11 +08:00
|
|
|
|
2020-03-12 06:36:02 +08:00
|
|
|
auto xTy = X().getType().cast<RankedTensorType>();
|
|
|
|
auto xShape = xTy.getShape();
|
2020-02-26 04:46:11 +08:00
|
|
|
auto weightTy = W().getType().cast<RankedTensorType>();
|
2020-01-21 00:16:27 +08:00
|
|
|
auto weightShape = weightTy.getShape();
|
2020-03-14 05:18:46 +08:00
|
|
|
auto builder = mlir::Builder(this->getContext());
|
2020-01-21 00:16:27 +08:00
|
|
|
|
2020-02-14 23:54:08 +08:00
|
|
|
// Lowest supported convolution is a one dimensional convolution.
|
2020-03-12 06:36:02 +08:00
|
|
|
if (xShape.size() < 3)
|
2020-02-26 04:46:11 +08:00
|
|
|
emitError("Data input shape must be at least (NxCxD1)");
|
2020-02-08 05:51:32 +08:00
|
|
|
|
2020-01-21 07:50:21 +08:00
|
|
|
// Check that shape of weight and data have same length.
|
2020-03-12 06:36:02 +08:00
|
|
|
if (xShape.size() != weightShape.size())
|
2020-02-26 04:46:11 +08:00
|
|
|
emitError("Weight size not compatible with data size");
|
2020-01-21 00:16:27 +08:00
|
|
|
|
|
|
|
// Group is a required attribute and should have default value of 1.
|
2020-03-12 06:36:02 +08:00
|
|
|
int64_t group = ONNXConvNoBiasOp::group().getSExtValue();
|
2020-03-14 05:18:46 +08:00
|
|
|
|
|
|
|
// Check if the attribute actually exists. If it does not then add it.
|
|
|
|
if (!groupAttr())
|
|
|
|
groupAttr(builder.getI64IntegerAttr(group));
|
|
|
|
|
2020-01-21 00:16:27 +08:00
|
|
|
// Check that the X.shape[1] == (W.shape[1] * group) == C condition holds.
|
2020-03-12 06:36:02 +08:00
|
|
|
if (xShape[1] != -1 && weightShape[1] != -1 &&
|
|
|
|
xShape[1] != (weightShape[1] * group))
|
2020-02-26 04:46:11 +08:00
|
|
|
emitError("Channel dimension mismatch");
|
2020-01-21 00:16:27 +08:00
|
|
|
|
2020-01-23 04:05:56 +08:00
|
|
|
// Note: the value of the group attribut only impacts the way the
|
|
|
|
// computation is carried out and not the actual output size.
|
|
|
|
|
2020-01-21 07:50:21 +08:00
|
|
|
// Number of spatial dimensions.
|
2020-03-12 06:36:02 +08:00
|
|
|
auto spatialOffset = 2;
|
|
|
|
int32_t spatialRank = xShape.size() - spatialOffset;
|
2020-01-21 07:50:21 +08:00
|
|
|
|
|
|
|
// Use kernel_shape attribute if present otherwise use size from weight
|
|
|
|
// argument.
|
2020-03-12 06:36:02 +08:00
|
|
|
auto kernelShape = kernel_shape();
|
|
|
|
if (kernelShape.hasValue()) {
|
|
|
|
if (ArrayAttrSize(kernelShape) != spatialRank)
|
2020-02-26 04:46:11 +08:00
|
|
|
emitError("kernel_shape length incompatible with spatial dimensions");
|
2020-03-12 06:36:02 +08:00
|
|
|
// Have the right number of values, check them.
|
|
|
|
for (int i = 0; i < spatialRank; ++i)
|
|
|
|
if (ArrayAttrIntVal(kernelShape, i) < 1)
|
|
|
|
emitError("bad kernel_shape value");
|
2020-01-21 07:50:21 +08:00
|
|
|
} else {
|
2020-03-12 06:36:02 +08:00
|
|
|
// Deduce shape from weight input.
|
|
|
|
SmallVector<int64_t, 2> defaultVals;
|
|
|
|
for (int i = 0; i < spatialRank; ++i)
|
|
|
|
defaultVals.emplace_back(weightShape[spatialOffset + i]);
|
|
|
|
// Convert to ArrayRef, then build attribute, then store attribute.
|
|
|
|
ArrayRef<int64_t> defaultRefs(defaultVals);
|
|
|
|
auto builder = mlir::Builder(getContext());
|
|
|
|
kernel_shapeAttr(builder.getI64ArrayAttr(defaultRefs));
|
|
|
|
kernelShape = kernel_shape();
|
2020-01-21 07:50:21 +08:00
|
|
|
}
|
|
|
|
|
2020-03-12 06:36:02 +08:00
|
|
|
// Process strides, dilations, and pads.
|
|
|
|
processConvTypeParams<>(this, X());
|
|
|
|
auto dilationsOpt = dilations();
|
|
|
|
auto stridesOpt = strides();
|
|
|
|
auto padsOpt = pads();
|
2020-03-11 04:58:05 +08:00
|
|
|
|
2020-03-12 06:36:02 +08:00
|
|
|
// First two output dimensions consist of the number of batches and the
|
|
|
|
// number of kernels being applied.
|
|
|
|
SmallVector<int64_t, 4> outputDims;
|
|
|
|
// Insert batch size.
|
|
|
|
outputDims.emplace_back(xShape[0]);
|
|
|
|
// Insert number of filters being applied (number of output channels).
|
|
|
|
outputDims.emplace_back(weightShape[0]);
|
2020-03-13 21:59:16 +08:00
|
|
|
// Compute and insert spatial dims.
|
|
|
|
insertConvSpatialDim(
|
|
|
|
&outputDims, xShape, kernelShape, padsOpt, stridesOpt, dilationsOpt);
|
2020-03-12 06:36:02 +08:00
|
|
|
|
|
|
|
getResult().setType(RankedTensorType::get(outputDims, xTy.getElementType()));
|
2020-01-21 00:16:27 +08:00
|
|
|
}
|
|
|
|
|
2020-01-29 23:46:02 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
2020-01-31 03:30:28 +08:00
|
|
|
|
2020-03-13 21:59:16 +08:00
|
|
|
// AveragePool
|
2020-02-26 03:33:48 +08:00
|
|
|
// Infer shape attributes output:
|
|
|
|
// - auto_pad set to NOTSET;
|
2020-03-13 21:59:16 +08:00
|
|
|
// - strides: set to 1 if not defined by user;
|
2020-02-26 03:33:48 +08:00
|
|
|
// - pads: set to proper value, 0 if not defined by user.
|
2020-01-31 03:30:28 +08:00
|
|
|
|
2020-03-13 21:59:16 +08:00
|
|
|
void ONNXAveragePoolOp::inferShapes() {
|
2020-01-31 03:30:28 +08:00
|
|
|
// Cannot infer shape if no shape exists.
|
|
|
|
if (!X().getType().isa<RankedTensorType>())
|
|
|
|
return;
|
|
|
|
|
2020-03-13 21:59:16 +08:00
|
|
|
// Get shape of input.
|
2020-01-31 03:30:28 +08:00
|
|
|
auto xTy = X().getType().cast<RankedTensorType>();
|
|
|
|
auto xShape = xTy.getShape();
|
|
|
|
|
2020-03-13 21:59:16 +08:00
|
|
|
// Kernel shape.
|
2020-01-31 03:30:28 +08:00
|
|
|
auto kernelShape = kernel_shape();
|
|
|
|
if (!kernelShape)
|
2020-02-24 23:46:48 +08:00
|
|
|
emitError(
|
2020-02-26 03:33:48 +08:00
|
|
|
"kernel_shape is a mandatory attribute for which there is no default");
|
2020-01-31 03:30:28 +08:00
|
|
|
|
2020-02-26 03:33:48 +08:00
|
|
|
// Ceil mode.
|
2020-01-31 03:30:28 +08:00
|
|
|
auto ceilMode = ceil_mode().getSExtValue();
|
|
|
|
|
2020-03-13 21:59:16 +08:00
|
|
|
// Process strides and pads.
|
|
|
|
processConvStrideParam<ONNXAveragePoolOp>(this, kernelShape);
|
|
|
|
auto stridesOpt = strides();
|
|
|
|
processConvPadParam<ONNXAveragePoolOp>(
|
|
|
|
this, xShape, kernelShape, stridesOpt, llvm::None);
|
|
|
|
auto padsOpt = pads();
|
|
|
|
|
|
|
|
SmallVector<int64_t, 4> outputDims;
|
|
|
|
// Insert batch size.
|
|
|
|
outputDims.emplace_back(xShape[0]);
|
|
|
|
outputDims.emplace_back(xShape[1]);
|
|
|
|
// Compute and insert spatial dims.
|
|
|
|
insertConvSpatialDim(&outputDims, xShape, kernelShape, padsOpt, stridesOpt,
|
|
|
|
llvm::None, ceilMode);
|
|
|
|
|
|
|
|
getResult().setType(RankedTensorType::get(outputDims, xTy.getElementType()));
|
|
|
|
}
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
// MaxPoolSingleOut
|
|
|
|
// Infer shape attributes output:
|
|
|
|
// - auto_pad set to NOTSET;
|
|
|
|
// - dilations, strides: set to 1 if not defined by user;
|
|
|
|
// - pads: set to proper value, 0 if not defined by user.
|
|
|
|
|
|
|
|
void ONNXMaxPoolSingleOutOp::inferShapes() {
|
|
|
|
// Cannot infer shape if no shape exists.
|
|
|
|
if (!X().getType().isa<RankedTensorType>())
|
|
|
|
return;
|
|
|
|
|
|
|
|
// Get shape of input.
|
|
|
|
auto xTy = X().getType().cast<RankedTensorType>();
|
|
|
|
auto xShape = xTy.getShape();
|
|
|
|
|
|
|
|
// Kernel shape.
|
|
|
|
auto kernelShape = kernel_shape();
|
|
|
|
if (!kernelShape)
|
|
|
|
emitError(
|
|
|
|
"kernel_shape is a mandatory attribute for which there is no default");
|
|
|
|
|
2020-02-26 03:33:48 +08:00
|
|
|
// Storage order.
|
|
|
|
auto storageOrder = storage_order().getSExtValue();
|
|
|
|
if (storageOrder != 0)
|
|
|
|
emitError("column major storage order not supported at this time");
|
2020-02-24 23:46:48 +08:00
|
|
|
|
2020-03-13 21:59:16 +08:00
|
|
|
// Process strides, dilations, and pads.
|
|
|
|
processConvTypeParams<>(this, X());
|
2020-03-12 06:36:02 +08:00
|
|
|
auto dilationsOpt = dilations();
|
|
|
|
auto stridesOpt = strides();
|
|
|
|
auto padsOpt = pads();
|
2020-03-13 21:59:16 +08:00
|
|
|
|
|
|
|
// Ceil mode.
|
|
|
|
auto ceilMode = ceil_mode().getSExtValue();
|
|
|
|
|
|
|
|
SmallVector<int64_t, 4> outputDims;
|
|
|
|
// Insert batch size.
|
|
|
|
outputDims.emplace_back(xShape[0]);
|
|
|
|
outputDims.emplace_back(xShape[1]);
|
|
|
|
// Compute and insert spatial dims.
|
|
|
|
insertConvSpatialDim(&outputDims, xShape, kernelShape, padsOpt, stridesOpt,
|
|
|
|
dilationsOpt, ceilMode);
|
|
|
|
|
|
|
|
getResult().setType(RankedTensorType::get(outputDims, xTy.getElementType()));
|
2020-01-31 03:30:28 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2020-02-26 06:43:49 +08:00
|
|
|
static Type padShapeInferenceHelper(Value data, ArrayAttr padsOpt) {
|
2020-02-14 01:08:29 +08:00
|
|
|
// Cannot infer shape if no shape exists.
|
2020-02-26 06:43:49 +08:00
|
|
|
if (!data.getType().isa<RankedTensorType>())
|
|
|
|
return (Type)NULL;
|
|
|
|
auto dataTy = data.getType().cast<RankedTensorType>();
|
2020-02-14 01:08:29 +08:00
|
|
|
auto dataShape = dataTy.getShape();
|
|
|
|
auto dataRank = dataShape.size();
|
|
|
|
SmallVector<int64_t, 4> outputShape(dataShape.begin(), dataShape.end());
|
|
|
|
if (padsOpt) {
|
|
|
|
auto padsArray = padsOpt.getValue();
|
2020-02-26 06:43:49 +08:00
|
|
|
// Pads consists of two values for each axis of data.
|
2020-03-12 06:36:02 +08:00
|
|
|
// The two values specify the number of elements padded before and after
|
|
|
|
// respectively.
|
2020-02-14 01:08:29 +08:00
|
|
|
for (int i = 0; i < dataRank; ++i) {
|
2020-03-13 22:19:27 +08:00
|
|
|
int64_t p1 = (padsArray[i]).cast<IntegerAttr>().getInt();
|
|
|
|
int64_t p2 = (padsArray[i + dataRank]).cast<IntegerAttr>().getInt();
|
2020-03-12 06:36:02 +08:00
|
|
|
// Have to non-negative constant
|
|
|
|
if (p1 < 0 || p2 < 0)
|
2020-02-26 06:43:49 +08:00
|
|
|
return (Type)NULL;
|
2020-03-13 22:19:27 +08:00
|
|
|
if (outputShape[i] != -1)
|
|
|
|
outputShape[i] += p1 + p2;
|
2020-02-14 01:08:29 +08:00
|
|
|
}
|
2020-02-26 06:43:49 +08:00
|
|
|
|
|
|
|
return (RankedTensorType::get(outputShape, dataTy.getElementType()));
|
2020-02-14 01:08:29 +08:00
|
|
|
} else {
|
2020-02-26 06:43:49 +08:00
|
|
|
return (Type)NULL;
|
2020-02-14 01:08:29 +08:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2020-02-26 06:43:49 +08:00
|
|
|
// PadConstantPad
|
|
|
|
|
2020-03-12 06:36:02 +08:00
|
|
|
void ONNXPadConstantPadOp::inferShapes() {
|
2020-02-26 06:43:49 +08:00
|
|
|
auto outputType = padShapeInferenceHelper(data(), pads());
|
|
|
|
if (outputType) {
|
|
|
|
getResult().setType(outputType);
|
2020-03-12 06:36:02 +08:00
|
|
|
}
|
2020-02-26 06:43:49 +08:00
|
|
|
return;
|
|
|
|
}
|
|
|
|
|
2020-02-14 01:08:29 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
// PadConstantValuePad
|
|
|
|
|
2020-03-12 06:36:02 +08:00
|
|
|
void ONNXPadConstantValuePadOp::inferShapes() {
|
2020-02-26 06:43:49 +08:00
|
|
|
auto outputType = padShapeInferenceHelper(data(), pads());
|
|
|
|
if (outputType) {
|
|
|
|
getResult().setType(outputType);
|
|
|
|
}
|
|
|
|
return;
|
2020-02-14 01:08:29 +08:00
|
|
|
}
|
|
|
|
|
2020-03-10 08:15:58 +08:00
|
|
|
void ONNXPadConstantValuePadOp::build(Builder *builder, OperationState &state,
|
|
|
|
Value data, ArrayAttr pads, FloatAttr constant_value, StringAttr mode) {
|
|
|
|
Type outputType = padShapeInferenceHelper(data, pads);
|
|
|
|
if (!outputType) {
|
|
|
|
auto elementType = data.getType().cast<TensorType>().getElementType();
|
|
|
|
outputType = UnrankedTensorType::get(elementType);
|
|
|
|
}
|
|
|
|
build(builder, state, outputType, data, pads, constant_value, mode);
|
|
|
|
}
|
|
|
|
|
2020-02-14 01:08:29 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2020-01-29 23:46:02 +08:00
|
|
|
// Unsqueeze
|
|
|
|
|
|
|
|
void ONNXUnsqueezeOp::inferShapes() {
|
2020-02-26 04:46:11 +08:00
|
|
|
if (!data().getType().isa<RankedTensorType>())
|
2020-01-29 23:46:02 +08:00
|
|
|
return;
|
|
|
|
|
2020-02-26 04:46:11 +08:00
|
|
|
auto operandTy = data().getType().cast<RankedTensorType>();
|
2020-01-29 23:46:02 +08:00
|
|
|
int inRank = operandTy.getRank();
|
|
|
|
|
|
|
|
ArrayAttr axisAttrs = axesAttr();
|
|
|
|
SmallVector<int, 4> axes;
|
|
|
|
int outRank = 0;
|
|
|
|
if (axisAttrs) {
|
|
|
|
outRank = inRank + axisAttrs.getValue().size();
|
|
|
|
for (auto axisAttr : axisAttrs.getValue()) {
|
|
|
|
int axis = axisAttr.cast<IntegerAttr>().getInt();
|
|
|
|
axis = axis >= 0 ? axis : (outRank + axis);
|
|
|
|
// Valid range
|
|
|
|
assert(axis >= -outRank && axis <= outRank - 1);
|
|
|
|
if (std::find(axes.begin(), axes.end(), axis) == axes.end())
|
|
|
|
axes.emplace_back(axis);
|
|
|
|
else
|
2020-02-26 04:46:11 +08:00
|
|
|
emitError("Duplicated axes");
|
2020-01-29 23:46:02 +08:00
|
|
|
}
|
|
|
|
} else {
|
2020-02-26 04:46:11 +08:00
|
|
|
emitError("Axes attribute is required");
|
2020-01-29 23:46:02 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
SmallVector<int64_t, 4> dims;
|
|
|
|
for (int i = 0, j = 0; i < outRank || j < inRank; ++i) {
|
|
|
|
if (std::find(axes.begin(), axes.end(), i) != axes.end()) {
|
|
|
|
dims.emplace_back(1);
|
|
|
|
} else {
|
|
|
|
dims.emplace_back(operandTy.getShape()[j++]);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
getResult().setType(RankedTensorType::get(dims, operandTy.getElementType()));
|
|
|
|
}
|
|
|
|
|
2020-03-12 22:58:42 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// Constant
|
|
|
|
|
|
|
|
void ONNXConstantOp::inferShapes() {
|
|
|
|
if ((sparse_value().hasValue() && value().hasValue()) ||
|
|
|
|
(!sparse_value().hasValue() && !value().hasValue()))
|
|
|
|
emitError("Require exactly one of the two attributes, either value or "
|
|
|
|
"sparse_value");
|
|
|
|
|
|
|
|
ElementsAttr valAttr;
|
|
|
|
if (sparse_value().hasValue())
|
|
|
|
valAttr = sparse_valueAttr().cast<SparseElementsAttr>();
|
|
|
|
else
|
|
|
|
valAttr = valueAttr().cast<DenseElementsAttr>();
|
|
|
|
getResult().setType(valAttr.getType());
|
|
|
|
}
|
|
|
|
|
2019-11-02 05:09:48 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// TableGen'd op method definitions
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
#define GET_OP_CLASSES
|
2019-12-23 13:13:52 +08:00
|
|
|
#include "src/onnx.cpp.inc"
|