Updates LLVM usage to match
[4964d75d7078](https://github.com/llvm/llvm-project/commit/4964d75d7078)

PiperOrigin-RevId: 330713009
This commit is contained in:
A. Unique TensorFlower 2020-09-09 06:49:15 -07:00 committed by TensorFlow MLIR Team
parent 81d51d810b
commit f46ba09653
3 changed files with 14 additions and 28 deletions

View File

@ -1,2 +1,2 @@
8d9c13f37d2081c11186718ae8b5aef8b507d152
4964d75d7078b932ac6b17c1990adaa6eada75c1

View File

@ -51,9 +51,9 @@ void PopulateMhloToStdPatterns(OwningRewritePatternList *patterns,
// Collection of rewrite patterns for lowering of HLO to LHLO dialect.
void populateHLOToLHLOConversionPattern(
MLIRContext *context, BufferAssignmentPlacer *bufferAssignment,
BufferAssignmentTypeConverter *converter,
MLIRContext *context, BufferAssignmentTypeConverter *converter,
OwningRewritePatternList *patterns);
// Collection of rewrite patterns for lowering of HLO to Linalg dialect.
void populateHLOToLinalgConversionPattern(MLIRContext *context,
OwningRewritePatternList *patterns);

View File

@ -78,7 +78,6 @@ Value InsertDynamicAllocAndDealloc(Location loc, Value result,
}
Value InsertAlloc(Location loc, OpResult result,
BufferAssignmentPlacer* bufferAssignment,
ConversionPatternRewriter* rewriter) {
auto result_type = result.getType().dyn_cast<ShapedType>();
if (!result_type || !result_type.hasStaticShape()) {
@ -88,8 +87,7 @@ Value InsertAlloc(Location loc, OpResult result,
auto memref_type =
MemRefType::get(result_type.getShape(), result_type.getElementType());
OpBuilder::InsertionGuard guard(*rewriter);
rewriter->restoreInsertionPoint(
bufferAssignment->computeAllocPosition(result));
rewriter->setInsertionPoint(result.getDefiningOp());
auto alloc = rewriter->create<AllocOp>(loc, memref_type);
return alloc;
}
@ -111,8 +109,8 @@ class HloToLhloOpConverter : public BaseOpConversion<HloOpTy> {
return failure();
}
if (resultType.hasStaticShape()) {
buffer_args.push_back(InsertAlloc(op->getLoc(), result.value(),
this->bufferAssignment, &rewriter));
buffer_args.push_back(
InsertAlloc(op->getLoc(), result.value(), &rewriter));
} else {
SmallVector<Value, 1> results_shape;
auto shape_type_op = dyn_cast<InferShapedTypeOpInterface>(op);
@ -259,8 +257,7 @@ struct HloToLhloReduceOpConverter : public BaseOpConversion<mhlo::ReduceOp> {
const auto& original_results = op.getResults();
SmallVector<Value, 4> buffer_args(operands.begin(), operands.end());
for (auto result : original_results) {
buffer_args.push_back(
InsertAlloc(loc, result, this->bufferAssignment, &rewriter));
buffer_args.push_back(InsertAlloc(loc, result, &rewriter));
}
auto new_op = rewriter.create<lmhlo::ReduceOp>(loc, llvm::None, buffer_args,
op.getAttrs());
@ -432,22 +429,12 @@ struct HloLegalizeToLhlo
isMemRefType);
});
auto module = getOperation();
WalkResult result = module.walk([&](FuncOp func) -> WalkResult {
BufferAssignmentPlacer bufferAssignment(func);
OwningRewritePatternList patterns;
populateHLOToLHLOConversionPattern(func.getContext(), &bufferAssignment,
&converter, &patterns);
// FIXME: we likely need to call converter.setResultConversionKind() to
// respect results_escape_function.
populateWithBufferAssignmentOpConversionPatterns<
mlir::ReturnOp, mlir::ReturnOp, lmhlo::CopyOp>(
&context, &bufferAssignment, &converter, &patterns);
return applyPartialConversion(func, target, patterns);
});
if (result.wasInterrupted()) {
populateHLOToLHLOConversionPattern(&context, &converter, &patterns);
populateWithBufferAssignmentOpConversionPatterns<
mlir::ReturnOp, mlir::ReturnOp, lmhlo::CopyOp>(&context, &converter,
&patterns);
if (failed(applyPartialConversion(getOperation(), target, patterns)))
signalPassFailure();
}
}
private:
@ -460,8 +447,7 @@ struct HloLegalizeToLhlo
} // namespace
void populateHLOToLHLOConversionPattern(
MLIRContext* context, BufferAssignmentPlacer* bufferAssignment,
BufferAssignmentTypeConverter* converter,
MLIRContext* context, BufferAssignmentTypeConverter* converter,
OwningRewritePatternList* patterns) {
// clang-format off
patterns->insert<
@ -506,7 +492,7 @@ void populateHLOToLHLOConversionPattern(
HloToLhloReturnOpConverter,
HloToLhloTensorLoadOpConverter,
HloToLhloTensorStoreOpConverter
>(context, bufferAssignment, converter);
>(context, converter);
// clang-format on
}