[KERNEL_GEN] Convert LHLO AddOp, SubOp (ComplexType) to complex ops.

PiperOrigin-RevId: 347805898
This commit is contained in:
Alexander Belyaev 2020-12-16 05:44:38 -08:00 committed by TensorFlow MLIR Team
parent 61244b136c
commit 65222893ae
2 changed files with 52 additions and 8 deletions

View File

@ -37,6 +37,7 @@ template <>
struct LhloToScalarOp<lmhlo::AddOp> {
using FOp = ::mlir::AddFOp;
using IOp = ::mlir::AddIOp;
using COp = ::mlir::AddCFOp;
};
template <>
struct LhloToScalarOp<lmhlo::CompareOp> {
@ -62,20 +63,18 @@ template <>
struct LhloToScalarOp<lmhlo::SubOp> {
using FOp = ::mlir::SubFOp;
using IOp = ::mlir::SubIOp;
};
template <typename LhloBinaryOpTy>
struct ScalarOp {
using FOp = typename LhloToScalarOp<LhloBinaryOpTy>::FOp;
using IOp = typename LhloToScalarOp<LhloBinaryOpTy>::IOp;
using COp = ::mlir::SubCFOp;
};
// Alias for the map from LHLO binary op type to STD floating-point op type.
template <typename LhloOp>
using ScalarFOp = typename ScalarOp<LhloOp>::FOp;
using ScalarFOp = typename LhloToScalarOp<LhloOp>::FOp;
// Alias for the map from LHLO binary op type to STD integer op type.
template <typename LhloOp>
using ScalarIOp = typename ScalarOp<LhloOp>::IOp;
using ScalarIOp = typename LhloToScalarOp<LhloOp>::IOp;
// Alias for the map from LHLO binary op type to STD complex op type.
template <typename LhloOp>
using ScalarCOp = typename LhloToScalarOp<LhloOp>::COp;
template <typename... Args>
struct MapLhloOpToStdScalarOpImpl {
@ -143,6 +142,16 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::AbsOp>(Location loc,
}
return nullptr;
}
template <>
inline Value MapLhloOpToStdScalarOp<lmhlo::AddOp>(Location loc,
ArrayRef<Type> result_types,
ArrayRef<Value> args,
OpBuilder* b) {
return MapLhloOpToStdScalarOpImpl<IntegerType, ScalarIOp<lmhlo::AddOp>,
FloatType, ScalarFOp<lmhlo::AddOp>,
ComplexType, ScalarCOp<lmhlo::AddOp>>{}(
loc, result_types, args, b);
}
template <>
inline Value MapLhloOpToStdScalarOp<lmhlo::AndOp>(Location loc,
@ -580,6 +589,17 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::SqrtOp>(Location loc,
loc, result_types, args, b);
}
template <>
inline Value MapLhloOpToStdScalarOp<lmhlo::SubOp>(Location loc,
ArrayRef<Type> result_types,
ArrayRef<Value> args,
OpBuilder* b) {
return MapLhloOpToStdScalarOpImpl<IntegerType, ScalarIOp<lmhlo::SubOp>,
FloatType, ScalarFOp<lmhlo::SubOp>,
ComplexType, ScalarCOp<lmhlo::SubOp>>{}(
loc, result_types, args, b);
}
template <>
inline Value MapLhloOpToStdScalarOp<lmhlo::TanhOp>(Location loc,
ArrayRef<Type> result_types,

View File

@ -29,6 +29,18 @@ func @integer_add(%lhs: tensor<2x2xi32>,
// -----
// CHECK-LABEL: complex_add
func @complex_add(%lhs: tensor<2x2xcomplex<f32>>,
%rhs: tensor<2x2xcomplex<f32>>) -> tensor<2x2xcomplex<f32>> {
// CHECK: linalg.generic
// CHECK: addcf
%0 = "mhlo.add"(%lhs, %rhs) : (tensor<2x2xcomplex<f32>>,
tensor<2x2xcomplex<f32>>) -> tensor<2x2xcomplex<f32>>
return %0 : tensor<2x2xcomplex<f32>>
}
// -----
// CHECK-LABEL: func @float_mul
func @float_mul(%lhs: tensor<2x2xf32>,
%rhs: tensor<2x2xf32>) -> tensor<2x2xf32> {
@ -112,6 +124,18 @@ func @integer_sub(%lhs: tensor<2x2xi32>,
// -----
// CHECK-LABEL: complex_sub
func @complex_sub(%lhs: tensor<2x2xcomplex<f32>>,
%rhs: tensor<2x2xcomplex<f32>>) -> tensor<2x2xcomplex<f32>> {
// CHECK: linalg.generic
// CHECK: subcf
%0 = "mhlo.subtract"(%lhs, %rhs) : (tensor<2x2xcomplex<f32>>,
tensor<2x2xcomplex<f32>>) -> tensor<2x2xcomplex<f32>>
return %0 : tensor<2x2xcomplex<f32>>
}
// -----
// CHECK-LABEL: func @float_abs
func @float_abs(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
// CHECK: linalg.generic