put the common code into a helper function
This commit is contained in:
parent
4079ee1f26
commit
3abbf1c0e9
|
@ -1045,39 +1045,41 @@ void ONNXMaxPoolSingleOutOp::inferShapes() {
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
static Type padShapeInferenceHelper(Value data, ArrayAttr padsOpt) {
|
||||||
|
// Cannot infer shape if no shape exists.
|
||||||
|
if (!data.getType().isa<RankedTensorType>())
|
||||||
|
return (Type)NULL;
|
||||||
|
auto dataTy = data.getType().cast<RankedTensorType>();
|
||||||
|
auto dataShape = dataTy.getShape();
|
||||||
|
auto dataRank = dataShape.size();
|
||||||
|
SmallVector<int64_t, 4> outputShape(dataShape.begin(), dataShape.end());
|
||||||
|
if (padsOpt) {
|
||||||
|
auto padsArray = padsOpt.getValue();
|
||||||
|
// Pads consists of two values for each axis of data.
|
||||||
|
// The two values specify the number of elements padded before and after respectively.
|
||||||
|
for (int i = 0; i < dataRank; ++i) {
|
||||||
|
int64_t p1 = (padsArray[2*i]).cast<IntegerAttr>().getInt();
|
||||||
|
int64_t p2 = (padsArray[2*i+1]).cast<IntegerAttr>().getInt();
|
||||||
|
//Have to non-negative constant
|
||||||
|
if (p1 < 0 || p2 <0)
|
||||||
|
return (Type)NULL;
|
||||||
|
outputShape[i] += p1+p2;
|
||||||
|
}
|
||||||
|
|
||||||
|
return (RankedTensorType::get(outputShape, dataTy.getElementType()));
|
||||||
|
} else {
|
||||||
|
return (Type)NULL;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// PadConstantPad
|
// PadConstantPad
|
||||||
|
|
||||||
void ONNXPadConstantPadOp::inferShapes(){
|
void ONNXPadConstantPadOp::inferShapes(){
|
||||||
// Cannot infer shape if no shape exists.
|
auto outputType = padShapeInferenceHelper(data(), pads());
|
||||||
if (!data().getType().isa<RankedTensorType>())
|
if (outputType) {
|
||||||
return;
|
getResult().setType(outputType);
|
||||||
|
}
|
||||||
// 1) get shape of input "data"
|
return;
|
||||||
auto dataTy = data().getType().cast<RankedTensorType>();
|
|
||||||
auto dataShape = dataTy.getShape();
|
|
||||||
auto dataRank = dataShape.size();
|
|
||||||
|
|
||||||
SmallVector<int64_t, 4> outputShape(dataShape.begin(), dataShape.end());
|
|
||||||
auto padsOpt = pads();
|
|
||||||
if (padsOpt) {
|
|
||||||
auto padsArray = padsOpt.getValue();
|
|
||||||
// pads consists of two entries for each spatial axis.
|
|
||||||
if (padsArray.size() != 2 * dataRank)
|
|
||||||
emitError("pads rank is not twice the spatial rank.");
|
|
||||||
// fill in the actual values
|
|
||||||
for (int i = 0; i < dataRank; ++i) {
|
|
||||||
int64_t p1 = (padsArray[2*i]).cast<IntegerAttr>().getInt();
|
|
||||||
if (p1 < 0)
|
|
||||||
emitError("pads value must be nonnegative.");
|
|
||||||
int64_t p2 = (padsArray[2*i+1]).cast<IntegerAttr>().getInt();
|
|
||||||
if (p2 < 0)
|
|
||||||
emitError("pads value must be nonnegative.");
|
|
||||||
outputShape[i] += p1+p2;
|
|
||||||
}
|
|
||||||
getResult().setType(RankedTensorType::get(outputShape, dataTy.getElementType()));
|
|
||||||
} else {
|
|
||||||
emitError("pads attribute is not available.");
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -1085,36 +1087,11 @@ void ONNXPadConstantPadOp::inferShapes(){
|
||||||
// PadConstantValuePad
|
// PadConstantValuePad
|
||||||
|
|
||||||
void ONNXPadConstantValuePadOp::inferShapes(){
|
void ONNXPadConstantValuePadOp::inferShapes(){
|
||||||
// Cannot infer shape if no shape exists.
|
auto outputType = padShapeInferenceHelper(data(), pads());
|
||||||
if (!data().getType().isa<RankedTensorType>())
|
if (outputType) {
|
||||||
return;
|
getResult().setType(outputType);
|
||||||
|
}
|
||||||
// 1) get shape of input "data"
|
return;
|
||||||
auto dataTy = data().getType().cast<RankedTensorType>();
|
|
||||||
auto dataShape = dataTy.getShape();
|
|
||||||
auto dataRank = dataShape.size();
|
|
||||||
|
|
||||||
SmallVector<int64_t, 4> outputShape(dataShape.begin(), dataShape.end());
|
|
||||||
auto padsOpt = pads();
|
|
||||||
if (padsOpt) {
|
|
||||||
auto padsArray = padsOpt.getValue();
|
|
||||||
// pads consists of two entries for each spatial axis.
|
|
||||||
if (padsArray.size() != 2 * dataRank)
|
|
||||||
emitError("pads rank is not twice the spatial rank.");
|
|
||||||
// fill in the actual values
|
|
||||||
for (int i = 0; i < dataRank; ++i) {
|
|
||||||
int64_t p1 = (padsArray[2*i]).cast<IntegerAttr>().getInt();
|
|
||||||
if (p1 < 0)
|
|
||||||
emitError("pads value must be nonnegative.");
|
|
||||||
int64_t p2 = (padsArray[2*i+1]).cast<IntegerAttr>().getInt();
|
|
||||||
if (p2 < 0)
|
|
||||||
emitError("pads value must be nonnegative.");
|
|
||||||
outputShape[i] += p1+p2;
|
|
||||||
}
|
|
||||||
getResult().setType(RankedTensorType::get(outputShape, dataTy.getElementType()));
|
|
||||||
} else {
|
|
||||||
emitError("pads attribute is not available.");
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
Loading…
Reference in New Issue