diff --git a/lib/Dialect/mhlo/IR/hlo_ops.cc b/lib/Dialect/mhlo/IR/hlo_ops.cc index 8e7673f..6607c9d 100644 --- a/lib/Dialect/mhlo/IR/hlo_ops.cc +++ b/lib/Dialect/mhlo/IR/hlo_ops.cc @@ -185,8 +185,7 @@ struct GatherSlice : public OpRewritePattern { return failure(); const auto& dnums = gather.dimension_numbers(); - if (dnums.collapsed_slice_dims().getNumElements() != 0 || - dnums.index_vector_dim().getInt() != 0 || index.getType().getRank() > 1) + if (dnums.index_vector_dim().getInt() != 0 || index.getType().getRank() > 1) return failure(); // TODO(tberghammer): Remove when the verifier catches this case what is @@ -206,11 +205,35 @@ struct GatherSlice : public OpRewritePattern { } llvm::SmallVector slice_stride(slice_end.size(), 1); - rewriter.replaceOpWithNewOp( - gather, gather.getType(), gather.getOperand(0), + llvm::SmallVector slice_shape(slice_end.size()); + for (int64_t i = 0; i < slice_end.size(); ++i) { + slice_shape[i] = slice_end[i] - slice_start[i]; + } + Type element_type = gather.getType().cast().getElementType(); + auto slice_type = RankedTensorType::get(slice_shape, element_type); + Value result = rewriter.create( + gather.getLoc(), slice_type, gather.getOperand(0), GetI64ElementsAttr(slice_start, &rewriter), GetI64ElementsAttr(slice_end, &rewriter), GetI64ElementsAttr(slice_stride, &rewriter)); + + if (dnums.collapsed_slice_dims().getNumElements() > 0) { + auto collapsed_slice_dims = llvm::to_vector<8>(llvm::map_range( + dnums.collapsed_slice_dims().getIntValues(), + [](const llvm::APInt& i) { return i.getSExtValue(); })); + llvm::SmallVector reshape_shape; + for (int64_t i = 0; i < slice_shape.size(); ++i) { + if (llvm::count(collapsed_slice_dims, i) == 0) { + reshape_shape.push_back(slice_shape[i]); + } + } + auto reshape_type = RankedTensorType::get(reshape_shape, element_type); + result = + rewriter.create(gather.getLoc(), reshape_type, result); + } + + result.setType(gather.getType()); + rewriter.replaceOp(gather, result); return success(); } }; diff --git a/tests/canonicalize.mlir b/tests/canonicalize.mlir index 6d88145..b065138 100644 --- a/tests/canonicalize.mlir +++ b/tests/canonicalize.mlir @@ -951,6 +951,22 @@ func @gather_scalar_index_to_slice(%arg0: tensor<5x6x7xf32>) -> tensor<5x6x4xf32 // CHECK: return %[[RET]] : tensor<5x6x4xf32> } +// CHECK-LABEL: gather_to_slice_reshape +func @gather_to_slice_reshape(%arg0: tensor<5x6x7xf32>) -> tensor<3x6xf32> { + %0 = constant dense<[1, 2]> : tensor<2xi32> + %1 = "mhlo.gather"(%arg0, %0) { + dimension_numbers = {collapsed_slice_dims = dense<[2]> : tensor<1xi64>, + index_vector_dim = 0 : i64, + offset_dims = dense<[0, 1, 2]> : tensor<3xi64>, + start_index_map = dense<[0, 2]> : tensor<2xi64>}, + indices_are_sorted = false, + slice_sizes = dense<[3, 6, 1]> : tensor<3xi64>} : (tensor<5x6x7xf32>, tensor<2xi32>) -> tensor<3x6xf32> + return %1 : tensor<3x6xf32> + // CHECK: %[[V0:.*]] = "mhlo.slice"(%arg0) {limit_indices = dense<[4, 6, 3]> : tensor<3xi64>, start_indices = dense<[1, 0, 2]> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>} : (tensor<5x6x7xf32>) -> tensor<3x6x1xf32> + // CHECK: %[[V1:.*]] = "mhlo.reshape"(%[[V0]]) : (tensor<3x6x1xf32>) -> tensor<3x6xf32> + // CHECK: return %[[V1]] : tensor<3x6xf32> +} + // CHECK-LABEL: func @fold_and_same func @fold_and_same(%arg0 : tensor<4xi32>) -> tensor<4xi32> { %0 = "mhlo.and"(%arg0, %arg0) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>