{mhlo.is_finite, lmhlo.is_finite} -> linalg.generic conversion

PiperOrigin-RevId: 334414295
This commit is contained in:
Ahmed S. Taei 2020-09-29 10:49:53 -07:00 committed by TensorFlow MLIR Team
parent 336ee14538
commit 39389587d2
5 changed files with 47 additions and 0 deletions

View File

@ -57,6 +57,7 @@ MAP_HLO_TO_LHLO(FloorOp);
MAP_HLO_TO_LHLO(GatherOp);
MAP_HLO_TO_LHLO(ImagOp);
MAP_HLO_TO_LHLO(IotaOp);
MAP_HLO_TO_LHLO(IsFiniteOp);
MAP_HLO_TO_LHLO(LogOp);
MAP_HLO_TO_LHLO(MaxOp);
MAP_HLO_TO_LHLO(MinOp);

View File

@ -345,6 +345,22 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::FloorOp>(Location loc,
loc, result_types, args, b);
}
template <>
inline Value MapLhloOpToStdScalarOp<lmhlo::IsFiniteOp>(
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
OpBuilder* b) {
if (args[0].getType().isa<FloatType>()) {
auto pos_inf = APFloat::getInf(
args[0].getType().cast<FloatType>().getFloatSemantics());
auto const_pos_inf =
b->create<ConstantOp>(loc, b->getFloatAttr(args[0].getType(), pos_inf));
Value abs_x = b->create<::mlir::AbsFOp>(loc, args[0]);
return b->create<::mlir::CmpFOp>(loc, CmpFPredicate::ONE, abs_x,
const_pos_inf);
}
return nullptr;
}
/// Implements the conversion of HLO op to scalar op (to use within region of a
/// linalg.generic op) for compare-select style operations like min/max.
template <typename... Args>

View File

@ -848,6 +848,7 @@ void populateLHLOToLinalgConversionPattern(MLIRContext* context,
PointwiseToLinalgConverter<lmhlo::SqrtOp>,
PointwiseToLinalgConverter<lmhlo::SubOp>,
PointwiseToLinalgConverter<lmhlo::TanhOp>,
PointwiseToLinalgConverter<lmhlo::IsFiniteOp>,
ReshapeOpConverter<lmhlo::ReshapeOp>,
ReverseConverter<lmhlo::ReverseOp>,
ScalarPointwiseToStandardConverter<lmhlo::AddOp>,
@ -955,6 +956,7 @@ void populateHLOToLinalgConversionPattern(MLIRContext* context,
PointwiseToLinalgConverter<mhlo::SqrtOp, false>,
PointwiseToLinalgConverter<mhlo::SubOp, false>,
PointwiseToLinalgConverter<mhlo::TanhOp, false>,
PointwiseToLinalgConverter<mhlo::IsFiniteOp, false>,
ReshapeOpConverter<mhlo::ReshapeOp, false>,
ReverseConverter<mhlo::ReverseOp, false>,
TransposeConverter<mhlo::TransposeOp, false>>(context);

View File

@ -252,6 +252,20 @@ func @copy(%input: tensor<2x4x8xf32>) -> tensor<2x4x8xf32> {
// -----
// CHECK-LABEL: func @is_finte
func @is_finte(%input: tensor<2x2xf32>) -> tensor<2x2xi1> {
%0 = "mhlo.is_finite"(%input) : (tensor<2x2xf32>) -> tensor<2x2xi1>
return %0 : tensor<2x2xi1>
}
// CHECK: linalg.generic
// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: f32
// CHECK-NEXT: %[[POS_INF:.+]] = constant 0x7F800000 : f32
// CHECK-NEXT: %[[ABS_X:.+]] = absf %[[OPERAND_IN]] : f32
// CHECK-NEXT: %[[RESULT:.+]] = cmpf "one", %[[ABS_X]], %[[POS_INF]] : f32
// CHECK-NEXT: linalg.yield %[[RESULT]] : i1
// -----
// CHECK-LABEL: func @select
func @select(%pred: tensor<2x2xi1>, %lhs: tensor<2x2xf32>,
%rhs: tensor<2x2xf32>) -> tensor<2x2xf32> {

View File

@ -125,6 +125,20 @@ func @copy(%in: memref<2x4x8xf32>, %out: memref<2x4x8xf32>) {
// -----
// CHECK-LABEL: func @is_finte
func @is_finte(%input: memref<2x2xf32>, %result: memref<2x2xi1>) {
"lmhlo.is_finite"(%input, %result) : (memref<2x2xf32>, memref<2x2xi1>) -> ()
return
}
// CHECK: linalg.generic
// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: f32, %[[RESULT_OUT:.*]]):
// CHECK-NEXT: %[[POS_INF:.+]] = constant 0x7F800000 : f32
// CHECK-NEXT: %[[ABS_X:.+]] = absf %[[OPERAND_IN]] : f32
// CHECK-NEXT: %[[RESULT:.+]] = cmpf "one", %[[ABS_X]], %[[POS_INF]] : f32
// CHECK-NEXT: linalg.yield %[[RESULT]] : i1
// -----
// CHECK-LABEL: func @float_cmp
func @float_cmp(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>,
%result: memref<2x2xi1>) {