Add support for lowering mhlo.scatter ops to Linalg.
This only works for updating tensors, not add/min/max computations. It requires the index depth to be 1 because of the limitation in Linalg. We can not compare multiple indices without packing indices. PiperOrigin-RevId: 375137721
This commit is contained in:
parent
f5dc73ab60
commit
1ba4c714c9
|
@ -2091,6 +2091,102 @@ struct TorchIndexSelectOpOnTensorsConversion
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
struct ScatterUpdateOnTensorsConversion
|
||||||
|
: public OpConversionPattern<mhlo::ScatterOp> {
|
||||||
|
using OpConversionPattern<mhlo::ScatterOp>::OpConversionPattern;
|
||||||
|
|
||||||
|
LogicalResult matchAndRewrite(
|
||||||
|
mhlo::ScatterOp op, ArrayRef<Value> args,
|
||||||
|
ConversionPatternRewriter& rewriter) const final {
|
||||||
|
mhlo::ScatterOp::Adaptor adaptor(args);
|
||||||
|
|
||||||
|
// Check if it is a tensor_scatter_nd_update-like op.
|
||||||
|
auto& body_ops = op.getRegion().front().getOperations();
|
||||||
|
if (body_ops.size() != 1) return failure();
|
||||||
|
auto ret_arg = body_ops.front().getOperand(0).dyn_cast<BlockArgument>();
|
||||||
|
if (!ret_arg || ret_arg.getArgNumber() != 1) return failure();
|
||||||
|
|
||||||
|
auto operand_ty = adaptor.operand().getType().dyn_cast<RankedTensorType>();
|
||||||
|
auto indices_ty =
|
||||||
|
adaptor.scatter_indices().getType().dyn_cast<RankedTensorType>();
|
||||||
|
if (!operand_ty || !indices_ty) return failure();
|
||||||
|
|
||||||
|
// Linalg operations put all the computation to the innermost loop. Since we
|
||||||
|
// also iterate over scatter_indices() with some loops, we can only check
|
||||||
|
// one scatter index in one iteration. If there are multiple indices (ie,
|
||||||
|
// the index depth is greater than 1), we don't have a way to keep the
|
||||||
|
// comparison state. E.g., if the index_depth is 2, like indices = [[0, 1]],
|
||||||
|
// we should use the update value only if (i == 0 and j == 1). However, we
|
||||||
|
// can not get both indices in one iteration unless we pack them together.
|
||||||
|
auto index_vector_dim =
|
||||||
|
op.scatter_dimension_numbers().index_vector_dim().getInt();
|
||||||
|
if (indices_ty.getDimSize(index_vector_dim) != 1)
|
||||||
|
return rewriter.notifyMatchFailure(op, "require index depth to be 1");
|
||||||
|
if (index_vector_dim != indices_ty.getRank() - 1) {
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "require index_vector_dim to be the last dim");
|
||||||
|
}
|
||||||
|
|
||||||
|
// One of indices dims is index depth vector.
|
||||||
|
int64_t nloops = operand_ty.getRank() + indices_ty.getRank() - 1;
|
||||||
|
SmallVector<AffineMap, 3> indexing_maps;
|
||||||
|
{
|
||||||
|
SmallVector<AffineExpr> exprs;
|
||||||
|
for (int64_t i = 0, e = operand_ty.getRank(); i < e; ++i)
|
||||||
|
exprs.push_back(rewriter.getAffineDimExpr(i));
|
||||||
|
indexing_maps.push_back(AffineMap::get(nloops, /*symbolCount=*/0, exprs,
|
||||||
|
rewriter.getContext()));
|
||||||
|
}
|
||||||
|
{
|
||||||
|
SmallVector<AffineExpr> exprs;
|
||||||
|
for (int64_t i = operand_ty.getRank(); i < nloops; ++i)
|
||||||
|
exprs.push_back(rewriter.getAffineDimExpr(i));
|
||||||
|
// The index depth is 1.
|
||||||
|
exprs.push_back(rewriter.getAffineConstantExpr(0));
|
||||||
|
indexing_maps.push_back(AffineMap::get(nloops, /*symbolCount=*/0, exprs,
|
||||||
|
rewriter.getContext()));
|
||||||
|
|
||||||
|
exprs.pop_back();
|
||||||
|
auto update_window_dims =
|
||||||
|
Extract1DVector(op.scatter_dimension_numbers().update_window_dims());
|
||||||
|
for (auto d : update_window_dims)
|
||||||
|
exprs.push_back(rewriter.getAffineDimExpr(d));
|
||||||
|
indexing_maps.push_back(AffineMap::get(nloops, /*symbolCount=*/0, exprs,
|
||||||
|
rewriter.getContext()));
|
||||||
|
}
|
||||||
|
indexing_maps.push_back(indexing_maps.front());
|
||||||
|
|
||||||
|
auto result_ty = this->typeConverter->convertType(op.getResult().getType())
|
||||||
|
.cast<ShapedType>();
|
||||||
|
auto scatter_dims_to_operand_dims = Extract1DVector(
|
||||||
|
op.scatter_dimension_numbers().scatter_dims_to_operand_dims());
|
||||||
|
assert(scatter_dims_to_operand_dims.size() == 1);
|
||||||
|
// Do not need init_tensor because we'd like to initialize the output as
|
||||||
|
// operand.
|
||||||
|
auto linalg_op = rewriter.create<linalg::GenericOp>(
|
||||||
|
op.getLoc(), /*resultTensors=*/ArrayRef<Type>{result_ty},
|
||||||
|
/*inputs=*/
|
||||||
|
ValueRange{adaptor.operand(), adaptor.scatter_indices(),
|
||||||
|
adaptor.updates()},
|
||||||
|
/*outputs=*/adaptor.operand(), indexing_maps,
|
||||||
|
GetNParallelLoopsAttrs(nloops),
|
||||||
|
[&](OpBuilder& b, Location loc, ValueRange args) {
|
||||||
|
Value cmp_idx =
|
||||||
|
b.create<linalg::IndexOp>(loc, scatter_dims_to_operand_dims[0]);
|
||||||
|
Value idx = b.create<IndexCastOp>(loc, b.getIndexType(), args[1]);
|
||||||
|
Value pred = b.create<CmpIOp>(loc, b.getI1Type(), CmpIPredicate::eq,
|
||||||
|
cmp_idx, idx);
|
||||||
|
// Use the output arg, so some update values won't be init value
|
||||||
|
// again.
|
||||||
|
Value res = b.create<SelectOp>(loc, args[2].getType(), pred, args[2],
|
||||||
|
args[3]);
|
||||||
|
b.create<linalg::YieldOp>(loc, res);
|
||||||
|
});
|
||||||
|
rewriter.replaceOp(op, linalg_op.getResults());
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
void populateLHLOToLinalgConversionPattern(MLIRContext* context,
|
void populateLHLOToLinalgConversionPattern(MLIRContext* context,
|
||||||
TypeConverter& typeConverter,
|
TypeConverter& typeConverter,
|
||||||
OwningRewritePatternList* patterns) {
|
OwningRewritePatternList* patterns) {
|
||||||
|
@ -2353,6 +2449,7 @@ void populateHLOToLinalgConversionPattern(MLIRContext* context,
|
||||||
DepthwiseConvOpOnTensorsConversion,
|
DepthwiseConvOpOnTensorsConversion,
|
||||||
ReduceOnTensorsConversion,
|
ReduceOnTensorsConversion,
|
||||||
ReduceWindowOpOnTensorsConversion,
|
ReduceWindowOpOnTensorsConversion,
|
||||||
|
ScatterUpdateOnTensorsConversion,
|
||||||
TorchIndexSelectOpOnTensorsConversion,
|
TorchIndexSelectOpOnTensorsConversion,
|
||||||
PadOpOnTensorsConversion>(type_converter, context);
|
PadOpOnTensorsConversion>(type_converter, context);
|
||||||
// clang-format on
|
// clang-format on
|
||||||
|
|
|
@ -2379,3 +2379,83 @@ func @unsigned_compare(%lhs: tensor<2x2xui32>, %rhs: tensor<2x2xui32>) -> tensor
|
||||||
%0 = "mhlo.compare"(%lhs, %rhs) {comparison_direction = "GT"} : (tensor<2x2xui32>, tensor<2x2xui32>) -> tensor<2x2xi1>
|
%0 = "mhlo.compare"(%lhs, %rhs) {comparison_direction = "GT"} : (tensor<2x2xui32>, tensor<2x2xui32>) -> tensor<2x2xi1>
|
||||||
return %0 : tensor<2x2xi1>
|
return %0 : tensor<2x2xi1>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
func @scatter_update_scalar(%arg0: tensor<3xi32>, %arg1: tensor<1x1xi32>,
|
||||||
|
%arg2: tensor<1xi32>) -> tensor<3xi32> {
|
||||||
|
%0 = "mhlo.scatter"(%arg0, %arg1, %arg2) ( {
|
||||||
|
^bb0(%arg3: tensor<i32>, %arg4: tensor<i32>): // no predecessors
|
||||||
|
"mhlo.return"(%arg4) : (tensor<i32>) -> ()
|
||||||
|
}) {
|
||||||
|
indices_are_sorted = false,
|
||||||
|
scatter_dimension_numbers = {
|
||||||
|
index_vector_dim = 1 : i64,
|
||||||
|
inserted_window_dims = dense<0> : tensor<1xi64>,
|
||||||
|
scatter_dims_to_operand_dims = dense<0> : tensor<1xi64>,
|
||||||
|
update_window_dims = dense<> : tensor<0xi64>
|
||||||
|
},
|
||||||
|
unique_indices = false
|
||||||
|
} : (tensor<3xi32>, tensor<1x1xi32>, tensor<1xi32>) -> tensor<3xi32>
|
||||||
|
return %0 : tensor<3xi32>
|
||||||
|
}
|
||||||
|
// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1) -> (d0)>
|
||||||
|
// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1) -> (d1, 0)>
|
||||||
|
// CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0, d1) -> (d1)>
|
||||||
|
// CHECK: func @scatter_update_scalar
|
||||||
|
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]
|
||||||
|
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]
|
||||||
|
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]*]]
|
||||||
|
// CHECK: %[[RES:.*]] = linalg.generic
|
||||||
|
// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]], #[[MAP0]]]
|
||||||
|
// CHECK-SAME: iterator_types = ["parallel", "parallel"]
|
||||||
|
// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]], %[[ARG2]] : tensor<3xi32>, tensor<1x1xi32>, tensor<1xi32>)
|
||||||
|
// CHECK-SAME: outs(%[[ARG0]] : tensor<3xi32>) {
|
||||||
|
// CHECK: ^bb0(%{{.*}}: i32, %[[IDX_I32:.*]]: i32, %[[UPDATE:.*]]: i32, %[[OUT:.*]]: i32): // no predecessors
|
||||||
|
// CHECK: %[[CMP_IDX:.*]] = linalg.index 0 : index
|
||||||
|
// CHECK: %[[IDX:.*]] = index_cast %[[IDX_I32]] : i32 to index
|
||||||
|
// CHECK: %[[PRED:.*]] = cmpi eq, %[[CMP_IDX]], %[[IDX]] : index
|
||||||
|
// CHECK: %[[SELECT:.*]] = select %[[PRED]], %[[UPDATE]], %[[OUT]] : i32
|
||||||
|
// CHECK: linalg.yield %[[SELECT]] : i32
|
||||||
|
// CHECK: } -> tensor<3xi32>
|
||||||
|
// CHECK: return %[[RES]] : tensor<3xi32>
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
func @scatter_update_slice(%arg0: tensor<6x3xi32>, %arg1: tensor<2x1xi32>,
|
||||||
|
%arg2: tensor<2x3xi32>) -> tensor<6x3xi32> {
|
||||||
|
%0 = "mhlo.scatter"(%arg0, %arg1, %arg2) ( {
|
||||||
|
^bb0(%arg3: tensor<i32>, %arg4: tensor<i32>): // no predecessors
|
||||||
|
"mhlo.return"(%arg4) : (tensor<i32>) -> ()
|
||||||
|
}) {
|
||||||
|
indices_are_sorted = false,
|
||||||
|
scatter_dimension_numbers = {
|
||||||
|
index_vector_dim = 1 : i64,
|
||||||
|
inserted_window_dims = dense<0> : tensor<1xi64>,
|
||||||
|
scatter_dims_to_operand_dims = dense<0> : tensor<1xi64>,
|
||||||
|
update_window_dims = dense<1> : tensor<1xi64>
|
||||||
|
},
|
||||||
|
unique_indices = false
|
||||||
|
} : (tensor<6x3xi32>, tensor<2x1xi32>, tensor<2x3xi32>) -> tensor<6x3xi32>
|
||||||
|
return %0 : tensor<6x3xi32>
|
||||||
|
}
|
||||||
|
// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
|
||||||
|
// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1, d2) -> (d2, 0)>
|
||||||
|
// CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0, d1, d2) -> (d2, d1)>
|
||||||
|
// CHECK: func @scatter_update_slice
|
||||||
|
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]
|
||||||
|
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]
|
||||||
|
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]*]]
|
||||||
|
// CHECK: %[[RES:.*]] = linalg.generic
|
||||||
|
// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]], #[[MAP0]]]
|
||||||
|
// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"]
|
||||||
|
// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]], %[[ARG2]] : tensor<6x3xi32>, tensor<2x1xi32>, tensor<2x3xi32>)
|
||||||
|
// CHECK-SAME: outs(%[[ARG0]] : tensor<6x3xi32>) {
|
||||||
|
// CHECK: ^bb0(%{{.*}}: i32, %[[IDX_I32:.*]]: i32, %[[UPDATE:.*]]: i32, %[[OUT:.*]]: i32): // no predecessors
|
||||||
|
// CHECK: %[[CMP_IDX:.*]] = linalg.index 0 : index
|
||||||
|
// CHECK: %[[IDX:.*]] = index_cast %[[IDX_I32]] : i32 to index
|
||||||
|
// CHECK: %[[PRED:.*]] = cmpi eq, %[[CMP_IDX]], %[[IDX]] : index
|
||||||
|
// CHECK: %[[SELECT:.*]] = select %[[PRED]], %[[UPDATE]], %[[OUT]] : i32
|
||||||
|
// CHECK: linalg.yield %[[SELECT]] : i32
|
||||||
|
// CHECK: } -> tensor<6x3xi32>
|
||||||
|
// CHECK: return %[[RES]] : tensor<6x3xi32>
|
||||||
|
|
Loading…
Reference in New Issue