Add plumbing for or and xor to hlo to lhlo and linalg lowerings.

PiperOrigin-RevId: 346311314
This commit is contained in:
Stephan Herhut 2020-12-08 06:38:26 -08:00 committed by TensorFlow MLIR Team
parent dd5895d083
commit c3790af758
6 changed files with 96 additions and 3 deletions

View File

@ -65,6 +65,7 @@ MAP_HLO_TO_LHLO(MinOp);
MAP_HLO_TO_LHLO(MulOp); MAP_HLO_TO_LHLO(MulOp);
MAP_HLO_TO_LHLO(NegOp); MAP_HLO_TO_LHLO(NegOp);
MAP_HLO_TO_LHLO(NotOp); MAP_HLO_TO_LHLO(NotOp);
MAP_HLO_TO_LHLO(OrOp);
MAP_HLO_TO_LHLO(RealOp); MAP_HLO_TO_LHLO(RealOp);
MAP_HLO_TO_LHLO(ReduceOp); MAP_HLO_TO_LHLO(ReduceOp);
MAP_HLO_TO_LHLO(ReshapeOp); MAP_HLO_TO_LHLO(ReshapeOp);
@ -81,6 +82,7 @@ MAP_HLO_TO_LHLO(SqrtOp);
MAP_HLO_TO_LHLO(SubOp); MAP_HLO_TO_LHLO(SubOp);
MAP_HLO_TO_LHLO(TanhOp); MAP_HLO_TO_LHLO(TanhOp);
MAP_HLO_TO_LHLO(TransposeOp); MAP_HLO_TO_LHLO(TransposeOp);
MAP_HLO_TO_LHLO(XorOp);
#undef MAP_HLO_TO_LHLO #undef MAP_HLO_TO_LHLO

View File

@ -481,6 +481,15 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::NotOp>(Location loc,
return nullptr; return nullptr;
} }
template <>
inline Value MapLhloOpToStdScalarOp<lmhlo::OrOp>(Location loc,
ArrayRef<Type> result_types,
ArrayRef<Value> args,
OpBuilder* b) {
return MapLhloOpToStdScalarOpImpl<IntegerType, ::mlir::OrOp>{}(
loc, result_types, args, b);
}
template <> template <>
inline Value MapLhloOpToStdScalarOp<lmhlo::RsqrtOp>(Location loc, inline Value MapLhloOpToStdScalarOp<lmhlo::RsqrtOp>(Location loc,
ArrayRef<Type> result_types, ArrayRef<Type> result_types,
@ -580,6 +589,15 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::TanhOp>(Location loc,
loc, result_types, args, b); loc, result_types, args, b);
} }
template <>
inline Value MapLhloOpToStdScalarOp<lmhlo::XorOp>(Location loc,
ArrayRef<Type> result_types,
ArrayRef<Value> args,
OpBuilder* b) {
return MapLhloOpToStdScalarOpImpl<IntegerType, ::mlir::XOrOp>{}(
loc, result_types, args, b);
}
} // namespace impl } // namespace impl
struct HloOpToStdScalarOp { struct HloOpToStdScalarOp {

View File

@ -629,6 +629,7 @@ void populateHLOToLHLOConversionPattern(MLIRContext* context,
HloToLhloOpConverter<mhlo::MulOp>, HloToLhloOpConverter<mhlo::MulOp>,
HloToLhloOpConverter<mhlo::NegOp>, HloToLhloOpConverter<mhlo::NegOp>,
HloToLhloOpConverter<mhlo::NotOp>, HloToLhloOpConverter<mhlo::NotOp>,
HloToLhloOpConverter<mhlo::OrOp>,
HloToLhloOpConverter<mhlo::RealOp>, HloToLhloOpConverter<mhlo::RealOp>,
HloToLhloOpConverter<mhlo::RemOp>, HloToLhloOpConverter<mhlo::RemOp>,
HloToLhloOpConverter<mhlo::RsqrtOp>, HloToLhloOpConverter<mhlo::RsqrtOp>,
@ -644,6 +645,7 @@ void populateHLOToLHLOConversionPattern(MLIRContext* context,
HloToLhloOpConverter<mhlo::SubOp>, HloToLhloOpConverter<mhlo::SubOp>,
HloToLhloOpConverter<mhlo::TanhOp>, HloToLhloOpConverter<mhlo::TanhOp>,
HloToLhloOpConverter<mhlo::TransposeOp>, HloToLhloOpConverter<mhlo::TransposeOp>,
HloToLhloOpConverter<mhlo::XorOp>,
HloToLhloReduceOpConverter, HloToLhloReduceOpConverter,
HloToLhloReturnOpConverter, HloToLhloReturnOpConverter,
HloToLhloTensorLoadOpConverter, HloToLhloTensorLoadOpConverter,

View File

@ -927,12 +927,14 @@ void populateLHLOToLinalgConversionPattern(MLIRContext* context,
PointwiseToLinalgConverter<lmhlo::ExpOp>, PointwiseToLinalgConverter<lmhlo::ExpOp>,
PointwiseToLinalgConverter<lmhlo::FloorOp>, PointwiseToLinalgConverter<lmhlo::FloorOp>,
PointwiseToLinalgConverter<lmhlo::ImagOp>, PointwiseToLinalgConverter<lmhlo::ImagOp>,
PointwiseToLinalgConverter<lmhlo::IsFiniteOp>,
PointwiseToLinalgConverter<lmhlo::LogOp>, PointwiseToLinalgConverter<lmhlo::LogOp>,
PointwiseToLinalgConverter<lmhlo::MaxOp>, PointwiseToLinalgConverter<lmhlo::MaxOp>,
PointwiseToLinalgConverter<lmhlo::MinOp>, PointwiseToLinalgConverter<lmhlo::MinOp>,
PointwiseToLinalgConverter<lmhlo::MulOp>, PointwiseToLinalgConverter<lmhlo::MulOp>,
PointwiseToLinalgConverter<lmhlo::NegOp>, PointwiseToLinalgConverter<lmhlo::NegOp>,
PointwiseToLinalgConverter<lmhlo::NotOp>, PointwiseToLinalgConverter<lmhlo::NotOp>,
PointwiseToLinalgConverter<lmhlo::OrOp>,
PointwiseToLinalgConverter<lmhlo::RealOp>, PointwiseToLinalgConverter<lmhlo::RealOp>,
PointwiseToLinalgConverter<lmhlo::RemOp>, PointwiseToLinalgConverter<lmhlo::RemOp>,
PointwiseToLinalgConverter<lmhlo::RsqrtOp>, PointwiseToLinalgConverter<lmhlo::RsqrtOp>,
@ -945,7 +947,7 @@ void populateLHLOToLinalgConversionPattern(MLIRContext* context,
PointwiseToLinalgConverter<lmhlo::SqrtOp>, PointwiseToLinalgConverter<lmhlo::SqrtOp>,
PointwiseToLinalgConverter<lmhlo::SubOp>, PointwiseToLinalgConverter<lmhlo::SubOp>,
PointwiseToLinalgConverter<lmhlo::TanhOp>, PointwiseToLinalgConverter<lmhlo::TanhOp>,
PointwiseToLinalgConverter<lmhlo::IsFiniteOp>, PointwiseToLinalgConverter<lmhlo::XorOp>,
ReduceConverter, ReduceConverter,
ReshapeOpConverter<lmhlo::ReshapeOp>, ReshapeOpConverter<lmhlo::ReshapeOp>,
ReverseConverter<lmhlo::ReverseOp>, ReverseConverter<lmhlo::ReverseOp>,
@ -1042,12 +1044,14 @@ void populateHLOToLinalgConversionPattern(MLIRContext* context,
PointwiseToLinalgConverter<mhlo::ExpOp, false>, PointwiseToLinalgConverter<mhlo::ExpOp, false>,
PointwiseToLinalgConverter<mhlo::FloorOp, false>, PointwiseToLinalgConverter<mhlo::FloorOp, false>,
PointwiseToLinalgConverter<mhlo::ImagOp, false>, PointwiseToLinalgConverter<mhlo::ImagOp, false>,
PointwiseToLinalgConverter<mhlo::IsFiniteOp, false>,
PointwiseToLinalgConverter<mhlo::LogOp, false>, PointwiseToLinalgConverter<mhlo::LogOp, false>,
PointwiseToLinalgConverter<mhlo::MaxOp, false>, PointwiseToLinalgConverter<mhlo::MaxOp, false>,
PointwiseToLinalgConverter<mhlo::MinOp, false>, PointwiseToLinalgConverter<mhlo::MinOp, false>,
PointwiseToLinalgConverter<mhlo::MulOp, false>, PointwiseToLinalgConverter<mhlo::MulOp, false>,
PointwiseToLinalgConverter<mhlo::NegOp, false>, PointwiseToLinalgConverter<mhlo::NegOp, false>,
PointwiseToLinalgConverter<mhlo::NotOp, false>, PointwiseToLinalgConverter<mhlo::NotOp, false>,
PointwiseToLinalgConverter<mhlo::OrOp, false>,
PointwiseToLinalgConverter<mhlo::RealOp, false>, PointwiseToLinalgConverter<mhlo::RealOp, false>,
PointwiseToLinalgConverter<mhlo::RemOp, false>, PointwiseToLinalgConverter<mhlo::RemOp, false>,
PointwiseToLinalgConverter<mhlo::RsqrtOp, false>, PointwiseToLinalgConverter<mhlo::RsqrtOp, false>,
@ -1059,7 +1063,7 @@ void populateHLOToLinalgConversionPattern(MLIRContext* context,
PointwiseToLinalgConverter<mhlo::SqrtOp, false>, PointwiseToLinalgConverter<mhlo::SqrtOp, false>,
PointwiseToLinalgConverter<mhlo::SubOp, false>, PointwiseToLinalgConverter<mhlo::SubOp, false>,
PointwiseToLinalgConverter<mhlo::TanhOp, false>, PointwiseToLinalgConverter<mhlo::TanhOp, false>,
PointwiseToLinalgConverter<mhlo::IsFiniteOp, false>, PointwiseToLinalgConverter<mhlo::XorOp, false>,
ReshapeOpConverter<mhlo::ReshapeOp, false>, ReshapeOpConverter<mhlo::ReshapeOp, false>,
ReverseConverter<mhlo::ReverseOp, false>, ReverseConverter<mhlo::ReverseOp, false>,
TransposeConverter<mhlo::TransposeOp, false>>(context); TransposeConverter<mhlo::TransposeOp, false>>(context);

View File

@ -316,6 +316,20 @@ func @abs(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
// ----- // -----
// CHECK-LABEL: func @and
func @and(%operand0: memref<2x2xi32>, %operand1: memref<2x2xi32>,
%result: memref<2x2xi32>) {
%tensor_operand0 = tensor_load %operand0 : memref<2x2xi32>
%tensor_operand1 = tensor_load %operand1 : memref<2x2xi32>
%tensor_result = "mhlo.and"(%tensor_operand0, %tensor_operand1)
: (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// CHECK: "lmhlo.and"(%{{.*}}, %{{.*}}, %{{.*}})
tensor_store %tensor_result, %result : memref<2x2xi32>
return
}
// -----
// CHECK-LABEL: func @ceil // CHECK-LABEL: func @ceil
func @ceil(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { func @ceil(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
%tensor_operand = tensor_load %operand : memref<2x2xf32> %tensor_operand = tensor_load %operand : memref<2x2xf32>
@ -389,6 +403,20 @@ func @not(%operand: memref<2x2xi32>, %result: memref<2x2xi32>) {
// ----- // -----
// CHECK-LABEL: func @or
func @or(%operand0: memref<2x2xi32>, %operand1: memref<2x2xi32>,
%result: memref<2x2xi32>) {
%tensor_operand0 = tensor_load %operand0 : memref<2x2xi32>
%tensor_operand1 = tensor_load %operand1 : memref<2x2xi32>
%tensor_result = "mhlo.or"(%tensor_operand0, %tensor_operand1)
: (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// CHECK: "lmhlo.or"(%{{.*}}, %{{.*}}, %{{.*}})
tensor_store %tensor_result, %result : memref<2x2xi32>
return
}
// -----
// CHECK-LABEL: func @rsqrt // CHECK-LABEL: func @rsqrt
func @rsqrt(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { func @rsqrt(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
%tensor_operand = tensor_load %operand : memref<2x2xf32> %tensor_operand = tensor_load %operand : memref<2x2xf32>
@ -480,7 +508,8 @@ func @tanh(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
// ----- // -----
// CHECK-LABEL: func @remainder // CHECK-LABEL: func @remainder
func @remainder(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>, %result: memref<2x2xf32>) { func @remainder(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>,
%result: memref<2x2xf32>) {
%tensor_lhs = tensor_load %lhs : memref<2x2xf32> %tensor_lhs = tensor_load %lhs : memref<2x2xf32>
%tensor_rhs = tensor_load %rhs : memref<2x2xf32> %tensor_rhs = tensor_load %rhs : memref<2x2xf32>
%tensor_result = "mhlo.remainder"(%tensor_lhs, %tensor_rhs) %tensor_result = "mhlo.remainder"(%tensor_lhs, %tensor_rhs)
@ -492,6 +521,20 @@ func @remainder(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>, %result: memref<2x
// ----- // -----
// CHECK-LABEL: func @xor
func @xor(%operand0: memref<2x2xi32>, %operand1: memref<2x2xi32>,
%result: memref<2x2xi32>) {
%tensor_operand0 = tensor_load %operand0 : memref<2x2xi32>
%tensor_operand1 = tensor_load %operand1 : memref<2x2xi32>
%tensor_result = "mhlo.xor"(%tensor_operand0, %tensor_operand1)
: (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
// CHECK: "lmhlo.xor"(%{{.*}}, %{{.*}})
tensor_store %tensor_result, %result : memref<2x2xi32>
return
}
// -----
// Dynamic shape binary element-wise operation. // Dynamic shape binary element-wise operation.
// CHECK-LABEL: func @add_dyn // CHECK-LABEL: func @add_dyn
func @add_dyn(%lhs: tensor<?x?xf32>, %rhs: tensor<?x?xf32>) { func @add_dyn(%lhs: tensor<?x?xf32>, %rhs: tensor<?x?xf32>) {

View File

@ -194,6 +194,30 @@ func @integer_and(%lhs: tensor<2x2xi32>,
// ----- // -----
// CHECK-LABEL: func @integer_or
func @integer_or(%lhs: tensor<2x2xi32>,
%rhs: tensor<2x2xi32>) -> tensor<2x2xi32> {
// CHECK: linalg.generic
// CHECK: or
%0 = "mhlo.or"(%lhs, %rhs) : (tensor<2x2xi32>,
tensor<2x2xi32>) -> tensor<2x2xi32>
return %0 : tensor<2x2xi32>
}
// -----
// CHECK-LABEL: func @integer_xor
func @integer_xor(%lhs: tensor<2x2xi32>,
%rhs: tensor<2x2xi32>) -> tensor<2x2xi32> {
// CHECK: linalg.generic
// CHECK: xor
%0 = "mhlo.xor"(%lhs, %rhs) : (tensor<2x2xi32>,
tensor<2x2xi32>) -> tensor<2x2xi32>
return %0 : tensor<2x2xi32>
}
// -----
// CHECK-LABEL: func @float_cmp // CHECK-LABEL: func @float_cmp
func @float_cmp(%lhs: tensor<2x2xf32>, func @float_cmp(%lhs: tensor<2x2xf32>,
%rhs: tensor<2x2xf32>) -> (tensor<2x2xi1>) { %rhs: tensor<2x2xf32>) -> (tensor<2x2xi1>) {