diff --git a/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc b/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc index 06651a7..c6f54b0 100644 --- a/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc +++ b/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc @@ -1773,6 +1773,99 @@ struct ReduceWindowOpOnTensorsConversion } }; +/// Converts xla-hlo.torch_index_select op to a linalg.indexed_generic op. +struct TorchIndexSelectOpOnTensorsConversion + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite( + mhlo::TorchIndexSelectOp op, ArrayRef args, + ConversionPatternRewriter& rewriter) const final { + mhlo::TorchIndexSelectOp::Adaptor adaptor(args); + int axis = static_cast(op.dim()); + int batch = static_cast(op.batch_dims()); + auto index_shaped_type = adaptor.index().getType().cast(); + int num_indices = static_cast(index_shaped_type.getRank()); + auto input_shaped_type = adaptor.input().getType().cast(); + if (axis < 0) axis += static_cast(input_shaped_type.getRank()); + if (batch < 0) batch += num_indices; + + Location loc = op.getLoc(); + auto result_type = op.getResult().getType().cast(); + int rank = static_cast(result_type.getRank()); + + SmallVector indexing_maps; + SmallVector exprs; + for (int i = 0; i < batch; ++i) { + exprs.push_back(rewriter.getAffineDimExpr(i)); + } + for (int i = 0, e = num_indices - batch; i < e; ++i) { + exprs.push_back(rewriter.getAffineDimExpr(axis + i)); + } + indexing_maps.emplace_back( + AffineMap::get(rank, /*symbolCount=*/0, exprs, rewriter.getContext())); + indexing_maps.emplace_back(rewriter.getMultiDimIdentityMap(rank)); + + // The output shape is + // `params[:axis] + indices[batch_dims:] + params[axis + 1:]` + SmallVector dyn_sizes; + for (int i = 0; i < rank; ++i) { + if (!result_type.isDynamicDim(i)) continue; + if (i < axis) { + dyn_sizes.push_back( + rewriter.create(loc, adaptor.input(), i)); + } else if (i < (axis + num_indices - batch)) { + int idx = i - axis + batch; + dyn_sizes.push_back( + rewriter.create(loc, adaptor.index(), idx)); + } else { + int idx = i - (axis + num_indices - batch) + axis + 1; + dyn_sizes.push_back( + rewriter.create(loc, adaptor.input(), idx)); + } + } + Value init_op = rewriter.create( + loc, dyn_sizes, result_type.getShape(), result_type.getElementType()); + auto linalg_op = rewriter.create( + loc, /*resultTensors=*/ArrayRef{result_type}, + /*inputs=*/adaptor.index(), + /*outputs=*/init_op, indexing_maps, GetNParallelLoopsAttrs(rank)); + + SmallVector body_arg_types; + SmallVector linalg_op_args = {adaptor.index()}; + // Add a block to the region. + auto* region = &linalg_op.region(); + auto* block = rewriter.createBlock(region, region->end()); + body_arg_types.append(rank, rewriter.getIndexType()); + for (auto block_args : linalg_op_args) { + body_arg_types.push_back( + block_args.getType().cast().getElementType()); + } + block->addArguments(body_arg_types); + block->addArguments(result_type.getElementType()); + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToEnd(block); + + SmallVector indices; + Value casted_value = rewriter.create( + loc, block->getArgument(rank), rewriter.getIndexType()); + for (int i = 0; i < axis; ++i) { + indices.push_back(block->getArgument(i)); + } + indices.push_back(casted_value); + for (int i = axis + num_indices - batch; i < rank; ++i) { + indices.push_back(block->getArgument(i)); + } + + Value res = + rewriter.create(loc, adaptor.input(), indices); + rewriter.create(loc, res); + + rewriter.replaceOp(op, linalg_op.getResults()); + return success(); + } +}; + void populateLHLOToLinalgConversionPattern(MLIRContext* context, OwningRewritePatternList* patterns) { // clang-format off @@ -1968,6 +2061,7 @@ void populateHLOToLinalgConversionPattern(MLIRContext* context, DepthwiseConvOpOnTensorsConversion, ReduceOnTensorsConversion, ReduceWindowOpOnTensorsConversion, + TorchIndexSelectOpOnTensorsConversion, PadOpOnTensorsConversion>(context); // clang-format on patterns->insert, diff --git a/tests/hlo-legalize-to-linalg.mlir b/tests/hlo-legalize-to-linalg.mlir index 53b4cfe..0013d97 100644 --- a/tests/hlo-legalize-to-linalg.mlir +++ b/tests/hlo-legalize-to-linalg.mlir @@ -1826,7 +1826,6 @@ func @reduce_window_max_nhwc_with_cst(%arg0: tensor<1x18x18x64xf32>) -> tensor<1 return %1 : tensor<1x8x8x64xf32> } -// ----- // CHECK-LABEL: func @reduce_window_max_nhwc // CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]] // CHECK-DAG: %[[CST:.+]] = constant dense<0xFF800000> : tensor @@ -1839,3 +1838,125 @@ func @reduce_window_max_nhwc_with_cst(%arg0: tensor<1x18x18x64xf32>) -> tensor<1 // CHECK-SAME: strides = dense<2> : vector<2xi64>} // CHECK-SAME: ins(%[[ARG0]], %[[WINDOW]] : tensor<1x18x18x64xf32>, tensor<3x3xf32>) // CHECK-SAME: outs(%[[FILL]] : tensor<1x8x8x64xf32>) -> tensor<1x8x8x64xf32> + +// ----- + +func @torch_select_index(%arg0: tensor<5x1x5xi32>, + %arg1: tensor<2xi32>) -> tensor<2x1x5xi32> { + %0 = "mhlo.torch_index_select"(%arg0, %arg1) { + dim = 0 : i64, + batch_dims = 0 : i64 + } : (tensor<5x1x5xi32>, tensor<2xi32>) -> tensor<2x1x5xi32> + return %0 : tensor<2x1x5xi32> +} +// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d0)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +// CHECK: func @torch_select_index +// CHECK-SAME: %[[INPUT:[a-zA-Z0-9_]*]] +// CHECK-SAME: %[[INDEX:[a-zA-Z0-9_]*]] +// CHECK: linalg.indexed_generic { +// CHECK-SAME: indexing_maps +// CHECK-SAME: #[[MAP0]], #[[MAP1]] +// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"] +// CHECK-SAME: ins(%[[INDEX]] : tensor<2xi32>) +// CHECK: ^{{.+}}( +// CHECK-SAME: %[[I:.+]]: index, %[[J:.+]]: index, %[[K:.+]]: index +// CHECK-SAME: %[[VAL:.+]]: i32, %{{.+}}: i32): +// CHECK: %[[CAST:.+]] = index_cast %[[VAL]] : i32 to index +// CHECK: %[[VAL2:.+]] = tensor.extract %[[INPUT]][%[[CAST]], %[[J]], %[[K]]] : tensor<5x1x5xi32> +// CHECK: linalg.yield %[[VAL2]] : i32 + +// ----- + +func @torch_select_index_scalar(%arg0: tensor<4x8xf32>, + %arg1: tensor) -> tensor<8xf32> { + %0 = "mhlo.torch_index_select"(%arg0, %arg1) { + batch_dims = 0 : i64, + dim = 0 : i64 + } : (tensor<4x8xf32>, tensor) -> tensor<8xf32> + return %0 : tensor<8xf32> +} + +// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0) -> ()> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0) -> (d0)> +// CHECK: func @torch_select_index_scalar +// CHECK-SAME: %[[INPUT:[a-zA-Z0-9_]*]] +// CHECK-SAME: %[[INDEX:[a-zA-Z0-9_]*]] +// CHECK: %[[T0:.+]] = linalg.init_tensor [8] : tensor<8xf32> +// CHECK: linalg.indexed_generic { +// CHECK-SAME: indexing_maps +// CHECK-SAME: #[[MAP0]], #[[MAP1]] +// CHECK-SAME: iterator_types = ["parallel"] +// CHECK-SAME: ins(%[[INDEX]] : tensor) outs(%[[T0]] : tensor<8xf32>) +// CHECK: ^{{.+}}( +// CHECK-SAME: %[[I:[a-zA-Z0-9_]+]]: index, %[[VAL:[a-zA-Z0-9_]+]]: i32, %{{.+}}: f32): +// CHECK: %[[CAST:.+]] = index_cast %[[VAL]] : i32 to index +// CHECK: %[[VAL2:.+]] = tensor.extract %[[INPUT]][%[[CAST]], %[[I]]] : tensor<4x8xf32> +// CHECK: linalg.yield %[[VAL2]] : f32 + +// ----- + +func @torch_select_index_batch(%arg0: tensor<4x7x8x2xf32>, + %arg1: tensor<4x1xi32>) -> tensor<4x7x1x2xf32> { + %0 = "mhlo.torch_index_select"(%arg0, %arg1) { + dim = 2 : i64, + batch_dims = 1 : i64 + } : (tensor<4x7x8x2xf32>, tensor<4x1xi32>) -> tensor<4x7x1x2xf32> + return %0 : tensor<4x7x1x2xf32> +} +// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d2)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +// CHECK: func @torch_select_index_batch +// CHECK-SAME: %[[INPUT:[a-zA-Z0-9_]*]] +// CHECK-SAME: %[[INDEX:[a-zA-Z0-9_]*]] +// CHECK: linalg.indexed_generic { +// CHECK-SAME: indexing_maps +// CHECK-SAME: #[[MAP0]], #[[MAP1]] +// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel"] +// CHECK-SAME: ins(%[[INDEX]] : tensor<4x1xi32>) +// CHECK-NEXT: ^{{.+}}( +// CHECK-SAME: %[[I:[a-zA-Z0-9_]+]]: index, %[[J:[a-zA-Z0-9_]+]]: index, +// CHECK-SAME: %[[K:[a-zA-Z0-9_]+]]: index, %[[L:.+]]: index, +// CHECK-SAME: %[[VAL:.+]]: i32, %{{.+}}: f32): +// CHECK: %[[CAST:.+]] = index_cast %[[VAL]] : i32 to index +// CHECK: %[[VAL2:.+]] = tensor.extract %[[INPUT]][%[[I]], %[[J]], %[[CAST]], %[[L]]] : tensor<4x7x8x2xf32> +// CHECK: linalg.yield %[[VAL2]] : f32 + +// ----- + +func @torch_index_select_dynamic(%input: tensor, + %index: tensor) -> tensor{ + %0 = "mhlo.torch_index_select"(%input, %index) { + batch_dims = 1 : i64, + dim = 2 : i64 + } : (tensor, tensor) -> tensor + return %0 : tensor +} +// CHECK: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d2)> +// CHECK: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +// CHECK: func @torch_index_select_dynamic +// CHECK-SAME: %[[INPUT:[a-zA-Z0-9_]*]] +// CHECK-SAME: %[[INDEX:[a-zA-Z0-9_]*]] +// CHECK: %[[C0:.+]] = constant 0 : index +// CHECK: %[[D0:.+]] = memref.dim %[[INPUT]], %[[C0]] +// CHECK: %[[C1:.+]] = constant 1 : index +// CHECK: %[[D1:.+]] = memref.dim %[[INPUT]], %[[C1]] +// CHECK: %[[C1:.+]] = constant 1 : index +// CHECK: %[[D2:.+]] = memref.dim %[[INDEX]], %[[C1]] +// CHECK: %[[C3:.+]] = constant 3 : index +// CHECK: %[[D3:.+]] = memref.dim %[[INPUT]], %[[C3]] +// CHECK: %[[INIT:.+]] = linalg.init_tensor [%[[D0]], %[[D1]], %[[D2]], %[[D3]]] +// CHECK: %[[RESULT:.+]] = linalg.indexed_generic +// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]] +// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel"] +// CHECK-SAME: ins(%[[INDEX]] : tensor) +// CHECK-SAME: outs(%[[INIT]] : tensor) +// CHECK: ^{{.+}}( +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: index, +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index, +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: index +// CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]+]]: index, +// CHECK-SAME: %[[ARG4:[a-zA-Z0-9_]+]]: i32, %{{[a-zA-Z0-9_]+}}: f32) +// CHECK: %[[POS:.+]] = index_cast %[[ARG4]] +// CHECK: %[[YIELD:.+]] = tensor.extract %[[INPUT]][%[[ARG0]], %[[ARG1]], %[[POS]], %[[ARG3]]] +// CHECK: linalg.yield %[[YIELD]]