{mhlo.is_finite, lmhlo.is_finite} -> linalg.generic conversion
PiperOrigin-RevId: 334414295
This commit is contained in:
parent
336ee14538
commit
39389587d2
|
@ -57,6 +57,7 @@ MAP_HLO_TO_LHLO(FloorOp);
|
|||
MAP_HLO_TO_LHLO(GatherOp);
|
||||
MAP_HLO_TO_LHLO(ImagOp);
|
||||
MAP_HLO_TO_LHLO(IotaOp);
|
||||
MAP_HLO_TO_LHLO(IsFiniteOp);
|
||||
MAP_HLO_TO_LHLO(LogOp);
|
||||
MAP_HLO_TO_LHLO(MaxOp);
|
||||
MAP_HLO_TO_LHLO(MinOp);
|
||||
|
|
|
@ -345,6 +345,22 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::FloorOp>(Location loc,
|
|||
loc, result_types, args, b);
|
||||
}
|
||||
|
||||
template <>
|
||||
inline Value MapLhloOpToStdScalarOp<lmhlo::IsFiniteOp>(
|
||||
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
|
||||
OpBuilder* b) {
|
||||
if (args[0].getType().isa<FloatType>()) {
|
||||
auto pos_inf = APFloat::getInf(
|
||||
args[0].getType().cast<FloatType>().getFloatSemantics());
|
||||
auto const_pos_inf =
|
||||
b->create<ConstantOp>(loc, b->getFloatAttr(args[0].getType(), pos_inf));
|
||||
Value abs_x = b->create<::mlir::AbsFOp>(loc, args[0]);
|
||||
return b->create<::mlir::CmpFOp>(loc, CmpFPredicate::ONE, abs_x,
|
||||
const_pos_inf);
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
/// Implements the conversion of HLO op to scalar op (to use within region of a
|
||||
/// linalg.generic op) for compare-select style operations like min/max.
|
||||
template <typename... Args>
|
||||
|
|
|
@ -848,6 +848,7 @@ void populateLHLOToLinalgConversionPattern(MLIRContext* context,
|
|||
PointwiseToLinalgConverter<lmhlo::SqrtOp>,
|
||||
PointwiseToLinalgConverter<lmhlo::SubOp>,
|
||||
PointwiseToLinalgConverter<lmhlo::TanhOp>,
|
||||
PointwiseToLinalgConverter<lmhlo::IsFiniteOp>,
|
||||
ReshapeOpConverter<lmhlo::ReshapeOp>,
|
||||
ReverseConverter<lmhlo::ReverseOp>,
|
||||
ScalarPointwiseToStandardConverter<lmhlo::AddOp>,
|
||||
|
@ -955,6 +956,7 @@ void populateHLOToLinalgConversionPattern(MLIRContext* context,
|
|||
PointwiseToLinalgConverter<mhlo::SqrtOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::SubOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::TanhOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::IsFiniteOp, false>,
|
||||
ReshapeOpConverter<mhlo::ReshapeOp, false>,
|
||||
ReverseConverter<mhlo::ReverseOp, false>,
|
||||
TransposeConverter<mhlo::TransposeOp, false>>(context);
|
||||
|
|
|
@ -252,6 +252,20 @@ func @copy(%input: tensor<2x4x8xf32>) -> tensor<2x4x8xf32> {
|
|||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @is_finte
|
||||
func @is_finte(%input: tensor<2x2xf32>) -> tensor<2x2xi1> {
|
||||
%0 = "mhlo.is_finite"(%input) : (tensor<2x2xf32>) -> tensor<2x2xi1>
|
||||
return %0 : tensor<2x2xi1>
|
||||
}
|
||||
// CHECK: linalg.generic
|
||||
// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: f32
|
||||
// CHECK-NEXT: %[[POS_INF:.+]] = constant 0x7F800000 : f32
|
||||
// CHECK-NEXT: %[[ABS_X:.+]] = absf %[[OPERAND_IN]] : f32
|
||||
// CHECK-NEXT: %[[RESULT:.+]] = cmpf "one", %[[ABS_X]], %[[POS_INF]] : f32
|
||||
// CHECK-NEXT: linalg.yield %[[RESULT]] : i1
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @select
|
||||
func @select(%pred: tensor<2x2xi1>, %lhs: tensor<2x2xf32>,
|
||||
%rhs: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
||||
|
|
|
@ -125,6 +125,20 @@ func @copy(%in: memref<2x4x8xf32>, %out: memref<2x4x8xf32>) {
|
|||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @is_finte
|
||||
func @is_finte(%input: memref<2x2xf32>, %result: memref<2x2xi1>) {
|
||||
"lmhlo.is_finite"(%input, %result) : (memref<2x2xf32>, memref<2x2xi1>) -> ()
|
||||
return
|
||||
}
|
||||
// CHECK: linalg.generic
|
||||
// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: f32, %[[RESULT_OUT:.*]]):
|
||||
// CHECK-NEXT: %[[POS_INF:.+]] = constant 0x7F800000 : f32
|
||||
// CHECK-NEXT: %[[ABS_X:.+]] = absf %[[OPERAND_IN]] : f32
|
||||
// CHECK-NEXT: %[[RESULT:.+]] = cmpf "one", %[[ABS_X]], %[[POS_INF]] : f32
|
||||
// CHECK-NEXT: linalg.yield %[[RESULT]] : i1
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @float_cmp
|
||||
func @float_cmp(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>,
|
||||
%result: memref<2x2xi1>) {
|
||||
|
|
Loading…
Reference in New Issue