Fix kernel dimensions.

This commit is contained in:
Doru Bercea 2020-01-22 10:10:06 -05:00
parent 169236a8fc
commit de77758faf
1 changed files with 4 additions and 4 deletions

View File

@ -509,15 +509,15 @@ void ONNXConvNoBiasOp::inferShapes() {
// argument. // argument.
SmallVector<int64_t, 2> kernelDims; SmallVector<int64_t, 2> kernelDims;
if (auto kernelShape = getAttrOfType<ArrayAttr>( if (auto kernelShape = getAttrOfType<ArrayAttr>(
ONNXConvOp::getKernelShapeAttrName())) { ONNXConvOp::getKernelShapeAttrName())) {
if (kernelShape.getValue().size() != nDims) if (kernelShape.getValue().size() != nDims)
emitError("kernel_shape length incompatible with spatial dimensions."); emitError("kernel_shape length incompatible with spatial dimensions.");
for (int i = 0; i < nDims; ++i) for (int i = 0; i < nDims; ++i)
kernelDims[i] = kernelDims.emplace_back(
(kernelShape.getValue()[i]).cast<IntegerAttr>().getInt(); (kernelShape.getValue()[i]).cast<IntegerAttr>().getInt());
} else { } else {
for (int i = 0; i < nDims; ++i) for (int i = 0; i < nDims; ++i)
kernelDims[i] = weightShape[i + 2]; kernelDims.emplace_back(weightShape[i + 2]);
} }
// Subtract kernel dimensions from input data dimensions. // Subtract kernel dimensions from input data dimensions.