diff --git a/BUILD b/BUILD index 062212e..466c509 100644 --- a/BUILD +++ b/BUILD @@ -668,6 +668,7 @@ cc_library( "@llvm-project//mlir:Pass", "@llvm-project//mlir:SCFDialect", "@llvm-project//mlir:StandardOps", + "@llvm-project//mlir:TensorDialect", "@llvm-project//mlir:Transforms", ], alwayslink = 1, diff --git a/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc b/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc index 9ea80fd..60102cb 100644 --- a/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc +++ b/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc @@ -27,6 +27,7 @@ limitations under the License. #include "mlir/Dialect/Linalg/IR/LinalgTypes.h" #include "mlir/Dialect/SCF/SCF.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/Builders.h" @@ -437,6 +438,55 @@ class HloBroadcastInDimConverter } }; +class HloDynamicBroadcastInDimConverter + : public OpConversionPattern { + public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite( + mhlo::DynamicBroadcastInDimOp op, ArrayRef operands, + ConversionPatternRewriter& rewriter) const final { + // Convert only if the producer is an HLO constant. Ideally the pattern + // (`mhlo.constant` -> `mhlo.dynamic_broadcast_in_dim`) should be converted + // to an Tensor-dialect op similar to TF ConstantLikeOp. + if (!op.operand().getDefiningOp()) return failure(); + + mhlo::DynamicBroadcastInDimOp::Adaptor adaptor(op); + Value operand = adaptor.operand(); + auto operand_type = operand.getType().dyn_cast(); + if (!operand_type || operand_type.getRank() != 0) return failure(); + + Value shape = adaptor.output_dimensions(); + auto shape_type = shape.getType().cast(); + int64_t result_rank = shape_type.getDimSize(0); + + SmallVector dyn_dims; + Location loc = op.getLoc(); + for (int i = 0; i < result_rank; ++i) { + Value index = rewriter.create(loc, i); + dyn_dims.push_back(rewriter.create(loc, shape, index)); + } + auto result_type = op.getType().cast(); + + int64_t nloops = result_type.getRank(); + Value init = rewriter.create( + loc, dyn_dims, result_type.getShape(), result_type.getElementType()); + Operation* generic = rewriter.create( + loc, TypeRange{init.getType()}, ValueRange{operand}, + /*outputBuffers=*/ValueRange{init}, + llvm::makeArrayRef( + {AffineMap::get(/*dimCount=*/nloops, /*symbolCount=*/0, {}, + rewriter.getContext()), + rewriter.getMultiDimIdentityMap(nloops)}), + GetNParallelLoopsAttrs(nloops), + [&](OpBuilder& nested_builder, Location nested_loc, ValueRange args) { + nested_builder.create(loc, *args.begin()); + }); + rewriter.replaceOp(op, generic->getResults()); + return success(); + } +}; + class LhloBroadcastInDimConverter : public OpConversionPattern { public: @@ -1067,7 +1117,7 @@ struct HloLegalizeToLinalgPass OwningRewritePatternList patterns; ConversionTarget target(getContext()); target.addLegalDialect(); + tensor::TensorDialect, scf::SCFDialect>(); auto func = getFunction(); mhlo::populateHLOToLinalgConversionPattern(func.getContext(), &patterns); @@ -1091,8 +1141,8 @@ void populateHLOToLinalgConversionPattern(MLIRContext* context, OwningRewritePatternList* patterns) { patterns ->insert, - ConstConverter, HloBroadcastInDimConverter, - IotaConverter, + ConstConverter, HloDynamicBroadcastInDimConverter, + HloBroadcastInDimConverter, IotaConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, diff --git a/tests/hlo-legalize-to-linalg.mlir b/tests/hlo-legalize-to-linalg.mlir index 63abc02..31b89d2 100644 --- a/tests/hlo-legalize-to-linalg.mlir +++ b/tests/hlo-legalize-to-linalg.mlir @@ -808,3 +808,25 @@ func @integer_pow(%lhs: tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32> return %0 : tensor<2x2xi32> } + +// ----- + +// CHECK-DAG: #[[OPERAND_MAP:.*]] = affine_map<(d0) -> ()> +// CHECK-DAG: #[[RESULT_MAP:.*]] = affine_map<(d0) -> (d0)> + +// CHECK-LABEL: func @dynamic_broadcast_in_dim( +// CHECK-SAME: [[SHAPE:%.*]]: tensor<1xindex> +func @dynamic_broadcast_in_dim(%shape: tensor<1xindex>) -> tensor { + %cst = mhlo.constant dense<0x7F800000> : tensor + %result = "mhlo.dynamic_broadcast_in_dim"(%cst, %shape) { + broadcast_dimensions = dense<> : tensor<0xi64> + } : (tensor, tensor<1xindex>) -> tensor + return %result : tensor +} +// CHECK: [[CST:%.*]] = constant +// CHECK: [[INIT:%.*]] = linalg.init_tensor +// CHECK: linalg.generic +// CHECK-SAME: indexing_maps = [#[[OPERAND_MAP]], #[[RESULT_MAP]]] +// CHECK-SAME: ins([[CST]] : tensor) outs([[INIT]] : tensor) +// CHECK-NEXT: ^bb0(%[[OPERAND:.*]]: f32, %[[RESULT:.*]]: f32): +// CHECK-NEXT: linalg.yield %[[OPERAND]] : f32