Add special case for 1-D matrix multiplication.
This commit is contained in:
parent
deef363309
commit
170296b7c6
|
@ -318,12 +318,28 @@ void ONNXMatMulOp::inferShapes() {
|
||||||
if (!getOperand(0).getType().isa<RankedTensorType>() ||
|
if (!getOperand(0).getType().isa<RankedTensorType>() ||
|
||||||
!getOperand(1).getType().isa<RankedTensorType>())
|
!getOperand(1).getType().isa<RankedTensorType>())
|
||||||
return;
|
return;
|
||||||
auto lhsTy = getOperand(0).getType().cast<RankedTensorType>();
|
|
||||||
auto rhsTy = getOperand(1).getType().cast<RankedTensorType>();
|
auto lhsTy = getOperand(0)->getType().cast<RankedTensorType>();
|
||||||
|
auto rhsTy = getOperand(1)->getType().cast<RankedTensorType>();
|
||||||
|
|
||||||
SmallVector<int64_t, 2> dims;
|
SmallVector<int64_t, 2> dims;
|
||||||
dims.emplace_back(lhsTy.getShape()[0]);
|
auto lhsShape = lhsTy.getShape();
|
||||||
dims.emplace_back(rhsTy.getShape()[1]);
|
auto rhsShape = rhsTy.getShape();
|
||||||
getResult().setType(RankedTensorType::get(dims, lhsTy.getElementType()));
|
if (lhsShape.size() == 1 && rhsShape.size() == 1) {
|
||||||
|
// Special case when both arrays are 1-dimensional and according to
|
||||||
|
// numpy rules the types need to be extended to 1xN and Nx1. Helper sizes
|
||||||
|
// need to be removed after the multiplication but cannot be removed if all
|
||||||
|
// remaining sizes are 1.
|
||||||
|
if (lhsShape[0] != -1 && rhsShape[0] != -1 &&
|
||||||
|
lhsShape[0] != rhsShape[0])
|
||||||
|
emitError("Attempt to multiply incompatible matrices.");
|
||||||
|
dims.emplace_back(1);
|
||||||
|
} else {
|
||||||
|
dims.emplace_back(lhsShape.size() == 1 ? 1 : lhsShape[0]);
|
||||||
|
dims.emplace_back(rhsShape.size() == 1 ? 1 : rhsShape[1]);
|
||||||
|
}
|
||||||
|
|
||||||
|
getResult()->setType(RankedTensorType::get(dims, lhsTy.getElementType()));
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO:
|
// TODO:
|
||||||
|
|
Loading…
Reference in New Issue