[KERNEL_GEN] Convert LHLO AddOp, SubOp (ComplexType) to complex ops.

PiperOrigin-RevId: 347805898
This commit is contained in:
Alexander Belyaev 2020-12-16 05:44:38 -08:00 committed by TensorFlow MLIR Team
parent 61244b136c
commit 65222893ae
2 changed files with 52 additions and 8 deletions

View File

@ -37,6 +37,7 @@ template <>
struct LhloToScalarOp<lmhlo::AddOp> { struct LhloToScalarOp<lmhlo::AddOp> {
using FOp = ::mlir::AddFOp; using FOp = ::mlir::AddFOp;
using IOp = ::mlir::AddIOp; using IOp = ::mlir::AddIOp;
using COp = ::mlir::AddCFOp;
}; };
template <> template <>
struct LhloToScalarOp<lmhlo::CompareOp> { struct LhloToScalarOp<lmhlo::CompareOp> {
@ -62,20 +63,18 @@ template <>
struct LhloToScalarOp<lmhlo::SubOp> { struct LhloToScalarOp<lmhlo::SubOp> {
using FOp = ::mlir::SubFOp; using FOp = ::mlir::SubFOp;
using IOp = ::mlir::SubIOp; using IOp = ::mlir::SubIOp;
}; using COp = ::mlir::SubCFOp;
template <typename LhloBinaryOpTy>
struct ScalarOp {
using FOp = typename LhloToScalarOp<LhloBinaryOpTy>::FOp;
using IOp = typename LhloToScalarOp<LhloBinaryOpTy>::IOp;
}; };
// Alias for the map from LHLO binary op type to STD floating-point op type. // Alias for the map from LHLO binary op type to STD floating-point op type.
template <typename LhloOp> template <typename LhloOp>
using ScalarFOp = typename ScalarOp<LhloOp>::FOp; using ScalarFOp = typename LhloToScalarOp<LhloOp>::FOp;
// Alias for the map from LHLO binary op type to STD integer op type. // Alias for the map from LHLO binary op type to STD integer op type.
template <typename LhloOp> template <typename LhloOp>
using ScalarIOp = typename ScalarOp<LhloOp>::IOp; using ScalarIOp = typename LhloToScalarOp<LhloOp>::IOp;
// Alias for the map from LHLO binary op type to STD complex op type.
template <typename LhloOp>
using ScalarCOp = typename LhloToScalarOp<LhloOp>::COp;
template <typename... Args> template <typename... Args>
struct MapLhloOpToStdScalarOpImpl { struct MapLhloOpToStdScalarOpImpl {
@ -143,6 +142,16 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::AbsOp>(Location loc,
} }
return nullptr; return nullptr;
} }
template <>
inline Value MapLhloOpToStdScalarOp<lmhlo::AddOp>(Location loc,
ArrayRef<Type> result_types,
ArrayRef<Value> args,
OpBuilder* b) {
return MapLhloOpToStdScalarOpImpl<IntegerType, ScalarIOp<lmhlo::AddOp>,
FloatType, ScalarFOp<lmhlo::AddOp>,
ComplexType, ScalarCOp<lmhlo::AddOp>>{}(
loc, result_types, args, b);
}
template <> template <>
inline Value MapLhloOpToStdScalarOp<lmhlo::AndOp>(Location loc, inline Value MapLhloOpToStdScalarOp<lmhlo::AndOp>(Location loc,
@ -580,6 +589,17 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::SqrtOp>(Location loc,
loc, result_types, args, b); loc, result_types, args, b);
} }
template <>
inline Value MapLhloOpToStdScalarOp<lmhlo::SubOp>(Location loc,
ArrayRef<Type> result_types,
ArrayRef<Value> args,
OpBuilder* b) {
return MapLhloOpToStdScalarOpImpl<IntegerType, ScalarIOp<lmhlo::SubOp>,
FloatType, ScalarFOp<lmhlo::SubOp>,
ComplexType, ScalarCOp<lmhlo::SubOp>>{}(
loc, result_types, args, b);
}
template <> template <>
inline Value MapLhloOpToStdScalarOp<lmhlo::TanhOp>(Location loc, inline Value MapLhloOpToStdScalarOp<lmhlo::TanhOp>(Location loc,
ArrayRef<Type> result_types, ArrayRef<Type> result_types,

View File

@ -29,6 +29,18 @@ func @integer_add(%lhs: tensor<2x2xi32>,
// ----- // -----
// CHECK-LABEL: complex_add
func @complex_add(%lhs: tensor<2x2xcomplex<f32>>,
%rhs: tensor<2x2xcomplex<f32>>) -> tensor<2x2xcomplex<f32>> {
// CHECK: linalg.generic
// CHECK: addcf
%0 = "mhlo.add"(%lhs, %rhs) : (tensor<2x2xcomplex<f32>>,
tensor<2x2xcomplex<f32>>) -> tensor<2x2xcomplex<f32>>
return %0 : tensor<2x2xcomplex<f32>>
}
// -----
// CHECK-LABEL: func @float_mul // CHECK-LABEL: func @float_mul
func @float_mul(%lhs: tensor<2x2xf32>, func @float_mul(%lhs: tensor<2x2xf32>,
%rhs: tensor<2x2xf32>) -> tensor<2x2xf32> { %rhs: tensor<2x2xf32>) -> tensor<2x2xf32> {
@ -112,6 +124,18 @@ func @integer_sub(%lhs: tensor<2x2xi32>,
// ----- // -----
// CHECK-LABEL: complex_sub
func @complex_sub(%lhs: tensor<2x2xcomplex<f32>>,
%rhs: tensor<2x2xcomplex<f32>>) -> tensor<2x2xcomplex<f32>> {
// CHECK: linalg.generic
// CHECK: subcf
%0 = "mhlo.subtract"(%lhs, %rhs) : (tensor<2x2xcomplex<f32>>,
tensor<2x2xcomplex<f32>>) -> tensor<2x2xcomplex<f32>>
return %0 : tensor<2x2xcomplex<f32>>
}
// -----
// CHECK-LABEL: func @float_abs // CHECK-LABEL: func @float_abs
func @float_abs(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> { func @float_abs(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
// CHECK: linalg.generic // CHECK: linalg.generic