Lower mhlo.floor to lmhlo to linalg

PiperOrigin-RevId: 329304882
This commit is contained in:
Benjamin Kramer 2020-08-31 08:15:32 -07:00 committed by TensorFlow MLIR Team
parent acf7413d40
commit 049f034116
6 changed files with 37 additions and 0 deletions

View File

@ -52,6 +52,7 @@ MAP_HLO_TO_LHLO(CosOp);
MAP_HLO_TO_LHLO(DivOp); MAP_HLO_TO_LHLO(DivOp);
MAP_HLO_TO_LHLO(DotOp); MAP_HLO_TO_LHLO(DotOp);
MAP_HLO_TO_LHLO(ExpOp); MAP_HLO_TO_LHLO(ExpOp);
MAP_HLO_TO_LHLO(FloorOp);
MAP_HLO_TO_LHLO(GatherOp); MAP_HLO_TO_LHLO(GatherOp);
MAP_HLO_TO_LHLO(ImagOp); MAP_HLO_TO_LHLO(ImagOp);
MAP_HLO_TO_LHLO(IotaOp); MAP_HLO_TO_LHLO(IotaOp);

View File

@ -336,6 +336,15 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::SinOp>(Location loc,
loc, result_types, args, b); loc, result_types, args, b);
} }
template <>
inline Value MapLhloOpToStdScalarOp<lmhlo::FloorOp>(Location loc,
ArrayRef<Type> result_types,
ArrayRef<Value> args,
OpBuilder* b) {
return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::FloorFOp>{}(
loc, result_types, args, b);
}
/// Implements the conversion of HLO op to scalar op (to use within region of a /// Implements the conversion of HLO op to scalar op (to use within region of a
/// linalg.generic op) for compare-select style operations like min/max. /// linalg.generic op) for compare-select style operations like min/max.
template <typename... Args> template <typename... Args>

View File

@ -487,6 +487,7 @@ void populateHLOToLHLOConversionPattern(
HloToLhloOpConverter<mhlo::DivOp>, HloToLhloOpConverter<mhlo::DivOp>,
HloToLhloOpConverter<mhlo::DotOp>, HloToLhloOpConverter<mhlo::DotOp>,
HloToLhloOpConverter<mhlo::ExpOp>, HloToLhloOpConverter<mhlo::ExpOp>,
HloToLhloOpConverter<mhlo::FloorOp>,
HloToLhloOpConverter<mhlo::GatherOp>, HloToLhloOpConverter<mhlo::GatherOp>,
HloToLhloOpConverter<mhlo::ImagOp>, HloToLhloOpConverter<mhlo::ImagOp>,
HloToLhloOpConverter<mhlo::IotaOp>, HloToLhloOpConverter<mhlo::IotaOp>,

View File

@ -822,6 +822,7 @@ void populateLHLOToLinalgConversionPattern(MLIRContext* context,
PointwiseToLinalgConverter<lmhlo::CosOp>, PointwiseToLinalgConverter<lmhlo::CosOp>,
PointwiseToLinalgConverter<lmhlo::DivOp>, PointwiseToLinalgConverter<lmhlo::DivOp>,
PointwiseToLinalgConverter<lmhlo::ExpOp>, PointwiseToLinalgConverter<lmhlo::ExpOp>,
PointwiseToLinalgConverter<lmhlo::FloorOp>,
PointwiseToLinalgConverter<lmhlo::ImagOp>, PointwiseToLinalgConverter<lmhlo::ImagOp>,
PointwiseToLinalgConverter<lmhlo::LogOp>, PointwiseToLinalgConverter<lmhlo::LogOp>,
PointwiseToLinalgConverter<lmhlo::MaxOp>, PointwiseToLinalgConverter<lmhlo::MaxOp>,
@ -929,6 +930,7 @@ void populateHLOToLinalgConversionPattern(MLIRContext* context,
PointwiseToLinalgConverter<mhlo::CosOp, false>, PointwiseToLinalgConverter<mhlo::CosOp, false>,
PointwiseToLinalgConverter<mhlo::DivOp, false>, PointwiseToLinalgConverter<mhlo::DivOp, false>,
PointwiseToLinalgConverter<mhlo::ExpOp, false>, PointwiseToLinalgConverter<mhlo::ExpOp, false>,
PointwiseToLinalgConverter<mhlo::FloorOp, false>,
PointwiseToLinalgConverter<mhlo::ImagOp, false>, PointwiseToLinalgConverter<mhlo::ImagOp, false>,
PointwiseToLinalgConverter<mhlo::LogOp, false>, PointwiseToLinalgConverter<mhlo::LogOp, false>,
PointwiseToLinalgConverter<mhlo::MaxOp, false>, PointwiseToLinalgConverter<mhlo::MaxOp, false>,

View File

@ -320,6 +320,18 @@ func @cos(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
// ----- // -----
// BOTH-LABEL: func @floor
func @floor(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
%tensor_operand = tensor_load %operand : memref<2x2xf32>
%tensor_result = "mhlo.floor"(%tensor_operand)
: (tensor<2x2xf32>) -> tensor<2x2xf32>
// BOTH: "lmhlo.floor"(%{{.*}}, %{{.*}})
tensor_store %tensor_result, %result : memref<2x2xf32>
return
}
// -----
// BOTH-LABEL: func @neg // BOTH-LABEL: func @neg
func @neg(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { func @neg(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
%tensor_operand = tensor_load %operand : memref<2x2xf32> %tensor_operand = tensor_load %operand : memref<2x2xf32>

View File

@ -496,6 +496,18 @@ func @sin(%input: memref<2x2xf32>,
// ----- // -----
// CHECK-LABEL: func @floor
func @floor(%input: memref<2x2xf32>, %result: memref<2x2xf32>) {
"lmhlo.floor"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
return
}
// CHECK: linalg.generic
// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: f32, %[[RESULT_OUT:.*]]):
// CHECK-NEXT: %[[RESULT:.*]] = floorf %[[OPERAND_IN]] : f32
// CHECK-NEXT: linalg.yield %[[RESULT]] : f32
// -----
// CHECK-LABEL: func @negf // CHECK-LABEL: func @negf
func @negf(%input: memref<2x2xf32>, %result: memref<2x2xf32>) { func @negf(%input: memref<2x2xf32>, %result: memref<2x2xf32>) {
"lmhlo.negate"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> () "lmhlo.negate"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> ()