Lower mhlo.not to a xor with all ones

PiperOrigin-RevId: 334361499
This commit is contained in:
Benjamin Kramer 2020-09-29 05:58:52 -07:00 committed by TensorFlow MLIR Team
parent 26ac5baae4
commit 6459f12235
6 changed files with 44 additions and 0 deletions

View File

@ -62,6 +62,7 @@ MAP_HLO_TO_LHLO(MaxOp);
MAP_HLO_TO_LHLO(MinOp);
MAP_HLO_TO_LHLO(MulOp);
MAP_HLO_TO_LHLO(NegOp);
MAP_HLO_TO_LHLO(NotOp);
MAP_HLO_TO_LHLO(RealOp);
MAP_HLO_TO_LHLO(ReduceOp);
MAP_HLO_TO_LHLO(ReshapeOp);

View File

@ -431,6 +431,21 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::NegOp>(Location loc,
return nullptr;
}
template <>
inline Value MapLhloOpToStdScalarOp<lmhlo::NotOp>(Location loc,
ArrayRef<Type> result_types,
ArrayRef<Value> args,
OpBuilder* b) {
Type element_type = args.front().getType();
if (auto integer_type = element_type.dyn_cast<IntegerType>()) {
// lmhlo.not(x) -> x ^ -1
auto all_ones =
b->create<::mlir::ConstantIntOp>(loc, -1, integer_type.getWidth());
return b->create<::mlir::XOrOp>(loc, all_ones, args[0]);
}
return nullptr;
}
template <>
inline Value MapLhloOpToStdScalarOp<lmhlo::RsqrtOp>(Location loc,
ArrayRef<Type> result_types,

View File

@ -510,6 +510,7 @@ void populateHLOToLHLOConversionPattern(
HloToLhloOpConverter<mhlo::MinOp>,
HloToLhloOpConverter<mhlo::MulOp>,
HloToLhloOpConverter<mhlo::NegOp>,
HloToLhloOpConverter<mhlo::NotOp>,
HloToLhloOpConverter<mhlo::RealOp>,
HloToLhloOpConverter<mhlo::RemOp>,
HloToLhloOpConverter<mhlo::RsqrtOp>,

View File

@ -838,6 +838,7 @@ void populateLHLOToLinalgConversionPattern(MLIRContext* context,
PointwiseToLinalgConverter<lmhlo::MinOp>,
PointwiseToLinalgConverter<lmhlo::MulOp>,
PointwiseToLinalgConverter<lmhlo::NegOp>,
PointwiseToLinalgConverter<lmhlo::NotOp>,
PointwiseToLinalgConverter<lmhlo::RealOp>,
PointwiseToLinalgConverter<lmhlo::RemOp>,
PointwiseToLinalgConverter<lmhlo::RsqrtOp>,
@ -945,6 +946,7 @@ void populateHLOToLinalgConversionPattern(MLIRContext* context,
PointwiseToLinalgConverter<mhlo::MinOp, false>,
PointwiseToLinalgConverter<mhlo::MulOp, false>,
PointwiseToLinalgConverter<mhlo::NegOp, false>,
PointwiseToLinalgConverter<mhlo::NotOp, false>,
PointwiseToLinalgConverter<mhlo::RealOp, false>,
PointwiseToLinalgConverter<mhlo::RemOp, false>,
PointwiseToLinalgConverter<mhlo::RsqrtOp, false>,

View File

@ -344,6 +344,18 @@ func @neg(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
// -----
// BOTH-LABEL: func @not
func @not(%operand: memref<2x2xi32>, %result: memref<2x2xi32>) {
%tensor_operand = tensor_load %operand : memref<2x2xi32>
%tensor_result = "mhlo.not"(%tensor_operand)
: (tensor<2x2xi32>) -> tensor<2x2xi32>
// BOTH: "lmhlo.not"(%{{.*}}, %{{.*}})
tensor_store %tensor_result, %result : memref<2x2xi32>
return
}
// -----
// BOTH-LABEL: func @rsqrt
func @rsqrt(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
%tensor_operand = tensor_load %operand : memref<2x2xf32>

View File

@ -534,6 +534,19 @@ func @negi(%input: memref<2x2xi32>, %result: memref<2x2xi32>) {
// -----
// CHECK-LABEL: func @not
func @not(%input: memref<2x2xi64>, %result: memref<2x2xi64>) {
"lmhlo.not"(%input, %result) : (memref<2x2xi64>, memref<2x2xi64>) -> ()
return
}
// CHECK: linalg.generic
// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: i64, %[[RESULT_OUT:.*]]):
// CHECK-NEXT: %[[N1:.*]] = constant -1 : i64
// CHECK-NEXT: %[[RESULT:.*]] = xor %[[N1]], %[[OPERAND_IN]] : i64
// CHECK-NEXT: linalg.yield %[[RESULT]] : i64
// -----
// CHECK-LABEL: func @rem
func @remainder(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>,
%result: memref<2x2xf32>) {