Add check for matrix size match for 1 and 2 dimenisional cases.

This commit is contained in:
Doru Bercea 2020-01-10 15:26:29 -05:00
parent da0e9b01b1
commit e091825896
1 changed files with 10 additions and 2 deletions

View File

@ -389,10 +389,18 @@ void ONNXMatMulOp::inferShapes() {
dims.emplace_back(rhsShape[rightDims - 1]); dims.emplace_back(rhsShape[rightDims - 1]);
} else { } else {
// This case covers all remaining combinations of 1 and 2-D matrices. // This case covers all remaining combinations of 1 and 2-D matrices.
if (lhsShape.size() != 1) int64_t lhsDim = lhsShape[0];
int64_t rhsDim = rhsShape[0];
if (lhsShape.size() > 1) {
lhsDim = lhsShape[1];
dims.emplace_back(lhsShape[0]); dims.emplace_back(lhsShape[0]);
}
if (rhsShape.size() != 1) // Check legality of matrix multiplication.
if (lhsDim != -1 && rhsDim != -1 && lhsDim != rhsDim)
emitError("Attempt to multiply incompatible matrices.");
if (rhsShape.size() > 1)
dims.emplace_back(rhsShape[1]); dims.emplace_back(rhsShape[1]);
} }