Add shape inference method.
This commit is contained in:
parent
3fe0f2e735
commit
ec9e023f04
|
@ -469,52 +469,102 @@ void ONNXConvNoBiasOp::inferShapes() {
|
|||
auto dataShape = dataTy.getShape();
|
||||
auto weightShape = weightTy.getShape();
|
||||
|
||||
// Check that shape of weight and data have same length.
|
||||
if (dataShape.size() != weightShape.size())
|
||||
emitError("ConvNoBias: weight size not compatible with data size.");
|
||||
emitError("Weight size not compatible with data size.");
|
||||
|
||||
// Required attribute auto_pad defaults to NOTSET.
|
||||
auto autoPad = getAttrOfType<StringAttr>(
|
||||
ONNXConvOp::getAutoPadAttrName()).getValue();
|
||||
// Group is a required attribute and should have default value of 1.
|
||||
int64_t group = getAttrOfType<IntegerAttr>(
|
||||
ONNXConvOp::getGroupAttrName()).getInt();
|
||||
if (!group)
|
||||
emitError("ConvNoBias: group attribute missing.");
|
||||
|
||||
// Check that the X.shape[1] == (W.shape[1] * group) == C condition holds.
|
||||
if (dataShape[1] != (weightShape[1] * group))
|
||||
emitError("ConvNoBias: channel dimension mismatch.");
|
||||
|
||||
// Required attributes.
|
||||
auto auto_pad = getAttrOfType<StringAttr>(
|
||||
ONNXConvOp::getAutoPadAttrName());
|
||||
auto pads = getAttrOfType<ArrayAttr>(
|
||||
ONNXConvOp::getPadsAttrName());
|
||||
emitError("Channel dimension mismatch.");
|
||||
|
||||
// First two output dimensions consist of the number of batches and the
|
||||
// number of kernels being applied.
|
||||
//
|
||||
SmallVector<int64_t, 2> dims;
|
||||
// Insert batch size.
|
||||
dims.emplace_back(dataShape[0]);
|
||||
// Insert number of filters being applied (number of output channels).
|
||||
dims.emplace_back(weightShape[0]);
|
||||
|
||||
// // Compute the spatial dimensions.
|
||||
// SmallVector<int64_t, 2> spatialDims;
|
||||
// // Number of spatial dimensions.
|
||||
// int32_t nDims = dataTy.size() - 2;
|
||||
// // Initialize dimenions based on the input and weight spatial dimensions.
|
||||
// for (int i = 2; i < dataTy.size(); ++i)
|
||||
// spatialDims.emplace_back(dataTy[i] - weightTy[i]);
|
||||
// // Add padding information.
|
||||
// if () {
|
||||
// for (int i = 0; i < nDims; ++i) {
|
||||
// // Padding for beginning of axis.
|
||||
// int32_t p = (pads.getValue()[i]).cast<IntegerAttr>().getInt();
|
||||
// spatialDims[i] += p;
|
||||
// // Padding for end of axis.
|
||||
// p = (pads.getValue()[i + nDims]).cast<IntegerAttr>().getInt();
|
||||
// spatialDims[i] += p;
|
||||
// }
|
||||
// } else if () {
|
||||
// // Attribute pads has not been provided.
|
||||
// }
|
||||
// Spatial dimensions are computed using the formula:
|
||||
//
|
||||
// dim = (inputDim - kernelDim + startPadding + endPadding) / stride + 1
|
||||
//
|
||||
SmallVector<int64_t, 2> spatialDims;
|
||||
// Number of spatial dimensions.
|
||||
int32_t nDims = dataShape.size() - 2;
|
||||
|
||||
// Initialize dimenions based on the input spatial dimensions.
|
||||
for (int i = 2; i < dataShape.size(); ++i)
|
||||
spatialDims.emplace_back(dataShape[i]);
|
||||
|
||||
// Use kernel_shape attribute if present otherwise use size from weight
|
||||
// argument.
|
||||
if (auto kernel_shape = getAttrOfType<ArrayAttr>(
|
||||
ONNXConvOp::getKernelShapeAttrName())) {
|
||||
if (kernel_shape.getValue().size() != nDims)
|
||||
emitError("kernel_shape length incompatible with spatial dimensions.");
|
||||
for (int i = 0; i < nDims; ++i) {
|
||||
int64_t kernelDim =
|
||||
(kernel_shape.getValue()[i]).cast<IntegerAttr>().getInt();
|
||||
spatialDims[i] -= kernelDim;
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < nDims; ++i)
|
||||
spatialDims[i] -= weightShape[i + 2];
|
||||
}
|
||||
|
||||
// Add padding information.
|
||||
if (autoPad == "NOTSET") {
|
||||
// Use pads to to determine the padding. If attribute is not
|
||||
// present then pads is considered to be all zeros (no padding).
|
||||
if (auto pads = getAttrOfType<ArrayAttr>(
|
||||
ONNXConvOp::getPadsAttrName())) {
|
||||
// pads consists of two entries for each spatial axis.
|
||||
if (pads.getValue().size() != 2 * nDims)
|
||||
emitError("pads size is not twice the spatial size.");
|
||||
|
||||
for (int i = 0; i < nDims; ++i) {
|
||||
// Padding for beginning of axis.
|
||||
int32_t p = (pads.getValue()[i]).cast<IntegerAttr>().getInt();
|
||||
spatialDims[i] += p;
|
||||
// Padding for end of axis.
|
||||
p = (pads.getValue()[i + nDims]).cast<IntegerAttr>().getInt();
|
||||
spatialDims[i] += p;
|
||||
}
|
||||
}
|
||||
} else if (autoPad == "VALID") {
|
||||
// TODO
|
||||
} else if (autoPad == "SAME_UPPER") {
|
||||
// TODO
|
||||
} else if (autoPad == "SAME_LOWER") {
|
||||
// TODO
|
||||
} else {
|
||||
emitError("Unexpected attribute value for auto_pad.");
|
||||
}
|
||||
|
||||
// Strides
|
||||
if (auto strides = getAttrOfType<ArrayAttr>(
|
||||
ONNXConvOp::getStridesAttrName())) {
|
||||
if (strides.getValue().size() != nDims)
|
||||
emitError("strides length incompatible with spatial dimensions.");
|
||||
for (int i = 0; i < nDims; ++i) {
|
||||
int64_t stride =
|
||||
(strides.getValue()[i]).cast<IntegerAttr>().getInt();
|
||||
spatialDims[i] = floor(spatialDims[i] / stride);
|
||||
}
|
||||
}
|
||||
|
||||
for (int i = 0; i < nDims; ++i)
|
||||
spatialDims[i] += 1;
|
||||
|
||||
dims.append(spatialDims.begin(), spatialDims.end());
|
||||
getResult().setType(RankedTensorType::get(dims, dataTy.getElementType()));
|
||||
}
|
||||
|
||||
|
@ -526,12 +576,16 @@ LogicalResult verify(ONNXConvNoBiasOp op) {
|
|||
auto autoPadAttr = op.getAttrOfType<StringAttr>(
|
||||
ONNXConvOp::getAutoPadAttrName());
|
||||
if (!autoPadAttr)
|
||||
op.emitError("ONNXConvNoBiasOp: auto_pad attribute not specified.");
|
||||
op.emitError("auto_pad attribute not specified.");
|
||||
if (autoPadAttr.getValue() != "NOTSET")
|
||||
if (auto pads = op.getAttrOfType<ArrayAttr>(
|
||||
ONNXConvOp::getPadsAttrName()))
|
||||
op.emitError("auto_pad and pads are both set.");
|
||||
|
||||
auto groupAttr =
|
||||
op.getAttrOfType<IntegerAttr>(ONNXConvOp::getGroupAttrName());
|
||||
if (!groupAttr)
|
||||
op.emitError("ONNXConvNoBiasOp: group attribute not specified.");
|
||||
op.emitError("group attribute not specified.");
|
||||
|
||||
return success();
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue