Lower Expm1 kernel to math.ExpM1.
PiperOrigin-RevId: 358152908
This commit is contained in:
parent
90b222cf6e
commit
37e31f8b26
|
@ -55,6 +55,7 @@ MAP_HLO_TO_LHLO(CustomCallOp);
|
|||
MAP_HLO_TO_LHLO(DivOp);
|
||||
MAP_HLO_TO_LHLO(DotOp);
|
||||
MAP_HLO_TO_LHLO(ExpOp);
|
||||
MAP_HLO_TO_LHLO(Expm1Op);
|
||||
MAP_HLO_TO_LHLO(FloorOp);
|
||||
MAP_HLO_TO_LHLO(GatherOp);
|
||||
MAP_HLO_TO_LHLO(ImagOp);
|
||||
|
|
|
@ -251,6 +251,15 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::ExpOp>(Location loc,
|
|||
loc, result_types, args, b);
|
||||
}
|
||||
|
||||
template <>
|
||||
inline Value MapLhloOpToStdScalarOp<lmhlo::Expm1Op>(Location loc,
|
||||
ArrayRef<Type> result_types,
|
||||
ArrayRef<Value> args,
|
||||
OpBuilder* b) {
|
||||
return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::math::ExpM1Op>{}(
|
||||
loc, result_types, args, b);
|
||||
}
|
||||
|
||||
template <>
|
||||
inline Value MapLhloOpToStdScalarOp<lmhlo::CeilOp>(Location loc,
|
||||
ArrayRef<Type> result_types,
|
||||
|
|
|
@ -658,6 +658,7 @@ void populateHLOToLHLOConversionPattern(MLIRContext* context,
|
|||
HloToLhloOpConverter<mhlo::DivOp>,
|
||||
HloToLhloOpConverter<mhlo::DotOp>,
|
||||
HloToLhloOpConverter<mhlo::ExpOp>,
|
||||
HloToLhloOpConverter<mhlo::Expm1Op>,
|
||||
HloToLhloOpConverter<mhlo::FloorOp>,
|
||||
HloToLhloOpConverter<mhlo::GatherOp>,
|
||||
HloToLhloOpConverter<mhlo::ImagOp>,
|
||||
|
|
|
@ -1398,6 +1398,7 @@ void populateLHLOToLinalgConversionPattern(MLIRContext* context,
|
|||
PointwiseToLinalgConverter<lmhlo::CosOp>,
|
||||
PointwiseToLinalgConverter<lmhlo::DivOp>,
|
||||
PointwiseToLinalgConverter<lmhlo::ExpOp>,
|
||||
PointwiseToLinalgConverter<lmhlo::Expm1Op>,
|
||||
PointwiseToLinalgConverter<lmhlo::FloorOp>,
|
||||
PointwiseToLinalgConverter<lmhlo::ImagOp>,
|
||||
PointwiseToLinalgConverter<lmhlo::IsFiniteOp>,
|
||||
|
@ -1524,6 +1525,7 @@ void populateHLOToLinalgConversionPattern(MLIRContext* context,
|
|||
PointwiseToLinalgConverter<mhlo::CosOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::DivOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::ExpOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::Expm1Op, false>,
|
||||
PointwiseToLinalgConverter<mhlo::FloorOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::ImagOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::IsFiniteOp, false>,
|
||||
|
|
|
@ -87,6 +87,15 @@ func @exp(%operand: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
|||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @expm1
|
||||
func @expm1(%operand: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
||||
%result = "mhlo.exponential_minus_one"(%operand) : (tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
// CHECK: "lmhlo.exponential_minus_one"(%{{.*}}, %{{.*}})
|
||||
return %result : tensor<2x2xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @log
|
||||
func @log(%operand: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
||||
%result = "mhlo.log"(%operand) : (tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
|
|
|
@ -156,6 +156,16 @@ func @float_exp(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
|||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @float_expm1
|
||||
func @float_expm1(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
||||
// CHECK: linalg.generic
|
||||
// CHECK: expm1
|
||||
%0 = "mhlo.exponential_minus_one"(%arg0) : (tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
return %0 : tensor<2x2xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @float_log
|
||||
func @float_log(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
||||
// CHECK: linalg.generic
|
||||
|
|
Loading…
Reference in New Issue