Add support for lowering mhlo.scatter ops to Linalg.
This only works for updating tensors, not add/min/max computations. It requires the index depth to be 1 because of the limitation in Linalg. We can not compare multiple indices without packing indices. PiperOrigin-RevId: 375137721
This commit is contained in:
		
							parent
							
								
									f5dc73ab60
								
							
						
					
					
						commit
						1ba4c714c9
					
				|  | @ -2091,6 +2091,102 @@ struct TorchIndexSelectOpOnTensorsConversion | ||||||
|   } |   } | ||||||
| }; | }; | ||||||
| 
 | 
 | ||||||
|  | struct ScatterUpdateOnTensorsConversion | ||||||
|  |     : public OpConversionPattern<mhlo::ScatterOp> { | ||||||
|  |   using OpConversionPattern<mhlo::ScatterOp>::OpConversionPattern; | ||||||
|  | 
 | ||||||
|  |   LogicalResult matchAndRewrite( | ||||||
|  |       mhlo::ScatterOp op, ArrayRef<Value> args, | ||||||
|  |       ConversionPatternRewriter& rewriter) const final { | ||||||
|  |     mhlo::ScatterOp::Adaptor adaptor(args); | ||||||
|  | 
 | ||||||
|  |     // Check if it is a tensor_scatter_nd_update-like op.
 | ||||||
|  |     auto& body_ops = op.getRegion().front().getOperations(); | ||||||
|  |     if (body_ops.size() != 1) return failure(); | ||||||
|  |     auto ret_arg = body_ops.front().getOperand(0).dyn_cast<BlockArgument>(); | ||||||
|  |     if (!ret_arg || ret_arg.getArgNumber() != 1) return failure(); | ||||||
|  | 
 | ||||||
|  |     auto operand_ty = adaptor.operand().getType().dyn_cast<RankedTensorType>(); | ||||||
|  |     auto indices_ty = | ||||||
|  |         adaptor.scatter_indices().getType().dyn_cast<RankedTensorType>(); | ||||||
|  |     if (!operand_ty || !indices_ty) return failure(); | ||||||
|  | 
 | ||||||
|  |     // Linalg operations put all the computation to the innermost loop. Since we
 | ||||||
|  |     // also iterate over scatter_indices() with some loops, we can only check
 | ||||||
|  |     // one scatter index in one iteration. If there are multiple indices (ie,
 | ||||||
|  |     // the index depth is greater than 1), we don't have a way to keep the
 | ||||||
|  |     // comparison state. E.g., if the index_depth is 2, like indices = [[0, 1]],
 | ||||||
|  |     // we should use the update value only if (i == 0 and j == 1). However, we
 | ||||||
|  |     // can not get both indices in one iteration unless we pack them together.
 | ||||||
|  |     auto index_vector_dim = | ||||||
|  |         op.scatter_dimension_numbers().index_vector_dim().getInt(); | ||||||
|  |     if (indices_ty.getDimSize(index_vector_dim) != 1) | ||||||
|  |       return rewriter.notifyMatchFailure(op, "require index depth to be 1"); | ||||||
|  |     if (index_vector_dim != indices_ty.getRank() - 1) { | ||||||
|  |       return rewriter.notifyMatchFailure( | ||||||
|  |           op, "require index_vector_dim to be the last dim"); | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     // One of indices dims is index depth vector.
 | ||||||
|  |     int64_t nloops = operand_ty.getRank() + indices_ty.getRank() - 1; | ||||||
|  |     SmallVector<AffineMap, 3> indexing_maps; | ||||||
|  |     { | ||||||
|  |       SmallVector<AffineExpr> exprs; | ||||||
|  |       for (int64_t i = 0, e = operand_ty.getRank(); i < e; ++i) | ||||||
|  |         exprs.push_back(rewriter.getAffineDimExpr(i)); | ||||||
|  |       indexing_maps.push_back(AffineMap::get(nloops, /*symbolCount=*/0, exprs, | ||||||
|  |                                              rewriter.getContext())); | ||||||
|  |     } | ||||||
|  |     { | ||||||
|  |       SmallVector<AffineExpr> exprs; | ||||||
|  |       for (int64_t i = operand_ty.getRank(); i < nloops; ++i) | ||||||
|  |         exprs.push_back(rewriter.getAffineDimExpr(i)); | ||||||
|  |       // The index depth is 1.
 | ||||||
|  |       exprs.push_back(rewriter.getAffineConstantExpr(0)); | ||||||
|  |       indexing_maps.push_back(AffineMap::get(nloops, /*symbolCount=*/0, exprs, | ||||||
|  |                                              rewriter.getContext())); | ||||||
|  | 
 | ||||||
|  |       exprs.pop_back(); | ||||||
|  |       auto update_window_dims = | ||||||
|  |           Extract1DVector(op.scatter_dimension_numbers().update_window_dims()); | ||||||
|  |       for (auto d : update_window_dims) | ||||||
|  |         exprs.push_back(rewriter.getAffineDimExpr(d)); | ||||||
|  |       indexing_maps.push_back(AffineMap::get(nloops, /*symbolCount=*/0, exprs, | ||||||
|  |                                              rewriter.getContext())); | ||||||
|  |     } | ||||||
|  |     indexing_maps.push_back(indexing_maps.front()); | ||||||
|  | 
 | ||||||
|  |     auto result_ty = this->typeConverter->convertType(op.getResult().getType()) | ||||||
|  |                          .cast<ShapedType>(); | ||||||
|  |     auto scatter_dims_to_operand_dims = Extract1DVector( | ||||||
|  |         op.scatter_dimension_numbers().scatter_dims_to_operand_dims()); | ||||||
|  |     assert(scatter_dims_to_operand_dims.size() == 1); | ||||||
|  |     // Do not need init_tensor because we'd like to initialize the output as
 | ||||||
|  |     // operand.
 | ||||||
|  |     auto linalg_op = rewriter.create<linalg::GenericOp>( | ||||||
|  |         op.getLoc(), /*resultTensors=*/ArrayRef<Type>{result_ty}, | ||||||
|  |         /*inputs=*/ | ||||||
|  |         ValueRange{adaptor.operand(), adaptor.scatter_indices(), | ||||||
|  |                    adaptor.updates()}, | ||||||
|  |         /*outputs=*/adaptor.operand(), indexing_maps, | ||||||
|  |         GetNParallelLoopsAttrs(nloops), | ||||||
|  |         [&](OpBuilder& b, Location loc, ValueRange args) { | ||||||
|  |           Value cmp_idx = | ||||||
|  |               b.create<linalg::IndexOp>(loc, scatter_dims_to_operand_dims[0]); | ||||||
|  |           Value idx = b.create<IndexCastOp>(loc, b.getIndexType(), args[1]); | ||||||
|  |           Value pred = b.create<CmpIOp>(loc, b.getI1Type(), CmpIPredicate::eq, | ||||||
|  |                                         cmp_idx, idx); | ||||||
|  |           // Use the output arg, so some update values won't be init value
 | ||||||
|  |           // again.
 | ||||||
|  |           Value res = b.create<SelectOp>(loc, args[2].getType(), pred, args[2], | ||||||
|  |                                          args[3]); | ||||||
|  |           b.create<linalg::YieldOp>(loc, res); | ||||||
|  |         }); | ||||||
|  |     rewriter.replaceOp(op, linalg_op.getResults()); | ||||||
|  |     return success(); | ||||||
|  |   } | ||||||
|  | }; | ||||||
|  | 
 | ||||||
| void populateLHLOToLinalgConversionPattern(MLIRContext* context, | void populateLHLOToLinalgConversionPattern(MLIRContext* context, | ||||||
|                                            TypeConverter& typeConverter, |                                            TypeConverter& typeConverter, | ||||||
|                                            OwningRewritePatternList* patterns) { |                                            OwningRewritePatternList* patterns) { | ||||||
|  | @ -2353,6 +2449,7 @@ void populateHLOToLinalgConversionPattern(MLIRContext* context, | ||||||
|       DepthwiseConvOpOnTensorsConversion, |       DepthwiseConvOpOnTensorsConversion, | ||||||
|       ReduceOnTensorsConversion, |       ReduceOnTensorsConversion, | ||||||
|       ReduceWindowOpOnTensorsConversion, |       ReduceWindowOpOnTensorsConversion, | ||||||
|  |       ScatterUpdateOnTensorsConversion, | ||||||
|       TorchIndexSelectOpOnTensorsConversion, |       TorchIndexSelectOpOnTensorsConversion, | ||||||
|       PadOpOnTensorsConversion>(type_converter, context); |       PadOpOnTensorsConversion>(type_converter, context); | ||||||
|   // clang-format on
 |   // clang-format on
 | ||||||
|  |  | ||||||
|  | @ -2379,3 +2379,83 @@ func @unsigned_compare(%lhs: tensor<2x2xui32>, %rhs: tensor<2x2xui32>) -> tensor | ||||||
|   %0 = "mhlo.compare"(%lhs, %rhs) {comparison_direction = "GT"} : (tensor<2x2xui32>, tensor<2x2xui32>) -> tensor<2x2xi1> |   %0 = "mhlo.compare"(%lhs, %rhs) {comparison_direction = "GT"} : (tensor<2x2xui32>, tensor<2x2xui32>) -> tensor<2x2xi1> | ||||||
|   return %0 : tensor<2x2xi1> |   return %0 : tensor<2x2xi1> | ||||||
| } | } | ||||||
|  | 
 | ||||||
|  | // ----- | ||||||
|  | 
 | ||||||
|  | func @scatter_update_scalar(%arg0: tensor<3xi32>, %arg1: tensor<1x1xi32>, | ||||||
|  |                             %arg2: tensor<1xi32>) -> tensor<3xi32> { | ||||||
|  |   %0 = "mhlo.scatter"(%arg0, %arg1, %arg2) ( { | ||||||
|  |   ^bb0(%arg3: tensor<i32>, %arg4: tensor<i32>):  // no predecessors | ||||||
|  |     "mhlo.return"(%arg4) : (tensor<i32>) -> () | ||||||
|  |   }) { | ||||||
|  |     indices_are_sorted = false, | ||||||
|  |     scatter_dimension_numbers = { | ||||||
|  |       index_vector_dim = 1 : i64, | ||||||
|  |       inserted_window_dims = dense<0> : tensor<1xi64>, | ||||||
|  |       scatter_dims_to_operand_dims = dense<0> : tensor<1xi64>, | ||||||
|  |       update_window_dims = dense<> : tensor<0xi64> | ||||||
|  |     }, | ||||||
|  |     unique_indices = false | ||||||
|  |   } : (tensor<3xi32>, tensor<1x1xi32>, tensor<1xi32>) -> tensor<3xi32> | ||||||
|  |   return %0 : tensor<3xi32> | ||||||
|  | } | ||||||
|  | // CHECK-DAG:   #[[MAP0:.*]] = affine_map<(d0, d1) -> (d0)> | ||||||
|  | // CHECK-DAG:   #[[MAP1:.*]] = affine_map<(d0, d1) -> (d1, 0)> | ||||||
|  | // CHECK-DAG:   #[[MAP2:.*]] = affine_map<(d0, d1) -> (d1)> | ||||||
|  | // CHECK:       func @scatter_update_scalar | ||||||
|  | // CHECK-SAME:    %[[ARG0:[a-zA-Z0-9_]*]] | ||||||
|  | // CHECK-SAME:    %[[ARG1:[a-zA-Z0-9_]*]] | ||||||
|  | // CHECK-SAME:    %[[ARG2:[a-zA-Z0-9_]*]] | ||||||
|  | // CHECK:         %[[RES:.*]] = linalg.generic | ||||||
|  | // CHECK-SAME:      indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]], #[[MAP0]]] | ||||||
|  | // CHECK-SAME:      iterator_types = ["parallel", "parallel"] | ||||||
|  | // CHECK-SAME:      ins(%[[ARG0]], %[[ARG1]], %[[ARG2]] : tensor<3xi32>, tensor<1x1xi32>, tensor<1xi32>) | ||||||
|  | // CHECK-SAME:     outs(%[[ARG0]] : tensor<3xi32>) { | ||||||
|  | // CHECK:         ^bb0(%{{.*}}: i32, %[[IDX_I32:.*]]: i32, %[[UPDATE:.*]]: i32, %[[OUT:.*]]: i32):  // no predecessors | ||||||
|  | // CHECK:           %[[CMP_IDX:.*]] = linalg.index 0 : index | ||||||
|  | // CHECK:           %[[IDX:.*]] = index_cast %[[IDX_I32]] : i32 to index | ||||||
|  | // CHECK:           %[[PRED:.*]] = cmpi eq, %[[CMP_IDX]], %[[IDX]] : index | ||||||
|  | // CHECK:           %[[SELECT:.*]] = select %[[PRED]], %[[UPDATE]], %[[OUT]] : i32 | ||||||
|  | // CHECK:           linalg.yield %[[SELECT]] : i32 | ||||||
|  | // CHECK:         } -> tensor<3xi32> | ||||||
|  | // CHECK:         return %[[RES]] : tensor<3xi32> | ||||||
|  | 
 | ||||||
|  | // ----- | ||||||
|  | 
 | ||||||
|  | func @scatter_update_slice(%arg0: tensor<6x3xi32>, %arg1: tensor<2x1xi32>, | ||||||
|  |                            %arg2: tensor<2x3xi32>) -> tensor<6x3xi32> { | ||||||
|  |   %0 = "mhlo.scatter"(%arg0, %arg1, %arg2) ( { | ||||||
|  |   ^bb0(%arg3: tensor<i32>, %arg4: tensor<i32>):  // no predecessors | ||||||
|  |     "mhlo.return"(%arg4) : (tensor<i32>) -> () | ||||||
|  |   }) { | ||||||
|  |     indices_are_sorted = false, | ||||||
|  |     scatter_dimension_numbers = { | ||||||
|  |       index_vector_dim = 1 : i64, | ||||||
|  |       inserted_window_dims = dense<0> : tensor<1xi64>, | ||||||
|  |       scatter_dims_to_operand_dims = dense<0> : tensor<1xi64>, | ||||||
|  |       update_window_dims = dense<1> : tensor<1xi64> | ||||||
|  |     }, | ||||||
|  |     unique_indices = false | ||||||
|  |   } : (tensor<6x3xi32>, tensor<2x1xi32>, tensor<2x3xi32>) -> tensor<6x3xi32> | ||||||
|  |   return %0 : tensor<6x3xi32> | ||||||
|  | } | ||||||
|  | // CHECK-DAG:   #[[MAP0:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)> | ||||||
|  | // CHECK-DAG:   #[[MAP1:.*]] = affine_map<(d0, d1, d2) -> (d2, 0)> | ||||||
|  | // CHECK-DAG:   #[[MAP2:.*]] = affine_map<(d0, d1, d2) -> (d2, d1)> | ||||||
|  | // CHECK:       func @scatter_update_slice | ||||||
|  | // CHECK-SAME:    %[[ARG0:[a-zA-Z0-9_]*]] | ||||||
|  | // CHECK-SAME:    %[[ARG1:[a-zA-Z0-9_]*]] | ||||||
|  | // CHECK-SAME:    %[[ARG2:[a-zA-Z0-9_]*]] | ||||||
|  | // CHECK:         %[[RES:.*]] = linalg.generic | ||||||
|  | // CHECK-SAME:      indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]], #[[MAP0]]] | ||||||
|  | // CHECK-SAME:      iterator_types = ["parallel", "parallel", "parallel"] | ||||||
|  | // CHECK-SAME:      ins(%[[ARG0]], %[[ARG1]], %[[ARG2]] : tensor<6x3xi32>, tensor<2x1xi32>, tensor<2x3xi32>) | ||||||
|  | // CHECK-SAME:     outs(%[[ARG0]] : tensor<6x3xi32>) { | ||||||
|  | // CHECK:         ^bb0(%{{.*}}: i32, %[[IDX_I32:.*]]: i32, %[[UPDATE:.*]]: i32, %[[OUT:.*]]: i32):  // no predecessors | ||||||
|  | // CHECK:           %[[CMP_IDX:.*]] = linalg.index 0 : index | ||||||
|  | // CHECK:           %[[IDX:.*]] = index_cast %[[IDX_I32]] : i32 to index | ||||||
|  | // CHECK:           %[[PRED:.*]] = cmpi eq, %[[CMP_IDX]], %[[IDX]] : index | ||||||
|  | // CHECK:           %[[SELECT:.*]] = select %[[PRED]], %[[UPDATE]], %[[OUT]] : i32 | ||||||
|  | // CHECK:           linalg.yield %[[SELECT]] : i32 | ||||||
|  | // CHECK:         } -> tensor<6x3xi32> | ||||||
|  | // CHECK:         return %[[RES]] : tensor<6x3xi32> | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue