Check channel dimension mismatch only for known dimensions (#2)

Co-authored-by: Gheorghe-Teodor Bercea <gt.bercea@gmail.com>
This commit is contained in:
Gheorghe-Teodor Bercea 2020-03-04 14:34:08 -05:00 committed by GitHub
parent e4c23da4fd
commit 8e1b30e133
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 2 additions and 1 deletions

View File

@ -781,7 +781,8 @@ void ONNXConvNoBiasOp::inferShapes() {
int64_t group =
ONNXConvNoBiasOp::group().getSExtValue(); //.getLimitedValue();
// Check that the X.shape[1] == (W.shape[1] * group) == C condition holds.
if (dataShape[1] != (weightShape[1] * group))
if (dataShape[1] != -1 && weightShape[1] != -1 &&
dataShape[1] != (weightShape[1] * group))
emitError("Channel dimension mismatch");
// Note: the value of the group attribut only impacts the way the