Add support for lowering mhlo.scatter ops to Linalg.

This only works for updating tensors, not add/min/max computations. It requires
the index depth to be 1 because of the limitation in Linalg. We can not compare
multiple indices without packing indices.

PiperOrigin-RevId: 375137721
This commit is contained in:
Hanhan Wang 2021-05-21 12:16:14 -07:00 committed by TensorFlow MLIR Team
parent f5dc73ab60
commit 1ba4c714c9
2 changed files with 177 additions and 0 deletions

View File

@ -2091,6 +2091,102 @@ struct TorchIndexSelectOpOnTensorsConversion
}
};
struct ScatterUpdateOnTensorsConversion
: public OpConversionPattern<mhlo::ScatterOp> {
using OpConversionPattern<mhlo::ScatterOp>::OpConversionPattern;
LogicalResult matchAndRewrite(
mhlo::ScatterOp op, ArrayRef<Value> args,
ConversionPatternRewriter& rewriter) const final {
mhlo::ScatterOp::Adaptor adaptor(args);
// Check if it is a tensor_scatter_nd_update-like op.
auto& body_ops = op.getRegion().front().getOperations();
if (body_ops.size() != 1) return failure();
auto ret_arg = body_ops.front().getOperand(0).dyn_cast<BlockArgument>();
if (!ret_arg || ret_arg.getArgNumber() != 1) return failure();
auto operand_ty = adaptor.operand().getType().dyn_cast<RankedTensorType>();
auto indices_ty =
adaptor.scatter_indices().getType().dyn_cast<RankedTensorType>();
if (!operand_ty || !indices_ty) return failure();
// Linalg operations put all the computation to the innermost loop. Since we
// also iterate over scatter_indices() with some loops, we can only check
// one scatter index in one iteration. If there are multiple indices (ie,
// the index depth is greater than 1), we don't have a way to keep the
// comparison state. E.g., if the index_depth is 2, like indices = [[0, 1]],
// we should use the update value only if (i == 0 and j == 1). However, we
// can not get both indices in one iteration unless we pack them together.
auto index_vector_dim =
op.scatter_dimension_numbers().index_vector_dim().getInt();
if (indices_ty.getDimSize(index_vector_dim) != 1)
return rewriter.notifyMatchFailure(op, "require index depth to be 1");
if (index_vector_dim != indices_ty.getRank() - 1) {
return rewriter.notifyMatchFailure(
op, "require index_vector_dim to be the last dim");
}
// One of indices dims is index depth vector.
int64_t nloops = operand_ty.getRank() + indices_ty.getRank() - 1;
SmallVector<AffineMap, 3> indexing_maps;
{
SmallVector<AffineExpr> exprs;
for (int64_t i = 0, e = operand_ty.getRank(); i < e; ++i)
exprs.push_back(rewriter.getAffineDimExpr(i));
indexing_maps.push_back(AffineMap::get(nloops, /*symbolCount=*/0, exprs,
rewriter.getContext()));
}
{
SmallVector<AffineExpr> exprs;
for (int64_t i = operand_ty.getRank(); i < nloops; ++i)
exprs.push_back(rewriter.getAffineDimExpr(i));
// The index depth is 1.
exprs.push_back(rewriter.getAffineConstantExpr(0));
indexing_maps.push_back(AffineMap::get(nloops, /*symbolCount=*/0, exprs,
rewriter.getContext()));
exprs.pop_back();
auto update_window_dims =
Extract1DVector(op.scatter_dimension_numbers().update_window_dims());
for (auto d : update_window_dims)
exprs.push_back(rewriter.getAffineDimExpr(d));
indexing_maps.push_back(AffineMap::get(nloops, /*symbolCount=*/0, exprs,
rewriter.getContext()));
}
indexing_maps.push_back(indexing_maps.front());
auto result_ty = this->typeConverter->convertType(op.getResult().getType())
.cast<ShapedType>();
auto scatter_dims_to_operand_dims = Extract1DVector(
op.scatter_dimension_numbers().scatter_dims_to_operand_dims());
assert(scatter_dims_to_operand_dims.size() == 1);
// Do not need init_tensor because we'd like to initialize the output as
// operand.
auto linalg_op = rewriter.create<linalg::GenericOp>(
op.getLoc(), /*resultTensors=*/ArrayRef<Type>{result_ty},
/*inputs=*/
ValueRange{adaptor.operand(), adaptor.scatter_indices(),
adaptor.updates()},
/*outputs=*/adaptor.operand(), indexing_maps,
GetNParallelLoopsAttrs(nloops),
[&](OpBuilder& b, Location loc, ValueRange args) {
Value cmp_idx =
b.create<linalg::IndexOp>(loc, scatter_dims_to_operand_dims[0]);
Value idx = b.create<IndexCastOp>(loc, b.getIndexType(), args[1]);
Value pred = b.create<CmpIOp>(loc, b.getI1Type(), CmpIPredicate::eq,
cmp_idx, idx);
// Use the output arg, so some update values won't be init value
// again.
Value res = b.create<SelectOp>(loc, args[2].getType(), pred, args[2],
args[3]);
b.create<linalg::YieldOp>(loc, res);
});
rewriter.replaceOp(op, linalg_op.getResults());
return success();
}
};
void populateLHLOToLinalgConversionPattern(MLIRContext* context,
TypeConverter& typeConverter,
OwningRewritePatternList* patterns) {
@ -2353,6 +2449,7 @@ void populateHLOToLinalgConversionPattern(MLIRContext* context,
DepthwiseConvOpOnTensorsConversion,
ReduceOnTensorsConversion,
ReduceWindowOpOnTensorsConversion,
ScatterUpdateOnTensorsConversion,
TorchIndexSelectOpOnTensorsConversion,
PadOpOnTensorsConversion>(type_converter, context);
// clang-format on

View File

@ -2379,3 +2379,83 @@ func @unsigned_compare(%lhs: tensor<2x2xui32>, %rhs: tensor<2x2xui32>) -> tensor
%0 = "mhlo.compare"(%lhs, %rhs) {comparison_direction = "GT"} : (tensor<2x2xui32>, tensor<2x2xui32>) -> tensor<2x2xi1>
return %0 : tensor<2x2xi1>
}
// -----
func @scatter_update_scalar(%arg0: tensor<3xi32>, %arg1: tensor<1x1xi32>,
%arg2: tensor<1xi32>) -> tensor<3xi32> {
%0 = "mhlo.scatter"(%arg0, %arg1, %arg2) ( {
^bb0(%arg3: tensor<i32>, %arg4: tensor<i32>): // no predecessors
"mhlo.return"(%arg4) : (tensor<i32>) -> ()
}) {
indices_are_sorted = false,
scatter_dimension_numbers = {
index_vector_dim = 1 : i64,
inserted_window_dims = dense<0> : tensor<1xi64>,
scatter_dims_to_operand_dims = dense<0> : tensor<1xi64>,
update_window_dims = dense<> : tensor<0xi64>
},
unique_indices = false
} : (tensor<3xi32>, tensor<1x1xi32>, tensor<1xi32>) -> tensor<3xi32>
return %0 : tensor<3xi32>
}
// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1) -> (d0)>
// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1) -> (d1, 0)>
// CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0, d1) -> (d1)>
// CHECK: func @scatter_update_scalar
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]*]]
// CHECK: %[[RES:.*]] = linalg.generic
// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]], #[[MAP0]]]
// CHECK-SAME: iterator_types = ["parallel", "parallel"]
// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]], %[[ARG2]] : tensor<3xi32>, tensor<1x1xi32>, tensor<1xi32>)
// CHECK-SAME: outs(%[[ARG0]] : tensor<3xi32>) {
// CHECK: ^bb0(%{{.*}}: i32, %[[IDX_I32:.*]]: i32, %[[UPDATE:.*]]: i32, %[[OUT:.*]]: i32): // no predecessors
// CHECK: %[[CMP_IDX:.*]] = linalg.index 0 : index
// CHECK: %[[IDX:.*]] = index_cast %[[IDX_I32]] : i32 to index
// CHECK: %[[PRED:.*]] = cmpi eq, %[[CMP_IDX]], %[[IDX]] : index
// CHECK: %[[SELECT:.*]] = select %[[PRED]], %[[UPDATE]], %[[OUT]] : i32
// CHECK: linalg.yield %[[SELECT]] : i32
// CHECK: } -> tensor<3xi32>
// CHECK: return %[[RES]] : tensor<3xi32>
// -----
func @scatter_update_slice(%arg0: tensor<6x3xi32>, %arg1: tensor<2x1xi32>,
%arg2: tensor<2x3xi32>) -> tensor<6x3xi32> {
%0 = "mhlo.scatter"(%arg0, %arg1, %arg2) ( {
^bb0(%arg3: tensor<i32>, %arg4: tensor<i32>): // no predecessors
"mhlo.return"(%arg4) : (tensor<i32>) -> ()
}) {
indices_are_sorted = false,
scatter_dimension_numbers = {
index_vector_dim = 1 : i64,
inserted_window_dims = dense<0> : tensor<1xi64>,
scatter_dims_to_operand_dims = dense<0> : tensor<1xi64>,
update_window_dims = dense<1> : tensor<1xi64>
},
unique_indices = false
} : (tensor<6x3xi32>, tensor<2x1xi32>, tensor<2x3xi32>) -> tensor<6x3xi32>
return %0 : tensor<6x3xi32>
}
// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1, d2) -> (d2, 0)>
// CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0, d1, d2) -> (d2, d1)>
// CHECK: func @scatter_update_slice
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]*]]
// CHECK: %[[RES:.*]] = linalg.generic
// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]], #[[MAP0]]]
// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"]
// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]], %[[ARG2]] : tensor<6x3xi32>, tensor<2x1xi32>, tensor<2x3xi32>)
// CHECK-SAME: outs(%[[ARG0]] : tensor<6x3xi32>) {
// CHECK: ^bb0(%{{.*}}: i32, %[[IDX_I32:.*]]: i32, %[[UPDATE:.*]]: i32, %[[OUT:.*]]: i32): // no predecessors
// CHECK: %[[CMP_IDX:.*]] = linalg.index 0 : index
// CHECK: %[[IDX:.*]] = index_cast %[[IDX_I32]] : i32 to index
// CHECK: %[[PRED:.*]] = cmpi eq, %[[CMP_IDX]], %[[IDX]] : index
// CHECK: %[[SELECT:.*]] = select %[[PRED]], %[[UPDATE]], %[[OUT]] : i32
// CHECK: linalg.yield %[[SELECT]] : i32
// CHECK: } -> tensor<6x3xi32>
// CHECK: return %[[RES]] : tensor<6x3xi32>