Add check for matrix size match for 1 and 2 dimenisional cases.
This commit is contained in:
parent
da0e9b01b1
commit
e091825896
|
@ -389,10 +389,18 @@ void ONNXMatMulOp::inferShapes() {
|
|||
dims.emplace_back(rhsShape[rightDims - 1]);
|
||||
} else {
|
||||
// This case covers all remaining combinations of 1 and 2-D matrices.
|
||||
if (lhsShape.size() != 1)
|
||||
int64_t lhsDim = lhsShape[0];
|
||||
int64_t rhsDim = rhsShape[0];
|
||||
if (lhsShape.size() > 1) {
|
||||
lhsDim = lhsShape[1];
|
||||
dims.emplace_back(lhsShape[0]);
|
||||
}
|
||||
|
||||
if (rhsShape.size() != 1)
|
||||
// Check legality of matrix multiplication.
|
||||
if (lhsDim != -1 && rhsDim != -1 && lhsDim != rhsDim)
|
||||
emitError("Attempt to multiply incompatible matrices.");
|
||||
|
||||
if (rhsShape.size() > 1)
|
||||
dims.emplace_back(rhsShape[1]);
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue