Add check for matrix size match for 1 and 2 dimenisional cases.
This commit is contained in:
parent
da0e9b01b1
commit
e091825896
|
@ -389,10 +389,18 @@ void ONNXMatMulOp::inferShapes() {
|
||||||
dims.emplace_back(rhsShape[rightDims - 1]);
|
dims.emplace_back(rhsShape[rightDims - 1]);
|
||||||
} else {
|
} else {
|
||||||
// This case covers all remaining combinations of 1 and 2-D matrices.
|
// This case covers all remaining combinations of 1 and 2-D matrices.
|
||||||
if (lhsShape.size() != 1)
|
int64_t lhsDim = lhsShape[0];
|
||||||
|
int64_t rhsDim = rhsShape[0];
|
||||||
|
if (lhsShape.size() > 1) {
|
||||||
|
lhsDim = lhsShape[1];
|
||||||
dims.emplace_back(lhsShape[0]);
|
dims.emplace_back(lhsShape[0]);
|
||||||
|
}
|
||||||
|
|
||||||
if (rhsShape.size() != 1)
|
// Check legality of matrix multiplication.
|
||||||
|
if (lhsDim != -1 && rhsDim != -1 && lhsDim != rhsDim)
|
||||||
|
emitError("Attempt to multiply incompatible matrices.");
|
||||||
|
|
||||||
|
if (rhsShape.size() > 1)
|
||||||
dims.emplace_back(rhsShape[1]);
|
dims.emplace_back(rhsShape[1]);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue