Updates LLVM usage to match
[4964d75d7078](https://github.com/llvm/llvm-project/commit/4964d75d7078)

PiperOrigin-RevId: 330713009
This commit is contained in:
A. Unique TensorFlower 2020-09-09 06:49:15 -07:00 committed by TensorFlow MLIR Team
parent 81d51d810b
commit f46ba09653
3 changed files with 14 additions and 28 deletions

View File

@ -1,2 +1,2 @@
8d9c13f37d2081c11186718ae8b5aef8b507d152 4964d75d7078b932ac6b17c1990adaa6eada75c1

View File

@ -51,9 +51,9 @@ void PopulateMhloToStdPatterns(OwningRewritePatternList *patterns,
// Collection of rewrite patterns for lowering of HLO to LHLO dialect. // Collection of rewrite patterns for lowering of HLO to LHLO dialect.
void populateHLOToLHLOConversionPattern( void populateHLOToLHLOConversionPattern(
MLIRContext *context, BufferAssignmentPlacer *bufferAssignment, MLIRContext *context, BufferAssignmentTypeConverter *converter,
BufferAssignmentTypeConverter *converter,
OwningRewritePatternList *patterns); OwningRewritePatternList *patterns);
// Collection of rewrite patterns for lowering of HLO to Linalg dialect. // Collection of rewrite patterns for lowering of HLO to Linalg dialect.
void populateHLOToLinalgConversionPattern(MLIRContext *context, void populateHLOToLinalgConversionPattern(MLIRContext *context,
OwningRewritePatternList *patterns); OwningRewritePatternList *patterns);

View File

@ -78,7 +78,6 @@ Value InsertDynamicAllocAndDealloc(Location loc, Value result,
} }
Value InsertAlloc(Location loc, OpResult result, Value InsertAlloc(Location loc, OpResult result,
BufferAssignmentPlacer* bufferAssignment,
ConversionPatternRewriter* rewriter) { ConversionPatternRewriter* rewriter) {
auto result_type = result.getType().dyn_cast<ShapedType>(); auto result_type = result.getType().dyn_cast<ShapedType>();
if (!result_type || !result_type.hasStaticShape()) { if (!result_type || !result_type.hasStaticShape()) {
@ -88,8 +87,7 @@ Value InsertAlloc(Location loc, OpResult result,
auto memref_type = auto memref_type =
MemRefType::get(result_type.getShape(), result_type.getElementType()); MemRefType::get(result_type.getShape(), result_type.getElementType());
OpBuilder::InsertionGuard guard(*rewriter); OpBuilder::InsertionGuard guard(*rewriter);
rewriter->restoreInsertionPoint( rewriter->setInsertionPoint(result.getDefiningOp());
bufferAssignment->computeAllocPosition(result));
auto alloc = rewriter->create<AllocOp>(loc, memref_type); auto alloc = rewriter->create<AllocOp>(loc, memref_type);
return alloc; return alloc;
} }
@ -111,8 +109,8 @@ class HloToLhloOpConverter : public BaseOpConversion<HloOpTy> {
return failure(); return failure();
} }
if (resultType.hasStaticShape()) { if (resultType.hasStaticShape()) {
buffer_args.push_back(InsertAlloc(op->getLoc(), result.value(), buffer_args.push_back(
this->bufferAssignment, &rewriter)); InsertAlloc(op->getLoc(), result.value(), &rewriter));
} else { } else {
SmallVector<Value, 1> results_shape; SmallVector<Value, 1> results_shape;
auto shape_type_op = dyn_cast<InferShapedTypeOpInterface>(op); auto shape_type_op = dyn_cast<InferShapedTypeOpInterface>(op);
@ -259,8 +257,7 @@ struct HloToLhloReduceOpConverter : public BaseOpConversion<mhlo::ReduceOp> {
const auto& original_results = op.getResults(); const auto& original_results = op.getResults();
SmallVector<Value, 4> buffer_args(operands.begin(), operands.end()); SmallVector<Value, 4> buffer_args(operands.begin(), operands.end());
for (auto result : original_results) { for (auto result : original_results) {
buffer_args.push_back( buffer_args.push_back(InsertAlloc(loc, result, &rewriter));
InsertAlloc(loc, result, this->bufferAssignment, &rewriter));
} }
auto new_op = rewriter.create<lmhlo::ReduceOp>(loc, llvm::None, buffer_args, auto new_op = rewriter.create<lmhlo::ReduceOp>(loc, llvm::None, buffer_args,
op.getAttrs()); op.getAttrs());
@ -432,23 +429,13 @@ struct HloLegalizeToLhlo
isMemRefType); isMemRefType);
}); });
auto module = getOperation(); populateHLOToLHLOConversionPattern(&context, &converter, &patterns);
WalkResult result = module.walk([&](FuncOp func) -> WalkResult {
BufferAssignmentPlacer bufferAssignment(func);
OwningRewritePatternList patterns;
populateHLOToLHLOConversionPattern(func.getContext(), &bufferAssignment,
&converter, &patterns);
// FIXME: we likely need to call converter.setResultConversionKind() to
// respect results_escape_function.
populateWithBufferAssignmentOpConversionPatterns< populateWithBufferAssignmentOpConversionPatterns<
mlir::ReturnOp, mlir::ReturnOp, lmhlo::CopyOp>( mlir::ReturnOp, mlir::ReturnOp, lmhlo::CopyOp>(&context, &converter,
&context, &bufferAssignment, &converter, &patterns); &patterns);
return applyPartialConversion(func, target, patterns); if (failed(applyPartialConversion(getOperation(), target, patterns)))
});
if (result.wasInterrupted()) {
signalPassFailure(); signalPassFailure();
} }
}
private: private:
Option<bool> results_escape_function{ Option<bool> results_escape_function{
@ -460,8 +447,7 @@ struct HloLegalizeToLhlo
} // namespace } // namespace
void populateHLOToLHLOConversionPattern( void populateHLOToLHLOConversionPattern(
MLIRContext* context, BufferAssignmentPlacer* bufferAssignment, MLIRContext* context, BufferAssignmentTypeConverter* converter,
BufferAssignmentTypeConverter* converter,
OwningRewritePatternList* patterns) { OwningRewritePatternList* patterns) {
// clang-format off // clang-format off
patterns->insert< patterns->insert<
@ -506,7 +492,7 @@ void populateHLOToLHLOConversionPattern(
HloToLhloReturnOpConverter, HloToLhloReturnOpConverter,
HloToLhloTensorLoadOpConverter, HloToLhloTensorLoadOpConverter,
HloToLhloTensorStoreOpConverter HloToLhloTensorStoreOpConverter
>(context, bufferAssignment, converter); >(context, converter);
// clang-format on // clang-format on
} }