Canonicalize mhlo.gather to mhlo.slice if it has a single set of constant indices
PiperOrigin-RevId: 330380755
This commit is contained in:
parent
dde1ed56cc
commit
73b4861f2c
|
@ -1065,6 +1065,8 @@ def HLO_GatherOp: HLO_Op<"gather", [NoSideEffect]>, BASE_HLO_GatherOp {
|
||||||
);
|
);
|
||||||
|
|
||||||
let results = (outs HLO_Tensor);
|
let results = (outs HLO_Tensor);
|
||||||
|
|
||||||
|
let hasCanonicalizer = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
def HLO_GetDimensionSizeOp: HLO_Op<"get_dimension_size", [NoSideEffect]>,
|
def HLO_GetDimensionSizeOp: HLO_Op<"get_dimension_size", [NoSideEffect]>,
|
||||||
|
|
|
@ -165,6 +165,57 @@ static LogicalResult Verify(DotGeneralOp op) {
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// GatherOp
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
// Converts gather ops to slice ops in case we have a single set of constant
|
||||||
|
// indices.
|
||||||
|
struct GatherSlice : public OpRewritePattern<GatherOp> {
|
||||||
|
using OpRewritePattern<GatherOp>::OpRewritePattern;
|
||||||
|
|
||||||
|
LogicalResult matchAndRewrite(GatherOp gather,
|
||||||
|
PatternRewriter& rewriter) const override {
|
||||||
|
DenseIntElementsAttr index;
|
||||||
|
if (!matchPattern(gather.start_indices(), m_Constant(&index)))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
const auto& dnums = gather.dimension_numbers();
|
||||||
|
if (dnums.collapsed_slice_dims().getNumElements() != 0 ||
|
||||||
|
dnums.index_vector_dim().getInt() != 0 || index.getType().getRank() > 1)
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
// TODO(tberghammer): Remove when the verifier catches this case what is
|
||||||
|
// invalid if all previous condition holds.
|
||||||
|
if (index.getNumElements() != dnums.start_index_map().getNumElements())
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
auto slice_end =
|
||||||
|
llvm::to_vector<8>(gather.slice_sizes().getValues<int64_t>());
|
||||||
|
llvm::SmallVector<int64_t, 8> slice_start(slice_end.size(), 0);
|
||||||
|
for (auto it : llvm::zip(dnums.start_index_map().getIntValues(),
|
||||||
|
index.getIntValues())) {
|
||||||
|
int64_t map_index = std::get<0>(it).getSExtValue();
|
||||||
|
int64_t offset = std::get<1>(it).getSExtValue();
|
||||||
|
slice_start[map_index] += offset;
|
||||||
|
slice_end[map_index] += offset;
|
||||||
|
}
|
||||||
|
|
||||||
|
llvm::SmallVector<int64_t, 8> slice_stride(slice_end.size(), 1);
|
||||||
|
rewriter.replaceOpWithNewOp<SliceOp>(
|
||||||
|
gather, gather.getType(), gather.getOperand(0),
|
||||||
|
GetI64ElementsAttr(slice_start, &rewriter),
|
||||||
|
GetI64ElementsAttr(slice_end, &rewriter),
|
||||||
|
GetI64ElementsAttr(slice_stride, &rewriter));
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
void GatherOp::getCanonicalizationPatterns(OwningRewritePatternList& results,
|
||||||
|
MLIRContext* context) {
|
||||||
|
results.insert<GatherSlice>(context);
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// GetDimensionSizeOp
|
// GetDimensionSizeOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
|
@ -665,3 +665,32 @@ func @fold_select_vector(%arg0 : tensor<4xf32>, %arg1 : tensor<4xf32>) -> tensor
|
||||||
return %1 : tensor<4xf32>
|
return %1 : tensor<4xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: gather_to_slice
|
||||||
|
func @gather_to_slice(%arg0: tensor<5x6x7xf32>) -> tensor<3x6x5xf32> {
|
||||||
|
%0 = constant dense<[1, 2]> : tensor<2xi32>
|
||||||
|
%1 = "mhlo.gather"(%arg0, %0) {
|
||||||
|
dimension_numbers = {collapsed_slice_dims = dense<> : tensor<0xi64>,
|
||||||
|
index_vector_dim = 0 : i64,
|
||||||
|
offset_dims = dense<[0, 1, 2]> : tensor<3xi64>,
|
||||||
|
start_index_map = dense<[0, 2]> : tensor<2xi64>},
|
||||||
|
indices_are_sorted = false,
|
||||||
|
slice_sizes = dense<[3, 6, 5]> : tensor<3xi64>} : (tensor<5x6x7xf32>, tensor<2xi32>) -> tensor<3x6x5xf32>
|
||||||
|
return %1 : tensor<3x6x5xf32>
|
||||||
|
// CHECK: %[[RET:.*]] = "mhlo.slice"(%arg0) {limit_indices = dense<[4, 6, 7]> : tensor<3xi64>, start_indices = dense<[1, 0, 2]> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>} : (tensor<5x6x7xf32>) -> tensor<3x6x5xf32>
|
||||||
|
// CHECK: return %[[RET]] : tensor<3x6x5xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: gather_scalar_index_to_slice
|
||||||
|
func @gather_scalar_index_to_slice(%arg0: tensor<5x6x7xf32>) -> tensor<5x6x4xf32> {
|
||||||
|
%0 = constant dense<1> : tensor<i32>
|
||||||
|
%1 = "mhlo.gather"(%arg0, %0) {
|
||||||
|
dimension_numbers = {collapsed_slice_dims = dense<> : tensor<0xi64>,
|
||||||
|
index_vector_dim = 0 : i64,
|
||||||
|
offset_dims = dense<[0, 1, 2]> : tensor<3xi64>,
|
||||||
|
start_index_map = dense<[2]> : tensor<1xi64>},
|
||||||
|
indices_are_sorted = false,
|
||||||
|
slice_sizes = dense<[5, 6, 4]> : tensor<3xi64>} : (tensor<5x6x7xf32>, tensor<i32>) -> tensor<5x6x4xf32>
|
||||||
|
return %1 : tensor<5x6x4xf32>
|
||||||
|
// CHECK: %[[RET:.*]] = "mhlo.slice"(%arg0) {limit_indices = dense<[5, 6, 5]> : tensor<3xi64>, start_indices = dense<[0, 0, 1]> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>} : (tensor<5x6x7xf32>) -> tensor<5x6x4xf32>
|
||||||
|
// CHECK: return %[[RET]] : tensor<5x6x4xf32>
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue