PR #49454: [MLIR][DISC] Upgrade to use the new `reifyReturnTypeShapes` interface.
Imported from GitHub PR https://github.com/tensorflow/tensorflow/pull/49454 The new interface is more safe to be used during dialect conversion (e.g. converting from tensor world to buffer world). Copybara import of the project: -- a6968072d59bec3c3bbaef0121d297e807c37c91 by Wenyi Zhao <reyizero@gmail.com>: [MLIR][DISC] Upgrade to use the new `reifyReturnTypeShapes` interface. The new interface is more safe to be used during dialect conversion (e.g. converting from tensor world to buffer world). -- 55e7c6b7f2f99b99e226645a57e2433fae3e90ed by Wenyi Zhao <reyizero@gmail.com>: minor fix PiperOrigin-RevId: 375500273
This commit is contained in:
		
							parent
							
								
									28c4112f35
								
							
						
					
					
						commit
						b93e54d8a4
					
				|  | @ -441,7 +441,7 @@ class HLOClient_UnaryElementwiseOp<string mnemonic, list<OpTrait> traits, | |||
|     LogicalResult reifyReturnTypeShapes(OpBuilder& builder, ValueRange operands, | ||||
|         SmallVectorImpl<Value>& reifiedReturnShapes) { | ||||
|       return ::mlir::mhlo::deriveShapeFromFirstOperand(&builder, getOperation(), | ||||
|           &reifiedReturnShapes); | ||||
|           operands, &reifiedReturnShapes); | ||||
|     } | ||||
|   }]; | ||||
| } | ||||
|  |  | |||
|  | @ -79,7 +79,7 @@ class TokenType : public Type::TypeBase<TokenType, Type, TypeStorage> { | |||
| //
 | ||||
| // and returns %4 as the shape value.
 | ||||
| LogicalResult deriveShapeFromFirstOperand( | ||||
|     OpBuilder *builder, Operation *op, | ||||
|     OpBuilder *builder, Operation *op, ValueRange operands, | ||||
|     SmallVectorImpl<Value> *reifiedReturnShapes); | ||||
| 
 | ||||
| // Type derivation function that returns a tensor type with a new element type.
 | ||||
|  |  | |||
|  | @ -142,6 +142,7 @@ class HLO_UnaryElementwiseOp<string mnemonic, list<OpTrait> traits, | |||
|         OpBuilder& builder, ValueRange operands, | ||||
|         SmallVectorImpl<Value>& reifiedReturnShapes) { | ||||
|       return ::mlir::mhlo::deriveShapeFromFirstOperand(&builder, getOperation(), | ||||
|                                                        operands, | ||||
|                                                        &reifiedReturnShapes); | ||||
|     } | ||||
|     bool inferInputOutputShapeEquality(int input, int output) { | ||||
|  | @ -456,6 +457,7 @@ class HLO_BinaryElementwiseOp<string mnemonic, list<OpTrait> traits> : | |||
|         OpBuilder& builder, ValueRange operands, | ||||
|         SmallVectorImpl<Value>& reifiedReturnShapes) { | ||||
|       return ::mlir::mhlo::deriveShapeFromFirstOperand(&builder, getOperation(), | ||||
|                                                        operands, | ||||
|                                                        &reifiedReturnShapes); | ||||
|     } | ||||
|     bool inferInputsShapeEquality(int lhs, int rhs) { | ||||
|  |  | |||
|  | @ -155,11 +155,11 @@ LogicalResult InferBroadcastBinaryOpReturnTypeComponents( | |||
| } | ||||
| 
 | ||||
| LogicalResult ReifyBroadcastBinaryOpReturnTypeShapes( | ||||
|     OpBuilder& builder, Operation* op, | ||||
|     OpBuilder& builder, Operation* op, ValueRange operands, | ||||
|     SmallVectorImpl<Value>& reifiedReturnShapes) { | ||||
|   auto loc = op->getLoc(); | ||||
|   auto lhs = op->getOperand(0); | ||||
|   auto rhs = op->getOperand(1); | ||||
|   auto lhs = operands[0]; | ||||
|   auto rhs = operands[1]; | ||||
| 
 | ||||
|   // Check for "numpy"-style rank broadcast.
 | ||||
|   auto broadcast_dimensions = op->getAttr("broadcast_dimensions") | ||||
|  | @ -204,10 +204,10 @@ LogicalResult BroadcastComplexOp::inferReturnTypeComponents( | |||
|                                                     inferedReturnShapes); | ||||
| } | ||||
| LogicalResult BroadcastComplexOp::reifyReturnTypeShapes( | ||||
|     OpBuilder& builder, ValueRange, | ||||
|     OpBuilder& builder, ValueRange operands, | ||||
|     SmallVectorImpl<Value>& reifiedReturnShapes) { | ||||
|   return ReifyBroadcastBinaryOpReturnTypeShapes(builder, getOperation(), | ||||
|                                                 reifiedReturnShapes); | ||||
|                                                 operands, reifiedReturnShapes); | ||||
| } | ||||
| 
 | ||||
