diff --git a/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc b/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc index 24afdb3..5c8184d 100644 --- a/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc +++ b/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc @@ -2091,6 +2091,102 @@ struct TorchIndexSelectOpOnTensorsConversion } }; +struct ScatterUpdateOnTensorsConversion + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite( + mhlo::ScatterOp op, ArrayRef args, + ConversionPatternRewriter& rewriter) const final { + mhlo::ScatterOp::Adaptor adaptor(args); + + // Check if it is a tensor_scatter_nd_update-like op. + auto& body_ops = op.getRegion().front().getOperations(); + if (body_ops.size() != 1) return failure(); + auto ret_arg = body_ops.front().getOperand(0).dyn_cast(); + if (!ret_arg || ret_arg.getArgNumber() != 1) return failure(); + + auto operand_ty = adaptor.operand().getType().dyn_cast(); + auto indices_ty = + adaptor.scatter_indices().getType().dyn_cast(); + if (!operand_ty || !indices_ty) return failure(); + + // Linalg operations put all the computation to the innermost loop. Since we + // also iterate over scatter_indices() with some loops, we can only check + // one scatter index in one iteration. If there are multiple indices (ie, + // the index depth is greater than 1), we don't have a way to keep the + // comparison state. E.g., if the index_depth is 2, like indices = [[0, 1]], + // we should use the update value only if (i == 0 and j == 1). However, we + // can not get both indices in one iteration unless we pack them together. + auto index_vector_dim = + op.scatter_dimension_numbers().index_vector_dim().getInt(); + if (indices_ty.getDimSize(index_vector_dim) != 1) + return rewriter.notifyMatchFailure(op, "require index depth to be 1"); + if (index_vector_dim != indices_ty.getRank() - 1) { + return rewriter.notifyMatchFailure( + op, "require index_vector_dim to be the last dim"); + } + + // One of indices dims is index depth vector. + int64_t nloops = operand_ty.getRank() + indices_ty.getRank() - 1; + SmallVector indexing_maps; + { + SmallVector exprs; + for (int64_t i = 0, e = operand_ty.getRank(); i < e; ++i) + exprs.push_back(rewriter.getAffineDimExpr(i)); + indexing_maps.push_back(AffineMap::get(nloops, /*symbolCount=*/0, exprs, + rewriter.getContext())); + } + { + SmallVector exprs; + for (int64_t i = operand_ty.getRank(); i < nloops; ++i) + exprs.push_back(rewriter.getAffineDimExpr(i)); + // The index depth is 1. + exprs.push_back(rewriter.getAffineConstantExpr(0)); + indexing_maps.push_back(AffineMap::get(nloops, /*symbolCount=*/0, exprs, + rewriter.getContext())); + + exprs.pop_back(); + auto update_window_dims = + Extract1DVector(op.scatter_dimension_numbers().update_window_dims()); + for (auto d : update_window_dims) + exprs.push_back(rewriter.getAffineDimExpr(d)); + indexing_maps.push_back(AffineMap::get(nloops, /*symbolCount=*/0, exprs, + rewriter.getContext())); + } + indexing_maps.push_back(indexing_maps.front()); + + auto result_ty = this->typeConverter->convertType(op.getResult().getType()) + .cast(); + auto scatter_dims_to_operand_dims = Extract1DVector( + op.scatter_dimension_numbers().scatter_dims_to_operand_dims()); + assert(scatter_dims_to_operand_dims.size() == 1); + // Do not need init_tensor because we'd like to initialize the output as + // operand. + auto linalg_op = rewriter.create( + op.getLoc(), /*resultTensors=*/ArrayRef{result_ty}, + /*inputs=*/ + ValueRange{adaptor.operand(), adaptor.scatter_indices(), + adaptor.updates()}, + /*outputs=*/adaptor.operand(), indexing_maps, + GetNParallelLoopsAttrs(nloops), + [&](OpBuilder& b, Location loc, ValueRange args) { + Value cmp_idx = + b.create(loc, scatter_dims_to_operand_dims[0]); + Value idx = b.create(loc, b.getIndexType(), args[1]); + Value pred = b.create(loc, b.getI1Type(), CmpIPredicate::eq, + cmp_idx, idx); + // Use the output arg, so some update values won't be init value + // again. + Value res = b.create(loc, args[2].getType(), pred, args[2], + args[3]); + b.create(loc, res); + }); + rewriter.replaceOp(op, linalg_op.getResults()); + return success(); + } +}; + void populateLHLOToLinalgConversionPattern(MLIRContext* context, TypeConverter& typeConverter, OwningRewritePatternList* patterns) { @@ -2353,6 +2449,7 @@ void populateHLOToLinalgConversionPattern(MLIRContext* context, DepthwiseConvOpOnTensorsConversion, ReduceOnTensorsConversion, ReduceWindowOpOnTensorsConversion, + ScatterUpdateOnTensorsConversion, TorchIndexSelectOpOnTensorsConversion, PadOpOnTensorsConversion>(type_converter, context); // clang-format on diff --git a/tests/hlo-legalize-to-linalg.mlir b/tests/hlo-legalize-to-linalg.mlir index 955187e..f46cc97 100644 --- a/tests/hlo-legalize-to-linalg.mlir +++ b/tests/hlo-legalize-to-linalg.mlir @@ -2379,3 +2379,83 @@ func @unsigned_compare(%lhs: tensor<2x2xui32>, %rhs: tensor<2x2xui32>) -> tensor %0 = "mhlo.compare"(%lhs, %rhs) {comparison_direction = "GT"} : (tensor<2x2xui32>, tensor<2x2xui32>) -> tensor<2x2xi1> return %0 : tensor<2x2xi1> } + +// ----- + +func @scatter_update_scalar(%arg0: tensor<3xi32>, %arg1: tensor<1x1xi32>, + %arg2: tensor<1xi32>) -> tensor<3xi32> { + %0 = "mhlo.scatter"(%arg0, %arg1, %arg2) ( { + ^bb0(%arg3: tensor, %arg4: tensor): // no predecessors + "mhlo.return"(%arg4) : (tensor) -> () + }) { + indices_are_sorted = false, + scatter_dimension_numbers = { + index_vector_dim = 1 : i64, + inserted_window_dims = dense<0> : tensor<1xi64>, + scatter_dims_to_operand_dims = dense<0> : tensor<1xi64>, + update_window_dims = dense<> : tensor<0xi64> + }, + unique_indices = false + } : (tensor<3xi32>, tensor<1x1xi32>, tensor<1xi32>) -> tensor<3xi32> + return %0 : tensor<3xi32> +} +// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1) -> (d0)> +// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1) -> (d1, 0)> +// CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0, d1) -> (d1)> +// CHECK: func @scatter_update_scalar +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]] +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]] +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]*]] +// CHECK: %[[RES:.*]] = linalg.generic +// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]], #[[MAP0]]] +// CHECK-SAME: iterator_types = ["parallel", "parallel"] +// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]], %[[ARG2]] : tensor<3xi32>, tensor<1x1xi32>, tensor<1xi32>) +// CHECK-SAME: outs(%[[ARG0]] : tensor<3xi32>) { +// CHECK: ^bb0(%{{.*}}: i32, %[[IDX_I32:.*]]: i32, %[[UPDATE:.*]]: i32, %[[OUT:.*]]: i32): // no predecessors +// CHECK: %[[CMP_IDX:.*]] = linalg.index 0 : index +// CHECK: %[[IDX:.*]] = index_cast %[[IDX_I32]] : i32 to index +// CHECK: %[[PRED:.*]] = cmpi eq, %[[CMP_IDX]], %[[IDX]] : index +// CHECK: %[[SELECT:.*]] = select %[[PRED]], %[[UPDATE]], %[[OUT]] : i32 +// CHECK: linalg.yield %[[SELECT]] : i32 +// CHECK: } -> tensor<3xi32> +// CHECK: return %[[RES]] : tensor<3xi32> + +// ----- + +func @scatter_update_slice(%arg0: tensor<6x3xi32>, %arg1: tensor<2x1xi32>, + %arg2: tensor<2x3xi32>) -> tensor<6x3xi32> { + %0 = "mhlo.scatter"(%arg0, %arg1, %arg2) ( { + ^bb0(%arg3: tensor, %arg4: tensor): // no predecessors + "mhlo.return"(%arg4) : (tensor) -> () + }) { + indices_are_sorted = false, + scatter_dimension_numbers = { + index_vector_dim = 1 : i64, + inserted_window_dims = dense<0> : tensor<1xi64>, + scatter_dims_to_operand_dims = dense<0> : tensor<1xi64>, + update_window_dims = dense<1> : tensor<1xi64> + }, + unique_indices = false + } : (tensor<6x3xi32>, tensor<2x1xi32>, tensor<2x3xi32>) -> tensor<6x3xi32> + return %0 : tensor<6x3xi32> +} +// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)> +// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1, d2) -> (d2, 0)> +// CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0, d1, d2) -> (d2, d1)> +// CHECK: func @scatter_update_slice +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]] +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]] +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]*]] +// CHECK: %[[RES:.*]] = linalg.generic +// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]], #[[MAP0]]] +// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"] +// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]], %[[ARG2]] : tensor<6x3xi32>, tensor<2x1xi32>, tensor<2x3xi32>) +// CHECK-SAME: outs(%[[ARG0]] : tensor<6x3xi32>) { +// CHECK: ^bb0(%{{.*}}: i32, %[[IDX_I32:.*]]: i32, %[[UPDATE:.*]]: i32, %[[OUT:.*]]: i32): // no predecessors +// CHECK: %[[CMP_IDX:.*]] = linalg.index 0 : index +// CHECK: %[[IDX:.*]] = index_cast %[[IDX_I32]] : i32 to index +// CHECK: %[[PRED:.*]] = cmpi eq, %[[CMP_IDX]], %[[IDX]] : index +// CHECK: %[[SELECT:.*]] = select %[[PRED]], %[[UPDATE]], %[[OUT]] : i32 +// CHECK: linalg.yield %[[SELECT]] : i32 +// CHECK: } -> tensor<6x3xi32> +// CHECK: return %[[RES]] : tensor<6x3xi32>