2020-03-19 16:48:09 +08:00
|
|
|
//===------------------ ONNXOps.cpp - ONNX Operations ---------------------===//
|
2019-11-02 05:09:48 +08:00
|
|
|
//
|
2020-03-19 16:48:09 +08:00
|
|
|
// Copyright 2019-2020 The IBM Research Authors.
|
2019-11-02 05:09:48 +08:00
|
|
|
//
|
|
|
|
// =============================================================================
|
|
|
|
//
|
2020-03-19 16:48:09 +08:00
|
|
|
// This file provides definition of ONNX dialect operations.
|
2019-11-02 05:09:48 +08:00
|
|
|
//
|
|
|
|
//===----------------------------------------------------------------------===//
|
2020-03-19 16:48:09 +08:00
|
|
|
|
2019-12-20 00:28:06 +08:00
|
|
|
#include "mlir/Dialect/Traits.h"
|
2019-11-02 05:09:48 +08:00
|
|
|
#include "mlir/IR/Block.h"
|
|
|
|
#include "mlir/IR/Builders.h"
|
2020-06-26 04:34:37 +08:00
|
|
|
#include "mlir/IR/DialectImplementation.h"
|
2019-11-02 05:09:48 +08:00
|
|
|
#include "mlir/IR/Function.h"
|
|
|
|
#include "mlir/IR/IntegerSet.h"
|
|
|
|
#include "mlir/IR/Matchers.h"
|
2020-01-21 03:46:54 +08:00
|
|
|
#include "mlir/IR/Module.h"
|
2019-11-02 05:09:48 +08:00
|
|
|
#include "mlir/IR/OpImplementation.h"
|
|
|
|
#include "mlir/IR/PatternMatch.h"
|
2019-12-17 07:45:39 +08:00
|
|
|
#include "llvm/ADT/SetVector.h"
|
|
|
|
#include "llvm/ADT/SmallBitVector.h"
|
2020-06-09 14:55:49 +08:00
|
|
|
#include "llvm/Support/FormatVariadic.h"
|
2019-11-02 05:09:48 +08:00
|
|
|
|
2020-03-19 16:48:09 +08:00
|
|
|
#include "ONNXOps.hpp"
|
2019-11-02 05:09:48 +08:00
|
|
|
|
|
|
|
using namespace mlir;
|
2019-12-20 00:28:06 +08:00
|
|
|
using namespace mlir::OpTrait::util;
|
2020-06-26 04:34:37 +08:00
|
|
|
using namespace mlir::onnxmlir;
|
2019-11-02 05:09:48 +08:00
|
|
|
|
2020-02-26 03:33:48 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// ONNX Helper functions
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
static size_t ArrayAttrSize(ArrayAttr a) { return a.size(); }
|
|
|
|
|
|
|
|
static size_t ArrayAttrSize(Optional<ArrayAttr> a) {
|
|
|
|
return a.getValue().size();
|
|
|
|
}
|
|
|
|
|
|
|
|
static int64_t ArrayAttrIntVal(ArrayAttr a, int i) {
|
|
|
|
return (a.getValue()[i]).cast<IntegerAttr>().getInt();
|
|
|
|
}
|
|
|
|
|
|
|
|
static int64_t ArrayAttrIntVal(Optional<ArrayAttr> a, int i) {
|
|
|
|
return (a.getValue().getValue()[i]).cast<IntegerAttr>().getInt();
|
|
|
|
}
|
|
|
|
|
2020-03-14 05:18:46 +08:00
|
|
|
// Returns the ConstantOp which defines an MLIR Value or null.
|
|
|
|
static mlir::ONNXConstantOp getONNXConstantOp(Value value) {
|
|
|
|
return dyn_cast_or_null<mlir::ONNXConstantOp>(value.getDefiningOp());
|
|
|
|
}
|
|
|
|
|
2020-05-14 17:31:33 +08:00
|
|
|
// This method substitutes any uses of dimensions and symbols (e.g.
|
|
|
|
// dim#0 with dimReplacements[0]) in an affine map, simplifies the modified
|
|
|
|
// affine map, and returns an integer constant.
|
|
|
|
int64_t AffineMapIntConstant(Builder &builder, AffineMap map,
|
|
|
|
ArrayRef<int64_t> dimReplacements, ArrayRef<int64_t> symReplacements,
|
|
|
|
unsigned numResultDims, unsigned numResultSyms) {
|
|
|
|
// Prepare affine expressions.
|
|
|
|
SmallVector<AffineExpr, 4> dimExprs, symExprs;
|
|
|
|
for (int64_t dim : dimReplacements) {
|
|
|
|
AffineExpr exp = builder.getAffineConstantExpr(dim);
|
|
|
|
dimExprs.emplace_back(exp);
|
|
|
|
}
|
|
|
|
for (int64_t sym : symReplacements) {
|
|
|
|
AffineExpr exp = builder.getAffineConstantExpr(sym);
|
|
|
|
symExprs.emplace_back(exp);
|
|
|
|
}
|
|
|
|
// Replace all the affine map's arguments with real values and evaluate the
|
|
|
|
// map.
|
|
|
|
AffineMap replacedDimMap = map.replaceDimsAndSymbols(
|
|
|
|
dimExprs, symExprs, numResultDims, numResultSyms);
|
|
|
|
AffineMap simplifiedMap = simplifyAffineMap(replacedDimMap);
|
|
|
|
return simplifiedMap.getSingleConstantResult();
|
|
|
|
}
|
|
|
|
|
2020-02-10 21:38:19 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// Get reduction type
|
|
|
|
//===----------------------------------------------------------------------===//
|
2020-02-26 03:33:48 +08:00
|
|
|
RankedTensorType getReductionOutputType(
|
|
|
|
RankedTensorType operandTy, Optional<ArrayAttr> axesAttrs, APInt keepdims) {
|
2020-02-10 21:38:19 +08:00
|
|
|
int64_t rank = operandTy.getRank();
|
|
|
|
|
|
|
|
SmallVector<int64_t, 4> axes;
|
|
|
|
if (axesAttrs != llvm::None) {
|
|
|
|
for (auto axisAttr : axesAttrs.getValue()) {
|
|
|
|
int64_t axis = axisAttr.cast<IntegerAttr>().getInt();
|
|
|
|
axis = axis >= 0 ? axis : (rank + axis);
|
|
|
|
assert(axis >= -rank && axis <= rank - 1);
|
|
|
|
if (std::find(axes.begin(), axes.end(), axis) == axes.end())
|
|
|
|
axes.emplace_back(axis);
|
|
|
|
}
|
|
|
|
} else {
|
|
|
|
for (decltype(rank) i = 0; i < rank; ++i) {
|
|
|
|
axes.emplace_back(i);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
// Mark reduction axes.
|
|
|
|
SmallVector<bool, 4> isReductionAxis;
|
|
|
|
for (decltype(rank) i = 0; i < rank; ++i) {
|
|
|
|
if (std::find(axes.begin(), axes.end(), i) != axes.end())
|
|
|
|
isReductionAxis.emplace_back(true);
|
|
|
|
else
|
|
|
|
isReductionAxis.emplace_back(false);
|
|
|
|
}
|
|
|
|
|
|
|
|
// KeepDims
|
|
|
|
bool isKeepdims = (keepdims == 1) ? true : false;
|
|
|
|
|
|
|
|
SmallVector<int64_t, 4> dims;
|
|
|
|
for (decltype(rank) i = 0; i < rank; ++i) {
|
|
|
|
if (isReductionAxis[i]) {
|
|
|
|
if (isKeepdims)
|
|
|
|
dims.emplace_back(1); // reduction dimension
|
|
|
|
} else {
|
|
|
|
dims.emplace_back(operandTy.getShape()[i]);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
return RankedTensorType::get(dims, operandTy.getElementType());
|
|
|
|
}
|
|
|
|
|
2020-03-13 21:59:16 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// Support function that computes default values for dilations.
|
2020-06-15 11:49:09 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
2020-03-13 21:59:16 +08:00
|
|
|
template <class T>
|
2020-05-27 10:09:28 +08:00
|
|
|
static LogicalResult processConvDilationParam(
|
|
|
|
T *op, Optional<ArrayAttr> kernelShape) {
|
2020-03-13 21:59:16 +08:00
|
|
|
auto builder = mlir::Builder(op->getContext());
|
|
|
|
auto kernelRank = ArrayAttrSize(kernelShape);
|
|
|
|
|
|
|
|
auto dilationsOpt = op->dilations();
|
|
|
|
if (dilationsOpt.hasValue()) {
|
2020-05-27 10:09:28 +08:00
|
|
|
if (ArrayAttrSize(dilationsOpt) != kernelRank) {
|
|
|
|
return op->emitError(
|
|
|
|
"dialation rank is not the same as the spatial rank");
|
|
|
|
}
|
2020-03-13 21:59:16 +08:00
|
|
|
// Test values to be greater than 0.
|
|
|
|
for (int i = 0; i < kernelRank; ++i) {
|
2020-05-27 10:09:28 +08:00
|
|
|
if (ArrayAttrIntVal(dilationsOpt, i) < 1) {
|
|
|
|
return op->emitError("dialation value must be nonzero positive");
|
|
|
|
}
|
2020-03-13 21:59:16 +08:00
|
|
|
}
|
|
|
|
} else {
|
|
|
|
// Default dilatation is needed, all dimensions init with 1.
|
|
|
|
SmallVector<int64_t, 4> defaultVals(kernelRank, 1);
|
|
|
|
// Convert to ArrayRef, then build attribute, then store attribute.
|
|
|
|
ArrayRef<int64_t> defaultRefs(defaultVals);
|
|
|
|
op->dilationsAttr(builder.getI64ArrayAttr(defaultRefs));
|
|
|
|
}
|
2020-05-27 10:09:28 +08:00
|
|
|
return success();
|
2020-03-13 21:59:16 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// Support function that computes default values for strides.
|
2020-06-15 11:49:09 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
2020-03-13 21:59:16 +08:00
|
|
|
template <class T>
|
2020-05-27 10:09:28 +08:00
|
|
|
static LogicalResult processConvStrideParam(
|
|
|
|
T *op, Optional<ArrayAttr> kernelShape) {
|
2020-03-13 21:59:16 +08:00
|
|
|
auto builder = mlir::Builder(op->getContext());
|
|
|
|
auto kernelRank = ArrayAttrSize(kernelShape);
|
|
|
|
|
|
|
|
auto stridesOpt = op->strides();
|
|
|
|
if (stridesOpt.hasValue()) {
|
|
|
|
if (ArrayAttrSize(stridesOpt) != kernelRank)
|
2020-05-27 10:09:28 +08:00
|
|
|
return op->emitError("strides rank is not the same as the spatial rank");
|
2020-03-13 21:59:16 +08:00
|
|
|
// Check values to be greater than 0.
|
|
|
|
for (int i = 0; i < kernelRank; ++i) {
|
|
|
|
if (ArrayAttrIntVal(stridesOpt, i) < 1)
|
2020-05-27 10:09:28 +08:00
|
|
|
return op->emitError("strides value must be nonzero positive");
|
2020-03-13 21:59:16 +08:00
|
|
|
}
|
|
|
|
} else {
|
|
|
|
// Default stride is needed, all dimensions init with 1.
|
|
|
|
SmallVector<int64_t, 4> defaultVals(kernelRank, 1);
|
|
|
|
// Convert to ArrayRef, then build attribute, then store attribute.
|
|
|
|
ArrayRef<int64_t> defaultRefs(defaultVals);
|
|
|
|
op->stridesAttr(builder.getI64ArrayAttr(defaultRefs));
|
|
|
|
}
|
2020-05-27 10:09:28 +08:00
|
|
|
return success();
|
2020-03-13 21:59:16 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// Support function that computes default values for pads.
|
2020-06-15 11:49:09 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
2020-03-13 21:59:16 +08:00
|
|
|
template <class T>
|
2020-05-27 10:09:28 +08:00
|
|
|
static LogicalResult processConvPadParam(T *op, ArrayRef<int64_t> inputShape,
|
2020-03-30 23:22:55 +08:00
|
|
|
Optional<ArrayAttr> kernelShape, Optional<ArrayAttr> stridesOpt,
|
2020-03-13 21:59:16 +08:00
|
|
|
Optional<ArrayAttr> dilationsOpt = llvm::None) {
|
|
|
|
auto builder = mlir::Builder(op->getContext());
|
|
|
|
|
|
|
|
auto inputRank = inputShape.size();
|
|
|
|
auto kernelRank = ArrayAttrSize(kernelShape);
|
|
|
|
auto kernelOffset = inputRank - kernelRank;
|
|
|
|
|
|
|
|
// Try to find padding, getting auto_pad attribute first.
|
|
|
|
auto autoPad = op->auto_pad();
|
|
|
|
// And then investigate the various different cases. Prefill pad values with
|
|
|
|
// zeros, the most common case.
|
|
|
|
SmallVector<int64_t, 4> actualPads(2 * kernelRank, 0);
|
|
|
|
bool updatedPad = false;
|
|
|
|
if (autoPad == "NOTSET") {
|
|
|
|
auto padsOpt = op->pads();
|
|
|
|
if (padsOpt.hasValue()) {
|
|
|
|
// Only option where pads are not updated. Pads consists of two entries
|
|
|
|
// for each spatial axis.
|
2020-05-27 10:09:28 +08:00
|
|
|
if (ArrayAttrSize(padsOpt) != 2 * kernelRank) {
|
|
|
|
return op->emitError("pads rank is not twice the spatial rank");
|
|
|
|
}
|
2020-03-13 21:59:16 +08:00
|
|
|
// Check values, pads cannot be negative.
|
|
|
|
for (int i = 0; i < 2 * kernelRank; ++i) {
|
2020-05-27 10:09:28 +08:00
|
|
|
if (ArrayAttrIntVal(padsOpt, i) < 0) {
|
|
|
|
return op->emitError("pads value must be nonnegative");
|
|
|
|
}
|
2020-03-13 21:59:16 +08:00
|
|
|
}
|
|
|
|
} else {
|
|
|
|
// We have notset with no pads, they are assumed to be all zero.
|
|
|
|
updatedPad = true;
|
|
|
|
}
|
|
|
|
} else if (autoPad == "SAME_UPPER" || autoPad == "SAME_LOWER") {
|
|
|
|
// Reload dialtion and strides as they may have gotten default values.
|
|
|
|
updatedPad = true;
|
|
|
|
int64_t dilationVal = 1;
|
|
|
|
for (int i = 0; i < kernelRank; ++i) {
|
|
|
|
auto inputSize = inputShape[kernelOffset + i];
|
|
|
|
auto kernelSize = ArrayAttrIntVal(kernelShape, i);
|
|
|
|
if (dilationsOpt.hasValue())
|
|
|
|
dilationVal = ArrayAttrIntVal(dilationsOpt, i);
|
|
|
|
auto strideVal = ArrayAttrIntVal(stridesOpt, i);
|
|
|
|
// Output size is input size divided by stride. When stride is 1, then
|
|
|
|
// input and output are the same size, which is the usual case. When
|
|
|
|
// stride is greater than 1, take the ceil to be sure to have each input
|
|
|
|
// value used, as padding will be used to fill the gaps.
|
|
|
|
int64_t outputSize = ceil((1.0 * inputSize) / (1.0 * strideVal));
|
|
|
|
// Forumla is from ONNX MaxPool, and can be explained as follows. Pads is
|
|
|
|
// the difference between the needed values for the computations, minus
|
|
|
|
// the input values. The needed values for the computation is the
|
|
|
|
// effective side of the kernel plus the number of times we jump to the
|
|
|
|
// next kernel. Number of time we jump is (outputSize - 1). That number is
|
|
|
|
// multiplied with the size of the jump, namely strideVal. Now for the
|
|
|
|
// effective kernel size. It is the kernelSize + the number of times we
|
|
|
|
// have dilation holes time the dialtion. The number of dialtion holes is
|
|
|
|
// (kernelSize -1). Thus the effective size is "kernelSize +
|
|
|
|
// (kernelSize-1)*dialation". This simplifies to "(kernelSize
|
|
|
|
// -1)*dialation + 1".
|
|
|
|
auto sumOfPad = (outputSize - 1) * strideVal +
|
|
|
|
((kernelSize - 1) * dilationVal + 1) - inputSize;
|
|
|
|
// Pad values are assumed equal on both size, at half the total value.
|
|
|
|
actualPads[i] = actualPads[kernelRank + i] = sumOfPad / 2;
|
|
|
|
// But if the total pad value is odd, we add 1 to begining or end
|
|
|
|
// depending on autoPad value.
|
|
|
|
if (sumOfPad % 2 != 0) {
|
|
|
|
if (autoPad == "SAME_UPPER") {
|
|
|
|
actualPads[kernelRank + i] += 1;
|
|
|
|
} else {
|
|
|
|
actualPads[i] += 1;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
} else if (autoPad == "VALID") {
|
|
|
|
// No pad, default value was set to zero, we are all set.
|
|
|
|
updatedPad = true;
|
|
|
|
} else {
|
2020-05-27 10:09:28 +08:00
|
|
|
return op->emitError("auto_pad of unknown / unsupported value");
|
2020-03-13 21:59:16 +08:00
|
|
|
}
|
|
|
|
// Set pads values in attributes, if it is needed.
|
|
|
|
if (updatedPad) {
|
|
|
|
ArrayRef<int64_t> defaultRefs(actualPads);
|
|
|
|
op->padsAttr(builder.getI64ArrayAttr(defaultRefs));
|
|
|
|
}
|
|
|
|
// In all cases now, the acutal pad values are found in the pads attribute.
|
|
|
|
op->auto_padAttr(builder.getStringAttr("NOTSET"));
|
2020-05-27 10:09:28 +08:00
|
|
|
return success();
|
2020-03-13 21:59:16 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
2020-06-15 11:49:09 +08:00
|
|
|
// Support function computing default values for dilations, strides, and pads.
|
|
|
|
//===----------------------------------------------------------------------===//
|
2020-03-13 21:59:16 +08:00
|
|
|
template <class T>
|
2020-05-27 10:09:28 +08:00
|
|
|
static LogicalResult processConvTypeParams(T *op, Value inputOperand) {
|
2020-03-13 21:59:16 +08:00
|
|
|
auto builder = mlir::Builder(op->getContext());
|
|
|
|
|
|
|
|
// 1) Get shape of input.
|
|
|
|
auto inputShape = inputOperand.getType().cast<RankedTensorType>().getShape();
|
|
|
|
auto inputRank = inputShape.size();
|
|
|
|
|
|
|
|
// 2) Get kernel_shape attribute.
|
|
|
|
auto kernelShape = op->kernel_shape();
|
|
|
|
|
|
|
|
// Dilation.
|
2020-05-27 10:09:28 +08:00
|
|
|
LogicalResult res = processConvDilationParam<T>(op, kernelShape);
|
|
|
|
if (failed(res))
|
|
|
|
return res;
|
2020-03-13 21:59:16 +08:00
|
|
|
auto dilationsOpt = op->dilations();
|
|
|
|
|
2020-03-30 23:22:55 +08:00
|
|
|
// Strides.
|
2020-05-27 10:09:28 +08:00
|
|
|
res = processConvStrideParam<T>(op, kernelShape);
|
|
|
|
if (failed(res))
|
|
|
|
return res;
|
2020-03-13 21:59:16 +08:00
|
|
|
auto stridesOpt = op->strides();
|
|
|
|
|
|
|
|
// Pads.
|
2020-05-27 10:09:28 +08:00
|
|
|
return processConvPadParam<T>(
|
|
|
|
op, inputShape, kernelShape, stridesOpt, dilationsOpt);
|
2020-03-13 21:59:16 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// Compute spatial dimensions given dilations, strides, pads, and ceil mode.
|
2020-06-15 11:49:09 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
2020-03-13 21:59:16 +08:00
|
|
|
static void insertConvSpatialDim(SmallVector<int64_t, 4> *outputDims,
|
2020-05-14 17:31:33 +08:00
|
|
|
Builder &builder, ArrayRef<int64_t> xShape, Optional<ArrayAttr> kernelShape,
|
2020-03-13 21:59:16 +08:00
|
|
|
Optional<ArrayAttr> padsOpt, Optional<ArrayAttr> stridesOpt,
|
|
|
|
Optional<ArrayAttr> dilationsOpt = llvm::None, bool ceilMode = false) {
|
|
|
|
auto spatialRank = ArrayAttrSize(kernelShape);
|
2020-05-14 17:31:33 +08:00
|
|
|
auto spatialOffset = xShape.size() - spatialRank;
|
2020-03-13 21:59:16 +08:00
|
|
|
|
2020-05-14 17:31:33 +08:00
|
|
|
// Get an affine map to compute the output dimension.
|
|
|
|
AffineMap dimMap = getConvDimMap(builder, ceilMode);
|
2020-03-13 21:59:16 +08:00
|
|
|
for (int i = 0; i < spatialRank; ++i) {
|
2020-05-14 17:31:33 +08:00
|
|
|
int64_t res = -1;
|
|
|
|
if (xShape[spatialOffset + i] != -1) {
|
|
|
|
auto inputSize = xShape[spatialOffset + i];
|
|
|
|
auto kernelSize = ArrayAttrIntVal(kernelShape, i);
|
|
|
|
auto sumOfPads = ArrayAttrIntVal(padsOpt, i) +
|
|
|
|
ArrayAttrIntVal(padsOpt, spatialRank + i);
|
|
|
|
auto strideVal = ArrayAttrIntVal(stridesOpt, i);
|
|
|
|
int64_t dilationVal = 1;
|
|
|
|
if (dilationsOpt.hasValue())
|
|
|
|
dilationVal = ArrayAttrIntVal(dilationsOpt, i);
|
|
|
|
res = AffineMapIntConstant(builder, dimMap, {inputSize},
|
|
|
|
{kernelSize, sumOfPads, strideVal, dilationVal}, 1, 4);
|
2020-03-13 21:59:16 +08:00
|
|
|
}
|
|
|
|
outputDims->emplace_back(res);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2020-05-13 21:08:06 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// Support function that infers shape for RNN operations.
|
2020-06-15 11:49:09 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
2020-05-13 21:08:06 +08:00
|
|
|
template <typename T>
|
2020-05-27 10:09:28 +08:00
|
|
|
static LogicalResult RNNShapeInference(T *op) {
|
2020-05-13 21:08:06 +08:00
|
|
|
Value X = op->X();
|
|
|
|
Value W = op->W();
|
|
|
|
Value R = op->R();
|
|
|
|
|
|
|
|
if (!X.getType().isa<RankedTensorType>() ||
|
|
|
|
!W.getType().isa<RankedTensorType>() ||
|
2020-05-27 10:09:28 +08:00
|
|
|
!R.getType().isa<RankedTensorType>()) {
|
|
|
|
return op->emitError("Input tensor not ranked");
|
|
|
|
}
|
2020-05-13 21:08:06 +08:00
|
|
|
|
|
|
|
auto xTy = X.getType().cast<RankedTensorType>();
|
|
|
|
auto elementType = xTy.getElementType();
|
|
|
|
|
|
|
|
// xShape :: [seq_length, batch_size, input_size]
|
|
|
|
auto xShape = xTy.getShape();
|
|
|
|
// wShape :: [num_directions, 4*hidden_size, input_size]
|
|
|
|
auto wShape = W.getType().cast<RankedTensorType>().getShape();
|
|
|
|
// rShape :: [num_directions, 4*hidden_size, hidden_size]
|
|
|
|
auto rShape = R.getType().cast<RankedTensorType>().getShape();
|
|
|
|
|
|
|
|
if (xShape.size() != 3) {
|
2020-05-27 10:09:28 +08:00
|
|
|
return op->emitError("The first input tensor must have rank 3");
|
2020-05-13 21:08:06 +08:00
|
|
|
}
|
|
|
|
if (wShape.size() != 3) {
|
2020-05-27 10:09:28 +08:00
|
|
|
return op->emitError("The second input tensor must have rank 3");
|
2020-05-13 21:08:06 +08:00
|
|
|
}
|
|
|
|
if (rShape.size() != 3) {
|
2020-05-27 10:09:28 +08:00
|
|
|
return op->emitError("The third input tensor must have rank 3");
|
2020-05-13 21:08:06 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
// Get sequence length, batch size and input size.
|
|
|
|
auto sequenceLength = xShape[0];
|
|
|
|
auto batchSize = xShape[1];
|
|
|
|
auto inputSize = xShape[2];
|
|
|
|
|
|
|
|
// Get hidden size from hidden_size attribute.
|
|
|
|
int64_t hiddenSize = -1;
|
|
|
|
if (op->hidden_size().hasValue()) {
|
|
|
|
hiddenSize = op->hidden_size().getValue().getSExtValue();
|
|
|
|
} else {
|
|
|
|
// Infer hidden_size from wShape and rShape if possible.
|
|
|
|
if (rShape[2] != -1)
|
|
|
|
hiddenSize = rShape[2];
|
|
|
|
else if (rShape[1] != -1)
|
|
|
|
hiddenSize = rShape[1] / 4;
|
|
|
|
else if (wShape[1] != -1)
|
|
|
|
hiddenSize = wShape[1] / 4;
|
|
|
|
// Update hidden_size attribute.
|
|
|
|
if (hiddenSize != -1) {
|
|
|
|
auto builder = mlir::Builder(op->getContext());
|
|
|
|
op->hidden_sizeAttr(builder.getI64IntegerAttr(hiddenSize));
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
// Get direction.
|
|
|
|
int numDirection;
|
|
|
|
if ((op->direction() == "forward") || (op->direction() == "reverse"))
|
|
|
|
numDirection = 1;
|
|
|
|
else if (op->direction() == "bidirectional")
|
|
|
|
numDirection = 2;
|
|
|
|
else
|
|
|
|
numDirection = -1;
|
|
|
|
if (numDirection == -1) {
|
2020-05-27 10:09:28 +08:00
|
|
|
return op->emitError(
|
|
|
|
"direction attribute muse be one of the strings: forward, "
|
|
|
|
"reverse, and bidirectional");
|
2020-05-13 21:08:06 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
// Set result types.
|
|
|
|
unsigned numOfResults = op->getNumResults();
|
|
|
|
if (numOfResults > 0) {
|
|
|
|
// Y :: [seq_length, num_directions, batch_size, hidden_size]
|
|
|
|
Type yTy = op->getResults()[0].getType();
|
|
|
|
if (!yTy.isa<NoneType>()) {
|
|
|
|
yTy = RankedTensorType::get(
|
|
|
|
{sequenceLength, numDirection, batchSize, hiddenSize}, elementType);
|
|
|
|
op->getResults()[0].setType(yTy);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
if (numOfResults > 1) {
|
|
|
|
// Y_h :: [num_directions, batch_size, hidden_size]
|
|
|
|
Type yhTy = op->getResults()[1].getType();
|
|
|
|
if (!yhTy.isa<NoneType>()) {
|
|
|
|
yhTy = RankedTensorType::get(
|
|
|
|
{numDirection, batchSize, hiddenSize}, elementType);
|
|
|
|
op->getResults()[1].setType(yhTy);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
if (numOfResults > 2) {
|
|
|
|
// Y_c :: [num_directions, batch_size, hidden_size]
|
|
|
|
Type ycTy = op->getResults()[2].getType();
|
|
|
|
if (!ycTy.isa<NoneType>()) {
|
|
|
|
ycTy = RankedTensorType::get(
|
|
|
|
{numDirection, batchSize, hiddenSize}, elementType);
|
|
|
|
op->getResults()[2].setType(ycTy);
|
|
|
|
}
|
|
|
|
}
|
2020-05-27 10:09:28 +08:00
|
|
|
return success();
|
2020-05-13 21:08:06 +08:00
|
|
|
}
|
|
|
|
|
2020-06-09 14:55:49 +08:00
|
|
|
static void insertConvTransposeSpatialDim(SmallVectorImpl<int64_t> &outputDims,
|
|
|
|
ArrayRef<int64_t> xShape, Optional<ArrayAttr> kernelShape,
|
|
|
|
Optional<ArrayAttr> padsOpt, Optional<ArrayAttr> stridesOpt,
|
|
|
|
Optional<ArrayAttr> outputPadsOpt, Optional<ArrayAttr> outputShapeOpt,
|
|
|
|
Optional<ArrayAttr> dilationsOpt = llvm::None, bool ceilMode = false) {
|
|
|
|
auto xRank = xShape.size();
|
|
|
|
auto spatialRank = ArrayAttrSize(kernelShape);
|
|
|
|
auto spatialOffset = xRank - spatialRank;
|
|
|
|
|
|
|
|
int64_t dilationVal = 1;
|
|
|
|
int64_t outputPadsVal = 0;
|
|
|
|
// output_shape[i] = stride[i] * (input_size[i] - 1) + output_padding[i] +
|
|
|
|
// ((kernel_shape[i] - 1) * dilations[i] + 1) - pads[start_i] - pads[end_i]
|
|
|
|
for (int i = 0; i < spatialRank; ++i) {
|
|
|
|
auto inputSize = xShape[spatialOffset + i];
|
|
|
|
auto sumOfPads =
|
|
|
|
ArrayAttrIntVal(padsOpt, i) + ArrayAttrIntVal(padsOpt, spatialRank + i);
|
|
|
|
auto kernelSize = ArrayAttrIntVal(kernelShape, i);
|
|
|
|
if (dilationsOpt.hasValue())
|
|
|
|
dilationVal = ArrayAttrIntVal(dilationsOpt, i);
|
|
|
|
auto strideVal = ArrayAttrIntVal(stridesOpt, i);
|
|
|
|
if (outputPadsOpt.hasValue())
|
|
|
|
outputPadsVal = ArrayAttrIntVal(outputPadsOpt, i);
|
|
|
|
// Number of useful values: input plus pad - effective size of kernel (see
|
|
|
|
// processConvTypeParams comments to see how this value is derived).
|
|
|
|
int64_t res = strideVal * (inputSize - 1) + outputPadsVal +
|
|
|
|
((kernelSize - 1) * dilationVal + 1) - sumOfPads;
|
|
|
|
outputDims.emplace_back(res);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2019-11-02 05:09:48 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// ONNXOpsDialect
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
/// Dialect creation, the instance will be owned by the context. This is the
|
|
|
|
/// point of registration of custom types and operations for the dialect.
|
2019-12-17 07:45:39 +08:00
|
|
|
ONNXOpsDialect::ONNXOpsDialect(mlir::MLIRContext *ctx)
|
2019-11-02 05:09:48 +08:00
|
|
|
: mlir::Dialect(getDialectNamespace(), ctx) {
|
|
|
|
addOperations<
|
|
|
|
#define GET_OP_LIST
|
2020-03-20 22:40:51 +08:00
|
|
|
#include "src/Dialect/ONNX/ONNXOps.cpp.inc"
|
2019-11-02 05:09:48 +08:00
|
|
|
>();
|
2020-06-26 04:34:37 +08:00
|
|
|
addTypes<StringType>();
|
|
|
|
}
|
|
|
|
|
|
|
|
mlir::Type ONNXOpsDialect::parseType(mlir::DialectAsmParser &parser) const {
|
|
|
|
if (parser.parseKeyword("String"))
|
|
|
|
return Type();
|
|
|
|
|
|
|
|
return StringType::get(getContext());
|
|
|
|
}
|
|
|
|
|
|
|
|
void ONNXOpsDialect::printType(
|
|
|
|
mlir::Type type, mlir::DialectAsmPrinter &printer) const {
|
|
|
|
printer << "String";
|
2019-11-02 05:09:48 +08:00
|
|
|
}
|
|
|
|
|
2020-05-20 15:45:42 +08:00
|
|
|
void ONNXEntryPointOp::build(mlir::OpBuilder &builder,
|
2020-02-26 03:33:48 +08:00
|
|
|
mlir::OperationState &state, mlir::FuncOp function, int numInputs,
|
|
|
|
int numOutputs) {
|
2019-12-22 13:25:02 +08:00
|
|
|
state.addAttribute(ONNXEntryPointOp::getEntryPointFuncAttrName(),
|
2020-05-20 15:45:42 +08:00
|
|
|
builder.getSymbolRefAttr(function));
|
2019-12-22 13:25:02 +08:00
|
|
|
state.addAttribute(ONNXEntryPointOp::getNumInputsAttrName(),
|
2020-05-20 15:45:42 +08:00
|
|
|
builder.getI32IntegerAttr(numInputs));
|
2019-12-22 13:25:02 +08:00
|
|
|
state.addAttribute(ONNXEntryPointOp::getNumOutputsAttrName(),
|
2020-05-20 15:45:42 +08:00
|
|
|
builder.getI32IntegerAttr(numOutputs));
|
2019-12-22 13:25:02 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
ONNXEntryPointOp ONNXEntryPointOp::create(mlir::Location location,
|
2020-02-26 03:33:48 +08:00
|
|
|
mlir::FuncOp &func, int numInputs, int numOutputs) {
|
2019-12-22 13:25:02 +08:00
|
|
|
mlir::OperationState state(location, "onnx.EntryPoint");
|
2020-05-20 15:45:42 +08:00
|
|
|
OpBuilder builder(location->getContext());
|
|
|
|
mlir::ONNXEntryPointOp::build(builder, state, func, numInputs, numOutputs);
|
2019-12-22 13:25:02 +08:00
|
|
|
Operation *op = mlir::Operation::create(state);
|
|
|
|
auto onnxEntryOp = llvm::cast<mlir::ONNXEntryPointOp>(op);
|
|
|
|
return onnxEntryOp;
|
|
|
|
}
|
|
|
|
|
2019-11-02 05:09:48 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// ONNX Operations
|
2019-12-06 09:08:09 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// Exp
|
|
|
|
/// Infer the output shape of the ONNXExpOp. This method is required by the
|
|
|
|
/// shape inference interface.
|
2020-05-27 10:09:28 +08:00
|
|
|
LogicalResult ONNXExpOp::inferShapes() {
|
2020-03-30 23:22:55 +08:00
|
|
|
getResult().setType(getOperand().getType());
|
2020-05-27 10:09:28 +08:00
|
|
|
return success();
|
2020-03-30 23:22:55 +08:00
|
|
|
}
|
2019-12-06 09:08:09 +08:00
|
|
|
|
2020-06-09 14:55:49 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// Atan
|
2020-06-15 11:49:09 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
2020-06-09 14:55:49 +08:00
|
|
|
/// Infer the output shape of the ONNXAtanOp. This method is required by the
|
|
|
|
/// shape inference interface.
|
|
|
|
LogicalResult ONNXAtanOp::inferShapes() {
|
|
|
|
getResult().setType(getOperand().getType());
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// Tan
|
2020-06-15 11:49:09 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
2020-06-09 14:55:49 +08:00
|
|
|
/// Infer the output shape of the ONNXTanOp. This method is required by the
|
|
|
|
/// shape inference interface.
|
|
|
|
LogicalResult ONNXTanOp::inferShapes() {
|
|
|
|
getResult().setType(getOperand().getType());
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
2019-12-06 09:08:09 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// Tanh
|
|
|
|
/// Infer the output shape of the ONNXTanhOp. This method is required by the
|
|
|
|
/// shape inference interface.
|
2020-05-27 10:09:28 +08:00
|
|
|
LogicalResult ONNXTanhOp::inferShapes() {
|
2020-03-30 23:22:55 +08:00
|
|
|
getResult().setType(getOperand().getType());
|
2020-05-27 10:09:28 +08:00
|
|
|
return success();
|
2020-03-30 23:22:55 +08:00
|
|
|
}
|
2019-12-06 09:08:09 +08:00
|
|
|
|
2020-06-09 14:55:49 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// Sin
|
2020-06-15 11:49:09 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
2020-06-09 14:55:49 +08:00
|
|
|
/// Infer the output shape of the ONNXSinOp. This method is required by the
|
|
|
|
/// shape inference interface.
|
|
|
|
LogicalResult ONNXSinOp::inferShapes() {
|
|
|
|
getResult().setType(getOperand().getType());
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
2019-12-06 09:08:09 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// Sinh
|
2020-06-15 11:49:09 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
2019-12-06 09:08:09 +08:00
|
|
|
/// Infer the output shape of the ONNXSinhOp. This method is required by the
|
|
|
|
/// shape inference interface.
|
2020-05-27 10:09:28 +08:00
|
|
|
LogicalResult ONNXSinhOp::inferShapes() {
|
2020-03-30 23:22:55 +08:00
|
|
|
getResult().setType(getOperand().getType());
|
2020-05-27 10:09:28 +08:00
|
|
|
return success();
|
2020-03-30 23:22:55 +08:00
|
|
|
}
|
2019-12-06 09:08:09 +08:00
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// Cosh
|
2020-06-15 11:49:09 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
2019-12-06 09:08:09 +08:00
|
|
|
/// Infer the output shape of the ONNXCoshOp. This method is required by the
|
|
|
|
/// shape inference interface.
|
2020-05-27 10:09:28 +08:00
|
|
|
LogicalResult ONNXCoshOp::inferShapes() {
|
2020-03-30 23:22:55 +08:00
|
|
|
getResult().setType(getOperand().getType());
|
2020-05-27 10:09:28 +08:00
|
|
|
return success();
|
2020-03-30 23:22:55 +08:00
|
|
|
}
|
2019-12-06 09:08:09 +08:00
|
|
|
|
2020-01-08 11:11:21 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// Cos
|
2020-06-15 11:49:09 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
2020-01-08 11:11:21 +08:00
|
|
|
/// Infer the output shape of the ONNXCosOp. This method is required by the
|
|
|
|
/// shape inference interface.
|
2020-05-27 10:09:28 +08:00
|
|
|
LogicalResult ONNXCosOp::inferShapes() {
|
2020-03-30 23:22:55 +08:00
|
|
|
getResult().setType(getOperand().getType());
|
2020-05-27 10:09:28 +08:00
|
|
|
return success();
|
2020-03-30 23:22:55 +08:00
|
|
|
}
|
2020-01-08 11:11:21 +08:00
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// Log
|
2020-06-15 11:49:09 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
2020-01-08 11:11:21 +08:00
|
|
|
/// Infer the output shape of the ONNXLogOp. This method is required by the
|
|
|
|
/// shape inference interface.
|
2020-05-27 10:09:28 +08:00
|
|
|
LogicalResult ONNXLogOp::inferShapes() {
|
2020-03-30 23:22:55 +08:00
|
|
|
getResult().setType(getOperand().getType());
|
2020-05-27 10:09:28 +08:00
|
|
|
return success();
|
2020-03-30 23:22:55 +08:00
|
|
|
}
|
2020-01-08 11:11:21 +08:00
|
|
|
|
[MLIR] Add support for Max, Min, Sum, Elu, Selu, LeakyRelu, HardSigmoid (#395)
* Lower ONNXSumOp
* Add inferShapes() and test cases
* Load the first operand to the result
* Update SharingWork.md
* Update SharingWork.md
* Update SharingWork.md
* Add support for Max, Min
* Pass operation instead of location to mapToLowerScalarOp
* Add support for Elu, Selu, LeakyRelu, HardSigmoid
* Add test cases
* Update SharingWork.md
* Rewrite the part of lowering variadic ops and use it for binary ops
* Use two diffenrent templates for Unary and Variadic Ops
* Revise the code
2019-12-12 10:49:50 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// HardSigmoid
|
2020-06-15 11:49:09 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
[MLIR] Add support for Max, Min, Sum, Elu, Selu, LeakyRelu, HardSigmoid (#395)
* Lower ONNXSumOp
* Add inferShapes() and test cases
* Load the first operand to the result
* Update SharingWork.md
* Update SharingWork.md
* Update SharingWork.md
* Add support for Max, Min
* Pass operation instead of location to mapToLowerScalarOp
* Add support for Elu, Selu, LeakyRelu, HardSigmoid
* Add test cases
* Update SharingWork.md
* Rewrite the part of lowering variadic ops and use it for binary ops
* Use two diffenrent templates for Unary and Variadic Ops
* Revise the code
2019-12-12 10:49:50 +08:00
|
|
|
/// Infer the output shape of the ONNXHardSigmoidOp. This method is required by
|
|
|
|
/// the shape inference interface.
|
2020-05-27 10:09:28 +08:00
|
|
|
LogicalResult ONNXHardSigmoidOp::inferShapes() {
|
2020-01-14 01:21:29 +08:00
|
|
|
getResult().setType(getOperand().getType());
|
2020-05-27 10:09:28 +08:00
|
|
|
return success();
|
[MLIR] Add support for Max, Min, Sum, Elu, Selu, LeakyRelu, HardSigmoid (#395)
* Lower ONNXSumOp
* Add inferShapes() and test cases
* Load the first operand to the result
* Update SharingWork.md
* Update SharingWork.md
* Update SharingWork.md
* Add support for Max, Min
* Pass operation instead of location to mapToLowerScalarOp
* Add support for Elu, Selu, LeakyRelu, HardSigmoid
* Add test cases
* Update SharingWork.md
* Rewrite the part of lowering variadic ops and use it for binary ops
* Use two diffenrent templates for Unary and Variadic Ops
* Revise the code
2019-12-12 10:49:50 +08:00
|
|
|
}
|
|
|
|
|
2019-12-06 09:08:09 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// Sigmoid
|
2020-06-15 11:49:09 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
2019-12-06 09:08:09 +08:00
|
|
|
/// Infer the output shape of the ONNXSigmoidOp. This method is required by the
|
|
|
|
/// shape inference interface.
|
2020-05-27 10:09:28 +08:00
|
|
|
LogicalResult ONNXSigmoidOp::inferShapes() {
|
2020-01-14 01:21:29 +08:00
|
|
|
getResult().setType(getOperand().getType());
|
2020-05-27 10:09:28 +08:00
|
|
|
return success();
|
2019-12-06 09:08:09 +08:00
|
|
|
}
|
|
|
|
|
[MLIR] Add support for Max, Min, Sum, Elu, Selu, LeakyRelu, HardSigmoid (#395)
* Lower ONNXSumOp
* Add inferShapes() and test cases
* Load the first operand to the result
* Update SharingWork.md
* Update SharingWork.md
* Update SharingWork.md
* Add support for Max, Min
* Pass operation instead of location to mapToLowerScalarOp
* Add support for Elu, Selu, LeakyRelu, HardSigmoid
* Add test cases
* Update SharingWork.md
* Rewrite the part of lowering variadic ops and use it for binary ops
* Use two diffenrent templates for Unary and Variadic Ops
* Revise the code
2019-12-12 10:49:50 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// Elu
|
2020-06-15 11:49:09 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
[MLIR] Add support for Max, Min, Sum, Elu, Selu, LeakyRelu, HardSigmoid (#395)
* Lower ONNXSumOp
* Add inferShapes() and test cases
* Load the first operand to the result
* Update SharingWork.md
* Update SharingWork.md
* Update SharingWork.md
* Add support for Max, Min
* Pass operation instead of location to mapToLowerScalarOp
* Add support for Elu, Selu, LeakyRelu, HardSigmoid
* Add test cases
* Update SharingWork.md
* Rewrite the part of lowering variadic ops and use it for binary ops
* Use two diffenrent templates for Unary and Variadic Ops
* Revise the code
2019-12-12 10:49:50 +08:00
|
|
|
/// Infer the output shape of the ONNXEluOp. This method is required by the
|
|
|
|
/// shape inference interface.
|
2020-05-27 10:09:28 +08:00
|
|
|
LogicalResult ONNXEluOp::inferShapes() {
|
2020-03-30 23:22:55 +08:00
|
|
|
getResult().setType(getOperand().getType());
|
2020-05-27 10:09:28 +08:00
|
|
|
return success();
|
2020-03-30 23:22:55 +08:00
|
|
|
}
|
[MLIR] Add support for Max, Min, Sum, Elu, Selu, LeakyRelu, HardSigmoid (#395)
* Lower ONNXSumOp
* Add inferShapes() and test cases
* Load the first operand to the result
* Update SharingWork.md
* Update SharingWork.md
* Update SharingWork.md
* Add support for Max, Min
* Pass operation instead of location to mapToLowerScalarOp
* Add support for Elu, Selu, LeakyRelu, HardSigmoid
* Add test cases
* Update SharingWork.md
* Rewrite the part of lowering variadic ops and use it for binary ops
* Use two diffenrent templates for Unary and Variadic Ops
* Revise the code
2019-12-12 10:49:50 +08:00
|
|
|
|
2019-12-06 13:31:17 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// Relu
|
2020-06-15 11:49:09 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
2019-12-06 13:31:17 +08:00
|
|
|
/// Infer the output shape of the ONNXReluOp. This method is required by the
|
|
|
|
/// shape inference interface.
|
2020-05-27 10:09:28 +08:00
|
|
|
LogicalResult ONNXReluOp::inferShapes() {
|
2020-03-30 23:22:55 +08:00
|
|
|
getResult().setType(getOperand().getType());
|
2020-05-27 10:09:28 +08:00
|
|
|
return success();
|
2020-03-30 23:22:55 +08:00
|
|
|
}
|
2019-12-06 13:31:17 +08:00
|
|
|
|
[MLIR] Add support for Max, Min, Sum, Elu, Selu, LeakyRelu, HardSigmoid (#395)
* Lower ONNXSumOp
* Add inferShapes() and test cases
* Load the first operand to the result
* Update SharingWork.md
* Update SharingWork.md
* Update SharingWork.md
* Add support for Max, Min
* Pass operation instead of location to mapToLowerScalarOp
* Add support for Elu, Selu, LeakyRelu, HardSigmoid
* Add test cases
* Update SharingWork.md
* Rewrite the part of lowering variadic ops and use it for binary ops
* Use two diffenrent templates for Unary and Variadic Ops
* Revise the code
2019-12-12 10:49:50 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// LeakyRelu
|
2020-06-15 11:49:09 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
[MLIR] Add support for Max, Min, Sum, Elu, Selu, LeakyRelu, HardSigmoid (#395)
* Lower ONNXSumOp
* Add inferShapes() and test cases
* Load the first operand to the result
* Update SharingWork.md
* Update SharingWork.md
* Update SharingWork.md
* Add support for Max, Min
* Pass operation instead of location to mapToLowerScalarOp
* Add support for Elu, Selu, LeakyRelu, HardSigmoid
* Add test cases
* Update SharingWork.md
* Rewrite the part of lowering variadic ops and use it for binary ops
* Use two diffenrent templates for Unary and Variadic Ops
* Revise the code
2019-12-12 10:49:50 +08:00
|
|
|
/// Infer the output shape of the ONNXLeakyReluOp. This method is required by
|
|
|
|
/// the shape inference interface.
|
2020-05-27 10:09:28 +08:00
|
|
|
LogicalResult ONNXLeakyReluOp::inferShapes() {
|
2020-01-14 01:21:29 +08:00
|
|
|
getResult().setType(getOperand().getType());
|
2020-05-27 10:09:28 +08:00
|
|
|
return success();
|
[MLIR] Add support for Max, Min, Sum, Elu, Selu, LeakyRelu, HardSigmoid (#395)
* Lower ONNXSumOp
* Add inferShapes() and test cases
* Load the first operand to the result
* Update SharingWork.md
* Update SharingWork.md
* Update SharingWork.md
* Add support for Max, Min
* Pass operation instead of location to mapToLowerScalarOp
* Add support for Elu, Selu, LeakyRelu, HardSigmoid
* Add test cases
* Update SharingWork.md
* Rewrite the part of lowering variadic ops and use it for binary ops
* Use two diffenrent templates for Unary and Variadic Ops
* Revise the code
2019-12-12 10:49:50 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// Selu
|
2020-06-15 11:49:09 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
[MLIR] Add support for Max, Min, Sum, Elu, Selu, LeakyRelu, HardSigmoid (#395)
* Lower ONNXSumOp
* Add inferShapes() and test cases
* Load the first operand to the result
* Update SharingWork.md
* Update SharingWork.md
* Update SharingWork.md
* Add support for Max, Min
* Pass operation instead of location to mapToLowerScalarOp
* Add support for Elu, Selu, LeakyRelu, HardSigmoid
* Add test cases
* Update SharingWork.md
* Rewrite the part of lowering variadic ops and use it for binary ops
* Use two diffenrent templates for Unary and Variadic Ops
* Revise the code
2019-12-12 10:49:50 +08:00
|
|
|
/// Infer the output shape of the ONNXSeluOp. This method is required by
|
|
|
|
/// the shape inference interface.
|
2020-05-27 10:09:28 +08:00
|
|
|
LogicalResult ONNXSeluOp::inferShapes() {
|
2020-03-30 23:22:55 +08:00
|
|
|
getResult().setType(getOperand().getType());
|
2020-05-27 10:09:28 +08:00
|
|
|
return success();
|
2020-03-30 23:22:55 +08:00
|
|
|
}
|
2019-12-16 14:23:33 +08:00
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// Reciprocal
|
2020-06-15 11:49:09 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
2019-12-16 14:23:33 +08:00
|
|
|
/// Infer the output shape of the ONNXReciprocalOp. This method is required by
|
|
|
|
/// the shape inference interface.
|
2020-05-27 10:09:28 +08:00
|
|
|
LogicalResult ONNXReciprocalOp::inferShapes() {
|
2020-01-14 01:21:29 +08:00
|
|
|
getResult().setType(getOperand().getType());
|
2020-05-27 10:09:28 +08:00
|
|
|
return success();
|
[MLIR] Add support for Max, Min, Sum, Elu, Selu, LeakyRelu, HardSigmoid (#395)
* Lower ONNXSumOp
* Add inferShapes() and test cases
* Load the first operand to the result
* Update SharingWork.md
* Update SharingWork.md
* Update SharingWork.md
* Add support for Max, Min
* Pass operation instead of location to mapToLowerScalarOp
* Add support for Elu, Selu, LeakyRelu, HardSigmoid
* Add test cases
* Update SharingWork.md
* Rewrite the part of lowering variadic ops and use it for binary ops
* Use two diffenrent templates for Unary and Variadic Ops
* Revise the code
2019-12-12 10:49:50 +08:00
|
|
|
}
|
|
|
|
|
2020-01-21 10:57:32 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// Softmax
|
2020-06-15 11:49:09 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
2020-01-21 10:57:32 +08:00
|
|
|
/// Infer the output shape of the ONNXSoftmaxOp. This method is required by
|
|
|
|
/// the shape inference interface.
|
2020-05-27 10:09:28 +08:00
|
|
|
LogicalResult ONNXSoftmaxOp::inferShapes() {
|
2020-01-21 10:57:32 +08:00
|
|
|
getResult().setType(getOperand().getType());
|
2020-05-27 10:09:28 +08:00
|
|
|
return success();
|
2020-01-21 10:57:32 +08:00
|
|
|
}
|
|
|
|
|
2020-01-24 12:18:38 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// Softplus
|
2020-06-15 11:49:09 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
2020-01-24 12:18:38 +08:00
|
|
|
/// Infer the output shape of the ONNXSoftplusOp. This method is required by
|
|
|
|
/// the shape inference interface.
|
2020-05-27 10:09:28 +08:00
|
|
|
LogicalResult ONNXSoftplusOp::inferShapes() {
|
2020-01-24 12:18:38 +08:00
|
|
|
getResult().setType(getOperand().getType());
|
2020-05-27 10:09:28 +08:00
|
|
|
return success();
|
2020-01-24 12:18:38 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// Softsign
|
2020-06-15 11:49:09 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
2020-01-24 12:18:38 +08:00
|
|
|
/// Infer the output shape of the ONNXSoftsignOp. This method is required by
|
|
|
|
/// the shape inference interface.
|
2020-05-27 10:09:28 +08:00
|
|
|
LogicalResult ONNXSoftsignOp::inferShapes() {
|
2020-01-24 12:18:38 +08:00
|
|
|
getResult().setType(getOperand().getType());
|
2020-05-27 10:09:28 +08:00
|
|
|
return success();
|
2020-01-24 12:18:38 +08:00
|
|
|
}
|
|
|
|
|
2020-01-29 00:10:47 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// Sqrt
|
2020-06-15 11:49:09 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
2020-01-29 00:10:47 +08:00
|
|
|
/// Infer the output shape of the ONNXSqrtOp. This method is required by
|
|
|
|
/// the shape inference interface.
|
2020-05-27 10:09:28 +08:00
|
|
|
LogicalResult ONNXSqrtOp::inferShapes() {
|
2020-03-30 23:22:55 +08:00
|
|
|
getResult().setType(getOperand().getType());
|
2020-05-27 10:09:28 +08:00
|
|
|
return success();
|
2020-03-30 23:22:55 +08:00
|
|
|
}
|
2020-01-29 00:10:47 +08:00
|
|
|
|
2020-02-04 22:27:17 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// Sign
|
2020-06-15 11:49:09 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
2020-02-04 22:27:17 +08:00
|
|
|
/// Infer the output shape of the ONNXSignOp. This method is required by
|
|
|
|
/// the shape inference interface.
|
2020-05-27 10:09:28 +08:00
|
|
|
LogicalResult ONNXSignOp::inferShapes() {
|
2020-03-30 23:22:55 +08:00
|
|
|
getResult().setType(getOperand().getType());
|
2020-05-27 10:09:28 +08:00
|
|
|
return success();
|
2020-03-30 23:22:55 +08:00
|
|
|
}
|
2020-02-04 22:27:17 +08:00
|
|
|
|
2020-03-17 23:12:45 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// Abs
|
2020-06-15 11:49:09 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
2020-03-17 23:12:45 +08:00
|
|
|
/// Infer the output shape of the ONNXAbsOp. This method is required by the
|
|
|
|
/// shape inference interface.
|
2020-05-27 10:09:28 +08:00
|
|
|
LogicalResult ONNXAbsOp::inferShapes() {
|
2020-03-30 23:22:55 +08:00
|
|
|
getResult().setType(getOperand().getType());
|
2020-05-27 10:09:28 +08:00
|
|
|
return success();
|
2020-03-30 23:22:55 +08:00
|
|
|
}
|
2020-03-17 23:12:45 +08:00
|
|
|
|
2019-11-02 05:09:48 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
2019-11-13 02:37:46 +08:00
|
|
|
// Add
|
2020-06-15 11:49:09 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
2019-11-19 10:08:21 +08:00
|
|
|
/// Infer the output shape of the ONNXAddOp. This method is required by the
|
|
|
|
/// shape inference interface.
|
2020-05-27 10:09:28 +08:00
|
|
|
LogicalResult ONNXAddOp::inferShapes() {
|
2020-01-14 01:21:29 +08:00
|
|
|
if (!getOperand(0).getType().isa<RankedTensorType>() ||
|
2020-05-27 10:09:28 +08:00
|
|
|
!getOperand(1).getType().isa<RankedTensorType>())
|
|
|
|
return emitError("Input tensor(s) not ranked");
|
2020-01-14 01:21:29 +08:00
|
|
|
auto lhsTy = getOperand(0).getType().cast<RankedTensorType>();
|
|
|
|
auto rhsTy = getOperand(1).getType().cast<RankedTensorType>();
|
|
|
|
getResult().setType(getBroadcastedType(lhsTy, rhsTy));
|
2020-05-27 10:09:28 +08:00
|
|
|
return success();
|
2019-11-08 00:42:40 +08:00
|
|
|
}
|
|
|
|
|
[MLIR] Lower ONNX element-wise binary ops: Mul, Div, Sub, And, Or, Xor (#388)
* Lower ONNX element-wise binary ops: Mul, Div, Sub, And, Or, Xor
* Edit gen_doc.py to avoid changes about AnyTypeOf<[AnyMemRef, AnyTensor]>
* Miss a space
* Add tests
* Shorten ONNXElementWiseBinaryOpLowering into ONNXEWBinaryOpLowering
* Move lowering patterns into runOnModule()
* Redundant space
2019-12-04 00:17:21 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// Mul
|
2020-06-15 11:49:09 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
[MLIR] Lower ONNX element-wise binary ops: Mul, Div, Sub, And, Or, Xor (#388)
* Lower ONNX element-wise binary ops: Mul, Div, Sub, And, Or, Xor
* Edit gen_doc.py to avoid changes about AnyTypeOf<[AnyMemRef, AnyTensor]>
* Miss a space
* Add tests
* Shorten ONNXElementWiseBinaryOpLowering into ONNXEWBinaryOpLowering
* Move lowering patterns into runOnModule()
* Redundant space
2019-12-04 00:17:21 +08:00
|
|
|
/// Infer the output shape of the ONNXMulOp. This method is required by the
|
|
|
|
/// shape inference interface.
|
2020-05-27 10:09:28 +08:00
|
|
|
LogicalResult ONNXMulOp::inferShapes() {
|
2020-01-14 01:21:29 +08:00
|
|
|
if (!getOperand(0).getType().isa<RankedTensorType>() ||
|
2020-05-27 10:09:28 +08:00
|
|
|
!getOperand(1).getType().isa<RankedTensorType>())
|
|
|
|
return emitError("Input tensor(s) not ranked");
|
2020-01-14 01:21:29 +08:00
|
|
|
auto lhsTy = getOperand(0).getType().cast<RankedTensorType>();
|
|
|
|
auto rhsTy = getOperand(1).getType().cast<RankedTensorType>();
|
|
|
|
getResult().setType(getBroadcastedType(lhsTy, rhsTy));
|
2020-05-27 10:09:28 +08:00
|
|
|
return success();
|
[MLIR] Lower ONNX element-wise binary ops: Mul, Div, Sub, And, Or, Xor (#388)
* Lower ONNX element-wise binary ops: Mul, Div, Sub, And, Or, Xor
* Edit gen_doc.py to avoid changes about AnyTypeOf<[AnyMemRef, AnyTensor]>
* Miss a space
* Add tests
* Shorten ONNXElementWiseBinaryOpLowering into ONNXEWBinaryOpLowering
* Move lowering patterns into runOnModule()
* Redundant space
2019-12-04 00:17:21 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// Div
|
2020-06-15 11:49:09 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
[MLIR] Lower ONNX element-wise binary ops: Mul, Div, Sub, And, Or, Xor (#388)
* Lower ONNX element-wise binary ops: Mul, Div, Sub, And, Or, Xor
* Edit gen_doc.py to avoid changes about AnyTypeOf<[AnyMemRef, AnyTensor]>
* Miss a space
* Add tests
* Shorten ONNXElementWiseBinaryOpLowering into ONNXEWBinaryOpLowering
* Move lowering patterns into runOnModule()
* Redundant space
2019-12-04 00:17:21 +08:00
|
|
|
/// Infer the output shape of the ONNXDivOp. This method is required by the
|
|
|
|
/// shape inference interface.
|
2020-05-27 10:09:28 +08:00
|
|
|
LogicalResult ONNXDivOp::inferShapes() {
|
2020-01-14 01:21:29 +08:00
|
|
|
if (!getOperand(0).getType().isa<RankedTensorType>() ||
|
2020-05-27 10:09:28 +08:00
|
|
|
!getOperand(1).getType().isa<RankedTensorType>())
|
|
|
|
return emitError("Input tensor(s) not ranked");
|
2020-01-14 01:21:29 +08:00
|
|
|
auto lhsTy = getOperand(0).getType().cast<RankedTensorType>();
|
|
|
|
auto rhsTy = getOperand(1).getType().cast<RankedTensorType>();
|
|
|
|
getResult().setType(getBroadcastedType(lhsTy, rhsTy));
|
2020-05-27 10:09:28 +08:00
|
|
|
return success();
|
[MLIR] Lower ONNX element-wise binary ops: Mul, Div, Sub, And, Or, Xor (#388)
* Lower ONNX element-wise binary ops: Mul, Div, Sub, And, Or, Xor
* Edit gen_doc.py to avoid changes about AnyTypeOf<[AnyMemRef, AnyTensor]>
* Miss a space
* Add tests
* Shorten ONNXElementWiseBinaryOpLowering into ONNXEWBinaryOpLowering
* Move lowering patterns into runOnModule()
* Redundant space
2019-12-04 00:17:21 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// Sub
|
2020-06-15 11:49:09 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
[MLIR] Lower ONNX element-wise binary ops: Mul, Div, Sub, And, Or, Xor (#388)
* Lower ONNX element-wise binary ops: Mul, Div, Sub, And, Or, Xor
* Edit gen_doc.py to avoid changes about AnyTypeOf<[AnyMemRef, AnyTensor]>
* Miss a space
* Add tests
* Shorten ONNXElementWiseBinaryOpLowering into ONNXEWBinaryOpLowering
* Move lowering patterns into runOnModule()
* Redundant space
2019-12-04 00:17:21 +08:00
|
|
|
/// Infer the output shape of the ONNXSubOp. This method is required by the
|
|
|
|
/// shape inference interface.
|
2020-05-27 10:09:28 +08:00
|
|
|
LogicalResult ONNXSubOp::inferShapes() {
|
2020-01-14 01:21:29 +08:00
|
|
|
if (!getOperand(0).getType().isa<RankedTensorType>() ||
|
2020-05-27 10:09:28 +08:00
|
|
|
!getOperand(1).getType().isa<RankedTensorType>())
|
|
|
|
return emitError("Input tensor(s) not ranked");
|
2020-01-14 01:21:29 +08:00
|
|
|
auto lhsTy = getOperand(0).getType().cast<RankedTensorType>();
|
|
|
|
auto rhsTy = getOperand(1).getType().cast<RankedTensorType>();
|
|
|
|
getResult().setType(getBroadcastedType(lhsTy, rhsTy));
|
2020-05-27 10:09:28 +08:00
|
|
|
return success();
|
[MLIR] Lower ONNX element-wise binary ops: Mul, Div, Sub, And, Or, Xor (#388)
* Lower ONNX element-wise binary ops: Mul, Div, Sub, And, Or, Xor
* Edit gen_doc.py to avoid changes about AnyTypeOf<[AnyMemRef, AnyTensor]>
* Miss a space
* Add tests
* Shorten ONNXElementWiseBinaryOpLowering into ONNXEWBinaryOpLowering
* Move lowering patterns into runOnModule()
* Redundant space
2019-12-04 00:17:21 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// And
|
2020-06-15 11:49:09 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
[MLIR] Lower ONNX element-wise binary ops: Mul, Div, Sub, And, Or, Xor (#388)
* Lower ONNX element-wise binary ops: Mul, Div, Sub, And, Or, Xor
* Edit gen_doc.py to avoid changes about AnyTypeOf<[AnyMemRef, AnyTensor]>
* Miss a space
* Add tests
* Shorten ONNXElementWiseBinaryOpLowering into ONNXEWBinaryOpLowering
* Move lowering patterns into runOnModule()
* Redundant space
2019-12-04 00:17:21 +08:00
|
|
|
/// Infer the output shape of the ONNXAndOp. This method is required by the
|
|
|
|
/// shape inference interface.
|
2020-05-27 10:09:28 +08:00
|
|
|
LogicalResult ONNXAndOp::inferShapes() {
|
2020-01-14 01:21:29 +08:00
|
|
|
if (!getOperand(0).getType().isa<RankedTensorType>() ||
|
2020-05-27 10:09:28 +08:00
|
|
|
!getOperand(1).getType().isa<RankedTensorType>())
|
|
|
|
return emitError("Input tensor(s) not ranked");
|
2020-01-14 01:21:29 +08:00
|
|
|
auto lhsTy = getOperand(0).getType().cast<RankedTensorType>();
|
|
|
|
auto rhsTy = getOperand(1).getType().cast<RankedTensorType>();
|
|
|
|
getResult().setType(getBroadcastedType(lhsTy, rhsTy));
|
2020-05-27 10:09:28 +08:00
|
|
|
return success();
|
[MLIR] Lower ONNX element-wise binary ops: Mul, Div, Sub, And, Or, Xor (#388)
* Lower ONNX element-wise binary ops: Mul, Div, Sub, And, Or, Xor
* Edit gen_doc.py to avoid changes about AnyTypeOf<[AnyMemRef, AnyTensor]>
* Miss a space
* Add tests
* Shorten ONNXElementWiseBinaryOpLowering into ONNXEWBinaryOpLowering
* Move lowering patterns into runOnModule()
* Redundant space
2019-12-04 00:17:21 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// Or
|
2020-06-15 11:49:09 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
[MLIR] Lower ONNX element-wise binary ops: Mul, Div, Sub, And, Or, Xor (#388)
* Lower ONNX element-wise binary ops: Mul, Div, Sub, And, Or, Xor
* Edit gen_doc.py to avoid changes about AnyTypeOf<[AnyMemRef, AnyTensor]>
* Miss a space
* Add tests
* Shorten ONNXElementWiseBinaryOpLowering into ONNXEWBinaryOpLowering
* Move lowering patterns into runOnModule()
* Redundant space
2019-12-04 00:17:21 +08:00
|
|
|
/// Infer the output shape of the ONNXOrOp. This method is required by the
|
|
|
|
/// shape inference interface.
|
2020-05-27 10:09:28 +08:00
|
|
|
LogicalResult ONNXOrOp::inferShapes() {
|
2020-01-14 01:21:29 +08:00
|
|
|
if (!getOperand(0).getType().isa<RankedTensorType>() ||
|
2020-05-27 10:09:28 +08:00
|
|
|
!getOperand(1).getType().isa<RankedTensorType>())
|
|
|
|
return emitError("Input tensor(s) not ranked");
|
2020-01-14 01:21:29 +08:00
|
|
|
auto lhsTy = getOperand(0).getType().cast<RankedTensorType>();
|
|
|
|
auto rhsTy = getOperand(1).getType().cast<RankedTensorType>();
|
|
|
|
getResult().setType(getBroadcastedType(lhsTy, rhsTy));
|
2020-05-27 10:09:28 +08:00
|
|
|
return success();
|
2019-12-20 00:28:06 +08:00
|
|
|
}
|
[MLIR] Lower ONNX element-wise binary ops: Mul, Div, Sub, And, Or, Xor (#388)
* Lower ONNX element-wise binary ops: Mul, Div, Sub, And, Or, Xor
* Edit gen_doc.py to avoid changes about AnyTypeOf<[AnyMemRef, AnyTensor]>
* Miss a space
* Add tests
* Shorten ONNXElementWiseBinaryOpLowering into ONNXEWBinaryOpLowering
* Move lowering patterns into runOnModule()
* Redundant space
2019-12-04 00:17:21 +08:00
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// Xor
|
2020-06-15 11:49:09 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
[MLIR] Lower ONNX element-wise binary ops: Mul, Div, Sub, And, Or, Xor (#388)
* Lower ONNX element-wise binary ops: Mul, Div, Sub, And, Or, Xor
* Edit gen_doc.py to avoid changes about AnyTypeOf<[AnyMemRef, AnyTensor]>
* Miss a space
* Add tests
* Shorten ONNXElementWiseBinaryOpLowering into ONNXEWBinaryOpLowering
* Move lowering patterns into runOnModule()
* Redundant space
2019-12-04 00:17:21 +08:00
|
|
|
/// Infer the output shape of the ONNXXorOp. This method is required by the
|
|
|
|
/// shape inference interface.
|
2020-05-27 10:09:28 +08:00
|
|
|
LogicalResult ONNXXorOp::inferShapes() {
|
2020-01-14 01:21:29 +08:00
|
|
|
if (!getOperand(0).getType().isa<RankedTensorType>() ||
|
2020-05-27 10:09:28 +08:00
|
|
|
!getOperand(1).getType().isa<RankedTensorType>())
|
|
|
|
return emitError("Input tensor(s) not ranked");
|
2020-01-14 01:21:29 +08:00
|
|
|
auto lhsTy = getOperand(0).getType().cast<RankedTensorType>();
|
|
|
|
auto rhsTy = getOperand(1).getType().cast<RankedTensorType>();
|
|
|
|
getResult().setType(getBroadcastedType(lhsTy, rhsTy));
|
2020-05-27 10:09:28 +08:00
|
|
|
return success();
|
[MLIR] Lower ONNX element-wise binary ops: Mul, Div, Sub, And, Or, Xor (#388)
* Lower ONNX element-wise binary ops: Mul, Div, Sub, And, Or, Xor
* Edit gen_doc.py to avoid changes about AnyTypeOf<[AnyMemRef, AnyTensor]>
* Miss a space
* Add tests
* Shorten ONNXElementWiseBinaryOpLowering into ONNXEWBinaryOpLowering
* Move lowering patterns into runOnModule()
* Redundant space
2019-12-04 00:17:21 +08:00
|
|
|
}
|
|
|
|
|
[MLIR] Add support for Max, Min, Sum, Elu, Selu, LeakyRelu, HardSigmoid (#395)
* Lower ONNXSumOp
* Add inferShapes() and test cases
* Load the first operand to the result
* Update SharingWork.md
* Update SharingWork.md
* Update SharingWork.md
* Add support for Max, Min
* Pass operation instead of location to mapToLowerScalarOp
* Add support for Elu, Selu, LeakyRelu, HardSigmoid
* Add test cases
* Update SharingWork.md
* Rewrite the part of lowering variadic ops and use it for binary ops
* Use two diffenrent templates for Unary and Variadic Ops
* Revise the code
2019-12-12 10:49:50 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// Sum
|
2020-06-15 11:49:09 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
[MLIR] Add support for Max, Min, Sum, Elu, Selu, LeakyRelu, HardSigmoid (#395)
* Lower ONNXSumOp
* Add inferShapes() and test cases
* Load the first operand to the result
* Update SharingWork.md
* Update SharingWork.md
* Update SharingWork.md
* Add support for Max, Min
* Pass operation instead of location to mapToLowerScalarOp
* Add support for Elu, Selu, LeakyRelu, HardSigmoid
* Add test cases
* Update SharingWork.md
* Rewrite the part of lowering variadic ops and use it for binary ops
* Use two diffenrent templates for Unary and Variadic Ops
* Revise the code
2019-12-12 10:49:50 +08:00
|
|
|
/// Infer the output shape of the ONNXSumOp. This method is required by the
|
|
|
|
/// shape inference interface.
|
2020-05-27 10:09:28 +08:00
|
|
|
LogicalResult ONNXSumOp::inferShapes() {
|
2019-12-20 00:28:06 +08:00
|
|
|
for (int i = 0; i < getNumOperands(); ++i) {
|
2020-05-27 10:09:28 +08:00
|
|
|
if (!getOperand(i).getType().cast<RankedTensorType>())
|
|
|
|
return emitError("Input tensor(s) not ranked");
|
2019-12-20 00:28:06 +08:00
|
|
|
}
|
2020-01-14 01:21:29 +08:00
|
|
|
Type resultTy = getOperand(0).getType().cast<RankedTensorType>();
|
2019-12-20 00:28:06 +08:00
|
|
|
for (int i = 1; i < getNumOperands(); ++i) {
|
2020-01-14 01:21:29 +08:00
|
|
|
Type nextTy = getOperand(i).getType().cast<RankedTensorType>();
|
2019-12-20 00:28:06 +08:00
|
|
|
resultTy = getBroadcastedType(resultTy, nextTy);
|
|
|
|
}
|
2020-01-14 01:21:29 +08:00
|
|
|
getResult().setType(resultTy);
|
2020-05-27 10:09:28 +08:00
|
|
|
return success();
|
[MLIR] Add support for Max, Min, Sum, Elu, Selu, LeakyRelu, HardSigmoid (#395)
* Lower ONNXSumOp
* Add inferShapes() and test cases
* Load the first operand to the result
* Update SharingWork.md
* Update SharingWork.md
* Update SharingWork.md
* Add support for Max, Min
* Pass operation instead of location to mapToLowerScalarOp
* Add support for Elu, Selu, LeakyRelu, HardSigmoid
* Add test cases
* Update SharingWork.md
* Rewrite the part of lowering variadic ops and use it for binary ops
* Use two diffenrent templates for Unary and Variadic Ops
* Revise the code
2019-12-12 10:49:50 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// Max
|
2020-06-15 11:49:09 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
[MLIR] Add support for Max, Min, Sum, Elu, Selu, LeakyRelu, HardSigmoid (#395)
* Lower ONNXSumOp
* Add inferShapes() and test cases
* Load the first operand to the result
* Update SharingWork.md
* Update SharingWork.md
* Update SharingWork.md
* Add support for Max, Min
* Pass operation instead of location to mapToLowerScalarOp
* Add support for Elu, Selu, LeakyRelu, HardSigmoid
* Add test cases
* Update SharingWork.md
* Rewrite the part of lowering variadic ops and use it for binary ops
* Use two diffenrent templates for Unary and Variadic Ops
* Revise the code
2019-12-12 10:49:50 +08:00
|
|
|
/// Infer the output shape of the ONNXMaxOp. This method is required by the
|
|
|
|
/// shape inference interface.
|
2020-05-27 10:09:28 +08:00
|
|
|
LogicalResult ONNXMaxOp::inferShapes() {
|
2019-12-20 00:28:06 +08:00
|
|
|
for (int i = 0; i < getNumOperands(); ++i) {
|
2020-05-27 10:09:28 +08:00
|
|
|
if (!getOperand(i).getType().cast<RankedTensorType>())
|
|
|
|
return emitError("Input tensor(s) not ranked");
|
2019-12-20 00:28:06 +08:00
|
|
|
}
|
2020-01-14 01:21:29 +08:00
|
|
|
Type resultTy = getOperand(0).getType().cast<RankedTensorType>();
|
2019-12-20 00:28:06 +08:00
|
|
|
for (int i = 1; i < getNumOperands(); ++i) {
|
2020-01-14 01:21:29 +08:00
|
|
|
Type nextTy = getOperand(i).getType().cast<RankedTensorType>();
|
2019-12-20 00:28:06 +08:00
|
|
|
resultTy = getBroadcastedType(resultTy, nextTy);
|
|
|
|
}
|
2020-01-14 01:21:29 +08:00
|
|
|
getResult().setType(resultTy);
|
2020-05-27 10:09:28 +08:00
|
|
|
return success();
|
[MLIR] Add support for Max, Min, Sum, Elu, Selu, LeakyRelu, HardSigmoid (#395)
* Lower ONNXSumOp
* Add inferShapes() and test cases
* Load the first operand to the result
* Update SharingWork.md
* Update SharingWork.md
* Update SharingWork.md
* Add support for Max, Min
* Pass operation instead of location to mapToLowerScalarOp
* Add support for Elu, Selu, LeakyRelu, HardSigmoid
* Add test cases
* Update SharingWork.md
* Rewrite the part of lowering variadic ops and use it for binary ops
* Use two diffenrent templates for Unary and Variadic Ops
* Revise the code
2019-12-12 10:49:50 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// Min
|
2020-06-15 11:49:09 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
[MLIR] Add support for Max, Min, Sum, Elu, Selu, LeakyRelu, HardSigmoid (#395)
* Lower ONNXSumOp
* Add inferShapes() and test cases
* Load the first operand to the result
* Update SharingWork.md
* Update SharingWork.md
* Update SharingWork.md
* Add support for Max, Min
* Pass operation instead of location to mapToLowerScalarOp
* Add support for Elu, Selu, LeakyRelu, HardSigmoid
* Add test cases
* Update SharingWork.md
* Rewrite the part of lowering variadic ops and use it for binary ops
* Use two diffenrent templates for Unary and Variadic Ops
* Revise the code
2019-12-12 10:49:50 +08:00
|
|
|
/// Infer the output shape of the ONNXMinOp. This method is required by the
|
|
|
|
/// shape inference interface.
|
2020-05-27 10:09:28 +08:00
|
|
|
LogicalResult ONNXMinOp::inferShapes() {
|
2019-12-20 00:28:06 +08:00
|
|
|
for (int i = 0; i < getNumOperands(); ++i) {
|
2020-05-27 10:09:28 +08:00
|
|
|
if (!getOperand(i).getType().cast<RankedTensorType>())
|
|
|
|
return emitError("Input tensor(s) not ranked");
|
2019-12-20 00:28:06 +08:00
|
|
|
}
|
2020-01-14 01:21:29 +08:00
|
|
|
Type resultTy = getOperand(0).getType().cast<RankedTensorType>();
|
2019-12-20 00:28:06 +08:00
|
|
|
for (int i = 1; i < getNumOperands(); ++i) {
|
2020-01-14 01:21:29 +08:00
|
|
|
Type nextTy = getOperand(i).getType().cast<RankedTensorType>();
|
2019-12-20 00:28:06 +08:00
|
|
|
resultTy = getBroadcastedType(resultTy, nextTy);
|
|
|
|
}
|
2020-01-14 01:21:29 +08:00
|
|
|
getResult().setType(resultTy);
|
2020-05-27 10:09:28 +08:00
|
|
|
return success();
|
[MLIR] Add support for Max, Min, Sum, Elu, Selu, LeakyRelu, HardSigmoid (#395)
* Lower ONNXSumOp
* Add inferShapes() and test cases
* Load the first operand to the result
* Update SharingWork.md
* Update SharingWork.md
* Update SharingWork.md
* Add support for Max, Min
* Pass operation instead of location to mapToLowerScalarOp
* Add support for Elu, Selu, LeakyRelu, HardSigmoid
* Add test cases
* Update SharingWork.md
* Rewrite the part of lowering variadic ops and use it for binary ops
* Use two diffenrent templates for Unary and Variadic Ops
* Revise the code
2019-12-12 10:49:50 +08:00
|
|
|
}
|
|
|
|
|
2020-05-07 11:42:43 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// Neg
|
2020-06-15 11:49:09 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
2020-05-07 11:42:43 +08:00
|
|
|
/// Infer the output shape of the ONNXNegOp. This method is required by the
|
|
|
|
/// shape inference interface.
|
2020-05-27 10:09:28 +08:00
|
|
|
LogicalResult ONNXNegOp::inferShapes() {
|
2020-05-07 11:42:43 +08:00
|
|
|
getResult().setType(getOperand().getType());
|
2020-05-27 10:09:28 +08:00
|
|
|
return success();
|
2020-05-07 11:42:43 +08:00
|
|
|
}
|
|
|
|
|
2019-12-17 07:45:39 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// Identity
|
2020-06-15 11:49:09 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
2019-12-17 07:45:39 +08:00
|
|
|
/// Infer the output shape of the ONNXIdentityOp. This method is required by the
|
|
|
|
/// shape inference interface.
|
2020-05-27 10:09:28 +08:00
|
|
|
LogicalResult ONNXIdentityOp::inferShapes() {
|
2020-01-14 01:21:29 +08:00
|
|
|
getResult().setType(getOperand().getType());
|
2020-05-27 10:09:28 +08:00
|
|
|
return success();
|
2019-12-17 07:45:39 +08:00
|
|
|
}
|
|
|
|
|
[MLIR] Add support for Max, Min, Sum, Elu, Selu, LeakyRelu, HardSigmoid (#395)
* Lower ONNXSumOp
* Add inferShapes() and test cases
* Load the first operand to the result
* Update SharingWork.md
* Update SharingWork.md
* Update SharingWork.md
* Add support for Max, Min
* Pass operation instead of location to mapToLowerScalarOp
* Add support for Elu, Selu, LeakyRelu, HardSigmoid
* Add test cases
* Update SharingWork.md
* Rewrite the part of lowering variadic ops and use it for binary ops
* Use two diffenrent templates for Unary and Variadic Ops
* Revise the code
2019-12-12 10:49:50 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
2019-11-13 02:37:46 +08:00
|
|
|
// MatMul
|
2020-06-15 11:49:09 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
2019-11-13 02:37:46 +08:00
|
|
|
|
2020-05-27 10:09:28 +08:00
|
|
|
LogicalResult ONNXMatMulOp::inferShapes() {
|
2019-11-16 02:10:41 +08:00
|
|
|
// Cannot infer shape if no shape exists.
|
2020-02-26 04:46:11 +08:00
|
|
|
if (!A().getType().isa<RankedTensorType>() ||
|
2020-05-27 10:09:28 +08:00
|
|
|
!B().getType().isa<RankedTensorType>())
|
|
|
|
return emitError("Input tensor(s) not ranked");
|
2020-01-10 04:30:57 +08:00
|
|
|
|
2020-02-26 04:46:11 +08:00
|
|
|
auto lhsTy = A().getType().cast<RankedTensorType>();
|
|
|
|
auto rhsTy = B().getType().cast<RankedTensorType>();
|
2020-01-10 04:30:57 +08:00
|
|
|
|
2019-11-16 02:10:41 +08:00
|
|
|
SmallVector<int64_t, 2> dims;
|
2020-01-10 04:30:57 +08:00
|
|
|
auto lhsShape = lhsTy.getShape();
|
|
|
|
auto rhsShape = rhsTy.getShape();
|
2020-01-28 01:08:23 +08:00
|
|
|
|
|
|
|
if (lhsShape.size() < 1 && rhsShape.size() < 1) {
|
|
|
|
// Multiplication by scalars is not allowed.
|
2020-05-27 10:09:28 +08:00
|
|
|
return emitError("Multiplication by scalar arguments not allowed");
|
2020-01-28 01:08:23 +08:00
|
|
|
} else if (lhsShape.size() == 1 && rhsShape.size() == 1) {
|
2020-01-10 04:30:57 +08:00
|
|
|
// Special case when both arrays are 1-dimensional and according to
|
|
|
|
// numpy rules the types need to be extended to 1xN and Nx1. Helper sizes
|
|
|
|
// need to be removed after the multiplication but cannot be removed if all
|
2020-01-11 00:22:41 +08:00
|
|
|
// sizes are 1.
|
2020-02-24 23:46:48 +08:00
|
|
|
if (lhsShape[0] != -1 && rhsShape[0] != -1 && lhsShape[0] != rhsShape[0])
|
2020-05-27 10:09:28 +08:00
|
|
|
return emitError("Attempt to multiply incompatible matrices");
|
2020-01-10 04:30:57 +08:00
|
|
|
dims.emplace_back(1);
|
2020-01-29 23:35:05 +08:00
|
|
|
} else if (lhsShape.size() == 1 && rhsShape.size() >= 2) {
|
|
|
|
// If the first argument is 1-D, it is promoted to a matrix by prepending a
|
|
|
|
// 1 to its dimensions. After matrix multiplication the prepended 1 is
|
|
|
|
// removed.
|
|
|
|
//
|
|
|
|
// N MATMUL (s1 x s2 x... x sK x N x P)
|
|
|
|
// =>
|
|
|
|
// (s1 x s2 x... x sK x P)
|
|
|
|
|
|
|
|
// Check legality of matrix multiplication.
|
|
|
|
unsigned rhsRank = rhsShape.size();
|
|
|
|
if (lhsShape[0] != -1 && rhsShape[rhsRank - 2] != -1 &&
|
|
|
|
lhsShape[0] != rhsShape[rhsRank - 2])
|
2020-05-27 10:09:28 +08:00
|
|
|
return emitError("Attempt to multiply incompatible matrices");
|
2020-02-14 23:43:17 +08:00
|
|
|
for (decltype(rhsRank) i = 0; i < rhsRank - 2; ++i)
|
2020-01-29 23:35:05 +08:00
|
|
|
dims.emplace_back(rhsShape[i]);
|
|
|
|
dims.emplace_back(rhsShape[rhsRank - 1]);
|
|
|
|
} else if (lhsShape.size() >= 2 && rhsShape.size() == 1) {
|
|
|
|
// If the second argument is 1-D, it is promoted to a matrix by appending a
|
|
|
|
// 1 to its dimensions. After matrix multiplication the appended 1 is
|
|
|
|
// removed.
|
|
|
|
//
|
|
|
|
// (s1 x s2 x... x sK x M x N) MATMUL N
|
|
|
|
// =>
|
|
|
|
// (s1 x s2 x... x sK x M)
|
|
|
|
|
|
|
|
// Check legality of matrix multiplication.
|
|
|
|
unsigned lhsRank = lhsShape.size();
|
|
|
|
if (lhsShape[lhsRank - 1] != -1 && rhsShape[0] != -1 &&
|
|
|
|
lhsShape[lhsRank - 1] != rhsShape[0])
|
2020-05-27 10:09:28 +08:00
|
|
|
return emitError("Attempt to multiply incompatible matrices");
|
2020-02-14 23:43:17 +08:00
|
|
|
for (decltype(lhsRank) i = 0; i < lhsRank - 2; ++i)
|
2020-01-29 23:35:05 +08:00
|
|
|
dims.emplace_back(lhsShape[i]);
|
|
|
|
dims.emplace_back(lhsShape[lhsRank - 2]);
|
2020-01-11 01:30:12 +08:00
|
|
|
} else if (lhsShape.size() > 2 && rhsShape.size() == 2) {
|
|
|
|
// (s1 x s2 x... x sK x M x N) MATMUL (N x P)
|
|
|
|
// =>
|
|
|
|
// (s1 x s2 x... x sK x M x P)
|
|
|
|
|
|
|
|
// Check legality of matrix multiplication.
|
2020-01-28 01:08:23 +08:00
|
|
|
unsigned lhsRank = lhsShape.size();
|
|
|
|
if (lhsShape[lhsRank - 1] != -1 && rhsShape[0] != -1 &&
|
|
|
|
lhsShape[lhsRank - 1] != rhsShape[0])
|
2020-05-27 10:09:28 +08:00
|
|
|
return emitError("Attempt to multiply incompatible matrices");
|
2020-02-14 23:43:17 +08:00
|
|
|
for (decltype(lhsRank) i = 0; i < lhsRank - 1; ++i)
|
2020-01-11 01:30:12 +08:00
|
|
|
dims.emplace_back(lhsShape[i]);
|
|
|
|
dims.emplace_back(rhsShape[1]);
|
|
|
|
} else if (lhsShape.size() == 2 && rhsShape.size() > 2) {
|
|
|
|
// (M x N) MATMUL (s1 x s2 x... x sK x N x P)
|
|
|
|
// =>
|
|
|
|
// (s1 x s2 x... x sK x M x P)
|
|
|
|
|
|
|
|
// Check legality of matrix multiplication.
|
2020-01-28 01:08:23 +08:00
|
|
|
unsigned rhsRank = rhsShape.size();
|
|
|
|
if (lhsShape[1] != -1 && rhsShape[rhsRank - 2] != -1 &&
|
|
|
|
lhsShape[1] != rhsShape[rhsRank - 2])
|
2020-05-27 10:09:28 +08:00
|
|
|
return emitError("Attempt to multiply incompatible matrices");
|
2020-02-14 23:43:17 +08:00
|
|
|
for (decltype(rhsRank) i = 0; i < rhsRank - 2; ++i)
|
2020-01-11 01:30:12 +08:00
|
|
|
dims.emplace_back(rhsShape[i]);
|
|
|
|
dims.emplace_back(lhsShape[0]);
|
2020-01-28 01:08:23 +08:00
|
|
|
dims.emplace_back(rhsShape[rhsRank - 1]);
|
2020-01-11 01:30:12 +08:00
|
|
|
} else if (lhsShape.size() > 2 && rhsShape.size() > 2) {
|
|
|
|
// (s1 x s2 x... x sK x M x N) MATMUL (t1 x t2 x... x tK x N x P)
|
|
|
|
// =>
|
|
|
|
// (u1 x u2 x... x uK x M x P)
|
|
|
|
|
|
|
|
// Check legality of matrix multiplication.
|
2020-01-28 01:08:23 +08:00
|
|
|
unsigned lhsRank = lhsShape.size();
|
|
|
|
unsigned rhsRank = rhsShape.size();
|
|
|
|
if (lhsShape[lhsRank - 1] != -1 && rhsShape[rhsRank - 2] != -1 &&
|
|
|
|
lhsShape[lhsRank - 1] != rhsShape[rhsRank - 2])
|
2020-05-27 10:09:28 +08:00
|
|
|
return emitError("Attempt to multiply incompatible matrices");
|
2020-01-11 01:30:12 +08:00
|
|
|
// Check and perform broadcasting for the shapes.
|
|
|
|
SmallVector<int64_t, 2> lhsBcastShape;
|
2020-02-14 23:43:17 +08:00
|
|
|
for (decltype(lhsRank) i = 0; i < lhsRank - 2; ++i)
|
2020-01-11 01:30:12 +08:00
|
|
|
lhsBcastShape.emplace_back(lhsShape[i]);
|
|
|
|
SmallVector<int64_t, 2> rhsBcastShape;
|
2020-02-14 23:43:17 +08:00
|
|
|
for (decltype(rhsRank) i = 0; i < rhsRank - 2; ++i)
|
2020-01-11 01:30:12 +08:00
|
|
|
rhsBcastShape.emplace_back(rhsShape[i]);
|
|
|
|
if (!getBroadcastedShape(lhsBcastShape, rhsBcastShape, dims))
|
2020-05-27 10:09:28 +08:00
|
|
|
return emitError("Broadcasted dimensions are incompatible");
|
2020-01-28 01:08:23 +08:00
|
|
|
dims.emplace_back(lhsShape[lhsRank - 2]);
|
|
|
|
dims.emplace_back(rhsShape[rhsRank - 1]);
|
2020-01-10 04:30:57 +08:00
|
|
|
} else {
|
2020-01-11 01:30:12 +08:00
|
|
|
// This case covers all remaining combinations of 1 and 2-D matrices.
|
2020-01-11 04:26:29 +08:00
|
|
|
int64_t lhsDim = lhsShape[0];
|
|
|
|
int64_t rhsDim = rhsShape[0];
|
|
|
|
if (lhsShape.size() > 1) {
|
|
|
|
lhsDim = lhsShape[1];
|
2020-01-11 04:16:45 +08:00
|
|
|
dims.emplace_back(lhsShape[0]);
|
2020-01-11 04:26:29 +08:00
|
|
|
}
|
2020-01-11 04:16:45 +08:00
|
|
|
|
2020-01-11 04:26:29 +08:00
|
|
|
// Check legality of matrix multiplication.
|
|
|
|
if (lhsDim != -1 && rhsDim != -1 && lhsDim != rhsDim)
|
2020-05-27 10:09:28 +08:00
|
|
|
return emitError("Attempt to multiply incompatible matrices");
|
2020-01-11 04:26:29 +08:00
|
|
|
if (rhsShape.size() > 1)
|
2020-01-11 04:16:45 +08:00
|
|
|
dims.emplace_back(rhsShape[1]);
|
2020-01-10 04:30:57 +08:00
|
|
|
}
|
|
|
|
|
2020-01-23 05:09:19 +08:00
|
|
|
getResult().setType(RankedTensorType::get(dims, lhsTy.getElementType()));
|
2020-05-27 10:09:28 +08:00
|
|
|
return success();
|
2019-11-13 02:37:46 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
// Gemm
|
2020-05-27 10:09:28 +08:00
|
|
|
LogicalResult ONNXGemmOp::inferShapes() {
|
2020-02-26 04:46:11 +08:00
|
|
|
bool hasBias = !C().getType().isa<NoneType>();
|
2019-11-16 02:10:41 +08:00
|
|
|
// Cannot infer shape if no shape exists.
|
2020-02-26 04:46:11 +08:00
|
|
|
if (!A().getType().isa<RankedTensorType>() ||
|
|
|
|
!B().getType().isa<RankedTensorType>() ||
|
2020-05-27 10:09:28 +08:00
|
|
|
(hasBias && !C().getType().isa<RankedTensorType>()))
|
|
|
|
return emitError("Input tensor(s) not ranked");
|
2020-02-26 04:46:11 +08:00
|
|
|
auto lhsTy = A().getType().cast<RankedTensorType>();
|
|
|
|
auto rhsTy = B().getType().cast<RankedTensorType>();
|
2020-01-30 00:11:49 +08:00
|
|
|
|
|
|
|
int64_t M, N, K_A, K_B;
|
|
|
|
M = (transA() == 0) ? lhsTy.getShape()[0] : lhsTy.getShape()[1];
|
|
|
|
K_A = (transA() == 0) ? lhsTy.getShape()[1] : lhsTy.getShape()[0];
|
|
|
|
N = (transB() == 0) ? rhsTy.getShape()[1] : rhsTy.getShape()[0];
|
|
|
|
K_B = (transB() == 0) ? rhsTy.getShape()[0] : rhsTy.getShape()[1];
|
|
|
|
|
2020-05-27 10:09:28 +08:00
|
|
|
if ((K_A != -1) && (K_B != -1) && (K_A != K_B))
|
|
|
|
return emitError("Tensor shapes mismatched");
|
2020-01-30 00:11:49 +08:00
|
|
|
|
2020-02-24 23:46:48 +08:00
|
|
|
if (hasBias) {
|
|
|
|
// Check whether bias is unidirectional broadcasting or not.
|
2020-02-26 04:46:11 +08:00
|
|
|
auto biasTy = C().getType().cast<RankedTensorType>();
|
2020-02-24 23:46:48 +08:00
|
|
|
auto shape = biasTy.getShape();
|
|
|
|
int rank = shape.size();
|
|
|
|
if ((rank > 2) ||
|
|
|
|
(rank >= 1 && shape[rank - 1] != -1 && N != -1 &&
|
2020-02-26 03:33:48 +08:00
|
|
|
N != shape[rank - 1] && shape[rank - 1] != 1) ||
|
2020-02-24 23:46:48 +08:00
|
|
|
(rank == 2 && shape[rank - 2] != -1 && M != -1 &&
|
2020-05-27 10:09:28 +08:00
|
|
|
M != shape[rank - 2] && shape[rank - 2] != 1))
|
|
|
|
return emitError("Bias shape mismatched");
|
2020-02-20 23:55:24 +08:00
|
|
|
}
|
|
|
|
|
2019-11-16 02:10:41 +08:00
|
|
|
SmallVector<int64_t, 2> dims;
|
2020-02-20 23:55:24 +08:00
|
|
|
dims.emplace_back(M);
|
|
|
|
dims.emplace_back(N);
|
2020-01-14 01:21:29 +08:00
|
|
|
getResult().setType(RankedTensorType::get(dims, lhsTy.getElementType()));
|
2020-05-27 10:09:28 +08:00
|
|
|
return success();
|
2019-11-13 02:37:46 +08:00
|
|
|
}
|
|
|
|
|
2020-02-21 00:45:40 +08:00
|
|
|
/// BatchNormalizationTestMode
|
2020-05-27 10:09:28 +08:00
|
|
|
LogicalResult ONNXBatchNormalizationTestModeOp::inferShapes() {
|
2020-02-21 00:45:40 +08:00
|
|
|
// Cannot infer shape if no shape exists.
|
2020-02-26 04:46:11 +08:00
|
|
|
if (!X().getType().isa<RankedTensorType>() ||
|
|
|
|
!scale().getType().isa<RankedTensorType>() ||
|
|
|
|
!B().getType().isa<RankedTensorType>() ||
|
|
|
|
!mean().getType().isa<RankedTensorType>() ||
|
2020-05-27 10:09:28 +08:00
|
|
|
!var().getType().isa<RankedTensorType>())
|
|
|
|
return emitError("Input tensor(s) not ranked");
|
2020-02-21 00:45:40 +08:00
|
|
|
|
2020-02-26 04:46:11 +08:00
|
|
|
auto inputTensorTy = X().getType().cast<RankedTensorType>();
|
|
|
|
auto scaleTensorTy = scale().getType().cast<RankedTensorType>();
|
|
|
|
auto biasTensorTy = B().getType().cast<RankedTensorType>();
|
|
|
|
auto meanTensorTy = mean().getType().cast<RankedTensorType>();
|
|
|
|
auto varianceTensorTy = var().getType().cast<RankedTensorType>();
|
2020-02-21 00:45:40 +08:00
|
|
|
|
|
|
|
// Check whether the shapes of scale, bias, mean and variance are valid.
|
|
|
|
// Operand's dimensions can be in the form of NxCxD1xD2x...xDn or N.
|
|
|
|
// In case of N, C is assumed to be 1.
|
|
|
|
// Shapes of scale, bias, mean and variance must be C.
|
|
|
|
int64_t c = -1;
|
2020-02-26 04:46:11 +08:00
|
|
|
if (inputTensorTy.getShape().size() == 1) {
|
2020-02-21 00:45:40 +08:00
|
|
|
c = 1;
|
2020-02-26 04:46:11 +08:00
|
|
|
} else if (inputTensorTy.getShape().size() > 2) {
|
|
|
|
c = (inputTensorTy.getShape()[1] != -1) ? inputTensorTy.getShape()[1] : -1;
|
2020-02-21 00:45:40 +08:00
|
|
|
} else {
|
2020-05-27 10:09:28 +08:00
|
|
|
return emitError("Wrong rank for the input");
|
2020-02-21 00:45:40 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
if (c != -1) {
|
2020-02-26 04:46:11 +08:00
|
|
|
auto s = scaleTensorTy.getShape();
|
|
|
|
auto b = biasTensorTy.getShape();
|
|
|
|
auto m = meanTensorTy.getShape();
|
|
|
|
auto v = varianceTensorTy.getShape();
|
2020-02-21 00:45:40 +08:00
|
|
|
|
|
|
|
if ((s.size() != 1) || (s[0] != -1 && s[0] != c))
|
2020-05-27 10:09:28 +08:00
|
|
|
return emitError("Wrong rank for the scale");
|
2020-02-21 00:45:40 +08:00
|
|
|
if ((b.size() != 1) || (b[0] != -1 && b[0] != c))
|
2020-05-27 10:09:28 +08:00
|
|
|
return emitError("Wrong rank for the bias");
|
2020-02-21 00:45:40 +08:00
|
|
|
if ((m.size() != 1) || (m[0] != -1 && m[0] != c))
|
2020-05-27 10:09:28 +08:00
|
|
|
return emitError("Wrong rank for the mean");
|
2020-02-21 00:45:40 +08:00
|
|
|
if ((v.size() != 1) || (v[0] != -1 && v[0] != c))
|
2020-05-27 10:09:28 +08:00
|
|
|
return emitError("Wrong rank for the variance");
|
2020-02-21 00:45:40 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
// The output tensor of the same shape as the input.
|
2020-02-26 04:46:11 +08:00
|
|
|
getResult().setType(X().getType());
|
2020-05-27 10:09:28 +08:00
|
|
|
return success();
|
2020-02-21 00:45:40 +08:00
|
|
|
}
|
|
|
|
|
2019-11-13 02:37:46 +08:00
|
|
|
// TODO:
|
|
|
|
// Verify that matrix sizes are valid for multiplication and addition.
|
|
|
|
// Take into account the dimensionality of the matrix.
|
|
|
|
|
2019-12-14 04:28:56 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// Reshape
|
2020-06-15 11:49:09 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
2019-12-14 04:28:56 +08:00
|
|
|
|
2020-05-27 10:09:28 +08:00
|
|
|
LogicalResult ONNXReshapeOp::inferShapes() {
|
2019-12-14 04:28:56 +08:00
|
|
|
// Cannot infer shape if no shape tensor is specified.
|
2020-05-27 10:09:28 +08:00
|
|
|
if (!data().getType().isa<RankedTensorType>())
|
|
|
|
return emitError("Input data tensor not ranked");
|
2020-04-06 23:35:17 +08:00
|
|
|
|
2020-05-27 10:09:28 +08:00
|
|
|
if (!shape().getType().isa<RankedTensorType>())
|
|
|
|
return emitError("Shape tensor not ranked");
|
2019-12-14 04:28:56 +08:00
|
|
|
|
2020-02-26 04:46:11 +08:00
|
|
|
auto inputTensorTy = data().getType().cast<RankedTensorType>();
|
|
|
|
auto shapeTensorTy = shape().getType().cast<RankedTensorType>();
|
2019-12-14 04:28:56 +08:00
|
|
|
|
|
|
|
// Only rank 1 shape tensors are supported.
|
|
|
|
if (shapeTensorTy.getShape().size() != 1)
|
2020-05-27 10:09:28 +08:00
|
|
|
return emitError("Shape tensor must have rank one");
|
2019-12-14 04:28:56 +08:00
|
|
|
int64_t outputRank = shapeTensorTy.getShape()[0];
|
|
|
|
|
|
|
|
// Shape tensor must have constant shape.
|
|
|
|
if (outputRank < 0)
|
2020-05-27 10:09:28 +08:00
|
|
|
return emitError("Shape tensor must have constant shape");
|
2020-03-14 05:18:46 +08:00
|
|
|
// Compute total number of elements.
|
|
|
|
int64_t totalInputSize = 1;
|
2020-03-30 23:22:55 +08:00
|
|
|
for (auto inputDim : inputTensorTy.getShape())
|
2020-03-14 05:18:46 +08:00
|
|
|
totalInputSize *= inputDim;
|
|
|
|
|
2020-03-11 02:46:35 +08:00
|
|
|
// Check if second argument of ReshapeOp is a constant.
|
2020-03-14 05:18:46 +08:00
|
|
|
auto constantOp = getONNXConstantOp(shape());
|
2020-03-11 02:46:35 +08:00
|
|
|
|
|
|
|
SmallVector<int64_t, 2> dims(outputRank, -1);
|
|
|
|
if (constantOp) {
|
2020-03-16 23:17:28 +08:00
|
|
|
DenseElementsAttr valueAttribute =
|
|
|
|
constantOp.valueAttr().dyn_cast<DenseElementsAttr>();
|
2020-03-11 02:46:35 +08:00
|
|
|
if (!valueAttribute)
|
2020-05-27 10:09:28 +08:00
|
|
|
return emitError("DenseElementsAttr expected");
|
2020-03-16 23:17:28 +08:00
|
|
|
// Get dims from valueAttribute.
|
|
|
|
auto valueIt = valueAttribute.getValues<IntegerAttr>().begin();
|
2020-03-30 23:22:55 +08:00
|
|
|
for (int i = 0; i < outputRank; ++i)
|
2020-03-16 23:17:28 +08:00
|
|
|
dims[i] = (*valueIt++).cast<IntegerAttr>().getInt();
|
2020-03-11 02:46:35 +08:00
|
|
|
|
2020-03-16 23:17:28 +08:00
|
|
|
if (valueIt != valueAttribute.getValues<IntegerAttr>().end())
|
2020-05-27 10:09:28 +08:00
|
|
|
return emitError("Constant value must have same rank as output");
|
2020-03-14 05:18:46 +08:00
|
|
|
int64_t numberOfDynamicInputs = 0;
|
|
|
|
int64_t totalKnownDimsSize = 1;
|
|
|
|
int64_t dynamicValueIndex = -1;
|
2020-03-30 23:22:55 +08:00
|
|
|
for (int i = 0; i < outputRank; ++i) {
|
2020-03-14 05:18:46 +08:00
|
|
|
// Set output dimension.
|
|
|
|
if (dims[i] == 0)
|
|
|
|
dims[i] = inputTensorTy.getShape()[i];
|
|
|
|
|
|
|
|
if (dims[i] < 0) {
|
|
|
|
numberOfDynamicInputs++;
|
|
|
|
dynamicValueIndex = i;
|
|
|
|
} else {
|
|
|
|
totalKnownDimsSize *= dims[i];
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
// If the number of dynamic inputs is 1 then deduce the missing value
|
|
|
|
// based on the total input size. The total input size must be greater
|
|
|
|
// than 0 i.e. all constant dimensions.
|
|
|
|
// TODO: Support dynamic input dimensons.
|
|
|
|
if (numberOfDynamicInputs == 1 && totalKnownDimsSize > 0 &&
|
|
|
|
totalInputSize > 0)
|
|
|
|
dims[dynamicValueIndex] = totalInputSize / totalKnownDimsSize;
|
2020-03-11 02:46:35 +08:00
|
|
|
}
|
2019-12-14 04:28:56 +08:00
|
|
|
|
2020-01-14 01:21:29 +08:00
|
|
|
getResult().setType(
|
2019-12-14 04:28:56 +08:00
|
|
|
RankedTensorType::get(dims, inputTensorTy.getElementType()));
|
2020-05-27 10:09:28 +08:00
|
|
|
return success();
|
2019-12-14 04:28:56 +08:00
|
|
|
}
|
|
|
|
|
2020-01-08 03:48:01 +08:00
|
|
|
// Transpose
|
|
|
|
|
2020-05-27 10:09:28 +08:00
|
|
|
LogicalResult ONNXTransposeOp::inferShapes() {
|
2020-01-08 03:48:01 +08:00
|
|
|
// Cannot infer shape if no shape exists.
|
2020-05-27 10:09:28 +08:00
|
|
|
if (!data().getType().isa<RankedTensorType>())
|
|
|
|
return emitError("Input tensor not ranked");
|
2020-01-08 03:48:01 +08:00
|
|
|
|
|
|
|
// Naive transposition which handles the default case of
|
|
|
|
// reversing the shape of the tensor (similar to numpy.transpose).
|
2020-02-26 04:46:11 +08:00
|
|
|
auto arrayTy = data().getType().cast<RankedTensorType>();
|
2020-01-14 07:08:19 +08:00
|
|
|
SmallVector<int64_t, 2> dims;
|
2020-01-27 23:09:14 +08:00
|
|
|
auto permutation = ONNXTransposeOp::permAttr();
|
2020-06-17 22:42:06 +08:00
|
|
|
if (!permutation) {
|
|
|
|
// Generate revese order for default transpose operation.
|
|
|
|
SmallVector<int64_t, 4> defaultVals;
|
|
|
|
auto builder = mlir::Builder(getContext());
|
|
|
|
auto rank = arrayTy.getShape().size();
|
|
|
|
for (int i = rank - 1; i >= 0; --i)
|
|
|
|
defaultVals.emplace_back(i);
|
|
|
|
// Set default attribute.
|
|
|
|
ArrayRef<int64_t> defaultRefs(defaultVals);
|
|
|
|
permAttr(builder.getI64ArrayAttr(defaultRefs));
|
|
|
|
permutation = permAttr();
|
2020-01-14 07:08:19 +08:00
|
|
|
}
|
2020-06-17 22:42:06 +08:00
|
|
|
// Perform transposition according to perm attribute.
|
|
|
|
for (auto perm : permutation.getValue())
|
|
|
|
dims.emplace_back(arrayTy.getShape()[perm.cast<IntegerAttr>().getInt()]);
|
2020-01-14 01:21:29 +08:00
|
|
|
getResult().setType(RankedTensorType::get(dims, arrayTy.getElementType()));
|
2020-05-27 10:09:28 +08:00
|
|
|
return success();
|
2020-01-08 03:48:01 +08:00
|
|
|
}
|
|
|
|
|
2020-01-21 00:16:27 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
2020-02-10 21:38:19 +08:00
|
|
|
// ReduceMax
|
2020-06-15 11:49:09 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
2020-02-10 21:38:19 +08:00
|
|
|
|
2020-05-27 10:09:28 +08:00
|
|
|
LogicalResult ONNXReduceMaxOp::inferShapes() {
|
|
|
|
if (!getOperand().getType().isa<RankedTensorType>())
|
|
|
|
return emitError("Input tensor not ranked");
|
2020-02-10 21:38:19 +08:00
|
|
|
|
|
|
|
auto operandTy = getOperand().getType().cast<RankedTensorType>();
|
|
|
|
getResult().setType(getReductionOutputType(operandTy, axes(), keepdims()));
|
2020-05-27 10:09:28 +08:00
|
|
|
return success();
|
2020-02-10 21:38:19 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// ReduceMin
|
2020-06-15 11:49:09 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
2020-02-10 21:38:19 +08:00
|
|
|
|
2020-05-27 10:09:28 +08:00
|
|
|
LogicalResult ONNXReduceMinOp::inferShapes() {
|
|
|
|
if (!getOperand().getType().isa<RankedTensorType>())
|
|
|
|
return emitError("Input tensor not ranked");
|
2020-02-10 21:38:19 +08:00
|
|
|
|
|
|
|
auto operandTy = getOperand().getType().cast<RankedTensorType>();
|
|
|
|
getResult().setType(getReductionOutputType(operandTy, axes(), keepdims()));
|
2020-05-27 10:09:28 +08:00
|
|
|
return success();
|
2020-02-10 21:38:19 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// ReduceProd
|
2020-06-15 11:49:09 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
2020-02-10 21:38:19 +08:00
|
|
|
|
2020-05-27 10:09:28 +08:00
|
|
|
LogicalResult ONNXReduceProdOp::inferShapes() {
|
|
|
|
if (!getOperand().getType().isa<RankedTensorType>())
|
|
|
|
return emitError("Input tensor not ranked");
|
2020-02-10 21:38:19 +08:00
|
|
|
|
|
|
|
auto operandTy = getOperand().getType().cast<RankedTensorType>();
|
|
|
|
getResult().setType(getReductionOutputType(operandTy, axes(), keepdims()));
|
2020-05-27 10:09:28 +08:00
|
|
|
return success();
|
2020-02-10 21:38:19 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// ReduceSum
|
2020-06-15 11:49:09 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
2020-02-10 21:38:19 +08:00
|
|
|
|
2020-05-27 10:09:28 +08:00
|
|
|
LogicalResult ONNXReduceSumOp::inferShapes() {
|
|
|
|
if (!getOperand().getType().isa<RankedTensorType>())
|
|
|
|
return emitError("Input tensor not ranked");
|
2020-02-10 21:38:19 +08:00
|
|
|
|
|
|
|
auto operandTy = getOperand().getType().cast<RankedTensorType>();
|
|
|
|
getResult().setType(getReductionOutputType(operandTy, axes(), keepdims()));
|
2020-05-27 10:09:28 +08:00
|
|
|
return success();
|
2020-02-10 21:38:19 +08:00
|
|
|
}
|
|
|
|
|
2020-03-14 05:18:46 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
2020-01-21 00:16:27 +08:00
|
|
|
// Conv
|
2020-06-15 11:49:09 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
2020-01-21 00:16:27 +08:00
|
|
|
|
|
|
|
// For this operation, we define the attributes once in the original Conv
|
|
|
|
// operation class. There is no need to redefine the attribute names for the
|
|
|
|
// other classes based on Conv.
|
2020-03-12 06:36:02 +08:00
|
|
|
// Conv attributes output:
|
|
|
|
// - auto_pad set to NOTSET;
|
|
|
|
// - dilations, strides: set to 1 if not defined by user;
|
|
|
|
// - kernelShape: inferred from weight matrix if not defined by user;
|
|
|
|
// - pads: set to proper value, 0 if not defined by user.
|
|
|
|
|
2020-05-27 10:09:28 +08:00
|
|
|
LogicalResult ONNXConvOp::inferShapes() {
|
2020-03-26 23:03:19 +08:00
|
|
|
// Generic shape for data input X, weight tensor W, and optional bias B
|
2020-01-21 00:16:27 +08:00
|
|
|
// X: (N x C x D1 x D2 ... x Dn)
|
|
|
|
// W: (M x C/group x k1 x k2 x ... x kn)
|
2020-03-26 23:03:19 +08:00
|
|
|
// B: (M) Optional
|
|
|
|
|
|
|
|
bool hasBias = !B().getType().isa<NoneType>();
|
2020-01-21 00:16:27 +08:00
|
|
|
|
|
|
|
// Cannot infer shape if no shape exists.
|
2020-02-26 04:46:11 +08:00
|
|
|
if (!X().getType().isa<RankedTensorType>() ||
|
2020-03-26 23:03:19 +08:00
|
|
|
!W().getType().isa<RankedTensorType>() ||
|
2020-05-27 10:09:28 +08:00
|
|
|
(hasBias && !B().getType().isa<RankedTensorType>()))
|
|
|
|
return emitError("Input tensor not ranked");
|
2020-01-22 09:39:11 +08:00
|
|
|
|
2020-03-12 06:36:02 +08:00
|
|
|
auto xTy = X().getType().cast<RankedTensorType>();
|
|
|
|
auto xShape = xTy.getShape();
|
2020-02-26 04:46:11 +08:00
|
|
|
auto weightTy = W().getType().cast<RankedTensorType>();
|
2020-01-21 00:16:27 +08:00
|
|
|
auto weightShape = weightTy.getShape();
|
2020-03-14 05:18:46 +08:00
|
|
|
auto builder = mlir::Builder(this->getContext());
|
2020-01-21 00:16:27 +08:00
|
|
|
|
2020-02-14 23:54:08 +08:00
|
|
|
// Lowest supported convolution is a one dimensional convolution.
|
2020-05-27 10:09:28 +08:00
|
|
|
if (xShape.size() < 3)
|
|
|
|
return emitError("Data input shape must be at least (NxCxD1)");
|
2020-02-08 05:51:32 +08:00
|
|
|
|
2020-01-21 07:50:21 +08:00
|
|
|
// Check that shape of weight and data have same length.
|
2020-05-27 10:09:28 +08:00
|
|
|
if (xShape.size() != weightShape.size())
|
|
|
|
return emitError("Weight size not compatible with data size");
|
2020-01-21 00:16:27 +08:00
|
|
|
|
|
|
|
// Group is a required attribute and should have default value of 1.
|
2020-03-26 23:03:19 +08:00
|
|
|
int64_t group = ONNXConvOp::group().getSExtValue();
|
2020-03-14 05:18:46 +08:00
|
|
|
|
|
|
|
// Check if the attribute actually exists. If it does not then add it.
|
|
|
|
if (!groupAttr())
|
|
|
|
groupAttr(builder.getI64IntegerAttr(group));
|
|
|
|
|
2020-01-21 00:16:27 +08:00
|
|
|
// Check that the X.shape[1] == (W.shape[1] * group) == C condition holds.
|
2020-03-12 06:36:02 +08:00
|
|
|
if (xShape[1] != -1 && weightShape[1] != -1 &&
|
2020-05-27 10:09:28 +08:00
|
|
|
xShape[1] != (weightShape[1] * group))
|
|
|
|
return emitError("Channel dimension mismatch");
|
2020-01-21 00:16:27 +08:00
|
|
|
|
2020-03-26 23:03:19 +08:00
|
|
|
// Check the size of bias.
|
|
|
|
if (hasBias) {
|
|
|
|
auto bTx = B().getType().cast<RankedTensorType>();
|
|
|
|
auto bShape = bTx.getShape();
|
2020-05-27 10:09:28 +08:00
|
|
|
if (bShape.size() != 1)
|
|
|
|
return emitError("bias should be one dimensional");
|
|
|
|
if (bShape[0] != weightShape[0])
|
|
|
|
return emitError("bias should have same dimensions "
|
|
|
|
"as weight's first dimension");
|
2020-03-26 23:03:19 +08:00
|
|
|
}
|
2020-03-30 23:22:55 +08:00
|
|
|
|
2020-01-23 04:05:56 +08:00
|
|
|
// Note: the value of the group attribut only impacts the way the
|
|
|
|
// computation is carried out and not the actual output size.
|
|
|
|
|
2020-01-21 07:50:21 +08:00
|
|
|
// Number of spatial dimensions.
|
2020-03-12 06:36:02 +08:00
|
|
|
auto spatialOffset = 2;
|
|
|
|
int32_t spatialRank = xShape.size() - spatialOffset;
|
2020-01-21 07:50:21 +08:00
|
|
|
|
|
|
|
// Use kernel_shape attribute if present otherwise use size from weight
|
|
|
|
// argument.
|
2020-03-12 06:36:02 +08:00
|
|
|
auto kernelShape = kernel_shape();
|
|
|
|
if (kernelShape.hasValue()) {
|
2020-05-27 10:09:28 +08:00
|
|
|
if (ArrayAttrSize(kernelShape) != spatialRank)
|
|
|
|
return emitError(
|
|
|
|
"kernel_shape length incompatible with spatial dimensions");
|
2020-03-12 06:36:02 +08:00
|
|
|
// Have the right number of values, check them.
|
|
|
|
for (int i = 0; i < spatialRank; ++i)
|
2020-05-27 10:09:28 +08:00
|
|
|
if (ArrayAttrIntVal(kernelShape, i) < 1)
|
|
|
|
return emitError("bad kernel_shape value");
|
2020-01-21 07:50:21 +08:00
|
|
|
} else {
|
2020-03-12 06:36:02 +08:00
|
|
|
// Deduce shape from weight input.
|
|
|
|
SmallVector<int64_t, 2> defaultVals;
|
|
|
|
for (int i = 0; i < spatialRank; ++i)
|
|
|
|
defaultVals.emplace_back(weightShape[spatialOffset + i]);
|
|
|
|
// Convert to ArrayRef, then build attribute, then store attribute.
|
|
|
|
ArrayRef<int64_t> defaultRefs(defaultVals);
|
|
|
|
auto builder = mlir::Builder(getContext());
|
|
|
|
kernel_shapeAttr(builder.getI64ArrayAttr(defaultRefs));
|
|
|
|
kernelShape = kernel_shape();
|
2020-01-21 07:50:21 +08:00
|
|
|
}
|
|
|
|
|
2020-03-12 06:36:02 +08:00
|
|
|
// Process strides, dilations, and pads.
|
|
|
|
processConvTypeParams<>(this, X());
|
|
|
|
auto dilationsOpt = dilations();
|
|
|
|
auto stridesOpt = strides();
|
|
|
|
auto padsOpt = pads();
|
2020-03-11 04:58:05 +08:00
|
|
|
|
2020-03-12 06:36:02 +08:00
|
|
|
// First two output dimensions consist of the number of batches and the
|
|
|
|
// number of kernels being applied.
|
|
|
|
SmallVector<int64_t, 4> outputDims;
|
|
|
|
// Insert batch size.
|
|
|
|
outputDims.emplace_back(xShape[0]);
|
|
|
|
// Insert number of filters being applied (number of output channels).
|
|
|
|
outputDims.emplace_back(weightShape[0]);
|
2020-03-13 21:59:16 +08:00
|
|
|
// Compute and insert spatial dims.
|
2020-05-14 17:31:33 +08:00
|
|
|
insertConvSpatialDim(&outputDims, builder, xShape, kernelShape, padsOpt,
|
|
|
|
stridesOpt, dilationsOpt);
|
2020-03-12 06:36:02 +08:00
|
|
|
|
|
|
|
getResult().setType(RankedTensorType::get(outputDims, xTy.getElementType()));
|
2020-05-27 10:09:28 +08:00
|
|
|
return success();
|
2020-01-21 00:16:27 +08:00
|
|
|
}
|
|
|
|
|
2020-01-29 23:46:02 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
2020-06-09 14:55:49 +08:00
|
|
|
// ConvTranspose
|
2020-06-15 11:49:09 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
2020-06-09 14:55:49 +08:00
|
|
|
|
|
|
|
// For this operation, we define the attributes once in the original Conv
|
|
|
|
// operation class. There is no need to redefine the attribute names for the
|
|
|
|
// other classes based on Conv.
|
|
|
|
// Conv attributes output:
|
|
|
|
// - auto_pad set to NOTSET;
|
|
|
|
// - dilations, strides: set to 1 if not defined by user;
|
|
|
|
// - kernelShape: inferred from weight matrix if not defined by user;
|
|
|
|
// - pads: set to proper value, 0 if not defined by user.
|
|
|
|
|
|
|
|
LogicalResult ONNXConvTransposeOp::inferShapes() {
|
|
|
|
// Generic shape for data input X, weight tensor W, and optional bias B
|
|
|
|
// X: (N x C x D1 x D2 ... x Dn)
|
|
|
|
// W: (M x C/group x k1 x k2 x ... x kn)
|
|
|
|
// B: (M) Optional
|
|
|
|
|
|
|
|
bool hasBias = !B().getType().isa<NoneType>();
|
|
|
|
|
|
|
|
// Cannot infer shape if no shape exists.
|
|
|
|
if (!X().getType().isa<RankedTensorType>() ||
|
|
|
|
!W().getType().isa<RankedTensorType>() ||
|
|
|
|
(hasBias && !B().getType().isa<RankedTensorType>())) {
|
|
|
|
return emitError("Input tensor not ranked");
|
|
|
|
}
|
|
|
|
|
|
|
|
auto xTy = X().getType().cast<RankedTensorType>();
|
|
|
|
auto xShape = xTy.getShape();
|
|
|
|
auto weightTy = W().getType().cast<RankedTensorType>();
|
|
|
|
auto weightShape = weightTy.getShape();
|
|
|
|
auto builder = mlir::Builder(this->getContext());
|
|
|
|
|
|
|
|
// Lowest supported convolution is a one dimensional convolution.
|
|
|
|
if (xShape.size() < 3) {
|
|
|
|
return emitError("Data input shape must be at least (NxCxD1)");
|
|
|
|
}
|
|
|
|
|
|
|
|
// Check that shape of weight and data have same length.
|
|
|
|
if (xShape.size() != weightShape.size()) {
|
|
|
|
return emitError("Weight size not compatible with data size");
|
|
|
|
}
|
|
|
|
|
|
|
|
// Group is a required attribute and should have default value of 1.
|
|
|
|
int64_t group = ONNXConvTransposeOp::group().getSExtValue();
|
|
|
|
|
|
|
|
// Check if the attribute actually exists. If it does not then add it.
|
|
|
|
if (!groupAttr())
|
|
|
|
groupAttr(builder.getI64IntegerAttr(group));
|
|
|
|
|
|
|
|
// Check that the X.shape[1] == (W.shape[0] * group) == C condition holds.
|
|
|
|
if (xShape[1] != -1 && weightShape[0] != -1 &&
|
|
|
|
xShape[1] != (weightShape[0] * group)) {
|
|
|
|
return emitError("Channel dimension mismatch");
|
|
|
|
}
|
|
|
|
|
|
|
|
// Check the size of bias.
|
|
|
|
if (hasBias) {
|
|
|
|
auto bTx = B().getType().cast<RankedTensorType>();
|
|
|
|
auto bShape = bTx.getShape();
|
|
|
|
if (bShape.size() != 1) {
|
|
|
|
return emitError("bias should be one dimensional");
|
|
|
|
}
|
|
|
|
if (bShape[0] != weightShape[1]) {
|
|
|
|
return emitError(
|
|
|
|
"bias should have same dimensions as weight's second dimension");
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
// Note: the value of the group attribut only impacts the way the
|
|
|
|
// computation is carried out and not the actual output size.
|
|
|
|
|
|
|
|
// Number of spatial dimensions.
|
|
|
|
auto spatialOffset = 2;
|
|
|
|
int32_t spatialRank = xShape.size() - spatialOffset;
|
|
|
|
|
|
|
|
// Use kernel_shape attribute if present otherwise use size from weight
|
|
|
|
// argument.
|
|
|
|
auto kernelShape = kernel_shape();
|
|
|
|
if (kernelShape.hasValue()) {
|
|
|
|
if (ArrayAttrSize(kernelShape) != spatialRank) {
|
|
|
|
return emitError(
|
|
|
|
"kernel_shape length incompatible with spatial dimensions");
|
|
|
|
}
|
|
|
|
// Have the right number of values, check them.
|
|
|
|
for (int i = 0; i < spatialRank; ++i)
|
|
|
|
if (ArrayAttrIntVal(kernelShape, i) < 1) {
|
|
|
|
return emitError("bad kernel_shape value");
|
|
|
|
}
|
|
|
|
} else {
|
|
|
|
// Deduce shape from weight input.
|
|
|
|
SmallVector<int64_t, 2> defaultVals;
|
|
|
|
for (int i = 0; i < spatialRank; ++i)
|
|
|
|
defaultVals.emplace_back(weightShape[spatialOffset + i]);
|
|
|
|
// Convert to ArrayRef, then build attribute, then store attribute.
|
|
|
|
ArrayRef<int64_t> defaultRefs(defaultVals);
|
|
|
|
auto builder = mlir::Builder(getContext());
|
|
|
|
kernel_shapeAttr(builder.getI64ArrayAttr(defaultRefs));
|
|
|
|
kernelShape = kernel_shape();
|
|
|
|
}
|
|
|
|
|
|
|
|
// Process strides, dilations, and pads.
|
|
|
|
processConvTypeParams<>(this, X());
|
|
|
|
auto dilationsOpt = dilations();
|
|
|
|
auto stridesOpt = strides();
|
|
|
|
auto padsOpt = pads();
|
|
|
|
auto outputPads = output_padding();
|
|
|
|
auto outputShape = output_shape();
|
|
|
|
// TODO: handle the spatial dimension computation if output shape is specified
|
|
|
|
assert(!outputShape.hasValue() && "unhandled option in ConvTranspose");
|
|
|
|
|
|
|
|
// First two output dimensions consist of the number of batches and the
|
|
|
|
// number of kernels being applied.
|
|
|
|
SmallVector<int64_t, 4> outputDims;
|
|
|
|
// Insert batch size.
|
|
|
|
outputDims.emplace_back(xShape[0]);
|
|
|
|
// Insert number of filters being applied (number of output channels).
|
|
|
|
outputDims.emplace_back(weightShape[1]);
|
|
|
|
// Compute and insert spatial dims.
|
|
|
|
insertConvTransposeSpatialDim(outputDims, xShape, kernelShape, padsOpt,
|
|
|
|
stridesOpt, outputPads, outputShape, dilationsOpt);
|
|
|
|
|
|
|
|
// Set the output shape if it's not already set
|
|
|
|
if (!outputShape.hasValue()) {
|
|
|
|
output_shapeAttr(builder.getI64ArrayAttr(outputDims));
|
|
|
|
}
|
|
|
|
|
|
|
|
getResult().setType(RankedTensorType::get(outputDims, xTy.getElementType()));
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
2020-03-13 21:59:16 +08:00
|
|
|
// AveragePool
|
2020-06-15 11:49:09 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2020-02-26 03:33:48 +08:00
|
|
|
// Infer shape attributes output:
|
|
|
|
// - auto_pad set to NOTSET;
|
2020-03-13 21:59:16 +08:00
|
|
|
// - strides: set to 1 if not defined by user;
|
2020-02-26 03:33:48 +08:00
|
|
|
// - pads: set to proper value, 0 if not defined by user.
|
2020-01-31 03:30:28 +08:00
|
|
|
|
2020-05-27 10:09:28 +08:00
|
|
|
LogicalResult ONNXAveragePoolOp::inferShapes() {
|
2020-01-31 03:30:28 +08:00
|
|
|
// Cannot infer shape if no shape exists.
|
2020-05-27 10:09:28 +08:00
|
|
|
if (!X().getType().isa<RankedTensorType>())
|
|
|
|
return emitError("Input tensor not ranked");
|
2020-01-31 03:30:28 +08:00
|
|
|
|
2020-05-14 17:31:33 +08:00
|
|
|
auto builder = mlir::Builder(getContext());
|
|
|
|
|
2020-03-13 21:59:16 +08:00
|
|
|
// Get shape of input.
|
2020-01-31 03:30:28 +08:00
|
|
|
auto xTy = X().getType().cast<RankedTensorType>();
|
|
|
|
auto xShape = xTy.getShape();
|
|
|
|
|
2020-03-13 21:59:16 +08:00
|
|
|
// Kernel shape.
|
2020-01-31 03:30:28 +08:00
|
|
|
auto kernelShape = kernel_shape();
|
|
|
|
if (!kernelShape)
|
2020-05-27 10:09:28 +08:00
|
|
|
return emitError(
|
2020-02-26 03:33:48 +08:00
|
|
|
"kernel_shape is a mandatory attribute for which there is no default");
|
2020-01-31 03:30:28 +08:00
|
|
|
|
2020-02-26 03:33:48 +08:00
|
|
|
// Ceil mode.
|
2020-01-31 03:30:28 +08:00
|
|
|
auto ceilMode = ceil_mode().getSExtValue();
|
|
|
|
|
2020-03-13 21:59:16 +08:00
|
|
|
// Process strides and pads.
|
2020-05-27 10:09:28 +08:00
|
|
|
LogicalResult res =
|
|
|
|
processConvStrideParam<ONNXAveragePoolOp>(this, kernelShape);
|
|
|
|
if (failed(res))
|
|
|
|
return res;
|
2020-03-13 21:59:16 +08:00
|
|
|
auto stridesOpt = strides();
|
2020-05-27 10:09:28 +08:00
|
|
|
res = processConvPadParam<ONNXAveragePoolOp>(
|
2020-03-13 21:59:16 +08:00
|
|
|
this, xShape, kernelShape, stridesOpt, llvm::None);
|
2020-05-27 10:09:28 +08:00
|
|
|
if (failed(res))
|
|
|
|
return res;
|
2020-03-13 21:59:16 +08:00
|
|
|
auto padsOpt = pads();
|
|
|
|
|
|
|
|
SmallVector<int64_t, 4> outputDims;
|
|
|
|
// Insert batch size.
|
|
|
|
outputDims.emplace_back(xShape[0]);
|
|
|
|
outputDims.emplace_back(xShape[1]);
|
|
|
|
// Compute and insert spatial dims.
|
2020-05-14 17:31:33 +08:00
|
|
|
insertConvSpatialDim(&outputDims, builder, xShape, kernelShape, padsOpt,
|
|
|
|
stridesOpt, llvm::None, ceilMode);
|
2020-03-13 21:59:16 +08:00
|
|
|
|
|
|
|
getResult().setType(RankedTensorType::get(outputDims, xTy.getElementType()));
|
2020-05-27 10:09:28 +08:00
|
|
|
return success();
|
2020-03-13 21:59:16 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// MaxPoolSingleOut
|
2020-06-15 11:49:09 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2020-03-13 21:59:16 +08:00
|
|
|
// Infer shape attributes output:
|
|
|
|
// - auto_pad set to NOTSET;
|
|
|
|
// - dilations, strides: set to 1 if not defined by user;
|
|
|
|
// - pads: set to proper value, 0 if not defined by user.
|
|
|
|
|
2020-05-27 10:09:28 +08:00
|
|
|
LogicalResult ONNXMaxPoolSingleOutOp::inferShapes() {
|
2020-03-13 21:59:16 +08:00
|
|
|
// Cannot infer shape if no shape exists.
|
2020-05-27 10:09:28 +08:00
|
|
|
if (!X().getType().isa<RankedTensorType>())
|
|
|
|
return emitError("Input tensor not ranked");
|
2020-03-13 21:59:16 +08:00
|
|
|
|
2020-05-14 17:31:33 +08:00
|
|
|
auto builder = mlir::Builder(getContext());
|
|
|
|
|
2020-03-13 21:59:16 +08:00
|
|
|
// Get shape of input.
|
|
|
|
auto xTy = X().getType().cast<RankedTensorType>();
|
|
|
|
auto xShape = xTy.getShape();
|
|
|
|
|
|
|
|
// Kernel shape.
|
|
|
|
auto kernelShape = kernel_shape();
|
|
|
|
if (!kernelShape)
|
2020-05-27 10:09:28 +08:00
|
|
|
return emitError(
|
2020-03-13 21:59:16 +08:00
|
|
|
"kernel_shape is a mandatory attribute for which there is no default");
|
|
|
|
|
2020-02-26 03:33:48 +08:00
|
|
|
// Storage order.
|
|
|
|
auto storageOrder = storage_order().getSExtValue();
|
|
|
|
if (storageOrder != 0)
|
2020-05-27 10:09:28 +08:00
|
|
|
return emitError("column major storage order not supported at this time");
|
2020-02-24 23:46:48 +08:00
|
|
|
|
2020-03-13 21:59:16 +08:00
|
|
|
// Process strides, dilations, and pads.
|
|
|
|
processConvTypeParams<>(this, X());
|
2020-03-12 06:36:02 +08:00
|
|
|
auto dilationsOpt = dilations();
|
|
|
|
auto stridesOpt = strides();
|
|
|
|
auto padsOpt = pads();
|
2020-03-13 21:59:16 +08:00
|
|
|
|
|
|
|
// Ceil mode.
|
|
|
|
auto ceilMode = ceil_mode().getSExtValue();
|
|
|
|
|
|
|
|
SmallVector<int64_t, 4> outputDims;
|
|
|
|
// Insert batch size.
|
|
|
|
outputDims.emplace_back(xShape[0]);
|
|
|
|
outputDims.emplace_back(xShape[1]);
|
|
|
|
// Compute and insert spatial dims.
|
2020-05-14 17:31:33 +08:00
|
|
|
insertConvSpatialDim(&outputDims, builder, xShape, kernelShape, padsOpt,
|
|
|
|
stridesOpt, dilationsOpt, ceilMode);
|
2020-03-13 21:59:16 +08:00
|
|
|
|
|
|
|
getResult().setType(RankedTensorType::get(outputDims, xTy.getElementType()));
|
2020-05-27 10:09:28 +08:00
|
|
|
return success();
|
2020-01-31 03:30:28 +08:00
|
|
|
}
|
|
|
|
|
2020-06-15 11:49:09 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// Pad
|
2020-01-31 03:30:28 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2020-05-27 10:09:28 +08:00
|
|
|
LogicalResult ONNXPadOp::inferShapes() {
|
2020-05-15 13:19:28 +08:00
|
|
|
// Cannot infer shape if no shape exists.
|
2020-05-27 10:09:28 +08:00
|
|
|
if (!data().getType().isa<RankedTensorType>())
|
|
|
|
return emitError("Pad: unknown input shape");
|
2020-05-15 13:19:28 +08:00
|
|
|
|
|
|
|
// Cannot infer if the pads is not constant
|
|
|
|
DenseElementsAttr padsAttributes =
|
|
|
|
getAttr("pads").dyn_cast_or_null<mlir::DenseElementsAttr>();
|
2020-05-27 10:09:28 +08:00
|
|
|
if (!padsAttributes)
|
|
|
|
return emitError("Pad: unknown pads");
|
2020-05-15 13:19:28 +08:00
|
|
|
|
|
|
|
auto dataTy = data().getType().cast<RankedTensorType>();
|
|
|
|
auto dataShape = dataTy.getShape();
|
|
|
|
auto dataRank = dataTy.getRank();
|
|
|
|
SmallVector<int64_t, 4> outputShape(dataShape.begin(), dataShape.end());
|
|
|
|
|
|
|
|
// Get pads from valueAttribute.
|
|
|
|
SmallVector<int64_t, 2> pads(dataRank * 2, -1);
|
|
|
|
auto valueIt = padsAttributes.getValues<IntegerAttr>().begin();
|
|
|
|
for (int64_t i = 0; i < dataRank * 2; ++i)
|
|
|
|
pads[i] = (*valueIt++).cast<IntegerAttr>().getInt();
|
|
|
|
|
|
|
|
// Pads consists of two values for each axis of data.
|
|
|
|
// The two values specify the number of elements padded before and after
|
|
|
|
// respectively.
|
|
|
|
for (int64_t i = 0; i < dataRank; ++i) {
|
|
|
|
int64_t p1 = pads[i];
|
|
|
|
int64_t p2 = pads[i + dataRank];
|
|
|
|
// Have to non-negative constant
|
2020-05-27 10:09:28 +08:00
|
|
|
if (p1 < 0 || p2 < 0)
|
|
|
|
return emitError("padding value can not be negative");
|
2020-05-15 13:19:28 +08:00
|
|
|
if (outputShape[i] != -1)
|
|
|
|
outputShape[i] += p1 + p2;
|
|
|
|
}
|
|
|
|
|
|
|
|
auto outputType = RankedTensorType::get(outputShape, dataTy.getElementType());
|
|
|
|
getResult().setType(outputType);
|
2020-05-27 10:09:28 +08:00
|
|
|
return success();
|
2020-05-15 13:19:28 +08:00
|
|
|
}
|
|
|
|
|
2020-02-26 06:43:49 +08:00
|
|
|
static Type padShapeInferenceHelper(Value data, ArrayAttr padsOpt) {
|
2020-02-14 01:08:29 +08:00
|
|
|
// Cannot infer shape if no shape exists.
|
2020-02-26 06:43:49 +08:00
|
|
|
if (!data.getType().isa<RankedTensorType>())
|
|
|
|
return (Type)NULL;
|
|
|
|
auto dataTy = data.getType().cast<RankedTensorType>();
|
2020-02-14 01:08:29 +08:00
|
|
|
auto dataShape = dataTy.getShape();
|
|
|
|
auto dataRank = dataShape.size();
|
|
|
|
SmallVector<int64_t, 4> outputShape(dataShape.begin(), dataShape.end());
|
|
|
|
if (padsOpt) {
|
|
|
|
auto padsArray = padsOpt.getValue();
|
2020-02-26 06:43:49 +08:00
|
|
|
// Pads consists of two values for each axis of data.
|
2020-03-12 06:36:02 +08:00
|
|
|
// The two values specify the number of elements padded before and after
|
|
|
|
// respectively.
|
2020-02-14 01:08:29 +08:00
|
|
|
for (int i = 0; i < dataRank; ++i) {
|
2020-03-13 22:19:27 +08:00
|
|
|
int64_t p1 = (padsArray[i]).cast<IntegerAttr>().getInt();
|
|
|
|
int64_t p2 = (padsArray[i + dataRank]).cast<IntegerAttr>().getInt();
|
2020-03-12 06:36:02 +08:00
|
|
|
// Have to non-negative constant
|
|
|
|
if (p1 < 0 || p2 < 0)
|
2020-02-26 06:43:49 +08:00
|
|
|
return (Type)NULL;
|
2020-03-30 23:22:55 +08:00
|
|
|
if (outputShape[i] != -1)
|
2020-03-13 22:19:27 +08:00
|
|
|
outputShape[i] += p1 + p2;
|
2020-02-14 01:08:29 +08:00
|
|
|
}
|
2020-02-26 06:43:49 +08:00
|
|
|
|
|
|
|
return (RankedTensorType::get(outputShape, dataTy.getElementType()));
|
2020-02-14 01:08:29 +08:00
|
|
|
} else {
|
2020-02-26 06:43:49 +08:00
|
|
|
return (Type)NULL;
|
2020-02-14 01:08:29 +08:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2020-06-15 11:49:09 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
2020-02-26 06:43:49 +08:00
|
|
|
// PadConstantPad
|
2020-06-15 11:49:09 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
2020-02-26 06:43:49 +08:00
|
|
|
|
2020-05-27 10:09:28 +08:00
|
|
|
LogicalResult ONNXPadConstantPadOp::inferShapes() {
|
2020-02-26 06:43:49 +08:00
|
|
|
auto outputType = padShapeInferenceHelper(data(), pads());
|
2020-05-27 10:09:28 +08:00
|
|
|
if (!outputType)
|
|
|
|
return emitError("missing output");
|
|
|
|
getResult().setType(outputType);
|
|
|
|
return success();
|
2020-02-26 06:43:49 +08:00
|
|
|
}
|
|
|
|
|
2020-02-14 01:08:29 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// PadConstantValuePad
|
2020-06-15 11:49:09 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
2020-02-14 01:08:29 +08:00
|
|
|
|
2020-05-27 10:09:28 +08:00
|
|
|
LogicalResult ONNXPadConstantValuePadOp::inferShapes() {
|
2020-02-26 06:43:49 +08:00
|
|
|
auto outputType = padShapeInferenceHelper(data(), pads());
|
2020-05-27 10:09:28 +08:00
|
|
|
if (!outputType)
|
|
|
|
return emitError("missing output");
|
|
|
|
getResult().setType(outputType);
|
|
|
|
return success();
|
2020-02-14 01:08:29 +08:00
|
|
|
}
|
|
|
|
|
2020-05-20 15:45:42 +08:00
|
|
|
void ONNXPadConstantValuePadOp::build(OpBuilder &builder, OperationState &state,
|
2020-03-10 08:15:58 +08:00
|
|
|
Value data, ArrayAttr pads, FloatAttr constant_value, StringAttr mode) {
|
|
|
|
Type outputType = padShapeInferenceHelper(data, pads);
|
|
|
|
if (!outputType) {
|
|
|
|
auto elementType = data.getType().cast<TensorType>().getElementType();
|
|
|
|
outputType = UnrankedTensorType::get(elementType);
|
|
|
|
}
|
|
|
|
build(builder, state, outputType, data, pads, constant_value, mode);
|
|
|
|
}
|
|
|
|
|
2020-02-14 01:08:29 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
2020-01-29 23:46:02 +08:00
|
|
|
// Unsqueeze
|
2020-06-15 11:49:09 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
2020-01-29 23:46:02 +08:00
|
|
|
|
2020-05-27 10:09:28 +08:00
|
|
|
LogicalResult ONNXUnsqueezeOp::inferShapes() {
|
|
|
|
if (!data().getType().isa<RankedTensorType>())
|
|
|
|
return emitError("Input tensor not ranked");
|
2020-01-29 23:46:02 +08:00
|
|
|
|
2020-02-26 04:46:11 +08:00
|
|
|
auto operandTy = data().getType().cast<RankedTensorType>();
|
2020-01-29 23:46:02 +08:00
|
|
|
int inRank = operandTy.getRank();
|
|
|
|
|
|
|
|
ArrayAttr axisAttrs = axesAttr();
|
|
|
|
SmallVector<int, 4> axes;
|
|
|
|
int outRank = 0;
|
|
|
|
if (axisAttrs) {
|
|
|
|
outRank = inRank + axisAttrs.getValue().size();
|
|
|
|
for (auto axisAttr : axisAttrs.getValue()) {
|
|
|
|
int axis = axisAttr.cast<IntegerAttr>().getInt();
|
|
|
|
axis = axis >= 0 ? axis : (outRank + axis);
|
|
|
|
// Valid range
|
|
|
|
assert(axis >= -outRank && axis <= outRank - 1);
|
|
|
|
if (std::find(axes.begin(), axes.end(), axis) == axes.end())
|
|
|
|
axes.emplace_back(axis);
|
2020-05-27 10:09:28 +08:00
|
|
|
else
|
|
|
|
return emitError("Duplicated axes");
|
2020-01-29 23:46:02 +08:00
|
|
|
}
|
2020-05-27 10:09:28 +08:00
|
|
|
} else
|
|
|
|
return emitError("Axes attribute is required");
|
2020-01-29 23:46:02 +08:00
|
|
|
|
|
|
|
SmallVector<int64_t, 4> dims;
|
|
|
|
for (int i = 0, j = 0; i < outRank || j < inRank; ++i) {
|
|
|
|
if (std::find(axes.begin(), axes.end(), i) != axes.end()) {
|
|
|
|
dims.emplace_back(1);
|
|
|
|
} else {
|
|
|
|
dims.emplace_back(operandTy.getShape()[j++]);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
getResult().setType(RankedTensorType::get(dims, operandTy.getElementType()));
|
2020-05-27 10:09:28 +08:00
|
|
|
return success();
|
2020-01-29 23:46:02 +08:00
|
|
|
}
|
|
|
|
|
2020-06-09 14:55:49 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
2020-07-03 16:26:41 +08:00
|
|
|
|
|
|
|
// Squeeze
|
|
|
|
|
|
|
|
LogicalResult ONNXSqueezeOp::inferShapes() {
|
|
|
|
if (!data().getType().isa<RankedTensorType>())
|
|
|
|
return emitError("Input tensor not ranked");
|
|
|
|
|
|
|
|
auto operandTy = data().getType().cast<RankedTensorType>();
|
|
|
|
int64_t inRank = operandTy.getRank();
|
|
|
|
|
|
|
|
ArrayAttr axisAttrs = axesAttr();
|
|
|
|
if (!axisAttrs)
|
|
|
|
return emitError("Axes attribute is required");
|
|
|
|
|
|
|
|
SmallVector<int64_t, 4> axes;
|
|
|
|
bool hasNegativeAxis = false;
|
|
|
|
for (auto axisAttr : axisAttrs.getValue()) {
|
|
|
|
int64_t axis = axisAttr.cast<IntegerAttr>().getInt();
|
|
|
|
if (axis < -inRank || axis >= inRank)
|
|
|
|
return emitError("Invalid axis value");
|
|
|
|
if (axis < 0) {
|
|
|
|
axis = inRank + axis;
|
|
|
|
hasNegativeAxis = true;
|
|
|
|
}
|
|
|
|
if (std::find(axes.begin(), axes.end(), axis) != axes.end())
|
|
|
|
return emitError("Duplicated axes");
|
|
|
|
axes.emplace_back(axis);
|
|
|
|
}
|
|
|
|
if (hasNegativeAxis) {
|
|
|
|
// Update axes attribute so that it contains only positive values.
|
|
|
|
auto builder = mlir::Builder(getContext());
|
|
|
|
ArrayRef<int64_t> defaultRefs(axes);
|
|
|
|
axesAttr(builder.getI64ArrayAttr(defaultRefs));
|
|
|
|
}
|
|
|
|
|
|
|
|
SmallVector<int64_t, 4> dims;
|
|
|
|
for (int i = 0; i < inRank; ++i) {
|
|
|
|
if (std::find(axes.begin(), axes.end(), i) == axes.end()) {
|
|
|
|
dims.emplace_back(operandTy.getShape()[i]);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
getResult().setType(RankedTensorType::get(dims, operandTy.getElementType()));
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
2020-06-09 14:55:49 +08:00
|
|
|
// Cast
|
2020-06-15 11:49:09 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
2020-06-09 14:55:49 +08:00
|
|
|
|
|
|
|
LogicalResult ONNXCastOp::inferShapes() {
|
|
|
|
ShapedType inputType = input().getType().dyn_cast<ShapedType>();
|
|
|
|
if (!inputType) {
|
|
|
|
return emitError("Non-shaped input type");
|
|
|
|
}
|
|
|
|
|
|
|
|
auto getOutputType = [&inputType](Type elementType) -> Type {
|
|
|
|
if (inputType.hasRank()) {
|
|
|
|
return RankedTensorType::get(inputType.getShape(), elementType);
|
|
|
|
}
|
|
|
|
return UnrankedTensorType::get(elementType);
|
|
|
|
};
|
|
|
|
|
|
|
|
int64_t targetType = toAttr().getInt();
|
|
|
|
OpBuilder builder(getContext());
|
|
|
|
if (auto elementType = convertONNXTypeToMLIRType(
|
|
|
|
builder, static_cast<onnx::TensorProto_DataType>(targetType))) {
|
|
|
|
getResult().setType(getOutputType(elementType));
|
|
|
|
} else {
|
|
|
|
return emitOpError("Unable to get the element type for to = " +
|
|
|
|
std::to_string(targetType));
|
|
|
|
}
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
2020-03-12 22:58:42 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// Constant
|
2020-06-15 11:49:09 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
2020-03-12 22:58:42 +08:00
|
|
|
|
2020-05-27 10:09:28 +08:00
|
|
|
LogicalResult ONNXConstantOp::inferShapes() {
|
2020-03-12 22:58:42 +08:00
|
|
|
if ((sparse_value().hasValue() && value().hasValue()) ||
|
|
|
|
(!sparse_value().hasValue() && !value().hasValue()))
|
2020-05-27 10:09:28 +08:00
|
|
|
return emitError("Require exactly one of the two attributes, "
|
|
|
|
"either value or sparse_value");
|
2020-03-12 22:58:42 +08:00
|
|
|
ElementsAttr valAttr;
|
|
|
|
if (sparse_value().hasValue())
|
|
|
|
valAttr = sparse_valueAttr().cast<SparseElementsAttr>();
|
|
|
|
else
|
|
|
|
valAttr = valueAttr().cast<DenseElementsAttr>();
|
|
|
|
getResult().setType(valAttr.getType());
|
2020-05-27 10:09:28 +08:00
|
|
|
return success();
|
2020-03-12 22:58:42 +08:00
|
|
|
}
|
|
|
|
|
2020-06-15 11:49:09 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
2020-04-08 04:13:41 +08:00
|
|
|
// Concat
|
2020-06-15 11:49:09 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
2020-04-08 04:13:41 +08:00
|
|
|
|
2020-05-27 10:09:28 +08:00
|
|
|
LogicalResult ONNXConcatOp::inferShapes() {
|
2020-04-08 04:13:41 +08:00
|
|
|
int inputNum = getNumOperands();
|
|
|
|
for (int i = 0; i < inputNum; ++i) {
|
2020-06-09 14:55:49 +08:00
|
|
|
if (!getOperand(i).getType().isa<RankedTensorType>())
|
2020-05-27 10:09:28 +08:00
|
|
|
return emitError("Input tensor(s) not ranked");
|
2020-04-08 04:13:41 +08:00
|
|
|
}
|
|
|
|
// Checking value of axis parameter.
|
|
|
|
auto commonType = getOperand(0).getType().cast<RankedTensorType>();
|
|
|
|
auto commonShape = commonType.getShape();
|
|
|
|
auto commonRank = commonShape.size();
|
|
|
|
auto axisIndex = axis().getSExtValue();
|
2020-04-13 23:40:39 +08:00
|
|
|
// Negative axis means values are counted from the opposite side.
|
|
|
|
if (axisIndex < 0) {
|
|
|
|
axisIndex = commonRank + axisIndex;
|
|
|
|
auto builder = mlir::Builder(getContext());
|
|
|
|
axisAttr(builder.getI64IntegerAttr(axisIndex));
|
|
|
|
}
|
2020-05-27 10:09:28 +08:00
|
|
|
if (axisIndex >= commonRank)
|
|
|
|
return emitError("Concat axis value out of bound");
|
2020-04-08 04:13:41 +08:00
|
|
|
// Initial cummlative size is that of the first operand.
|
|
|
|
int cummulativeAxisSize = commonShape[axisIndex];
|
|
|
|
|
2020-05-27 10:09:28 +08:00
|
|
|
// Compute the cummlative size with all of the other ones, and make sure
|
|
|
|
// that the other sizes are all alike.
|
2020-04-08 04:13:41 +08:00
|
|
|
for (int i = 1; i < inputNum; ++i) {
|
|
|
|
auto currShape =
|
|
|
|
getOperand(i).getType().cast<RankedTensorType>().getShape();
|
2020-05-27 10:09:28 +08:00
|
|
|
if (currShape.size() != commonRank)
|
|
|
|
return emitError("Concat input must all have the same rank");
|
2020-04-08 04:13:41 +08:00
|
|
|
for (int j = 0; j < commonRank; ++j) {
|
|
|
|
if (j == axisIndex) {
|
|
|
|
// Check that the value is positive.
|
2020-05-27 10:09:28 +08:00
|
|
|
if (currShape[j] <= 0)
|
|
|
|
return emitError("Concat axis being concatenated is "
|
|
|
|
"expected to be known at compile time for now");
|
2020-04-08 04:13:41 +08:00
|
|
|
} else if (currShape[j] != commonShape[j]) {
|
2020-05-27 10:09:28 +08:00
|
|
|
return emitError(
|
|
|
|
"Concat input dimensions must be all identical, "
|
|
|
|
"except for dimension on the axis of the concatenation");
|
2020-04-08 04:13:41 +08:00
|
|
|
}
|
|
|
|
}
|
|
|
|
cummulativeAxisSize += currShape[axisIndex];
|
|
|
|
}
|
|
|
|
|
|
|
|
// Set output size and type
|
|
|
|
SmallVector<int64_t, 4> outputDims;
|
|
|
|
for (int j = 0; j < commonRank; ++j)
|
|
|
|
outputDims.emplace_back(
|
|
|
|
j == axisIndex ? cummulativeAxisSize : commonShape[j]);
|
|
|
|
getResult().setType(
|
|
|
|
RankedTensorType::get(outputDims, commonType.getElementType()));
|
2020-05-27 10:09:28 +08:00
|
|
|
return success();
|
2020-04-08 04:13:41 +08:00
|
|
|
}
|
|
|
|
|
2020-05-13 21:08:06 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// RNN
|
2020-06-15 11:49:09 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
2020-05-13 21:08:06 +08:00
|
|
|
|
2020-05-27 10:09:28 +08:00
|
|
|
LogicalResult ONNXRNNOp::inferShapes() { return RNNShapeInference<>(this); }
|
2020-05-13 21:08:06 +08:00
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// LSTM
|
2020-06-15 11:49:09 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
2020-05-13 21:08:06 +08:00
|
|
|
|
2020-05-27 10:09:28 +08:00
|
|
|
LogicalResult ONNXLSTMOp::inferShapes() { return RNNShapeInference<>(this); }
|
2020-05-13 21:08:06 +08:00
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// GRU
|
2020-06-15 11:49:09 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
2020-05-13 21:08:06 +08:00
|
|
|
|
2020-05-27 10:09:28 +08:00
|
|
|
LogicalResult ONNXGRUOp::inferShapes() { return RNNShapeInference<>(this); }
|
2020-05-13 21:08:06 +08:00
|
|
|
|
2020-05-13 18:07:27 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// Split
|
2020-06-15 11:49:09 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
2020-05-13 18:07:27 +08:00
|
|
|
|
2020-05-27 10:09:28 +08:00
|
|
|
LogicalResult ONNXSplitOp::inferShapes() {
|
|
|
|
if (!getOperand().getType().cast<RankedTensorType>())
|
|
|
|
return emitError("Input tensor not ranked");
|
2020-05-13 18:07:27 +08:00
|
|
|
|
|
|
|
int numOfResults = getNumResults();
|
|
|
|
auto inputType = getOperand().getType().cast<RankedTensorType>();
|
|
|
|
auto inputShape = inputType.getShape();
|
|
|
|
int64_t inputRank = inputShape.size();
|
|
|
|
|
|
|
|
// Checking value of axis parameter.
|
|
|
|
auto axisIndex = axis().getSExtValue();
|
2020-05-27 10:09:28 +08:00
|
|
|
if (axisIndex < -inputRank || axisIndex >= inputRank)
|
|
|
|
return emitError("Split axis value out of bound");
|
2020-05-13 18:07:27 +08:00
|
|
|
// Negative axis means values are counted from the opposite side.
|
|
|
|
if (axisIndex < 0) {
|
|
|
|
axisIndex = inputRank + axisIndex;
|
|
|
|
auto builder = mlir::Builder(getContext());
|
|
|
|
axisAttr(builder.getI64IntegerAttr(axisIndex));
|
|
|
|
}
|
|
|
|
|
|
|
|
// Checking value of split parameter.
|
|
|
|
auto splitAttribute = split();
|
|
|
|
SmallVector<int64_t, 4> splitLengths;
|
|
|
|
if (splitAttribute.hasValue()) {
|
2020-05-27 10:09:28 +08:00
|
|
|
if (ArrayAttrSize(splitAttribute) != numOfResults)
|
|
|
|
return emitError("Split size not equal to the number of results");
|
2020-05-13 18:07:27 +08:00
|
|
|
for (int i = 0; i < numOfResults; ++i)
|
|
|
|
splitLengths.emplace_back(ArrayAttrIntVal(splitAttribute, i));
|
|
|
|
|
|
|
|
} else {
|
2020-05-27 10:09:28 +08:00
|
|
|
if (inputShape[axisIndex] <= 0)
|
|
|
|
return emitError("The dimension at the split axis is "
|
|
|
|
"expected to be known at compile time");
|
|
|
|
if (inputShape[axisIndex] % numOfResults != 0)
|
|
|
|
return emitError("The dimension at the split axis is "
|
|
|
|
"expected to be divisible by the number of results");
|
2020-05-13 18:07:27 +08:00
|
|
|
// If split parameter is not specified, the dimension is split to
|
|
|
|
// equal-sized parts.
|
|
|
|
for (int i = 0; i < numOfResults; ++i)
|
|
|
|
splitLengths.emplace_back(inputShape[axisIndex] / numOfResults);
|
|
|
|
// Build attribute and store attribute.
|
|
|
|
auto builder = mlir::Builder(getContext());
|
|
|
|
splitAttr(builder.getI64ArrayAttr(llvm::makeArrayRef(splitLengths)));
|
|
|
|
}
|
|
|
|
|
|
|
|
// Build result types.
|
|
|
|
for (int i = 0; i < numOfResults; ++i) {
|
|
|
|
SmallVector<int64_t, 3> resultShape;
|
|
|
|
for (int j = 0; j < inputRank; ++j) {
|
|
|
|
if (j == axisIndex) {
|
|
|
|
resultShape.emplace_back(splitLengths[i]);
|
|
|
|
} else {
|
|
|
|
resultShape.emplace_back(inputShape[j]);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
getResults()[i].setType(
|
|
|
|
RankedTensorType::get(resultShape, inputType.getElementType()));
|
|
|
|
}
|
2020-05-27 10:09:28 +08:00
|
|
|
return success();
|
2020-05-13 18:07:27 +08:00
|
|
|
}
|
|
|
|
|
2020-06-09 14:55:49 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// Flatten
|
2020-06-15 11:49:09 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
2020-06-09 14:55:49 +08:00
|
|
|
|
|
|
|
LogicalResult ONNXFlattenOp::inferShapes() {
|
|
|
|
assert(axis() == 1 && "ONNXFlattenOp can only handle axis=1 for now");
|
|
|
|
auto inTy = input().getType().dyn_cast<ShapedType>();
|
|
|
|
if (!inTy) {
|
|
|
|
return emitOpError("Input is a non-shaped type");
|
|
|
|
}
|
|
|
|
auto outTy = output().getType().dyn_cast<ShapedType>();
|
|
|
|
if (!outTy) {
|
|
|
|
return emitOpError("Output is a non-shaped type");
|
|
|
|
}
|
|
|
|
|
|
|
|
// TODO(tjingrant): Seems like we can also fairly easily support the case
|
|
|
|
// where the batch dimension is dynamic
|
|
|
|
if (!outTy.hasStaticShape()) {
|
|
|
|
auto inShape = inTy.getShape();
|
|
|
|
assert(inShape.size() >= 1 && "ONNXFlattenOp inShape.size() should be > 0");
|
|
|
|
uint64_t outDim = 1;
|
|
|
|
for (auto it = inShape.begin() + 1; it < inShape.end(); it++) {
|
|
|
|
outDim *= *it;
|
|
|
|
}
|
|
|
|
|
|
|
|
SmallVector<int64_t, 2> dims;
|
|
|
|
// https://pytorch.org/docs/master/generated/torch.nn.Flatten.html
|
|
|
|
dims.emplace_back(inShape[0]);
|
|
|
|
dims.emplace_back(outDim);
|
|
|
|
getResult().setType(RankedTensorType::get(dims, outTy.getElementType()));
|
|
|
|
}
|
|
|
|
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// DynamicQuantizeLinear
|
2020-06-15 11:49:09 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
2020-06-09 14:55:49 +08:00
|
|
|
|
|
|
|
LogicalResult ONNXDynamicQuantizeLinearOp::inferShapes() {
|
|
|
|
auto inTy = x().getType().dyn_cast<RankedTensorType>();
|
|
|
|
if (!inTy || !inTy.hasStaticShape()) {
|
|
|
|
return emitOpError("Input is not a statically-shaped type");
|
|
|
|
}
|
|
|
|
|
|
|
|
auto yTy = y().getType().cast<ShapedType>();
|
|
|
|
auto yScaleTy = y_scale().getType().cast<ShapedType>();
|
|
|
|
auto yZPTy = y_zero_point().getType().cast<ShapedType>();
|
|
|
|
|
2020-06-26 04:34:37 +08:00
|
|
|
IntegerType ui8Type =
|
|
|
|
IntegerType::get(8, IntegerType::Unsigned, getContext());
|
|
|
|
FloatType f32Type = FloatType::getF32(getContext());
|
|
|
|
|
|
|
|
RankedTensorType scalarType = RankedTensorType::get({}, f32Type);
|
|
|
|
RankedTensorType y_zero_point_type = RankedTensorType::get({}, ui8Type);
|
2020-06-09 14:55:49 +08:00
|
|
|
|
|
|
|
// Set the types for the scalars
|
|
|
|
if (!yScaleTy.hasStaticShape()) {
|
|
|
|
y_scale().setType(scalarType);
|
|
|
|
}
|
|
|
|
|
|
|
|
if (!yZPTy.hasStaticShape()) {
|
2020-06-26 04:34:37 +08:00
|
|
|
y_zero_point().setType(y_zero_point_type);
|
2020-06-09 14:55:49 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
if (!yTy.hasStaticShape()) {
|
2020-06-26 04:34:37 +08:00
|
|
|
RankedTensorType outType = RankedTensorType::get(inTy.getShape(), ui8Type);
|
2020-06-09 14:55:49 +08:00
|
|
|
y().setType(outType);
|
|
|
|
}
|
|
|
|
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// QuantizeLinear
|
2020-06-15 11:49:09 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
2020-06-09 14:55:49 +08:00
|
|
|
|
|
|
|
LogicalResult ONNXQuantizeLinearOp::inferShapes() {
|
|
|
|
auto inTy = x().getType().dyn_cast<RankedTensorType>();
|
|
|
|
if (!inTy || !inTy.hasStaticShape()) {
|
|
|
|
return emitOpError("Input is not a statically-shaped type");
|
|
|
|
}
|
|
|
|
|
|
|
|
auto yTy = y().getType().cast<ShapedType>();
|
|
|
|
|
|
|
|
if (!yTy.hasStaticShape()) {
|
|
|
|
// TODO: Unfortunately, we can't tell if this should be signed or unsigned
|
|
|
|
// here...
|
|
|
|
IntegerType i8Type = IntegerType::get(8, getContext());
|
|
|
|
RankedTensorType outType = RankedTensorType::get(inTy.getShape(), i8Type);
|
|
|
|
y().setType(outType);
|
|
|
|
}
|
|
|
|
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// DequantizeLinear
|
2020-06-15 11:49:09 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
2020-06-09 14:55:49 +08:00
|
|
|
|
|
|
|
LogicalResult ONNXDequantizeLinearOp::inferShapes() {
|
|
|
|
auto inTy = x().getType().dyn_cast<RankedTensorType>();
|
|
|
|
if (!inTy || !inTy.hasStaticShape()) {
|
|
|
|
return emitOpError("Input is not a statically-shaped type");
|
|
|
|
}
|
|
|
|
|
|
|
|
auto yTy = y().getType().cast<ShapedType>();
|
|
|
|
|
|
|
|
if (!yTy.hasStaticShape()) {
|
|
|
|
FloatType f32 = FloatType::getF32(getContext());
|
|
|
|
RankedTensorType outType = RankedTensorType::get(inTy.getShape(), f32);
|
|
|
|
y().setType(outType);
|
|
|
|
}
|
|
|
|
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// ConvInteger - copied almost exactly from Conv (X -> x, W -> w, no bias)
|
2020-06-15 11:49:09 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
2020-06-09 14:55:49 +08:00
|
|
|
|
|
|
|
LogicalResult ONNXConvIntegerOp::inferShapes() {
|
|
|
|
// Generic shape for data input X, weight tensor W
|
|
|
|
// X: (N x C x D1 x D2 ... x Dn)
|
|
|
|
// W: (M x C/group x k1 x k2 x ... x kn)
|
|
|
|
|
|
|
|
// Cannot infer shape if no shape exists.
|
|
|
|
if (!x().getType().isa<RankedTensorType>() ||
|
|
|
|
!w().getType().isa<RankedTensorType>()) {
|
|
|
|
return emitOpError("Input tensor not ranked");
|
|
|
|
}
|
|
|
|
|
|
|
|
auto xTy = x().getType().cast<RankedTensorType>();
|
|
|
|
if (!xTy.getElementType().isInteger(8)) {
|
|
|
|
return emitOpError("Invalid input type");
|
|
|
|
}
|
|
|
|
auto xShape = xTy.getShape();
|
|
|
|
auto weightTy = w().getType().cast<RankedTensorType>();
|
|
|
|
if (!weightTy.getElementType().isInteger(8)) {
|
|
|
|
return emitOpError("Invalid input type");
|
|
|
|
}
|
|
|
|
auto weightShape = weightTy.getShape();
|
|
|
|
auto builder = mlir::Builder(this->getContext());
|
|
|
|
|
|
|
|
// Lowest supported convolution is a one dimensional convolution.
|
|
|
|
if (xShape.size() < 3) {
|
|
|
|
return emitOpError("Data input shape must be at least (NxCxD1)");
|
|
|
|
}
|
|
|
|
|
|
|
|
// Check that shape of weight and data have same length.
|
|
|
|
if (xShape.size() != weightShape.size()) {
|
|
|
|
return emitError("Weight size not compatible with data size");
|
|
|
|
}
|
|
|
|
|
|
|
|
// Group is a required attribute and should have default value of 1.
|
|
|
|
int64_t group = ONNXConvIntegerOp::group().getSExtValue();
|
|
|
|
|
|
|
|
// Check if the attribute actually exists. If it does not then add it.
|
|
|
|
if (!groupAttr())
|
|
|
|
groupAttr(builder.getI64IntegerAttr(group));
|
|
|
|
|
|
|
|
// Check that the X.shape[1] == (W.shape[1] * group) == C condition holds.
|
|
|
|
if (xShape[1] != -1 && weightShape[1] != -1 &&
|
|
|
|
xShape[1] != (weightShape[1] * group)) {
|
|
|
|
return emitOpError("Channel dimension mismatch");
|
|
|
|
}
|
|
|
|
|
|
|
|
// Note: the value of the group attribut only impacts the way the
|
|
|
|
// computation is carried out and not the actual output size.
|
|
|
|
|
|
|
|
// Number of spatial dimensions.
|
|
|
|
auto spatialOffset = 2;
|
|
|
|
int32_t spatialRank = xShape.size() - spatialOffset;
|
|
|
|
|
|
|
|
// Use kernel_shape attribute if present otherwise use size from weight
|
|
|
|
// argument.
|
|
|
|
auto kernelShape = kernel_shape();
|
|
|
|
if (kernelShape.hasValue()) {
|
|
|
|
if (ArrayAttrSize(kernelShape) != spatialRank) {
|
|
|
|
return emitOpError(
|
|
|
|
"kernel_shape length incompatible with spatial dimensions");
|
|
|
|
}
|
|
|
|
// Have the right number of values, check them.
|
|
|
|
for (int i = 0; i < spatialRank; ++i)
|
|
|
|
if (ArrayAttrIntVal(kernelShape, i) < 1) {
|
|
|
|
return emitError("bad kernel_shape value");
|
|
|
|
}
|
|
|
|
} else {
|
|
|
|
// Deduce shape from weight input.
|
|
|
|
SmallVector<int64_t, 2> defaultVals;
|
|
|
|
for (int i = 0; i < spatialRank; ++i)
|
|
|
|
defaultVals.emplace_back(weightShape[spatialOffset + i]);
|
|
|
|
// Convert to ArrayRef, then build attribute, then store attribute.
|
|
|
|
ArrayRef<int64_t> defaultRefs(defaultVals);
|
|
|
|
auto builder = mlir::Builder(getContext());
|
|
|
|
kernel_shapeAttr(builder.getI64ArrayAttr(defaultRefs));
|
|
|
|
kernelShape = kernel_shape();
|
|
|
|
}
|
|
|
|
|
|
|
|
// Process strides, dilations, and pads.
|
|
|
|
processConvTypeParams<>(this, x());
|
|
|
|
auto dilationsOpt = dilations();
|
|
|
|
auto stridesOpt = strides();
|
|
|
|
auto padsOpt = pads();
|
|
|
|
|
|
|
|
// First two output dimensions consist of the number of batches and the
|
|
|
|
// number of kernels being applied.
|
|
|
|
SmallVector<int64_t, 4> outputDims;
|
|
|
|
// Insert batch size.
|
|
|
|
outputDims.emplace_back(xShape[0]);
|
|
|
|
// Insert number of filters being applied (number of output channels).
|
|
|
|
outputDims.emplace_back(weightShape[0]);
|
|
|
|
// Compute and insert spatial dims.
|
|
|
|
insertConvSpatialDim(&outputDims, builder, xShape, kernelShape, padsOpt,
|
|
|
|
stridesOpt, dilationsOpt);
|
|
|
|
|
|
|
|
// ONNX spec specifies the output type as an int32
|
|
|
|
Type outputType = IntegerType::get(32, getContext());
|
|
|
|
getResult().setType(RankedTensorType::get(outputDims, outputType));
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
2020-07-22 22:15:56 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// Shape
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
LogicalResult ONNXShapeOp::inferShapes() {
|
|
|
|
// Cannot infer shape if no shape exists.
|
|
|
|
if (!data().getType().isa<RankedTensorType>())
|
|
|
|
return emitError("Input tensor not ranked");
|
|
|
|
|
|
|
|
// Output is an 1D int64 tensor containing the shape of the input tensor.
|
|
|
|
int64_t rank = data().getType().cast<RankedTensorType>().getRank();
|
|
|
|
SmallVector<int64_t, 1> outDims(1, rank);
|
|
|
|
getResult().setType(
|
|
|
|
RankedTensorType::get(outDims, IntegerType::get(64, getContext())));
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// Tile
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
LogicalResult ONNXTileOp::inferShapes() {
|
|
|
|
// Cannot infer shape if no shape exists.
|
|
|
|
if (!input().getType().isa<RankedTensorType>())
|
|
|
|
return emitError("Input tensor not ranked");
|
|
|
|
|
|
|
|
// Read 'repeats' value.
|
|
|
|
if (!repeats().getType().isa<RankedTensorType>())
|
|
|
|
return emitError("Repeats tensor not ranked");
|
|
|
|
|
|
|
|
auto inputTensorTy = input().getType().cast<RankedTensorType>();
|
|
|
|
auto repeatsTensorTy = repeats().getType().cast<RankedTensorType>();
|
|
|
|
|
|
|
|
// 'repeats' tensor is an 1D tensor.
|
|
|
|
if (repeatsTensorTy.getShape().size() != 1)
|
|
|
|
return emitError("Repeats tensor must have rank one");
|
|
|
|
|
|
|
|
// 'repeats' tensor must have constant shape.
|
|
|
|
int64_t repeatsLength = repeatsTensorTy.getShape()[0];
|
|
|
|
if (repeatsLength < 0)
|
|
|
|
return emitError("Repeats tensor must have constant shape");
|
|
|
|
|
|
|
|
// Check the 1D repeats tensor length.
|
|
|
|
int64_t inputRank = inputTensorTy.getShape().size();
|
|
|
|
if (inputRank != repeatsLength)
|
|
|
|
return emitError("Repeats tensor must have the same length as the input's "
|
|
|
|
"dimension number.");
|
|
|
|
|
|
|
|
// Check if second argument of TileOp is a constant.
|
|
|
|
auto constantOp = getONNXConstantOp(repeats());
|
|
|
|
|
|
|
|
// Compute output's dimensions: output_dim[i] = input_dim[i] * repeats[i]
|
|
|
|
SmallVector<int64_t, 2> dims(inputRank, -1);
|
|
|
|
if (constantOp) {
|
|
|
|
// 1. Initialize output_dim with values from 'input'.
|
|
|
|
// output_dim[i] = input[i]
|
|
|
|
for (decltype(inputRank) i = 0; i < inputRank; ++i)
|
|
|
|
dims[i] = inputTensorTy.getShape()[i];
|
|
|
|
|
|
|
|
// 2. Update output_dim using values from 'repeats'.
|
|
|
|
// Do this only for static 'input_dim[i]'.
|
|
|
|
// if (output_dim[i] != -1) output_dim[i] *= repeats[i]
|
|
|
|
DenseElementsAttr valueAttribute =
|
|
|
|
constantOp.valueAttr().dyn_cast<DenseElementsAttr>();
|
|
|
|
if (!valueAttribute)
|
|
|
|
return emitError("DenseElementsAttr expected");
|
|
|
|
// Get repeat values from valueAttribute.
|
|
|
|
auto valueIt = valueAttribute.getValues<IntegerAttr>().begin();
|
|
|
|
for (int i = 0; i < inputRank; ++i)
|
|
|
|
if (dims[i] != -1)
|
|
|
|
dims[i] *= (*valueIt++).cast<IntegerAttr>().getInt();
|
|
|
|
|
|
|
|
if (valueIt != valueAttribute.getValues<IntegerAttr>().end())
|
|
|
|
return emitError("Constant value must have same length as output's rank");
|
|
|
|
}
|
|
|
|
|
|
|
|
getResult().setType(
|
|
|
|
RankedTensorType::get(dims, inputTensorTy.getElementType()));
|
|
|
|
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// Gather
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
LogicalResult ONNXGatherOp::inferShapes() {
|
|
|
|
// Cannot infer shape if no shape exists.
|
|
|
|
if (!data().getType().isa<RankedTensorType>())
|
|
|
|
return emitError("Input tensor not ranked");
|
|
|
|
if (!indices().getType().isa<RankedTensorType>())
|
|
|
|
return emitError("Indices tensor not ranked");
|
|
|
|
|
|
|
|
auto inputShape = data().getType().cast<RankedTensorType>().getShape();
|
|
|
|
auto indicesShape = indices().getType().cast<RankedTensorType>().getShape();
|
|
|
|
int64_t inputRank = inputShape.size();
|
|
|
|
int64_t indicesRank = indicesShape.size();
|
|
|
|
|
|
|
|
if (inputRank < 1)
|
|
|
|
return emitError("Input tensor must have rank >= 1");
|
|
|
|
|
|
|
|
// Read 'axis' attribute.
|
|
|
|
auto axisIndex = axis().getSExtValue();
|
|
|
|
// 'axis' must be in [-rank, rank-1]
|
|
|
|
if (axisIndex < -inputRank || axisIndex >= inputRank)
|
|
|
|
return emitError("Gather axis value out of bound");
|
|
|
|
// Convert a negative axis to a positive axis.
|
|
|
|
if (axisIndex < 0) {
|
|
|
|
axisIndex += inputRank;
|
|
|
|
auto builder = mlir::Builder(getContext());
|
|
|
|
axisAttr(builder.getI64IntegerAttr(axisIndex));
|
|
|
|
}
|
|
|
|
|
|
|
|
// If 'indices' is a constant, check whether its values are valid or not.
|
|
|
|
auto constantOp = getONNXConstantOp(indices());
|
|
|
|
if (constantOp && inputShape[axisIndex] != -1) {
|
|
|
|
DenseElementsAttr valueAttribute =
|
|
|
|
constantOp.valueAttr().dyn_cast<DenseElementsAttr>();
|
|
|
|
if (!valueAttribute)
|
|
|
|
return emitError("DenseElementsAttr expected");
|
|
|
|
for (auto value : valueAttribute.getValues<IntegerAttr>()) {
|
|
|
|
auto index = value.cast<IntegerAttr>().getInt();
|
|
|
|
if (index < -inputShape[axisIndex] || index >= inputShape[axisIndex])
|
|
|
|
return emitError("Indices tensor contains an out-of-bound index");
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
// Output has rank of 'indicesRank + (inputRank - 1).
|
|
|
|
// Output shape is constructed from 'input' by:
|
|
|
|
// replacing the dimension at 'axis' in 'input' by the shape of 'indices'.
|
|
|
|
SmallVector<int64_t, 1> outDims;
|
|
|
|
for (decltype(inputRank) i = 0; i < inputRank; ++i) {
|
|
|
|
if (i == axisIndex)
|
|
|
|
for (decltype(indicesRank) j = 0; j < indicesRank; ++j)
|
|
|
|
outDims.emplace_back(indicesShape[j]);
|
|
|
|
else
|
|
|
|
outDims.emplace_back(inputShape[i]);
|
|
|
|
}
|
|
|
|
|
|
|
|
getResult().setType(RankedTensorType::get(
|
|
|
|
outDims, data().getType().cast<RankedTensorType>().getElementType()));
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// ConstantOfShape
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
LogicalResult ONNXConstantOfShapeOp::inferShapes() {
|
|
|
|
Type elementType;
|
|
|
|
|
|
|
|
// 'value' attribute is a one-element tensor whose value and datatype are used
|
|
|
|
// to set the output tensor's value and datatype..
|
|
|
|
if (value().hasValue()) {
|
|
|
|
elementType =
|
|
|
|
valueAttr().cast<DenseElementsAttr>().getType().getElementType();
|
|
|
|
} else {
|
|
|
|
// If 'value' attribute is not specified, it defaults to a tensor of value 0
|
|
|
|
// and datatype float32.
|
|
|
|
elementType = FloatType::getF32(getContext());
|
|
|
|
|
|
|
|
llvm::SmallVector<int64_t, 2> dims(1, 1);
|
|
|
|
auto tensorType = mlir::RankedTensorType::get(dims, elementType);
|
|
|
|
|
|
|
|
llvm::SmallVector<float, 1> values(1, 0.);
|
|
|
|
valueAttr(
|
|
|
|
mlir::DenseElementsAttr::get(tensorType, llvm::makeArrayRef(values)));
|
|
|
|
}
|
|
|
|
|
|
|
|
// 'input' must be a 1D tensor.
|
|
|
|
auto inputShape = input().getType().cast<RankedTensorType>().getShape();
|
|
|
|
if (inputShape.size() != 1)
|
|
|
|
return emitError("Input tensor must be a 1D tensor");
|
|
|
|
if (inputShape[0] == -1)
|
|
|
|
return emitError("Input tensor must have static shape");
|
|
|
|
if (inputShape[0] == 0) {
|
|
|
|
// If 'input' is an empty tensor, the output would be a scalar.
|
|
|
|
getResult().setType(RankedTensorType::get({}, elementType));
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
|
|
|
// Calculate output dimensions.
|
|
|
|
SmallVector<int64_t, 4> outputDims(inputShape[0], -1);
|
|
|
|
// If 'input' is a constant, check whether its values are valid or not.
|
|
|
|
// If the values are valid, it is possible to infer shape.
|
|
|
|
if (auto constantOp = getONNXConstantOp(input())) {
|
|
|
|
DenseElementsAttr valueAttribute =
|
|
|
|
constantOp.valueAttr().dyn_cast<DenseElementsAttr>();
|
|
|
|
// Get repeat values from valueAttribute.
|
|
|
|
auto valueIt = valueAttribute.getValues<IntegerAttr>().begin();
|
|
|
|
for (int i = 0; i < inputShape[0]; ++i) {
|
|
|
|
auto dim = (*valueIt++).cast<IntegerAttr>().getInt();
|
|
|
|
if (dim < 0)
|
|
|
|
return emitError("All values of the input tensor must be >=0");
|
|
|
|
outputDims[i] = dim;
|
|
|
|
}
|
|
|
|
|
|
|
|
if (valueIt != valueAttribute.getValues<IntegerAttr>().end())
|
|
|
|
return emitError("Constant value must have same length as output's rank");
|
|
|
|
}
|
|
|
|
|
|
|
|
getResult().setType(RankedTensorType::get(outputDims, elementType));
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// Slice
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
LogicalResult ONNXSliceOp::inferShapes() {
|
|
|
|
// Cannot infer shape if no shape exists.
|
|
|
|
if (!data().getType().isa<RankedTensorType>())
|
|
|
|
return emitError("Input tensor not ranked");
|
|
|
|
|
|
|
|
auto elementType = data().getType().cast<ShapedType>().getElementType();
|
|
|
|
auto dataShape = data().getType().cast<ShapedType>().getShape();
|
|
|
|
int64_t numDims = dataShape.size();
|
|
|
|
|
|
|
|
SmallVector<int64_t, 2> outputDims(numDims, -1);
|
|
|
|
// If 'starts', 'ends', 'axes', and 'steps' are constants, check whether their
|
|
|
|
// values are valid or not. If the values are valid, it is possible to infer
|
|
|
|
// shape.
|
|
|
|
//
|
|
|
|
// 'starts', 'ends', and 'steps' are for each axis in the list of axes, so
|
|
|
|
// processing 'axes' first.
|
|
|
|
|
|
|
|
// Check and get 'axes' tensor.
|
|
|
|
SmallVector<int64_t, 2> axesValue;
|
|
|
|
if (axes().getType().isa<NoneType>()) {
|
|
|
|
// If `axes` are omitted, they are set to `[0, ..., ndim-1]`."
|
|
|
|
for (int i = 0; i < numDims; ++i)
|
|
|
|
axesValue.emplace_back(i);
|
|
|
|
} else if (auto constantOp = getONNXConstantOp(axes())) {
|
|
|
|
// If `axes` are constants, read them."
|
|
|
|
DenseElementsAttr valueAttribute =
|
|
|
|
constantOp.valueAttr().dyn_cast<DenseElementsAttr>();
|
|
|
|
for (auto value : valueAttribute.getValues<IntegerAttr>()) {
|
|
|
|
int64_t axis = value.cast<IntegerAttr>().getInt();
|
|
|
|
if (axis < -numDims || axis >= numDims)
|
|
|
|
return emitError("Axes contains an out-of-bound index");
|
|
|
|
if (axis < 0)
|
|
|
|
axis += numDims;
|
|
|
|
if (dataShape[axis] == -1) {
|
|
|
|
// It is unsafe to infer shape for an axis with an unknown dimension,
|
|
|
|
// since we can not validate 'start' and 'end' values from this
|
|
|
|
// dimension.
|
|
|
|
getResult().setType(RankedTensorType::get(outputDims, elementType));
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
axesValue.emplace_back(axis);
|
|
|
|
}
|
|
|
|
} else {
|
|
|
|
// Cannot infer a static shape.
|
|
|
|
getResult().setType(RankedTensorType::get(outputDims, elementType));
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
|
|
|
// Check 'starts' tensor.
|
|
|
|
SmallVector<int64_t, 2> startsValue;
|
|
|
|
if (auto constantOp = getONNXConstantOp(starts())) {
|
|
|
|
DenseElementsAttr valueAttribute =
|
|
|
|
constantOp.valueAttr().dyn_cast<DenseElementsAttr>();
|
|
|
|
int i = 0;
|
|
|
|
for (auto value : valueAttribute.getValues<IntegerAttr>()) {
|
|
|
|
int64_t axis = axesValue[i];
|
|
|
|
int64_t index = value.cast<IntegerAttr>().getInt();
|
|
|
|
if (index < -dataShape[axis])
|
|
|
|
index = 0;
|
|
|
|
else if (index > dataShape[axis])
|
|
|
|
index = dataShape[axis];
|
|
|
|
else if (index < 0)
|
|
|
|
index += dataShape[axis];
|
|
|
|
startsValue.emplace_back(index);
|
|
|
|
i++;
|
|
|
|
}
|
|
|
|
if (i != axesValue.size())
|
|
|
|
emitError("starts and axes tensors must have the same length");
|
|
|
|
} else {
|
|
|
|
// Cannot infer a static shape.
|
|
|
|
getResult().setType(RankedTensorType::get(outputDims, elementType));
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
|
|
|
// Check 'ends' tensor.
|
|
|
|
SmallVector<int64_t, 2> endsValue;
|
|
|
|
if (auto constantOp = getONNXConstantOp(ends())) {
|
|
|
|
DenseElementsAttr valueAttribute =
|
|
|
|
constantOp.valueAttr().dyn_cast<DenseElementsAttr>();
|
|
|
|
int i = 0;
|
|
|
|
for (auto value : valueAttribute.getValues<IntegerAttr>()) {
|
|
|
|
int64_t axis = axesValue[i];
|
|
|
|
int64_t index = value.cast<IntegerAttr>().getInt();
|
|
|
|
if (index < -dataShape[axis])
|
|
|
|
index = 0;
|
|
|
|
else if (index > dataShape[axis])
|
|
|
|
index = dataShape[axis];
|
|
|
|
else if (index < 0)
|
|
|
|
index += dataShape[axis];
|
|
|
|
endsValue.emplace_back(index);
|
|
|
|
i++;
|
|
|
|
}
|
|
|
|
if (i != axesValue.size())
|
|
|
|
emitError("ends and axes tensors must have the same length");
|
|
|
|
} else {
|
|
|
|
// Cannot infer a static shape.
|
|
|
|
getResult().setType(RankedTensorType::get(outputDims, elementType));
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
|
|
|
// Check and get 'steps' tensor.
|
|
|
|
SmallVector<int64_t, 2> stepsValue;
|
|
|
|
if (steps().getType().isa<NoneType>()) {
|
|
|
|
// If `steps` are omitted, they are set to `[1, ..., 1]` of len(starts)."
|
|
|
|
for (int i = 0; i < startsValue.size(); ++i)
|
|
|
|
stepsValue.emplace_back(1);
|
|
|
|
} else if (auto constantOp = getONNXConstantOp(steps())) {
|
|
|
|
// If `steps` are constants, read them."
|
|
|
|
DenseElementsAttr valueAttribute =
|
|
|
|
constantOp.valueAttr().dyn_cast<DenseElementsAttr>();
|
|
|
|
int i = 0;
|
|
|
|
for (auto value : valueAttribute.getValues<IntegerAttr>()) {
|
|
|
|
int64_t index = value.cast<IntegerAttr>().getInt();
|
|
|
|
if (index == 0)
|
|
|
|
emitError("step cannot be zero");
|
|
|
|
stepsValue.emplace_back(index);
|
|
|
|
i++;
|
|
|
|
}
|
|
|
|
if (i != axesValue.size())
|
|
|
|
emitError("steps and axes tensors must have the same length");
|
|
|
|
} else {
|
|
|
|
// Cannot infer a static shape.
|
|
|
|
getResult().setType(RankedTensorType::get(outputDims, elementType));
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
|
|
|
// All 'starts', 'ends', 'steps' values are valid. Now calculate output
|
|
|
|
// dimensions for axes in 'axes'.
|
|
|
|
for (int i = 0; i < axesValue.size(); i++) {
|
|
|
|
int64_t axis = axesValue[i];
|
|
|
|
int64_t start = startsValue[i];
|
|
|
|
int64_t end = endsValue[i];
|
|
|
|
int64_t step = stepsValue[i];
|
|
|
|
if (step < 0)
|
|
|
|
step = -step;
|
|
|
|
|
|
|
|
int64_t q = (end - start) / step;
|
|
|
|
int64_t r = (end - start) % step;
|
|
|
|
if (r != 0)
|
|
|
|
q += 1;
|
|
|
|
outputDims[axis] = q;
|
|
|
|
}
|
|
|
|
|
|
|
|
getResult().setType(RankedTensorType::get(outputDims, elementType));
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
2019-11-02 05:09:48 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// TableGen'd op method definitions
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
#define GET_OP_CLASSES
|
2020-03-20 22:40:51 +08:00
|
|
|
#include "src/Dialect/ONNX/ONNXOps.cpp.inc"
|