Add support for lowering mhlo.scatter ops to Linalg.
This only works for updating tensors, not add/min/max computations. It requires the index depth to be 1 because of the limitation in Linalg. We can not compare multiple indices without packing indices. PiperOrigin-RevId: 375137721
This commit is contained in:
parent
f5dc73ab60
commit
1ba4c714c9
|
@ -2091,6 +2091,102 @@ struct TorchIndexSelectOpOnTensorsConversion
|
|||
}
|
||||
};
|
||||
|
||||
struct ScatterUpdateOnTensorsConversion
|
||||
: public OpConversionPattern<mhlo::ScatterOp> {
|
||||
using OpConversionPattern<mhlo::ScatterOp>::OpConversionPattern;
|
||||
|
||||
LogicalResult matchAndRewrite(
|
||||
mhlo::ScatterOp op, ArrayRef<Value> args,
|
||||
ConversionPatternRewriter& rewriter) const final {
|
||||
mhlo::ScatterOp::Adaptor adaptor(args);
|
||||
|
||||
// Check if it is a tensor_scatter_nd_update-like op.
|
||||
auto& body_ops = op.getRegion().front().getOperations();
|
||||
if (body_ops.size() != 1) return failure();
|
||||
auto ret_arg = body_ops.front().getOperand(0).dyn_cast<BlockArgument>();
|
||||
if (!ret_arg || ret_arg.getArgNumber() != 1) return failure();
|
||||
|
||||
auto operand_ty = adaptor.operand().getType().dyn_cast<RankedTensorType>();
|
||||
auto indices_ty =
|
||||
adaptor.scatter_indices().getType().dyn_cast<RankedTensorType>();
|
||||
if (!operand_ty || !indices_ty) return failure();
|
||||
|
||||
// Linalg operations put all the computation to the innermost loop. Since we
|
||||
// also iterate over scatter_indices() with some loops, we can only check
|
||||
// one scatter index in one iteration. If there are multiple indices (ie,
|
||||
// the index depth is greater than 1), we don't have a way to keep the
|
||||
// comparison state. E.g., if the index_depth is 2, like indices = [[0, 1]],
|
||||
// we should use the update value only if (i == 0 and j == 1). However, we
|
||||
// can not get both indices in one iteration unless we pack them together.
|
||||
auto index_vector_dim =
|
||||
op.scatter_dimension_numbers().index_vector_dim().getInt();
|
||||
if (indices_ty.getDimSize(index_vector_dim) != 1)
|
||||
return rewriter.notifyMatchFailure(op, "require index depth to be 1");
|
||||
if (index_vector_dim != indices_ty.getRank() - 1) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "require index_vector_dim to be the last dim");
|
||||
}
|
||||
|
||||
// One of indices dims is index depth vector.
|
||||
int64_t nloops = operand_ty.getRank() + indices_ty.getRank() - 1;
|
||||
SmallVector<AffineMap, 3> indexing_maps;
|
||||
{
|
||||
SmallVector<AffineExpr> exprs;
|
||||
for (int64_t i = 0, e = operand_ty.getRank(); i < e; ++i)
|
||||
exprs.push_back(rewriter.getAffineDimExpr(i));
|
||||
indexing_maps.push_back(AffineMap::get(nloops, /*symbolCount=*/0, exprs,
|
||||
rewriter.getContext()));
|
||||
}
|
||||
{
|
||||
SmallVector<AffineExpr> exprs;
|
||||
for (int64_t i = operand_ty.getRank(); i < nloops; ++i)
|
||||
exprs.push_back(rewriter.getAffineDimExpr(i));
|
||||
// The index depth is 1.
|
||||
exprs.push_back(rewriter.getAffineConstantExpr(0));
|
||||
indexing_maps.push_back(AffineMap::get(nloops, /*symbolCount=*/0, exprs,
|
||||
rewriter.getContext()));
|
||||
|
||||
exprs.pop_back();
|
||||
auto update_window_dims =
|
||||
Extract1DVector(op.scatter_dimension_numbers().update_window_dims());
|
||||
for (auto d : update_window_dims)
|
||||
exprs.push_back(rewriter.getAffineDimExpr(d));
|
||||
indexing_maps.push_back(AffineMap::get(nloops, /*symbolCount=*/0, exprs,
|
||||
rewriter.getContext()));
|
||||
}
|
||||
indexing_maps.push_back(indexing_maps.front());
|
||||
|
||||
auto result_ty = this->typeConverter->convertType(op.getResult().getType())
|
||||
.cast<ShapedType>();
|
||||
auto scatter_dims_to_operand_dims = Extract1DVector(
|
||||
op.scatter_dimension_numbers().scatter_dims_to_operand_dims());
|
||||
assert(scatter_dims_to_operand_dims.size() == 1);
|
||||
// Do not need init_tensor because we'd like to initialize the output as
|
||||
// operand.
|
||||
auto linalg_op = rewriter.create<linalg::GenericOp>(
|
||||
op.getLoc(), /*resultTensors=*/ArrayRef<Type>{result_ty},
|
||||
/*inputs=*/
|
||||
ValueRange{adaptor.operand(), adaptor.scatter_indices(),
|
||||
adaptor.updates()},
|
||||
/*outputs=*/adaptor.operand(), indexing_maps,
|
||||
GetNParallelLoopsAttrs(nloops),
|
||||
[&](OpBuilder& b, Location loc, ValueRange args) {
|
||||
Value cmp_idx =
|
||||
b.create<linalg::IndexOp>(loc, scatter_dims_to_operand_dims[0]);
|
||||
Value idx = b.create<IndexCastOp>(loc, b.getIndexType(), args[1]);
|
||||
Value pred = b.create<CmpIOp>(loc, b.getI1Type(), CmpIPredicate::eq,
|
||||
cmp_idx, idx);
|
||||
// Use the output arg, so some update values won't be init value
|
||||
// again.
|
||||
Value res = b.create<SelectOp>(loc, args[2].getType(), pred, args[2],
|
||||
args[3]);
|
||||
b.create<linalg::YieldOp>(loc, res);
|
||||
});
|
||||
rewriter.replaceOp(op, linalg_op.getResults());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
void populateLHLOToLinalgConversionPattern(MLIRContext* context,
|
||||
TypeConverter& typeConverter,
|
||||
OwningRewritePatternList* patterns) {
|
||||
|
@ -2353,6 +2449,7 @@ void populateHLOToLinalgConversionPattern(MLIRContext* context,
|
|||
DepthwiseConvOpOnTensorsConversion,
|
||||
ReduceOnTensorsConversion,
|
||||
ReduceWindowOpOnTensorsConversion,
|
||||
ScatterUpdateOnTensorsConversion,
|
||||
TorchIndexSelectOpOnTensorsConversion,
|
||||
PadOpOnTensorsConversion>(type_converter, context);
|
||||
// clang-format on
|
||||
|
|
|
@ -2379,3 +2379,83 @@ func @unsigned_compare(%lhs: tensor<2x2xui32>, %rhs: tensor<2x2xui32>) -> tensor
|
|||
%0 = "mhlo.compare"(%lhs, %rhs) {comparison_direction = "GT"} : (tensor<2x2xui32>, tensor<2x2xui32>) -> tensor<2x2xi1>
|
||||
return %0 : tensor<2x2xi1>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @scatter_update_scalar(%arg0: tensor<3xi32>, %arg1: tensor<1x1xi32>,
|
||||
%arg2: tensor<1xi32>) -> tensor<3xi32> {
|
||||
%0 = "mhlo.scatter"(%arg0, %arg1, %arg2) ( {
|
||||
^bb0(%arg3: tensor<i32>, %arg4: tensor<i32>): // no predecessors
|
||||
"mhlo.return"(%arg4) : (tensor<i32>) -> ()
|
||||
}) {
|
||||
indices_are_sorted = false,
|
||||
scatter_dimension_numbers = {
|
||||
index_vector_dim = 1 : i64,
|
||||
inserted_window_dims = dense<0> : tensor<1xi64>,
|
||||
scatter_dims_to_operand_dims = dense<0> : tensor<1xi64>,
|
||||
update_window_dims = dense<> : tensor<0xi64>
|
||||
},
|
||||
unique_indices = false
|
||||
} : (tensor<3xi32>, tensor<1x1xi32>, tensor<1xi32>) -> tensor<3xi32>
|
||||
return %0 : tensor<3xi32>
|
||||
}
|
||||
// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1) -> (d0)>
|
||||
// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1) -> (d1, 0)>
|
||||
// CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0, d1) -> (d1)>
|
||||
// CHECK: func @scatter_update_scalar
|
||||
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]
|
||||
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]
|
||||
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]*]]
|
||||
// CHECK: %[[RES:.*]] = linalg.generic
|
||||
// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]], #[[MAP0]]]
|
||||
// CHECK-SAME: iterator_types = ["parallel", "parallel"]
|
||||
// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]], %[[ARG2]] : tensor<3xi32>, tensor<1x1xi32>, tensor<1xi32>)
|
||||
// CHECK-SAME: outs(%[[ARG0]] : tensor<3xi32>) {
|
||||
// CHECK: ^bb0(%{{.*}}: i32, %[[IDX_I32:.*]]: i32, %[[UPDATE:.*]]: i32, %[[OUT:.*]]: i32): // no predecessors
|
||||
// CHECK: %[[CMP_IDX:.*]] = linalg.index 0 : index
|
||||
// CHECK: %[[IDX:.*]] = index_cast %[[IDX_I32]] : i32 to index
|
||||
// CHECK: %[[PRED:.*]] = cmpi eq, %[[CMP_IDX]], %[[IDX]] : index
|
||||
// CHECK: %[[SELECT:.*]] = select %[[PRED]], %[[UPDATE]], %[[OUT]] : i32
|
||||
// CHECK: linalg.yield %[[SELECT]] : i32
|
||||
// CHECK: } -> tensor<3xi32>
|
||||
// CHECK: return %[[RES]] : tensor<3xi32>
|
||||
|
||||
// -----
|
||||
|
||||
func @scatter_update_slice(%arg0: tensor<6x3xi32>, %arg1: tensor<2x1xi32>,
|
||||
%arg2: tensor<2x3xi32>) -> tensor<6x3xi32> {
|
||||
%0 = "mhlo.scatter"(%arg0, %arg1, %arg2) ( {
|
||||
^bb0(%arg3: tensor<i32>, %arg4: tensor<i32>): // no predecessors
|
||||
"mhlo.return"(%arg4) : (tensor<i32>) -> ()
|
||||
}) {
|
||||
indices_are_sorted = false,
|
||||
scatter_dimension_numbers = {
|
||||
index_vector_dim = 1 : i64,
|
||||
inserted_window_dims = dense<0> : tensor<1xi64>,
|
||||
scatter_dims_to_operand_dims = dense<0> : tensor<1xi64>,
|
||||
update_window_dims = dense<1> : tensor<1xi64>
|
||||
},
|
||||
unique_indices = false
|
||||
} : (tensor<6x3xi32>, tensor<2x1xi32>, tensor<2x3xi32>) -> tensor<6x3xi32>
|
||||
return %0 : tensor<6x3xi32>
|
||||
}
|
||||
// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
|
||||
// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1, d2) -> (d2, 0)>
|
||||
// CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0, d1, d2) -> (d2, d1)>
|
||||
// CHECK: func @scatter_update_slice
|
||||
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]
|
||||
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]
|
||||
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]*]]
|
||||
// CHECK: %[[RES:.*]] = linalg.generic
|
||||
// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]], #[[MAP0]]]
|
||||
// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"]
|
||||
// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]], %[[ARG2]] : tensor<6x3xi32>, tensor<2x1xi32>, tensor<2x3xi32>)
|
||||
// CHECK-SAME: outs(%[[ARG0]] : tensor<6x3xi32>) {
|
||||
// CHECK: ^bb0(%{{.*}}: i32, %[[IDX_I32:.*]]: i32, %[[UPDATE:.*]]: i32, %[[OUT:.*]]: i32): // no predecessors
|
||||
// CHECK: %[[CMP_IDX:.*]] = linalg.index 0 : index
|
||||
// CHECK: %[[IDX:.*]] = index_cast %[[IDX_I32]] : i32 to index
|
||||
// CHECK: %[[PRED:.*]] = cmpi eq, %[[CMP_IDX]], %[[IDX]] : index
|
||||
// CHECK: %[[SELECT:.*]] = select %[[PRED]], %[[UPDATE]], %[[OUT]] : i32
|
||||
// CHECK: linalg.yield %[[SELECT]] : i32
|
||||
// CHECK: } -> tensor<6x3xi32>
|
||||
// CHECK: return %[[RES]] : tensor<6x3xi32>
|
||||
|
|
Loading…
Reference in New Issue