Add support for lowering mhlo.scatter ops to Linalg.
This only works for updating tensors, not add/min/max computations. It requires the index depth to be 1 because of the limitation in Linalg. We can not compare multiple indices without packing indices. PiperOrigin-RevId: 375137721
This commit is contained in:
		
							parent
							
								
									f5dc73ab60
								
							
						
					
					
						commit
						1ba4c714c9
					
				|  | @ -2091,6 +2091,102 @@ struct TorchIndexSelectOpOnTensorsConversion | |||
|   } | ||||
| }; | ||||
| 
 | ||||
| struct ScatterUpdateOnTensorsConversion | ||||
|     : public OpConversionPattern<mhlo::ScatterOp> { | ||||
|   using OpConversionPattern<mhlo::ScatterOp>::OpConversionPattern; | ||||
| 
 | ||||
|   LogicalResult matchAndRewrite( | ||||
|       mhlo::ScatterOp op, ArrayRef<Value> args, | ||||
|       ConversionPatternRewriter& rewriter) const final { | ||||
|     mhlo::ScatterOp::Adaptor adaptor(args); | ||||
| 
 | ||||
|     // Check if it is a tensor_scatter_nd_update-like op.
 | ||||
|     auto& body_ops = op.getRegion().front().getOperations(); | ||||
|     if (body_ops.size() != 1) return failure(); | ||||
|     auto ret_arg = body_ops.front().getOperand(0).dyn_cast<BlockArgument>(); | ||||
|     if (!ret_arg || ret_arg.getArgNumber() != 1) return failure(); | ||||
| 
 | ||||
|     auto operand_ty = adaptor.operand().getType().dyn_cast<RankedTensorType>(); | ||||
|     auto indices_ty = | ||||
|         adaptor.scatter_indices().getType().dyn_cast<RankedTensorType>(); | ||||
|     if (!operand_ty || !indices_ty) return failure(); | ||||
| 
 | ||||
|     // Linalg operations put all the computation to the innermost loop. Since we
 | ||||
|     // also iterate over scatter_indices() with some loops, we can only check
 | ||||
|     // one scatter index in one iteration. If there are multiple indices (ie,
 | ||||
|     // the index depth is greater than 1), we don't have a way to keep the
 | ||||
|     // comparison state. E.g., if the index_depth is 2, like indices = [[0, 1]],
 | ||||
|     // we should use the update value only if (i == 0 and j == 1). However, we
 | ||||
|     // can not get both indices in one iteration unless we pack them together.
 | ||||
|     auto index_vector_dim = | ||||
|         op.scatter_dimension_numbers().index_vector_dim().getInt(); | ||||
|     if (indices_ty.getDimSize(index_vector_dim) != 1) | ||||
|       return rewriter.notifyMatchFailure(op, "require index depth to be 1"); | ||||
|     if (index_vector_dim != indices_ty.getRank() - 1) { | ||||
|       return rewriter.notifyMatchFailure( | ||||
|           op, "require index_vector_dim to be the last dim"); | ||||
|     } | ||||
| 
 | ||||
|     // One of indices dims is index depth vector.
 | ||||
|     int64_t nloops = operand_ty.getRank() + indices_ty.getRank() - 1; | ||||
|     SmallVector<AffineMap, 3> indexing_maps; | ||||
|     { | ||||
|       SmallVector<AffineExpr> exprs; | ||||
|       for (int64_t i = 0, e = operand_ty.getRank(); i < e; ++i) | ||||
|         exprs.push_back(rewriter.getAffineDimExpr(i)); | ||||
|       indexing_maps.push_back(AffineMap::get(nloops, /*symbolCount=*/0, exprs, | ||||
|                                              rewriter.getContext())); | ||||
|     } | ||||
|     { | ||||
|       SmallVector<AffineExpr> exprs; | ||||
|       for (int64_t i = operand_ty.getRank(); i < nloops; ++i) | ||||
|         exprs.push_back(rewriter.getAffineDimExpr(i)); | ||||
|       // The index depth is 1.
 | ||||
|       exprs.push_back(rewriter.getAffineConstantExpr(0)); | ||||
|       indexing_maps.push_back(AffineMap::get(nloops, /*symbolCount=*/0, exprs, | ||||
|                                              rewriter.getContext())); | ||||
| 
 | ||||
|       exprs.pop_back(); | ||||
|       auto update_window_dims = | ||||
|           Extract1DVector(op.scatter_dimension_numbers().update_window_dims()); | ||||
|       for (auto d : update_window_dims) | ||||
|         exprs.push_back(rewriter.getAffineDimExpr(d)); | ||||
|       indexing_maps.push_back(AffineMap::get(nloops, /*symbolCount=*/0, exprs, | ||||
|                                              rewriter.getContext())); | ||||
|     } | ||||
|     indexing_maps.push_back(indexing_maps.front()); | ||||
| 
 | ||||
|     auto result_ty = this->typeConverter->convertType(op.getResult().getType()) | ||||
|                          .cast<ShapedType>(); | ||||
|     auto scatter_dims_to_operand_dims = Extract1DVector( | ||||
|         op.scatter_dimension_numbers().scatter_dims_to_operand_dims()); | ||||
|     assert(scatter_dims_to_operand_dims.size() == 1); | ||||
|     // Do not need init_tensor because we'd like to initialize the output as
 | ||||
|     // operand.
 | ||||
|     auto linalg_op = rewriter.create<linalg::GenericOp>( | ||||
|         op.getLoc(), /*resultTensors=*/ArrayRef<Type>{result_ty}, | ||||
|         /*inputs=*/ | ||||
|         ValueRange{adaptor.operand(), adaptor.scatter_indices(), | ||||
|                    adaptor.updates()}, | ||||
|         /*outputs=*/adaptor.operand(), indexing_maps, | ||||
|         GetNParallelLoopsAttrs(nloops), | ||||
|         [&](OpBuilder& b, Location loc, ValueRange args) { | ||||
|           Value cmp_idx = | ||||
|               b.create<linalg::IndexOp>(loc, scatter_dims_to_operand_dims[0]); | ||||
|           Value idx = b.create<IndexCastOp>(loc, b.getIndexType(), args[1]); | ||||
|           Value pred = b.create<CmpIOp>(loc, b.getI1Type(), CmpIPredicate::eq, | ||||
|                                         cmp_idx, idx); | ||||
|           // Use the output arg, so some update values won't be init value
 | ||||
|           // again.
 | ||||
|           Value res = b.create<SelectOp>(loc, args[2].getType(), pred, args[2], | ||||
|                                          args[3]); | ||||
|           b.create<linalg::YieldOp>(loc, res); | ||||
|         }); | ||||
|     rewriter.replaceOp(op, linalg_op.getResults()); | ||||
|     return success(); | ||||
|   } | ||||
| }; | ||||
| 
 | ||||
| void populateLHLOToLinalgConversionPattern(MLIRContext* context, | ||||
|                                            TypeConverter& typeConverter, | ||||
|                                            OwningRewritePatternList* patterns) { | ||||
|  | @ -2353,6 +2449,7 @@ void populateHLOToLinalgConversionPattern(MLIRContext* context, | |||
|       DepthwiseConvOpOnTensorsConversion, | ||||
|       ReduceOnTensorsConversion, | ||||
|       ReduceWindowOpOnTensorsConversion, | ||||
|       ScatterUpdateOnTensorsConversion, | ||||
|       TorchIndexSelectOpOnTensorsConversion, | ||||
|       PadOpOnTensorsConversion>(type_converter, context); | ||||
|   // clang-format on
 | ||||
|  |  | |||
|  | @ -2379,3 +2379,83 @@ func @unsigned_compare(%lhs: tensor<2x2xui32>, %rhs: tensor<2x2xui32>) -> tensor | |||
|   %0 = "mhlo.compare"(%lhs, %rhs) {comparison_direction = "GT"} : (tensor<2x2xui32>, tensor<2x2xui32>) -> tensor<2x2xi1> | ||||
|   return %0 : tensor<2x2xi1> | ||||
| } | ||||
| 
 | ||||
| // ----- | ||||
| 
 | ||||
| func @scatter_update_scalar(%arg0: tensor<3xi32>, %arg1: tensor<1x1xi32>, | ||||
|                             %arg2: tensor<1xi32>) -> tensor<3xi32> { | ||||
|   %0 = "mhlo.scatter"(%arg0, %arg1, %arg2) ( { | ||||
|   ^bb0(%arg3: tensor<i32>, %arg4: tensor<i32>):  // no predecessors | ||||
|     "mhlo.return"(%arg4) : (tensor<i32>) -> () | ||||
|   }) { | ||||
|     indices_are_sorted = false, | ||||
|     scatter_dimension_numbers = { | ||||
|       index_vector_dim = 1 : i64, | ||||
|       inserted_window_dims = dense<0> : tensor<1xi64>, | ||||
|       scatter_dims_to_operand_dims = dense<0> : tensor<1xi64>, | ||||
|       update_window_dims = dense<> : tensor<0xi64> | ||||
|     }, | ||||
|     unique_indices = false | ||||
|   } : (tensor<3xi32>, tensor<1x1xi32>, tensor<1xi32>) -> tensor<3xi32> | ||||
|   return %0 : tensor<3xi32> | ||||
| } | ||||
| // CHECK-DAG:   #[[MAP0:.*]] = affine_map<(d0, d1) -> (d0)> | ||||
| // CHECK-DAG:   #[[MAP1:.*]] = affine_map<(d0, d1) -> (d1, 0)> | ||||
| // CHECK-DAG:   #[[MAP2:.*]] = affine_map<(d0, d1) -> (d1)> | ||||
| // CHECK:       func @scatter_update_scalar | ||||
| // CHECK-SAME:    %[[ARG0:[a-zA-Z0-9_]*]] | ||||
| // CHECK-SAME:    %[[ARG1:[a-zA-Z0-9_]*]] | ||||
| // CHECK-SAME:    %[[ARG2:[a-zA-Z0-9_]*]] | ||||
| // CHECK:         %[[RES:.*]] = linalg.generic | ||||
| // CHECK-SAME:      indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]], #[[MAP0]]] | ||||
| // CHECK-SAME:      iterator_types = ["parallel", "parallel"] | ||||
| // CHECK-SAME:      ins(%[[ARG0]], %[[ARG1]], %[[ARG2]] : tensor<3xi32>, tensor<1x1xi32>, tensor<1xi32>) | ||||
| // CHECK-SAME:     outs(%[[ARG0]] : tensor<3xi32>) { | ||||
| // CHECK:         ^bb0(%{{.*}}: i32, %[[IDX_I32:.*]]: i32, %[[UPDATE:.*]]: i32, %[[OUT:.*]]: i32):  // no predecessors | ||||
| // CHECK:           %[[CMP_IDX:.*]] = linalg.index 0 : index | ||||
| // CHECK:           %[[IDX:.*]] = index_cast %[[IDX_I32]] : i32 to index | ||||
| // CHECK:           %[[PRED:.*]] = cmpi eq, %[[CMP_IDX]], %[[IDX]] : index | ||||
| // CHECK:           %[[SELECT:.*]] = select %[[PRED]], %[[UPDATE]], %[[OUT]] : i32 | ||||
| // CHECK:           linalg.yield %[[SELECT]] : i32 | ||||
| // CHECK:         } -> tensor<3xi32> | ||||
| // CHECK:         return %[[RES]] : tensor<3xi32> | ||||
| 
 | ||||
| // ----- | ||||
| 
 | ||||
| func @scatter_update_slice(%arg0: tensor<6x3xi32>, %arg1: tensor<2x1xi32>, | ||||
|                            %arg2: tensor<2x3xi32>) -> tensor<6x3xi32> { | ||||
|   %0 = "mhlo.scatter"(%arg0, %arg1, %arg2) ( { | ||||
|   ^bb0(%arg3: tensor<i32>, %arg4: tensor<i32>):  // no predecessors | ||||
|     "mhlo.return"(%arg4) : (tensor<i32>) -> () | ||||
|   }) { | ||||
|     indices_are_sorted = false, | ||||
|     scatter_dimension_numbers = { | ||||
|       index_vector_dim = 1 : i64, | ||||
|       inserted_window_dims = dense<0> : tensor<1xi64>, | ||||
|       scatter_dims_to_operand_dims = dense<0> : tensor<1xi64>, | ||||
|       update_window_dims = dense<1> : tensor<1xi64> | ||||
|     }, | ||||
|     unique_indices = false | ||||
|   } : (tensor<6x3xi32>, tensor<2x1xi32>, tensor<2x3xi32>) -> tensor<6x3xi32> | ||||
|   return %0 : tensor<6x3xi32> | ||||
| } | ||||
| // CHECK-DAG:   #[[MAP0:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)> | ||||
| // CHECK-DAG:   #[[MAP1:.*]] = affine_map<(d0, d1, d2) -> (d2, 0)> | ||||
| // CHECK-DAG:   #[[MAP2:.*]] = affine_map<(d0, d1, d2) -> (d2, d1)> | ||||
| // CHECK:       func @scatter_update_slice | ||||
| // CHECK-SAME:    %[[ARG0:[a-zA-Z0-9_]*]] | ||||
| // CHECK-SAME:    %[[ARG1:[a-zA-Z0-9_]*]] | ||||
| // CHECK-SAME:    %[[ARG2:[a-zA-Z0-9_]*]] | ||||
| // CHECK:         %[[RES:.*]] = linalg.generic | ||||
| // CHECK-SAME:      indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]], #[[MAP0]]] | ||||
| // CHECK-SAME:      iterator_types = ["parallel", "parallel", "parallel"] | ||||
| // CHECK-SAME:      ins(%[[ARG0]], %[[ARG1]], %[[ARG2]] : tensor<6x3xi32>, tensor<2x1xi32>, tensor<2x3xi32>) | ||||
| // CHECK-SAME:     outs(%[[ARG0]] : tensor<6x3xi32>) { | ||||
| // CHECK:         ^bb0(%{{.*}}: i32, %[[IDX_I32:.*]]: i32, %[[UPDATE:.*]]: i32, %[[OUT:.*]]: i32):  // no predecessors | ||||
| // CHECK:           %[[CMP_IDX:.*]] = linalg.index 0 : index | ||||
| // CHECK:           %[[IDX:.*]] = index_cast %[[IDX_I32]] : i32 to index | ||||
| // CHECK:           %[[PRED:.*]] = cmpi eq, %[[CMP_IDX]], %[[IDX]] : index | ||||
| // CHECK:           %[[SELECT:.*]] = select %[[PRED]], %[[UPDATE]], %[[OUT]] : i32 | ||||
| // CHECK:           linalg.yield %[[SELECT]] : i32 | ||||
| // CHECK:         } -> tensor<6x3xi32> | ||||
| // CHECK:         return %[[RES]] : tensor<6x3xi32> | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue