PR #49970: [MLIR][DISC] bufferize DynamicReshape and DynamicBroadcastInDim

Imported from GitHub PR https://github.com/tensorflow/tensorflow/pull/49970

1, add hlo-to-lhlo support for DynamicReshape and DynamicBroadcastInDim

2, add a flag `convert-to-lmhlo-only` to seperate following two case:
   - hlo-to-lhlo only. Simply lowers all mhlo ops to their lmhlo
     counterparts, do not apply any optimization (e.g. elide any
     buffer copy). Buffer optimization is not easy in dynamic
     shape world especially when involving control flow, thus we
     leave this to another dedicated pass.

   - hlo-to-lhlo-or-memref-directly. Lowers some metadata-only mhlo
     ops (e.g. reshape) to memref dialect directly and Lowers others
     to their lmhlo counterparts.
Copybara import of the project:

--
562bd65a368f6194405c4ae6900e3b4388a5ec03 by Wenyi Zhao <reyizero@gmail.com>:

[MLIR][DISC] bufferize DynamicReshape and DynamicBroadcastInDim

1, add hlo-to-lhlo support for DynamicReshape and DynamicBroadcastInDim

2, add a flag `convert-to-lmhlo-only` to seperate following two case:
   - hlo-to-lhlo only. Simply lowers all mhlo ops to their lmhlo
     counterparts, do not apply any optimization (e.g. elide any
     buffer copy). Buffer optimization is not easy in dynamic
     shape world especially when involving control flow, thus we
     leave this to another dedicated pass.

   - hlo-to-lhlo-or-memref-directly. Lowers some metadata-only mhlo
     ops (e.g. reshape) to memref dialect directly and Lowers others
     to their lmhlo counterparts.

PiperOrigin-RevId: 377603395
This commit is contained in:
Wenyi Zhao 2021-06-04 15:35:08 -07:00 committed by TensorFlow MLIR Team
parent 8b3a75ea25
commit ade873a5e0
10 changed files with 132 additions and 43 deletions

1
BUILD
View File

@ -880,6 +880,7 @@ cc_library(
":hlo",
":lhlo",
":map_hlo_to_lhlo_op",
":pass_details",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:MemRefDialect",

View File

@ -35,6 +35,11 @@ class HLO_Op<string mnemonic, list<OpTrait> traits> :
let verifier = [{ return Verify(*this); }];
}
class HLO_ShapedInterfaceOp<string mnemonic, list<OpTrait> traits> :
HLO_Op<mnemonic, traits # [DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
["reifyReturnTypeShapes"]>]> {
}
def HLO_LOOP_FUSION : StrEnumAttrCase<"kLoop">;
def HLO_INPUT_FUSION : StrEnumAttrCase<"kInput">;
def HLO_OUTPUT_FUSION : StrEnumAttrCase<"kOutput">;
@ -1277,9 +1282,8 @@ def HLO_BroadcastInDimOp : HLO_Op<"broadcast_in_dim",
let hasCustomHLOConverter = 1;
}
def HLO_DynamicBroadcastInDimOp : HLO_Op<"dynamic_broadcast_in_dim", [
NoSideEffect, DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
["inferReturnTypeComponents", "reifyReturnTypeShapes"]>]> {
def HLO_DynamicBroadcastInDimOp : HLO_ShapedInterfaceOp<
"dynamic_broadcast_in_dim", [NoSideEffect]> {
let summary = "Broadcast a tensor into the given dynamic shape by adding dimensions.";
let description = [{
This is a generalization of the BroadcastInDimOp which accepts its output
@ -1671,7 +1675,7 @@ def HLO_ReshapeOp: HLO_Op<"reshape",
let hasCustomHLOConverter = 1;
}
def HLO_DynamicReshapeOp: HLO_Op<"dynamic_reshape", [NoSideEffect]> {
def HLO_DynamicReshapeOp: HLO_ShapedInterfaceOp<"dynamic_reshape", [NoSideEffect]> {
let summary = "Reshape a tensor to a given, possibly dynamic, shape.";
let description = [{
Reshapes `operand` to `output_shape`.

View File

@ -1535,7 +1535,7 @@ def LHLO_DynamicReshapeOp: LHLO_Op<"dynamic_reshape", []> {
}];
let arguments = (ins
Arg<LHLO_Buffer, "", [MemRead]>:$operand,
Arg<LHLO_IntBuffer, "", [MemRead]>:$shape,
Arg<LHLO_DimensionBuffer, "", [MemRead]>:$shape,
Arg<LHLO_Buffer, "", [MemWrite]>:$output
);
}

View File

@ -54,6 +54,8 @@ MAP_HLO_TO_LHLO(CosOp);
MAP_HLO_TO_LHLO(CustomCallOp);
MAP_HLO_TO_LHLO(DivOp);
MAP_HLO_TO_LHLO(DotOp);
MAP_HLO_TO_LHLO(DynamicBroadcastInDimOp);
MAP_HLO_TO_LHLO(DynamicReshapeOp);
MAP_HLO_TO_LHLO(ExpOp);
MAP_HLO_TO_LHLO(Expm1Op);
MAP_HLO_TO_LHLO(FloorOp);

View File

@ -29,6 +29,12 @@ def ChloLegalizeToHloPass : FunctionPass<"chlo-legalize-to-hlo"> {
def HloLegalizeToLhloPass : Pass<"hlo-legalize-to-lhlo", "ModuleOp"> {
let summary = "Legalize from HLO dialect to LHLO dialect.";
let constructor = "createLegalizeToLhloPass()";
let options = [
Option<"convert_to_lmhlo_only_", "convert-to-lmhlo-only", "bool",
/*default=*/"false", "If enabled, simply lower all mhlo ops to their lmhlo counterparts, "
"otherwise, some metadata-only ops (e.g. reshape) may be lowerred "
"to memref dialect to elide some buffer copy.">,
];
}
def LegalizeControlFlowPass : Pass<"mhlo-legalize-control-flow", "FuncOp"> {

View File

@ -50,7 +50,8 @@ std::unique_ptr<FunctionPass> createChloLegalizeToHloPass(
/// Lowers from HLO dialect to LHLO dialect allocating/deallocating temporary
/// buffers if necessary.
std::unique_ptr<OperationPass<ModuleOp>> createLegalizeToLhloPass();
std::unique_ptr<OperationPass<ModuleOp>> createLegalizeToLhloPass(
bool convert_to_lmhlo_only = false);
// Lowers from HLO dialect to Linalg dialect.
std::unique_ptr<OperationPass<FuncOp>> createLegalizeHloToLinalgPass();

View File

@ -45,15 +45,24 @@ void PopulateGatherToTorchIndexSelectPatterns(
void PopulateMhloToStdPatterns(OwningRewritePatternList *patterns,
MLIRContext *ctx);
// Collection of rewrite patterns for lowering of dynamic HLOs to LHLO dialect.
void populateDynamicHLOToLHLOConversionPattern(
// Collection of rewrite patterns for lowering of dynamic HLOs to LHLO or memref
// dialect.
void populateDynamicHLOToLHLOOrMemRefConversionPattern(
MLIRContext *context, BufferizeTypeConverter *converter,
OwningRewritePatternList *patterns, bool insert_copy = true);
// Collection of rewrite patterns for simply lowering all mhlo ops to their
// lmhlo counterparts, do not apply any optimization (e.g. elide any buffer
// copy).
void populateDynamicHLOToLHLOOnlyConversionPattern(
MLIRContext *context, BufferizeTypeConverter *converter,
OwningRewritePatternList *patterns);
// Collection of rewrite patterns for lowering of HLO to LHLO dialect.
void populateHLOToLHLOConversionPattern(MLIRContext *context,
BufferizeTypeConverter *converter,
OwningRewritePatternList *patterns);
OwningRewritePatternList *patterns,
bool convert_to_lmhlo_only = false);
// Collection of rewrite patterns for lowering of HLO to Linalg dialect.
void populateHLOToLinalgConversionPattern(MLIRContext *context,

View File

@ -1021,12 +1021,6 @@ void DynamicBroadcastInDimOp::getCanonicalizationPatterns(
context);
}
LogicalResult DynamicBroadcastInDimOp::inferReturnTypeComponents(
MLIRContext*, llvm::Optional<mlir::Location>, ValueRange, DictionaryAttr,
RegionRange, llvm::SmallVectorImpl<mlir::ShapedTypeComponents>&) {
return failure();
}
LogicalResult DynamicBroadcastInDimOp::reifyReturnTypeShapes(
OpBuilder&, ValueRange operands,
SmallVectorImpl<Value>& reifiedReturnShapes) {
@ -1411,6 +1405,14 @@ static LogicalResult Verify(DynamicReshapeOp op) {
return success();
}
LogicalResult DynamicReshapeOp::reifyReturnTypeShapes(
OpBuilder&, ValueRange operands,
SmallVectorImpl<Value>& reifiedReturnShapes) {
DynamicReshapeOp::Adaptor adaptor(operands);
reifiedReturnShapes.push_back(adaptor.output_shape());
return success();
}
namespace {
class DynamicReshapeOpNotActuallyDynamic
: public OpRewritePattern<DynamicReshapeOp> {

View File

@ -17,6 +17,7 @@ limitations under the License.
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
#include "mlir-hlo/Dialect/mhlo/transforms/PassDetail.h"
#include "mlir-hlo/Dialect/mhlo/transforms/map_hlo_to_lhlo_op.h"
#include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
#include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
@ -576,8 +577,14 @@ class HloToLhloTensorStoreOpLegacyConverter
// "lmhlo.terminator"() : () -> ()
// }
struct HloLegalizeToLhlo
: public PassWrapper<HloLegalizeToLhlo, OperationPass<ModuleOp>> {
struct HloLegalizeToLhlo : public HloLegalizeToLhloPassBase<HloLegalizeToLhlo> {
using HloLegalizeToLhloPassBase<HloLegalizeToLhlo>::HloLegalizeToLhloPassBase;
explicit HloLegalizeToLhlo(bool convert_to_lmhlo_only)
: HloLegalizeToLhloPassBase<
HloLegalizeToLhlo>::HloLegalizeToLhloPassBase() {
this->convert_to_lmhlo_only_ = convert_to_lmhlo_only;
}
void getDependentDialects(DialectRegistry& registry) const override {
registry.insert<lmhlo::LmhloDialect, memref::MemRefDialect,
shape::ShapeDialect>();
@ -623,7 +630,8 @@ struct HloLegalizeToLhlo
isMemRefType);
});
populateHLOToLHLOConversionPattern(&context, &converter, &patterns);
populateHLOToLHLOConversionPattern(&context, &converter, &patterns,
convert_to_lmhlo_only_);
populateFuncOpTypeConversionPattern(patterns, converter);
populateCallOpTypeConversionPattern(patterns, converter);
populateBranchOpInterfaceTypeConversionPattern(patterns, converter);
@ -643,7 +651,9 @@ struct HloLegalizeToLhlo
};
} // namespace
void populateDynamicHLOToLHLOConversionPattern(
// Lowers some metadata-only mhlo ops (e.g. reshape) to memref dialect
// directly and Lowers others to their lmhlo counterparts.
void populateDynamicHLOToLHLOOrMemRefConversionPattern(
MLIRContext* context, BufferizeTypeConverter* converter,
OwningRewritePatternList* patterns, bool insert_copy) {
patterns->insert<HloToLhloDynamicBroadcastInDimOpConverter>(
@ -652,10 +662,28 @@ void populateDynamicHLOToLHLOConversionPattern(
HloToLhloReshapeUnrankedConverter>(*converter, context);
}
// Simply lowers all mhlo ops to their lmhlo counterparts, do not apply
// any optimization (e.g. elide any buffer copy).
void populateDynamicHLOToLHLOOnlyConversionPattern(
MLIRContext* context, BufferizeTypeConverter* converter,
OwningRewritePatternList* patterns) {
// clang-format off
patterns->insert<HloToLhloOpConverter<mhlo::DynamicBroadcastInDimOp>,
HloToLhloOpConverter<mhlo::DynamicReshapeOp>
>(*converter, context);
// clang-format on
}
void populateHLOToLHLOConversionPattern(MLIRContext* context,
BufferizeTypeConverter* converter,
OwningRewritePatternList* patterns) {
populateDynamicHLOToLHLOConversionPattern(context, converter, patterns);
OwningRewritePatternList* patterns,
bool convert_to_lmhlo_only) {
if (convert_to_lmhlo_only) {
populateDynamicHLOToLHLOOnlyConversionPattern(context, converter, patterns);
} else {
populateDynamicHLOToLHLOOrMemRefConversionPattern(context, converter,
patterns);
}
// clang-format off
patterns->insert<
HloToLhloCustomCallOpConverter,
@ -712,8 +740,9 @@ void populateHLOToLHLOConversionPattern(MLIRContext* context,
// clang-format on
}
std::unique_ptr<OperationPass<ModuleOp>> createLegalizeToLhloPass() {
return std::make_unique<HloLegalizeToLhlo>();
std::unique_ptr<OperationPass<ModuleOp>> createLegalizeToLhloPass(
bool convert_to_lmhlo_only) {
return std::make_unique<HloLegalizeToLhlo>(convert_to_lmhlo_only);
}
} // namespace mhlo

View File

@ -0,0 +1,35 @@
// RUN: mlir-hlo-opt -hlo-legalize-to-lhlo=convert-to-lmhlo-only=true \
// RUN: -canonicalize -lhlo-legalize-tensor-load-op %s -o - | FileCheck %s
// CHECK-LABEL: func @dynamic_reshape
// CHECK-SAME: (%[[ARG:.*]]: memref<?x?xf32>, %[[SHAPE:.*]]: memref<3xindex>) -> memref<?x?x?xf32>
func @dynamic_reshape(%lhs: tensor<?x?xf32>, %rhs: tensor<3xindex>) -> tensor<?x?x?xf32> {
// CHECK-NOT: tensor_load
// CHECK: %[[DIM0:.*]] = memref.load %[[SHAPE]][%c0]
// CHECK: %[[DIM1:.*]] = memref.load %[[SHAPE]][%c1]
// CHECK: %[[DIM2:.*]] = memref.load %[[SHAPE]][%c2]
// CHECK: %[[OUTPUT:.*]] = memref.alloc(%[[DIM0]], %[[DIM1]], %[[DIM2]])
// CHECK: "lmhlo.dynamic_reshape"(%[[ARG]], %[[SHAPE]], %[[OUTPUT]])
// CHECK: return %[[OUTPUT]]
%result = "mhlo.dynamic_reshape"(%lhs, %rhs)
: (tensor<?x?xf32>, tensor<3xindex>) -> tensor<?x?x?xf32>
return %result : tensor<?x?x?xf32>
}
// -----
// CHECK-LABEL: func @dynamic_broadcast_in_dim
// CHECK-SAME: (%[[ARG:.*]]: memref<?x?xf32>, %[[SHAPE:.*]]: memref<3xindex>) -> memref<?x?x?xf32>
func @dynamic_broadcast_in_dim(%operand: tensor<?x?xf32>, %shape: tensor<3xindex>) -> tensor<?x?x?xf32> {
// CHECK-NOT: tensor_load
// CHECK: %[[DIM0:.*]] = memref.load %[[SHAPE]][%c0]
// CHECK: %[[DIM1:.*]] = memref.load %[[SHAPE]][%c1]
// CHECK: %[[DIM2:.*]] = memref.load %[[SHAPE]][%c2]
// CHECK: %[[OUTPUT:.*]] = memref.alloc(%[[DIM0]], %[[DIM1]], %[[DIM2]])
// CHECK: "lmhlo.dynamic_broadcast_in_dim"(%[[ARG]], %[[SHAPE]], %[[OUTPUT]])
// CHECK: return %[[OUTPUT]]
%result = "mhlo.dynamic_broadcast_in_dim"(%operand, %shape) {
broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>
} : (tensor<?x?xf32>, tensor<3xindex>) -> tensor<?x?x?xf32>
return %result : tensor<?x?x?xf32>
}