Extract some duplicated code into a helper function.
- Extract code to create result memref's into a ConvertResults function. - Also fix a bug when using reifyReturnTypes: use correct index for result_shape instead of always using the first element. PiperOrigin-RevId: 341852227
This commit is contained in:
parent
d4f2c767d3
commit
745c8aa0b1
|
@ -87,6 +87,32 @@ Value InsertAlloc(Location loc, OpResult result,
|
|||
return alloc;
|
||||
}
|
||||
|
||||
/// Converts the results of the operation `op` to memref types and append them
|
||||
/// to the `results` vector.
|
||||
LogicalResult ConvertResults(Operation* op, SmallVectorImpl<Value>& results,
|
||||
ConversionPatternRewriter& rewriter) {
|
||||
for (auto result : llvm::enumerate(op->getResults())) {
|
||||
RankedTensorType resultType =
|
||||
result.value().getType().dyn_cast<RankedTensorType>();
|
||||
if (!resultType) return failure();
|
||||
|
||||
if (resultType.hasStaticShape()) {
|
||||
results.push_back(InsertAlloc(op->getLoc(), result.value(), &rewriter));
|
||||
continue;
|
||||
}
|
||||
auto shape_type_op = dyn_cast<InferShapedTypeOpInterface>(op);
|
||||
if (!shape_type_op) return failure();
|
||||
|
||||
SmallVector<Value, 1> results_shape;
|
||||
auto status = shape_type_op.reifyReturnTypeShapes(rewriter, results_shape);
|
||||
if (failed(status)) return failure();
|
||||
results.push_back(
|
||||
InsertDynamicAllocAndDealloc(op->getLoc(), result.value(),
|
||||
results_shape[result.index()], &rewriter));
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
template <typename HloOpTy>
|
||||
class HloToLhloOpConverter : public BaseOpConversion<HloOpTy> {
|
||||
public:
|
||||
|
@ -95,29 +121,8 @@ class HloToLhloOpConverter : public BaseOpConversion<HloOpTy> {
|
|||
HloOpTy hloOp, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter& rewriter) const final {
|
||||
Operation* op = hloOp.getOperation();
|
||||
const auto& original_results = op->getResults();
|
||||
SmallVector<Value, 4> buffer_args(operands.begin(), operands.end());
|
||||
for (auto result : llvm::enumerate(original_results)) {
|
||||
RankedTensorType resultType =
|
||||
result.value().getType().dyn_cast<RankedTensorType>();
|
||||
if (!resultType) {
|
||||
return failure();
|
||||
}
|
||||
if (resultType.hasStaticShape()) {
|
||||
buffer_args.push_back(
|
||||
InsertAlloc(op->getLoc(), result.value(), &rewriter));
|
||||
} else {
|
||||
auto shape_type_op = dyn_cast<InferShapedTypeOpInterface>(op);
|
||||
if (!shape_type_op) return failure();
|
||||
|
||||
SmallVector<Value, 1> results_shape;
|
||||
auto status =
|
||||
shape_type_op.reifyReturnTypeShapes(rewriter, results_shape);
|
||||
if (failed(status)) return failure();
|
||||
buffer_args.push_back(InsertDynamicAllocAndDealloc(
|
||||
op->getLoc(), result.value(), results_shape.front(), &rewriter));
|
||||
}
|
||||
}
|
||||
if (failed(ConvertResults(op, buffer_args, rewriter))) return failure();
|
||||
rewriter.create<mhlo::HloToLhloOp<HloOpTy>>(op->getLoc(), llvm::None,
|
||||
buffer_args, op->getAttrs());
|
||||
rewriter.replaceOp(
|
||||
|
@ -139,28 +144,8 @@ class HloToLhloOpConverter<mhlo::DotOp> : public BaseOpConversion<mhlo::DotOp> {
|
|||
mhlo::DotOp hloOp, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter& rewriter) const final {
|
||||
Operation* op = hloOp.getOperation();
|
||||
const auto& original_results = op->getResults();
|
||||
SmallVector<Value, 2> buffer_args(operands.begin(), operands.end());
|
||||
for (auto result : llvm::enumerate(original_results)) {
|
||||
RankedTensorType resultType =
|
||||
result.value().getType().dyn_cast<RankedTensorType>();
|
||||
if (!resultType) {
|
||||
return failure();
|
||||
}
|
||||
if (resultType.hasStaticShape()) {
|
||||
buffer_args.push_back(
|
||||
InsertAlloc(op->getLoc(), result.value(), &rewriter));
|
||||
} else {
|
||||
SmallVector<Value, 1> results_shape;
|
||||
auto shape_type_op = dyn_cast<InferShapedTypeOpInterface>(op);
|
||||
if (!shape_type_op) return failure();
|
||||
if (failed(
|
||||
shape_type_op.reifyReturnTypeShapes(rewriter, results_shape)))
|
||||
return failure();
|
||||
buffer_args.push_back(InsertDynamicAllocAndDealloc(
|
||||
op->getLoc(), result.value(), results_shape.front(), &rewriter));
|
||||
}
|
||||
}
|
||||
if (failed(ConvertResults(op, buffer_args, rewriter))) return failure();
|
||||
|
||||
// TODO(silvasean): Move this helper to MLIR core.
|
||||
auto make_elements_attr = [&rewriter](ArrayRef<int64_t> integers) {
|
||||
|
|
Loading…
Reference in New Issue