Lower mhlo.floor to lmhlo to linalg
PiperOrigin-RevId: 329304882
This commit is contained in:
parent
acf7413d40
commit
049f034116
|
@ -52,6 +52,7 @@ MAP_HLO_TO_LHLO(CosOp);
|
|||
MAP_HLO_TO_LHLO(DivOp);
|
||||
MAP_HLO_TO_LHLO(DotOp);
|
||||
MAP_HLO_TO_LHLO(ExpOp);
|
||||
MAP_HLO_TO_LHLO(FloorOp);
|
||||
MAP_HLO_TO_LHLO(GatherOp);
|
||||
MAP_HLO_TO_LHLO(ImagOp);
|
||||
MAP_HLO_TO_LHLO(IotaOp);
|
||||
|
|
|
@ -336,6 +336,15 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::SinOp>(Location loc,
|
|||
loc, result_types, args, b);
|
||||
}
|
||||
|
||||
template <>
|
||||
inline Value MapLhloOpToStdScalarOp<lmhlo::FloorOp>(Location loc,
|
||||
ArrayRef<Type> result_types,
|
||||
ArrayRef<Value> args,
|
||||
OpBuilder* b) {
|
||||
return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::FloorFOp>{}(
|
||||
loc, result_types, args, b);
|
||||
}
|
||||
|
||||
/// Implements the conversion of HLO op to scalar op (to use within region of a
|
||||
/// linalg.generic op) for compare-select style operations like min/max.
|
||||
template <typename... Args>
|
||||
|
|
|
@ -487,6 +487,7 @@ void populateHLOToLHLOConversionPattern(
|
|||
HloToLhloOpConverter<mhlo::DivOp>,
|
||||
HloToLhloOpConverter<mhlo::DotOp>,
|
||||
HloToLhloOpConverter<mhlo::ExpOp>,
|
||||
HloToLhloOpConverter<mhlo::FloorOp>,
|
||||
HloToLhloOpConverter<mhlo::GatherOp>,
|
||||
HloToLhloOpConverter<mhlo::ImagOp>,
|
||||
HloToLhloOpConverter<mhlo::IotaOp>,
|
||||
|
|
|
@ -822,6 +822,7 @@ void populateLHLOToLinalgConversionPattern(MLIRContext* context,
|
|||
PointwiseToLinalgConverter<lmhlo::CosOp>,
|
||||
PointwiseToLinalgConverter<lmhlo::DivOp>,
|
||||
PointwiseToLinalgConverter<lmhlo::ExpOp>,
|
||||
PointwiseToLinalgConverter<lmhlo::FloorOp>,
|
||||
PointwiseToLinalgConverter<lmhlo::ImagOp>,
|
||||
PointwiseToLinalgConverter<lmhlo::LogOp>,
|
||||
PointwiseToLinalgConverter<lmhlo::MaxOp>,
|
||||
|
@ -929,6 +930,7 @@ void populateHLOToLinalgConversionPattern(MLIRContext* context,
|
|||
PointwiseToLinalgConverter<mhlo::CosOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::DivOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::ExpOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::FloorOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::ImagOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::LogOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::MaxOp, false>,
|
||||
|
|
|
@ -320,6 +320,18 @@ func @cos(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
|||
|
||||
// -----
|
||||
|
||||
// BOTH-LABEL: func @floor
|
||||
func @floor(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||
%tensor_operand = tensor_load %operand : memref<2x2xf32>
|
||||
%tensor_result = "mhlo.floor"(%tensor_operand)
|
||||
: (tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
// BOTH: "lmhlo.floor"(%{{.*}}, %{{.*}})
|
||||
tensor_store %tensor_result, %result : memref<2x2xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// BOTH-LABEL: func @neg
|
||||
func @neg(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||
%tensor_operand = tensor_load %operand : memref<2x2xf32>
|
||||
|
|
|
@ -496,6 +496,18 @@ func @sin(%input: memref<2x2xf32>,
|
|||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @floor
|
||||
func @floor(%input: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||
"lmhlo.floor"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||
return
|
||||
}
|
||||
// CHECK: linalg.generic
|
||||
// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: f32, %[[RESULT_OUT:.*]]):
|
||||
// CHECK-NEXT: %[[RESULT:.*]] = floorf %[[OPERAND_IN]] : f32
|
||||
// CHECK-NEXT: linalg.yield %[[RESULT]] : f32
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @negf
|
||||
func @negf(%input: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||
"lmhlo.negate"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||
|
|
Loading…
Reference in New Issue