diff --git a/src/dialect/onnx/onnx_ops.cpp b/src/dialect/onnx/onnx_ops.cpp index 1b89893..5474856 100644 --- a/src/dialect/onnx/onnx_ops.cpp +++ b/src/dialect/onnx/onnx_ops.cpp @@ -1045,39 +1045,41 @@ void ONNXMaxPoolSingleOutOp::inferShapes() { //===----------------------------------------------------------------------===// +static Type padShapeInferenceHelper(Value data, ArrayAttr padsOpt) { + // Cannot infer shape if no shape exists. + if (!data.getType().isa()) + return (Type)NULL; + auto dataTy = data.getType().cast(); + auto dataShape = dataTy.getShape(); + auto dataRank = dataShape.size(); + SmallVector outputShape(dataShape.begin(), dataShape.end()); + if (padsOpt) { + auto padsArray = padsOpt.getValue(); + // Pads consists of two values for each axis of data. + // The two values specify the number of elements padded before and after respectively. + for (int i = 0; i < dataRank; ++i) { + int64_t p1 = (padsArray[2*i]).cast().getInt(); + int64_t p2 = (padsArray[2*i+1]).cast().getInt(); + //Have to non-negative constant + if (p1 < 0 || p2 <0) + return (Type)NULL; + outputShape[i] += p1+p2; + } + + return (RankedTensorType::get(outputShape, dataTy.getElementType())); + } else { + return (Type)NULL; + } +} + // PadConstantPad void ONNXPadConstantPadOp::inferShapes(){ - // Cannot infer shape if no shape exists. - if (!data().getType().isa()) - return; - - // 1) get shape of input "data" - auto dataTy = data().getType().cast(); - auto dataShape = dataTy.getShape(); - auto dataRank = dataShape.size(); - - SmallVector outputShape(dataShape.begin(), dataShape.end()); - auto padsOpt = pads(); - if (padsOpt) { - auto padsArray = padsOpt.getValue(); - // pads consists of two entries for each spatial axis. - if (padsArray.size() != 2 * dataRank) - emitError("pads rank is not twice the spatial rank."); - // fill in the actual values - for (int i = 0; i < dataRank; ++i) { - int64_t p1 = (padsArray[2*i]).cast().getInt(); - if (p1 < 0) - emitError("pads value must be nonnegative."); - int64_t p2 = (padsArray[2*i+1]).cast().getInt(); - if (p2 < 0) - emitError("pads value must be nonnegative."); - outputShape[i] += p1+p2; - } - getResult().setType(RankedTensorType::get(outputShape, dataTy.getElementType())); - } else { - emitError("pads attribute is not available."); - } + auto outputType = padShapeInferenceHelper(data(), pads()); + if (outputType) { + getResult().setType(outputType); + } + return; } //===----------------------------------------------------------------------===// @@ -1085,36 +1087,11 @@ void ONNXPadConstantPadOp::inferShapes(){ // PadConstantValuePad void ONNXPadConstantValuePadOp::inferShapes(){ - // Cannot infer shape if no shape exists. - if (!data().getType().isa()) - return; - - // 1) get shape of input "data" - auto dataTy = data().getType().cast(); - auto dataShape = dataTy.getShape(); - auto dataRank = dataShape.size(); - - SmallVector outputShape(dataShape.begin(), dataShape.end()); - auto padsOpt = pads(); - if (padsOpt) { - auto padsArray = padsOpt.getValue(); - // pads consists of two entries for each spatial axis. - if (padsArray.size() != 2 * dataRank) - emitError("pads rank is not twice the spatial rank."); - // fill in the actual values - for (int i = 0; i < dataRank; ++i) { - int64_t p1 = (padsArray[2*i]).cast().getInt(); - if (p1 < 0) - emitError("pads value must be nonnegative."); - int64_t p2 = (padsArray[2*i+1]).cast().getInt(); - if (p2 < 0) - emitError("pads value must be nonnegative."); - outputShape[i] += p1+p2; - } - getResult().setType(RankedTensorType::get(outputShape, dataTy.getElementType())); - } else { - emitError("pads attribute is not available."); - } + auto outputType = padShapeInferenceHelper(data(), pads()); + if (outputType) { + getResult().setType(outputType); + } + return; } //===----------------------------------------------------------------------===//