{mhlo.is_finite, lmhlo.is_finite} -> linalg.generic conversion
PiperOrigin-RevId: 334414295
This commit is contained in:
parent
336ee14538
commit
39389587d2
|
@ -57,6 +57,7 @@ MAP_HLO_TO_LHLO(FloorOp);
|
||||||
MAP_HLO_TO_LHLO(GatherOp);
|
MAP_HLO_TO_LHLO(GatherOp);
|
||||||
MAP_HLO_TO_LHLO(ImagOp);
|
MAP_HLO_TO_LHLO(ImagOp);
|
||||||
MAP_HLO_TO_LHLO(IotaOp);
|
MAP_HLO_TO_LHLO(IotaOp);
|
||||||
|
MAP_HLO_TO_LHLO(IsFiniteOp);
|
||||||
MAP_HLO_TO_LHLO(LogOp);
|
MAP_HLO_TO_LHLO(LogOp);
|
||||||
MAP_HLO_TO_LHLO(MaxOp);
|
MAP_HLO_TO_LHLO(MaxOp);
|
||||||
MAP_HLO_TO_LHLO(MinOp);
|
MAP_HLO_TO_LHLO(MinOp);
|
||||||
|
|
|
@ -345,6 +345,22 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::FloorOp>(Location loc,
|
||||||
loc, result_types, args, b);
|
loc, result_types, args, b);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
inline Value MapLhloOpToStdScalarOp<lmhlo::IsFiniteOp>(
|
||||||
|
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
|
||||||
|
OpBuilder* b) {
|
||||||
|
if (args[0].getType().isa<FloatType>()) {
|
||||||
|
auto pos_inf = APFloat::getInf(
|
||||||
|
args[0].getType().cast<FloatType>().getFloatSemantics());
|
||||||
|
auto const_pos_inf =
|
||||||
|
b->create<ConstantOp>(loc, b->getFloatAttr(args[0].getType(), pos_inf));
|
||||||
|
Value abs_x = b->create<::mlir::AbsFOp>(loc, args[0]);
|
||||||
|
return b->create<::mlir::CmpFOp>(loc, CmpFPredicate::ONE, abs_x,
|
||||||
|
const_pos_inf);
|
||||||
|
}
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
/// Implements the conversion of HLO op to scalar op (to use within region of a
|
/// Implements the conversion of HLO op to scalar op (to use within region of a
|
||||||
/// linalg.generic op) for compare-select style operations like min/max.
|
/// linalg.generic op) for compare-select style operations like min/max.
|
||||||
template <typename... Args>
|
template <typename... Args>
|
||||||
|
|
|
@ -848,6 +848,7 @@ void populateLHLOToLinalgConversionPattern(MLIRContext* context,
|
||||||
PointwiseToLinalgConverter<lmhlo::SqrtOp>,
|
PointwiseToLinalgConverter<lmhlo::SqrtOp>,
|
||||||
PointwiseToLinalgConverter<lmhlo::SubOp>,
|
PointwiseToLinalgConverter<lmhlo::SubOp>,
|
||||||
PointwiseToLinalgConverter<lmhlo::TanhOp>,
|
PointwiseToLinalgConverter<lmhlo::TanhOp>,
|
||||||
|
PointwiseToLinalgConverter<lmhlo::IsFiniteOp>,
|
||||||
ReshapeOpConverter<lmhlo::ReshapeOp>,
|
ReshapeOpConverter<lmhlo::ReshapeOp>,
|
||||||
ReverseConverter<lmhlo::ReverseOp>,
|
ReverseConverter<lmhlo::ReverseOp>,
|
||||||
ScalarPointwiseToStandardConverter<lmhlo::AddOp>,
|
ScalarPointwiseToStandardConverter<lmhlo::AddOp>,
|
||||||
|
@ -955,6 +956,7 @@ void populateHLOToLinalgConversionPattern(MLIRContext* context,
|
||||||
PointwiseToLinalgConverter<mhlo::SqrtOp, false>,
|
PointwiseToLinalgConverter<mhlo::SqrtOp, false>,
|
||||||
PointwiseToLinalgConverter<mhlo::SubOp, false>,
|
PointwiseToLinalgConverter<mhlo::SubOp, false>,
|
||||||
PointwiseToLinalgConverter<mhlo::TanhOp, false>,
|
PointwiseToLinalgConverter<mhlo::TanhOp, false>,
|
||||||
|
PointwiseToLinalgConverter<mhlo::IsFiniteOp, false>,
|
||||||
ReshapeOpConverter<mhlo::ReshapeOp, false>,
|
ReshapeOpConverter<mhlo::ReshapeOp, false>,
|
||||||
ReverseConverter<mhlo::ReverseOp, false>,
|
ReverseConverter<mhlo::ReverseOp, false>,
|
||||||
TransposeConverter<mhlo::TransposeOp, false>>(context);
|
TransposeConverter<mhlo::TransposeOp, false>>(context);
|
||||||
|
|
|
@ -252,6 +252,20 @@ func @copy(%input: tensor<2x4x8xf32>) -> tensor<2x4x8xf32> {
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @is_finte
|
||||||
|
func @is_finte(%input: tensor<2x2xf32>) -> tensor<2x2xi1> {
|
||||||
|
%0 = "mhlo.is_finite"(%input) : (tensor<2x2xf32>) -> tensor<2x2xi1>
|
||||||
|
return %0 : tensor<2x2xi1>
|
||||||
|
}
|
||||||
|
// CHECK: linalg.generic
|
||||||
|
// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: f32
|
||||||
|
// CHECK-NEXT: %[[POS_INF:.+]] = constant 0x7F800000 : f32
|
||||||
|
// CHECK-NEXT: %[[ABS_X:.+]] = absf %[[OPERAND_IN]] : f32
|
||||||
|
// CHECK-NEXT: %[[RESULT:.+]] = cmpf "one", %[[ABS_X]], %[[POS_INF]] : f32
|
||||||
|
// CHECK-NEXT: linalg.yield %[[RESULT]] : i1
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
// CHECK-LABEL: func @select
|
// CHECK-LABEL: func @select
|
||||||
func @select(%pred: tensor<2x2xi1>, %lhs: tensor<2x2xf32>,
|
func @select(%pred: tensor<2x2xi1>, %lhs: tensor<2x2xf32>,
|
||||||
%rhs: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
%rhs: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
||||||
|
|
|
@ -125,6 +125,20 @@ func @copy(%in: memref<2x4x8xf32>, %out: memref<2x4x8xf32>) {
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @is_finte
|
||||||
|
func @is_finte(%input: memref<2x2xf32>, %result: memref<2x2xi1>) {
|
||||||
|
"lmhlo.is_finite"(%input, %result) : (memref<2x2xf32>, memref<2x2xi1>) -> ()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// CHECK: linalg.generic
|
||||||
|
// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: f32, %[[RESULT_OUT:.*]]):
|
||||||
|
// CHECK-NEXT: %[[POS_INF:.+]] = constant 0x7F800000 : f32
|
||||||
|
// CHECK-NEXT: %[[ABS_X:.+]] = absf %[[OPERAND_IN]] : f32
|
||||||
|
// CHECK-NEXT: %[[RESULT:.+]] = cmpf "one", %[[ABS_X]], %[[POS_INF]] : f32
|
||||||
|
// CHECK-NEXT: linalg.yield %[[RESULT]] : i1
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
// CHECK-LABEL: func @float_cmp
|
// CHECK-LABEL: func @float_cmp
|
||||||
func @float_cmp(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>,
|
func @float_cmp(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>,
|
||||||
%result: memref<2x2xi1>) {
|
%result: memref<2x2xi1>) {
|
||||||
|
|
Loading…
Reference in New Issue