diff --git a/src/dialect/onnx/onnx_ops.cpp b/src/dialect/onnx/onnx_ops.cpp index 3666318..ed18086 100644 --- a/src/dialect/onnx/onnx_ops.cpp +++ b/src/dialect/onnx/onnx_ops.cpp @@ -406,12 +406,12 @@ void ONNXIdentityOp::inferShapes() { void ONNXMatMulOp::inferShapes() { // Cannot infer shape if no shape exists. - if (!getOperand(0).getType().isa() || - !getOperand(1).getType().isa()) + if (!A().getType().isa() || + !B().getType().isa()) return; - auto lhsTy = getOperand(0).getType().cast(); - auto rhsTy = getOperand(1).getType().cast(); + auto lhsTy = A().getType().cast(); + auto rhsTy = B().getType().cast(); SmallVector dims; auto lhsShape = lhsTy.getShape(); @@ -419,14 +419,14 @@ void ONNXMatMulOp::inferShapes() { if (lhsShape.size() < 1 && rhsShape.size() < 1) { // Multiplication by scalars is not allowed. - emitError("Multiplication by scalar arguments not allowed."); + emitError("Multiplication by scalar arguments not allowed"); } else if (lhsShape.size() == 1 && rhsShape.size() == 1) { // Special case when both arrays are 1-dimensional and according to // numpy rules the types need to be extended to 1xN and Nx1. Helper sizes // need to be removed after the multiplication but cannot be removed if all // sizes are 1. if (lhsShape[0] != -1 && rhsShape[0] != -1 && lhsShape[0] != rhsShape[0]) - emitError("Attempt to multiply incompatible matrices."); + emitError("Attempt to multiply incompatible matrices"); dims.emplace_back(1); } else if (lhsShape.size() == 1 && rhsShape.size() >= 2) { // If the first argument is 1-D, it is promoted to a matrix by prepending a @@ -441,7 +441,7 @@ void ONNXMatMulOp::inferShapes() { unsigned rhsRank = rhsShape.size(); if (lhsShape[0] != -1 && rhsShape[rhsRank - 2] != -1 && lhsShape[0] != rhsShape[rhsRank - 2]) - emitError("Attempt to multiply incompatible matrices."); + emitError("Attempt to multiply incompatible matrices"); for (decltype(rhsRank) i = 0; i < rhsRank - 2; ++i) dims.emplace_back(rhsShape[i]); @@ -459,7 +459,7 @@ void ONNXMatMulOp::inferShapes() { unsigned lhsRank = lhsShape.size(); if (lhsShape[lhsRank - 1] != -1 && rhsShape[0] != -1 && lhsShape[lhsRank - 1] != rhsShape[0]) - emitError("Attempt to multiply incompatible matrices."); + emitError("Attempt to multiply incompatible matrices"); for (decltype(lhsRank) i = 0; i < lhsRank - 2; ++i) dims.emplace_back(lhsShape[i]); @@ -473,7 +473,7 @@ void ONNXMatMulOp::inferShapes() { unsigned lhsRank = lhsShape.size(); if (lhsShape[lhsRank - 1] != -1 && rhsShape[0] != -1 && lhsShape[lhsRank - 1] != rhsShape[0]) - emitError("Attempt to multiply incompatible matrices."); + emitError("Attempt to multiply incompatible matrices"); for (decltype(lhsRank) i = 0; i < lhsRank - 1; ++i) dims.emplace_back(lhsShape[i]); @@ -487,7 +487,7 @@ void ONNXMatMulOp::inferShapes() { unsigned rhsRank = rhsShape.size(); if (lhsShape[1] != -1 && rhsShape[rhsRank - 2] != -1 && lhsShape[1] != rhsShape[rhsRank - 2]) - emitError("Attempt to multiply incompatible matrices."); + emitError("Attempt to multiply incompatible matrices"); for (decltype(rhsRank) i = 0; i < rhsRank - 2; ++i) dims.emplace_back(rhsShape[i]); @@ -503,7 +503,7 @@ void ONNXMatMulOp::inferShapes() { unsigned rhsRank = rhsShape.size(); if (lhsShape[lhsRank - 1] != -1 && rhsShape[rhsRank - 2] != -1 && lhsShape[lhsRank - 1] != rhsShape[rhsRank - 2]) - emitError("Attempt to multiply incompatible matrices."); + emitError("Attempt to multiply incompatible matrices"); // Check and perform broadcasting for the shapes. SmallVector lhsBcastShape; @@ -513,7 +513,7 @@ void ONNXMatMulOp::inferShapes() { for (decltype(rhsRank) i = 0; i < rhsRank - 2; ++i) rhsBcastShape.emplace_back(rhsShape[i]); if (!getBroadcastedShape(lhsBcastShape, rhsBcastShape, dims)) - emitError("Broadcasted dimensions are incompatible."); + emitError("Broadcasted dimensions are incompatible"); dims.emplace_back(lhsShape[lhsRank - 2]); dims.emplace_back(rhsShape[rhsRank - 1]); @@ -528,7 +528,7 @@ void ONNXMatMulOp::inferShapes() { // Check legality of matrix multiplication. if (lhsDim != -1 && rhsDim != -1 && lhsDim != rhsDim) - emitError("Attempt to multiply incompatible matrices."); + emitError("Attempt to multiply incompatible matrices"); if (rhsShape.size() > 1) dims.emplace_back(rhsShape[1]); @@ -542,14 +542,14 @@ void ONNXMatMulOp::inferShapes() { // Gemm void ONNXGemmOp::inferShapes() { - bool hasBias = !getOperand(2).getType().isa(); + bool hasBias = !C().getType().isa(); // Cannot infer shape if no shape exists. - if (!getOperand(0).getType().isa() || - !getOperand(1).getType().isa() || - (hasBias && !getOperand(2).getType().isa())) + if (!A().getType().isa() || + !B().getType().isa() || + (hasBias && !C().getType().isa())) return; - auto lhsTy = getOperand(0).getType().cast(); - auto rhsTy = getOperand(1).getType().cast(); + auto lhsTy = A().getType().cast(); + auto rhsTy = B().getType().cast(); int64_t M, N, K_A, K_B; M = (transA() == 0) ? lhsTy.getShape()[0] : lhsTy.getShape()[1]; @@ -558,12 +558,12 @@ void ONNXGemmOp::inferShapes() { K_B = (transB() == 0) ? rhsTy.getShape()[0] : rhsTy.getShape()[1]; if ((K_A != -1) and (K_B != -1) and (K_A != K_B)) { - emitError("Tensor shapes mismatched."); + emitError("Tensor shapes mismatched"); } if (hasBias) { // Check whether bias is unidirectional broadcasting or not. - auto biasTy = getOperand(2).getType().cast(); + auto biasTy = C().getType().cast(); auto shape = biasTy.getShape(); int rank = shape.size(); if ((rank > 2) || @@ -571,7 +571,7 @@ void ONNXGemmOp::inferShapes() { N != shape[rank - 1] && shape[rank - 1] != 1) || (rank == 2 && shape[rank - 2] != -1 && M != -1 && M != shape[rank - 2] && shape[rank - 2] != 1)) { - emitError("Bias shape mismatched."); + emitError("Bias shape mismatched"); } } @@ -584,50 +584,50 @@ void ONNXGemmOp::inferShapes() { /// BatchNormalizationTestMode void ONNXBatchNormalizationTestModeOp::inferShapes() { // Cannot infer shape if no shape exists. - if (!getOperand(0).getType().isa() || - !getOperand(1).getType().isa() || - !getOperand(2).getType().isa() || - !getOperand(3).getType().isa() || - !getOperand(4).getType().isa()) + if (!X().getType().isa() || + !scale().getType().isa() || + !B().getType().isa() || + !mean().getType().isa() || + !var().getType().isa()) return; - auto input = getOperand(0).getType().cast(); - auto scale = getOperand(1).getType().cast(); - auto bias = getOperand(2).getType().cast(); - auto mean = getOperand(3).getType().cast(); - auto variance = getOperand(4).getType().cast(); + auto inputTensorTy = X().getType().cast(); + auto scaleTensorTy = scale().getType().cast(); + auto biasTensorTy = B().getType().cast(); + auto meanTensorTy = mean().getType().cast(); + auto varianceTensorTy = var().getType().cast(); // Check whether the shapes of scale, bias, mean and variance are valid. // Operand's dimensions can be in the form of NxCxD1xD2x...xDn or N. // In case of N, C is assumed to be 1. // Shapes of scale, bias, mean and variance must be C. int64_t c = -1; - if (input.getShape().size() == 1) { + if (inputTensorTy.getShape().size() == 1) { c = 1; - } else if (input.getShape().size() > 2) { - c = (input.getShape()[1] != -1) ? input.getShape()[1] : -1; + } else if (inputTensorTy.getShape().size() > 2) { + c = (inputTensorTy.getShape()[1] != -1) ? inputTensorTy.getShape()[1] : -1; } else { - emitError("Wrong rank for the input."); + emitError("Wrong rank for the input"); } if (c != -1) { - auto s = scale.getShape(); - auto b = bias.getShape(); - auto m = mean.getShape(); - auto v = variance.getShape(); + auto s = scaleTensorTy.getShape(); + auto b = biasTensorTy.getShape(); + auto m = meanTensorTy.getShape(); + auto v = varianceTensorTy.getShape(); if ((s.size() != 1) || (s[0] != -1 && s[0] != c)) - emitError("Wrong rank for the scale."); + emitError("Wrong rank for the scale"); if ((b.size() != 1) || (b[0] != -1 && b[0] != c)) - emitError("Wrong rank for the bias."); + emitError("Wrong rank for the bias"); if ((m.size() != 1) || (m[0] != -1 && m[0] != c)) - emitError("Wrong rank for the mean."); + emitError("Wrong rank for the mean"); if ((v.size() != 1) || (v[0] != -1 && v[0] != c)) - emitError("Wrong rank for the variance."); + emitError("Wrong rank for the variance"); } // The output tensor of the same shape as the input. - getResult().setType(getOperand(0).getType()); + getResult().setType(X().getType()); } // TODO: @@ -640,21 +640,21 @@ void ONNXBatchNormalizationTestModeOp::inferShapes() { void ONNXReshapeOp::inferShapes() { // Cannot infer shape if no shape tensor is specified. - if (!getOperand(1).getType().isa()) - emitError("Shape tensor not ranked."); + if (!shape().getType().isa()) + emitError("Shape tensor not ranked"); - auto inputTensorTy = getOperand(0).getType().cast(); - auto shapeTensorTy = getOperand(1).getType().cast(); + auto inputTensorTy = data().getType().cast(); + auto shapeTensorTy = shape().getType().cast(); // Only rank 1 shape tensors are supported. if (shapeTensorTy.getShape().size() != 1) - emitError("Shape tensor must have rank one."); + emitError("Shape tensor must have rank one"); int64_t outputRank = shapeTensorTy.getShape()[0]; // Shape tensor must have constant shape. if (outputRank < 0) - emitError("Shape tensor must have constant shape."); + emitError("Shape tensor must have constant shape"); SmallVector dims; for (int i = 0; i < outputRank; ++i) @@ -670,12 +670,12 @@ void ONNXReshapeOp::inferShapes() { void ONNXTransposeOp::inferShapes() { // Cannot infer shape if no shape exists. - if (!getOperand().getType().isa()) + if (!data().getType().isa()) return; // Naive transposition which handles the default case of // reversing the shape of the tensor (similar to numpy.transpose). - auto arrayTy = getOperand().getType().cast(); + auto arrayTy = data().getType().cast(); SmallVector dims; auto permutation = ONNXTransposeOp::permAttr(); if (permutation) { @@ -697,7 +697,7 @@ void ONNXTransposeOp::inferShapes() { void ONNXReduceMaxOp::inferShapes() { if (!getOperand().getType().isa()) { - emitError("Shape tensor not ranked."); + emitError("Shape tensor not ranked"); return; } @@ -711,7 +711,7 @@ void ONNXReduceMaxOp::inferShapes() { void ONNXReduceMinOp::inferShapes() { if (!getOperand().getType().isa()) { - emitError("Shape tensor not ranked."); + emitError("Shape tensor not ranked"); return; } @@ -725,7 +725,7 @@ void ONNXReduceMinOp::inferShapes() { void ONNXReduceProdOp::inferShapes() { if (!getOperand().getType().isa()) { - emitError("Shape tensor not ranked."); + emitError("Shape tensor not ranked"); return; } @@ -739,7 +739,7 @@ void ONNXReduceProdOp::inferShapes() { void ONNXReduceSumOp::inferShapes() { if (!getOperand().getType().isa()) { - emitError("Shape tensor not ranked."); + emitError("Shape tensor not ranked"); return; } @@ -758,22 +758,22 @@ void ONNXConvNoBiasOp::inferShapes() { // W: (M x C/group x k1 x k2 x ... x kn) // Cannot infer shape if no shape exists. - if (!getOperand(0).getType().isa() || - !getOperand(1).getType().isa()) + if (!X().getType().isa() || + !W().getType().isa()) return; - auto dataTy = getOperand(0).getType().cast(); - auto weightTy = getOperand(1).getType().cast(); + auto dataTy = X().getType().cast(); + auto weightTy = W().getType().cast(); auto dataShape = dataTy.getShape(); auto weightShape = weightTy.getShape(); // Lowest supported convolution is a one dimensional convolution. if (dataShape.size() < 3) - emitError("Data input shape must be at least (NxCxD1)."); + emitError("Data input shape must be at least (NxCxD1)"); // Check that shape of weight and data have same length. if (dataShape.size() != weightShape.size()) - emitError("Weight size not compatible with data size."); + emitError("Weight size not compatible with data size"); // Required attribute auto_pad defaults to NOTSET. auto autoPad = auto_pad(); @@ -782,7 +782,7 @@ void ONNXConvNoBiasOp::inferShapes() { ONNXConvNoBiasOp::group().getSExtValue(); //.getLimitedValue(); // Check that the X.shape[1] == (W.shape[1] * group) == C condition holds. if (dataShape[1] != (weightShape[1] * group)) - emitError("Channel dimension mismatch."); + emitError("Channel dimension mismatch"); // Note: the value of the group attribut only impacts the way the // computation is carried out and not the actual output size. @@ -812,11 +812,10 @@ void ONNXConvNoBiasOp::inferShapes() { // argument. SmallVector kernelDims; if (auto kernelShape = kernel_shapeAttr()) { - if (kernelShape.getValue().size() != nDims) - emitError("kernel_shape length incompatible with spatial dimensions."); + if (ArrayAttrSize(kernelShape) != nDims) + emitError("kernel_shape length incompatible with spatial dimensions"); for (int i = 0; i < nDims; ++i) - kernelDims.emplace_back( - (kernelShape.getValue()[i]).cast().getInt()); + kernelDims.emplace_back(ArrayAttrIntVal(kernelShape, i)); } else { for (int i = 0; i < nDims; ++i) kernelDims.emplace_back(weightShape[i + 2]); @@ -834,13 +833,11 @@ void ONNXConvNoBiasOp::inferShapes() { // From a dimensionality perspective the kernel size becomes the dilated // kernel size. if (auto dilations = dilationsAttr()) { - if (dilations.getValue().size() != nDims) - emitError("dilations length incompatible with spatial dimensions."); + if (ArrayAttrSize(dilations) != nDims) + emitError("dilations length incompatible with spatial dimensions"); for (int i = 0; i < nDims; ++i) kernelDims[i] = - (kernelDims[i] + 1) * - (dilations.getValue()[i]).cast().getInt() - - 1; + (kernelDims[i] + 1) * ArrayAttrIntVal(dilations, i) - 1; } // Subtract kernel dimensions from input data dimensions. @@ -853,16 +850,14 @@ void ONNXConvNoBiasOp::inferShapes() { // present then pads is considered to be all zeros (no padding). if (auto pads = padsAttr()) { // pads consists of two entries for each spatial axis. - if (pads.getValue().size() != 2 * nDims) - emitError("pads size is not twice the spatial size."); + if (ArrayAttrSize(pads) != 2 * nDims) + emitError("pads size is not twice the spatial size"); for (int i = 0; i < nDims; ++i) { // Padding for beginning of axis. - int32_t p = (pads.getValue()[i]).cast().getInt(); - outSpatialDims[i] += p; + outSpatialDims[i] += ArrayAttrIntVal(pads, i); // Padding for end of axis. - p = (pads.getValue()[i + nDims]).cast().getInt(); - outSpatialDims[i] += p; + outSpatialDims[i] += ArrayAttrIntVal(pads, i + nDims); } } } else if (autoPad == "SAME_UPPER" || autoPad == "SAME_LOWER") { @@ -878,15 +873,15 @@ void ONNXConvNoBiasOp::inferShapes() { } else if (autoPad == "VALID") { // No padding } else { - emitError("Unexpected attribute value for auto_pad."); + emitError("Unexpected attribute value for auto_pad"); } // Strides if (auto strides = ONNXConvNoBiasOp::stridesAttr()) { - if (strides.getValue().size() != nDims) - emitError("strides length incompatible with spatial dimensions."); + if (ArrayAttrSize(strides) != nDims) + emitError("strides length incompatible with spatial dimensions"); for (int i = 0; i < nDims; ++i) { - int64_t stride = strides.getValue()[i].cast().getInt(); + int64_t stride = ArrayAttrIntVal(strides, i); outSpatialDims[i] = floor(outSpatialDims[i] / stride); } } @@ -1013,7 +1008,7 @@ void ONNXMaxPoolSingleOutOp::inferShapes() { } } } else if (autoPad != "VALID") { - emitError("auto_pad of unknown / unsupported value."); + emitError("auto_pad of unknown / unsupported value"); } // Set pads values in attributes. { @@ -1044,7 +1039,7 @@ void ONNXMaxPoolSingleOutOp::inferShapes() { } yShape[kernelOffset + i] = res; } - auto arrayTy = getOperand().getType().cast(); + auto arrayTy = X().getType().cast(); getResult().setType(RankedTensorType::get(yShape, arrayTy.getElementType())); } @@ -1053,10 +1048,10 @@ void ONNXMaxPoolSingleOutOp::inferShapes() { // Unsqueeze void ONNXUnsqueezeOp::inferShapes() { - if (!getOperand().getType().isa()) + if (!data().getType().isa()) return; - auto operandTy = getOperand().getType().cast(); + auto operandTy = data().getType().cast(); int inRank = operandTy.getRank(); ArrayAttr axisAttrs = axesAttr(); @@ -1072,10 +1067,10 @@ void ONNXUnsqueezeOp::inferShapes() { if (std::find(axes.begin(), axes.end(), axis) == axes.end()) axes.emplace_back(axis); else - emitError("Duplicated axes."); + emitError("Duplicated axes"); } } else { - emitError("Axes attribute is required."); + emitError("Axes attribute is required"); } SmallVector dims;