[KERNEL_GEN] Convert LHLO AddOp, SubOp (ComplexType) to complex ops.
PiperOrigin-RevId: 347805898
This commit is contained in:
parent
61244b136c
commit
65222893ae
|
@ -37,6 +37,7 @@ template <>
|
|||
struct LhloToScalarOp<lmhlo::AddOp> {
|
||||
using FOp = ::mlir::AddFOp;
|
||||
using IOp = ::mlir::AddIOp;
|
||||
using COp = ::mlir::AddCFOp;
|
||||
};
|
||||
template <>
|
||||
struct LhloToScalarOp<lmhlo::CompareOp> {
|
||||
|
@ -62,20 +63,18 @@ template <>
|
|||
struct LhloToScalarOp<lmhlo::SubOp> {
|
||||
using FOp = ::mlir::SubFOp;
|
||||
using IOp = ::mlir::SubIOp;
|
||||
};
|
||||
|
||||
template <typename LhloBinaryOpTy>
|
||||
struct ScalarOp {
|
||||
using FOp = typename LhloToScalarOp<LhloBinaryOpTy>::FOp;
|
||||
using IOp = typename LhloToScalarOp<LhloBinaryOpTy>::IOp;
|
||||
using COp = ::mlir::SubCFOp;
|
||||
};
|
||||
|
||||
// Alias for the map from LHLO binary op type to STD floating-point op type.
|
||||
template <typename LhloOp>
|
||||
using ScalarFOp = typename ScalarOp<LhloOp>::FOp;
|
||||
using ScalarFOp = typename LhloToScalarOp<LhloOp>::FOp;
|
||||
// Alias for the map from LHLO binary op type to STD integer op type.
|
||||
template <typename LhloOp>
|
||||
using ScalarIOp = typename ScalarOp<LhloOp>::IOp;
|
||||
using ScalarIOp = typename LhloToScalarOp<LhloOp>::IOp;
|
||||
// Alias for the map from LHLO binary op type to STD complex op type.
|
||||
template <typename LhloOp>
|
||||
using ScalarCOp = typename LhloToScalarOp<LhloOp>::COp;
|
||||
|
||||
template <typename... Args>
|
||||
struct MapLhloOpToStdScalarOpImpl {
|
||||
|
@ -143,6 +142,16 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::AbsOp>(Location loc,
|
|||
}
|
||||
return nullptr;
|
||||
}
|
||||
template <>
|
||||
inline Value MapLhloOpToStdScalarOp<lmhlo::AddOp>(Location loc,
|
||||
ArrayRef<Type> result_types,
|
||||
ArrayRef<Value> args,
|
||||
OpBuilder* b) {
|
||||
return MapLhloOpToStdScalarOpImpl<IntegerType, ScalarIOp<lmhlo::AddOp>,
|
||||
FloatType, ScalarFOp<lmhlo::AddOp>,
|
||||
ComplexType, ScalarCOp<lmhlo::AddOp>>{}(
|
||||
loc, result_types, args, b);
|
||||
}
|
||||
|
||||
template <>
|
||||
inline Value MapLhloOpToStdScalarOp<lmhlo::AndOp>(Location loc,
|
||||
|
@ -580,6 +589,17 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::SqrtOp>(Location loc,
|
|||
loc, result_types, args, b);
|
||||
}
|
||||
|
||||
template <>
|
||||
inline Value MapLhloOpToStdScalarOp<lmhlo::SubOp>(Location loc,
|
||||
ArrayRef<Type> result_types,
|
||||
ArrayRef<Value> args,
|
||||
OpBuilder* b) {
|
||||
return MapLhloOpToStdScalarOpImpl<IntegerType, ScalarIOp<lmhlo::SubOp>,
|
||||
FloatType, ScalarFOp<lmhlo::SubOp>,
|
||||
ComplexType, ScalarCOp<lmhlo::SubOp>>{}(
|
||||
loc, result_types, args, b);
|
||||
}
|
||||
|
||||
template <>
|
||||
inline Value MapLhloOpToStdScalarOp<lmhlo::TanhOp>(Location loc,
|
||||
ArrayRef<Type> result_types,
|
||||
|
|
|
@ -29,6 +29,18 @@ func @integer_add(%lhs: tensor<2x2xi32>,
|
|||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: complex_add
|
||||
func @complex_add(%lhs: tensor<2x2xcomplex<f32>>,
|
||||
%rhs: tensor<2x2xcomplex<f32>>) -> tensor<2x2xcomplex<f32>> {
|
||||
// CHECK: linalg.generic
|
||||
// CHECK: addcf
|
||||
%0 = "mhlo.add"(%lhs, %rhs) : (tensor<2x2xcomplex<f32>>,
|
||||
tensor<2x2xcomplex<f32>>) -> tensor<2x2xcomplex<f32>>
|
||||
return %0 : tensor<2x2xcomplex<f32>>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @float_mul
|
||||
func @float_mul(%lhs: tensor<2x2xf32>,
|
||||
%rhs: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
||||
|
@ -112,6 +124,18 @@ func @integer_sub(%lhs: tensor<2x2xi32>,
|
|||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: complex_sub
|
||||
func @complex_sub(%lhs: tensor<2x2xcomplex<f32>>,
|
||||
%rhs: tensor<2x2xcomplex<f32>>) -> tensor<2x2xcomplex<f32>> {
|
||||
// CHECK: linalg.generic
|
||||
// CHECK: subcf
|
||||
%0 = "mhlo.subtract"(%lhs, %rhs) : (tensor<2x2xcomplex<f32>>,
|
||||
tensor<2x2xcomplex<f32>>) -> tensor<2x2xcomplex<f32>>
|
||||
return %0 : tensor<2x2xcomplex<f32>>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @float_abs
|
||||
func @float_abs(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
||||
// CHECK: linalg.generic
|
||||
|
|
Loading…
Reference in New Issue