diff --git a/BUILD b/BUILD index 9a22ef8..0b19c55 100644 --- a/BUILD +++ b/BUILD @@ -593,6 +593,7 @@ cc_library( ":lhlo", ":map_hlo_to_lhlo_op", "@llvm-project//llvm:Support", + "@llvm-project//mlir:ComplexDialect", "@llvm-project//mlir:IR", "@llvm-project//mlir:SCFDialect", "@llvm-project//mlir:StandardOps", diff --git a/include/mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h b/include/mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h index 7ce33fb..77600e0 100644 --- a/include/mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h +++ b/include/mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h @@ -23,6 +23,7 @@ limitations under the License. #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" #include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h" #include "mlir-hlo/Dialect/mhlo/transforms/map_hlo_to_lhlo_op.h" +#include "mlir/Dialect/Complex/IR/Complex.h" #include "mlir/Dialect/SCF/SCF.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/IR/BuiltinTypes.h" @@ -41,7 +42,7 @@ template <> struct LhloToScalarOp { using FOp = ::mlir::AddFOp; using IOp = ::mlir::AddIOp; - using COp = ::mlir::AddCFOp; + using COp = ::mlir::complex::AddOp; }; template <> struct LhloToScalarOp { @@ -67,7 +68,7 @@ template <> struct LhloToScalarOp { using FOp = ::mlir::SubFOp; using IOp = ::mlir::SubIOp; - using COp = ::mlir::SubCFOp; + using COp = ::mlir::complex::SubOp; }; // Alias for the map from LHLO binary op type to STD floating-point op type. @@ -261,8 +262,8 @@ template <> inline Value MapLhloOpToStdScalarOp( Location loc, ArrayRef result_types, ArrayRef args, OpBuilder* b) { - return MapLhloOpToStdScalarOpImpl{}(loc, result_types, args, - b); + return MapLhloOpToStdScalarOpImpl{}(loc, result_types, + args, b); } template <> @@ -270,7 +271,8 @@ inline Value MapLhloOpToStdScalarOp(Location loc, ArrayRef result_types, ArrayRef args, OpBuilder* b) { - return MapLhloOpToStdScalarOpImpl{}(loc, result_types, args, b); + return MapLhloOpToStdScalarOpImpl{}(loc, result_types, args, + b); } template <> @@ -278,7 +280,8 @@ inline Value MapLhloOpToStdScalarOp(Location loc, ArrayRef result_types, ArrayRef args, OpBuilder* b) { - return MapLhloOpToStdScalarOpImpl{}(loc, result_types, args, b); + return MapLhloOpToStdScalarOpImpl{}(loc, result_types, args, + b); } template <> diff --git a/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc b/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc index 2cee0b4..7f630a5 100644 --- a/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc +++ b/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc @@ -1298,8 +1298,8 @@ struct LhloLegalizeToLinalgPass void runOnFunction() override { OwningRewritePatternList patterns; ConversionTarget target(getContext()); - target.addLegalDialect(); + target.addLegalDialect(); auto func = getFunction(); populateLHLOToLinalgConversionPattern(func.getContext(), &patterns); @@ -1312,14 +1312,16 @@ struct LhloLegalizeToLinalgPass struct HloLegalizeToLinalgPass : public PassWrapper { void getDependentDialects(DialectRegistry& registry) const override { - registry.insert(); + registry.insert(); } void runOnFunction() override { OwningRewritePatternList patterns; ConversionTarget target(getContext()); - target.addLegalDialect(); + target.addLegalDialect(); auto func = getFunction(); mhlo::populateHLOToLinalgConversionPattern(func.getContext(), &patterns); diff --git a/tests/hlo-legalize-to-linalg.mlir b/tests/hlo-legalize-to-linalg.mlir index 4b2f354..ade42d8 100644 --- a/tests/hlo-legalize-to-linalg.mlir +++ b/tests/hlo-legalize-to-linalg.mlir @@ -33,7 +33,7 @@ func @integer_add(%lhs: tensor<2x2xi32>, func @complex_add(%lhs: tensor<2x2xcomplex>, %rhs: tensor<2x2xcomplex>) -> tensor<2x2xcomplex> { // CHECK: linalg.generic - // CHECK: addcf + // CHECK: complex.add %0 = "mhlo.add"(%lhs, %rhs) : (tensor<2x2xcomplex>, tensor<2x2xcomplex>) -> tensor<2x2xcomplex> return %0 : tensor<2x2xcomplex> @@ -128,7 +128,7 @@ func @integer_sub(%lhs: tensor<2x2xi32>, func @complex_sub(%lhs: tensor<2x2xcomplex>, %rhs: tensor<2x2xcomplex>) -> tensor<2x2xcomplex> { // CHECK: linalg.generic - // CHECK: subcf + // CHECK: complex.sub %0 = "mhlo.subtract"(%lhs, %rhs) : (tensor<2x2xcomplex>, tensor<2x2xcomplex>) -> tensor<2x2xcomplex> return %0 : tensor<2x2xcomplex> diff --git a/tests/lhlo-legalize-to-linalg.mlir b/tests/lhlo-legalize-to-linalg.mlir index 3c75795..dfcc3c7 100644 --- a/tests/lhlo-legalize-to-linalg.mlir +++ b/tests/lhlo-legalize-to-linalg.mlir @@ -700,7 +700,7 @@ func @complex(%real: memref<2x2xf32>, } // CHECK: linalg.generic // CHECK-NEXT: ^bb0(%[[RE:.*]]: f32, %[[IM:.*]]: f32, %[[CP:.*]]: complex): -// CHECK-NEXT: %[[RESULT:.*]] = create_complex %[[RE]], %[[IM]] : complex +// CHECK-NEXT: %[[RESULT:.*]] = complex.create %[[RE]], %[[IM]] : complex // CHECK-NEXT: linalg.yield %[[RESULT]] : complex // ----- @@ -714,7 +714,7 @@ func @real(%cplx: memref<2x2xcomplex>, } // CHECK: linalg.generic // CHECK-NEXT: ^bb0(%[[CPLX_IN:.*]]: complex, %[[REAL_OUT:.*]]: f32): -// CHECK-NEXT: %[[REAL:.*]] = re %[[CPLX_IN:.*]] : complex +// CHECK-NEXT: %[[REAL:.*]] = complex.re %[[CPLX_IN:.*]] : complex // CHECK-NEXT: linalg.yield %[[REAL]] : f32 // ----- @@ -728,7 +728,7 @@ func @imag(%cplx: memref<2x2xcomplex>, } // CHECK: linalg.generic // CHECK-NEXT: ^bb0(%[[CPLX_IN:.*]]: complex, %[[IMAG_OUT:.*]]: f32): -// CHECK-NEXT: %[[IMAG:.*]] = im %[[CPLX_IN:.*]] : complex +// CHECK-NEXT: %[[IMAG:.*]] = complex.im %[[CPLX_IN:.*]] : complex // CHECK-NEXT: linalg.yield %[[IMAG]] : f32 // -----