put the common code into a helper function
This commit is contained in:
parent
4079ee1f26
commit
3abbf1c0e9
|
@ -1045,39 +1045,41 @@ void ONNXMaxPoolSingleOutOp::inferShapes() {
|
|||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static Type padShapeInferenceHelper(Value data, ArrayAttr padsOpt) {
|
||||
// Cannot infer shape if no shape exists.
|
||||
if (!data.getType().isa<RankedTensorType>())
|
||||
return (Type)NULL;
|
||||
auto dataTy = data.getType().cast<RankedTensorType>();
|
||||
auto dataShape = dataTy.getShape();
|
||||
auto dataRank = dataShape.size();
|
||||
SmallVector<int64_t, 4> outputShape(dataShape.begin(), dataShape.end());
|
||||
if (padsOpt) {
|
||||
auto padsArray = padsOpt.getValue();
|
||||
// Pads consists of two values for each axis of data.
|
||||
// The two values specify the number of elements padded before and after respectively.
|
||||
for (int i = 0; i < dataRank; ++i) {
|
||||
int64_t p1 = (padsArray[2*i]).cast<IntegerAttr>().getInt();
|
||||
int64_t p2 = (padsArray[2*i+1]).cast<IntegerAttr>().getInt();
|
||||
//Have to non-negative constant
|
||||
if (p1 < 0 || p2 <0)
|
||||
return (Type)NULL;
|
||||
outputShape[i] += p1+p2;
|
||||
}
|
||||
|
||||
return (RankedTensorType::get(outputShape, dataTy.getElementType()));
|
||||
} else {
|
||||
return (Type)NULL;
|
||||
}
|
||||
}
|
||||
|
||||
// PadConstantPad
|
||||
|
||||
void ONNXPadConstantPadOp::inferShapes(){
|
||||
// Cannot infer shape if no shape exists.
|
||||
if (!data().getType().isa<RankedTensorType>())
|
||||
return;
|
||||
|
||||
// 1) get shape of input "data"
|
||||
auto dataTy = data().getType().cast<RankedTensorType>();
|
||||
auto dataShape = dataTy.getShape();
|
||||
auto dataRank = dataShape.size();
|
||||
|
||||
SmallVector<int64_t, 4> outputShape(dataShape.begin(), dataShape.end());
|
||||
auto padsOpt = pads();
|
||||
if (padsOpt) {
|
||||
auto padsArray = padsOpt.getValue();
|
||||
// pads consists of two entries for each spatial axis.
|
||||
if (padsArray.size() != 2 * dataRank)
|
||||
emitError("pads rank is not twice the spatial rank.");
|
||||
// fill in the actual values
|
||||
for (int i = 0; i < dataRank; ++i) {
|
||||
int64_t p1 = (padsArray[2*i]).cast<IntegerAttr>().getInt();
|
||||
if (p1 < 0)
|
||||
emitError("pads value must be nonnegative.");
|
||||
int64_t p2 = (padsArray[2*i+1]).cast<IntegerAttr>().getInt();
|
||||
if (p2 < 0)
|
||||
emitError("pads value must be nonnegative.");
|
||||
outputShape[i] += p1+p2;
|
||||
}
|
||||
getResult().setType(RankedTensorType::get(outputShape, dataTy.getElementType()));
|
||||
} else {
|
||||
emitError("pads attribute is not available.");
|
||||
}
|
||||
auto outputType = padShapeInferenceHelper(data(), pads());
|
||||
if (outputType) {
|
||||
getResult().setType(outputType);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -1085,36 +1087,11 @@ void ONNXPadConstantPadOp::inferShapes(){
|
|||
// PadConstantValuePad
|
||||
|
||||
void ONNXPadConstantValuePadOp::inferShapes(){
|
||||
// Cannot infer shape if no shape exists.
|
||||
if (!data().getType().isa<RankedTensorType>())
|
||||
return;
|
||||
|
||||
// 1) get shape of input "data"
|
||||
auto dataTy = data().getType().cast<RankedTensorType>();
|
||||
auto dataShape = dataTy.getShape();
|
||||
auto dataRank = dataShape.size();
|
||||
|
||||
SmallVector<int64_t, 4> outputShape(dataShape.begin(), dataShape.end());
|
||||
auto padsOpt = pads();
|
||||
if (padsOpt) {
|
||||
auto padsArray = padsOpt.getValue();
|
||||
// pads consists of two entries for each spatial axis.
|
||||
if (padsArray.size() != 2 * dataRank)
|
||||
emitError("pads rank is not twice the spatial rank.");
|
||||
// fill in the actual values
|
||||
for (int i = 0; i < dataRank; ++i) {
|
||||
int64_t p1 = (padsArray[2*i]).cast<IntegerAttr>().getInt();
|
||||
if (p1 < 0)
|
||||
emitError("pads value must be nonnegative.");
|
||||
int64_t p2 = (padsArray[2*i+1]).cast<IntegerAttr>().getInt();
|
||||
if (p2 < 0)
|
||||
emitError("pads value must be nonnegative.");
|
||||
outputShape[i] += p1+p2;
|
||||
}
|
||||
getResult().setType(RankedTensorType::get(outputShape, dataTy.getElementType()));
|
||||
} else {
|
||||
emitError("pads attribute is not available.");
|
||||
}
|
||||
auto outputType = padShapeInferenceHelper(data(), pads());
|
||||
if (outputType) {
|
||||
getResult().setType(outputType);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
Loading…
Reference in New Issue