Lower mhlo.not to a xor with all ones
PiperOrigin-RevId: 334361499
This commit is contained in:
parent
26ac5baae4
commit
6459f12235
|
@ -62,6 +62,7 @@ MAP_HLO_TO_LHLO(MaxOp);
|
|||
MAP_HLO_TO_LHLO(MinOp);
|
||||
MAP_HLO_TO_LHLO(MulOp);
|
||||
MAP_HLO_TO_LHLO(NegOp);
|
||||
MAP_HLO_TO_LHLO(NotOp);
|
||||
MAP_HLO_TO_LHLO(RealOp);
|
||||
MAP_HLO_TO_LHLO(ReduceOp);
|
||||
MAP_HLO_TO_LHLO(ReshapeOp);
|
||||
|
|
|
@ -431,6 +431,21 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::NegOp>(Location loc,
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
template <>
|
||||
inline Value MapLhloOpToStdScalarOp<lmhlo::NotOp>(Location loc,
|
||||
ArrayRef<Type> result_types,
|
||||
ArrayRef<Value> args,
|
||||
OpBuilder* b) {
|
||||
Type element_type = args.front().getType();
|
||||
if (auto integer_type = element_type.dyn_cast<IntegerType>()) {
|
||||
// lmhlo.not(x) -> x ^ -1
|
||||
auto all_ones =
|
||||
b->create<::mlir::ConstantIntOp>(loc, -1, integer_type.getWidth());
|
||||
return b->create<::mlir::XOrOp>(loc, all_ones, args[0]);
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
template <>
|
||||
inline Value MapLhloOpToStdScalarOp<lmhlo::RsqrtOp>(Location loc,
|
||||
ArrayRef<Type> result_types,
|
||||
|
|
|
@ -510,6 +510,7 @@ void populateHLOToLHLOConversionPattern(
|
|||
HloToLhloOpConverter<mhlo::MinOp>,
|
||||
HloToLhloOpConverter<mhlo::MulOp>,
|
||||
HloToLhloOpConverter<mhlo::NegOp>,
|
||||
HloToLhloOpConverter<mhlo::NotOp>,
|
||||
HloToLhloOpConverter<mhlo::RealOp>,
|
||||
HloToLhloOpConverter<mhlo::RemOp>,
|
||||
HloToLhloOpConverter<mhlo::RsqrtOp>,
|
||||
|
|
|
@ -838,6 +838,7 @@ void populateLHLOToLinalgConversionPattern(MLIRContext* context,
|
|||
PointwiseToLinalgConverter<lmhlo::MinOp>,
|
||||
PointwiseToLinalgConverter<lmhlo::MulOp>,
|
||||
PointwiseToLinalgConverter<lmhlo::NegOp>,
|
||||
PointwiseToLinalgConverter<lmhlo::NotOp>,
|
||||
PointwiseToLinalgConverter<lmhlo::RealOp>,
|
||||
PointwiseToLinalgConverter<lmhlo::RemOp>,
|
||||
PointwiseToLinalgConverter<lmhlo::RsqrtOp>,
|
||||
|
@ -945,6 +946,7 @@ void populateHLOToLinalgConversionPattern(MLIRContext* context,
|
|||
PointwiseToLinalgConverter<mhlo::MinOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::MulOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::NegOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::NotOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::RealOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::RemOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::RsqrtOp, false>,
|
||||
|
|
|
@ -344,6 +344,18 @@ func @neg(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
|||
|
||||
// -----
|
||||
|
||||
// BOTH-LABEL: func @not
|
||||
func @not(%operand: memref<2x2xi32>, %result: memref<2x2xi32>) {
|
||||
%tensor_operand = tensor_load %operand : memref<2x2xi32>
|
||||
%tensor_result = "mhlo.not"(%tensor_operand)
|
||||
: (tensor<2x2xi32>) -> tensor<2x2xi32>
|
||||
// BOTH: "lmhlo.not"(%{{.*}}, %{{.*}})
|
||||
tensor_store %tensor_result, %result : memref<2x2xi32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// BOTH-LABEL: func @rsqrt
|
||||
func @rsqrt(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||
%tensor_operand = tensor_load %operand : memref<2x2xf32>
|
||||
|
|
|
@ -534,6 +534,19 @@ func @negi(%input: memref<2x2xi32>, %result: memref<2x2xi32>) {
|
|||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @not
|
||||
func @not(%input: memref<2x2xi64>, %result: memref<2x2xi64>) {
|
||||
"lmhlo.not"(%input, %result) : (memref<2x2xi64>, memref<2x2xi64>) -> ()
|
||||
return
|
||||
}
|
||||
// CHECK: linalg.generic
|
||||
// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: i64, %[[RESULT_OUT:.*]]):
|
||||
// CHECK-NEXT: %[[N1:.*]] = constant -1 : i64
|
||||
// CHECK-NEXT: %[[RESULT:.*]] = xor %[[N1]], %[[OPERAND_IN]] : i64
|
||||
// CHECK-NEXT: linalg.yield %[[RESULT]] : i64
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @rem
|
||||
func @remainder(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>,
|
||||
%result: memref<2x2xf32>) {
|
||||
|
|
Loading…
Reference in New Issue