| //===----------------------------------------------------------------------===//
 | ||||
|  | @ -236,10 +236,10 @@ LogicalResult BroadcastCompareOp::inferReturnTypeComponents( | |||
| } | ||||
| 
 | ||||
| LogicalResult BroadcastCompareOp::reifyReturnTypeShapes( | ||||
|     OpBuilder& builder, ValueRange, | ||||
|     OpBuilder& builder, ValueRange operands, | ||||
|     SmallVectorImpl<Value>& reifiedReturnShapes) { | ||||
|   return ReifyBroadcastBinaryOpReturnTypeShapes(builder, getOperation(), | ||||
|                                                 reifiedReturnShapes); | ||||
|                                                 operands, reifiedReturnShapes); | ||||
| } | ||||
| 
 | ||||
| //===----------------------------------------------------------------------===//
 | ||||
|  | @ -295,10 +295,10 @@ LogicalResult IsPosInfOp::inferReturnTypes( | |||
|         inferedReturnShapes);                                                 \ | ||||
|   }                                                                           \ | ||||
|   LogicalResult Op::reifyReturnTypeShapes(                                    \ | ||||
|       OpBuilder& builder, ValueRange,                                         \ | ||||
|       OpBuilder& builder, ValueRange operands,                                \ | ||||
|       SmallVectorImpl<Value>& reifiedReturnShapes) {                          \ | ||||
|     return ReifyBroadcastBinaryOpReturnTypeShapes(builder, getOperation(),    \ | ||||
|                                                   reifiedReturnShapes);       \ | ||||
|     return ReifyBroadcastBinaryOpReturnTypeShapes(                            \ | ||||
|         builder, getOperation(), operands, reifiedReturnShapes);              \ | ||||
|   } | ||||
| 
 | ||||
| #define BROADCAST_BINARY_OP_DEFS(Op)                                           \ | ||||
|  |  | |||
|  | @ -1002,8 +1002,10 @@ LogicalResult DynamicBroadcastInDimOp::inferReturnTypeComponents( | |||
| } | ||||
| 
 | ||||
| LogicalResult DynamicBroadcastInDimOp::reifyReturnTypeShapes( | ||||
|     OpBuilder&, ValueRange, SmallVectorImpl<Value>& reifiedReturnShapes) { | ||||
|   reifiedReturnShapes.push_back(output_dimensions()); | ||||
|     OpBuilder&, ValueRange operands, | ||||
|     SmallVectorImpl<Value>& reifiedReturnShapes) { | ||||
|   DynamicBroadcastInDimOp::Adaptor adaptor(operands); | ||||
|   reifiedReturnShapes.push_back(adaptor.output_dimensions()); | ||||
|   return success(); | ||||
| } | ||||
| 
 | ||||
