Add shape inference method.

This commit is contained in:
Doru Bercea 2020-01-20 18:50:21 -05:00
parent 3fe0f2e735
commit ec9e023f04
1 changed files with 87 additions and 33 deletions

View File

@ -469,52 +469,102 @@ void ONNXConvNoBiasOp::inferShapes() {
auto dataShape = dataTy.getShape();
auto weightShape = weightTy.getShape();
// Check that shape of weight and data have same length.
if (dataShape.size() != weightShape.size())
emitError("ConvNoBias: weight size not compatible with data size.");
emitError("Weight size not compatible with data size.");
// Required attribute auto_pad defaults to NOTSET.
auto autoPad = getAttrOfType<StringAttr>(
ONNXConvOp::getAutoPadAttrName()).getValue();
// Group is a required attribute and should have default value of 1.
int64_t group = getAttrOfType<IntegerAttr>(
ONNXConvOp::getGroupAttrName()).getInt();
if (!group)
emitError("ConvNoBias: group attribute missing.");
// Check that the X.shape[1] == (W.shape[1] * group) == C condition holds.
if (dataShape[1] != (weightShape[1] * group))
emitError("ConvNoBias: channel dimension mismatch.");
// Required attributes.
auto auto_pad = getAttrOfType<StringAttr>(
ONNXConvOp::getAutoPadAttrName());
auto pads = getAttrOfType<ArrayAttr>(
ONNXConvOp::getPadsAttrName());
emitError("Channel dimension mismatch.");
// First two output dimensions consist of the number of batches and the
// number of kernels being applied.
//
SmallVector<int64_t, 2> dims;
// Insert batch size.
dims.emplace_back(dataShape[0]);
// Insert number of filters being applied (number of output channels).
dims.emplace_back(weightShape[0]);
// // Compute the spatial dimensions.
// SmallVector<int64_t, 2> spatialDims;
// // Number of spatial dimensions.
// int32_t nDims = dataTy.size() - 2;
// // Initialize dimenions based on the input and weight spatial dimensions.
// for (int i = 2; i < dataTy.size(); ++i)
// spatialDims.emplace_back(dataTy[i] - weightTy[i]);
// // Add padding information.
// if () {
// for (int i = 0; i < nDims; ++i) {
// // Padding for beginning of axis.
// int32_t p = (pads.getValue()[i]).cast<IntegerAttr>().getInt();
// spatialDims[i] += p;
// // Padding for end of axis.
// p = (pads.getValue()[i + nDims]).cast<IntegerAttr>().getInt();
// spatialDims[i] += p;
// }
// } else if () {
// // Attribute pads has not been provided.
// }
// Spatial dimensions are computed using the formula:
//
// dim = (inputDim - kernelDim + startPadding + endPadding) / stride + 1
//
SmallVector<int64_t, 2> spatialDims;
// Number of spatial dimensions.
int32_t nDims = dataShape.size() - 2;
// Initialize dimenions based on the input spatial dimensions.
for (int i = 2; i < dataShape.size(); ++i)
spatialDims.emplace_back(dataShape[i]);
// Use kernel_shape attribute if present otherwise use size from weight
// argument.
if (auto kernel_shape = getAttrOfType<ArrayAttr>(
ONNXConvOp::getKernelShapeAttrName())) {
if (kernel_shape.getValue().size() != nDims)
emitError("kernel_shape length incompatible with spatial dimensions.");
for (int i = 0; i < nDims; ++i) {
int64_t kernelDim =
(kernel_shape.getValue()[i]).cast<IntegerAttr>().getInt();
spatialDims[i] -= kernelDim;
}
} else {
for (int i = 0; i < nDims; ++i)
spatialDims[i] -= weightShape[i + 2];
}
// Add padding information.
if (autoPad == "NOTSET") {
// Use pads to to determine the padding. If attribute is not
// present then pads is considered to be all zeros (no padding).
if (auto pads = getAttrOfType<ArrayAttr>(
ONNXConvOp::getPadsAttrName())) {
// pads consists of two entries for each spatial axis.
if (pads.getValue().size() != 2 * nDims)
emitError("pads size is not twice the spatial size.");
for (int i = 0; i < nDims; ++i) {
// Padding for beginning of axis.
int32_t p = (pads.getValue()[i]).cast<IntegerAttr>().getInt();
spatialDims[i] += p;
// Padding for end of axis.
p = (pads.getValue()[i + nDims]).cast<IntegerAttr>().getInt();
spatialDims[i] += p;
}
}
} else if (autoPad == "VALID") {
// TODO
} else if (autoPad == "SAME_UPPER") {
// TODO
} else if (autoPad == "SAME_LOWER") {
// TODO
} else {
emitError("Unexpected attribute value for auto_pad.");
}
// Strides
if (auto strides = getAttrOfType<ArrayAttr>(
ONNXConvOp::getStridesAttrName())) {
if (strides.getValue().size() != nDims)
emitError("strides length incompatible with spatial dimensions.");
for (int i = 0; i < nDims; ++i) {
int64_t stride =
(strides.getValue()[i]).cast<IntegerAttr>().getInt();
spatialDims[i] = floor(spatialDims[i] / stride);
}
}
for (int i = 0; i < nDims; ++i)
spatialDims[i] += 1;
dims.append(spatialDims.begin(), spatialDims.end());
getResult().setType(RankedTensorType::get(dims, dataTy.getElementType()));
}
@ -526,12 +576,16 @@ LogicalResult verify(ONNXConvNoBiasOp op) {
auto autoPadAttr = op.getAttrOfType<StringAttr>(
ONNXConvOp::getAutoPadAttrName());
if (!autoPadAttr)
op.emitError("ONNXConvNoBiasOp: auto_pad attribute not specified.");
op.emitError("auto_pad attribute not specified.");
if (autoPadAttr.getValue() != "NOTSET")
if (auto pads = op.getAttrOfType<ArrayAttr>(
ONNXConvOp::getPadsAttrName()))
op.emitError("auto_pad and pads are both set.");
auto groupAttr =
op.getAttrOfType<IntegerAttr>(ONNXConvOp::getGroupAttrName());
if (!groupAttr)
op.emitError("ONNXConvNoBiasOp: group attribute not specified.");
op.emitError("group attribute not specified.");
return success();
}