Lower Expm1 kernel to math.ExpM1.
PiperOrigin-RevId: 358152908
This commit is contained in:
		
							parent
							
								
									90b222cf6e
								
							
						
					
					
						commit
						37e31f8b26
					
				|  | @ -55,6 +55,7 @@ MAP_HLO_TO_LHLO(CustomCallOp); | |||
| MAP_HLO_TO_LHLO(DivOp); | ||||
| MAP_HLO_TO_LHLO(DotOp); | ||||
| MAP_HLO_TO_LHLO(ExpOp); | ||||
| MAP_HLO_TO_LHLO(Expm1Op); | ||||
| MAP_HLO_TO_LHLO(FloorOp); | ||||
| MAP_HLO_TO_LHLO(GatherOp); | ||||
| MAP_HLO_TO_LHLO(ImagOp); | ||||
|  |  | |||
|  | @ -251,6 +251,15 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::ExpOp>(Location loc, | |||
|       loc, result_types, args, b); | ||||
| } | ||||
| 
 | ||||
| template <> | ||||
| inline Value MapLhloOpToStdScalarOp<lmhlo::Expm1Op>(Location loc, | ||||
|                                                     ArrayRef<Type> result_types, | ||||
|                                                     ArrayRef<Value> args, | ||||
|                                                     OpBuilder* b) { | ||||
|   return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::math::ExpM1Op>{}( | ||||
|       loc, result_types, args, b); | ||||
| } | ||||
| 
 | ||||
| template <> | ||||
| inline Value MapLhloOpToStdScalarOp<lmhlo::CeilOp>(Location loc, | ||||
|                                                    ArrayRef<Type> result_types, | ||||
|  |  | |||
|  | @ -658,6 +658,7 @@ void populateHLOToLHLOConversionPattern(MLIRContext* context, | |||
|       HloToLhloOpConverter<mhlo::DivOp>, | ||||
|       HloToLhloOpConverter<mhlo::DotOp>, | ||||
|       HloToLhloOpConverter<mhlo::ExpOp>, | ||||
|       HloToLhloOpConverter<mhlo::Expm1Op>, | ||||
|       HloToLhloOpConverter<mhlo::FloorOp>, | ||||
|       HloToLhloOpConverter<mhlo::GatherOp>, | ||||
|       HloToLhloOpConverter<mhlo::ImagOp>, | ||||
|  |  | |||
|  | @ -1398,6 +1398,7 @@ void populateLHLOToLinalgConversionPattern(MLIRContext* context, | |||
|                    PointwiseToLinalgConverter<lmhlo::CosOp>, | ||||
|                    PointwiseToLinalgConverter<lmhlo::DivOp>, | ||||
|                    PointwiseToLinalgConverter<lmhlo::ExpOp>, | ||||
|                    PointwiseToLinalgConverter<lmhlo::Expm1Op>, | ||||
|                    PointwiseToLinalgConverter<lmhlo::FloorOp>, | ||||
|                    PointwiseToLinalgConverter<lmhlo::ImagOp>, | ||||
|                    PointwiseToLinalgConverter<lmhlo::IsFiniteOp>, | ||||
|  | @ -1524,6 +1525,7 @@ void populateHLOToLinalgConversionPattern(MLIRContext* context, | |||
|       PointwiseToLinalgConverter<mhlo::CosOp, false>, | ||||
|       PointwiseToLinalgConverter<mhlo::DivOp, false>, | ||||
|       PointwiseToLinalgConverter<mhlo::ExpOp, false>, | ||||
|       PointwiseToLinalgConverter<mhlo::Expm1Op, false>, | ||||
|       PointwiseToLinalgConverter<mhlo::FloorOp, false>, | ||||
|       PointwiseToLinalgConverter<mhlo::ImagOp, false>, | ||||
|       PointwiseToLinalgConverter<mhlo::IsFiniteOp, false>, | ||||
|  |  | |||
|  | @ -87,6 +87,15 @@ func @exp(%operand: tensor<2x2xf32>) -> tensor<2x2xf32> { | |||
| 
 | ||||
| // ----- | ||||
| 
 | ||||
| // CHECK-LABEL: func @expm1 | ||||
| func @expm1(%operand: tensor<2x2xf32>) -> tensor<2x2xf32> { | ||||
|   %result = "mhlo.exponential_minus_one"(%operand) : (tensor<2x2xf32>) -> tensor<2x2xf32> | ||||
|   // CHECK: "lmhlo.exponential_minus_one"(%{{.*}}, %{{.*}}) | ||||
|   return %result : tensor<2x2xf32> | ||||
| } | ||||
| 
 | ||||
| // ----- | ||||
| 
 | ||||
| // CHECK-LABEL: func @log | ||||
| func @log(%operand: tensor<2x2xf32>) -> tensor<2x2xf32> { | ||||
|   %result = "mhlo.log"(%operand) : (tensor<2x2xf32>) -> tensor<2x2xf32> | ||||
|  |  | |||
|  | @ -156,6 +156,16 @@ func @float_exp(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> { | |||
| 
 | ||||
| // ----- | ||||
| 
 | ||||
| // CHECK-LABEL: func @float_expm1 | ||||
| func @float_expm1(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> { | ||||
|   // CHECK: linalg.generic | ||||
|   // CHECK: expm1 | ||||
|   %0 = "mhlo.exponential_minus_one"(%arg0) : (tensor<2x2xf32>) -> tensor<2x2xf32> | ||||
|   return %0 : tensor<2x2xf32> | ||||
| } | ||||
| 
 | ||||
| // ----- | ||||
| 
 | ||||
| // CHECK-LABEL: func @float_log | ||||
| func @float_log(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> { | ||||
|   // CHECK: linalg.generic | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue