use input/output operation names, use helper for attribute function and int values (#106)

This commit is contained in:
Alexandre Eichenberger 2020-02-25 15:46:11 -05:00 committed by GitHub
parent 3b1c29c078
commit 3a88361b17
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 84 additions and 89 deletions

View File

@ -406,12 +406,12 @@ void ONNXIdentityOp::inferShapes() {
void ONNXMatMulOp::inferShapes() {
// Cannot infer shape if no shape exists.
if (!getOperand(0).getType().isa<RankedTensorType>() ||
!getOperand(1).getType().isa<RankedTensorType>())
if (!A().getType().isa<RankedTensorType>() ||
!B().getType().isa<RankedTensorType>())
return;
auto lhsTy = getOperand(0).getType().cast<RankedTensorType>();
auto rhsTy = getOperand(1).getType().cast<RankedTensorType>();
auto lhsTy = A().getType().cast<RankedTensorType>();
auto rhsTy = B().getType().cast<RankedTensorType>();
SmallVector<int64_t, 2> dims;
auto lhsShape = lhsTy.getShape();
@ -419,14 +419,14 @@ void ONNXMatMulOp::inferShapes() {
if (lhsShape.size() < 1 && rhsShape.size() < 1) {
// Multiplication by scalars is not allowed.
emitError("Multiplication by scalar arguments not allowed.");
emitError("Multiplication by scalar arguments not allowed");
} else if (lhsShape.size() == 1 && rhsShape.size() == 1) {
// Special case when both arrays are 1-dimensional and according to
// numpy rules the types need to be extended to 1xN and Nx1. Helper sizes
// need to be removed after the multiplication but cannot be removed if all
// sizes are 1.
if (lhsShape[0] != -1 && rhsShape[0] != -1 && lhsShape[0] != rhsShape[0])
emitError("Attempt to multiply incompatible matrices.");
emitError("Attempt to multiply incompatible matrices");
dims.emplace_back(1);
} else if (lhsShape.size() == 1 && rhsShape.size() >= 2) {
// If the first argument is 1-D, it is promoted to a matrix by prepending a
@ -441,7 +441,7 @@ void ONNXMatMulOp::inferShapes() {
unsigned rhsRank = rhsShape.size();
if (lhsShape[0] != -1 && rhsShape[rhsRank - 2] != -1 &&
lhsShape[0] != rhsShape[rhsRank - 2])
emitError("Attempt to multiply incompatible matrices.");
emitError("Attempt to multiply incompatible matrices");
for (decltype(rhsRank) i = 0; i < rhsRank - 2; ++i)
dims.emplace_back(rhsShape[i]);
@ -459,7 +459,7 @@ void ONNXMatMulOp::inferShapes() {
unsigned lhsRank = lhsShape.size();
if (lhsShape[lhsRank - 1] != -1 && rhsShape[0] != -1 &&
lhsShape[lhsRank - 1] != rhsShape[0])
emitError("Attempt to multiply incompatible matrices.");
emitError("Attempt to multiply incompatible matrices");
for (decltype(lhsRank) i = 0; i < lhsRank - 2; ++i)
dims.emplace_back(lhsShape[i]);
@ -473,7 +473,7 @@ void ONNXMatMulOp::inferShapes() {
unsigned lhsRank = lhsShape.size();
if (lhsShape[lhsRank - 1] != -1 && rhsShape[0] != -1 &&
lhsShape[lhsRank - 1] != rhsShape[0])
emitError("Attempt to multiply incompatible matrices.");
emitError("Attempt to multiply incompatible matrices");
for (decltype(lhsRank) i = 0; i < lhsRank - 1; ++i)
dims.emplace_back(lhsShape[i]);
@ -487,7 +487,7 @@ void ONNXMatMulOp::inferShapes() {
unsigned rhsRank = rhsShape.size();
if (lhsShape[1] != -1 && rhsShape[rhsRank - 2] != -1 &&
lhsShape[1] != rhsShape[rhsRank - 2])
emitError("Attempt to multiply incompatible matrices.");
emitError("Attempt to multiply incompatible matrices");
for (decltype(rhsRank) i = 0; i < rhsRank - 2; ++i)
dims.emplace_back(rhsShape[i]);
@ -503,7 +503,7 @@ void ONNXMatMulOp::inferShapes() {
unsigned rhsRank = rhsShape.size();
if (lhsShape[lhsRank - 1] != -1 && rhsShape[rhsRank - 2] != -1 &&
lhsShape[lhsRank - 1] != rhsShape[rhsRank - 2])
emitError("Attempt to multiply incompatible matrices.");
emitError("Attempt to multiply incompatible matrices");
// Check and perform broadcasting for the shapes.
SmallVector<int64_t, 2> lhsBcastShape;
@ -513,7 +513,7 @@ void ONNXMatMulOp::inferShapes() {
for (decltype(rhsRank) i = 0; i < rhsRank - 2; ++i)
rhsBcastShape.emplace_back(rhsShape[i]);
if (!getBroadcastedShape(lhsBcastShape, rhsBcastShape, dims))
emitError("Broadcasted dimensions are incompatible.");
emitError("Broadcasted dimensions are incompatible");
dims.emplace_back(lhsShape[lhsRank - 2]);
dims.emplace_back(rhsShape[rhsRank - 1]);
@ -528,7 +528,7 @@ void ONNXMatMulOp::inferShapes() {
// Check legality of matrix multiplication.
if (lhsDim != -1 && rhsDim != -1 && lhsDim != rhsDim)
emitError("Attempt to multiply incompatible matrices.");
emitError("Attempt to multiply incompatible matrices");
if (rhsShape.size() > 1)
dims.emplace_back(rhsShape[1]);
@ -542,14 +542,14 @@ void ONNXMatMulOp::inferShapes() {
// Gemm
void ONNXGemmOp::inferShapes() {
bool hasBias = !getOperand(2).getType().isa<NoneType>();
bool hasBias = !C().getType().isa<NoneType>();
// Cannot infer shape if no shape exists.
if (!getOperand(0).getType().isa<RankedTensorType>() ||
!getOperand(1).getType().isa<RankedTensorType>() ||
(hasBias && !getOperand(2).getType().isa<RankedTensorType>()))
if (!A().getType().isa<RankedTensorType>() ||
!B().getType().isa<RankedTensorType>() ||
(hasBias && !C().getType().isa<RankedTensorType>()))
return;
auto lhsTy = getOperand(0).getType().cast<RankedTensorType>();
auto rhsTy = getOperand(1).getType().cast<RankedTensorType>();
auto lhsTy = A().getType().cast<RankedTensorType>();
auto rhsTy = B().getType().cast<RankedTensorType>();
int64_t M, N, K_A, K_B;
M = (transA() == 0) ? lhsTy.getShape()[0] : lhsTy.getShape()[1];
@ -558,12 +558,12 @@ void ONNXGemmOp::inferShapes() {
K_B = (transB() == 0) ? rhsTy.getShape()[0] : rhsTy.getShape()[1];
if ((K_A != -1) and (K_B != -1) and (K_A != K_B)) {
emitError("Tensor shapes mismatched.");
emitError("Tensor shapes mismatched");
}
if (hasBias) {
// Check whether bias is unidirectional broadcasting or not.
auto biasTy = getOperand(2).getType().cast<RankedTensorType>();
auto biasTy = C().getType().cast<RankedTensorType>();
auto shape = biasTy.getShape();
int rank = shape.size();
if ((rank > 2) ||
@ -571,7 +571,7 @@ void ONNXGemmOp::inferShapes() {
N != shape[rank - 1] && shape[rank - 1] != 1) ||
(rank == 2 && shape[rank - 2] != -1 && M != -1 &&
M != shape[rank - 2] && shape[rank - 2] != 1)) {
emitError("Bias shape mismatched.");
emitError("Bias shape mismatched");
}
}
@ -584,50 +584,50 @@ void ONNXGemmOp::inferShapes() {
/// BatchNormalizationTestMode
void ONNXBatchNormalizationTestModeOp::inferShapes() {
// Cannot infer shape if no shape exists.
if (!getOperand(0).getType().isa<RankedTensorType>() ||
!getOperand(1).getType().isa<RankedTensorType>() ||
!getOperand(2).getType().isa<RankedTensorType>() ||
!getOperand(3).getType().isa<RankedTensorType>() ||
!getOperand(4).getType().isa<RankedTensorType>())
if (!X().getType().isa<RankedTensorType>() ||
!scale().getType().isa<RankedTensorType>() ||
!B().getType().isa<RankedTensorType>() ||
!mean().getType().isa<RankedTensorType>() ||
!var().getType().isa<RankedTensorType>())
return;
auto input = getOperand(0).getType().cast<RankedTensorType>();
auto scale = getOperand(1).getType().cast<RankedTensorType>();
auto bias = getOperand(2).getType().cast<RankedTensorType>();
auto mean = getOperand(3).getType().cast<RankedTensorType>();
auto variance = getOperand(4).getType().cast<RankedTensorType>();
auto inputTensorTy = X().getType().cast<RankedTensorType>();
auto scaleTensorTy = scale().getType().cast<RankedTensorType>();
auto biasTensorTy = B().getType().cast<RankedTensorType>();
auto meanTensorTy = mean().getType().cast<RankedTensorType>();
auto varianceTensorTy = var().getType().cast<RankedTensorType>();
// Check whether the shapes of scale, bias, mean and variance are valid.
// Operand's dimensions can be in the form of NxCxD1xD2x...xDn or N.
// In case of N, C is assumed to be 1.
// Shapes of scale, bias, mean and variance must be C.
int64_t c = -1;
if (input.getShape().size() == 1) {
if (inputTensorTy.getShape().size() == 1) {
c = 1;
} else if (input.getShape().size() > 2) {
c = (input.getShape()[1] != -1) ? input.getShape()[1] : -1;
} else if (inputTensorTy.getShape().size() > 2) {
c = (inputTensorTy.getShape()[1] != -1) ? inputTensorTy.getShape()[1] : -1;
} else {
emitError("Wrong rank for the input.");
emitError("Wrong rank for the input");
}
if (c != -1) {
auto s = scale.getShape();
auto b = bias.getShape();
auto m = mean.getShape();
auto v = variance.getShape();
auto s = scaleTensorTy.getShape();
auto b = biasTensorTy.getShape();
auto m = meanTensorTy.getShape();
auto v = varianceTensorTy.getShape();
if ((s.size() != 1) || (s[0] != -1 && s[0] != c))
emitError("Wrong rank for the scale.");
emitError("Wrong rank for the scale");
if ((b.size() != 1) || (b[0] != -1 && b[0] != c))
emitError("Wrong rank for the bias.");
emitError("Wrong rank for the bias");
if ((m.size() != 1) || (m[0] != -1 && m[0] != c))
emitError("Wrong rank for the mean.");
emitError("Wrong rank for the mean");
if ((v.size() != 1) || (v[0] != -1 && v[0] != c))
emitError("Wrong rank for the variance.");
emitError("Wrong rank for the variance");
}
// The output tensor of the same shape as the input.
getResult().setType(getOperand(0).getType());
getResult().setType(X().getType());
}
// TODO:
@ -640,21 +640,21 @@ void ONNXBatchNormalizationTestModeOp::inferShapes() {
void ONNXReshapeOp::inferShapes() {
// Cannot infer shape if no shape tensor is specified.
if (!getOperand(1).getType().isa<RankedTensorType>())
emitError("Shape tensor not ranked.");
if (!shape().getType().isa<RankedTensorType>())
emitError("Shape tensor not ranked");
auto inputTensorTy = getOperand(0).getType().cast<RankedTensorType>();
auto shapeTensorTy = getOperand(1).getType().cast<RankedTensorType>();
auto inputTensorTy = data().getType().cast<RankedTensorType>();
auto shapeTensorTy = shape().getType().cast<RankedTensorType>();
// Only rank 1 shape tensors are supported.
if (shapeTensorTy.getShape().size() != 1)
emitError("Shape tensor must have rank one.");
emitError("Shape tensor must have rank one");
int64_t outputRank = shapeTensorTy.getShape()[0];
// Shape tensor must have constant shape.
if (outputRank < 0)
emitError("Shape tensor must have constant shape.");
emitError("Shape tensor must have constant shape");
SmallVector<int64_t, 2> dims;
for (int i = 0; i < outputRank; ++i)
@ -670,12 +670,12 @@ void ONNXReshapeOp::inferShapes() {
void ONNXTransposeOp::inferShapes() {
// Cannot infer shape if no shape exists.
if (!getOperand().getType().isa<RankedTensorType>())
if (!data().getType().isa<RankedTensorType>())
return;
// Naive transposition which handles the default case of
// reversing the shape of the tensor (similar to numpy.transpose).
auto arrayTy = getOperand().getType().cast<RankedTensorType>();
auto arrayTy = data().getType().cast<RankedTensorType>();
SmallVector<int64_t, 2> dims;
auto permutation = ONNXTransposeOp::permAttr();
if (permutation) {
@ -697,7 +697,7 @@ void ONNXTransposeOp::inferShapes() {
void ONNXReduceMaxOp::inferShapes() {
if (!getOperand().getType().isa<RankedTensorType>()) {
emitError("Shape tensor not ranked.");
emitError("Shape tensor not ranked");
return;
}
@ -711,7 +711,7 @@ void ONNXReduceMaxOp::inferShapes() {
void ONNXReduceMinOp::inferShapes() {
if (!getOperand().getType().isa<RankedTensorType>()) {
emitError("Shape tensor not ranked.");
emitError("Shape tensor not ranked");
return;
}
@ -725,7 +725,7 @@ void ONNXReduceMinOp::inferShapes() {
void ONNXReduceProdOp::inferShapes() {
if (!getOperand().getType().isa<RankedTensorType>()) {
emitError("Shape tensor not ranked.");
emitError("Shape tensor not ranked");
return;
}
@ -739,7 +739,7 @@ void ONNXReduceProdOp::inferShapes() {
void ONNXReduceSumOp::inferShapes() {
if (!getOperand().getType().isa<RankedTensorType>()) {
emitError("Shape tensor not ranked.");
emitError("Shape tensor not ranked");
return;
}
@ -758,22 +758,22 @@ void ONNXConvNoBiasOp::inferShapes() {
// W: (M x C/group x k1 x k2 x ... x kn)
// Cannot infer shape if no shape exists.
if (!getOperand(0).getType().isa<RankedTensorType>() ||
!getOperand(1).getType().isa<RankedTensorType>())
if (!X().getType().isa<RankedTensorType>() ||
!W().getType().isa<RankedTensorType>())
return;
auto dataTy = getOperand(0).getType().cast<RankedTensorType>();
auto weightTy = getOperand(1).getType().cast<RankedTensorType>();
auto dataTy = X().getType().cast<RankedTensorType>();
auto weightTy = W().getType().cast<RankedTensorType>();
auto dataShape = dataTy.getShape();
auto weightShape = weightTy.getShape();
// Lowest supported convolution is a one dimensional convolution.
if (dataShape.size() < 3)
emitError("Data input shape must be at least (NxCxD1).");
emitError("Data input shape must be at least (NxCxD1)");
// Check that shape of weight and data have same length.
if (dataShape.size() != weightShape.size())
emitError("Weight size not compatible with data size.");
emitError("Weight size not compatible with data size");
// Required attribute auto_pad defaults to NOTSET.
auto autoPad = auto_pad();
@ -782,7 +782,7 @@ void ONNXConvNoBiasOp::inferShapes() {
ONNXConvNoBiasOp::group().getSExtValue(); //.getLimitedValue();
// Check that the X.shape[1] == (W.shape[1] * group) == C condition holds.
if (dataShape[1] != (weightShape[1] * group))
emitError("Channel dimension mismatch.");
emitError("Channel dimension mismatch");
// Note: the value of the group attribut only impacts the way the
// computation is carried out and not the actual output size.
@ -812,11 +812,10 @@ void ONNXConvNoBiasOp::inferShapes() {
// argument.
SmallVector<int64_t, 2> kernelDims;
if (auto kernelShape = kernel_shapeAttr()) {
if (kernelShape.getValue().size() != nDims)
emitError("kernel_shape length incompatible with spatial dimensions.");
if (ArrayAttrSize(kernelShape) != nDims)
emitError("kernel_shape length incompatible with spatial dimensions");
for (int i = 0; i < nDims; ++i)
kernelDims.emplace_back(
(kernelShape.getValue()[i]).cast<IntegerAttr>().getInt());
kernelDims.emplace_back(ArrayAttrIntVal(kernelShape, i));
} else {
for (int i = 0; i < nDims; ++i)
kernelDims.emplace_back(weightShape[i + 2]);
@ -834,13 +833,11 @@ void ONNXConvNoBiasOp::inferShapes() {
// From a dimensionality perspective the kernel size becomes the dilated
// kernel size.
if (auto dilations = dilationsAttr()) {
if (dilations.getValue().size() != nDims)
emitError("dilations length incompatible with spatial dimensions.");
if (ArrayAttrSize(dilations) != nDims)
emitError("dilations length incompatible with spatial dimensions");
for (int i = 0; i < nDims; ++i)
kernelDims[i] =
(kernelDims[i] + 1) *
(dilations.getValue()[i]).cast<IntegerAttr>().getInt() -
1;
(kernelDims[i] + 1) * ArrayAttrIntVal(dilations, i) - 1;
}
// Subtract kernel dimensions from input data dimensions.
@ -853,16 +850,14 @@ void ONNXConvNoBiasOp::inferShapes() {
// present then pads is considered to be all zeros (no padding).
if (auto pads = padsAttr()) {
// pads consists of two entries for each spatial axis.
if (pads.getValue().size() != 2 * nDims)
emitError("pads size is not twice the spatial size.");
if (ArrayAttrSize(pads) != 2 * nDims)
emitError("pads size is not twice the spatial size");
for (int i = 0; i < nDims; ++i) {
// Padding for beginning of axis.
int32_t p = (pads.getValue()[i]).cast<IntegerAttr>().getInt();
outSpatialDims[i] += p;
outSpatialDims[i] += ArrayAttrIntVal(pads, i);
// Padding for end of axis.
p = (pads.getValue()[i + nDims]).cast<IntegerAttr>().getInt();
outSpatialDims[i] += p;
outSpatialDims[i] += ArrayAttrIntVal(pads, i + nDims);
}
}
} else if (autoPad == "SAME_UPPER" || autoPad == "SAME_LOWER") {
@ -878,15 +873,15 @@ void ONNXConvNoBiasOp::inferShapes() {
} else if (autoPad == "VALID") {
// No padding
} else {
emitError("Unexpected attribute value for auto_pad.");
emitError("Unexpected attribute value for auto_pad");
}
// Strides
if (auto strides = ONNXConvNoBiasOp::stridesAttr()) {
if (strides.getValue().size() != nDims)
emitError("strides length incompatible with spatial dimensions.");
if (ArrayAttrSize(strides) != nDims)
emitError("strides length incompatible with spatial dimensions");
for (int i = 0; i < nDims; ++i) {
int64_t stride = strides.getValue()[i].cast<IntegerAttr>().getInt();
int64_t stride = ArrayAttrIntVal(strides, i);
outSpatialDims[i] = floor(outSpatialDims[i] / stride);
}
}
@ -1013,7 +1008,7 @@ void ONNXMaxPoolSingleOutOp::inferShapes() {
}
}
} else if (autoPad != "VALID") {
emitError("auto_pad of unknown / unsupported value.");
emitError("auto_pad of unknown / unsupported value");
}
// Set pads values in attributes.
{
@ -1044,7 +1039,7 @@ void ONNXMaxPoolSingleOutOp::inferShapes() {
}
yShape[kernelOffset + i] = res;
}
auto arrayTy = getOperand().getType().cast<RankedTensorType>();
auto arrayTy = X().getType().cast<RankedTensorType>();
getResult().setType(RankedTensorType::get(yShape, arrayTy.getElementType()));
}
@ -1053,10 +1048,10 @@ void ONNXMaxPoolSingleOutOp::inferShapes() {
// Unsqueeze
void ONNXUnsqueezeOp::inferShapes() {
if (!getOperand().getType().isa<RankedTensorType>())
if (!data().getType().isa<RankedTensorType>())
return;
auto operandTy = getOperand().getType().cast<RankedTensorType>();
auto operandTy = data().getType().cast<RankedTensorType>();
int inRank = operandTy.getRank();
ArrayAttr axisAttrs = axesAttr();
@ -1072,10 +1067,10 @@ void ONNXUnsqueezeOp::inferShapes() {
if (std::find(axes.begin(), axes.end(), axis) == axes.end())
axes.emplace_back(axis);
else
emitError("Duplicated axes.");
emitError("Duplicated axes");
}
} else {
emitError("Axes attribute is required.");
emitError("Axes attribute is required");
}
SmallVector<int64_t, 4> dims;