From 59fa7c0ef74c3b9b6549f99c46597527ce114f58 Mon Sep 17 00:00:00 2001 From: Benjamin Kramer Date: Fri, 19 Mar 2021 03:51:14 -0700 Subject: [PATCH] [MHLO:linalg] Lower all dynamic broadcasts of static shapes to linalg.generic We only need the memref_reinterpret_cast if we don't know whether a dimension gets expanded or not. With static shapes we know that a dimension can only be expanded if it's a static 1, so lower it in the same way we lower fully static broadcasts. PiperOrigin-RevId: 363859181 --- .../mhlo/transforms/legalize_to_linalg.cc | 31 ++++++++++++++---- tests/hlo-legalize-to-linalg.mlir | 32 ++++++++++++++++--- 2 files changed, 52 insertions(+), 11 deletions(-) diff --git a/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc b/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc index c6f54b0..493c6ad 100644 --- a/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc +++ b/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc @@ -518,15 +518,20 @@ class HloDynamicBroadcastInDimConverter LogicalResult matchAndRewrite( mhlo::DynamicBroadcastInDimOp op, ArrayRef operands, ConversionPatternRewriter& rewriter) const final { - // Convert only if the producer is an HLO constant. Ideally the pattern - // (`mhlo.constant` -> `mhlo.dynamic_broadcast_in_dim`) should be converted - // to an Tensor-dialect op similar to TF ConstantLikeOp. - if (!op.operand().getDefiningOp()) return failure(); + // If the input has a static shape we know exactly when the broadcast must + // expand (the dimension is 1, which also trivially expands to 1) or will + // never expand (the dimension is not 1). This means we can lower the + // broadcast just as we would lower a fully static broadcast and go directly + // to linalg.generic. This also covers the important case of broadcasting a + // scalar. + + // Ideally the pattern (`mhlo.constant` -> `mhlo.dynamic_broadcast_in_dim`) + // should be converted to an Tensor-dialect op similar to TF ConstantLikeOp. mhlo::DynamicBroadcastInDimOp::Adaptor adaptor(op); Value operand = adaptor.operand(); auto operand_type = operand.getType().dyn_cast(); - if (!operand_type || operand_type.getRank() != 0) return failure(); + if (!operand_type || !operand_type.hasStaticShape()) return failure(); Value shape = adaptor.output_dimensions(); auto shape_type = shape.getType().cast(); @@ -544,13 +549,27 @@ class HloDynamicBroadcastInDimConverter } int64_t nloops = result_type.getRank(); + auto operand_shape = operand_type.getShape(); + SmallVector dim_exprs; + dim_exprs.reserve(nloops); + + if (op.broadcast_dimensions()) { + for (const auto& broadcast_dim : + enumerate(op.broadcast_dimensions().getIntValues())) { + int64_t size = broadcast_dim.value().getSExtValue(); + bool expansion_needed = operand_shape[broadcast_dim.index()] == 1; + dim_exprs.push_back(expansion_needed ? rewriter.getAffineConstantExpr(0) + : rewriter.getAffineDimExpr(size)); + } + } + Value init = rewriter.create( loc, dyn_dims, result_type.getShape(), result_type.getElementType()); Operation* generic = rewriter.create( loc, TypeRange{init.getType()}, ValueRange{operand}, /*outputBuffers=*/ValueRange{init}, llvm::makeArrayRef( - {AffineMap::get(/*dimCount=*/nloops, /*symbolCount=*/0, {}, + {AffineMap::get(/*dimCount=*/nloops, /*symbolCount=*/0, dim_exprs, rewriter.getContext()), rewriter.getMultiDimIdentityMap(nloops)}), GetNParallelLoopsAttrs(nloops), diff --git a/tests/hlo-legalize-to-linalg.mlir b/tests/hlo-legalize-to-linalg.mlir index 0013d97..bb097eb 100644 --- a/tests/hlo-legalize-to-linalg.mlir +++ b/tests/hlo-legalize-to-linalg.mlir @@ -997,19 +997,41 @@ func @dynamic_broadcast_in_dim(%shape: tensor<1xindex>) -> tensor { // CHECK-DAG: #[[RESULT_MAP:.*]] = affine_map<(d0, d1) -> (d0, d1)> // CHECK-LABEL: func @dynamic_broadcast_in_dim( +// CHECK-SAME: [[SCALAR:%.*]]: tensor // CHECK-SAME: [[SHAPE:%.*]]: tensor<2xindex> -func @dynamic_broadcast_in_dim(%shape: tensor<2xindex>) -> tensor { - %cst = mhlo.constant dense<0x7F800000> : tensor - %result = "mhlo.dynamic_broadcast_in_dim"(%cst, %shape) { +func @dynamic_broadcast_in_dim(%scalar: tensor, %shape: tensor<2xindex>) + -> tensor { + %result = "mhlo.dynamic_broadcast_in_dim"(%scalar, %shape) { broadcast_dimensions = dense<> : tensor<0xi64> } : (tensor, tensor<2xindex>) -> tensor return %result : tensor } -// CHECK: [[CST:%.*]] = constant // CHECK: [[INIT:%.*]] = linalg.init_tensor // CHECK: linalg.generic // CHECK-SAME: indexing_maps = [#[[OPERAND_MAP]], #[[RESULT_MAP]]] -// CHECK-SAME: ins([[CST]] : tensor) outs([[INIT]] : tensor) +// CHECK-SAME: ins([[SCALAR]] : tensor) outs([[INIT]] : tensor) +// CHECK-NEXT: ^bb0(%[[OPERAND:.*]]: f32, %[[RESULT:.*]]: f32): +// CHECK-NEXT: linalg.yield %[[OPERAND]] : f32 + +// ----- + +// CHECK-DAG: #[[OPERAND_MAP:.*]] = affine_map<(d0, d1, d2) -> (d1)> +// CHECK-DAG: #[[RESULT_MAP:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> + +// CHECK-LABEL: func @dynamic_broadcast_in_dim( +// CHECK-SAME: [[VECTOR:%.*]]: tensor<42xf32> +// CHECK-SAME: [[SHAPE:%.*]]: tensor<3xindex> +func @dynamic_broadcast_in_dim(%vector: tensor<42xf32>, %shape: tensor<3xindex>) + -> tensor { + %result = "mhlo.dynamic_broadcast_in_dim"(%vector, %shape) { + broadcast_dimensions = dense<1> : tensor<1xi64> + } : (tensor<42xf32>, tensor<3xindex>) -> tensor + return %result : tensor +} +// CHECK: [[INIT:%.*]] = linalg.init_tensor +// CHECK: linalg.generic +// CHECK-SAME: indexing_maps = [#[[OPERAND_MAP]], #[[RESULT_MAP]]] +// CHECK-SAME: ins([[VECTOR]] : tensor<42xf32>) outs([[INIT]] : tensor) // CHECK-NEXT: ^bb0(%[[OPERAND:.*]]: f32, %[[RESULT:.*]]: f32): // CHECK-NEXT: linalg.yield %[[OPERAND]] : f32