put the common code into a helper function

This commit is contained in:
chentong 2020-02-25 17:43:49 -05:00
parent 4079ee1f26
commit 3abbf1c0e9
1 changed files with 37 additions and 60 deletions

View File

@ -1045,39 +1045,41 @@ void ONNXMaxPoolSingleOutOp::inferShapes() {
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
static Type padShapeInferenceHelper(Value data, ArrayAttr padsOpt) {
// Cannot infer shape if no shape exists.
if (!data.getType().isa<RankedTensorType>())
return (Type)NULL;
auto dataTy = data.getType().cast<RankedTensorType>();
auto dataShape = dataTy.getShape();
auto dataRank = dataShape.size();
SmallVector<int64_t, 4> outputShape(dataShape.begin(), dataShape.end());
if (padsOpt) {
auto padsArray = padsOpt.getValue();
// Pads consists of two values for each axis of data.
// The two values specify the number of elements padded before and after respectively.
for (int i = 0; i < dataRank; ++i) {
int64_t p1 = (padsArray[2*i]).cast<IntegerAttr>().getInt();
int64_t p2 = (padsArray[2*i+1]).cast<IntegerAttr>().getInt();
//Have to non-negative constant
if (p1 < 0 || p2 <0)
return (Type)NULL;
outputShape[i] += p1+p2;
}
return (RankedTensorType::get(outputShape, dataTy.getElementType()));
} else {
return (Type)NULL;
}
}
// PadConstantPad // PadConstantPad
void ONNXPadConstantPadOp::inferShapes(){ void ONNXPadConstantPadOp::inferShapes(){
// Cannot infer shape if no shape exists. auto outputType = padShapeInferenceHelper(data(), pads());
if (!data().getType().isa<RankedTensorType>()) if (outputType) {
return; getResult().setType(outputType);
// 1) get shape of input "data"
auto dataTy = data().getType().cast<RankedTensorType>();
auto dataShape = dataTy.getShape();
auto dataRank = dataShape.size();
SmallVector<int64_t, 4> outputShape(dataShape.begin(), dataShape.end());
auto padsOpt = pads();
if (padsOpt) {
auto padsArray = padsOpt.getValue();
// pads consists of two entries for each spatial axis.
if (padsArray.size() != 2 * dataRank)
emitError("pads rank is not twice the spatial rank.");
// fill in the actual values
for (int i = 0; i < dataRank; ++i) {
int64_t p1 = (padsArray[2*i]).cast<IntegerAttr>().getInt();
if (p1 < 0)
emitError("pads value must be nonnegative.");
int64_t p2 = (padsArray[2*i+1]).cast<IntegerAttr>().getInt();
if (p2 < 0)
emitError("pads value must be nonnegative.");
outputShape[i] += p1+p2;
}
getResult().setType(RankedTensorType::get(outputShape, dataTy.getElementType()));
} else {
emitError("pads attribute is not available.");
} }
return;
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -1085,36 +1087,11 @@ void ONNXPadConstantPadOp::inferShapes(){
// PadConstantValuePad // PadConstantValuePad
void ONNXPadConstantValuePadOp::inferShapes(){ void ONNXPadConstantValuePadOp::inferShapes(){
// Cannot infer shape if no shape exists. auto outputType = padShapeInferenceHelper(data(), pads());
if (!data().getType().isa<RankedTensorType>()) if (outputType) {
return; getResult().setType(outputType);
// 1) get shape of input "data"
auto dataTy = data().getType().cast<RankedTensorType>();
auto dataShape = dataTy.getShape();
auto dataRank = dataShape.size();
SmallVector<int64_t, 4> outputShape(dataShape.begin(), dataShape.end());
auto padsOpt = pads();
if (padsOpt) {
auto padsArray = padsOpt.getValue();
// pads consists of two entries for each spatial axis.
if (padsArray.size() != 2 * dataRank)
emitError("pads rank is not twice the spatial rank.");
// fill in the actual values
for (int i = 0; i < dataRank; ++i) {
int64_t p1 = (padsArray[2*i]).cast<IntegerAttr>().getInt();
if (p1 < 0)
emitError("pads value must be nonnegative.");
int64_t p2 = (padsArray[2*i+1]).cast<IntegerAttr>().getInt();
if (p2 < 0)
emitError("pads value must be nonnegative.");
outputShape[i] += p1+p2;
}
getResult().setType(RankedTensorType::get(outputShape, dataTy.getElementType()));
} else {
emitError("pads attribute is not available.");
} }
return;
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//