Add shape inference method.
This commit is contained in:
parent
3fe0f2e735
commit
ec9e023f04
|
@ -469,52 +469,102 @@ void ONNXConvNoBiasOp::inferShapes() {
|
||||||
auto dataShape = dataTy.getShape();
|
auto dataShape = dataTy.getShape();
|
||||||
auto weightShape = weightTy.getShape();
|
auto weightShape = weightTy.getShape();
|
||||||
|
|
||||||
|
// Check that shape of weight and data have same length.
|
||||||
if (dataShape.size() != weightShape.size())
|
if (dataShape.size() != weightShape.size())
|
||||||
emitError("ConvNoBias: weight size not compatible with data size.");
|
emitError("Weight size not compatible with data size.");
|
||||||
|
|
||||||
|
// Required attribute auto_pad defaults to NOTSET.
|
||||||
|
auto autoPad = getAttrOfType<StringAttr>(
|
||||||
|
ONNXConvOp::getAutoPadAttrName()).getValue();
|
||||||
// Group is a required attribute and should have default value of 1.
|
// Group is a required attribute and should have default value of 1.
|
||||||
int64_t group = getAttrOfType<IntegerAttr>(
|
int64_t group = getAttrOfType<IntegerAttr>(
|
||||||
ONNXConvOp::getGroupAttrName()).getInt();
|
ONNXConvOp::getGroupAttrName()).getInt();
|
||||||
if (!group)
|
|
||||||
emitError("ConvNoBias: group attribute missing.");
|
|
||||||
|
|
||||||
// Check that the X.shape[1] == (W.shape[1] * group) == C condition holds.
|
// Check that the X.shape[1] == (W.shape[1] * group) == C condition holds.
|
||||||
if (dataShape[1] != (weightShape[1] * group))
|
if (dataShape[1] != (weightShape[1] * group))
|
||||||
emitError("ConvNoBias: channel dimension mismatch.");
|
emitError("Channel dimension mismatch.");
|
||||||
|
|
||||||
// Required attributes.
|
|
||||||
auto auto_pad = getAttrOfType<StringAttr>(
|
|
||||||
ONNXConvOp::getAutoPadAttrName());
|
|
||||||
auto pads = getAttrOfType<ArrayAttr>(
|
|
||||||
ONNXConvOp::getPadsAttrName());
|
|
||||||
|
|
||||||
|
// First two output dimensions consist of the number of batches and the
|
||||||
|
// number of kernels being applied.
|
||||||
|
//
|
||||||
SmallVector<int64_t, 2> dims;
|
SmallVector<int64_t, 2> dims;
|
||||||
// Insert batch size.
|
// Insert batch size.
|
||||||
dims.emplace_back(dataShape[0]);
|
dims.emplace_back(dataShape[0]);
|
||||||
// Insert number of filters being applied (number of output channels).
|
// Insert number of filters being applied (number of output channels).
|
||||||
dims.emplace_back(weightShape[0]);
|
dims.emplace_back(weightShape[0]);
|
||||||
|
|
||||||
// // Compute the spatial dimensions.
|
// Spatial dimensions are computed using the formula:
|
||||||
// SmallVector<int64_t, 2> spatialDims;
|
//
|
||||||
// // Number of spatial dimensions.
|
// dim = (inputDim - kernelDim + startPadding + endPadding) / stride + 1
|
||||||
// int32_t nDims = dataTy.size() - 2;
|
//
|
||||||
// // Initialize dimenions based on the input and weight spatial dimensions.
|
SmallVector<int64_t, 2> spatialDims;
|
||||||
// for (int i = 2; i < dataTy.size(); ++i)
|
// Number of spatial dimensions.
|
||||||
// spatialDims.emplace_back(dataTy[i] - weightTy[i]);
|
int32_t nDims = dataShape.size() - 2;
|
||||||
// // Add padding information.
|
|
||||||
// if () {
|
|
||||||
// for (int i = 0; i < nDims; ++i) {
|
|
||||||
// // Padding for beginning of axis.
|
|
||||||
// int32_t p = (pads.getValue()[i]).cast<IntegerAttr>().getInt();
|
|
||||||
// spatialDims[i] += p;
|
|
||||||
// // Padding for end of axis.
|
|
||||||
// p = (pads.getValue()[i + nDims]).cast<IntegerAttr>().getInt();
|
|
||||||
// spatialDims[i] += p;
|
|
||||||
// }
|
|
||||||
// } else if () {
|
|
||||||
// // Attribute pads has not been provided.
|
|
||||||
// }
|
|
||||||
|
|
||||||
|
// Initialize dimenions based on the input spatial dimensions.
|
||||||
|
for (int i = 2; i < dataShape.size(); ++i)
|
||||||
|
spatialDims.emplace_back(dataShape[i]);
|
||||||
|
|
||||||
|
// Use kernel_shape attribute if present otherwise use size from weight
|
||||||
|
// argument.
|
||||||
|
if (auto kernel_shape = getAttrOfType<ArrayAttr>(
|
||||||
|
ONNXConvOp::getKernelShapeAttrName())) {
|
||||||
|
if (kernel_shape.getValue().size() != nDims)
|
||||||
|
emitError("kernel_shape length incompatible with spatial dimensions.");
|
||||||
|
for (int i = 0; i < nDims; ++i) {
|
||||||
|
int64_t kernelDim =
|
||||||
|
(kernel_shape.getValue()[i]).cast<IntegerAttr>().getInt();
|
||||||
|
spatialDims[i] -= kernelDim;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for (int i = 0; i < nDims; ++i)
|
||||||
|
spatialDims[i] -= weightShape[i + 2];
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add padding information.
|
||||||
|
if (autoPad == "NOTSET") {
|
||||||
|
// Use pads to to determine the padding. If attribute is not
|
||||||
|
// present then pads is considered to be all zeros (no padding).
|
||||||
|
if (auto pads = getAttrOfType<ArrayAttr>(
|
||||||
|
ONNXConvOp::getPadsAttrName())) {
|
||||||
|
// pads consists of two entries for each spatial axis.
|
||||||
|
if (pads.getValue().size() != 2 * nDims)
|
||||||
|
emitError("pads size is not twice the spatial size.");
|
||||||
|
|
||||||
|
for (int i = 0; i < nDims; ++i) {
|
||||||
|
// Padding for beginning of axis.
|
||||||
|
int32_t p = (pads.getValue()[i]).cast<IntegerAttr>().getInt();
|
||||||
|
spatialDims[i] += p;
|
||||||
|
// Padding for end of axis.
|
||||||
|
p = (pads.getValue()[i + nDims]).cast<IntegerAttr>().getInt();
|
||||||
|
spatialDims[i] += p;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else if (autoPad == "VALID") {
|
||||||
|
// TODO
|
||||||
|
} else if (autoPad == "SAME_UPPER") {
|
||||||
|
// TODO
|
||||||
|
} else if (autoPad == "SAME_LOWER") {
|
||||||
|
// TODO
|
||||||
|
} else {
|
||||||
|
emitError("Unexpected attribute value for auto_pad.");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Strides
|
||||||
|
if (auto strides = getAttrOfType<ArrayAttr>(
|
||||||
|
ONNXConvOp::getStridesAttrName())) {
|
||||||
|
if (strides.getValue().size() != nDims)
|
||||||
|
emitError("strides length incompatible with spatial dimensions.");
|
||||||
|
for (int i = 0; i < nDims; ++i) {
|
||||||
|
int64_t stride =
|
||||||
|
(strides.getValue()[i]).cast<IntegerAttr>().getInt();
|
||||||
|
spatialDims[i] = floor(spatialDims[i] / stride);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int i = 0; i < nDims; ++i)
|
||||||
|
spatialDims[i] += 1;
|
||||||
|
|
||||||
|
dims.append(spatialDims.begin(), spatialDims.end());
|
||||||
getResult().setType(RankedTensorType::get(dims, dataTy.getElementType()));
|
getResult().setType(RankedTensorType::get(dims, dataTy.getElementType()));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -526,12 +576,16 @@ LogicalResult verify(ONNXConvNoBiasOp op) {
|
||||||
auto autoPadAttr = op.getAttrOfType<StringAttr>(
|
auto autoPadAttr = op.getAttrOfType<StringAttr>(
|
||||||
ONNXConvOp::getAutoPadAttrName());
|
ONNXConvOp::getAutoPadAttrName());
|
||||||
if (!autoPadAttr)
|
if (!autoPadAttr)
|
||||||
op.emitError("ONNXConvNoBiasOp: auto_pad attribute not specified.");
|
op.emitError("auto_pad attribute not specified.");
|
||||||
|
if (autoPadAttr.getValue() != "NOTSET")
|
||||||
|
if (auto pads = op.getAttrOfType<ArrayAttr>(
|
||||||
|
ONNXConvOp::getPadsAttrName()))
|
||||||
|
op.emitError("auto_pad and pads are both set.");
|
||||||
|
|
||||||
auto groupAttr =
|
auto groupAttr =
|
||||||
op.getAttrOfType<IntegerAttr>(ONNXConvOp::getGroupAttrName());
|
op.getAttrOfType<IntegerAttr>(ONNXConvOp::getGroupAttrName());
|
||||||
if (!groupAttr)
|
if (!groupAttr)
|
||||||
op.emitError("ONNXConvNoBiasOp: group attribute not specified.");
|
op.emitError("group attribute not specified.");
|
||||||
|
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue