PR #49970: [MLIR][DISC] bufferize DynamicReshape and DynamicBroadcastInDim

Imported from GitHub PR https://github.com/tensorflow/tensorflow/pull/49970

1, add hlo-to-lhlo support for DynamicReshape and DynamicBroadcastInDim

2, add a flag `convert-to-lmhlo-only` to seperate following two case:
   - hlo-to-lhlo only. Simply lowers all mhlo ops to their lmhlo
     counterparts, do not apply any optimization (e.g. elide any
     buffer copy). Buffer optimization is not easy in dynamic
     shape world especially when involving control flow, thus we
     leave this to another dedicated pass.

   - hlo-to-lhlo-or-memref-directly. Lowers some metadata-only mhlo
     ops (e.g. reshape) to memref dialect directly and Lowers others
     to their lmhlo counterparts.
Copybara import of the project:

--
562bd65a368f6194405c4ae6900e3b4388a5ec03 by Wenyi Zhao <reyizero@gmail.com>:

[MLIR][DISC] bufferize DynamicReshape and DynamicBroadcastInDim

1, add hlo-to-lhlo support for DynamicReshape and DynamicBroadcastInDim

2, add a flag `convert-to-lmhlo-only` to seperate following two case:
   - hlo-to-lhlo only. Simply lowers all mhlo ops to their lmhlo
     counterparts, do not apply any optimization (e.g. elide any
     buffer copy). Buffer optimization is not easy in dynamic
     shape world especially when involving control flow, thus we
     leave this to another dedicated pass.

   - hlo-to-lhlo-or-memref-directly. Lowers some metadata-only mhlo
     ops (e.g. reshape) to memref dialect directly and Lowers others
     to their lmhlo counterparts.

PiperOrigin-RevId: 377603395
This commit is contained in:
Wenyi Zhao 2021-06-04 15:35:08 -07:00 committed by TensorFlow MLIR Team
parent 8b3a75ea25
commit ade873a5e0
10 changed files with 132 additions and 43 deletions

1
BUILD
View File

@ -880,6 +880,7 @@ cc_library(
":hlo", ":hlo",
":lhlo", ":lhlo",
":map_hlo_to_lhlo_op", ":map_hlo_to_lhlo_op",
":pass_details",
"@llvm-project//llvm:Support", "@llvm-project//llvm:Support",
"@llvm-project//mlir:IR", "@llvm-project//mlir:IR",
"@llvm-project//mlir:MemRefDialect", "@llvm-project//mlir:MemRefDialect",

View File

@ -35,6 +35,11 @@ class HLO_Op<string mnemonic, list<OpTrait> traits> :
let verifier = [{ return Verify(*this); }]; let verifier = [{ return Verify(*this); }];
} }
class HLO_ShapedInterfaceOp<string mnemonic, list<OpTrait> traits> :
HLO_Op<mnemonic, traits # [DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
["reifyReturnTypeShapes"]>]> {
}
def HLO_LOOP_FUSION : StrEnumAttrCase<"kLoop">; def HLO_LOOP_FUSION : StrEnumAttrCase<"kLoop">;
def HLO_INPUT_FUSION : StrEnumAttrCase<"kInput">; def HLO_INPUT_FUSION : StrEnumAttrCase<"kInput">;
def HLO_OUTPUT_FUSION : StrEnumAttrCase<"kOutput">; def HLO_OUTPUT_FUSION : StrEnumAttrCase<"kOutput">;
@ -1277,9 +1282,8 @@ def HLO_BroadcastInDimOp : HLO_Op<"broadcast_in_dim",
let hasCustomHLOConverter = 1; let hasCustomHLOConverter = 1;
} }
def HLO_DynamicBroadcastInDimOp : HLO_Op<"dynamic_broadcast_in_dim", [ def HLO_DynamicBroadcastInDimOp : HLO_ShapedInterfaceOp<
NoSideEffect, DeclareOpInterfaceMethods<InferShapedTypeOpInterface, "dynamic_broadcast_in_dim", [NoSideEffect]> {
["inferReturnTypeComponents", "reifyReturnTypeShapes"]>]> {
let summary = "Broadcast a tensor into the given dynamic shape by adding dimensions."; let summary = "Broadcast a tensor into the given dynamic shape by adding dimensions.";
let description = [{ let description = [{
This is a generalization of the BroadcastInDimOp which accepts its output This is a generalization of the BroadcastInDimOp which accepts its output
@ -1671,7 +1675,7 @@ def HLO_ReshapeOp: HLO_Op<"reshape",
let hasCustomHLOConverter = 1; let hasCustomHLOConverter = 1;
} }
def HLO_DynamicReshapeOp: HLO_Op<"dynamic_reshape", [NoSideEffect]> { def HLO_DynamicReshapeOp: HLO_ShapedInterfaceOp<"dynamic_reshape", [NoSideEffect]> {
let summary = "Reshape a tensor to a given, possibly dynamic, shape."; let summary = "Reshape a tensor to a given, possibly dynamic, shape.";
let description = [{ let description = [{
Reshapes `operand` to `output_shape`. Reshapes `operand` to `output_shape`.
@ -2201,9 +2205,9 @@ def HLO_RealDynamicSliceOp: HLO_Op<
AllTypesMatch<["start_indices", "limit_indices", "strides"]>]> { AllTypesMatch<["start_indices", "limit_indices", "strides"]>]> {
let summary = "Real Dynamic Slice operator"; let summary = "Real Dynamic Slice operator";
let description = [{ let description = [{
The dynamic shape version of SliceOp. Extracts a sub-array from the input The dynamic shape version of SliceOp. Extracts a sub-array from the input
array according to start_indices, limit_indices and strides. Expect array according to start_indices, limit_indices and strides. Expect
start_indices/limit_indices/strides to be statically shaped and matching start_indices/limit_indices/strides to be statically shaped and matching
the rank of the input. the rank of the input.
}]; }];
let arguments = (ins let arguments = (ins
@ -2221,11 +2225,11 @@ def HLO_DynamicPadOp: HLO_Op<"dynamic_pad",
AllTypesMatch<["edge_padding_low", "edge_padding_high", "interior_padding"]>]> { AllTypesMatch<["edge_padding_low", "edge_padding_high", "interior_padding"]>]> {
let summary = "Dynamic Pad operator"; let summary = "Dynamic Pad operator";
let description = [{ let description = [{
The dynamic shape version of PadOp. Pads the edges of `operand` with the The dynamic shape version of PadOp. Pads the edges of `operand` with the
`padding_value` and according to the passed configuration. Expect `padding_value` and according to the passed configuration. Expect
edge_padding_low/edge_padding_high/interior_padding to be statically shaped edge_padding_low/edge_padding_high/interior_padding to be statically shaped
and matching the rank of the input. and matching the rank of the input.
See See
https://www.tensorflow.org/xla/operation_semantics#pad https://www.tensorflow.org/xla/operation_semantics#pad
}]; }];
let arguments = (ins let arguments = (ins
@ -2247,10 +2251,10 @@ def HLO_DynamicPadOp: HLO_Op<"dynamic_pad",
def HLO_DynamicGatherOp: HLO_Op<"dynamic_gather", [NoSideEffect]> { def HLO_DynamicGatherOp: HLO_Op<"dynamic_gather", [NoSideEffect]> {
string summary = "Dynamic Gather operator"; string summary = "Dynamic Gather operator";
string description = [{ string description = [{
The dynamic shape version of GatherOp. Stitches together several slices of an input The dynamic shape version of GatherOp. Stitches together several slices of an input
array. slice_sizes is a compile-time variable. array. slice_sizes is a compile-time variable.
}]; }];
let arguments = (ins let arguments = (ins
HLO_Tensor:$operand, HLO_Tensor:$operand,
HLO_IntTensor:$start_indices, HLO_IntTensor:$start_indices,
@ -2259,7 +2263,7 @@ def HLO_DynamicGatherOp: HLO_Op<"dynamic_gather", [NoSideEffect]> {
DefaultValuedAttr<BoolAttr, "false">:$indices_are_sorted DefaultValuedAttr<BoolAttr, "false">:$indices_are_sorted
); );
let results = (outs HLO_Tensor); let results = (outs HLO_Tensor);
let hasCustomHLOConverter = 1; let hasCustomHLOConverter = 1;
} }

View File

@ -1391,7 +1391,7 @@ def LHLO_RealDynamicSliceOp: LHLO_Op<
[AllTypesMatch<["start_indices", "limit_indices", "strides"]>]> { [AllTypesMatch<["start_indices", "limit_indices", "strides"]>]> {
let summary = "LHLO Real Dynamic Slice operator"; let summary = "LHLO Real Dynamic Slice operator";
let description = [{ let description = [{
The dynamic shape version of DynamicSliceOp. Extracts a sub-array from the The dynamic shape version of DynamicSliceOp. Extracts a sub-array from the
input array according to dynamic start_indices, limit_indices and strides. input array according to dynamic start_indices, limit_indices and strides.
}]; }];
let arguments = (ins let arguments = (ins
@ -1407,10 +1407,10 @@ def LHLO_DynamicBroadcastInDimOp : LHLO_Op<"dynamic_broadcast_in_dim",
[]> { []> {
let summary = "Broadcast a tensor into the given dynamic shape by adding dimensions."; let summary = "Broadcast a tensor into the given dynamic shape by adding dimensions.";
let description = [{ let description = [{
The dynamic shape version of BroadcastInDimOp. This is a generalization of the The dynamic shape version of BroadcastInDimOp. This is a generalization of the
BroadcastInDimOp which accepts its output dimensions as an argument. It should BroadcastInDimOp which accepts its output dimensions as an argument. It should
eventually supercede the statically shaped original, but is being phased as a eventually supercede the statically shaped original, but is being phased as a
separate op in order to support compatibility with lowerings and translations that separate op in order to support compatibility with lowerings and translations that
precede dynamic shapes. precede dynamic shapes.
}]; }];
let arguments = (ins let arguments = (ins
@ -1441,7 +1441,7 @@ def LHLO_DotGeneralOp: LHLO_Op<"dot_general", []> {
def LHLO_DynamicGatherOp: LHLO_Op<"dynamic_gather", []> { def LHLO_DynamicGatherOp: LHLO_Op<"dynamic_gather", []> {
string summary = "LHLO Dynamic Gather operator"; string summary = "LHLO Dynamic Gather operator";
string description = [{ string description = [{
The dynamic shape version of GatherOp. Stitches together several slices of an input The dynamic shape version of GatherOp. Stitches together several slices of an input
array. slice_sizes is not a const. array. slice_sizes is not a const.
}]; }];
let arguments = (ins let arguments = (ins
@ -1454,7 +1454,7 @@ def LHLO_DynamicGatherOp: LHLO_Op<"dynamic_gather", []> {
} }
def LHLO_DynamicPadOp: LHLO_Op< def LHLO_DynamicPadOp: LHLO_Op<
"dynamic_pad", "dynamic_pad",
[AllTypesMatch<["edge_padding_low", "edge_padding_high", "interior_padding"]>]> { [AllTypesMatch<["edge_padding_low", "edge_padding_high", "interior_padding"]>]> {
let summary = "LHLO Dynamic Pad operator"; let summary = "LHLO Dynamic Pad operator";
let description = [{ let description = [{
@ -1492,7 +1492,7 @@ def LHLO_BitcastOp: LHLO_Op<"bitcast", []> {
def LHLO_DynamicBitcastOp: LHLO_Op<"dynamic_bitcast", []> { def LHLO_DynamicBitcastOp: LHLO_Op<"dynamic_bitcast", []> {
let summary = "LHLO Dynamic Bitcast operator"; let summary = "LHLO Dynamic Bitcast operator";
let description = [{ let description = [{
The dynamic shape version of BitcastOp. This op changes the shape of the The dynamic shape version of BitcastOp. This op changes the shape of the
input in the way that the physical arrangement of elements are unchanged. input in the way that the physical arrangement of elements are unchanged.
However, the op needs layout information to make sense of "physical However, the op needs layout information to make sense of "physical
@ -1509,7 +1509,7 @@ def LHLO_DynamicBitcastOp: LHLO_Op<"dynamic_bitcast", []> {
def LHLO_DynamicIotaOp : LHLO_Op<"dynamic_iota", []> { def LHLO_DynamicIotaOp : LHLO_Op<"dynamic_iota", []> {
let summary = "Create linear increasing values from 0 to length -1."; let summary = "Create linear increasing values from 0 to length -1.";
let description = [{ let description = [{
The dynamic shape version of IotaOp. Produces an output of the specified shape, The dynamic shape version of IotaOp. Produces an output of the specified shape,
with an incremental set of values along the specified dimension starting at 0. with an incremental set of values along the specified dimension starting at 0.
See See
https://www.tensorflow.org/xla/operation_semantics#iota https://www.tensorflow.org/xla/operation_semantics#iota
@ -1535,7 +1535,7 @@ def LHLO_DynamicReshapeOp: LHLO_Op<"dynamic_reshape", []> {
}]; }];
let arguments = (ins let arguments = (ins
Arg<LHLO_Buffer, "", [MemRead]>:$operand, Arg<LHLO_Buffer, "", [MemRead]>:$operand,
Arg<LHLO_IntBuffer, "", [MemRead]>:$shape, Arg<LHLO_DimensionBuffer, "", [MemRead]>:$shape,
Arg<LHLO_Buffer, "", [MemWrite]>:$output Arg<LHLO_Buffer, "", [MemWrite]>:$output
); );
} }

View File

@ -54,6 +54,8 @@ MAP_HLO_TO_LHLO(CosOp);
MAP_HLO_TO_LHLO(CustomCallOp); MAP_HLO_TO_LHLO(CustomCallOp);
MAP_HLO_TO_LHLO(DivOp); MAP_HLO_TO_LHLO(DivOp);
MAP_HLO_TO_LHLO(DotOp); MAP_HLO_TO_LHLO(DotOp);
MAP_HLO_TO_LHLO(DynamicBroadcastInDimOp);
MAP_HLO_TO_LHLO(DynamicReshapeOp);
MAP_HLO_TO_LHLO(ExpOp); MAP_HLO_TO_LHLO(ExpOp);
MAP_HLO_TO_LHLO(Expm1Op); MAP_HLO_TO_LHLO(Expm1Op);
MAP_HLO_TO_LHLO(FloorOp); MAP_HLO_TO_LHLO(FloorOp);

View File

@ -29,6 +29,12 @@ def ChloLegalizeToHloPass : FunctionPass<"chlo-legalize-to-hlo"> {
def HloLegalizeToLhloPass : Pass<"hlo-legalize-to-lhlo", "ModuleOp"> { def HloLegalizeToLhloPass : Pass<"hlo-legalize-to-lhlo", "ModuleOp"> {
let summary = "Legalize from HLO dialect to LHLO dialect."; let summary = "Legalize from HLO dialect to LHLO dialect.";
let constructor = "createLegalizeToLhloPass()"; let constructor = "createLegalizeToLhloPass()";
let options = [
Option<"convert_to_lmhlo_only_", "convert-to-lmhlo-only", "bool",
/*default=*/"false", "If enabled, simply lower all mhlo ops to their lmhlo counterparts, "
"otherwise, some metadata-only ops (e.g. reshape) may be lowerred "
"to memref dialect to elide some buffer copy.">,
];
} }
def LegalizeControlFlowPass : Pass<"mhlo-legalize-control-flow", "FuncOp"> { def LegalizeControlFlowPass : Pass<"mhlo-legalize-control-flow", "FuncOp"> {

View File

@ -50,7 +50,8 @@ std::unique_ptr<FunctionPass> createChloLegalizeToHloPass(
/// Lowers from HLO dialect to LHLO dialect allocating/deallocating temporary /// Lowers from HLO dialect to LHLO dialect allocating/deallocating temporary
/// buffers if necessary. /// buffers if necessary.
std::unique_ptr<OperationPass<ModuleOp>> createLegalizeToLhloPass(); std::unique_ptr<OperationPass<ModuleOp>> createLegalizeToLhloPass(
bool convert_to_lmhlo_only = false);
// Lowers from HLO dialect to Linalg dialect. // Lowers from HLO dialect to Linalg dialect.
std::unique_ptr<OperationPass<FuncOp>> createLegalizeHloToLinalgPass(); std::unique_ptr<OperationPass<FuncOp>> createLegalizeHloToLinalgPass();

View File

@ -45,15 +45,24 @@ void PopulateGatherToTorchIndexSelectPatterns(
void PopulateMhloToStdPatterns(OwningRewritePatternList *patterns, void PopulateMhloToStdPatterns(OwningRewritePatternList *patterns,
MLIRContext *ctx); MLIRContext *ctx);
// Collection of rewrite patterns for lowering of dynamic HLOs to LHLO dialect. // Collection of rewrite patterns for lowering of dynamic HLOs to LHLO or memref
void populateDynamicHLOToLHLOConversionPattern( // dialect.
void populateDynamicHLOToLHLOOrMemRefConversionPattern(
MLIRContext *context, BufferizeTypeConverter *converter, MLIRContext *context, BufferizeTypeConverter *converter,
OwningRewritePatternList *patterns, bool insert_copy = true); OwningRewritePatternList *patterns, bool insert_copy = true);
// Collection of rewrite patterns for simply lowering all mhlo ops to their
// lmhlo counterparts, do not apply any optimization (e.g. elide any buffer
// copy).
void populateDynamicHLOToLHLOOnlyConversionPattern(
MLIRContext *context, BufferizeTypeConverter *converter,
OwningRewritePatternList *patterns);
// Collection of rewrite patterns for lowering of HLO to LHLO dialect. // Collection of rewrite patterns for lowering of HLO to LHLO dialect.
void populateHLOToLHLOConversionPattern(MLIRContext *context, void populateHLOToLHLOConversionPattern(MLIRContext *context,
BufferizeTypeConverter *converter, BufferizeTypeConverter *converter,
OwningRewritePatternList *patterns); OwningRewritePatternList *patterns,
bool convert_to_lmhlo_only = false);
// Collection of rewrite patterns for lowering of HLO to Linalg dialect. // Collection of rewrite patterns for lowering of HLO to Linalg dialect.
void populateHLOToLinalgConversionPattern(MLIRContext *context, void populateHLOToLinalgConversionPattern(MLIRContext *context,

View File

@ -1021,12 +1021,6 @@ void DynamicBroadcastInDimOp::getCanonicalizationPatterns(
context); context);
} }
LogicalResult DynamicBroadcastInDimOp::inferReturnTypeComponents(
MLIRContext*, llvm::Optional<mlir::Location>, ValueRange, DictionaryAttr,
RegionRange, llvm::SmallVectorImpl<mlir::ShapedTypeComponents>&) {
return failure();
}
LogicalResult DynamicBroadcastInDimOp::reifyReturnTypeShapes( LogicalResult DynamicBroadcastInDimOp::reifyReturnTypeShapes(
OpBuilder&, ValueRange operands, OpBuilder&, ValueRange operands,
SmallVectorImpl<Value>& reifiedReturnShapes) { SmallVectorImpl<Value>& reifiedReturnShapes) {
@ -1411,6 +1405,14 @@ static LogicalResult Verify(DynamicReshapeOp op) {
return success(); return success();
} }
LogicalResult DynamicReshapeOp::reifyReturnTypeShapes(
OpBuilder&, ValueRange operands,
SmallVectorImpl<Value>& reifiedReturnShapes) {
DynamicReshapeOp::Adaptor adaptor(operands);
reifiedReturnShapes.push_back(adaptor.output_shape());
return success();
}
namespace { namespace {
class DynamicReshapeOpNotActuallyDynamic class DynamicReshapeOpNotActuallyDynamic
: public OpRewritePattern<DynamicReshapeOp> { : public OpRewritePattern<DynamicReshapeOp> {

View File

@ -17,6 +17,7 @@ limitations under the License.
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h" #include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
#include "mlir-hlo/Dialect/mhlo/transforms/PassDetail.h"
#include "mlir-hlo/Dialect/mhlo/transforms/map_hlo_to_lhlo_op.h" #include "mlir-hlo/Dialect/mhlo/transforms/map_hlo_to_lhlo_op.h"
#include "mlir-hlo/Dialect/mhlo/transforms/passes.h" #include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
#include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h" #include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
@ -576,8 +577,14 @@ class HloToLhloTensorStoreOpLegacyConverter
// "lmhlo.terminator"() : () -> () // "lmhlo.terminator"() : () -> ()
// } // }
struct HloLegalizeToLhlo struct HloLegalizeToLhlo : public HloLegalizeToLhloPassBase<HloLegalizeToLhlo> {
: public PassWrapper<HloLegalizeToLhlo, OperationPass<ModuleOp>> { using HloLegalizeToLhloPassBase<HloLegalizeToLhlo>::HloLegalizeToLhloPassBase;
explicit HloLegalizeToLhlo(bool convert_to_lmhlo_only)
: HloLegalizeToLhloPassBase<
HloLegalizeToLhlo>::HloLegalizeToLhloPassBase() {
this->convert_to_lmhlo_only_ = convert_to_lmhlo_only;
}
void getDependentDialects(DialectRegistry& registry) const override { void getDependentDialects(DialectRegistry& registry) const override {
registry.insert<lmhlo::LmhloDialect, memref::MemRefDialect, registry.insert<lmhlo::LmhloDialect, memref::MemRefDialect,
shape::ShapeDialect>(); shape::ShapeDialect>();
@ -623,7 +630,8 @@ struct HloLegalizeToLhlo
isMemRefType); isMemRefType);
}); });
populateHLOToLHLOConversionPattern(&context, &converter, &patterns); populateHLOToLHLOConversionPattern(&context, &converter, &patterns,
convert_to_lmhlo_only_);
populateFuncOpTypeConversionPattern(patterns, converter); populateFuncOpTypeConversionPattern(patterns, converter);
populateCallOpTypeConversionPattern(patterns, converter); populateCallOpTypeConversionPattern(patterns, converter);
populateBranchOpInterfaceTypeConversionPattern(patterns, converter); populateBranchOpInterfaceTypeConversionPattern(patterns, converter);
@ -643,7 +651,9 @@ struct HloLegalizeToLhlo
}; };
} // namespace } // namespace
void populateDynamicHLOToLHLOConversionPattern( // Lowers some metadata-only mhlo ops (e.g. reshape) to memref dialect
// directly and Lowers others to their lmhlo counterparts.
void populateDynamicHLOToLHLOOrMemRefConversionPattern(
MLIRContext* context, BufferizeTypeConverter* converter, MLIRContext* context, BufferizeTypeConverter* converter,
OwningRewritePatternList* patterns, bool insert_copy) { OwningRewritePatternList* patterns, bool insert_copy) {
patterns->insert<HloToLhloDynamicBroadcastInDimOpConverter>( patterns->insert<HloToLhloDynamicBroadcastInDimOpConverter>(
@ -652,10 +662,28 @@ void populateDynamicHLOToLHLOConversionPattern(
HloToLhloReshapeUnrankedConverter>(*converter, context); HloToLhloReshapeUnrankedConverter>(*converter, context);
} }
// Simply lowers all mhlo ops to their lmhlo counterparts, do not apply
// any optimization (e.g. elide any buffer copy).
void populateDynamicHLOToLHLOOnlyConversionPattern(
MLIRContext* context, BufferizeTypeConverter* converter,
OwningRewritePatternList* patterns) {
// clang-format off
patterns->insert<HloToLhloOpConverter<mhlo::DynamicBroadcastInDimOp>,
HloToLhloOpConverter<mhlo::DynamicReshapeOp>
>(*converter, context);
// clang-format on
}
void populateHLOToLHLOConversionPattern(MLIRContext* context, void populateHLOToLHLOConversionPattern(MLIRContext* context,
BufferizeTypeConverter* converter, BufferizeTypeConverter* converter,
OwningRewritePatternList* patterns) { OwningRewritePatternList* patterns,
populateDynamicHLOToLHLOConversionPattern(context, converter, patterns); bool convert_to_lmhlo_only) {
if (convert_to_lmhlo_only) {
populateDynamicHLOToLHLOOnlyConversionPattern(context, converter, patterns);
} else {
populateDynamicHLOToLHLOOrMemRefConversionPattern(context, converter,
patterns);
}
// clang-format off // clang-format off
patterns->insert< patterns->insert<
HloToLhloCustomCallOpConverter, HloToLhloCustomCallOpConverter,
@ -712,8 +740,9 @@ void populateHLOToLHLOConversionPattern(MLIRContext* context,
// clang-format on // clang-format on
} }
std::unique_ptr<OperationPass<ModuleOp>> createLegalizeToLhloPass() { std::unique_ptr<OperationPass<ModuleOp>> createLegalizeToLhloPass(
return std::make_unique<HloLegalizeToLhlo>(); bool convert_to_lmhlo_only) {
return std::make_unique<HloLegalizeToLhlo>(convert_to_lmhlo_only);
} }
} // namespace mhlo } // namespace mhlo

View File

@ -0,0 +1,35 @@
// RUN: mlir-hlo-opt -hlo-legalize-to-lhlo=convert-to-lmhlo-only=true \
// RUN: -canonicalize -lhlo-legalize-tensor-load-op %s -o - | FileCheck %s
// CHECK-LABEL: func @dynamic_reshape
// CHECK-SAME: (%[[ARG:.*]]: memref<?x?xf32>, %[[SHAPE:.*]]: memref<3xindex>) -> memref<?x?x?xf32>
func @dynamic_reshape(%lhs: tensor<?x?xf32>, %rhs: tensor<3xindex>) -> tensor<?x?x?xf32> {
// CHECK-NOT: tensor_load
// CHECK: %[[DIM0:.*]] = memref.load %[[SHAPE]][%c0]
// CHECK: %[[DIM1:.*]] = memref.load %[[SHAPE]][%c1]
// CHECK: %[[DIM2:.*]] = memref.load %[[SHAPE]][%c2]
// CHECK: %[[OUTPUT:.*]] = memref.alloc(%[[DIM0]], %[[DIM1]], %[[DIM2]])
// CHECK: "lmhlo.dynamic_reshape"(%[[ARG]], %[[SHAPE]], %[[OUTPUT]])
// CHECK: return %[[OUTPUT]]
%result = "mhlo.dynamic_reshape"(%lhs, %rhs)
: (tensor<?x?xf32>, tensor<3xindex>) -> tensor<?x?x?xf32>
return %result : tensor<?x?x?xf32>
}
// -----
// CHECK-LABEL: func @dynamic_broadcast_in_dim
// CHECK-SAME: (%[[ARG:.*]]: memref<?x?xf32>, %[[SHAPE:.*]]: memref<3xindex>) -> memref<?x?x?xf32>
func @dynamic_broadcast_in_dim(%operand: tensor<?x?xf32>, %shape: tensor<3xindex>) -> tensor<?x?x?xf32> {
// CHECK-NOT: tensor_load
// CHECK: %[[DIM0:.*]] = memref.load %[[SHAPE]][%c0]
// CHECK: %[[DIM1:.*]] = memref.load %[[SHAPE]][%c1]
// CHECK: %[[DIM2:.*]] = memref.load %[[SHAPE]][%c2]
// CHECK: %[[OUTPUT:.*]] = memref.alloc(%[[DIM0]], %[[DIM1]], %[[DIM2]])
// CHECK: "lmhlo.dynamic_broadcast_in_dim"(%[[ARG]], %[[SHAPE]], %[[OUTPUT]])
// CHECK: return %[[OUTPUT]]
%result = "mhlo.dynamic_broadcast_in_dim"(%operand, %shape) {
broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>
} : (tensor<?x?xf32>, tensor<3xindex>) -> tensor<?x?x?xf32>
return %result : tensor<?x?x?xf32>
}