diff --git a/BUILD b/BUILD index 8af4602..a258d98 100644 --- a/BUILD +++ b/BUILD @@ -880,6 +880,7 @@ cc_library( ":hlo", ":lhlo", ":map_hlo_to_lhlo_op", + ":pass_details", "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", "@llvm-project//mlir:MemRefDialect", diff --git a/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td b/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td index 0f8d721..5d872d3 100644 --- a/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td +++ b/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td @@ -35,6 +35,11 @@ class HLO_Op traits> : let verifier = [{ return Verify(*this); }]; } +class HLO_ShapedInterfaceOp traits> : + HLO_Op]> { +} + def HLO_LOOP_FUSION : StrEnumAttrCase<"kLoop">; def HLO_INPUT_FUSION : StrEnumAttrCase<"kInput">; def HLO_OUTPUT_FUSION : StrEnumAttrCase<"kOutput">; @@ -1277,9 +1282,8 @@ def HLO_BroadcastInDimOp : HLO_Op<"broadcast_in_dim", let hasCustomHLOConverter = 1; } -def HLO_DynamicBroadcastInDimOp : HLO_Op<"dynamic_broadcast_in_dim", [ - NoSideEffect, DeclareOpInterfaceMethods]> { +def HLO_DynamicBroadcastInDimOp : HLO_ShapedInterfaceOp< + "dynamic_broadcast_in_dim", [NoSideEffect]> { let summary = "Broadcast a tensor into the given dynamic shape by adding dimensions."; let description = [{ This is a generalization of the BroadcastInDimOp which accepts its output @@ -1671,7 +1675,7 @@ def HLO_ReshapeOp: HLO_Op<"reshape", let hasCustomHLOConverter = 1; } -def HLO_DynamicReshapeOp: HLO_Op<"dynamic_reshape", [NoSideEffect]> { +def HLO_DynamicReshapeOp: HLO_ShapedInterfaceOp<"dynamic_reshape", [NoSideEffect]> { let summary = "Reshape a tensor to a given, possibly dynamic, shape."; let description = [{ Reshapes `operand` to `output_shape`. @@ -2201,9 +2205,9 @@ def HLO_RealDynamicSliceOp: HLO_Op< AllTypesMatch<["start_indices", "limit_indices", "strides"]>]> { let summary = "Real Dynamic Slice operator"; let description = [{ - The dynamic shape version of SliceOp. Extracts a sub-array from the input - array according to start_indices, limit_indices and strides. Expect - start_indices/limit_indices/strides to be statically shaped and matching + The dynamic shape version of SliceOp. Extracts a sub-array from the input + array according to start_indices, limit_indices and strides. Expect + start_indices/limit_indices/strides to be statically shaped and matching the rank of the input. }]; let arguments = (ins @@ -2221,11 +2225,11 @@ def HLO_DynamicPadOp: HLO_Op<"dynamic_pad", AllTypesMatch<["edge_padding_low", "edge_padding_high", "interior_padding"]>]> { let summary = "Dynamic Pad operator"; let description = [{ - The dynamic shape version of PadOp. Pads the edges of `operand` with the - `padding_value` and according to the passed configuration. Expect - edge_padding_low/edge_padding_high/interior_padding to be statically shaped - and matching the rank of the input. - See + The dynamic shape version of PadOp. Pads the edges of `operand` with the + `padding_value` and according to the passed configuration. Expect + edge_padding_low/edge_padding_high/interior_padding to be statically shaped + and matching the rank of the input. + See https://www.tensorflow.org/xla/operation_semantics#pad }]; let arguments = (ins @@ -2247,10 +2251,10 @@ def HLO_DynamicPadOp: HLO_Op<"dynamic_pad", def HLO_DynamicGatherOp: HLO_Op<"dynamic_gather", [NoSideEffect]> { string summary = "Dynamic Gather operator"; string description = [{ - The dynamic shape version of GatherOp. Stitches together several slices of an input + The dynamic shape version of GatherOp. Stitches together several slices of an input array. slice_sizes is a compile-time variable. }]; - + let arguments = (ins HLO_Tensor:$operand, HLO_IntTensor:$start_indices, @@ -2259,7 +2263,7 @@ def HLO_DynamicGatherOp: HLO_Op<"dynamic_gather", [NoSideEffect]> { DefaultValuedAttr:$indices_are_sorted ); let results = (outs HLO_Tensor); - + let hasCustomHLOConverter = 1; } diff --git a/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td b/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td index c8ee0ad..9877774 100644 --- a/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td +++ b/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td @@ -1391,7 +1391,7 @@ def LHLO_RealDynamicSliceOp: LHLO_Op< [AllTypesMatch<["start_indices", "limit_indices", "strides"]>]> { let summary = "LHLO Real Dynamic Slice operator"; let description = [{ - The dynamic shape version of DynamicSliceOp. Extracts a sub-array from the + The dynamic shape version of DynamicSliceOp. Extracts a sub-array from the input array according to dynamic start_indices, limit_indices and strides. }]; let arguments = (ins @@ -1407,10 +1407,10 @@ def LHLO_DynamicBroadcastInDimOp : LHLO_Op<"dynamic_broadcast_in_dim", []> { let summary = "Broadcast a tensor into the given dynamic shape by adding dimensions."; let description = [{ - The dynamic shape version of BroadcastInDimOp. This is a generalization of the - BroadcastInDimOp which accepts its output dimensions as an argument. It should - eventually supercede the statically shaped original, but is being phased as a - separate op in order to support compatibility with lowerings and translations that + The dynamic shape version of BroadcastInDimOp. This is a generalization of the + BroadcastInDimOp which accepts its output dimensions as an argument. It should + eventually supercede the statically shaped original, but is being phased as a + separate op in order to support compatibility with lowerings and translations that precede dynamic shapes. }]; let arguments = (ins @@ -1441,7 +1441,7 @@ def LHLO_DotGeneralOp: LHLO_Op<"dot_general", []> { def LHLO_DynamicGatherOp: LHLO_Op<"dynamic_gather", []> { string summary = "LHLO Dynamic Gather operator"; string description = [{ - The dynamic shape version of GatherOp. Stitches together several slices of an input + The dynamic shape version of GatherOp. Stitches together several slices of an input array. slice_sizes is not a const. }]; let arguments = (ins @@ -1454,7 +1454,7 @@ def LHLO_DynamicGatherOp: LHLO_Op<"dynamic_gather", []> { } def LHLO_DynamicPadOp: LHLO_Op< - "dynamic_pad", + "dynamic_pad", [AllTypesMatch<["edge_padding_low", "edge_padding_high", "interior_padding"]>]> { let summary = "LHLO Dynamic Pad operator"; let description = [{ @@ -1492,7 +1492,7 @@ def LHLO_BitcastOp: LHLO_Op<"bitcast", []> { def LHLO_DynamicBitcastOp: LHLO_Op<"dynamic_bitcast", []> { let summary = "LHLO Dynamic Bitcast operator"; let description = [{ - The dynamic shape version of BitcastOp. This op changes the shape of the + The dynamic shape version of BitcastOp. This op changes the shape of the input in the way that the physical arrangement of elements are unchanged. However, the op needs layout information to make sense of "physical @@ -1509,7 +1509,7 @@ def LHLO_DynamicBitcastOp: LHLO_Op<"dynamic_bitcast", []> { def LHLO_DynamicIotaOp : LHLO_Op<"dynamic_iota", []> { let summary = "Create linear increasing values from 0 to length -1."; let description = [{ - The dynamic shape version of IotaOp. Produces an output of the specified shape, + The dynamic shape version of IotaOp. Produces an output of the specified shape, with an incremental set of values along the specified dimension starting at 0. See https://www.tensorflow.org/xla/operation_semantics#iota @@ -1535,7 +1535,7 @@ def LHLO_DynamicReshapeOp: LHLO_Op<"dynamic_reshape", []> { }]; let arguments = (ins Arg:$operand, - Arg:$shape, + Arg:$shape, Arg:$output ); } diff --git a/include/mlir-hlo/Dialect/mhlo/transforms/map_hlo_to_lhlo_op.h b/include/mlir-hlo/Dialect/mhlo/transforms/map_hlo_to_lhlo_op.h index 252b15d..45256c2 100644 --- a/include/mlir-hlo/Dialect/mhlo/transforms/map_hlo_to_lhlo_op.h +++ b/include/mlir-hlo/Dialect/mhlo/transforms/map_hlo_to_lhlo_op.h @@ -54,6 +54,8 @@ MAP_HLO_TO_LHLO(CosOp); MAP_HLO_TO_LHLO(CustomCallOp); MAP_HLO_TO_LHLO(DivOp); MAP_HLO_TO_LHLO(DotOp); +MAP_HLO_TO_LHLO(DynamicBroadcastInDimOp); +MAP_HLO_TO_LHLO(DynamicReshapeOp); MAP_HLO_TO_LHLO(ExpOp); MAP_HLO_TO_LHLO(Expm1Op); MAP_HLO_TO_LHLO(FloorOp); diff --git a/include/mlir-hlo/Dialect/mhlo/transforms/mhlo_passes.td b/include/mlir-hlo/Dialect/mhlo/transforms/mhlo_passes.td index 11804f8..edfe1c0 100644 --- a/include/mlir-hlo/Dialect/mhlo/transforms/mhlo_passes.td +++ b/include/mlir-hlo/Dialect/mhlo/transforms/mhlo_passes.td @@ -29,6 +29,12 @@ def ChloLegalizeToHloPass : FunctionPass<"chlo-legalize-to-hlo"> { def HloLegalizeToLhloPass : Pass<"hlo-legalize-to-lhlo", "ModuleOp"> { let summary = "Legalize from HLO dialect to LHLO dialect."; let constructor = "createLegalizeToLhloPass()"; + let options = [ + Option<"convert_to_lmhlo_only_", "convert-to-lmhlo-only", "bool", + /*default=*/"false", "If enabled, simply lower all mhlo ops to their lmhlo counterparts, " + "otherwise, some metadata-only ops (e.g. reshape) may be lowerred " + "to memref dialect to elide some buffer copy.">, + ]; } def LegalizeControlFlowPass : Pass<"mhlo-legalize-control-flow", "FuncOp"> { diff --git a/include/mlir-hlo/Dialect/mhlo/transforms/passes.h b/include/mlir-hlo/Dialect/mhlo/transforms/passes.h index 6f662d1..58c3e22 100644 --- a/include/mlir-hlo/Dialect/mhlo/transforms/passes.h +++ b/include/mlir-hlo/Dialect/mhlo/transforms/passes.h @@ -50,7 +50,8 @@ std::unique_ptr createChloLegalizeToHloPass( /// Lowers from HLO dialect to LHLO dialect allocating/deallocating temporary /// buffers if necessary. -std::unique_ptr> createLegalizeToLhloPass(); +std::unique_ptr> createLegalizeToLhloPass( + bool convert_to_lmhlo_only = false); // Lowers from HLO dialect to Linalg dialect. std::unique_ptr> createLegalizeHloToLinalgPass(); diff --git a/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h b/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h index 6b7c099..49ed725 100644 --- a/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h +++ b/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h @@ -45,15 +45,24 @@ void PopulateGatherToTorchIndexSelectPatterns( void PopulateMhloToStdPatterns(OwningRewritePatternList *patterns, MLIRContext *ctx); -// Collection of rewrite patterns for lowering of dynamic HLOs to LHLO dialect. -void populateDynamicHLOToLHLOConversionPattern( +// Collection of rewrite patterns for lowering of dynamic HLOs to LHLO or memref +// dialect. +void populateDynamicHLOToLHLOOrMemRefConversionPattern( MLIRContext *context, BufferizeTypeConverter *converter, OwningRewritePatternList *patterns, bool insert_copy = true); +// Collection of rewrite patterns for simply lowering all mhlo ops to their +// lmhlo counterparts, do not apply any optimization (e.g. elide any buffer +// copy). +void populateDynamicHLOToLHLOOnlyConversionPattern( + MLIRContext *context, BufferizeTypeConverter *converter, + OwningRewritePatternList *patterns); + // Collection of rewrite patterns for lowering of HLO to LHLO dialect. void populateHLOToLHLOConversionPattern(MLIRContext *context, BufferizeTypeConverter *converter, - OwningRewritePatternList *patterns); + OwningRewritePatternList *patterns, + bool convert_to_lmhlo_only = false); // Collection of rewrite patterns for lowering of HLO to Linalg dialect. void populateHLOToLinalgConversionPattern(MLIRContext *context, diff --git a/lib/Dialect/mhlo/IR/hlo_ops.cc b/lib/Dialect/mhlo/IR/hlo_ops.cc index d363844..771827e 100644 --- a/lib/Dialect/mhlo/IR/hlo_ops.cc +++ b/lib/Dialect/mhlo/IR/hlo_ops.cc @@ -1021,12 +1021,6 @@ void DynamicBroadcastInDimOp::getCanonicalizationPatterns( context); } -LogicalResult DynamicBroadcastInDimOp::inferReturnTypeComponents( - MLIRContext*, llvm::Optional, ValueRange, DictionaryAttr, - RegionRange, llvm::SmallVectorImpl&) { - return failure(); -} - LogicalResult DynamicBroadcastInDimOp::reifyReturnTypeShapes( OpBuilder&, ValueRange operands, SmallVectorImpl& reifiedReturnShapes) { @@ -1411,6 +1405,14 @@ static LogicalResult Verify(DynamicReshapeOp op) { return success(); } +LogicalResult DynamicReshapeOp::reifyReturnTypeShapes( + OpBuilder&, ValueRange operands, + SmallVectorImpl& reifiedReturnShapes) { + DynamicReshapeOp::Adaptor adaptor(operands); + reifiedReturnShapes.push_back(adaptor.output_shape()); + return success(); +} + namespace { class DynamicReshapeOpNotActuallyDynamic : public OpRewritePattern { diff --git a/lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc b/lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc index fc64bda..99031c4 100644 --- a/lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc +++ b/lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc @@ -17,6 +17,7 @@ limitations under the License. #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" #include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h" +#include "mlir-hlo/Dialect/mhlo/transforms/PassDetail.h" #include "mlir-hlo/Dialect/mhlo/transforms/map_hlo_to_lhlo_op.h" #include "mlir-hlo/Dialect/mhlo/transforms/passes.h" #include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h" @@ -576,8 +577,14 @@ class HloToLhloTensorStoreOpLegacyConverter // "lmhlo.terminator"() : () -> () // } -struct HloLegalizeToLhlo - : public PassWrapper> { +struct HloLegalizeToLhlo : public HloLegalizeToLhloPassBase { + using HloLegalizeToLhloPassBase::HloLegalizeToLhloPassBase; + explicit HloLegalizeToLhlo(bool convert_to_lmhlo_only) + : HloLegalizeToLhloPassBase< + HloLegalizeToLhlo>::HloLegalizeToLhloPassBase() { + this->convert_to_lmhlo_only_ = convert_to_lmhlo_only; + } + void getDependentDialects(DialectRegistry& registry) const override { registry.insert(); @@ -623,7 +630,8 @@ struct HloLegalizeToLhlo isMemRefType); }); - populateHLOToLHLOConversionPattern(&context, &converter, &patterns); + populateHLOToLHLOConversionPattern(&context, &converter, &patterns, + convert_to_lmhlo_only_); populateFuncOpTypeConversionPattern(patterns, converter); populateCallOpTypeConversionPattern(patterns, converter); populateBranchOpInterfaceTypeConversionPattern(patterns, converter); @@ -643,7 +651,9 @@ struct HloLegalizeToLhlo }; } // namespace -void populateDynamicHLOToLHLOConversionPattern( +// Lowers some metadata-only mhlo ops (e.g. reshape) to memref dialect +// directly and Lowers others to their lmhlo counterparts. +void populateDynamicHLOToLHLOOrMemRefConversionPattern( MLIRContext* context, BufferizeTypeConverter* converter, OwningRewritePatternList* patterns, bool insert_copy) { patterns->insert( @@ -652,10 +662,28 @@ void populateDynamicHLOToLHLOConversionPattern( HloToLhloReshapeUnrankedConverter>(*converter, context); } +// Simply lowers all mhlo ops to their lmhlo counterparts, do not apply +// any optimization (e.g. elide any buffer copy). +void populateDynamicHLOToLHLOOnlyConversionPattern( + MLIRContext* context, BufferizeTypeConverter* converter, + OwningRewritePatternList* patterns) { + // clang-format off + patterns->insert, + HloToLhloOpConverter + >(*converter, context); + // clang-format on +} + void populateHLOToLHLOConversionPattern(MLIRContext* context, BufferizeTypeConverter* converter, - OwningRewritePatternList* patterns) { - populateDynamicHLOToLHLOConversionPattern(context, converter, patterns); + OwningRewritePatternList* patterns, + bool convert_to_lmhlo_only) { + if (convert_to_lmhlo_only) { + populateDynamicHLOToLHLOOnlyConversionPattern(context, converter, patterns); + } else { + populateDynamicHLOToLHLOOrMemRefConversionPattern(context, converter, + patterns); + } // clang-format off patterns->insert< HloToLhloCustomCallOpConverter, @@ -712,8 +740,9 @@ void populateHLOToLHLOConversionPattern(MLIRContext* context, // clang-format on } -std::unique_ptr> createLegalizeToLhloPass() { - return std::make_unique(); +std::unique_ptr> createLegalizeToLhloPass( + bool convert_to_lmhlo_only) { + return std::make_unique(convert_to_lmhlo_only); } } // namespace mhlo diff --git a/tests/hlo-legalize-to-lhlo-only-dynamic.mlir b/tests/hlo-legalize-to-lhlo-only-dynamic.mlir new file mode 100644 index 0000000..4cae2dc --- /dev/null +++ b/tests/hlo-legalize-to-lhlo-only-dynamic.mlir @@ -0,0 +1,35 @@ +// RUN: mlir-hlo-opt -hlo-legalize-to-lhlo=convert-to-lmhlo-only=true \ +// RUN: -canonicalize -lhlo-legalize-tensor-load-op %s -o - | FileCheck %s + +// CHECK-LABEL: func @dynamic_reshape +// CHECK-SAME: (%[[ARG:.*]]: memref, %[[SHAPE:.*]]: memref<3xindex>) -> memref +func @dynamic_reshape(%lhs: tensor, %rhs: tensor<3xindex>) -> tensor { + // CHECK-NOT: tensor_load + // CHECK: %[[DIM0:.*]] = memref.load %[[SHAPE]][%c0] + // CHECK: %[[DIM1:.*]] = memref.load %[[SHAPE]][%c1] + // CHECK: %[[DIM2:.*]] = memref.load %[[SHAPE]][%c2] + // CHECK: %[[OUTPUT:.*]] = memref.alloc(%[[DIM0]], %[[DIM1]], %[[DIM2]]) + // CHECK: "lmhlo.dynamic_reshape"(%[[ARG]], %[[SHAPE]], %[[OUTPUT]]) + // CHECK: return %[[OUTPUT]] + %result = "mhlo.dynamic_reshape"(%lhs, %rhs) + : (tensor, tensor<3xindex>) -> tensor + return %result : tensor +} + +// ----- + +// CHECK-LABEL: func @dynamic_broadcast_in_dim +// CHECK-SAME: (%[[ARG:.*]]: memref, %[[SHAPE:.*]]: memref<3xindex>) -> memref +func @dynamic_broadcast_in_dim(%operand: tensor, %shape: tensor<3xindex>) -> tensor { + // CHECK-NOT: tensor_load + // CHECK: %[[DIM0:.*]] = memref.load %[[SHAPE]][%c0] + // CHECK: %[[DIM1:.*]] = memref.load %[[SHAPE]][%c1] + // CHECK: %[[DIM2:.*]] = memref.load %[[SHAPE]][%c2] + // CHECK: %[[OUTPUT:.*]] = memref.alloc(%[[DIM0]], %[[DIM1]], %[[DIM2]]) + // CHECK: "lmhlo.dynamic_broadcast_in_dim"(%[[ARG]], %[[SHAPE]], %[[OUTPUT]]) + // CHECK: return %[[OUTPUT]] + %result = "mhlo.dynamic_broadcast_in_dim"(%operand, %shape) { + broadcast_dimensions = dense<[1, 2]> : tensor<2xi64> + } : (tensor, tensor<3xindex>) -> tensor + return %result : tensor +} \ No newline at end of file