Add plumbing for or and xor to hlo to lhlo and linalg lowerings.
PiperOrigin-RevId: 346311314
This commit is contained in:
parent
dd5895d083
commit
c3790af758
|
@ -65,6 +65,7 @@ MAP_HLO_TO_LHLO(MinOp);
|
||||||
MAP_HLO_TO_LHLO(MulOp);
|
MAP_HLO_TO_LHLO(MulOp);
|
||||||
MAP_HLO_TO_LHLO(NegOp);
|
MAP_HLO_TO_LHLO(NegOp);
|
||||||
MAP_HLO_TO_LHLO(NotOp);
|
MAP_HLO_TO_LHLO(NotOp);
|
||||||
|
MAP_HLO_TO_LHLO(OrOp);
|
||||||
MAP_HLO_TO_LHLO(RealOp);
|
MAP_HLO_TO_LHLO(RealOp);
|
||||||
MAP_HLO_TO_LHLO(ReduceOp);
|
MAP_HLO_TO_LHLO(ReduceOp);
|
||||||
MAP_HLO_TO_LHLO(ReshapeOp);
|
MAP_HLO_TO_LHLO(ReshapeOp);
|
||||||
|
@ -81,6 +82,7 @@ MAP_HLO_TO_LHLO(SqrtOp);
|
||||||
MAP_HLO_TO_LHLO(SubOp);
|
MAP_HLO_TO_LHLO(SubOp);
|
||||||
MAP_HLO_TO_LHLO(TanhOp);
|
MAP_HLO_TO_LHLO(TanhOp);
|
||||||
MAP_HLO_TO_LHLO(TransposeOp);
|
MAP_HLO_TO_LHLO(TransposeOp);
|
||||||
|
MAP_HLO_TO_LHLO(XorOp);
|
||||||
|
|
||||||
#undef MAP_HLO_TO_LHLO
|
#undef MAP_HLO_TO_LHLO
|
||||||
|
|
||||||
|
|
|
@ -481,6 +481,15 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::NotOp>(Location loc,
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
inline Value MapLhloOpToStdScalarOp<lmhlo::OrOp>(Location loc,
|
||||||
|
ArrayRef<Type> result_types,
|
||||||
|
ArrayRef<Value> args,
|
||||||
|
OpBuilder* b) {
|
||||||
|
return MapLhloOpToStdScalarOpImpl<IntegerType, ::mlir::OrOp>{}(
|
||||||
|
loc, result_types, args, b);
|
||||||
|
}
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
inline Value MapLhloOpToStdScalarOp<lmhlo::RsqrtOp>(Location loc,
|
inline Value MapLhloOpToStdScalarOp<lmhlo::RsqrtOp>(Location loc,
|
||||||
ArrayRef<Type> result_types,
|
ArrayRef<Type> result_types,
|
||||||
|
@ -580,6 +589,15 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::TanhOp>(Location loc,
|
||||||
loc, result_types, args, b);
|
loc, result_types, args, b);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
inline Value MapLhloOpToStdScalarOp<lmhlo::XorOp>(Location loc,
|
||||||
|
ArrayRef<Type> result_types,
|
||||||
|
ArrayRef<Value> args,
|
||||||
|
OpBuilder* b) {
|
||||||
|
return MapLhloOpToStdScalarOpImpl<IntegerType, ::mlir::XOrOp>{}(
|
||||||
|
loc, result_types, args, b);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace impl
|
} // namespace impl
|
||||||
|
|
||||||
struct HloOpToStdScalarOp {
|
struct HloOpToStdScalarOp {
|
||||||
|
|
|
@ -629,6 +629,7 @@ void populateHLOToLHLOConversionPattern(MLIRContext* context,
|
||||||
HloToLhloOpConverter<mhlo::MulOp>,
|
HloToLhloOpConverter<mhlo::MulOp>,
|
||||||
HloToLhloOpConverter<mhlo::NegOp>,
|
HloToLhloOpConverter<mhlo::NegOp>,
|
||||||
HloToLhloOpConverter<mhlo::NotOp>,
|
HloToLhloOpConverter<mhlo::NotOp>,
|
||||||
|
HloToLhloOpConverter<mhlo::OrOp>,
|
||||||
HloToLhloOpConverter<mhlo::RealOp>,
|
HloToLhloOpConverter<mhlo::RealOp>,
|
||||||
HloToLhloOpConverter<mhlo::RemOp>,
|
HloToLhloOpConverter<mhlo::RemOp>,
|
||||||
HloToLhloOpConverter<mhlo::RsqrtOp>,
|
HloToLhloOpConverter<mhlo::RsqrtOp>,
|
||||||
|
@ -644,6 +645,7 @@ void populateHLOToLHLOConversionPattern(MLIRContext* context,
|
||||||
HloToLhloOpConverter<mhlo::SubOp>,
|
HloToLhloOpConverter<mhlo::SubOp>,
|
||||||
HloToLhloOpConverter<mhlo::TanhOp>,
|
HloToLhloOpConverter<mhlo::TanhOp>,
|
||||||
HloToLhloOpConverter<mhlo::TransposeOp>,
|
HloToLhloOpConverter<mhlo::TransposeOp>,
|
||||||
|
HloToLhloOpConverter<mhlo::XorOp>,
|
||||||
HloToLhloReduceOpConverter,
|
HloToLhloReduceOpConverter,
|
||||||
HloToLhloReturnOpConverter,
|
HloToLhloReturnOpConverter,
|
||||||
HloToLhloTensorLoadOpConverter,
|
HloToLhloTensorLoadOpConverter,
|
||||||
|
|
|
@ -927,12 +927,14 @@ void populateLHLOToLinalgConversionPattern(MLIRContext* context,
|
||||||
PointwiseToLinalgConverter<lmhlo::ExpOp>,
|
PointwiseToLinalgConverter<lmhlo::ExpOp>,
|
||||||
PointwiseToLinalgConverter<lmhlo::FloorOp>,
|
PointwiseToLinalgConverter<lmhlo::FloorOp>,
|
||||||
PointwiseToLinalgConverter<lmhlo::ImagOp>,
|
PointwiseToLinalgConverter<lmhlo::ImagOp>,
|
||||||
|
PointwiseToLinalgConverter<lmhlo::IsFiniteOp>,
|
||||||
PointwiseToLinalgConverter<lmhlo::LogOp>,
|
PointwiseToLinalgConverter<lmhlo::LogOp>,
|
||||||
PointwiseToLinalgConverter<lmhlo::MaxOp>,
|
PointwiseToLinalgConverter<lmhlo::MaxOp>,
|
||||||
PointwiseToLinalgConverter<lmhlo::MinOp>,
|
PointwiseToLinalgConverter<lmhlo::MinOp>,
|
||||||
PointwiseToLinalgConverter<lmhlo::MulOp>,
|
PointwiseToLinalgConverter<lmhlo::MulOp>,
|
||||||
PointwiseToLinalgConverter<lmhlo::NegOp>,
|
PointwiseToLinalgConverter<lmhlo::NegOp>,
|
||||||
PointwiseToLinalgConverter<lmhlo::NotOp>,
|
PointwiseToLinalgConverter<lmhlo::NotOp>,
|
||||||
|
PointwiseToLinalgConverter<lmhlo::OrOp>,
|
||||||
PointwiseToLinalgConverter<lmhlo::RealOp>,
|
PointwiseToLinalgConverter<lmhlo::RealOp>,
|
||||||
PointwiseToLinalgConverter<lmhlo::RemOp>,
|
PointwiseToLinalgConverter<lmhlo::RemOp>,
|
||||||
PointwiseToLinalgConverter<lmhlo::RsqrtOp>,
|
PointwiseToLinalgConverter<lmhlo::RsqrtOp>,
|
||||||
|
@ -945,7 +947,7 @@ void populateLHLOToLinalgConversionPattern(MLIRContext* context,
|
||||||
PointwiseToLinalgConverter<lmhlo::SqrtOp>,
|
PointwiseToLinalgConverter<lmhlo::SqrtOp>,
|
||||||
PointwiseToLinalgConverter<lmhlo::SubOp>,
|
PointwiseToLinalgConverter<lmhlo::SubOp>,
|
||||||
PointwiseToLinalgConverter<lmhlo::TanhOp>,
|
PointwiseToLinalgConverter<lmhlo::TanhOp>,
|
||||||
PointwiseToLinalgConverter<lmhlo::IsFiniteOp>,
|
PointwiseToLinalgConverter<lmhlo::XorOp>,
|
||||||
ReduceConverter,
|
ReduceConverter,
|
||||||
ReshapeOpConverter<lmhlo::ReshapeOp>,
|
ReshapeOpConverter<lmhlo::ReshapeOp>,
|
||||||
ReverseConverter<lmhlo::ReverseOp>,
|
ReverseConverter<lmhlo::ReverseOp>,
|
||||||
|
@ -1042,12 +1044,14 @@ void populateHLOToLinalgConversionPattern(MLIRContext* context,
|
||||||
PointwiseToLinalgConverter<mhlo::ExpOp, false>,
|
PointwiseToLinalgConverter<mhlo::ExpOp, false>,
|
||||||
PointwiseToLinalgConverter<mhlo::FloorOp, false>,
|
PointwiseToLinalgConverter<mhlo::FloorOp, false>,
|
||||||
PointwiseToLinalgConverter<mhlo::ImagOp, false>,
|
PointwiseToLinalgConverter<mhlo::ImagOp, false>,
|
||||||
|
PointwiseToLinalgConverter<mhlo::IsFiniteOp, false>,
|
||||||
PointwiseToLinalgConverter<mhlo::LogOp, false>,
|
PointwiseToLinalgConverter<mhlo::LogOp, false>,
|
||||||
PointwiseToLinalgConverter<mhlo::MaxOp, false>,
|
PointwiseToLinalgConverter<mhlo::MaxOp, false>,
|
||||||
PointwiseToLinalgConverter<mhlo::MinOp, false>,
|
PointwiseToLinalgConverter<mhlo::MinOp, false>,
|
||||||
PointwiseToLinalgConverter<mhlo::MulOp, false>,
|
PointwiseToLinalgConverter<mhlo::MulOp, false>,
|
||||||
PointwiseToLinalgConverter<mhlo::NegOp, false>,
|
PointwiseToLinalgConverter<mhlo::NegOp, false>,
|
||||||
PointwiseToLinalgConverter<mhlo::NotOp, false>,
|
PointwiseToLinalgConverter<mhlo::NotOp, false>,
|
||||||
|
PointwiseToLinalgConverter<mhlo::OrOp, false>,
|
||||||
PointwiseToLinalgConverter<mhlo::RealOp, false>,
|
PointwiseToLinalgConverter<mhlo::RealOp, false>,
|
||||||
PointwiseToLinalgConverter<mhlo::RemOp, false>,
|
PointwiseToLinalgConverter<mhlo::RemOp, false>,
|
||||||
PointwiseToLinalgConverter<mhlo::RsqrtOp, false>,
|
PointwiseToLinalgConverter<mhlo::RsqrtOp, false>,
|
||||||
|
@ -1059,7 +1063,7 @@ void populateHLOToLinalgConversionPattern(MLIRContext* context,
|
||||||
PointwiseToLinalgConverter<mhlo::SqrtOp, false>,
|
PointwiseToLinalgConverter<mhlo::SqrtOp, false>,
|
||||||
PointwiseToLinalgConverter<mhlo::SubOp, false>,
|
PointwiseToLinalgConverter<mhlo::SubOp, false>,
|
||||||
PointwiseToLinalgConverter<mhlo::TanhOp, false>,
|
PointwiseToLinalgConverter<mhlo::TanhOp, false>,
|
||||||
PointwiseToLinalgConverter<mhlo::IsFiniteOp, false>,
|
PointwiseToLinalgConverter<mhlo::XorOp, false>,
|
||||||
ReshapeOpConverter<mhlo::ReshapeOp, false>,
|
ReshapeOpConverter<mhlo::ReshapeOp, false>,
|
||||||
ReverseConverter<mhlo::ReverseOp, false>,
|
ReverseConverter<mhlo::ReverseOp, false>,
|
||||||
TransposeConverter<mhlo::TransposeOp, false>>(context);
|
TransposeConverter<mhlo::TransposeOp, false>>(context);
|
||||||
|
|
|
@ -316,6 +316,20 @@ func @abs(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @and
|
||||||
|
func @and(%operand0: memref<2x2xi32>, %operand1: memref<2x2xi32>,
|
||||||
|
%result: memref<2x2xi32>) {
|
||||||
|
%tensor_operand0 = tensor_load %operand0 : memref<2x2xi32>
|
||||||
|
%tensor_operand1 = tensor_load %operand1 : memref<2x2xi32>
|
||||||
|
%tensor_result = "mhlo.and"(%tensor_operand0, %tensor_operand1)
|
||||||
|
: (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
|
||||||
|
// CHECK: "lmhlo.and"(%{{.*}}, %{{.*}}, %{{.*}})
|
||||||
|
tensor_store %tensor_result, %result : memref<2x2xi32>
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
// CHECK-LABEL: func @ceil
|
// CHECK-LABEL: func @ceil
|
||||||
func @ceil(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
func @ceil(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||||
%tensor_operand = tensor_load %operand : memref<2x2xf32>
|
%tensor_operand = tensor_load %operand : memref<2x2xf32>
|
||||||
|
@ -389,6 +403,20 @@ func @not(%operand: memref<2x2xi32>, %result: memref<2x2xi32>) {
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @or
|
||||||
|
func @or(%operand0: memref<2x2xi32>, %operand1: memref<2x2xi32>,
|
||||||
|
%result: memref<2x2xi32>) {
|
||||||
|
%tensor_operand0 = tensor_load %operand0 : memref<2x2xi32>
|
||||||
|
%tensor_operand1 = tensor_load %operand1 : memref<2x2xi32>
|
||||||
|
%tensor_result = "mhlo.or"(%tensor_operand0, %tensor_operand1)
|
||||||
|
: (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
|
||||||
|
// CHECK: "lmhlo.or"(%{{.*}}, %{{.*}}, %{{.*}})
|
||||||
|
tensor_store %tensor_result, %result : memref<2x2xi32>
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
// CHECK-LABEL: func @rsqrt
|
// CHECK-LABEL: func @rsqrt
|
||||||
func @rsqrt(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
func @rsqrt(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||||
%tensor_operand = tensor_load %operand : memref<2x2xf32>
|
%tensor_operand = tensor_load %operand : memref<2x2xf32>
|
||||||
|
@ -480,7 +508,8 @@ func @tanh(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
// CHECK-LABEL: func @remainder
|
// CHECK-LABEL: func @remainder
|
||||||
func @remainder(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
func @remainder(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>,
|
||||||
|
%result: memref<2x2xf32>) {
|
||||||
%tensor_lhs = tensor_load %lhs : memref<2x2xf32>
|
%tensor_lhs = tensor_load %lhs : memref<2x2xf32>
|
||||||
%tensor_rhs = tensor_load %rhs : memref<2x2xf32>
|
%tensor_rhs = tensor_load %rhs : memref<2x2xf32>
|
||||||
%tensor_result = "mhlo.remainder"(%tensor_lhs, %tensor_rhs)
|
%tensor_result = "mhlo.remainder"(%tensor_lhs, %tensor_rhs)
|
||||||
|
@ -492,6 +521,20 @@ func @remainder(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>, %result: memref<2x
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @xor
|
||||||
|
func @xor(%operand0: memref<2x2xi32>, %operand1: memref<2x2xi32>,
|
||||||
|
%result: memref<2x2xi32>) {
|
||||||
|
%tensor_operand0 = tensor_load %operand0 : memref<2x2xi32>
|
||||||
|
%tensor_operand1 = tensor_load %operand1 : memref<2x2xi32>
|
||||||
|
%tensor_result = "mhlo.xor"(%tensor_operand0, %tensor_operand1)
|
||||||
|
: (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
|
||||||
|
// CHECK: "lmhlo.xor"(%{{.*}}, %{{.*}})
|
||||||
|
tensor_store %tensor_result, %result : memref<2x2xi32>
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
// Dynamic shape binary element-wise operation.
|
// Dynamic shape binary element-wise operation.
|
||||||
// CHECK-LABEL: func @add_dyn
|
// CHECK-LABEL: func @add_dyn
|
||||||
func @add_dyn(%lhs: tensor<?x?xf32>, %rhs: tensor<?x?xf32>) {
|
func @add_dyn(%lhs: tensor<?x?xf32>, %rhs: tensor<?x?xf32>) {
|
||||||
|
|
|
@ -194,6 +194,30 @@ func @integer_and(%lhs: tensor<2x2xi32>,
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @integer_or
|
||||||
|
func @integer_or(%lhs: tensor<2x2xi32>,
|
||||||
|
%rhs: tensor<2x2xi32>) -> tensor<2x2xi32> {
|
||||||
|
// CHECK: linalg.generic
|
||||||
|
// CHECK: or
|
||||||
|
%0 = "mhlo.or"(%lhs, %rhs) : (tensor<2x2xi32>,
|
||||||
|
tensor<2x2xi32>) -> tensor<2x2xi32>
|
||||||
|
return %0 : tensor<2x2xi32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @integer_xor
|
||||||
|
func @integer_xor(%lhs: tensor<2x2xi32>,
|
||||||
|
%rhs: tensor<2x2xi32>) -> tensor<2x2xi32> {
|
||||||
|
// CHECK: linalg.generic
|
||||||
|
// CHECK: xor
|
||||||
|
%0 = "mhlo.xor"(%lhs, %rhs) : (tensor<2x2xi32>,
|
||||||
|
tensor<2x2xi32>) -> tensor<2x2xi32>
|
||||||
|
return %0 : tensor<2x2xi32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
// CHECK-LABEL: func @float_cmp
|
// CHECK-LABEL: func @float_cmp
|
||||||
func @float_cmp(%lhs: tensor<2x2xf32>,
|
func @float_cmp(%lhs: tensor<2x2xf32>,
|
||||||
%rhs: tensor<2x2xf32>) -> (tensor<2x2xi1>) {
|
%rhs: tensor<2x2xf32>) -> (tensor<2x2xi1>) {
|
||||||
|
|
Loading…
Reference in New Issue