Fix reference error.
This commit is contained in:
parent
0bc07ef661
commit
1784ec2314
|
@ -328,8 +328,8 @@ void ONNXMatMulOp::inferShapes() {
|
||||||
!getOperand(1).getType().isa<RankedTensorType>())
|
!getOperand(1).getType().isa<RankedTensorType>())
|
||||||
return;
|
return;
|
||||||
|
|
||||||
auto lhsTy = getOperand(0)->getType().cast<RankedTensorType>();
|
auto lhsTy = getOperand(0).getType().cast<RankedTensorType>();
|
||||||
auto rhsTy = getOperand(1)->getType().cast<RankedTensorType>();
|
auto rhsTy = getOperand(1).getType().cast<RankedTensorType>();
|
||||||
|
|
||||||
SmallVector<int64_t, 2> dims;
|
SmallVector<int64_t, 2> dims;
|
||||||
auto lhsShape = lhsTy.getShape();
|
auto lhsShape = lhsTy.getShape();
|
||||||
|
@ -413,7 +413,7 @@ void ONNXMatMulOp::inferShapes() {
|
||||||
dims.emplace_back(rhsShape[1]);
|
dims.emplace_back(rhsShape[1]);
|
||||||
}
|
}
|
||||||
|
|
||||||
getResult()->setType(RankedTensorType::get(dims, lhsTy.getElementType()));
|
getResult().setType(RankedTensorType::get(dims, lhsTy.getElementType()));
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
Loading…
Reference in New Issue