use input/output operation names, use helper for attribute function and int values (#106)
This commit is contained in:
parent
3b1c29c078
commit
3a88361b17
|
@ -406,12 +406,12 @@ void ONNXIdentityOp::inferShapes() {
|
|||
|
||||
void ONNXMatMulOp::inferShapes() {
|
||||
// Cannot infer shape if no shape exists.
|
||||
if (!getOperand(0).getType().isa<RankedTensorType>() ||
|
||||
!getOperand(1).getType().isa<RankedTensorType>())
|
||||
if (!A().getType().isa<RankedTensorType>() ||
|
||||
!B().getType().isa<RankedTensorType>())
|
||||
return;
|
||||
|
||||
auto lhsTy = getOperand(0).getType().cast<RankedTensorType>();
|
||||
auto rhsTy = getOperand(1).getType().cast<RankedTensorType>();
|
||||
auto lhsTy = A().getType().cast<RankedTensorType>();
|
||||
auto rhsTy = B().getType().cast<RankedTensorType>();
|
||||
|
||||
SmallVector<int64_t, 2> dims;
|
||||
auto lhsShape = lhsTy.getShape();
|
||||
|
@ -419,14 +419,14 @@ void ONNXMatMulOp::inferShapes() {
|
|||
|
||||
if (lhsShape.size() < 1 && rhsShape.size() < 1) {
|
||||
// Multiplication by scalars is not allowed.
|
||||
emitError("Multiplication by scalar arguments not allowed.");
|
||||
emitError("Multiplication by scalar arguments not allowed");
|
||||
} else if (lhsShape.size() == 1 && rhsShape.size() == 1) {
|
||||
// Special case when both arrays are 1-dimensional and according to
|
||||
// numpy rules the types need to be extended to 1xN and Nx1. Helper sizes
|
||||
// need to be removed after the multiplication but cannot be removed if all
|
||||
// sizes are 1.
|
||||
if (lhsShape[0] != -1 && rhsShape[0] != -1 && lhsShape[0] != rhsShape[0])
|
||||
emitError("Attempt to multiply incompatible matrices.");
|
||||
emitError("Attempt to multiply incompatible matrices");
|
||||
dims.emplace_back(1);
|
||||
} else if (lhsShape.size() == 1 && rhsShape.size() >= 2) {
|
||||
// If the first argument is 1-D, it is promoted to a matrix by prepending a
|
||||
|
@ -441,7 +441,7 @@ void ONNXMatMulOp::inferShapes() {
|
|||
unsigned rhsRank = rhsShape.size();
|
||||
if (lhsShape[0] != -1 && rhsShape[rhsRank - 2] != -1 &&
|
||||
lhsShape[0] != rhsShape[rhsRank - 2])
|
||||
emitError("Attempt to multiply incompatible matrices.");
|
||||
emitError("Attempt to multiply incompatible matrices");
|
||||
|
||||
for (decltype(rhsRank) i = 0; i < rhsRank - 2; ++i)
|
||||
dims.emplace_back(rhsShape[i]);
|
||||
|
@ -459,7 +459,7 @@ void ONNXMatMulOp::inferShapes() {
|
|||
unsigned lhsRank = lhsShape.size();
|
||||
if (lhsShape[lhsRank - 1] != -1 && rhsShape[0] != -1 &&
|
||||
lhsShape[lhsRank - 1] != rhsShape[0])
|
||||
emitError("Attempt to multiply incompatible matrices.");
|
||||
emitError("Attempt to multiply incompatible matrices");
|
||||
|
||||
for (decltype(lhsRank) i = 0; i < lhsRank - 2; ++i)
|
||||
dims.emplace_back(lhsShape[i]);
|
||||
|
@ -473,7 +473,7 @@ void ONNXMatMulOp::inferShapes() {
|
|||
unsigned lhsRank = lhsShape.size();
|
||||
if (lhsShape[lhsRank - 1] != -1 && rhsShape[0] != -1 &&
|
||||
lhsShape[lhsRank - 1] != rhsShape[0])
|
||||
emitError("Attempt to multiply incompatible matrices.");
|
||||
emitError("Attempt to multiply incompatible matrices");
|
||||
|
||||
for (decltype(lhsRank) i = 0; i < lhsRank - 1; ++i)
|
||||
dims.emplace_back(lhsShape[i]);
|
||||
|
@ -487,7 +487,7 @@ void ONNXMatMulOp::inferShapes() {
|
|||
unsigned rhsRank = rhsShape.size();
|
||||
if (lhsShape[1] != -1 && rhsShape[rhsRank - 2] != -1 &&
|
||||
lhsShape[1] != rhsShape[rhsRank - 2])
|
||||
emitError("Attempt to multiply incompatible matrices.");
|
||||
emitError("Attempt to multiply incompatible matrices");
|
||||
|
||||
for (decltype(rhsRank) i = 0; i < rhsRank - 2; ++i)
|
||||
dims.emplace_back(rhsShape[i]);
|
||||
|
@ -503,7 +503,7 @@ void ONNXMatMulOp::inferShapes() {
|
|||
unsigned rhsRank = rhsShape.size();
|
||||
if (lhsShape[lhsRank - 1] != -1 && rhsShape[rhsRank - 2] != -1 &&
|
||||
lhsShape[lhsRank - 1] != rhsShape[rhsRank - 2])
|
||||
emitError("Attempt to multiply incompatible matrices.");
|
||||
emitError("Attempt to multiply incompatible matrices");
|
||||
|
||||
// Check and perform broadcasting for the shapes.
|
||||
SmallVector<int64_t, 2> lhsBcastShape;
|
||||
|
@ -513,7 +513,7 @@ void ONNXMatMulOp::inferShapes() {
|
|||
for (decltype(rhsRank) i = 0; i < rhsRank - 2; ++i)
|
||||
rhsBcastShape.emplace_back(rhsShape[i]);
|
||||
if (!getBroadcastedShape(lhsBcastShape, rhsBcastShape, dims))
|
||||
emitError("Broadcasted dimensions are incompatible.");
|
||||
emitError("Broadcasted dimensions are incompatible");
|
||||
|
||||
dims.emplace_back(lhsShape[lhsRank - 2]);
|
||||
dims.emplace_back(rhsShape[rhsRank - 1]);
|
||||
|
@ -528,7 +528,7 @@ void ONNXMatMulOp::inferShapes() {
|
|||
|
||||
// Check legality of matrix multiplication.
|
||||
if (lhsDim != -1 && rhsDim != -1 && lhsDim != rhsDim)
|
||||
emitError("Attempt to multiply incompatible matrices.");
|
||||
emitError("Attempt to multiply incompatible matrices");
|
||||
|
||||
if (rhsShape.size() > 1)
|
||||
dims.emplace_back(rhsShape[1]);
|
||||
|
@ -542,14 +542,14 @@ void ONNXMatMulOp::inferShapes() {
|
|||
// Gemm
|
||||
|
||||
void ONNXGemmOp::inferShapes() {
|
||||
bool hasBias = !getOperand(2).getType().isa<NoneType>();
|
||||
bool hasBias = !C().getType().isa<NoneType>();
|
||||
// Cannot infer shape if no shape exists.
|
||||
if (!getOperand(0).getType().isa<RankedTensorType>() ||
|
||||
!getOperand(1).getType().isa<RankedTensorType>() ||
|
||||
(hasBias && !getOperand(2).getType().isa<RankedTensorType>()))
|
||||
if (!A().getType().isa<RankedTensorType>() ||
|
||||
!B().getType().isa<RankedTensorType>() ||
|
||||
(hasBias && !C().getType().isa<RankedTensorType>()))
|
||||
return;
|
||||
auto lhsTy = getOperand(0).getType().cast<RankedTensorType>();
|
||||
auto rhsTy = getOperand(1).getType().cast<RankedTensorType>();
|
||||
auto lhsTy = A().getType().cast<RankedTensorType>();
|
||||
auto rhsTy = B().getType().cast<RankedTensorType>();
|
||||
|
||||
int64_t M, N, K_A, K_B;
|
||||
M = (transA() == 0) ? lhsTy.getShape()[0] : lhsTy.getShape()[1];
|
||||
|
@ -558,12 +558,12 @@ void ONNXGemmOp::inferShapes() {
|
|||
K_B = (transB() == 0) ? rhsTy.getShape()[0] : rhsTy.getShape()[1];
|
||||
|
||||
if ((K_A != -1) and (K_B != -1) and (K_A != K_B)) {
|
||||
emitError("Tensor shapes mismatched.");
|
||||
emitError("Tensor shapes mismatched");
|
||||
}
|
||||
|
||||
if (hasBias) {
|
||||
// Check whether bias is unidirectional broadcasting or not.
|
||||
auto biasTy = getOperand(2).getType().cast<RankedTensorType>();
|
||||
auto biasTy = C().getType().cast<RankedTensorType>();
|
||||
auto shape = biasTy.getShape();
|
||||
int rank = shape.size();
|
||||
if ((rank > 2) ||
|
||||
|
@ -571,7 +571,7 @@ void ONNXGemmOp::inferShapes() {
|
|||
N != shape[rank - 1] && shape[rank - 1] != 1) ||
|
||||
(rank == 2 && shape[rank - 2] != -1 && M != -1 &&
|
||||
M != shape[rank - 2] && shape[rank - 2] != 1)) {
|
||||
emitError("Bias shape mismatched.");
|
||||
emitError("Bias shape mismatched");
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -584,50 +584,50 @@ void ONNXGemmOp::inferShapes() {
|
|||
/// BatchNormalizationTestMode
|
||||
void ONNXBatchNormalizationTestModeOp::inferShapes() {
|
||||
// Cannot infer shape if no shape exists.
|
||||
if (!getOperand(0).getType().isa<RankedTensorType>() ||
|
||||
!getOperand(1).getType().isa<RankedTensorType>() ||
|
||||
!getOperand(2).getType().isa<RankedTensorType>() ||
|
||||
!getOperand(3).getType().isa<RankedTensorType>() ||
|
||||
!getOperand(4).getType().isa<RankedTensorType>())
|
||||
if (!X().getType().isa<RankedTensorType>() ||
|
||||
!scale().getType().isa<RankedTensorType>() ||
|
||||
!B().getType().isa<RankedTensorType>() ||
|
||||
!mean().getType().isa<RankedTensorType>() ||
|
||||
!var().getType().isa<RankedTensorType>())
|
||||
return;
|
||||
|
||||
auto input = getOperand(0).getType().cast<RankedTensorType>();
|
||||
auto scale = getOperand(1).getType().cast<RankedTensorType>();
|
||||
auto bias = getOperand(2).getType().cast<RankedTensorType>();
|
||||
auto mean = getOperand(3).getType().cast<RankedTensorType>();
|
||||
auto variance = getOperand(4).getType().cast<RankedTensorType>();
|
||||
auto inputTensorTy = X().getType().cast<RankedTensorType>();
|
||||
auto scaleTensorTy = scale().getType().cast<RankedTensorType>();
|
||||
auto biasTensorTy = B().getType().cast<RankedTensorType>();
|
||||
auto meanTensorTy = mean().getType().cast<RankedTensorType>();
|
||||
auto varianceTensorTy = var().getType().cast<RankedTensorType>();
|
||||
|
||||
// Check whether the shapes of scale, bias, mean and variance are valid.
|
||||
// Operand's dimensions can be in the form of NxCxD1xD2x...xDn or N.
|
||||
// In case of N, C is assumed to be 1.
|
||||
// Shapes of scale, bias, mean and variance must be C.
|
||||
int64_t c = -1;
|
||||
if (input.getShape().size() == 1) {
|
||||
if (inputTensorTy.getShape().size() == 1) {
|
||||
c = 1;
|
||||
} else if (input.getShape().size() > 2) {
|
||||
c = (input.getShape()[1] != -1) ? input.getShape()[1] : -1;
|
||||
} else if (inputTensorTy.getShape().size() > 2) {
|
||||
c = (inputTensorTy.getShape()[1] != -1) ? inputTensorTy.getShape()[1] : -1;
|
||||
} else {
|
||||
emitError("Wrong rank for the input.");
|
||||
emitError("Wrong rank for the input");
|
||||
}
|
||||
|
||||
if (c != -1) {
|
||||
auto s = scale.getShape();
|
||||
auto b = bias.getShape();
|
||||
auto m = mean.getShape();
|
||||
auto v = variance.getShape();
|
||||
auto s = scaleTensorTy.getShape();
|
||||
auto b = biasTensorTy.getShape();
|
||||
auto m = meanTensorTy.getShape();
|
||||
auto v = varianceTensorTy.getShape();
|
||||
|
||||
if ((s.size() != 1) || (s[0] != -1 && s[0] != c))
|
||||
emitError("Wrong rank for the scale.");
|
||||
emitError("Wrong rank for the scale");
|
||||
if ((b.size() != 1) || (b[0] != -1 && b[0] != c))
|
||||
emitError("Wrong rank for the bias.");
|
||||
emitError("Wrong rank for the bias");
|
||||
if ((m.size() != 1) || (m[0] != -1 && m[0] != c))
|
||||
emitError("Wrong rank for the mean.");
|
||||
emitError("Wrong rank for the mean");
|
||||
if ((v.size() != 1) || (v[0] != -1 && v[0] != c))
|
||||
emitError("Wrong rank for the variance.");
|
||||
emitError("Wrong rank for the variance");
|
||||
}
|
||||
|
||||
// The output tensor of the same shape as the input.
|
||||
getResult().setType(getOperand(0).getType());
|
||||
getResult().setType(X().getType());
|
||||
}
|
||||
|
||||
// TODO:
|
||||
|
@ -640,21 +640,21 @@ void ONNXBatchNormalizationTestModeOp::inferShapes() {
|
|||
|
||||
void ONNXReshapeOp::inferShapes() {
|
||||
// Cannot infer shape if no shape tensor is specified.
|
||||
if (!getOperand(1).getType().isa<RankedTensorType>())
|
||||
emitError("Shape tensor not ranked.");
|
||||
if (!shape().getType().isa<RankedTensorType>())
|
||||
emitError("Shape tensor not ranked");
|
||||
|
||||
auto inputTensorTy = getOperand(0).getType().cast<RankedTensorType>();
|
||||
auto shapeTensorTy = getOperand(1).getType().cast<RankedTensorType>();
|
||||
auto inputTensorTy = data().getType().cast<RankedTensorType>();
|
||||
auto shapeTensorTy = shape().getType().cast<RankedTensorType>();
|
||||
|
||||
// Only rank 1 shape tensors are supported.
|
||||
if (shapeTensorTy.getShape().size() != 1)
|
||||
emitError("Shape tensor must have rank one.");
|
||||
emitError("Shape tensor must have rank one");
|
||||
|
||||
int64_t outputRank = shapeTensorTy.getShape()[0];
|
||||
|
||||
// Shape tensor must have constant shape.
|
||||
if (outputRank < 0)
|
||||
emitError("Shape tensor must have constant shape.");
|
||||
emitError("Shape tensor must have constant shape");
|
||||
|
||||
SmallVector<int64_t, 2> dims;
|
||||
for (int i = 0; i < outputRank; ++i)
|
||||
|
@ -670,12 +670,12 @@ void ONNXReshapeOp::inferShapes() {
|
|||
|
||||
void ONNXTransposeOp::inferShapes() {
|
||||
// Cannot infer shape if no shape exists.
|
||||
if (!getOperand().getType().isa<RankedTensorType>())
|
||||
if (!data().getType().isa<RankedTensorType>())
|
||||
return;
|
||||
|
||||
// Naive transposition which handles the default case of
|
||||
// reversing the shape of the tensor (similar to numpy.transpose).
|
||||
auto arrayTy = getOperand().getType().cast<RankedTensorType>();
|
||||
auto arrayTy = data().getType().cast<RankedTensorType>();
|
||||
SmallVector<int64_t, 2> dims;
|
||||
auto permutation = ONNXTransposeOp::permAttr();
|
||||
if (permutation) {
|
||||
|
@ -697,7 +697,7 @@ void ONNXTransposeOp::inferShapes() {
|
|||
|
||||
void ONNXReduceMaxOp::inferShapes() {
|
||||
if (!getOperand().getType().isa<RankedTensorType>()) {
|
||||
emitError("Shape tensor not ranked.");
|
||||
emitError("Shape tensor not ranked");
|
||||
return;
|
||||
}
|
||||
|
||||
|
@ -711,7 +711,7 @@ void ONNXReduceMaxOp::inferShapes() {
|
|||
|
||||
void ONNXReduceMinOp::inferShapes() {
|
||||
if (!getOperand().getType().isa<RankedTensorType>()) {
|
||||
emitError("Shape tensor not ranked.");
|
||||
emitError("Shape tensor not ranked");
|
||||
return;
|
||||
}
|
||||
|
||||
|
@ -725,7 +725,7 @@ void ONNXReduceMinOp::inferShapes() {
|
|||
|
||||
void ONNXReduceProdOp::inferShapes() {
|
||||
if (!getOperand().getType().isa<RankedTensorType>()) {
|
||||
emitError("Shape tensor not ranked.");
|
||||
emitError("Shape tensor not ranked");
|
||||
return;
|
||||
}
|
||||
|
||||
|
@ -739,7 +739,7 @@ void ONNXReduceProdOp::inferShapes() {
|
|||
|
||||
void ONNXReduceSumOp::inferShapes() {
|
||||
if (!getOperand().getType().isa<RankedTensorType>()) {
|
||||
emitError("Shape tensor not ranked.");
|
||||
emitError("Shape tensor not ranked");
|
||||
return;
|
||||
}
|
||||
|
||||
|
@ -758,22 +758,22 @@ void ONNXConvNoBiasOp::inferShapes() {
|
|||
// W: (M x C/group x k1 x k2 x ... x kn)
|
||||
|
||||
// Cannot infer shape if no shape exists.
|
||||
if (!getOperand(0).getType().isa<RankedTensorType>() ||
|
||||
!getOperand(1).getType().isa<RankedTensorType>())
|
||||
if (!X().getType().isa<RankedTensorType>() ||
|
||||
!W().getType().isa<RankedTensorType>())
|
||||
return;
|
||||
|
||||
auto dataTy = getOperand(0).getType().cast<RankedTensorType>();
|
||||
auto weightTy = getOperand(1).getType().cast<RankedTensorType>();
|
||||
auto dataTy = X().getType().cast<RankedTensorType>();
|
||||
auto weightTy = W().getType().cast<RankedTensorType>();
|
||||
auto dataShape = dataTy.getShape();
|
||||
auto weightShape = weightTy.getShape();
|
||||
|
||||
// Lowest supported convolution is a one dimensional convolution.
|
||||
if (dataShape.size() < 3)
|
||||
emitError("Data input shape must be at least (NxCxD1).");
|
||||
emitError("Data input shape must be at least (NxCxD1)");
|
||||
|
||||
// Check that shape of weight and data have same length.
|
||||
if (dataShape.size() != weightShape.size())
|
||||
emitError("Weight size not compatible with data size.");
|
||||
emitError("Weight size not compatible with data size");
|
||||
|
||||
// Required attribute auto_pad defaults to NOTSET.
|
||||
auto autoPad = auto_pad();
|
||||
|
@ -782,7 +782,7 @@ void ONNXConvNoBiasOp::inferShapes() {
|
|||
ONNXConvNoBiasOp::group().getSExtValue(); //.getLimitedValue();
|
||||
// Check that the X.shape[1] == (W.shape[1] * group) == C condition holds.
|
||||
if (dataShape[1] != (weightShape[1] * group))
|
||||
emitError("Channel dimension mismatch.");
|
||||
emitError("Channel dimension mismatch");
|
||||
|
||||
// Note: the value of the group attribut only impacts the way the
|
||||
// computation is carried out and not the actual output size.
|
||||
|
@ -812,11 +812,10 @@ void ONNXConvNoBiasOp::inferShapes() {
|
|||
// argument.
|
||||
SmallVector<int64_t, 2> kernelDims;
|
||||
if (auto kernelShape = kernel_shapeAttr()) {
|
||||
if (kernelShape.getValue().size() != nDims)
|
||||
emitError("kernel_shape length incompatible with spatial dimensions.");
|
||||
if (ArrayAttrSize(kernelShape) != nDims)
|
||||
emitError("kernel_shape length incompatible with spatial dimensions");
|
||||
for (int i = 0; i < nDims; ++i)
|
||||
kernelDims.emplace_back(
|
||||
(kernelShape.getValue()[i]).cast<IntegerAttr>().getInt());
|
||||
kernelDims.emplace_back(ArrayAttrIntVal(kernelShape, i));
|
||||
} else {
|
||||
for (int i = 0; i < nDims; ++i)
|
||||
kernelDims.emplace_back(weightShape[i + 2]);
|
||||
|
@ -834,13 +833,11 @@ void ONNXConvNoBiasOp::inferShapes() {
|
|||
// From a dimensionality perspective the kernel size becomes the dilated
|
||||
// kernel size.
|
||||
if (auto dilations = dilationsAttr()) {
|
||||
if (dilations.getValue().size() != nDims)
|
||||
emitError("dilations length incompatible with spatial dimensions.");
|
||||
if (ArrayAttrSize(dilations) != nDims)
|
||||
emitError("dilations length incompatible with spatial dimensions");
|
||||
for (int i = 0; i < nDims; ++i)
|
||||
kernelDims[i] =
|
||||
(kernelDims[i] + 1) *
|
||||
(dilations.getValue()[i]).cast<IntegerAttr>().getInt() -
|
||||
1;
|
||||
(kernelDims[i] + 1) * ArrayAttrIntVal(dilations, i) - 1;
|
||||
}
|
||||
|
||||
// Subtract kernel dimensions from input data dimensions.
|
||||
|
@ -853,16 +850,14 @@ void ONNXConvNoBiasOp::inferShapes() {
|
|||
// present then pads is considered to be all zeros (no padding).
|
||||
if (auto pads = padsAttr()) {
|
||||
// pads consists of two entries for each spatial axis.
|
||||
if (pads.getValue().size() != 2 * nDims)
|
||||
emitError("pads size is not twice the spatial size.");
|
||||
if (ArrayAttrSize(pads) != 2 * nDims)
|
||||
emitError("pads size is not twice the spatial size");
|
||||
|
||||
for (int i = 0; i < nDims; ++i) {
|
||||
// Padding for beginning of axis.
|
||||
int32_t p = (pads.getValue()[i]).cast<IntegerAttr>().getInt();
|
||||
outSpatialDims[i] += p;
|
||||
outSpatialDims[i] += ArrayAttrIntVal(pads, i);
|
||||
// Padding for end of axis.
|
||||
p = (pads.getValue()[i + nDims]).cast<IntegerAttr>().getInt();
|
||||
outSpatialDims[i] += p;
|
||||
outSpatialDims[i] += ArrayAttrIntVal(pads, i + nDims);
|
||||
}
|
||||
}
|
||||
} else if (autoPad == "SAME_UPPER" || autoPad == "SAME_LOWER") {
|
||||
|
@ -878,15 +873,15 @@ void ONNXConvNoBiasOp::inferShapes() {
|
|||
} else if (autoPad == "VALID") {
|
||||
// No padding
|
||||
} else {
|
||||
emitError("Unexpected attribute value for auto_pad.");
|
||||
emitError("Unexpected attribute value for auto_pad");
|
||||
}
|
||||
|
||||
// Strides
|
||||
if (auto strides = ONNXConvNoBiasOp::stridesAttr()) {
|
||||
if (strides.getValue().size() != nDims)
|
||||
emitError("strides length incompatible with spatial dimensions.");
|
||||
if (ArrayAttrSize(strides) != nDims)
|
||||
emitError("strides length incompatible with spatial dimensions");
|
||||
for (int i = 0; i < nDims; ++i) {
|
||||
int64_t stride = strides.getValue()[i].cast<IntegerAttr>().getInt();
|
||||
int64_t stride = ArrayAttrIntVal(strides, i);
|
||||
outSpatialDims[i] = floor(outSpatialDims[i] / stride);
|
||||
}
|
||||
}
|
||||
|
@ -1013,7 +1008,7 @@ void ONNXMaxPoolSingleOutOp::inferShapes() {
|
|||
}
|
||||
}
|
||||
} else if (autoPad != "VALID") {
|
||||
emitError("auto_pad of unknown / unsupported value.");
|
||||
emitError("auto_pad of unknown / unsupported value");
|
||||
}
|
||||
// Set pads values in attributes.
|
||||
{
|
||||
|
@ -1044,7 +1039,7 @@ void ONNXMaxPoolSingleOutOp::inferShapes() {
|
|||
}
|
||||
yShape[kernelOffset + i] = res;
|
||||
}
|
||||
auto arrayTy = getOperand().getType().cast<RankedTensorType>();
|
||||
auto arrayTy = X().getType().cast<RankedTensorType>();
|
||||
getResult().setType(RankedTensorType::get(yShape, arrayTy.getElementType()));
|
||||
}
|
||||
|
||||
|
@ -1053,10 +1048,10 @@ void ONNXMaxPoolSingleOutOp::inferShapes() {
|
|||
// Unsqueeze
|
||||
|
||||
void ONNXUnsqueezeOp::inferShapes() {
|
||||
if (!getOperand().getType().isa<RankedTensorType>())
|
||||
if (!data().getType().isa<RankedTensorType>())
|
||||
return;
|
||||
|
||||
auto operandTy = getOperand().getType().cast<RankedTensorType>();
|
||||
auto operandTy = data().getType().cast<RankedTensorType>();
|
||||
int inRank = operandTy.getRank();
|
||||
|
||||
ArrayAttr axisAttrs = axesAttr();
|
||||
|
@ -1072,10 +1067,10 @@ void ONNXUnsqueezeOp::inferShapes() {
|
|||
if (std::find(axes.begin(), axes.end(), axis) == axes.end())
|
||||
axes.emplace_back(axis);
|
||||
else
|
||||
emitError("Duplicated axes.");
|
||||
emitError("Duplicated axes");
|
||||
}
|
||||
} else {
|
||||
emitError("Axes attribute is required.");
|
||||
emitError("Axes attribute is required");
|
||||
}
|
||||
|
||||
SmallVector<int64_t, 4> dims;
|
||||
|
|
Loading…
Reference in New Issue