Fix kernel dimensions.

This commit is contained in:
Doru Bercea 2020-01-22 10:10:06 -05:00
parent 169236a8fc
commit de77758faf
1 changed files with 4 additions and 4 deletions

View File

@ -509,15 +509,15 @@ void ONNXConvNoBiasOp::inferShapes() {
// argument.
SmallVector<int64_t, 2> kernelDims;
if (auto kernelShape = getAttrOfType<ArrayAttr>(
ONNXConvOp::getKernelShapeAttrName())) {
ONNXConvOp::getKernelShapeAttrName())) {
if (kernelShape.getValue().size() != nDims)
emitError("kernel_shape length incompatible with spatial dimensions.");
for (int i = 0; i < nDims; ++i)
kernelDims[i] =
(kernelShape.getValue()[i]).cast<IntegerAttr>().getInt();
kernelDims.emplace_back(
(kernelShape.getValue()[i]).cast<IntegerAttr>().getInt());
} else {
for (int i = 0; i < nDims; ++i)
kernelDims[i] = weightShape[i + 2];
kernelDims.emplace_back(weightShape[i + 2]);
}
// Subtract kernel dimensions from input data dimensions.