|  | @ -2137,7 +2139,7 @@ LogicalResult SelectOp::inferReturnTypeComponents( | |||
| LogicalResult SelectOp::reifyReturnTypeShapes( | ||||
|     OpBuilder& builder, ValueRange operands, | ||||
|     SmallVectorImpl<Value>& reifiedReturnShapes) { | ||||
|   return deriveShapeFromFirstOperand(&builder, getOperation(), | ||||
|   return deriveShapeFromFirstOperand(&builder, getOperation(), operands, | ||||
|                                      &reifiedReturnShapes); | ||||
| } | ||||
| 
 | ||||
|  | @ -3278,7 +3280,7 @@ LogicalResult CompareOp::inferReturnTypeComponents( | |||
| LogicalResult CompareOp::reifyReturnTypeShapes( | ||||
|     OpBuilder& builder, ValueRange operands, | ||||
|     SmallVectorImpl<Value>& reifiedReturnShapes) { | ||||
|   return deriveShapeFromFirstOperand(&builder, getOperation(), | ||||
|   return deriveShapeFromFirstOperand(&builder, getOperation(), operands, | ||||
|                                      &reifiedReturnShapes); | ||||
| } | ||||
| 
 | ||||
|  | @ -3629,9 +3631,9 @@ void MhloDialect::printType(Type type, DialectAsmPrinter& os) const { | |||
| //===----------------------------------------------------------------------===//
 | ||||
| 
 | ||||
| LogicalResult deriveShapeFromFirstOperand( | ||||
|     OpBuilder* builder, Operation* op, | ||||
|     OpBuilder* builder, Operation* op, ValueRange operands, | ||||
|     SmallVectorImpl<Value>* reifiedReturnShapes) { | ||||
|   Value operand = op->getOperand(0); | ||||
|   Value operand = operands.front(); | ||||
|   ShapedType operand_type = operand.getType().dyn_cast<ShapedType>(); | ||||
|   if (!operand_type) { | ||||
|     op->emitOpError() << "first operand is not a shaped type"; | ||||
|  |  | |||
|  | @ -94,6 +94,8 @@ Value InsertAlloc(Location loc, OpResult result, | |||
| /// to the `results` vector.
 | ||||
| LogicalResult ConvertResults(Operation* op, SmallVectorImpl<Value>& results, | ||||
|                              ConversionPatternRewriter& rewriter) { | ||||
|   size_t num_operands = results.size(); | ||||
|   SmallVector<Value, 2> tensor_operands; | ||||
|   for (auto result : llvm::enumerate(op->getResults())) { | ||||
|     RankedTensorType resultType = | ||||
|         result.value().getType().dyn_cast<RankedTensorType>(); | ||||
|  | @ -106,9 +108,19 @@ LogicalResult ConvertResults(Operation* op, SmallVectorImpl<Value>& results, | |||
|     auto shape_type_op = dyn_cast<InferShapedTypeOpInterface>(op); | ||||
|     if (!shape_type_op) return failure(); | ||||
| 
 | ||||
|     if (tensor_operands.empty()) { | ||||
|       for (auto operand : ArrayRef<Value>(results).take_front(num_operands)) { | ||||
|         auto tp = operand.getType().cast<ShapedType>(); | ||||
|         tensor_operands.push_back(rewriter.create<memref::TensorLoadOp>( | ||||
|             op->getLoc(), | ||||
|             RankedTensorType::get(tp.getShape(), tp.getElementType()), | ||||
|             operand)); | ||||
|       } | ||||
|     } | ||||
| 
 | ||||
|     SmallVector<Value, 1> results_shape; | ||||
|     auto status = shape_type_op.reifyReturnTypeShapes( | ||||
|         rewriter, shape_type_op->getOperands(), results_shape); | ||||
|     auto status = shape_type_op.reifyReturnTypeShapes(rewriter, tensor_operands, | ||||
|                                                       results_shape); | ||||
|     if (failed(status)) return failure(); | ||||
|     results.push_back( | ||||
|         InsertDynamicAllocAndDealloc(op->getLoc(), result.value(), | ||||
|  | @ -391,8 +403,8 @@ struct HloToLhloDotGeneralOpConverter | |||
|     } else { | ||||
|       SmallVector<Value, 1> results_shape; | ||||
|       auto shape_type_op = dyn_cast<InferShapedTypeOpInterface>(op); | ||||
|       if (failed(shape_type_op.reifyReturnTypeShapes( | ||||
|               rewriter, shape_type_op->getOperands(), results_shape))) | ||||
|       if (failed(shape_type_op.reifyReturnTypeShapes(rewriter, operands, | ||||
|                                                      results_shape))) | ||||
|         return failure(); | ||||
| 
 | ||||
|       bufferArgs[2] = InsertDynamicAllocAndDealloc( | ||||
|  | @ -585,9 +597,9 @@ struct HloLegalizeToLhlo | |||
|     target.addLegalDialect<shape::ShapeDialect>(); | ||||
|     target.addLegalDialect<tensor::TensorDialect>(); | ||||
|     target.addIllegalDialect<mhlo::MhloDialect>(); | ||||
|     // Declare tensor_load and tensor_store illegal.
 | ||||
|     target.addIllegalOp<mlir::memref::TensorLoadOp, | ||||
|                         mlir::memref::TensorStoreOp>(); | ||||
|     // Declare tensor_store illegal. tensor_load may be used to reify output
 | ||||
|     // shape computation during dialect conversion and will be handled later.
 | ||||
|     target.addIllegalOp<mlir::memref::TensorStoreOp>(); | ||||
|     // buffer_cast is illegal if it has uses.
 | ||||
|     // TODO(b/175670649) Make buffer_cast illegal.
 | ||||
|     target.addDynamicallyLegalOp<mlir::memref::BufferCastOp>( | ||||
|  |  | |||
|  | @ -1,4 +1,4 @@ | |||
| // RUN: mlir-hlo-opt -hlo-legalize-to-lhlo -buffer-hoisting -buffer-deallocation %s -o - | FileCheck %s | ||||
| // RUN: mlir-hlo-opt -hlo-legalize-to-lhlo -buffer-hoisting -buffer-deallocation -canonicalize %s -o - | FileCheck %s | ||||
| 
 | ||||
| // CHECK-LABEL: func @func_op_unranked_arg_result | ||||
| func @func_op_unranked_arg_result(%arg0: tensor<*xf32>) -> tensor<*xf32> { | ||||
|  |  | |||
|  | @ -466,7 +466,7 @@ func @xor(%operand0: tensor<2x2xi32>, %operand1: tensor<2x2xi32>) | |||
| func @add_dyn(%lhs: tensor<?x?xf32>, %rhs: tensor<?x?xf32>) -> tensor<?x?xf32> { | ||||
|   %result = "mhlo.add"(%lhs, %rhs) | ||||
|       : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32> | ||||
|   // CHECK: %[[SHAPE:.*]] = shape.shape_of %arg0 : memref<?x?xf32> -> tensor<2xindex> | ||||
|   // CHECK: %[[SHAPE:.*]] = shape.shape_of %[[INPUT:.*]] : tensor<?x?xf32> -> tensor<2xindex> | ||||
|   // CHECK: %[[EE0:.*]] = tensor.extract %[[SHAPE]][%[[C0]]] : tensor<2xindex> | ||||
|   // CHECK: %[[EE1:.*]] = tensor.extract %[[SHAPE]][%[[C1]]] : tensor<2xindex> | ||||
|   // CHECK: %[[RESULT:.*]] = memref.alloc(%[[EE0]], %[[EE1]]) | ||||
|  | @ -482,7 +482,7 @@ func @add_dyn(%lhs: tensor<?x?xf32>, %rhs: tensor<?x?xf32>) -> tensor<?x?xf32> { | |||
| func @tanh_dyn(%arg0: tensor<?x?xf32>) -> tensor<?x?xf32> { | ||||
|   %result = "mhlo.tanh"(%arg0) | ||||
|       : (tensor<?x?xf32>) -> tensor<?x?xf32> | ||||
|   // CHECK: %[[SHAPE:.*]] = shape.shape_of %arg0 : memref<?x?xf32> -> tensor<2xindex> | ||||
|   // CHECK: %[[SHAPE:.*]] = shape.shape_of %[[INPUT:.*]] : tensor<?x?xf32> -> tensor<2xindex> | ||||
|   // CHECK: %[[EE0:.*]] = tensor.extract %[[SHAPE]][%[[C0]]] : tensor<2xindex> | ||||
|   // CHECK: %[[EE1:.*]] = tensor.extract %[[SHAPE]][%[[C1]]] : tensor<2xindex> | ||||
|   // CHECK: %[[RESULT:.*]] = memref.alloc(%[[EE0]], %[[EE1]]) | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue