Fix kernel dimensions.
This commit is contained in:
parent
169236a8fc
commit
de77758faf
|
@ -509,15 +509,15 @@ void ONNXConvNoBiasOp::inferShapes() {
|
||||||
// argument.
|
// argument.
|
||||||
SmallVector<int64_t, 2> kernelDims;
|
SmallVector<int64_t, 2> kernelDims;
|
||||||
if (auto kernelShape = getAttrOfType<ArrayAttr>(
|
if (auto kernelShape = getAttrOfType<ArrayAttr>(
|
||||||
ONNXConvOp::getKernelShapeAttrName())) {
|
ONNXConvOp::getKernelShapeAttrName())) {
|
||||||
if (kernelShape.getValue().size() != nDims)
|
if (kernelShape.getValue().size() != nDims)
|
||||||
emitError("kernel_shape length incompatible with spatial dimensions.");
|
emitError("kernel_shape length incompatible with spatial dimensions.");
|
||||||
for (int i = 0; i < nDims; ++i)
|
for (int i = 0; i < nDims; ++i)
|
||||||
kernelDims[i] =
|
kernelDims.emplace_back(
|
||||||
(kernelShape.getValue()[i]).cast<IntegerAttr>().getInt();
|
(kernelShape.getValue()[i]).cast<IntegerAttr>().getInt());
|
||||||
} else {
|
} else {
|
||||||
for (int i = 0; i < nDims; ++i)
|
for (int i = 0; i < nDims; ++i)
|
||||||
kernelDims[i] = weightShape[i + 2];
|
kernelDims.emplace_back(weightShape[i + 2]);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Subtract kernel dimensions from input data dimensions.
|
// Subtract kernel dimensions from input data dimensions.
|
||||||
|
|
Loading…
Reference in New Issue