[MHLO:linalg] Lower all dynamic broadcasts of static shapes to linalg.generic
We only need the memref_reinterpret_cast if we don't know whether a dimension gets expanded or not. With static shapes we know that a dimension can only be expanded if it's a static 1, so lower it in the same way we lower fully static broadcasts. PiperOrigin-RevId: 363859181
This commit is contained in:
parent
22b27dbaa2
commit
59fa7c0ef7
|
@ -518,15 +518,20 @@ class HloDynamicBroadcastInDimConverter
|
||||||
LogicalResult matchAndRewrite(
|
LogicalResult matchAndRewrite(
|
||||||
mhlo::DynamicBroadcastInDimOp op, ArrayRef<Value> operands,
|
mhlo::DynamicBroadcastInDimOp op, ArrayRef<Value> operands,
|
||||||
ConversionPatternRewriter& rewriter) const final {
|
ConversionPatternRewriter& rewriter) const final {
|
||||||
// Convert only if the producer is an HLO constant. Ideally the pattern
|
// If the input has a static shape we know exactly when the broadcast must
|
||||||
// (`mhlo.constant` -> `mhlo.dynamic_broadcast_in_dim`) should be converted
|
// expand (the dimension is 1, which also trivially expands to 1) or will
|
||||||
// to an Tensor-dialect op similar to TF ConstantLikeOp.
|
// never expand (the dimension is not 1). This means we can lower the
|
||||||
if (!op.operand().getDefiningOp<mhlo::ConstOp>()) return failure();
|
// broadcast just as we would lower a fully static broadcast and go directly
|
||||||
|
// to linalg.generic. This also covers the important case of broadcasting a
|
||||||
|
// scalar.
|
||||||
|
|
||||||
|
// Ideally the pattern (`mhlo.constant` -> `mhlo.dynamic_broadcast_in_dim`)
|
||||||
|
// should be converted to an Tensor-dialect op similar to TF ConstantLikeOp.
|
||||||
|
|
||||||
mhlo::DynamicBroadcastInDimOp::Adaptor adaptor(op);
|
mhlo::DynamicBroadcastInDimOp::Adaptor adaptor(op);
|
||||||
Value operand = adaptor.operand();
|
Value operand = adaptor.operand();
|
||||||
auto operand_type = operand.getType().dyn_cast<RankedTensorType>();
|
auto operand_type = operand.getType().dyn_cast<RankedTensorType>();
|
||||||
if (!operand_type || operand_type.getRank() != 0) return failure();
|
if (!operand_type || !operand_type.hasStaticShape()) return failure();
|
||||||
|
|
||||||
Value shape = adaptor.output_dimensions();
|
Value shape = adaptor.output_dimensions();
|
||||||
auto shape_type = shape.getType().cast<RankedTensorType>();
|
auto shape_type = shape.getType().cast<RankedTensorType>();
|
||||||
|
@ -544,13 +549,27 @@ class HloDynamicBroadcastInDimConverter
|
||||||
}
|
}
|
||||||
|
|
||||||
int64_t nloops = result_type.getRank();
|
int64_t nloops = result_type.getRank();
|
||||||
|
auto operand_shape = operand_type.getShape();
|
||||||
|
SmallVector<AffineExpr, 4> dim_exprs;
|
||||||
|
dim_exprs.reserve(nloops);
|
||||||
|
|
||||||
|
if (op.broadcast_dimensions()) {
|
||||||
|
for (const auto& broadcast_dim :
|
||||||
|
enumerate(op.broadcast_dimensions().getIntValues())) {
|
||||||
|
int64_t size = broadcast_dim.value().getSExtValue();
|
||||||
|
bool expansion_needed = operand_shape[broadcast_dim.index()] == 1;
|
||||||
|
dim_exprs.push_back(expansion_needed ? rewriter.getAffineConstantExpr(0)
|
||||||
|
: rewriter.getAffineDimExpr(size));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
Value init = rewriter.create<linalg::InitTensorOp>(
|
Value init = rewriter.create<linalg::InitTensorOp>(
|
||||||
loc, dyn_dims, result_type.getShape(), result_type.getElementType());
|
loc, dyn_dims, result_type.getShape(), result_type.getElementType());
|
||||||
Operation* generic = rewriter.create<linalg::GenericOp>(
|
Operation* generic = rewriter.create<linalg::GenericOp>(
|
||||||
loc, TypeRange{init.getType()}, ValueRange{operand},
|
loc, TypeRange{init.getType()}, ValueRange{operand},
|
||||||
/*outputBuffers=*/ValueRange{init},
|
/*outputBuffers=*/ValueRange{init},
|
||||||
llvm::makeArrayRef(
|
llvm::makeArrayRef(
|
||||||
{AffineMap::get(/*dimCount=*/nloops, /*symbolCount=*/0, {},
|
{AffineMap::get(/*dimCount=*/nloops, /*symbolCount=*/0, dim_exprs,
|
||||||
rewriter.getContext()),
|
rewriter.getContext()),
|
||||||
rewriter.getMultiDimIdentityMap(nloops)}),
|
rewriter.getMultiDimIdentityMap(nloops)}),
|
||||||
GetNParallelLoopsAttrs(nloops),
|
GetNParallelLoopsAttrs(nloops),
|
||||||
|
|
|
@ -997,19 +997,41 @@ func @dynamic_broadcast_in_dim(%shape: tensor<1xindex>) -> tensor<?xf32> {
|
||||||
// CHECK-DAG: #[[RESULT_MAP:.*]] = affine_map<(d0, d1) -> (d0, d1)>
|
// CHECK-DAG: #[[RESULT_MAP:.*]] = affine_map<(d0, d1) -> (d0, d1)>
|
||||||
|
|
||||||
// CHECK-LABEL: func @dynamic_broadcast_in_dim(
|
// CHECK-LABEL: func @dynamic_broadcast_in_dim(
|
||||||
|
// CHECK-SAME: [[SCALAR:%.*]]: tensor<f32>
|
||||||
// CHECK-SAME: [[SHAPE:%.*]]: tensor<2xindex>
|
// CHECK-SAME: [[SHAPE:%.*]]: tensor<2xindex>
|
||||||
func @dynamic_broadcast_in_dim(%shape: tensor<2xindex>) -> tensor<?x32xf32> {
|
func @dynamic_broadcast_in_dim(%scalar: tensor<f32>, %shape: tensor<2xindex>)
|
||||||
%cst = mhlo.constant dense<0x7F800000> : tensor<f32>
|
-> tensor<?x32xf32> {
|
||||||
%result = "mhlo.dynamic_broadcast_in_dim"(%cst, %shape) {
|
%result = "mhlo.dynamic_broadcast_in_dim"(%scalar, %shape) {
|
||||||
broadcast_dimensions = dense<> : tensor<0xi64>
|
broadcast_dimensions = dense<> : tensor<0xi64>
|
||||||
} : (tensor<f32>, tensor<2xindex>) -> tensor<?x32xf32>
|
} : (tensor<f32>, tensor<2xindex>) -> tensor<?x32xf32>
|
||||||
return %result : tensor<?x32xf32>
|
return %result : tensor<?x32xf32>
|
||||||
}
|
}
|
||||||
// CHECK: [[CST:%.*]] = constant
|
|
||||||
// CHECK: [[INIT:%.*]] = linalg.init_tensor
|
// CHECK: [[INIT:%.*]] = linalg.init_tensor
|
||||||
// CHECK: linalg.generic
|
// CHECK: linalg.generic
|
||||||
// CHECK-SAME: indexing_maps = [#[[OPERAND_MAP]], #[[RESULT_MAP]]]
|
// CHECK-SAME: indexing_maps = [#[[OPERAND_MAP]], #[[RESULT_MAP]]]
|
||||||
// CHECK-SAME: ins([[CST]] : tensor<f32>) outs([[INIT]] : tensor<?x32xf32>)
|
// CHECK-SAME: ins([[SCALAR]] : tensor<f32>) outs([[INIT]] : tensor<?x32xf32>)
|
||||||
|
// CHECK-NEXT: ^bb0(%[[OPERAND:.*]]: f32, %[[RESULT:.*]]: f32):
|
||||||
|
// CHECK-NEXT: linalg.yield %[[OPERAND]] : f32
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-DAG: #[[OPERAND_MAP:.*]] = affine_map<(d0, d1, d2) -> (d1)>
|
||||||
|
// CHECK-DAG: #[[RESULT_MAP:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @dynamic_broadcast_in_dim(
|
||||||
|
// CHECK-SAME: [[VECTOR:%.*]]: tensor<42xf32>
|
||||||
|
// CHECK-SAME: [[SHAPE:%.*]]: tensor<3xindex>
|
||||||
|
func @dynamic_broadcast_in_dim(%vector: tensor<42xf32>, %shape: tensor<3xindex>)
|
||||||
|
-> tensor<?x?x?xf32> {
|
||||||
|
%result = "mhlo.dynamic_broadcast_in_dim"(%vector, %shape) {
|
||||||
|
broadcast_dimensions = dense<1> : tensor<1xi64>
|
||||||
|
} : (tensor<42xf32>, tensor<3xindex>) -> tensor<?x?x?xf32>
|
||||||
|
return %result : tensor<?x?x?xf32>
|
||||||
|
}
|
||||||
|
// CHECK: [[INIT:%.*]] = linalg.init_tensor
|
||||||
|
// CHECK: linalg.generic
|
||||||
|
// CHECK-SAME: indexing_maps = [#[[OPERAND_MAP]], #[[RESULT_MAP]]]
|
||||||
|
// CHECK-SAME: ins([[VECTOR]] : tensor<42xf32>) outs([[INIT]] : tensor<?x?x?xf32>)
|
||||||
// CHECK-NEXT: ^bb0(%[[OPERAND:.*]]: f32, %[[RESULT:.*]]: f32):
|
// CHECK-NEXT: ^bb0(%[[OPERAND:.*]]: f32, %[[RESULT:.*]]: f32):
|
||||||
// CHECK-NEXT: linalg.yield %[[OPERAND]] : f32
|
// CHECK-NEXT: linalg.yield %[[OPERAND]] : f32
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue