Filter static dimensions from dynamic_broadcast_in_dim's init_tensor
Otherwise we'd generate invalid IR for those cases. PiperOrigin-RevId: 360144122
This commit is contained in:
parent
e6a1f5f0f9
commit
e19ccf975e
|
@ -517,14 +517,16 @@ class HloDynamicBroadcastInDimConverter
|
||||||
auto shape_type = shape.getType().cast<RankedTensorType>();
|
auto shape_type = shape.getType().cast<RankedTensorType>();
|
||||||
int64_t result_rank = shape_type.getDimSize(0);
|
int64_t result_rank = shape_type.getDimSize(0);
|
||||||
|
|
||||||
|
auto result_type = op.getType().dyn_cast<RankedTensorType>();
|
||||||
|
if (!result_type) return failure();
|
||||||
|
|
||||||
SmallVector<Value, 2> dyn_dims;
|
SmallVector<Value, 2> dyn_dims;
|
||||||
Location loc = op.getLoc();
|
Location loc = op.getLoc();
|
||||||
for (int i = 0; i < result_rank; ++i) {
|
for (int i = 0; i < result_rank; ++i) {
|
||||||
|
if (!result_type.isDynamicDim(i)) continue;
|
||||||
Value index = rewriter.create<ConstantIndexOp>(loc, i);
|
Value index = rewriter.create<ConstantIndexOp>(loc, i);
|
||||||
dyn_dims.push_back(rewriter.create<tensor::ExtractOp>(loc, shape, index));
|
dyn_dims.push_back(rewriter.create<tensor::ExtractOp>(loc, shape, index));
|
||||||
}
|
}
|
||||||
auto result_type = op.getType().dyn_cast<RankedTensorType>();
|
|
||||||
if (!result_type) return failure();
|
|
||||||
|
|
||||||
int64_t nloops = result_type.getRank();
|
int64_t nloops = result_type.getRank();
|
||||||
Value init = rewriter.create<linalg::InitTensorOp>(
|
Value init = rewriter.create<linalg::InitTensorOp>(
|
||||||
|
|
|
@ -954,6 +954,28 @@ func @dynamic_broadcast_in_dim(%shape: tensor<1xindex>) -> tensor<?xf32> {
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
|
// CHECK-DAG: #[[OPERAND_MAP:.*]] = affine_map<(d0, d1) -> ()>
|
||||||
|
// CHECK-DAG: #[[RESULT_MAP:.*]] = affine_map<(d0, d1) -> (d0, d1)>
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @dynamic_broadcast_in_dim(
|
||||||
|
// CHECK-SAME: [[SHAPE:%.*]]: tensor<2xindex>
|
||||||
|
func @dynamic_broadcast_in_dim(%shape: tensor<2xindex>) -> tensor<?x32xf32> {
|
||||||
|
%cst = mhlo.constant dense<0x7F800000> : tensor<f32>
|
||||||
|
%result = "mhlo.dynamic_broadcast_in_dim"(%cst, %shape) {
|
||||||
|
broadcast_dimensions = dense<> : tensor<0xi64>
|
||||||
|
} : (tensor<f32>, tensor<2xindex>) -> tensor<?x32xf32>
|
||||||
|
return %result : tensor<?x32xf32>
|
||||||
|
}
|
||||||
|
// CHECK: [[CST:%.*]] = constant
|
||||||
|
// CHECK: [[INIT:%.*]] = linalg.init_tensor
|
||||||
|
// CHECK: linalg.generic
|
||||||
|
// CHECK-SAME: indexing_maps = [#[[OPERAND_MAP]], #[[RESULT_MAP]]]
|
||||||
|
// CHECK-SAME: ins([[CST]] : tensor<f32>) outs([[INIT]] : tensor<?x32xf32>)
|
||||||
|
// CHECK-NEXT: ^bb0(%[[OPERAND:.*]]: f32, %[[RESULT:.*]]: f32):
|
||||||
|
// CHECK-NEXT: linalg.yield %[[OPERAND]] : f32
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
func @dot_matmul(%arg0: tensor<2x3xf32>,
|
func @dot_matmul(%arg0: tensor<2x3xf32>,
|
||||||
%arg1: tensor<3x?xf32>) -> tensor<2x?xf32> {
|
%arg1: tensor<3x?xf32>) -> tensor<2x?xf32> {
|
||||||
%0 = "mhlo.dot"(%arg0, %arg1) : (tensor<2x3xf32>,
|
%0 = "mhlo.dot"(%arg0, %arg1) : (tensor<2x3xf32>,
|
||||||
|
|
Loading…
Reference in New Issue