Extract some duplicated code into a helper function.

- Extract code to create result memref's into a ConvertResults function.
- Also fix a bug when using reifyReturnTypes: use correct index for result_shape instead
  of always using the first element.

PiperOrigin-RevId: 341852227
This commit is contained in:
Rahul Joshi 2020-11-11 10:00:22 -08:00 committed by TensorFlow MLIR Team
parent d4f2c767d3
commit 745c8aa0b1
1 changed files with 28 additions and 43 deletions

View File

@ -87,6 +87,32 @@ Value InsertAlloc(Location loc, OpResult result,
return alloc;
}
/// Converts the results of the operation `op` to memref types and append them
/// to the `results` vector.
LogicalResult ConvertResults(Operation* op, SmallVectorImpl<Value>& results,
ConversionPatternRewriter& rewriter) {
for (auto result : llvm::enumerate(op->getResults())) {
RankedTensorType resultType =
result.value().getType().dyn_cast<RankedTensorType>();
if (!resultType) return failure();
if (resultType.hasStaticShape()) {
results.push_back(InsertAlloc(op->getLoc(), result.value(), &rewriter));
continue;
}
auto shape_type_op = dyn_cast<InferShapedTypeOpInterface>(op);
if (!shape_type_op) return failure();
SmallVector<Value, 1> results_shape;
auto status = shape_type_op.reifyReturnTypeShapes(rewriter, results_shape);
if (failed(status)) return failure();
results.push_back(
InsertDynamicAllocAndDealloc(op->getLoc(), result.value(),
results_shape[result.index()], &rewriter));
}
return success();
}
template <typename HloOpTy>
class HloToLhloOpConverter : public BaseOpConversion<HloOpTy> {
public:
@ -95,29 +121,8 @@ class HloToLhloOpConverter : public BaseOpConversion<HloOpTy> {
HloOpTy hloOp, ArrayRef<Value> operands,
ConversionPatternRewriter& rewriter) const final {
Operation* op = hloOp.getOperation();
const auto& original_results = op->getResults();
SmallVector<Value, 4> buffer_args(operands.begin(), operands.end());
for (auto result : llvm::enumerate(original_results)) {
RankedTensorType resultType =
result.value().getType().dyn_cast<RankedTensorType>();
if (!resultType) {
return failure();
}
if (resultType.hasStaticShape()) {
buffer_args.push_back(
InsertAlloc(op->getLoc(), result.value(), &rewriter));
} else {
auto shape_type_op = dyn_cast<InferShapedTypeOpInterface>(op);
if (!shape_type_op) return failure();
SmallVector<Value, 1> results_shape;
auto status =
shape_type_op.reifyReturnTypeShapes(rewriter, results_shape);
if (failed(status)) return failure();
buffer_args.push_back(InsertDynamicAllocAndDealloc(
op->getLoc(), result.value(), results_shape.front(), &rewriter));
}
}
if (failed(ConvertResults(op, buffer_args, rewriter))) return failure();
rewriter.create<mhlo::HloToLhloOp<HloOpTy>>(op->getLoc(), llvm::None,
buffer_args, op->getAttrs());
rewriter.replaceOp(
@ -139,28 +144,8 @@ class HloToLhloOpConverter<mhlo::DotOp> : public BaseOpConversion<mhlo::DotOp> {
mhlo::DotOp hloOp, ArrayRef<Value> operands,
ConversionPatternRewriter& rewriter) const final {
Operation* op = hloOp.getOperation();
const auto& original_results = op->getResults();
SmallVector<Value, 2> buffer_args(operands.begin(), operands.end());
for (auto result : llvm::enumerate(original_results)) {
RankedTensorType resultType =
result.value().getType().dyn_cast<RankedTensorType>();
if (!resultType) {
return failure();
}
if (resultType.hasStaticShape()) {
buffer_args.push_back(
InsertAlloc(op->getLoc(), result.value(), &rewriter));
} else {
SmallVector<Value, 1> results_shape;
auto shape_type_op = dyn_cast<InferShapedTypeOpInterface>(op);
if (!shape_type_op) return failure();
if (failed(
shape_type_op.reifyReturnTypeShapes(rewriter, results_shape)))
return failure();
buffer_args.push_back(InsertDynamicAllocAndDealloc(
op->getLoc(), result.value(), results_shape.front(), &rewriter));
}
}
if (failed(ConvertResults(op, buffer_args, rewriter))) return failure();
// TODO(silvasean): Move this helper to MLIR core.
auto make_elements_attr = [&rewriter](ArrayRef<int64_t> integers) {