diff --git a/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc b/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc index 9e18a63..b8cac4c 100644 --- a/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc +++ b/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc @@ -517,14 +517,16 @@ class HloDynamicBroadcastInDimConverter auto shape_type = shape.getType().cast(); int64_t result_rank = shape_type.getDimSize(0); + auto result_type = op.getType().dyn_cast(); + if (!result_type) return failure(); + SmallVector dyn_dims; Location loc = op.getLoc(); for (int i = 0; i < result_rank; ++i) { + if (!result_type.isDynamicDim(i)) continue; Value index = rewriter.create(loc, i); dyn_dims.push_back(rewriter.create(loc, shape, index)); } - auto result_type = op.getType().dyn_cast(); - if (!result_type) return failure(); int64_t nloops = result_type.getRank(); Value init = rewriter.create( diff --git a/tests/hlo-legalize-to-linalg.mlir b/tests/hlo-legalize-to-linalg.mlir index 837406e..f55d6fc 100644 --- a/tests/hlo-legalize-to-linalg.mlir +++ b/tests/hlo-legalize-to-linalg.mlir @@ -954,6 +954,28 @@ func @dynamic_broadcast_in_dim(%shape: tensor<1xindex>) -> tensor { // ----- +// CHECK-DAG: #[[OPERAND_MAP:.*]] = affine_map<(d0, d1) -> ()> +// CHECK-DAG: #[[RESULT_MAP:.*]] = affine_map<(d0, d1) -> (d0, d1)> + +// CHECK-LABEL: func @dynamic_broadcast_in_dim( +// CHECK-SAME: [[SHAPE:%.*]]: tensor<2xindex> +func @dynamic_broadcast_in_dim(%shape: tensor<2xindex>) -> tensor { + %cst = mhlo.constant dense<0x7F800000> : tensor + %result = "mhlo.dynamic_broadcast_in_dim"(%cst, %shape) { + broadcast_dimensions = dense<> : tensor<0xi64> + } : (tensor, tensor<2xindex>) -> tensor + return %result : tensor +} +// CHECK: [[CST:%.*]] = constant +// CHECK: [[INIT:%.*]] = linalg.init_tensor +// CHECK: linalg.generic +// CHECK-SAME: indexing_maps = [#[[OPERAND_MAP]], #[[RESULT_MAP]]] +// CHECK-SAME: ins([[CST]] : tensor) outs([[INIT]] : tensor) +// CHECK-NEXT: ^bb0(%[[OPERAND:.*]]: f32, %[[RESULT:.*]]: f32): +// CHECK-NEXT: linalg.yield %[[OPERAND]] : f32 + +// ----- + func @dot_matmul(%arg0: tensor<2x3xf32>, %arg1: tensor<3x?xf32>) -> tensor<2x?xf32> { %0 = "mhlo.dot"(%arg0, %arg1) : (tensor<2x3xf32>,