Lower mhlo.not to a xor with all ones
PiperOrigin-RevId: 334361499
This commit is contained in:
parent
26ac5baae4
commit
6459f12235
|
@ -62,6 +62,7 @@ MAP_HLO_TO_LHLO(MaxOp);
|
||||||
MAP_HLO_TO_LHLO(MinOp);
|
MAP_HLO_TO_LHLO(MinOp);
|
||||||
MAP_HLO_TO_LHLO(MulOp);
|
MAP_HLO_TO_LHLO(MulOp);
|
||||||
MAP_HLO_TO_LHLO(NegOp);
|
MAP_HLO_TO_LHLO(NegOp);
|
||||||
|
MAP_HLO_TO_LHLO(NotOp);
|
||||||
MAP_HLO_TO_LHLO(RealOp);
|
MAP_HLO_TO_LHLO(RealOp);
|
||||||
MAP_HLO_TO_LHLO(ReduceOp);
|
MAP_HLO_TO_LHLO(ReduceOp);
|
||||||
MAP_HLO_TO_LHLO(ReshapeOp);
|
MAP_HLO_TO_LHLO(ReshapeOp);
|
||||||
|
|
|
@ -431,6 +431,21 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::NegOp>(Location loc,
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
inline Value MapLhloOpToStdScalarOp<lmhlo::NotOp>(Location loc,
|
||||||
|
ArrayRef<Type> result_types,
|
||||||
|
ArrayRef<Value> args,
|
||||||
|
OpBuilder* b) {
|
||||||
|
Type element_type = args.front().getType();
|
||||||
|
if (auto integer_type = element_type.dyn_cast<IntegerType>()) {
|
||||||
|
// lmhlo.not(x) -> x ^ -1
|
||||||
|
auto all_ones =
|
||||||
|
b->create<::mlir::ConstantIntOp>(loc, -1, integer_type.getWidth());
|
||||||
|
return b->create<::mlir::XOrOp>(loc, all_ones, args[0]);
|
||||||
|
}
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
inline Value MapLhloOpToStdScalarOp<lmhlo::RsqrtOp>(Location loc,
|
inline Value MapLhloOpToStdScalarOp<lmhlo::RsqrtOp>(Location loc,
|
||||||
ArrayRef<Type> result_types,
|
ArrayRef<Type> result_types,
|
||||||
|
|
|
@ -510,6 +510,7 @@ void populateHLOToLHLOConversionPattern(
|
||||||
HloToLhloOpConverter<mhlo::MinOp>,
|
HloToLhloOpConverter<mhlo::MinOp>,
|
||||||
HloToLhloOpConverter<mhlo::MulOp>,
|
HloToLhloOpConverter<mhlo::MulOp>,
|
||||||
HloToLhloOpConverter<mhlo::NegOp>,
|
HloToLhloOpConverter<mhlo::NegOp>,
|
||||||
|
HloToLhloOpConverter<mhlo::NotOp>,
|
||||||
HloToLhloOpConverter<mhlo::RealOp>,
|
HloToLhloOpConverter<mhlo::RealOp>,
|
||||||
HloToLhloOpConverter<mhlo::RemOp>,
|
HloToLhloOpConverter<mhlo::RemOp>,
|
||||||
HloToLhloOpConverter<mhlo::RsqrtOp>,
|
HloToLhloOpConverter<mhlo::RsqrtOp>,
|
||||||
|
|
|
@ -838,6 +838,7 @@ void populateLHLOToLinalgConversionPattern(MLIRContext* context,
|
||||||
PointwiseToLinalgConverter<lmhlo::MinOp>,
|
PointwiseToLinalgConverter<lmhlo::MinOp>,
|
||||||
PointwiseToLinalgConverter<lmhlo::MulOp>,
|
PointwiseToLinalgConverter<lmhlo::MulOp>,
|
||||||
PointwiseToLinalgConverter<lmhlo::NegOp>,
|
PointwiseToLinalgConverter<lmhlo::NegOp>,
|
||||||
|
PointwiseToLinalgConverter<lmhlo::NotOp>,
|
||||||
PointwiseToLinalgConverter<lmhlo::RealOp>,
|
PointwiseToLinalgConverter<lmhlo::RealOp>,
|
||||||
PointwiseToLinalgConverter<lmhlo::RemOp>,
|
PointwiseToLinalgConverter<lmhlo::RemOp>,
|
||||||
PointwiseToLinalgConverter<lmhlo::RsqrtOp>,
|
PointwiseToLinalgConverter<lmhlo::RsqrtOp>,
|
||||||
|
@ -945,6 +946,7 @@ void populateHLOToLinalgConversionPattern(MLIRContext* context,
|
||||||
PointwiseToLinalgConverter<mhlo::MinOp, false>,
|
PointwiseToLinalgConverter<mhlo::MinOp, false>,
|
||||||
PointwiseToLinalgConverter<mhlo::MulOp, false>,
|
PointwiseToLinalgConverter<mhlo::MulOp, false>,
|
||||||
PointwiseToLinalgConverter<mhlo::NegOp, false>,
|
PointwiseToLinalgConverter<mhlo::NegOp, false>,
|
||||||
|
PointwiseToLinalgConverter<mhlo::NotOp, false>,
|
||||||
PointwiseToLinalgConverter<mhlo::RealOp, false>,
|
PointwiseToLinalgConverter<mhlo::RealOp, false>,
|
||||||
PointwiseToLinalgConverter<mhlo::RemOp, false>,
|
PointwiseToLinalgConverter<mhlo::RemOp, false>,
|
||||||
PointwiseToLinalgConverter<mhlo::RsqrtOp, false>,
|
PointwiseToLinalgConverter<mhlo::RsqrtOp, false>,
|
||||||
|
|
|
@ -344,6 +344,18 @@ func @neg(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
|
// BOTH-LABEL: func @not
|
||||||
|
func @not(%operand: memref<2x2xi32>, %result: memref<2x2xi32>) {
|
||||||
|
%tensor_operand = tensor_load %operand : memref<2x2xi32>
|
||||||
|
%tensor_result = "mhlo.not"(%tensor_operand)
|
||||||
|
: (tensor<2x2xi32>) -> tensor<2x2xi32>
|
||||||
|
// BOTH: "lmhlo.not"(%{{.*}}, %{{.*}})
|
||||||
|
tensor_store %tensor_result, %result : memref<2x2xi32>
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
// BOTH-LABEL: func @rsqrt
|
// BOTH-LABEL: func @rsqrt
|
||||||
func @rsqrt(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
func @rsqrt(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||||
%tensor_operand = tensor_load %operand : memref<2x2xf32>
|
%tensor_operand = tensor_load %operand : memref<2x2xf32>
|
||||||
|
|
|
@ -534,6 +534,19 @@ func @negi(%input: memref<2x2xi32>, %result: memref<2x2xi32>) {
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @not
|
||||||
|
func @not(%input: memref<2x2xi64>, %result: memref<2x2xi64>) {
|
||||||
|
"lmhlo.not"(%input, %result) : (memref<2x2xi64>, memref<2x2xi64>) -> ()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// CHECK: linalg.generic
|
||||||
|
// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: i64, %[[RESULT_OUT:.*]]):
|
||||||
|
// CHECK-NEXT: %[[N1:.*]] = constant -1 : i64
|
||||||
|
// CHECK-NEXT: %[[RESULT:.*]] = xor %[[N1]], %[[OPERAND_IN]] : i64
|
||||||
|
// CHECK-NEXT: linalg.yield %[[RESULT]] : i64
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
// CHECK-LABEL: func @rem
|
// CHECK-LABEL: func @rem
|
||||||
func @remainder(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>,
|
func @remainder(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>,
|
||||||
%result: memref<2x2xf32>) {
|
%result: memref<2x2xf32>) {
|
||||||
|
|
Loading…
Reference in New Issue