Extract some duplicated code into a helper function.
- Extract code to create result memref's into a ConvertResults function. - Also fix a bug when using reifyReturnTypes: use correct index for result_shape instead of always using the first element. PiperOrigin-RevId: 341852227
This commit is contained in:
		
							parent
							
								
									d4f2c767d3
								
							
						
					
					
						commit
						745c8aa0b1
					
				|  | @ -87,6 +87,32 @@ Value InsertAlloc(Location loc, OpResult result, | |||
|   return alloc; | ||||
| } | ||||
| 
 | ||||
| /// Converts the results of the operation `op` to memref types and append them
 | ||||
| /// to the `results` vector.
 | ||||
| LogicalResult ConvertResults(Operation* op, SmallVectorImpl<Value>& results, | ||||
|                              ConversionPatternRewriter& rewriter) { | ||||
|   for (auto result : llvm::enumerate(op->getResults())) { | ||||
|     RankedTensorType resultType = | ||||
|         result.value().getType().dyn_cast<RankedTensorType>(); | ||||
|     if (!resultType) return failure(); | ||||
| 
 | ||||
|     if (resultType.hasStaticShape()) { | ||||
|       results.push_back(InsertAlloc(op->getLoc(), result.value(), &rewriter)); | ||||
|       continue; | ||||
|     } | ||||
|     auto shape_type_op = dyn_cast<InferShapedTypeOpInterface>(op); | ||||
|     if (!shape_type_op) return failure(); | ||||
| 
 | ||||
|     SmallVector<Value, 1> results_shape; | ||||
|     auto status = shape_type_op.reifyReturnTypeShapes(rewriter, results_shape); | ||||
|     if (failed(status)) return failure(); | ||||
|     results.push_back( | ||||
|         InsertDynamicAllocAndDealloc(op->getLoc(), result.value(), | ||||
|                                      results_shape[result.index()], &rewriter)); | ||||
|   } | ||||
|   return success(); | ||||
| } | ||||
| 
 | ||||
| template <typename HloOpTy> | ||||
| class HloToLhloOpConverter : public BaseOpConversion<HloOpTy> { | ||||
|  public: | ||||
|  | @ -95,29 +121,8 @@ class HloToLhloOpConverter : public BaseOpConversion<HloOpTy> { | |||
|       HloOpTy hloOp, ArrayRef<Value> operands, | ||||
|       ConversionPatternRewriter& rewriter) const final { | ||||
|     Operation* op = hloOp.getOperation(); | ||||
|     const auto& original_results = op->getResults(); | ||||
|     SmallVector<Value, 4> buffer_args(operands.begin(), operands.end()); | ||||
|     for (auto result : llvm::enumerate(original_results)) { | ||||
|       RankedTensorType resultType = | ||||
|           result.value().getType().dyn_cast<RankedTensorType>(); | ||||
|       if (!resultType) { | ||||
|         return failure(); | ||||
|       } | ||||
|       if (resultType.hasStaticShape()) { | ||||
|         buffer_args.push_back( | ||||
|             InsertAlloc(op->getLoc(), result.value(), &rewriter)); | ||||
|       } else { | ||||
|         auto shape_type_op = dyn_cast<InferShapedTypeOpInterface>(op); | ||||
|         if (!shape_type_op) return failure(); | ||||
| 
 | ||||
|         SmallVector<Value, 1> results_shape; | ||||
|         auto status = | ||||
|             shape_type_op.reifyReturnTypeShapes(rewriter, results_shape); | ||||
|         if (failed(status)) return failure(); | ||||
|         buffer_args.push_back(InsertDynamicAllocAndDealloc( | ||||
|             op->getLoc(), result.value(), results_shape.front(), &rewriter)); | ||||
|       } | ||||
|     } | ||||
|     if (failed(ConvertResults(op, buffer_args, rewriter))) return failure(); | ||||
|     rewriter.create<mhlo::HloToLhloOp<HloOpTy>>(op->getLoc(), llvm::None, | ||||
|                                                 buffer_args, op->getAttrs()); | ||||
|     rewriter.replaceOp( | ||||
|  | @ -139,28 +144,8 @@ class HloToLhloOpConverter<mhlo::DotOp> : public BaseOpConversion<mhlo::DotOp> { | |||
|       mhlo::DotOp hloOp, ArrayRef<Value> operands, | ||||
|       ConversionPatternRewriter& rewriter) const final { | ||||
|     Operation* op = hloOp.getOperation(); | ||||
|     const auto& original_results = op->getResults(); | ||||
|     SmallVector<Value, 2> buffer_args(operands.begin(), operands.end()); | ||||
|     for (auto result : llvm::enumerate(original_results)) { | ||||
|       RankedTensorType resultType = | ||||
|           result.value().getType().dyn_cast<RankedTensorType>(); | ||||
|       if (!resultType) { | ||||
|         return failure(); | ||||
|       } | ||||
|       if (resultType.hasStaticShape()) { | ||||
|         buffer_args.push_back( | ||||
|             InsertAlloc(op->getLoc(), result.value(), &rewriter)); | ||||
|       } else { | ||||
|         SmallVector<Value, 1> results_shape; | ||||
|         auto shape_type_op = dyn_cast<InferShapedTypeOpInterface>(op); | ||||
|         if (!shape_type_op) return failure(); | ||||
|         if (failed( | ||||
|                 shape_type_op.reifyReturnTypeShapes(rewriter, results_shape))) | ||||
|           return failure(); | ||||
|         buffer_args.push_back(InsertDynamicAllocAndDealloc( | ||||
|             op->getLoc(), result.value(), results_shape.front(), &rewriter)); | ||||
|       } | ||||
|     } | ||||
|     if (failed(ConvertResults(op, buffer_args, rewriter))) return failure(); | ||||
| 
 | ||||
|     // TODO(silvasean): Move this helper to MLIR core.
 | ||||
|     auto make_elements_attr = [&rewriter](ArrayRef<int64_t> integers) { | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue