Lower mhlo.floor to lmhlo to linalg
PiperOrigin-RevId: 329304882
This commit is contained in:
		
							parent
							
								
									acf7413d40
								
							
						
					
					
						commit
						049f034116
					
				|  | @ -52,6 +52,7 @@ MAP_HLO_TO_LHLO(CosOp); | |||
| MAP_HLO_TO_LHLO(DivOp); | ||||
| MAP_HLO_TO_LHLO(DotOp); | ||||
| MAP_HLO_TO_LHLO(ExpOp); | ||||
| MAP_HLO_TO_LHLO(FloorOp); | ||||
| MAP_HLO_TO_LHLO(GatherOp); | ||||
| MAP_HLO_TO_LHLO(ImagOp); | ||||
| MAP_HLO_TO_LHLO(IotaOp); | ||||
|  |  | |||
|  | @ -336,6 +336,15 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::SinOp>(Location loc, | |||
|       loc, result_types, args, b); | ||||
| } | ||||
| 
 | ||||
| template <> | ||||
| inline Value MapLhloOpToStdScalarOp<lmhlo::FloorOp>(Location loc, | ||||
|                                                     ArrayRef<Type> result_types, | ||||
|                                                     ArrayRef<Value> args, | ||||
|                                                     OpBuilder* b) { | ||||
|   return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::FloorFOp>{}( | ||||
|       loc, result_types, args, b); | ||||
| } | ||||
| 
 | ||||
| /// Implements the conversion of HLO op to scalar op (to use within region of a
 | ||||
| /// linalg.generic op) for compare-select style operations like min/max.
 | ||||
| template <typename... Args> | ||||
|  |  | |||
|  | @ -487,6 +487,7 @@ void populateHLOToLHLOConversionPattern( | |||
|       HloToLhloOpConverter<mhlo::DivOp>, | ||||
|       HloToLhloOpConverter<mhlo::DotOp>, | ||||
|       HloToLhloOpConverter<mhlo::ExpOp>, | ||||
|       HloToLhloOpConverter<mhlo::FloorOp>, | ||||
|       HloToLhloOpConverter<mhlo::GatherOp>, | ||||
|       HloToLhloOpConverter<mhlo::ImagOp>, | ||||
|       HloToLhloOpConverter<mhlo::IotaOp>, | ||||
|  |  | |||
|  | @ -822,6 +822,7 @@ void populateLHLOToLinalgConversionPattern(MLIRContext* context, | |||
|                    PointwiseToLinalgConverter<lmhlo::CosOp>, | ||||
|                    PointwiseToLinalgConverter<lmhlo::DivOp>, | ||||
|                    PointwiseToLinalgConverter<lmhlo::ExpOp>, | ||||
|                    PointwiseToLinalgConverter<lmhlo::FloorOp>, | ||||
|                    PointwiseToLinalgConverter<lmhlo::ImagOp>, | ||||
|                    PointwiseToLinalgConverter<lmhlo::LogOp>, | ||||
|                    PointwiseToLinalgConverter<lmhlo::MaxOp>, | ||||
|  | @ -929,6 +930,7 @@ void populateHLOToLinalgConversionPattern(MLIRContext* context, | |||
|                PointwiseToLinalgConverter<mhlo::CosOp, false>, | ||||
|                PointwiseToLinalgConverter<mhlo::DivOp, false>, | ||||
|                PointwiseToLinalgConverter<mhlo::ExpOp, false>, | ||||
|                PointwiseToLinalgConverter<mhlo::FloorOp, false>, | ||||
|                PointwiseToLinalgConverter<mhlo::ImagOp, false>, | ||||
|                PointwiseToLinalgConverter<mhlo::LogOp, false>, | ||||
|                PointwiseToLinalgConverter<mhlo::MaxOp, false>, | ||||
|  |  | |||
|  | @ -320,6 +320,18 @@ func @cos(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { | |||
| 
 | ||||
| // ----- | ||||
| 
 | ||||
| // BOTH-LABEL: func @floor | ||||
| func @floor(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { | ||||
|   %tensor_operand = tensor_load %operand : memref<2x2xf32> | ||||
|   %tensor_result = "mhlo.floor"(%tensor_operand) | ||||
|       : (tensor<2x2xf32>) -> tensor<2x2xf32> | ||||
|   // BOTH: "lmhlo.floor"(%{{.*}}, %{{.*}}) | ||||
|   tensor_store %tensor_result, %result : memref<2x2xf32> | ||||
|   return | ||||
| } | ||||
| 
 | ||||
| // ----- | ||||
| 
 | ||||
| // BOTH-LABEL: func @neg | ||||
| func @neg(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { | ||||
|   %tensor_operand = tensor_load %operand : memref<2x2xf32> | ||||
|  |  | |||
|  | @ -496,6 +496,18 @@ func @sin(%input: memref<2x2xf32>, | |||
| 
 | ||||
| // ----- | ||||
| 
 | ||||
| // CHECK-LABEL: func @floor | ||||
| func @floor(%input: memref<2x2xf32>, %result: memref<2x2xf32>) { | ||||
|   "lmhlo.floor"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> () | ||||
|   return | ||||
| } | ||||
| // CHECK: linalg.generic | ||||
| // CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: f32, %[[RESULT_OUT:.*]]): | ||||
| // CHECK-NEXT:   %[[RESULT:.*]] = floorf %[[OPERAND_IN]] : f32 | ||||
| // CHECK-NEXT:   linalg.yield %[[RESULT]] : f32 | ||||
| 
 | ||||
| // ----- | ||||
| 
 | ||||
| // CHECK-LABEL: func @negf | ||||
| func @negf(%input: memref<2x2xf32>, %result: memref<2x2xf32>) { | ||||
|   "lmhlo.negate"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> () | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue