onnx-mlir/src/dialect/onnx/onnx_ops.cpp

689 lines
26 KiB
C++
Raw Normal View History

//===- onnx_ops.cpp - MLIR ONNX Operations --------------------------------===//
//
// Copyright 2019 The IBM Research Authors.
//
// =============================================================================
//
// This file defines ONNX operations in the MLIR operation set.
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Traits.h"
#include "mlir/IR/Block.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/Function.h"
#include "mlir/IR/IntegerSet.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/Module.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/PatternMatch.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/SmallBitVector.h"
#include "onnx_ops.hpp"
using namespace mlir;
using namespace mlir::OpTrait::util;
//===----------------------------------------------------------------------===//
// ONNXOpsDialect
//===----------------------------------------------------------------------===//
/// Dialect creation, the instance will be owned by the context. This is the
/// point of registration of custom types and operations for the dialect.
ONNXOpsDialect::ONNXOpsDialect(mlir::MLIRContext *ctx)
: mlir::Dialect(getDialectNamespace(), ctx) {
addOperations<
#define GET_OP_LIST
2019-12-23 13:13:52 +08:00
#include "src/onnx.cpp.inc"
>();
}
void ONNXEntryPointOp::build(mlir::Builder *builder,
mlir::OperationState &state, mlir::FuncOp function,
int numInputs, int numOutputs) {
state.addAttribute(ONNXEntryPointOp::getEntryPointFuncAttrName(),
builder->getSymbolRefAttr(function));
state.addAttribute(ONNXEntryPointOp::getNumInputsAttrName(),
builder->getI32IntegerAttr(numInputs));
state.addAttribute(ONNXEntryPointOp::getNumOutputsAttrName(),
builder->getI32IntegerAttr(numOutputs));
}
ONNXEntryPointOp ONNXEntryPointOp::create(mlir::Location location,
mlir::FuncOp &func, int numInputs,
int numOutputs) {
mlir::OperationState state(location, "onnx.EntryPoint");
Builder builder(location->getContext());
mlir::ONNXEntryPointOp::build(&builder, state, func, numInputs, numOutputs);
Operation *op = mlir::Operation::create(state);
auto onnxEntryOp = llvm::cast<mlir::ONNXEntryPointOp>(op);
return onnxEntryOp;
}
//===----------------------------------------------------------------------===//
// ONNX Operations
//===----------------------------------------------------------------------===//
// Exp
/// Infer the output shape of the ONNXExpOp. This method is required by the
/// shape inference interface.
2020-01-14 01:21:29 +08:00
void ONNXExpOp::inferShapes() { getResult().setType(getOperand().getType()); }
//===----------------------------------------------------------------------===//
// Tanh
/// Infer the output shape of the ONNXTanhOp. This method is required by the
/// shape inference interface.
void ONNXTanhOp::inferShapes() {
2020-01-14 01:21:29 +08:00
getResult().setType(getOperand().getType());
}
//===----------------------------------------------------------------------===//
// Sinh
/// Infer the output shape of the ONNXSinhOp. This method is required by the
/// shape inference interface.
void ONNXSinhOp::inferShapes() {
2020-01-14 01:21:29 +08:00
getResult().setType(getOperand().getType());
}
//===----------------------------------------------------------------------===//
// Cosh
/// Infer the output shape of the ONNXCoshOp. This method is required by the
/// shape inference interface.
void ONNXCoshOp::inferShapes() {
2020-01-14 01:21:29 +08:00
getResult().setType(getOperand().getType());
}
//===----------------------------------------------------------------------===//
// Cos
/// Infer the output shape of the ONNXCosOp. This method is required by the
/// shape inference interface.
2020-01-14 01:21:29 +08:00
void ONNXCosOp::inferShapes() { getResult().setType(getOperand().getType()); }
//===----------------------------------------------------------------------===//
// Log
/// Infer the output shape of the ONNXLogOp. This method is required by the
/// shape inference interface.
2020-01-14 01:21:29 +08:00
void ONNXLogOp::inferShapes() { getResult().setType(getOperand().getType()); }
//===----------------------------------------------------------------------===//
// HardSigmoid
/// Infer the output shape of the ONNXHardSigmoidOp. This method is required by
/// the shape inference interface.
void ONNXHardSigmoidOp::inferShapes() {
2020-01-14 01:21:29 +08:00
getResult().setType(getOperand().getType());
}
//===----------------------------------------------------------------------===//
// Sigmoid
/// Infer the output shape of the ONNXSigmoidOp. This method is required by the
/// shape inference interface.
void ONNXSigmoidOp::inferShapes() {
2020-01-14 01:21:29 +08:00
getResult().setType(getOperand().getType());
}
//===----------------------------------------------------------------------===//
// Elu
/// Infer the output shape of the ONNXEluOp. This method is required by the
/// shape inference interface.
2020-01-14 01:21:29 +08:00
void ONNXEluOp::inferShapes() { getResult().setType(getOperand().getType()); }
//===----------------------------------------------------------------------===//
// Relu
/// Infer the output shape of the ONNXReluOp. This method is required by the
/// shape inference interface.
void ONNXReluOp::inferShapes() {
2020-01-14 01:21:29 +08:00
getResult().setType(getOperand().getType());
}
//===----------------------------------------------------------------------===//
// LeakyRelu
/// Infer the output shape of the ONNXLeakyReluOp. This method is required by
/// the shape inference interface.
void ONNXLeakyReluOp::inferShapes() {
2020-01-14 01:21:29 +08:00
getResult().setType(getOperand().getType());
}
//===----------------------------------------------------------------------===//
// Selu
/// Infer the output shape of the ONNXSeluOp. This method is required by
/// the shape inference interface.
void ONNXSeluOp::inferShapes() {
2020-01-14 01:21:29 +08:00
getResult().setType(getOperand().getType());
}
//===----------------------------------------------------------------------===//
// Reciprocal
/// Infer the output shape of the ONNXReciprocalOp. This method is required by
/// the shape inference interface.
void ONNXReciprocalOp::inferShapes() {
2020-01-14 01:21:29 +08:00
getResult().setType(getOperand().getType());
}
//===----------------------------------------------------------------------===//
// Softmax
/// Infer the output shape of the ONNXSoftmaxOp. This method is required by
/// the shape inference interface.
void ONNXSoftmaxOp::inferShapes() {
getResult().setType(getOperand().getType());
}
//===----------------------------------------------------------------------===//
// Softplus
/// Infer the output shape of the ONNXSoftplusOp. This method is required by
/// the shape inference interface.
void ONNXSoftplusOp::inferShapes() {
getResult().setType(getOperand().getType());
}
//===----------------------------------------------------------------------===//
// Softsign
/// Infer the output shape of the ONNXSoftsignOp. This method is required by
/// the shape inference interface.
void ONNXSoftsignOp::inferShapes() {
getResult().setType(getOperand().getType());
}
//===----------------------------------------------------------------------===//
// Add
/// Infer the output shape of the ONNXAddOp. This method is required by the
/// shape inference interface.
void ONNXAddOp::inferShapes() {
2020-01-14 01:21:29 +08:00
if (!getOperand(0).getType().isa<RankedTensorType>() ||
!getOperand(1).getType().isa<RankedTensorType>())
return;
2020-01-14 01:21:29 +08:00
auto lhsTy = getOperand(0).getType().cast<RankedTensorType>();
auto rhsTy = getOperand(1).getType().cast<RankedTensorType>();
getResult().setType(getBroadcastedType(lhsTy, rhsTy));
}
//===----------------------------------------------------------------------===//
// Mul
/// Infer the output shape of the ONNXMulOp. This method is required by the
/// shape inference interface.
void ONNXMulOp::inferShapes() {
2020-01-14 01:21:29 +08:00
if (!getOperand(0).getType().isa<RankedTensorType>() ||
!getOperand(1).getType().isa<RankedTensorType>())
return;
2020-01-14 01:21:29 +08:00
auto lhsTy = getOperand(0).getType().cast<RankedTensorType>();
auto rhsTy = getOperand(1).getType().cast<RankedTensorType>();
getResult().setType(getBroadcastedType(lhsTy, rhsTy));
}
//===----------------------------------------------------------------------===//
// Div
/// Infer the output shape of the ONNXDivOp. This method is required by the
/// shape inference interface.
void ONNXDivOp::inferShapes() {
2020-01-14 01:21:29 +08:00
if (!getOperand(0).getType().isa<RankedTensorType>() ||
!getOperand(1).getType().isa<RankedTensorType>())
return;
2020-01-14 01:21:29 +08:00
auto lhsTy = getOperand(0).getType().cast<RankedTensorType>();
auto rhsTy = getOperand(1).getType().cast<RankedTensorType>();
getResult().setType(getBroadcastedType(lhsTy, rhsTy));
}
//===----------------------------------------------------------------------===//
// Sub
/// Infer the output shape of the ONNXSubOp. This method is required by the
/// shape inference interface.
void ONNXSubOp::inferShapes() {
2020-01-14 01:21:29 +08:00
if (!getOperand(0).getType().isa<RankedTensorType>() ||
!getOperand(1).getType().isa<RankedTensorType>())
return;
2020-01-14 01:21:29 +08:00
auto lhsTy = getOperand(0).getType().cast<RankedTensorType>();
auto rhsTy = getOperand(1).getType().cast<RankedTensorType>();
getResult().setType(getBroadcastedType(lhsTy, rhsTy));
}
//===----------------------------------------------------------------------===//
// And
/// Infer the output shape of the ONNXAndOp. This method is required by the
/// shape inference interface.
void ONNXAndOp::inferShapes() {
2020-01-14 01:21:29 +08:00
if (!getOperand(0).getType().isa<RankedTensorType>() ||
!getOperand(1).getType().isa<RankedTensorType>())
return;
2020-01-14 01:21:29 +08:00
auto lhsTy = getOperand(0).getType().cast<RankedTensorType>();
auto rhsTy = getOperand(1).getType().cast<RankedTensorType>();
getResult().setType(getBroadcastedType(lhsTy, rhsTy));
}
//===----------------------------------------------------------------------===//
// Or
/// Infer the output shape of the ONNXOrOp. This method is required by the
/// shape inference interface.
void ONNXOrOp::inferShapes() {
2020-01-14 01:21:29 +08:00
if (!getOperand(0).getType().isa<RankedTensorType>() ||
!getOperand(1).getType().isa<RankedTensorType>())
return;
2020-01-14 01:21:29 +08:00
auto lhsTy = getOperand(0).getType().cast<RankedTensorType>();
auto rhsTy = getOperand(1).getType().cast<RankedTensorType>();
getResult().setType(getBroadcastedType(lhsTy, rhsTy));
}
//===----------------------------------------------------------------------===//
// Xor
/// Infer the output shape of the ONNXXorOp. This method is required by the
/// shape inference interface.
void ONNXXorOp::inferShapes() {
2020-01-14 01:21:29 +08:00
if (!getOperand(0).getType().isa<RankedTensorType>() ||
!getOperand(1).getType().isa<RankedTensorType>())
return;
2020-01-14 01:21:29 +08:00
auto lhsTy = getOperand(0).getType().cast<RankedTensorType>();
auto rhsTy = getOperand(1).getType().cast<RankedTensorType>();
getResult().setType(getBroadcastedType(lhsTy, rhsTy));
}
//===----------------------------------------------------------------------===//
//===----------------------------------------------------------------------===//
// Sum
/// Infer the output shape of the ONNXSumOp. This method is required by the
/// shape inference interface.
void ONNXSumOp::inferShapes() {
for (int i = 0; i < getNumOperands(); ++i) {
2020-01-14 01:21:29 +08:00
if (!getOperand(i).getType().cast<RankedTensorType>())
return;
}
2020-01-14 01:21:29 +08:00
Type resultTy = getOperand(0).getType().cast<RankedTensorType>();
for (int i = 1; i < getNumOperands(); ++i) {
2020-01-14 01:21:29 +08:00
Type nextTy = getOperand(i).getType().cast<RankedTensorType>();
resultTy = getBroadcastedType(resultTy, nextTy);
}
2020-01-14 01:21:29 +08:00
getResult().setType(resultTy);
}
//===----------------------------------------------------------------------===//
// Max
/// Infer the output shape of the ONNXMaxOp. This method is required by the
/// shape inference interface.
void ONNXMaxOp::inferShapes() {
for (int i = 0; i < getNumOperands(); ++i) {
2020-01-14 01:21:29 +08:00
if (!getOperand(i).getType().cast<RankedTensorType>())
return;
}
2020-01-14 01:21:29 +08:00
Type resultTy = getOperand(0).getType().cast<RankedTensorType>();
for (int i = 1; i < getNumOperands(); ++i) {
2020-01-14 01:21:29 +08:00
Type nextTy = getOperand(i).getType().cast<RankedTensorType>();
resultTy = getBroadcastedType(resultTy, nextTy);
}
2020-01-14 01:21:29 +08:00
getResult().setType(resultTy);
}
//===----------------------------------------------------------------------===//
// Min
/// Infer the output shape of the ONNXMinOp. This method is required by the
/// shape inference interface.
void ONNXMinOp::inferShapes() {
for (int i = 0; i < getNumOperands(); ++i) {
2020-01-14 01:21:29 +08:00
if (!getOperand(i).getType().cast<RankedTensorType>())
return;
}
2020-01-14 01:21:29 +08:00
Type resultTy = getOperand(0).getType().cast<RankedTensorType>();
for (int i = 1; i < getNumOperands(); ++i) {
2020-01-14 01:21:29 +08:00
Type nextTy = getOperand(i).getType().cast<RankedTensorType>();
resultTy = getBroadcastedType(resultTy, nextTy);
}
2020-01-14 01:21:29 +08:00
getResult().setType(resultTy);
}
//===----------------------------------------------------------------------===//
// Identity
/// Infer the output shape of the ONNXIdentityOp. This method is required by the
/// shape inference interface.
void ONNXIdentityOp::inferShapes() {
2020-01-14 01:21:29 +08:00
getResult().setType(getOperand().getType());
}
//===----------------------------------------------------------------------===//
// MatMul
void ONNXMatMulOp::inferShapes() {
// Cannot infer shape if no shape exists.
2020-01-14 01:21:29 +08:00
if (!getOperand(0).getType().isa<RankedTensorType>() ||
!getOperand(1).getType().isa<RankedTensorType>())
return;
2020-01-23 05:09:19 +08:00
auto lhsTy = getOperand(0).getType().cast<RankedTensorType>();
auto rhsTy = getOperand(1).getType().cast<RankedTensorType>();
SmallVector<int64_t, 2> dims;
auto lhsShape = lhsTy.getShape();
auto rhsShape = rhsTy.getShape();
if (lhsShape.size() < 1 && rhsShape.size() < 1) {
// Multiplication by scalars is not allowed.
emitError("Multiplication by scalar arguments not allowed.");
} else if (lhsShape.size() == 1 && rhsShape.size() == 1) {
// Special case when both arrays are 1-dimensional and according to
// numpy rules the types need to be extended to 1xN and Nx1. Helper sizes
// need to be removed after the multiplication but cannot be removed if all
// sizes are 1.
if (lhsShape[0] != -1 && rhsShape[0] != -1 &&
lhsShape[0] != rhsShape[0])
emitError("Attempt to multiply incompatible matrices.");
dims.emplace_back(1);
2020-01-11 01:30:12 +08:00
} else if (lhsShape.size() > 2 && rhsShape.size() == 2) {
// (s1 x s2 x... x sK x M x N) MATMUL (N x P)
// =>
// (s1 x s2 x... x sK x M x P)
// Check legality of matrix multiplication.
unsigned lhsRank = lhsShape.size();
if (lhsShape[lhsRank - 1] != -1 && rhsShape[0] != -1 &&
lhsShape[lhsRank - 1] != rhsShape[0])
2020-01-11 01:30:12 +08:00
emitError("Attempt to multiply incompatible matrices.");
for (int i = 0; i < lhsRank - 1; ++i)
2020-01-11 01:30:12 +08:00
dims.emplace_back(lhsShape[i]);
dims.emplace_back(rhsShape[1]);
} else if (lhsShape.size() == 2 && rhsShape.size() > 2) {
// (M x N) MATMUL (s1 x s2 x... x sK x N x P)
// =>
// (s1 x s2 x... x sK x M x P)
// Check legality of matrix multiplication.
unsigned rhsRank = rhsShape.size();
if (lhsShape[1] != -1 && rhsShape[rhsRank - 2] != -1 &&
lhsShape[1] != rhsShape[rhsRank - 2])
2020-01-11 01:30:12 +08:00
emitError("Attempt to multiply incompatible matrices.");
for (int i = 0; i < rhsRank - 2; ++i)
2020-01-11 01:30:12 +08:00
dims.emplace_back(rhsShape[i]);
dims.emplace_back(lhsShape[0]);
dims.emplace_back(rhsShape[rhsRank - 1]);
2020-01-11 01:30:12 +08:00
} else if (lhsShape.size() > 2 && rhsShape.size() > 2) {
// (s1 x s2 x... x sK x M x N) MATMUL (t1 x t2 x... x tK x N x P)
// =>
// (u1 x u2 x... x uK x M x P)
// Check legality of matrix multiplication.
unsigned lhsRank = lhsShape.size();
unsigned rhsRank = rhsShape.size();
if (lhsShape[lhsRank - 1] != -1 && rhsShape[rhsRank - 2] != -1 &&
lhsShape[lhsRank - 1] != rhsShape[rhsRank - 2])
2020-01-11 01:30:12 +08:00
emitError("Attempt to multiply incompatible matrices.");
// Check and perform broadcasting for the shapes.
SmallVector<int64_t, 2> lhsBcastShape;
for (int i = 0; i < lhsRank - 2; ++i)
2020-01-11 01:30:12 +08:00
lhsBcastShape.emplace_back(lhsShape[i]);
SmallVector<int64_t, 2> rhsBcastShape;
for (int i = 0; i < rhsRank - 2; ++i)
2020-01-11 01:30:12 +08:00
rhsBcastShape.emplace_back(rhsShape[i]);
if (!getBroadcastedShape(lhsBcastShape, rhsBcastShape, dims))
emitError("Broadcasted dimensions are incompatible.");
dims.emplace_back(lhsShape[lhsRank - 2]);
dims.emplace_back(rhsShape[rhsRank - 1]);
} else {
2020-01-11 01:30:12 +08:00
// This case covers all remaining combinations of 1 and 2-D matrices.
int64_t lhsDim = lhsShape[0];
int64_t rhsDim = rhsShape[0];
if (lhsShape.size() > 1) {
lhsDim = lhsShape[1];
dims.emplace_back(lhsShape[0]);
}
// Check legality of matrix multiplication.
if (lhsDim != -1 && rhsDim != -1 && lhsDim != rhsDim)
emitError("Attempt to multiply incompatible matrices.");
if (rhsShape.size() > 1)
dims.emplace_back(rhsShape[1]);
}
2020-01-23 05:09:19 +08:00
getResult().setType(RankedTensorType::get(dims, lhsTy.getElementType()));
}
//===----------------------------------------------------------------------===//
// Gemm
void ONNXGemmOp::inferShapes() {
// Cannot infer shape if no shape exists.
2020-01-14 01:21:29 +08:00
if (!getOperand(0).getType().isa<RankedTensorType>() ||
!getOperand(1).getType().isa<RankedTensorType>())
return;
2020-01-14 01:21:29 +08:00
auto lhsTy = getOperand(0).getType().cast<RankedTensorType>();
auto rhsTy = getOperand(1).getType().cast<RankedTensorType>();
SmallVector<int64_t, 2> dims;
dims.emplace_back(lhsTy.getShape()[0]);
dims.emplace_back(rhsTy.getShape()[1]);
2020-01-14 01:21:29 +08:00
getResult().setType(RankedTensorType::get(dims, lhsTy.getElementType()));
}
2020-01-16 03:27:21 +08:00
// GemmNoBias
2020-01-16 03:11:32 +08:00
void ONNXGemmNoBiasOp::inferShapes() {
// Cannot infer shape if no shape exists.
2020-01-14 01:21:29 +08:00
if (!getOperand(0).getType().isa<RankedTensorType>() ||
!getOperand(1).getType().isa<RankedTensorType>())
return;
2020-01-14 01:21:29 +08:00
auto lhsTy = getOperand(0).getType().cast<RankedTensorType>();
auto rhsTy = getOperand(1).getType().cast<RankedTensorType>();
SmallVector<int64_t, 2> dims;
dims.emplace_back(lhsTy.getShape()[0]);
dims.emplace_back(rhsTy.getShape()[1]);
2020-01-14 01:21:29 +08:00
getResult().setType(RankedTensorType::get(dims, lhsTy.getElementType()));
}
// TODO:
// Verify that matrix sizes are valid for multiplication and addition.
// Take into account the dimensionality of the matrix.
//===----------------------------------------------------------------------===//
// Reshape
void ONNXReshapeOp::inferShapes() {
// Cannot infer shape if no shape tensor is specified.
2020-01-14 01:21:29 +08:00
if (!getOperand(1).getType().isa<RankedTensorType>())
emitError("Shape tensor not ranked.");
2020-01-14 01:21:29 +08:00
auto inputTensorTy = getOperand(0).getType().cast<RankedTensorType>();
auto shapeTensorTy = getOperand(1).getType().cast<RankedTensorType>();
// Only rank 1 shape tensors are supported.
if (shapeTensorTy.getShape().size() != 1)
emitError("Shape tensor must have rank one.");
int64_t outputRank = shapeTensorTy.getShape()[0];
// Shape tensor must have constant shape.
if (outputRank < 0)
emitError("Shape tensor must have constant shape.");
SmallVector<int64_t, 2> dims;
for (int i = 0; i < outputRank; ++i)
dims.emplace_back(-1);
2020-01-14 01:21:29 +08:00
getResult().setType(
RankedTensorType::get(dims, inputTensorTy.getElementType()));
}
//===----------------------------------------------------------------------===//
// Transpose
void ONNXTransposeOp::inferShapes() {
// Cannot infer shape if no shape exists.
2020-01-14 01:21:29 +08:00
if (!getOperand().getType().isa<RankedTensorType>())
2020-01-22 09:39:11 +08:00
return;
// Naive transposition which handles the default case of
// reversing the shape of the tensor (similar to numpy.transpose).
2020-01-14 01:21:29 +08:00
auto arrayTy = getOperand().getType().cast<RankedTensorType>();
2020-01-14 07:08:19 +08:00
SmallVector<int64_t, 2> dims;
//if (auto permutation = getAttrOfType<ArrayAttr>(
// ONNXTransposeOp::getPermAttrName())) {
auto permutation = ONNXTransposeOp::permAttr();
if (permutation) {
2020-01-14 07:08:19 +08:00
// Perform transposition according to perm attribute.
for (auto perm : permutation.getValue())
dims.emplace_back(arrayTy.getShape()[perm.cast<IntegerAttr>().getInt()]);
2020-01-14 07:08:19 +08:00
} else {
// Default
for (auto dim : llvm::reverse(arrayTy.getShape()))
dims.emplace_back(dim);
2020-01-14 07:08:19 +08:00
}
2020-01-14 01:21:29 +08:00
getResult().setType(RankedTensorType::get(dims, arrayTy.getElementType()));
}
//===----------------------------------------------------------------------===//
// Conv
// For this operation, we define the attributes once in the original Conv
// operation class. There is no need to redefine the attribute names for the
// other classes based on Conv.
void ONNXConvNoBiasOp::inferShapes() {
// Generic shape for data input X and weight tensor W:
// X: (N x C x D1 x D2 ... x Dn)
// W: (M x C/group x k1 x k2 x ... x kn)
// Cannot infer shape if no shape exists.
if (!getOperand(0).getType().isa<RankedTensorType>() ||
!getOperand(1).getType().isa<RankedTensorType>())
return;
2020-01-22 09:39:11 +08:00
2020-01-21 04:46:15 +08:00
auto dataTy = getOperand(0).getType().cast<RankedTensorType>();
auto weightTy = getOperand(1).getType().cast<RankedTensorType>();
auto dataShape = dataTy.getShape();
auto weightShape = weightTy.getShape();
2020-01-21 07:50:21 +08:00
// Check that shape of weight and data have same length.
if (dataShape.size() != weightShape.size())
2020-01-21 07:50:21 +08:00
emitError("Weight size not compatible with data size.");
2020-01-21 07:50:21 +08:00
// Required attribute auto_pad defaults to NOTSET.
auto autoPad = auto_pad();
// Group is a required attribute and should have default value of 1.
int64_t group = ONNXConvNoBiasOp::group().getSExtValue(); //.getLimitedValue();
// Check that the X.shape[1] == (W.shape[1] * group) == C condition holds.
if (dataShape[1] != (weightShape[1] * group))
2020-01-21 07:50:21 +08:00
emitError("Channel dimension mismatch.");
2020-01-23 04:05:56 +08:00
// Note: the value of the group attribut only impacts the way the
// computation is carried out and not the actual output size.
2020-01-21 07:50:21 +08:00
// First two output dimensions consist of the number of batches and the
// number of kernels being applied.
//
SmallVector<int64_t, 2> dims;
// Insert batch size.
dims.emplace_back(dataShape[0]);
// Insert number of filters being applied (number of output channels).
dims.emplace_back(weightShape[0]);
2020-01-22 09:39:11 +08:00
// Spatial dimensions of the output are computed using the formula:
2020-01-21 07:50:21 +08:00
//
// dim = (inputDim - kernelDim + startPadding + endPadding) / stride + 1
//
2020-01-22 09:39:11 +08:00
SmallVector<int64_t, 2> outSpatialDims;
2020-01-21 07:50:21 +08:00
// Number of spatial dimensions.
int32_t nDims = dataShape.size() - 2;
// Initialize dimenions based on the input spatial dimensions.
for (int i = 2; i < dataShape.size(); ++i)
2020-01-22 09:39:11 +08:00
outSpatialDims.emplace_back(dataShape[i]);
2020-01-21 07:50:21 +08:00
// Use kernel_shape attribute if present otherwise use size from weight
// argument.
2020-01-22 09:39:11 +08:00
SmallVector<int64_t, 2> kernelDims;
if (auto kernelShape = kernel_shapeAttr()) {
2020-01-22 09:39:11 +08:00
if (kernelShape.getValue().size() != nDims)
2020-01-21 07:50:21 +08:00
emitError("kernel_shape length incompatible with spatial dimensions.");
2020-01-22 09:39:11 +08:00
for (int i = 0; i < nDims; ++i)
2020-01-22 23:10:06 +08:00
kernelDims.emplace_back(
(kernelShape.getValue()[i]).cast<IntegerAttr>().getInt());
2020-01-21 07:50:21 +08:00
} else {
for (int i = 0; i < nDims; ++i)
2020-01-22 23:10:06 +08:00
kernelDims.emplace_back(weightShape[i + 2]);
2020-01-21 07:50:21 +08:00
}
// Check if dilations attribute is present.
// If it is then compute new kernel size that includes the receptive field.
// In this calculation we assume that the receptive field pixels must all be
// within the bounds of the image. In this case the new kernel size is given
// by:
//
// ( K + 1 ) * d - 1
// where K is a kernel dimension and d is the dilation along that axis.
//
// From a dimensionality perspective the kernel size becomes the dilated
// kernel size.
if (auto dilations = dilationsAttr()) {
if (dilations.getValue().size() != nDims)
emitError("dilations length incompatible with spatial dimensions.");
for (int i = 0; i < nDims; ++i)
kernelDims[i] = (kernelDims[i] + 1) *
2020-01-23 05:34:59 +08:00
(dilations.getValue()[i]).cast<IntegerAttr>().getInt() - 1;
}
2020-01-22 09:39:11 +08:00
// Subtract kernel dimensions from input data dimensions.
for (int i = 0; i < nDims; ++i)
outSpatialDims[i] -= kernelDims[i];
2020-01-21 07:50:21 +08:00
// Add padding information.
if (autoPad == "NOTSET") {
// Use pads to to determine the padding. If attribute is not
// present then pads is considered to be all zeros (no padding).
if (auto pads = padsAttr()) {
2020-01-21 07:50:21 +08:00
// pads consists of two entries for each spatial axis.
if (pads.getValue().size() != 2 * nDims)
emitError("pads size is not twice the spatial size.");
for (int i = 0; i < nDims; ++i) {
// Padding for beginning of axis.
int32_t p = (pads.getValue()[i]).cast<IntegerAttr>().getInt();
2020-01-22 09:39:11 +08:00
outSpatialDims[i] += p;
2020-01-21 07:50:21 +08:00
// Padding for end of axis.
p = (pads.getValue()[i + nDims]).cast<IntegerAttr>().getInt();
2020-01-22 09:39:11 +08:00
outSpatialDims[i] += p;
2020-01-21 07:50:21 +08:00
}
}
2020-01-22 09:39:11 +08:00
} else if (autoPad == "SAME_UPPER" || autoPad == "SAME_LOWER") {
// Pad input so that output size matches input size.
// Each spatial dimension needs to be padded by a total of:
2020-01-22 09:39:11 +08:00
//
// K - 1
2020-01-22 09:39:11 +08:00
//
// where K is a kernel spatial dimension.
// Pad as if stride is 1.
2020-01-22 09:39:11 +08:00
for (int i = 0; i < nDims; ++i)
outSpatialDims[i] += kernelDims[i] - 1;
2020-01-21 07:50:21 +08:00
} else if (autoPad == "VALID") {
2020-01-22 09:39:11 +08:00
// No padding
2020-01-21 07:50:21 +08:00
} else {
emitError("Unexpected attribute value for auto_pad.");
}
// Strides
if (auto strides = ONNXConvNoBiasOp::stridesAttr()) {
2020-01-21 07:50:21 +08:00
if (strides.getValue().size() != nDims)
emitError("strides length incompatible with spatial dimensions.");
for (int i = 0; i < nDims; ++i) {
int64_t stride =
strides.getValue()[i].cast<IntegerAttr>().getInt();
2020-01-22 09:39:11 +08:00
outSpatialDims[i] = floor(outSpatialDims[i] / stride);
2020-01-21 07:50:21 +08:00
}
}
for (int i = 0; i < nDims; ++i)
2020-01-22 09:39:11 +08:00
outSpatialDims[i] += 1;
2020-01-22 09:39:11 +08:00
dims.append(outSpatialDims.begin(), outSpatialDims.end());
getResult().setType(RankedTensorType::get(dims, dataTy.getElementType()));
}
//===----------------------------------------------------------------------===//
// TableGen'd op method definitions
//===----------------------------------------------------------------------===//
#define GET_OP_CLASSES
2019-12-23 13:13:52 +08:00
#include "src/onnx.cpp.inc"