[MLIR][LHLO] Convert mhlo.dynamic_reshape -> lhlo.reshape_memref_cast.
PiperOrigin-RevId: 320149593
This commit is contained in:
		
							parent
							
								
									8692fde3f9
								
							
						
					
					
						commit
						e8cfdee592
					
				|  | @ -470,14 +470,6 @@ def ReshapeMemRefCastOp: Op<LHLO_Dialect, "reshape_memref_cast", [ | |||
|   ); | ||||
|   let results = (outs AnyRankedOrUnrankedMemRef:$result); | ||||
| 
 | ||||
|   let builders = [OpBuilder< | ||||
|     "OpBuilder &builder, OperationState &result, MemRefType resultType, " # | ||||
|     "Value operand, Value shape", [{ | ||||
|        result.addOperands(operand); | ||||
|        result.addOperands(shape); | ||||
|        result.types.push_back(resultType); | ||||
|      }]>]; | ||||
| 
 | ||||
|   let extraClassDeclaration = [{ | ||||
|     MemRefType getType() { return getResult().getType().cast<MemRefType>(); } | ||||
|   }]; | ||||
|  |  | |||
|  | @ -220,6 +220,31 @@ struct HloToLhloDynamicBroadcastInDimOpConverter | |||
|   } | ||||
| }; | ||||
| 
 | ||||
| struct HloToLhloDynamicReshapeConverter | ||||
|     : public BaseOpConversion<mhlo::DynamicReshapeOp> { | ||||
|  public: | ||||
|   using BaseOpConversion<mhlo::DynamicReshapeOp>::BaseOpConversion; | ||||
| 
 | ||||
|   LogicalResult matchAndRewrite( | ||||
|       mhlo::DynamicReshapeOp op, ArrayRef<Value> operands, | ||||
|       ConversionPatternRewriter& rewriter) const final { | ||||
|     Type result_type; | ||||
|     if (auto ranked_type = op.getType().dyn_cast<RankedTensorType>()) { | ||||
|       result_type = | ||||
|           MemRefType::get(ranked_type.getShape(), ranked_type.getElementType()); | ||||
|     } else if (auto unranked_type = | ||||
|                    op.getType().dyn_cast<UnrankedTensorType>()) { | ||||
|       result_type = UnrankedMemRefType::get(unranked_type.getElementType(), 0); | ||||
|     } else { | ||||
|       return failure(); | ||||
|     } | ||||
|     mhlo::DynamicReshapeOp::Adaptor adaptor(operands); | ||||
|     rewriter.replaceOpWithNewOp<xla_lhlo::ReshapeMemRefCastOp>( | ||||
|         op, result_type, adaptor.operand(), adaptor.output_shape()); | ||||
|     return success(); | ||||
|   } | ||||
| }; | ||||
| 
 | ||||
| struct HloToLhloReduceOpConverter : public BaseOpConversion<mhlo::ReduceOp> { | ||||
|  public: | ||||
|   using BaseOpConversion<mhlo::ReduceOp>::BaseOpConversion; | ||||
|  | @ -441,6 +466,7 @@ void populateHLOToLHLOConversionPattern( | |||
|   // clang-format off
 | ||||
|   patterns->insert< | ||||
|       HloToLhloDynamicBroadcastInDimOpConverter, | ||||
|       HloToLhloDynamicReshapeConverter, | ||||
|       HloToLhloOpConverter<mhlo::AbsOp>, | ||||
|       HloToLhloOpConverter<mhlo::AddOp>, | ||||
|       HloToLhloOpConverter<mhlo::AndOp>, | ||||
|  |  | |||
|  | @ -6,3 +6,29 @@ func @func_op_unranked_arg_result(%arg0: tensor<*xf32>) -> tensor<*xf32> { | |||
| } | ||||
| // CHECK-SAME: ([[ARG:%.*]]: memref<*xf32>) -> memref<*xf32> | ||||
| // CHECK-NEXT: return [[ARG]] : memref<*xf32> | ||||
| 
 | ||||
| // ----- | ||||
| 
 | ||||
| // CHECK-LABEL: func @dynamic_reshape_from_unranked | ||||
| func @dynamic_reshape_from_unranked( | ||||
|          %operand: tensor<*xf32>, %shape: tensor<1xi32>) -> tensor<?xf32> { | ||||
|   %reshaped = "mhlo.dynamic_reshape"(%operand, %shape) | ||||
|       : (tensor<*xf32>, tensor<1xi32>) -> tensor<?xf32> | ||||
|   return %reshaped : tensor<?xf32> | ||||
| } | ||||
| // CHECK-SAME: ([[ARG:%.*]]: memref<*xf32>, [[SHAPE:%.*]]: memref<1xi32>) | ||||
| // CHECK-NEXT: reshape_memref_cast [[ARG]]([[SHAPE]]) | ||||
| // CHECK-SAME:   : (memref<*xf32>, memref<1xi32>) -> memref<?xf32> | ||||
| 
 | ||||
| // ----- | ||||
| 
 | ||||
| // CHECK-LABEL: func @dynamic_reshape_to_unranked | ||||
| func @dynamic_reshape_to_unranked( | ||||
|          %operand: tensor<?xf32>, %shape: tensor<?xi32>) -> tensor<*xf32> { | ||||
|   %reshaped = "mhlo.dynamic_reshape"(%operand, %shape) | ||||
|       : (tensor<?xf32>, tensor<?xi32>) -> tensor<*xf32> | ||||
|   return %reshaped : tensor<*xf32> | ||||
| } | ||||
| // CHECK-SAME: ([[ARG:%.*]]: memref<?xf32>, [[SHAPE:%.*]]: memref<?xi32>) | ||||
| // CHECK-NEXT: reshape_memref_cast [[ARG]]([[SHAPE]]) | ||||
| // CHECK-SAME:   : (memref<?xf32>, memref<?xi32>) -> memref<*xf32> | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue