PR #50211: [MLIR][DISC] Bufferize RealDynamicSliceOp and ReduceOp
Imported from GitHub PR https://github.com/tensorflow/tensorflow/pull/50211 support hlo-to-lhlo conversion for RealDynamicSliceOp and ReduceOp Copybara import of the project: -- c417b336670a1fc256f7026dfe8080e46d13d79a by Wenyi Zhao <reyizero@gmail.com>: [MLIR][DISC] Bufferize RealDynamicSliceOp and ReduceOp PiperOrigin-RevId: 378972113
This commit is contained in:
		
							parent
							
								
									95ba03534f
								
							
						
					
					
						commit
						8388303fd2
					
				|  | @ -967,7 +967,7 @@ def HLO_AllToAllOp : HLO_Op<"all_to_all", | |||
|   let results = (outs HLO_Tensor); | ||||
| } | ||||
| 
 | ||||
| def HLO_ReduceOp: HLO_Op<"reduce", [ | ||||
| def HLO_ReduceOp: HLO_ShapedInterfaceOp<"reduce", [ | ||||
|       RecursiveSideEffects, | ||||
|       SameVariadicOperandSize, | ||||
|       SingleBlockImplicitTerminator<"ReturnOp">, | ||||
|  | @ -2199,7 +2199,7 @@ def HLO_ReducePrecisionOp : | |||
|   let results = (outs HLO_FpTensor:$output); | ||||
| } | ||||
| 
 | ||||
| def HLO_RealDynamicSliceOp: HLO_Op< | ||||
| def HLO_RealDynamicSliceOp: HLO_ShapedInterfaceOp< | ||||
|       "real_dynamic_slice", | ||||
|       [NoSideEffect, AllElementTypesMatch<["operand", "result"]>, | ||||
|        AllTypesMatch<["start_indices", "limit_indices", "strides"]>]> { | ||||
|  |  | |||
|  | @ -75,6 +75,7 @@ MAP_HLO_TO_LHLO(NegOp); | |||
| MAP_HLO_TO_LHLO(NotOp); | ||||
| MAP_HLO_TO_LHLO(OrOp); | ||||
| MAP_HLO_TO_LHLO(PowOp); | ||||
| MAP_HLO_TO_LHLO(RealDynamicSliceOp); | ||||
| MAP_HLO_TO_LHLO(RealOp); | ||||
| MAP_HLO_TO_LHLO(ReduceOp); | ||||
| MAP_HLO_TO_LHLO(ReshapeOp); | ||||
|  |  | |||
|  | @ -1626,6 +1626,52 @@ static LogicalResult Verify(RealDynamicSliceOp op) { | |||
|   return success(); | ||||
| } | ||||
| 
 | ||||
| LogicalResult RealDynamicSliceOp::reifyReturnTypeShapes( | ||||
|     OpBuilder& builder, ValueRange operands, | ||||
|     SmallVectorImpl<Value>& reifiedReturnShapes) { | ||||
|   RealDynamicSliceOp::Adaptor adaptor(operands); | ||||
|   Value operand = adaptor.operand(); | ||||
|   Value start_indices = adaptor.start_indices(); | ||||
|   Value limit_indices = adaptor.limit_indices(); | ||||
|   Value strides = adaptor.strides(); | ||||
| 
 | ||||
|   auto operand_type = operand.getType().dyn_cast<RankedTensorType>(); | ||||
|   // Not support unranked type a.t.m.
 | ||||
|   if (!operand_type) return failure(); | ||||
| 
 | ||||
|   Location loc = this->getLoc(); | ||||
|   SmallVector<Value, 4> shape_values; | ||||
|   shape_values.reserve(operand_type.getRank()); | ||||
|   Type shape_scalar_type = | ||||
|       start_indices.getType().cast<ShapedType>().getElementType(); | ||||
|   Value one = builder.create<ConstantIndexOp>(loc, 1); | ||||
|   one = MaybeCastTo(builder, loc, one, shape_scalar_type); | ||||
|   for (const auto& element : llvm::enumerate(operand_type.getShape())) { | ||||
|     Value offset = builder.create<ConstantIndexOp>(loc, element.index()); | ||||
|     Value value_start = | ||||
|         builder.create<tensor::ExtractOp>(loc, start_indices, offset); | ||||
|     Value value_limit = | ||||
|         builder.create<tensor::ExtractOp>(loc, limit_indices, offset); | ||||
|     Value value_stride = | ||||
|         builder.create<tensor::ExtractOp>(loc, strides, offset); | ||||
|     // size = (limit - start + stride - 1) / stride
 | ||||
|     shape_values.push_back(builder.create<SignedDivIOp>( | ||||
|         loc, | ||||
|         builder.create<SubIOp>( | ||||
|             loc, | ||||
|             builder.create<AddIOp>( | ||||
|                 loc, value_stride, | ||||
|                 builder.create<SubIOp>(loc, value_limit, value_start)), | ||||
|             one), | ||||
|         value_stride)); | ||||
|   } | ||||
| 
 | ||||
|   reifiedReturnShapes.push_back(builder.create<tensor::FromElementsOp>( | ||||
|       loc, shape_scalar_type, shape_values)); | ||||
| 
 | ||||
|   return success(); | ||||
| } | ||||
| 
 | ||||
| //===----------------------------------------------------------------------===//
 | ||||
| // InfeedOp
 | ||||
| //===----------------------------------------------------------------------===//
 | ||||
|  | @ -2072,6 +2118,46 @@ void ReduceOp::getCanonicalizationPatterns(OwningRewritePatternList& results, | |||
|                                            MLIRContext* context) { | ||||
|   results.insert<LowerBoolSplatConstantsIntoRegion>(context); | ||||
| } | ||||
| 
 | ||||
| LogicalResult ReduceOp::reifyReturnTypeShapes( | ||||
|     OpBuilder& builder, ValueRange operands, | ||||
|     SmallVectorImpl<Value>& reifiedReturnShapes) { | ||||
|   ReduceOp::Adaptor adaptor(operands); | ||||
|   auto inputs = adaptor.inputs(); | ||||
| 
 | ||||
|   auto operand_type = inputs[0].getType().dyn_cast<RankedTensorType>(); | ||||
|   // Not support unranked type a.t.m.
 | ||||
|   if (!operand_type) return failure(); | ||||
| 
 | ||||
|   Location loc = this->getLoc(); | ||||
|   SmallVector<Value, 4> shape_values; | ||||
|   SmallVector<int64_t, 4> dimensions(this->dimensions().getValues<int64_t>()); | ||||
|   shape_values.reserve(operand_type.getRank()); | ||||
|   Type shape_scalar_type = builder.getIndexType(); | ||||
|   auto to_shape_scalar_type = [&](Value v) { | ||||
|     return MaybeCastTo(builder, loc, v, shape_scalar_type); | ||||
|   }; | ||||
| 
 | ||||
|   for (const auto& element : llvm::enumerate(operand_type.getShape())) { | ||||
|     int64_t idx = element.index(); | ||||
|     auto it = std::find(dimensions.begin(), dimensions.end(), idx); | ||||
|     if (it != dimensions.end()) { | ||||
|       continue; | ||||
|     } | ||||
|     Value value_dim = to_shape_scalar_type( | ||||
|         builder.create<memref::DimOp>(loc, inputs[0], element.index())); | ||||
|     shape_values.push_back(value_dim); | ||||
|   } | ||||
| 
 | ||||
|   Value output_shape = builder.create<tensor::FromElementsOp>( | ||||
|       loc, shape_scalar_type, shape_values); | ||||
|   for (size_t i = 0; i < inputs.size(); ++i) { | ||||
|     reifiedReturnShapes.push_back(output_shape); | ||||
|   } | ||||
| 
 | ||||
|   return success(); | ||||
| } | ||||
| 
 | ||||
| //===----------------------------------------------------------------------===//
 | ||||
| // RngNormalOp
 | ||||
| //===----------------------------------------------------------------------===//
 | ||||
|  |  | |||
|  | @ -434,11 +434,8 @@ struct HloToLhloReduceOpConverter : public BaseOpConversion<mhlo::ReduceOp> { | |||
|              << "tensor to buffer conversion expects a single block " | ||||
|                 "in the region containing the operation"; | ||||
|     } | ||||
|     const auto& original_results = op.getResults(); | ||||
|     SmallVector<Value, 4> buffer_args(operands.begin(), operands.end()); | ||||
|     for (auto result : original_results) { | ||||
|       buffer_args.push_back(InsertAlloc(loc, result, &rewriter)); | ||||
|     } | ||||
|     if (failed(ConvertResults(op, buffer_args, rewriter))) return failure(); | ||||
|     auto new_op = rewriter.create<lmhlo::ReduceOp>(loc, llvm::None, buffer_args, | ||||
|                                                    op->getAttrs()); | ||||
| 
 | ||||
|  | @ -671,7 +668,8 @@ void populateDynamicHLOToLHLOOnlyConversionPattern( | |||
|   patterns->insert<HloToLhloOpConverter<mhlo::DynamicBroadcastInDimOp>, | ||||
|                    HloToLhloOpConverter<mhlo::DynamicIotaOp>, | ||||
|                    HloToLhloOpConverter<mhlo::DynamicPadOp>, | ||||
|                    HloToLhloOpConverter<mhlo::DynamicReshapeOp> | ||||
|                    HloToLhloOpConverter<mhlo::DynamicReshapeOp>, | ||||
|                    HloToLhloOpConverter<mhlo::RealDynamicSliceOp> | ||||
|   >(*converter, context); | ||||
|   // clang-format on
 | ||||
| } | ||||
|  |  | |||
|  | @ -81,3 +81,72 @@ func @dynamic_pad(%arg0: tensor<?x?xf32>, %arg1: tensor<f32>, %arg2: tensor<2xin | |||
|   %0 = "mhlo.dynamic_pad"(%arg0, %arg1, %arg2, %arg3, %arg4) : (tensor<?x?xf32>, tensor<f32>, tensor<2xindex>, tensor<2xindex>, tensor<2xindex>) -> tensor<?x?xf32> | ||||
|   return %0: tensor<?x?xf32> | ||||
| } | ||||
| 
 | ||||
| // ----- | ||||
| 
 | ||||
| // CHECK-LABEL: func @real_dynamic_slice | ||||
| // CHECK-SAME: (%[[ARG:.*]]: memref<?x?xf32>, | ||||
| // CHECK-SAME:  %[[START:.*]]: memref<2xi32>, %[[LIMIT:.*]]: memref<2xi32>, %[[STRIDE:.*]]: memref<2xi32>) -> memref<?x?xf32> | ||||
| func @real_dynamic_slice(%arg0: tensor<?x?xf32>, %arg1: tensor<2xi32>, %arg2: tensor<2xi32>, %arg3: tensor<2xi32>) -> tensor<?x?xf32> { | ||||
|   // CHECK-NOT: tensor_load | ||||
|   // CHECK: %[[T0:.*]] = memref.load %[[START]][%c0] : memref<2xi32> | ||||
|   // CHECK: %[[T1:.*]] = memref.load %[[LIMIT]][%c0] : memref<2xi32> | ||||
|   // CHECK: %[[T2:.*]] = memref.load %[[STRIDE]][%c0] : memref<2xi32> | ||||
|   // CHECK: %[[T3:.*]] = subi %[[T1]], %[[T0]] : i32 | ||||
|   // CHECK: %[[T4:.*]] = addi %[[T2]], %[[T3]] : i32 | ||||
|   // CHECK: %[[T5:.*]] = subi %[[T4]], %c1_i32 : i32 | ||||
|   // CHECK: %[[T6:.*]] = divi_signed %[[T5]], %[[T2]] : i32 | ||||
|   // CHECK: %[[T7:.*]] = memref.load %[[START]][%c1] : memref<2xi32> | ||||
|   // CHECK: %[[T8:.*]] = memref.load %[[LIMIT]][%c1] : memref<2xi32> | ||||
|   // CHECK: %[[T9:.*]] = memref.load %[[STRIDE]][%c1] : memref<2xi32> | ||||
|   // CHECK: %[[T10:.*]] = subi %[[T8]], %[[T7]] : i32 | ||||
|   // CHECK: %[[T11:.*]] = addi %[[T9]], %[[T10]] : i32 | ||||
|   // CHECK: %[[T12:.*]] = subi %[[T11]], %c1_i32 : i32 | ||||
|   // CHECK: %[[T13:.*]] = divi_signed %[[T12]], %[[T9]] : i32 | ||||
|   // CHECK: %[[T14:.*]] = index_cast %[[T6]] : i32 to index | ||||
|   // CHECK: %[[T15:.*]] = index_cast %[[T13]] : i32 to index | ||||
|   // CHECK: %[[T16:.*]] = memref.alloc(%[[T14]], %[[T15]]) : memref<?x?xf32> | ||||
|   // CHECK: "lmhlo.real_dynamic_slice"(%[[ARG]], %[[START]], %[[LIMIT]], %[[STRIDE]], %[[T16]]) | ||||
|   %0 = "mhlo.real_dynamic_slice"(%arg0, %arg1, %arg2, %arg3) : (tensor<?x?xf32>, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor<?x?xf32> | ||||
|   return %0: tensor<?x?xf32> | ||||
| } | ||||
| 
 | ||||
| // ----- | ||||
| 
 | ||||
| // CHECK-LABEL: func @row_reduce | ||||
| // CHECK-SAME: (%[[ARG:.*]]: memref<?x?xf32>, %[[VAL:.*]]: memref<f32>) -> memref<?xf32> | ||||
| func @row_reduce(%arg0: tensor<?x?xf32>, %arg1: tensor<f32>) -> tensor<?xf32> { | ||||
|   // CHECK-NOT: tensor_load | ||||
|   // CHECK: %[[DIM0:.*]] = memref.dim %[[ARG]], %c0 : memref<?x?xf32> | ||||
|   // CHECK: %[[OUT:.*]] = memref.alloc(%[[DIM0]]) : memref<?xf32> | ||||
|   // CHECK: lmhlo.reduce | ||||
|   // CHECK-SAME: %[[ARG]], %[[VAL]], %[[OUT]] | ||||
|   // CHECK: return %[[OUT]] : memref<?xf32> | ||||
|   %0 = "mhlo.reduce"(%arg0, %arg1) ( { | ||||
|   ^bb0(%arg2: tensor<f32>, %arg3: tensor<f32>):  // no predecessors | ||||
|     %1 = mhlo.add %arg2, %arg3 : tensor<f32> | ||||
|     "mhlo.return"(%1) : (tensor<f32>) -> () | ||||
|   }) {dimensions = dense<1> : tensor<1xi64>} | ||||
|       : (tensor<?x?xf32>, tensor<f32>) -> tensor<?xf32> | ||||
|   return %0: tensor<?xf32> | ||||
| } | ||||
| 
 | ||||
| // ----- | ||||
| 
 | ||||
| // CHECK-LABEL: func @column_reduce | ||||
| // CHECK-SAME: (%[[ARG:.*]]: memref<?x?xf32>, %[[VAL:.*]]: memref<f32>) -> memref<?xf32> | ||||
| func @column_reduce(%arg0: tensor<?x?xf32>, %arg1: tensor<f32>) -> tensor<?xf32> { | ||||
|   // CHECK-NOT: tensor_load | ||||
|   // CHECK: %[[DIM1:.*]] = memref.dim %[[ARG]], %c1 : memref<?x?xf32> | ||||
|   // CHECK: %[[OUT:.*]] = memref.alloc(%[[DIM1]]) : memref<?xf32> | ||||
|   // CHECK: lmhlo.reduce | ||||
|   // CHECK-SAME: %[[ARG]], %[[VAL]], %[[OUT]] | ||||
|   // CHECK: return %[[OUT]] : memref<?xf32> | ||||
|   %0 = "mhlo.reduce"(%arg0, %arg1) ( { | ||||
|   ^bb0(%arg2: tensor<f32>, %arg3: tensor<f32>):  // no predecessors | ||||
|     %1 = mhlo.add %arg2, %arg3 : tensor<f32> | ||||
|     "mhlo.return"(%1) : (tensor<f32>) -> () | ||||
|   }) {dimensions = dense<0> : tensor<1xi64>} | ||||
|       : (tensor<?x?xf32>, tensor<f32>) -> tensor<?xf32> | ||||
|   return %0: tensor<?xf32> | ||||
| } | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue