Lower Expm1 kernel to math.ExpM1.

PiperOrigin-RevId: 358152908
This commit is contained in:
Adrian Kuegel 2021-02-18 04:53:34 -08:00 committed by TensorFlow MLIR Team
parent 90b222cf6e
commit 37e31f8b26
6 changed files with 32 additions and 0 deletions

View File

@ -55,6 +55,7 @@ MAP_HLO_TO_LHLO(CustomCallOp);
MAP_HLO_TO_LHLO(DivOp);
MAP_HLO_TO_LHLO(DotOp);
MAP_HLO_TO_LHLO(ExpOp);
MAP_HLO_TO_LHLO(Expm1Op);
MAP_HLO_TO_LHLO(FloorOp);
MAP_HLO_TO_LHLO(GatherOp);
MAP_HLO_TO_LHLO(ImagOp);

View File

@ -251,6 +251,15 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::ExpOp>(Location loc,
loc, result_types, args, b);
}
template <>
inline Value MapLhloOpToStdScalarOp<lmhlo::Expm1Op>(Location loc,
ArrayRef<Type> result_types,
ArrayRef<Value> args,
OpBuilder* b) {
return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::math::ExpM1Op>{}(
loc, result_types, args, b);
}
template <>
inline Value MapLhloOpToStdScalarOp<lmhlo::CeilOp>(Location loc,
ArrayRef<Type> result_types,

View File

@ -658,6 +658,7 @@ void populateHLOToLHLOConversionPattern(MLIRContext* context,
HloToLhloOpConverter<mhlo::DivOp>,
HloToLhloOpConverter<mhlo::DotOp>,
HloToLhloOpConverter<mhlo::ExpOp>,
HloToLhloOpConverter<mhlo::Expm1Op>,
HloToLhloOpConverter<mhlo::FloorOp>,
HloToLhloOpConverter<mhlo::GatherOp>,
HloToLhloOpConverter<mhlo::ImagOp>,

View File

@ -1398,6 +1398,7 @@ void populateLHLOToLinalgConversionPattern(MLIRContext* context,
PointwiseToLinalgConverter<lmhlo::CosOp>,
PointwiseToLinalgConverter<lmhlo::DivOp>,
PointwiseToLinalgConverter<lmhlo::ExpOp>,
PointwiseToLinalgConverter<lmhlo::Expm1Op>,
PointwiseToLinalgConverter<lmhlo::FloorOp>,
PointwiseToLinalgConverter<lmhlo::ImagOp>,
PointwiseToLinalgConverter<lmhlo::IsFiniteOp>,
@ -1524,6 +1525,7 @@ void populateHLOToLinalgConversionPattern(MLIRContext* context,
PointwiseToLinalgConverter<mhlo::CosOp, false>,
PointwiseToLinalgConverter<mhlo::DivOp, false>,
PointwiseToLinalgConverter<mhlo::ExpOp, false>,
PointwiseToLinalgConverter<mhlo::Expm1Op, false>,
PointwiseToLinalgConverter<mhlo::FloorOp, false>,
PointwiseToLinalgConverter<mhlo::ImagOp, false>,
PointwiseToLinalgConverter<mhlo::IsFiniteOp, false>,

View File

@ -87,6 +87,15 @@ func @exp(%operand: tensor<2x2xf32>) -> tensor<2x2xf32> {
// -----
// CHECK-LABEL: func @expm1
func @expm1(%operand: tensor<2x2xf32>) -> tensor<2x2xf32> {
%result = "mhlo.exponential_minus_one"(%operand) : (tensor<2x2xf32>) -> tensor<2x2xf32>
// CHECK: "lmhlo.exponential_minus_one"(%{{.*}}, %{{.*}})
return %result : tensor<2x2xf32>
}
// -----
// CHECK-LABEL: func @log
func @log(%operand: tensor<2x2xf32>) -> tensor<2x2xf32> {
%result = "mhlo.log"(%operand) : (tensor<2x2xf32>) -> tensor<2x2xf32>

View File

@ -156,6 +156,16 @@ func @float_exp(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
// -----
// CHECK-LABEL: func @float_expm1
func @float_expm1(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
// CHECK: linalg.generic
// CHECK: expm1
%0 = "mhlo.exponential_minus_one"(%arg0) : (tensor<2x2xf32>) -> tensor<2x2xf32>
return %0 : tensor<2x2xf32>
}
// -----
// CHECK-LABEL: func @float_log
func @float_log(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
// CHECK: linalg.generic