use input/output operation names, use helper for attribute function and int values (#106)

This commit is contained in:
Alexandre Eichenberger 2020-02-25 15:46:11 -05:00 committed by GitHub
parent 3b1c29c078
commit 3a88361b17
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 84 additions and 89 deletions

View File

@ -406,12 +406,12 @@ void ONNXIdentityOp::inferShapes() {
void ONNXMatMulOp::inferShapes() { void ONNXMatMulOp::inferShapes() {
// Cannot infer shape if no shape exists. // Cannot infer shape if no shape exists.
if (!getOperand(0).getType().isa<RankedTensorType>() || if (!A().getType().isa<RankedTensorType>() ||
!getOperand(1).getType().isa<RankedTensorType>()) !B().getType().isa<RankedTensorType>())
return; return;
auto lhsTy = getOperand(0).getType().cast<RankedTensorType>(); auto lhsTy = A().getType().cast<RankedTensorType>();
auto rhsTy = getOperand(1).getType().cast<RankedTensorType>(); auto rhsTy = B().getType().cast<RankedTensorType>();
SmallVector<int64_t, 2> dims; SmallVector<int64_t, 2> dims;
auto lhsShape = lhsTy.getShape(); auto lhsShape = lhsTy.getShape();
@ -419,14 +419,14 @@ void ONNXMatMulOp::inferShapes() {
if (lhsShape.size() < 1 && rhsShape.size() < 1) { if (lhsShape.size() < 1 && rhsShape.size() < 1) {
// Multiplication by scalars is not allowed. // Multiplication by scalars is not allowed.
emitError("Multiplication by scalar arguments not allowed."); emitError("Multiplication by scalar arguments not allowed");
} else if (lhsShape.size() == 1 && rhsShape.size() == 1) { } else if (lhsShape.size() == 1 && rhsShape.size() == 1) {
// Special case when both arrays are 1-dimensional and according to // Special case when both arrays are 1-dimensional and according to
// numpy rules the types need to be extended to 1xN and Nx1. Helper sizes // numpy rules the types need to be extended to 1xN and Nx1. Helper sizes
// need to be removed after the multiplication but cannot be removed if all // need to be removed after the multiplication but cannot be removed if all
// sizes are 1. // sizes are 1.
if (lhsShape[0] != -1 && rhsShape[0] != -1 && lhsShape[0] != rhsShape[0]) if (lhsShape[0] != -1 && rhsShape[0] != -1 && lhsShape[0] != rhsShape[0])
emitError("Attempt to multiply incompatible matrices."); emitError("Attempt to multiply incompatible matrices");
dims.emplace_back(1); dims.emplace_back(1);
} else if (lhsShape.size() == 1 && rhsShape.size() >= 2) { } else if (lhsShape.size() == 1 && rhsShape.size() >= 2) {
// If the first argument is 1-D, it is promoted to a matrix by prepending a // If the first argument is 1-D, it is promoted to a matrix by prepending a
@ -441,7 +441,7 @@ void ONNXMatMulOp::inferShapes() {
unsigned rhsRank = rhsShape.size(); unsigned rhsRank = rhsShape.size();
if (lhsShape[0] != -1 && rhsShape[rhsRank - 2] != -1 && if (lhsShape[0] != -1 && rhsShape[rhsRank - 2] != -1 &&
lhsShape[0] != rhsShape[rhsRank - 2]) lhsShape[0] != rhsShape[rhsRank - 2])
emitError("Attempt to multiply incompatible matrices."); emitError("Attempt to multiply incompatible matrices");
for (decltype(rhsRank) i = 0; i < rhsRank - 2; ++i) for (decltype(rhsRank) i = 0; i < rhsRank - 2; ++i)
dims.emplace_back(rhsShape[i]); dims.emplace_back(rhsShape[i]);
@ -459,7 +459,7 @@ void ONNXMatMulOp::inferShapes() {
unsigned lhsRank = lhsShape.size(); unsigned lhsRank = lhsShape.size();
if (lhsShape[lhsRank - 1] != -1 && rhsShape[0] != -1 && if (lhsShape[lhsRank - 1] != -1 && rhsShape[0] != -1 &&
lhsShape[lhsRank - 1] != rhsShape[0]) lhsShape[lhsRank - 1] != rhsShape[0])
emitError("Attempt to multiply incompatible matrices."); emitError("Attempt to multiply incompatible matrices");
for (decltype(lhsRank) i = 0; i < lhsRank - 2; ++i) for (decltype(lhsRank) i = 0; i < lhsRank - 2; ++i)
dims.emplace_back(lhsShape[i]); dims.emplace_back(lhsShape[i]);
@ -473,7 +473,7 @@ void ONNXMatMulOp::inferShapes() {
unsigned lhsRank = lhsShape.size(); unsigned lhsRank = lhsShape.size();
if (lhsShape[lhsRank - 1] != -1 && rhsShape[0] != -1 && if (lhsShape[lhsRank - 1] != -1 && rhsShape[0] != -1 &&
lhsShape[lhsRank - 1] != rhsShape[0]) lhsShape[lhsRank - 1] != rhsShape[0])
emitError("Attempt to multiply incompatible matrices."); emitError("Attempt to multiply incompatible matrices");
for (decltype(lhsRank) i = 0; i < lhsRank - 1; ++i) for (decltype(lhsRank) i = 0; i < lhsRank - 1; ++i)
dims.emplace_back(lhsShape[i]); dims.emplace_back(lhsShape[i]);
@ -487,7 +487,7 @@ void ONNXMatMulOp::inferShapes() {
unsigned rhsRank = rhsShape.size(); unsigned rhsRank = rhsShape.size();
if (lhsShape[1] != -1 && rhsShape[rhsRank - 2] != -1 && if (lhsShape[1] != -1 && rhsShape[rhsRank - 2] != -1 &&
lhsShape[1] != rhsShape[rhsRank - 2]) lhsShape[1] != rhsShape[rhsRank - 2])
emitError("Attempt to multiply incompatible matrices."); emitError("Attempt to multiply incompatible matrices");
for (decltype(rhsRank) i = 0; i < rhsRank - 2; ++i) for (decltype(rhsRank) i = 0; i < rhsRank - 2; ++i)
dims.emplace_back(rhsShape[i]); dims.emplace_back(rhsShape[i]);
@ -503,7 +503,7 @@ void ONNXMatMulOp::inferShapes() {
unsigned rhsRank = rhsShape.size(); unsigned rhsRank = rhsShape.size();
if (lhsShape[lhsRank - 1] != -1 && rhsShape[rhsRank - 2] != -1 && if (lhsShape[lhsRank - 1] != -1 && rhsShape[rhsRank - 2] != -1 &&
lhsShape[lhsRank - 1] != rhsShape[rhsRank - 2]) lhsShape[lhsRank - 1] != rhsShape[rhsRank - 2])
emitError("Attempt to multiply incompatible matrices."); emitError("Attempt to multiply incompatible matrices");
// Check and perform broadcasting for the shapes. // Check and perform broadcasting for the shapes.
SmallVector<int64_t, 2> lhsBcastShape; SmallVector<int64_t, 2> lhsBcastShape;
@ -513,7 +513,7 @@ void ONNXMatMulOp::inferShapes() {
for (decltype(rhsRank) i = 0; i < rhsRank - 2; ++i) for (decltype(rhsRank) i = 0; i < rhsRank - 2; ++i)
rhsBcastShape.emplace_back(rhsShape[i]); rhsBcastShape.emplace_back(rhsShape[i]);
if (!getBroadcastedShape(lhsBcastShape, rhsBcastShape, dims)) if (!getBroadcastedShape(lhsBcastShape, rhsBcastShape, dims))
emitError("Broadcasted dimensions are incompatible."); emitError("Broadcasted dimensions are incompatible");
dims.emplace_back(lhsShape[lhsRank - 2]); dims.emplace_back(lhsShape[lhsRank - 2]);
dims.emplace_back(rhsShape[rhsRank - 1]); dims.emplace_back(rhsShape[rhsRank - 1]);
@ -528,7 +528,7 @@ void ONNXMatMulOp::inferShapes() {
// Check legality of matrix multiplication. // Check legality of matrix multiplication.
if (lhsDim != -1 && rhsDim != -1 && lhsDim != rhsDim) if (lhsDim != -1 && rhsDim != -1 && lhsDim != rhsDim)
emitError("Attempt to multiply incompatible matrices."); emitError("Attempt to multiply incompatible matrices");
if (rhsShape.size() > 1) if (rhsShape.size() > 1)
dims.emplace_back(rhsShape[1]); dims.emplace_back(rhsShape[1]);
@ -542,14 +542,14 @@ void ONNXMatMulOp::inferShapes() {
// Gemm // Gemm
void ONNXGemmOp::inferShapes() { void ONNXGemmOp::inferShapes() {
bool hasBias = !getOperand(2).getType().isa<NoneType>(); bool hasBias = !C().getType().isa<NoneType>();
// Cannot infer shape if no shape exists. // Cannot infer shape if no shape exists.
if (!getOperand(0).getType().isa<RankedTensorType>() || if (!A().getType().isa<RankedTensorType>() ||
!getOperand(1).getType().isa<RankedTensorType>() || !B().getType().isa<RankedTensorType>() ||
(hasBias && !getOperand(2).getType().isa<RankedTensorType>())) (hasBias && !C().getType().isa<RankedTensorType>()))
return; return;
auto lhsTy = getOperand(0).getType().cast<RankedTensorType>(); auto lhsTy = A().getType().cast<RankedTensorType>();
auto rhsTy = getOperand(1).getType().cast<RankedTensorType>(); auto rhsTy = B().getType().cast<RankedTensorType>();
int64_t M, N, K_A, K_B; int64_t M, N, K_A, K_B;
M = (transA() == 0) ? lhsTy.getShape()[0] : lhsTy.getShape()[1]; M = (transA() == 0) ? lhsTy.getShape()[0] : lhsTy.getShape()[1];
@ -558,12 +558,12 @@ void ONNXGemmOp::inferShapes() {
K_B = (transB() == 0) ? rhsTy.getShape()[0] : rhsTy.getShape()[1]; K_B = (transB() == 0) ? rhsTy.getShape()[0] : rhsTy.getShape()[1];
if ((K_A != -1) and (K_B != -1) and (K_A != K_B)) { if ((K_A != -1) and (K_B != -1) and (K_A != K_B)) {
emitError("Tensor shapes mismatched."); emitError("Tensor shapes mismatched");
} }
if (hasBias) { if (hasBias) {
// Check whether bias is unidirectional broadcasting or not. // Check whether bias is unidirectional broadcasting or not.
auto biasTy = getOperand(2).getType().cast<RankedTensorType>(); auto biasTy = C().getType().cast<RankedTensorType>();
auto shape = biasTy.getShape(); auto shape = biasTy.getShape();
int rank = shape.size(); int rank = shape.size();
if ((rank > 2) || if ((rank > 2) ||
@ -571,7 +571,7 @@ void ONNXGemmOp::inferShapes() {
N != shape[rank - 1] && shape[rank - 1] != 1) || N != shape[rank - 1] && shape[rank - 1] != 1) ||
(rank == 2 && shape[rank - 2] != -1 && M != -1 && (rank == 2 && shape[rank - 2] != -1 && M != -1 &&
M != shape[rank - 2] && shape[rank - 2] != 1)) { M != shape[rank - 2] && shape[rank - 2] != 1)) {
emitError("Bias shape mismatched."); emitError("Bias shape mismatched");
} }
} }
@ -584,50 +584,50 @@ void ONNXGemmOp::inferShapes() {
/// BatchNormalizationTestMode /// BatchNormalizationTestMode
void ONNXBatchNormalizationTestModeOp::inferShapes() { void ONNXBatchNormalizationTestModeOp::inferShapes() {
// Cannot infer shape if no shape exists. // Cannot infer shape if no shape exists.
if (!getOperand(0).getType().isa<RankedTensorType>() || if (!X().getType().isa<RankedTensorType>() ||
!getOperand(1).getType().isa<RankedTensorType>() || !scale().getType().isa<RankedTensorType>() ||
!getOperand(2).getType().isa<RankedTensorType>() || !B().getType().isa<RankedTensorType>() ||
!getOperand(3).getType().isa<RankedTensorType>() || !mean().getType().isa<RankedTensorType>() ||
!getOperand(4).getType().isa<RankedTensorType>()) !var().getType().isa<RankedTensorType>())
return; return;
auto input = getOperand(0).getType().cast<RankedTensorType>(); auto inputTensorTy = X().getType().cast<RankedTensorType>();
auto scale = getOperand(1).getType().cast<RankedTensorType>(); auto scaleTensorTy = scale().getType().cast<RankedTensorType>();
auto bias = getOperand(2).getType().cast<RankedTensorType>(); auto biasTensorTy = B().getType().cast<RankedTensorType>();
auto mean = getOperand(3).getType().cast<RankedTensorType>(); auto meanTensorTy = mean().getType().cast<RankedTensorType>();
auto variance = getOperand(4).getType().cast<RankedTensorType>(); auto varianceTensorTy = var().getType().cast<RankedTensorType>();
// Check whether the shapes of scale, bias, mean and variance are valid. // Check whether the shapes of scale, bias, mean and variance are valid.
// Operand's dimensions can be in the form of NxCxD1xD2x...xDn or N. // Operand's dimensions can be in the form of NxCxD1xD2x...xDn or N.
// In case of N, C is assumed to be 1. // In case of N, C is assumed to be 1.
// Shapes of scale, bias, mean and variance must be C. // Shapes of scale, bias, mean and variance must be C.
int64_t c = -1; int64_t c = -1;
if (input.getShape().size() == 1) { if (inputTensorTy.getShape().size() == 1) {
c = 1; c = 1;
} else if (input.getShape().size() > 2) { } else if (inputTensorTy.getShape().size() > 2) {
c = (input.getShape()[1] != -1) ? input.getShape()[1] : -1; c = (inputTensorTy.getShape()[1] != -1) ? inputTensorTy.getShape()[1] : -1;
} else { } else {
emitError("Wrong rank for the input."); emitError("Wrong rank for the input");
} }
if (c != -1) { if (c != -1) {
auto s = scale.getShape(); auto s = scaleTensorTy.getShape();
auto b = bias.getShape(); auto b = biasTensorTy.getShape();
auto m = mean.getShape(); auto m = meanTensorTy.getShape();
auto v = variance.getShape(); auto v = varianceTensorTy.getShape();
if ((s.size() != 1) || (s[0] != -1 && s[0] != c)) if ((s.size() != 1) || (s[0] != -1 && s[0] != c))
emitError("Wrong rank for the scale."); emitError("Wrong rank for the scale");
if ((b.size() != 1) || (b[0] != -1 && b[0] != c)) if ((b.size() != 1) || (b[0] != -1 && b[0] != c))
emitError("Wrong rank for the bias."); emitError("Wrong rank for the bias");
if ((m.size() != 1) || (m[0] != -1 && m[0] != c)) if ((m.size() != 1) || (m[0] != -1 && m[0] != c))
emitError("Wrong rank for the mean."); emitError("Wrong rank for the mean");
if ((v.size() != 1) || (v[0] != -1 && v[0] != c)) if ((v.size() != 1) || (v[0] != -1 && v[0] != c))
emitError("Wrong rank for the variance."); emitError("Wrong rank for the variance");
} }
// The output tensor of the same shape as the input. // The output tensor of the same shape as the input.
getResult().setType(getOperand(0).getType()); getResult().setType(X().getType());
} }
// TODO: // TODO:
@ -640,21 +640,21 @@ void ONNXBatchNormalizationTestModeOp::inferShapes() {
void ONNXReshapeOp::inferShapes() { void ONNXReshapeOp::inferShapes() {
// Cannot infer shape if no shape tensor is specified. // Cannot infer shape if no shape tensor is specified.
if (!getOperand(1).getType().isa<RankedTensorType>()) if (!shape().getType().isa<RankedTensorType>())
emitError("Shape tensor not ranked."); emitError("Shape tensor not ranked");
auto inputTensorTy = getOperand(0).getType().cast<RankedTensorType>(); auto inputTensorTy = data().getType().cast<RankedTensorType>();
auto shapeTensorTy = getOperand(1).getType().cast<RankedTensorType>(); auto shapeTensorTy = shape().getType().cast<RankedTensorType>();
// Only rank 1 shape tensors are supported. // Only rank 1 shape tensors are supported.
if (shapeTensorTy.getShape().size() != 1) if (shapeTensorTy.getShape().size() != 1)
emitError("Shape tensor must have rank one."); emitError("Shape tensor must have rank one");
int64_t outputRank = shapeTensorTy.getShape()[0]; int64_t outputRank = shapeTensorTy.getShape()[0];
// Shape tensor must have constant shape. // Shape tensor must have constant shape.
if (outputRank < 0) if (outputRank < 0)
emitError("Shape tensor must have constant shape."); emitError("Shape tensor must have constant shape");
SmallVector<int64_t, 2> dims; SmallVector<int64_t, 2> dims;
for (int i = 0; i < outputRank; ++i) for (int i = 0; i < outputRank; ++i)
@ -670,12 +670,12 @@ void ONNXReshapeOp::inferShapes() {
void ONNXTransposeOp::inferShapes() { void ONNXTransposeOp::inferShapes() {
// Cannot infer shape if no shape exists. // Cannot infer shape if no shape exists.
if (!getOperand().getType().isa<RankedTensorType>()) if (!data().getType().isa<RankedTensorType>())
return; return;
// Naive transposition which handles the default case of // Naive transposition which handles the default case of
// reversing the shape of the tensor (similar to numpy.transpose). // reversing the shape of the tensor (similar to numpy.transpose).
auto arrayTy = getOperand().getType().cast<RankedTensorType>(); auto arrayTy = data().getType().cast<RankedTensorType>();
SmallVector<int64_t, 2> dims; SmallVector<int64_t, 2> dims;
auto permutation = ONNXTransposeOp::permAttr(); auto permutation = ONNXTransposeOp::permAttr();
if (permutation) { if (permutation) {
@ -697,7 +697,7 @@ void ONNXTransposeOp::inferShapes() {
void ONNXReduceMaxOp::inferShapes() { void ONNXReduceMaxOp::inferShapes() {
if (!getOperand().getType().isa<RankedTensorType>()) { if (!getOperand().getType().isa<RankedTensorType>()) {
emitError("Shape tensor not ranked."); emitError("Shape tensor not ranked");
return; return;
} }
@ -711,7 +711,7 @@ void ONNXReduceMaxOp::inferShapes() {
void ONNXReduceMinOp::inferShapes() { void ONNXReduceMinOp::inferShapes() {
if (!getOperand().getType().isa<RankedTensorType>()) { if (!getOperand().getType().isa<RankedTensorType>()) {
emitError("Shape tensor not ranked."); emitError("Shape tensor not ranked");
return; return;
} }
@ -725,7 +725,7 @@ void ONNXReduceMinOp::inferShapes() {
void ONNXReduceProdOp::inferShapes() { void ONNXReduceProdOp::inferShapes() {
if (!getOperand().getType().isa<RankedTensorType>()) { if (!getOperand().getType().isa<RankedTensorType>()) {
emitError("Shape tensor not ranked."); emitError("Shape tensor not ranked");
return; return;
} }
@ -739,7 +739,7 @@ void ONNXReduceProdOp::inferShapes() {
void ONNXReduceSumOp::inferShapes() { void ONNXReduceSumOp::inferShapes() {
if (!getOperand().getType().isa<RankedTensorType>()) { if (!getOperand().getType().isa<RankedTensorType>()) {
emitError("Shape tensor not ranked."); emitError("Shape tensor not ranked");
return; return;
} }
@ -758,22 +758,22 @@ void ONNXConvNoBiasOp::inferShapes() {
// W: (M x C/group x k1 x k2 x ... x kn) // W: (M x C/group x k1 x k2 x ... x kn)
// Cannot infer shape if no shape exists. // Cannot infer shape if no shape exists.
if (!getOperand(0).getType().isa<RankedTensorType>() || if (!X().getType().isa<RankedTensorType>() ||
!getOperand(1).getType().isa<RankedTensorType>()) !W().getType().isa<RankedTensorType>())
return; return;
auto dataTy = getOperand(0).getType().cast<RankedTensorType>(); auto dataTy = X().getType().cast<RankedTensorType>();
auto weightTy = getOperand(1).getType().cast<RankedTensorType>(); auto weightTy = W().getType().cast<RankedTensorType>();
auto dataShape = dataTy.getShape(); auto dataShape = dataTy.getShape();
auto weightShape = weightTy.getShape(); auto weightShape = weightTy.getShape();
// Lowest supported convolution is a one dimensional convolution. // Lowest supported convolution is a one dimensional convolution.
if (dataShape.size() < 3) if (dataShape.size() < 3)
emitError("Data input shape must be at least (NxCxD1)."); emitError("Data input shape must be at least (NxCxD1)");
// Check that shape of weight and data have same length. // Check that shape of weight and data have same length.
if (dataShape.size() != weightShape.size()) if (dataShape.size() != weightShape.size())
emitError("Weight size not compatible with data size."); emitError("Weight size not compatible with data size");
// Required attribute auto_pad defaults to NOTSET. // Required attribute auto_pad defaults to NOTSET.
auto autoPad = auto_pad(); auto autoPad = auto_pad();
@ -782,7 +782,7 @@ void ONNXConvNoBiasOp::inferShapes() {
ONNXConvNoBiasOp::group().getSExtValue(); //.getLimitedValue(); ONNXConvNoBiasOp::group().getSExtValue(); //.getLimitedValue();
// Check that the X.shape[1] == (W.shape[1] * group) == C condition holds. // Check that the X.shape[1] == (W.shape[1] * group) == C condition holds.
if (dataShape[1] != (weightShape[1] * group)) if (dataShape[1] != (weightShape[1] * group))
emitError("Channel dimension mismatch."); emitError("Channel dimension mismatch");
// Note: the value of the group attribut only impacts the way the // Note: the value of the group attribut only impacts the way the
// computation is carried out and not the actual output size. // computation is carried out and not the actual output size.
@ -812,11 +812,10 @@ void ONNXConvNoBiasOp::inferShapes() {
// argument. // argument.
SmallVector<int64_t, 2> kernelDims; SmallVector<int64_t, 2> kernelDims;
if (auto kernelShape = kernel_shapeAttr()) { if (auto kernelShape = kernel_shapeAttr()) {
if (kernelShape.getValue().size() != nDims) if (ArrayAttrSize(kernelShape) != nDims)
emitError("kernel_shape length incompatible with spatial dimensions."); emitError("kernel_shape length incompatible with spatial dimensions");
for (int i = 0; i < nDims; ++i) for (int i = 0; i < nDims; ++i)
kernelDims.emplace_back( kernelDims.emplace_back(ArrayAttrIntVal(kernelShape, i));
(kernelShape.getValue()[i]).cast<IntegerAttr>().getInt());
} else { } else {
for (int i = 0; i < nDims; ++i) for (int i = 0; i < nDims; ++i)
kernelDims.emplace_back(weightShape[i + 2]); kernelDims.emplace_back(weightShape[i + 2]);
@ -834,13 +833,11 @@ void ONNXConvNoBiasOp::inferShapes() {
// From a dimensionality perspective the kernel size becomes the dilated // From a dimensionality perspective the kernel size becomes the dilated
// kernel size. // kernel size.
if (auto dilations = dilationsAttr()) { if (auto dilations = dilationsAttr()) {
if (dilations.getValue().size() != nDims) if (ArrayAttrSize(dilations) != nDims)
emitError("dilations length incompatible with spatial dimensions."); emitError("dilations length incompatible with spatial dimensions");
for (int i = 0; i < nDims; ++i) for (int i = 0; i < nDims; ++i)
kernelDims[i] = kernelDims[i] =
(kernelDims[i] + 1) * (kernelDims[i] + 1) * ArrayAttrIntVal(dilations, i) - 1;
(dilations.getValue()[i]).cast<IntegerAttr>().getInt() -
1;
} }
// Subtract kernel dimensions from input data dimensions. // Subtract kernel dimensions from input data dimensions.
@ -853,16 +850,14 @@ void ONNXConvNoBiasOp::inferShapes() {
// present then pads is considered to be all zeros (no padding). // present then pads is considered to be all zeros (no padding).
if (auto pads = padsAttr()) { if (auto pads = padsAttr()) {
// pads consists of two entries for each spatial axis. // pads consists of two entries for each spatial axis.
if (pads.getValue().size() != 2 * nDims) if (ArrayAttrSize(pads) != 2 * nDims)
emitError("pads size is not twice the spatial size."); emitError("pads size is not twice the spatial size");
for (int i = 0; i < nDims; ++i) { for (int i = 0; i < nDims; ++i) {
// Padding for beginning of axis. // Padding for beginning of axis.
int32_t p = (pads.getValue()[i]).cast<IntegerAttr>().getInt(); outSpatialDims[i] += ArrayAttrIntVal(pads, i);
outSpatialDims[i] += p;
// Padding for end of axis. // Padding for end of axis.
p = (pads.getValue()[i + nDims]).cast<IntegerAttr>().getInt(); outSpatialDims[i] += ArrayAttrIntVal(pads, i + nDims);
outSpatialDims[i] += p;
} }
} }
} else if (autoPad == "SAME_UPPER" || autoPad == "SAME_LOWER") { } else if (autoPad == "SAME_UPPER" || autoPad == "SAME_LOWER") {
@ -878,15 +873,15 @@ void ONNXConvNoBiasOp::inferShapes() {
} else if (autoPad == "VALID") { } else if (autoPad == "VALID") {
// No padding // No padding
} else { } else {
emitError("Unexpected attribute value for auto_pad."); emitError("Unexpected attribute value for auto_pad");
} }
// Strides // Strides
if (auto strides = ONNXConvNoBiasOp::stridesAttr()) { if (auto strides = ONNXConvNoBiasOp::stridesAttr()) {
if (strides.getValue().size() != nDims) if (ArrayAttrSize(strides) != nDims)
emitError("strides length incompatible with spatial dimensions."); emitError("strides length incompatible with spatial dimensions");
for (int i = 0; i < nDims; ++i) { for (int i = 0; i < nDims; ++i) {
int64_t stride = strides.getValue()[i].cast<IntegerAttr>().getInt(); int64_t stride = ArrayAttrIntVal(strides, i);
outSpatialDims[i] = floor(outSpatialDims[i] / stride); outSpatialDims[i] = floor(outSpatialDims[i] / stride);
} }
} }
@ -1013,7 +1008,7 @@ void ONNXMaxPoolSingleOutOp::inferShapes() {
} }
} }
} else if (autoPad != "VALID") { } else if (autoPad != "VALID") {
emitError("auto_pad of unknown / unsupported value."); emitError("auto_pad of unknown / unsupported value");
} }
// Set pads values in attributes. // Set pads values in attributes.
{ {
@ -1044,7 +1039,7 @@ void ONNXMaxPoolSingleOutOp::inferShapes() {
} }
yShape[kernelOffset + i] = res; yShape[kernelOffset + i] = res;
} }
auto arrayTy = getOperand().getType().cast<RankedTensorType>(); auto arrayTy = X().getType().cast<RankedTensorType>();
getResult().setType(RankedTensorType::get(yShape, arrayTy.getElementType())); getResult().setType(RankedTensorType::get(yShape, arrayTy.getElementType()));
} }
@ -1053,10 +1048,10 @@ void ONNXMaxPoolSingleOutOp::inferShapes() {
// Unsqueeze // Unsqueeze
void ONNXUnsqueezeOp::inferShapes() { void ONNXUnsqueezeOp::inferShapes() {
if (!getOperand().getType().isa<RankedTensorType>()) if (!data().getType().isa<RankedTensorType>())
return; return;
auto operandTy = getOperand().getType().cast<RankedTensorType>(); auto operandTy = data().getType().cast<RankedTensorType>();
int inRank = operandTy.getRank(); int inRank = operandTy.getRank();
ArrayAttr axisAttrs = axesAttr(); ArrayAttr axisAttrs = axesAttr();
@ -1072,10 +1067,10 @@ void ONNXUnsqueezeOp::inferShapes() {
if (std::find(axes.begin(), axes.end(), axis) == axes.end()) if (std::find(axes.begin(), axes.end(), axis) == axes.end())
axes.emplace_back(axis); axes.emplace_back(axis);
else else
emitError("Duplicated axes."); emitError("Duplicated axes");
} }
} else { } else {
emitError("Axes attribute is required."); emitError("Axes attribute is required");
} }
SmallVector<int64_t, 4> dims; SmallVector<int64_t, 4> dims;