Add special case for 1-D matrix multiplication.

This commit is contained in:
Doru Bercea 2020-01-09 15:30:57 -05:00
parent deef363309
commit 170296b7c6
1 changed files with 21 additions and 5 deletions

View File

@ -318,12 +318,28 @@ void ONNXMatMulOp::inferShapes() {
if (!getOperand(0).getType().isa<RankedTensorType>() || if (!getOperand(0).getType().isa<RankedTensorType>() ||
!getOperand(1).getType().isa<RankedTensorType>()) !getOperand(1).getType().isa<RankedTensorType>())
return; return;
auto lhsTy = getOperand(0).getType().cast<RankedTensorType>();
auto rhsTy = getOperand(1).getType().cast<RankedTensorType>(); auto lhsTy = getOperand(0)->getType().cast<RankedTensorType>();
auto rhsTy = getOperand(1)->getType().cast<RankedTensorType>();
SmallVector<int64_t, 2> dims; SmallVector<int64_t, 2> dims;
dims.emplace_back(lhsTy.getShape()[0]); auto lhsShape = lhsTy.getShape();
dims.emplace_back(rhsTy.getShape()[1]); auto rhsShape = rhsTy.getShape();
getResult().setType(RankedTensorType::get(dims, lhsTy.getElementType())); if (lhsShape.size() == 1 && rhsShape.size() == 1) {
// Special case when both arrays are 1-dimensional and according to
// numpy rules the types need to be extended to 1xN and Nx1. Helper sizes
// need to be removed after the multiplication but cannot be removed if all
// remaining sizes are 1.
if (lhsShape[0] != -1 && rhsShape[0] != -1 &&
lhsShape[0] != rhsShape[0])
emitError("Attempt to multiply incompatible matrices.");
dims.emplace_back(1);
} else {
dims.emplace_back(lhsShape.size() == 1 ? 1 : lhsShape[0]);
dims.emplace_back(rhsShape.size() == 1 ? 1 : rhsShape[1]);
}
getResult()->setType(RankedTensorType::get(dims, lhsTy.getElementType()));
} }
// TODO: // TODO: