diff --git a/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc b/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc index bd1cb53..4a62a40 100644 --- a/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc +++ b/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc @@ -536,6 +536,9 @@ class HloDynamicBroadcastInDimConverter Value shape = adaptor.output_dimensions(); auto shape_type = shape.getType().cast(); int64_t result_rank = shape_type.getDimSize(0); + // HLO dimension types can be any integer, as well as index. + bool convert_to_index = + shape_type.getElementType() != rewriter.getIndexType(); auto result_type = op.getType().dyn_cast(); if (!result_type) return failure(); @@ -545,7 +548,11 @@ class HloDynamicBroadcastInDimConverter for (int i = 0; i < result_rank; ++i) { if (!result_type.isDynamicDim(i)) continue; Value index = rewriter.create(loc, i); - dyn_dims.push_back(rewriter.create(loc, shape, index)); + Value dim = rewriter.create(loc, shape, index); + if (convert_to_index) { + dim = rewriter.create(loc, rewriter.getIndexType(), dim); + } + dyn_dims.push_back(dim); } int64_t nloops = result_type.getRank(); diff --git a/tests/hlo-legalize-to-linalg.mlir b/tests/hlo-legalize-to-linalg.mlir index acd425d..5085d49 100644 --- a/tests/hlo-legalize-to-linalg.mlir +++ b/tests/hlo-legalize-to-linalg.mlir @@ -1054,6 +1054,19 @@ func @dynamic_broadcast_in_dim(%vector: tensor<42xf32>, %shape: tensor<3xindex>) // ----- +// CHECK-LABEL: func @dynamic_broadcast_in_dim( +// Note: this test requires no checks. The linalg init_tensor verifier will +// fail if the %shape i32 -> index cast is not performed properly. +func @dynamic_broadcast_in_dim(%scalar: tensor, %shape: tensor<2xi32>) + -> tensor { + %result = "mhlo.dynamic_broadcast_in_dim"(%scalar, %shape) { + broadcast_dimensions = dense<> : tensor<0xi64> + } : (tensor, tensor<2xi32>) -> tensor + return %result : tensor +} + +// ----- + func @dot_matmul(%arg0: tensor<2x3xf32>, %arg1: tensor<3x?xf32>) -> tensor<2x?xf32> { %0 = "mhlo.dot"(%arg0, %arg1) : (tensor<2x3xf32>,