Extract some duplicated code into a helper function.
- Extract code to create result memref's into a ConvertResults function. - Also fix a bug when using reifyReturnTypes: use correct index for result_shape instead of always using the first element. PiperOrigin-RevId: 341852227
This commit is contained in:
parent
d4f2c767d3
commit
745c8aa0b1
|
@ -87,6 +87,32 @@ Value InsertAlloc(Location loc, OpResult result,
|
||||||
return alloc;
|
return alloc;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Converts the results of the operation `op` to memref types and append them
|
||||||
|
/// to the `results` vector.
|
||||||
|
LogicalResult ConvertResults(Operation* op, SmallVectorImpl<Value>& results,
|
||||||
|
ConversionPatternRewriter& rewriter) {
|
||||||
|
for (auto result : llvm::enumerate(op->getResults())) {
|
||||||
|
RankedTensorType resultType =
|
||||||
|
result.value().getType().dyn_cast<RankedTensorType>();
|
||||||
|
if (!resultType) return failure();
|
||||||
|
|
||||||
|
if (resultType.hasStaticShape()) {
|
||||||
|
results.push_back(InsertAlloc(op->getLoc(), result.value(), &rewriter));
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
auto shape_type_op = dyn_cast<InferShapedTypeOpInterface>(op);
|
||||||
|
if (!shape_type_op) return failure();
|
||||||
|
|
||||||
|
SmallVector<Value, 1> results_shape;
|
||||||
|
auto status = shape_type_op.reifyReturnTypeShapes(rewriter, results_shape);
|
||||||
|
if (failed(status)) return failure();
|
||||||
|
results.push_back(
|
||||||
|
InsertDynamicAllocAndDealloc(op->getLoc(), result.value(),
|
||||||
|
results_shape[result.index()], &rewriter));
|
||||||
|
}
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
template <typename HloOpTy>
|
template <typename HloOpTy>
|
||||||
class HloToLhloOpConverter : public BaseOpConversion<HloOpTy> {
|
class HloToLhloOpConverter : public BaseOpConversion<HloOpTy> {
|
||||||
public:
|
public:
|
||||||
|
@ -95,29 +121,8 @@ class HloToLhloOpConverter : public BaseOpConversion<HloOpTy> {
|
||||||
HloOpTy hloOp, ArrayRef<Value> operands,
|
HloOpTy hloOp, ArrayRef<Value> operands,
|
||||||
ConversionPatternRewriter& rewriter) const final {
|
ConversionPatternRewriter& rewriter) const final {
|
||||||
Operation* op = hloOp.getOperation();
|
Operation* op = hloOp.getOperation();
|
||||||
const auto& original_results = op->getResults();
|
|
||||||
SmallVector<Value, 4> buffer_args(operands.begin(), operands.end());
|
SmallVector<Value, 4> buffer_args(operands.begin(), operands.end());
|
||||||
for (auto result : llvm::enumerate(original_results)) {
|
if (failed(ConvertResults(op, buffer_args, rewriter))) return failure();
|
||||||
RankedTensorType resultType =
|
|
||||||
result.value().getType().dyn_cast<RankedTensorType>();
|
|
||||||
if (!resultType) {
|
|
||||||
return failure();
|
|
||||||
}
|
|
||||||
if (resultType.hasStaticShape()) {
|
|
||||||
buffer_args.push_back(
|
|
||||||
InsertAlloc(op->getLoc(), result.value(), &rewriter));
|
|
||||||
} else {
|
|
||||||
auto shape_type_op = dyn_cast<InferShapedTypeOpInterface>(op);
|
|
||||||
if (!shape_type_op) return failure();
|
|
||||||
|
|
||||||
SmallVector<Value, 1> results_shape;
|
|
||||||
auto status =
|
|
||||||
shape_type_op.reifyReturnTypeShapes(rewriter, results_shape);
|
|
||||||
if (failed(status)) return failure();
|
|
||||||
buffer_args.push_back(InsertDynamicAllocAndDealloc(
|
|
||||||
op->getLoc(), result.value(), results_shape.front(), &rewriter));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
rewriter.create<mhlo::HloToLhloOp<HloOpTy>>(op->getLoc(), llvm::None,
|
rewriter.create<mhlo::HloToLhloOp<HloOpTy>>(op->getLoc(), llvm::None,
|
||||||
buffer_args, op->getAttrs());
|
buffer_args, op->getAttrs());
|
||||||
rewriter.replaceOp(
|
rewriter.replaceOp(
|
||||||
|
@ -139,28 +144,8 @@ class HloToLhloOpConverter<mhlo::DotOp> : public BaseOpConversion<mhlo::DotOp> {
|
||||||
mhlo::DotOp hloOp, ArrayRef<Value> operands,
|
mhlo::DotOp hloOp, ArrayRef<Value> operands,
|
||||||
ConversionPatternRewriter& rewriter) const final {
|
ConversionPatternRewriter& rewriter) const final {
|
||||||
Operation* op = hloOp.getOperation();
|
Operation* op = hloOp.getOperation();
|
||||||
const auto& original_results = op->getResults();
|
|
||||||
SmallVector<Value, 2> buffer_args(operands.begin(), operands.end());
|
SmallVector<Value, 2> buffer_args(operands.begin(), operands.end());
|
||||||
for (auto result : llvm::enumerate(original_results)) {
|
if (failed(ConvertResults(op, buffer_args, rewriter))) return failure();
|
||||||
RankedTensorType resultType =
|
|
||||||
result.value().getType().dyn_cast<RankedTensorType>();
|
|
||||||
if (!resultType) {
|
|
||||||
return failure();
|
|
||||||
}
|
|
||||||
if (resultType.hasStaticShape()) {
|
|
||||||
buffer_args.push_back(
|
|
||||||
InsertAlloc(op->getLoc(), result.value(), &rewriter));
|
|
||||||
} else {
|
|
||||||
SmallVector<Value, 1> results_shape;
|
|
||||||
auto shape_type_op = dyn_cast<InferShapedTypeOpInterface>(op);
|
|
||||||
if (!shape_type_op) return failure();
|
|
||||||
if (failed(
|
|
||||||
shape_type_op.reifyReturnTypeShapes(rewriter, results_shape)))
|
|
||||||
return failure();
|
|
||||||
buffer_args.push_back(InsertDynamicAllocAndDealloc(
|
|
||||||
op->getLoc(), result.value(), results_shape.front(), &rewriter));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO(silvasean): Move this helper to MLIR core.
|
// TODO(silvasean): Move this helper to MLIR core.
|
||||||
auto make_elements_attr = [&rewriter](ArrayRef<int64_t> integers) {
|
auto make_elements_attr = [&rewriter](ArrayRef<int64_t> integers) {
|
||||||
|
|
Loading…
Reference in New Issue