Fix kernel dimensions.
This commit is contained in:
parent
169236a8fc
commit
de77758faf
|
@ -509,15 +509,15 @@ void ONNXConvNoBiasOp::inferShapes() {
|
|||
// argument.
|
||||
SmallVector<int64_t, 2> kernelDims;
|
||||
if (auto kernelShape = getAttrOfType<ArrayAttr>(
|
||||
ONNXConvOp::getKernelShapeAttrName())) {
|
||||
ONNXConvOp::getKernelShapeAttrName())) {
|
||||
if (kernelShape.getValue().size() != nDims)
|
||||
emitError("kernel_shape length incompatible with spatial dimensions.");
|
||||
for (int i = 0; i < nDims; ++i)
|
||||
kernelDims[i] =
|
||||
(kernelShape.getValue()[i]).cast<IntegerAttr>().getInt();
|
||||
kernelDims.emplace_back(
|
||||
(kernelShape.getValue()[i]).cast<IntegerAttr>().getInt());
|
||||
} else {
|
||||
for (int i = 0; i < nDims; ++i)
|
||||
kernelDims[i] = weightShape[i + 2];
|
||||
kernelDims.emplace_back(weightShape[i + 2]);
|
||||
}
|
||||
|
||||
// Subtract kernel dimensions from input data dimensions.
|
||||
|
|
Loading…
Reference in New Issue