diff --git a/include/mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h b/include/mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h index a311ae7..791f456 100644 --- a/include/mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h +++ b/include/mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h @@ -37,6 +37,7 @@ template <> struct LhloToScalarOp { using FOp = ::mlir::AddFOp; using IOp = ::mlir::AddIOp; + using COp = ::mlir::AddCFOp; }; template <> struct LhloToScalarOp { @@ -62,20 +63,18 @@ template <> struct LhloToScalarOp { using FOp = ::mlir::SubFOp; using IOp = ::mlir::SubIOp; -}; - -template -struct ScalarOp { - using FOp = typename LhloToScalarOp::FOp; - using IOp = typename LhloToScalarOp::IOp; + using COp = ::mlir::SubCFOp; }; // Alias for the map from LHLO binary op type to STD floating-point op type. template -using ScalarFOp = typename ScalarOp::FOp; +using ScalarFOp = typename LhloToScalarOp::FOp; // Alias for the map from LHLO binary op type to STD integer op type. template -using ScalarIOp = typename ScalarOp::IOp; +using ScalarIOp = typename LhloToScalarOp::IOp; +// Alias for the map from LHLO binary op type to STD complex op type. +template +using ScalarCOp = typename LhloToScalarOp::COp; template struct MapLhloOpToStdScalarOpImpl { @@ -143,6 +142,16 @@ inline Value MapLhloOpToStdScalarOp(Location loc, } return nullptr; } +template <> +inline Value MapLhloOpToStdScalarOp(Location loc, + ArrayRef result_types, + ArrayRef args, + OpBuilder* b) { + return MapLhloOpToStdScalarOpImpl, + FloatType, ScalarFOp, + ComplexType, ScalarCOp>{}( + loc, result_types, args, b); +} template <> inline Value MapLhloOpToStdScalarOp(Location loc, @@ -580,6 +589,17 @@ inline Value MapLhloOpToStdScalarOp(Location loc, loc, result_types, args, b); } +template <> +inline Value MapLhloOpToStdScalarOp(Location loc, + ArrayRef result_types, + ArrayRef args, + OpBuilder* b) { + return MapLhloOpToStdScalarOpImpl, + FloatType, ScalarFOp, + ComplexType, ScalarCOp>{}( + loc, result_types, args, b); +} + template <> inline Value MapLhloOpToStdScalarOp(Location loc, ArrayRef result_types, diff --git a/tests/hlo-legalize-to-linalg.mlir b/tests/hlo-legalize-to-linalg.mlir index 517dca8..1932de0 100644 --- a/tests/hlo-legalize-to-linalg.mlir +++ b/tests/hlo-legalize-to-linalg.mlir @@ -29,6 +29,18 @@ func @integer_add(%lhs: tensor<2x2xi32>, // ----- +// CHECK-LABEL: complex_add +func @complex_add(%lhs: tensor<2x2xcomplex>, + %rhs: tensor<2x2xcomplex>) -> tensor<2x2xcomplex> { + // CHECK: linalg.generic + // CHECK: addcf + %0 = "mhlo.add"(%lhs, %rhs) : (tensor<2x2xcomplex>, + tensor<2x2xcomplex>) -> tensor<2x2xcomplex> + return %0 : tensor<2x2xcomplex> +} + +// ----- + // CHECK-LABEL: func @float_mul func @float_mul(%lhs: tensor<2x2xf32>, %rhs: tensor<2x2xf32>) -> tensor<2x2xf32> { @@ -112,6 +124,18 @@ func @integer_sub(%lhs: tensor<2x2xi32>, // ----- +// CHECK-LABEL: complex_sub +func @complex_sub(%lhs: tensor<2x2xcomplex>, + %rhs: tensor<2x2xcomplex>) -> tensor<2x2xcomplex> { + // CHECK: linalg.generic + // CHECK: subcf + %0 = "mhlo.subtract"(%lhs, %rhs) : (tensor<2x2xcomplex>, + tensor<2x2xcomplex>) -> tensor<2x2xcomplex> + return %0 : tensor<2x2xcomplex> +} + +// ----- + // CHECK-LABEL: func @float_abs func @float_abs(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> { // CHECK: linalg.generic