[KERNEL_GEN] Convert LHLO AddOp, SubOp (ComplexType) to complex ops.
PiperOrigin-RevId: 347805898
This commit is contained in:
parent
61244b136c
commit
65222893ae
|
@ -37,6 +37,7 @@ template <>
|
||||||
struct LhloToScalarOp<lmhlo::AddOp> {
|
struct LhloToScalarOp<lmhlo::AddOp> {
|
||||||
using FOp = ::mlir::AddFOp;
|
using FOp = ::mlir::AddFOp;
|
||||||
using IOp = ::mlir::AddIOp;
|
using IOp = ::mlir::AddIOp;
|
||||||
|
using COp = ::mlir::AddCFOp;
|
||||||
};
|
};
|
||||||
template <>
|
template <>
|
||||||
struct LhloToScalarOp<lmhlo::CompareOp> {
|
struct LhloToScalarOp<lmhlo::CompareOp> {
|
||||||
|
@ -62,20 +63,18 @@ template <>
|
||||||
struct LhloToScalarOp<lmhlo::SubOp> {
|
struct LhloToScalarOp<lmhlo::SubOp> {
|
||||||
using FOp = ::mlir::SubFOp;
|
using FOp = ::mlir::SubFOp;
|
||||||
using IOp = ::mlir::SubIOp;
|
using IOp = ::mlir::SubIOp;
|
||||||
};
|
using COp = ::mlir::SubCFOp;
|
||||||
|
|
||||||
template <typename LhloBinaryOpTy>
|
|
||||||
struct ScalarOp {
|
|
||||||
using FOp = typename LhloToScalarOp<LhloBinaryOpTy>::FOp;
|
|
||||||
using IOp = typename LhloToScalarOp<LhloBinaryOpTy>::IOp;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
// Alias for the map from LHLO binary op type to STD floating-point op type.
|
// Alias for the map from LHLO binary op type to STD floating-point op type.
|
||||||
template <typename LhloOp>
|
template <typename LhloOp>
|
||||||
using ScalarFOp = typename ScalarOp<LhloOp>::FOp;
|
using ScalarFOp = typename LhloToScalarOp<LhloOp>::FOp;
|
||||||
// Alias for the map from LHLO binary op type to STD integer op type.
|
// Alias for the map from LHLO binary op type to STD integer op type.
|
||||||
template <typename LhloOp>
|
template <typename LhloOp>
|
||||||
using ScalarIOp = typename ScalarOp<LhloOp>::IOp;
|
using ScalarIOp = typename LhloToScalarOp<LhloOp>::IOp;
|
||||||
|
// Alias for the map from LHLO binary op type to STD complex op type.
|
||||||
|
template <typename LhloOp>
|
||||||
|
using ScalarCOp = typename LhloToScalarOp<LhloOp>::COp;
|
||||||
|
|
||||||
template <typename... Args>
|
template <typename... Args>
|
||||||
struct MapLhloOpToStdScalarOpImpl {
|
struct MapLhloOpToStdScalarOpImpl {
|
||||||
|
@ -143,6 +142,16 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::AbsOp>(Location loc,
|
||||||
}
|
}
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
template <>
|
||||||
|
inline Value MapLhloOpToStdScalarOp<lmhlo::AddOp>(Location loc,
|
||||||
|
ArrayRef<Type> result_types,
|
||||||
|
ArrayRef<Value> args,
|
||||||
|
OpBuilder* b) {
|
||||||
|
return MapLhloOpToStdScalarOpImpl<IntegerType, ScalarIOp<lmhlo::AddOp>,
|
||||||
|
FloatType, ScalarFOp<lmhlo::AddOp>,
|
||||||
|
ComplexType, ScalarCOp<lmhlo::AddOp>>{}(
|
||||||
|
loc, result_types, args, b);
|
||||||
|
}
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
inline Value MapLhloOpToStdScalarOp<lmhlo::AndOp>(Location loc,
|
inline Value MapLhloOpToStdScalarOp<lmhlo::AndOp>(Location loc,
|
||||||
|
@ -580,6 +589,17 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::SqrtOp>(Location loc,
|
||||||
loc, result_types, args, b);
|
loc, result_types, args, b);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
inline Value MapLhloOpToStdScalarOp<lmhlo::SubOp>(Location loc,
|
||||||
|
ArrayRef<Type> result_types,
|
||||||
|
ArrayRef<Value> args,
|
||||||
|
OpBuilder* b) {
|
||||||
|
return MapLhloOpToStdScalarOpImpl<IntegerType, ScalarIOp<lmhlo::SubOp>,
|
||||||
|
FloatType, ScalarFOp<lmhlo::SubOp>,
|
||||||
|
ComplexType, ScalarCOp<lmhlo::SubOp>>{}(
|
||||||
|
loc, result_types, args, b);
|
||||||
|
}
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
inline Value MapLhloOpToStdScalarOp<lmhlo::TanhOp>(Location loc,
|
inline Value MapLhloOpToStdScalarOp<lmhlo::TanhOp>(Location loc,
|
||||||
ArrayRef<Type> result_types,
|
ArrayRef<Type> result_types,
|
||||||
|
|
|
@ -29,6 +29,18 @@ func @integer_add(%lhs: tensor<2x2xi32>,
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: complex_add
|
||||||
|
func @complex_add(%lhs: tensor<2x2xcomplex<f32>>,
|
||||||
|
%rhs: tensor<2x2xcomplex<f32>>) -> tensor<2x2xcomplex<f32>> {
|
||||||
|
// CHECK: linalg.generic
|
||||||
|
// CHECK: addcf
|
||||||
|
%0 = "mhlo.add"(%lhs, %rhs) : (tensor<2x2xcomplex<f32>>,
|
||||||
|
tensor<2x2xcomplex<f32>>) -> tensor<2x2xcomplex<f32>>
|
||||||
|
return %0 : tensor<2x2xcomplex<f32>>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
// CHECK-LABEL: func @float_mul
|
// CHECK-LABEL: func @float_mul
|
||||||
func @float_mul(%lhs: tensor<2x2xf32>,
|
func @float_mul(%lhs: tensor<2x2xf32>,
|
||||||
%rhs: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
%rhs: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
||||||
|
@ -112,6 +124,18 @@ func @integer_sub(%lhs: tensor<2x2xi32>,
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: complex_sub
|
||||||
|
func @complex_sub(%lhs: tensor<2x2xcomplex<f32>>,
|
||||||
|
%rhs: tensor<2x2xcomplex<f32>>) -> tensor<2x2xcomplex<f32>> {
|
||||||
|
// CHECK: linalg.generic
|
||||||
|
// CHECK: subcf
|
||||||
|
%0 = "mhlo.subtract"(%lhs, %rhs) : (tensor<2x2xcomplex<f32>>,
|
||||||
|
tensor<2x2xcomplex<f32>>) -> tensor<2x2xcomplex<f32>>
|
||||||
|
return %0 : tensor<2x2xcomplex<f32>>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
// CHECK-LABEL: func @float_abs
|
// CHECK-LABEL: func @float_abs
|
||||||
func @float_abs(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
func @float_abs(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
||||||
// CHECK: linalg.generic
|
// CHECK: linalg.generic
|
||||||
|
|
Loading…
Reference in New Issue