Support collapse_slice_dims in the mhlo.gather->mhlo.slice canonicalizer
PiperOrigin-RevId: 334774763
This commit is contained in:
parent
ae51900562
commit
4b1809784a
|
@ -185,8 +185,7 @@ struct GatherSlice : public OpRewritePattern<GatherOp> {
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
const auto& dnums = gather.dimension_numbers();
|
const auto& dnums = gather.dimension_numbers();
|
||||||
if (dnums.collapsed_slice_dims().getNumElements() != 0 ||
|
if (dnums.index_vector_dim().getInt() != 0 || index.getType().getRank() > 1)
|
||||||
dnums.index_vector_dim().getInt() != 0 || index.getType().getRank() > 1)
|
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
// TODO(tberghammer): Remove when the verifier catches this case what is
|
// TODO(tberghammer): Remove when the verifier catches this case what is
|
||||||
|
@ -206,11 +205,35 @@ struct GatherSlice : public OpRewritePattern<GatherOp> {
|
||||||
}
|
}
|
||||||
|
|
||||||
llvm::SmallVector<int64_t, 8> slice_stride(slice_end.size(), 1);
|
llvm::SmallVector<int64_t, 8> slice_stride(slice_end.size(), 1);
|
||||||
rewriter.replaceOpWithNewOp<SliceOp>(
|
llvm::SmallVector<int64_t, 8> slice_shape(slice_end.size());
|
||||||
gather, gather.getType(), gather.getOperand(0),
|
for (int64_t i = 0; i < slice_end.size(); ++i) {
|
||||||
|
slice_shape[i] = slice_end[i] - slice_start[i];
|
||||||
|
}
|
||||||
|
Type element_type = gather.getType().cast<TensorType>().getElementType();
|
||||||
|
auto slice_type = RankedTensorType::get(slice_shape, element_type);
|
||||||
|
Value result = rewriter.create<SliceOp>(
|
||||||
|
gather.getLoc(), slice_type, gather.getOperand(0),
|
||||||
GetI64ElementsAttr(slice_start, &rewriter),
|
GetI64ElementsAttr(slice_start, &rewriter),
|
||||||
GetI64ElementsAttr(slice_end, &rewriter),
|
GetI64ElementsAttr(slice_end, &rewriter),
|
||||||
GetI64ElementsAttr(slice_stride, &rewriter));
|
GetI64ElementsAttr(slice_stride, &rewriter));
|
||||||
|
|
||||||
|
if (dnums.collapsed_slice_dims().getNumElements() > 0) {
|
||||||
|
auto collapsed_slice_dims = llvm::to_vector<8>(llvm::map_range(
|
||||||
|
dnums.collapsed_slice_dims().getIntValues(),
|
||||||
|
[](const llvm::APInt& i) { return i.getSExtValue(); }));
|
||||||
|
llvm::SmallVector<int64_t, 8> reshape_shape;
|
||||||
|
for (int64_t i = 0; i < slice_shape.size(); ++i) {
|
||||||
|
if (llvm::count(collapsed_slice_dims, i) == 0) {
|
||||||
|
reshape_shape.push_back(slice_shape[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
auto reshape_type = RankedTensorType::get(reshape_shape, element_type);
|
||||||
|
result =
|
||||||
|
rewriter.create<ReshapeOp>(gather.getLoc(), reshape_type, result);
|
||||||
|
}
|
||||||
|
|
||||||
|
result.setType(gather.getType());
|
||||||
|
rewriter.replaceOp(gather, result);
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
|
@ -951,6 +951,22 @@ func @gather_scalar_index_to_slice(%arg0: tensor<5x6x7xf32>) -> tensor<5x6x4xf32
|
||||||
// CHECK: return %[[RET]] : tensor<5x6x4xf32>
|
// CHECK: return %[[RET]] : tensor<5x6x4xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: gather_to_slice_reshape
|
||||||
|
func @gather_to_slice_reshape(%arg0: tensor<5x6x7xf32>) -> tensor<3x6xf32> {
|
||||||
|
%0 = constant dense<[1, 2]> : tensor<2xi32>
|
||||||
|
%1 = "mhlo.gather"(%arg0, %0) {
|
||||||
|
dimension_numbers = {collapsed_slice_dims = dense<[2]> : tensor<1xi64>,
|
||||||
|
index_vector_dim = 0 : i64,
|
||||||
|
offset_dims = dense<[0, 1, 2]> : tensor<3xi64>,
|
||||||
|
start_index_map = dense<[0, 2]> : tensor<2xi64>},
|
||||||
|
indices_are_sorted = false,
|
||||||
|
slice_sizes = dense<[3, 6, 1]> : tensor<3xi64>} : (tensor<5x6x7xf32>, tensor<2xi32>) -> tensor<3x6xf32>
|
||||||
|
return %1 : tensor<3x6xf32>
|
||||||
|
// CHECK: %[[V0:.*]] = "mhlo.slice"(%arg0) {limit_indices = dense<[4, 6, 3]> : tensor<3xi64>, start_indices = dense<[1, 0, 2]> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>} : (tensor<5x6x7xf32>) -> tensor<3x6x1xf32>
|
||||||
|
// CHECK: %[[V1:.*]] = "mhlo.reshape"(%[[V0]]) : (tensor<3x6x1xf32>) -> tensor<3x6xf32>
|
||||||
|
// CHECK: return %[[V1]] : tensor<3x6xf32>
|
||||||
|
}
|
||||||
|
|
||||||
// CHECK-LABEL: func @fold_and_same
|
// CHECK-LABEL: func @fold_and_same
|
||||||
func @fold_and_same(%arg0 : tensor<4xi32>) -> tensor<4xi32> {
|
func @fold_and_same(%arg0 : tensor<4xi32>) -> tensor<4xi32> {
|
||||||
%0 = "mhlo.and"(%arg0, %arg0) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
|
%0 = "mhlo.and"(%arg0, %arg0) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
|
||||||
|
|
Loading…
Reference in New Issue