[MLIR] Migrate TF from STD complex ops to ComplexDialect.
PiperOrigin-RevId: 352966408
This commit is contained in:
parent
46112c95c6
commit
7aa64ee0b7
1
BUILD
1
BUILD
|
@ -593,6 +593,7 @@ cc_library(
|
||||||
":lhlo",
|
":lhlo",
|
||||||
":map_hlo_to_lhlo_op",
|
":map_hlo_to_lhlo_op",
|
||||||
"@llvm-project//llvm:Support",
|
"@llvm-project//llvm:Support",
|
||||||
|
"@llvm-project//mlir:ComplexDialect",
|
||||||
"@llvm-project//mlir:IR",
|
"@llvm-project//mlir:IR",
|
||||||
"@llvm-project//mlir:SCFDialect",
|
"@llvm-project//mlir:SCFDialect",
|
||||||
"@llvm-project//mlir:StandardOps",
|
"@llvm-project//mlir:StandardOps",
|
||||||
|
|
|
@ -23,6 +23,7 @@ limitations under the License.
|
||||||
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
||||||
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
|
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
|
||||||
#include "mlir-hlo/Dialect/mhlo/transforms/map_hlo_to_lhlo_op.h"
|
#include "mlir-hlo/Dialect/mhlo/transforms/map_hlo_to_lhlo_op.h"
|
||||||
|
#include "mlir/Dialect/Complex/IR/Complex.h"
|
||||||
#include "mlir/Dialect/SCF/SCF.h"
|
#include "mlir/Dialect/SCF/SCF.h"
|
||||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||||
#include "mlir/IR/BuiltinTypes.h"
|
#include "mlir/IR/BuiltinTypes.h"
|
||||||
|
@ -41,7 +42,7 @@ template <>
|
||||||
struct LhloToScalarOp<lmhlo::AddOp> {
|
struct LhloToScalarOp<lmhlo::AddOp> {
|
||||||
using FOp = ::mlir::AddFOp;
|
using FOp = ::mlir::AddFOp;
|
||||||
using IOp = ::mlir::AddIOp;
|
using IOp = ::mlir::AddIOp;
|
||||||
using COp = ::mlir::AddCFOp;
|
using COp = ::mlir::complex::AddOp;
|
||||||
};
|
};
|
||||||
template <>
|
template <>
|
||||||
struct LhloToScalarOp<lmhlo::CompareOp> {
|
struct LhloToScalarOp<lmhlo::CompareOp> {
|
||||||
|
@ -67,7 +68,7 @@ template <>
|
||||||
struct LhloToScalarOp<lmhlo::SubOp> {
|
struct LhloToScalarOp<lmhlo::SubOp> {
|
||||||
using FOp = ::mlir::SubFOp;
|
using FOp = ::mlir::SubFOp;
|
||||||
using IOp = ::mlir::SubIOp;
|
using IOp = ::mlir::SubIOp;
|
||||||
using COp = ::mlir::SubCFOp;
|
using COp = ::mlir::complex::SubOp;
|
||||||
};
|
};
|
||||||
|
|
||||||
// Alias for the map from LHLO binary op type to STD floating-point op type.
|
// Alias for the map from LHLO binary op type to STD floating-point op type.
|
||||||
|
@ -261,8 +262,8 @@ template <>
|
||||||
inline Value MapLhloOpToStdScalarOp<lmhlo::ComplexOp>(
|
inline Value MapLhloOpToStdScalarOp<lmhlo::ComplexOp>(
|
||||||
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
|
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
|
||||||
OpBuilder* b) {
|
OpBuilder* b) {
|
||||||
return MapLhloOpToStdScalarOpImpl<CreateComplexOp>{}(loc, result_types, args,
|
return MapLhloOpToStdScalarOpImpl<complex::CreateOp>{}(loc, result_types,
|
||||||
b);
|
args, b);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
|
@ -270,7 +271,8 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::RealOp>(Location loc,
|
||||||
ArrayRef<Type> result_types,
|
ArrayRef<Type> result_types,
|
||||||
ArrayRef<Value> args,
|
ArrayRef<Value> args,
|
||||||
OpBuilder* b) {
|
OpBuilder* b) {
|
||||||
return MapLhloOpToStdScalarOpImpl<ReOp>{}(loc, result_types, args, b);
|
return MapLhloOpToStdScalarOpImpl<complex::ReOp>{}(loc, result_types, args,
|
||||||
|
b);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
|
@ -278,7 +280,8 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::ImagOp>(Location loc,
|
||||||
ArrayRef<Type> result_types,
|
ArrayRef<Type> result_types,
|
||||||
ArrayRef<Value> args,
|
ArrayRef<Value> args,
|
||||||
OpBuilder* b) {
|
OpBuilder* b) {
|
||||||
return MapLhloOpToStdScalarOpImpl<ImOp>{}(loc, result_types, args, b);
|
return MapLhloOpToStdScalarOpImpl<complex::ImOp>{}(loc, result_types, args,
|
||||||
|
b);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
|
|
|
@ -1298,8 +1298,8 @@ struct LhloLegalizeToLinalgPass
|
||||||
void runOnFunction() override {
|
void runOnFunction() override {
|
||||||
OwningRewritePatternList patterns;
|
OwningRewritePatternList patterns;
|
||||||
ConversionTarget target(getContext());
|
ConversionTarget target(getContext());
|
||||||
target.addLegalDialect<linalg::LinalgDialect, StandardOpsDialect,
|
target.addLegalDialect<complex::ComplexDialect, linalg::LinalgDialect,
|
||||||
AffineDialect>();
|
StandardOpsDialect, AffineDialect>();
|
||||||
|
|
||||||
auto func = getFunction();
|
auto func = getFunction();
|
||||||
populateLHLOToLinalgConversionPattern(func.getContext(), &patterns);
|
populateLHLOToLinalgConversionPattern(func.getContext(), &patterns);
|
||||||
|
@ -1312,14 +1312,16 @@ struct LhloLegalizeToLinalgPass
|
||||||
struct HloLegalizeToLinalgPass
|
struct HloLegalizeToLinalgPass
|
||||||
: public PassWrapper<HloLegalizeToLinalgPass, FunctionPass> {
|
: public PassWrapper<HloLegalizeToLinalgPass, FunctionPass> {
|
||||||
void getDependentDialects(DialectRegistry& registry) const override {
|
void getDependentDialects(DialectRegistry& registry) const override {
|
||||||
registry.insert<linalg::LinalgDialect, scf::SCFDialect>();
|
registry.insert<linalg::LinalgDialect, scf::SCFDialect,
|
||||||
|
complex::ComplexDialect>();
|
||||||
}
|
}
|
||||||
|
|
||||||
void runOnFunction() override {
|
void runOnFunction() override {
|
||||||
OwningRewritePatternList patterns;
|
OwningRewritePatternList patterns;
|
||||||
ConversionTarget target(getContext());
|
ConversionTarget target(getContext());
|
||||||
target.addLegalDialect<linalg::LinalgDialect, StandardOpsDialect,
|
target.addLegalDialect<complex::ComplexDialect, linalg::LinalgDialect,
|
||||||
tensor::TensorDialect, scf::SCFDialect>();
|
StandardOpsDialect, tensor::TensorDialect,
|
||||||
|
scf::SCFDialect>();
|
||||||
|
|
||||||
auto func = getFunction();
|
auto func = getFunction();
|
||||||
mhlo::populateHLOToLinalgConversionPattern(func.getContext(), &patterns);
|
mhlo::populateHLOToLinalgConversionPattern(func.getContext(), &patterns);
|
||||||
|
|
|
@ -33,7 +33,7 @@ func @integer_add(%lhs: tensor<2x2xi32>,
|
||||||
func @complex_add(%lhs: tensor<2x2xcomplex<f32>>,
|
func @complex_add(%lhs: tensor<2x2xcomplex<f32>>,
|
||||||
%rhs: tensor<2x2xcomplex<f32>>) -> tensor<2x2xcomplex<f32>> {
|
%rhs: tensor<2x2xcomplex<f32>>) -> tensor<2x2xcomplex<f32>> {
|
||||||
// CHECK: linalg.generic
|
// CHECK: linalg.generic
|
||||||
// CHECK: addcf
|
// CHECK: complex.add
|
||||||
%0 = "mhlo.add"(%lhs, %rhs) : (tensor<2x2xcomplex<f32>>,
|
%0 = "mhlo.add"(%lhs, %rhs) : (tensor<2x2xcomplex<f32>>,
|
||||||
tensor<2x2xcomplex<f32>>) -> tensor<2x2xcomplex<f32>>
|
tensor<2x2xcomplex<f32>>) -> tensor<2x2xcomplex<f32>>
|
||||||
return %0 : tensor<2x2xcomplex<f32>>
|
return %0 : tensor<2x2xcomplex<f32>>
|
||||||
|
@ -128,7 +128,7 @@ func @integer_sub(%lhs: tensor<2x2xi32>,
|
||||||
func @complex_sub(%lhs: tensor<2x2xcomplex<f32>>,
|
func @complex_sub(%lhs: tensor<2x2xcomplex<f32>>,
|
||||||
%rhs: tensor<2x2xcomplex<f32>>) -> tensor<2x2xcomplex<f32>> {
|
%rhs: tensor<2x2xcomplex<f32>>) -> tensor<2x2xcomplex<f32>> {
|
||||||
// CHECK: linalg.generic
|
// CHECK: linalg.generic
|
||||||
// CHECK: subcf
|
// CHECK: complex.sub
|
||||||
%0 = "mhlo.subtract"(%lhs, %rhs) : (tensor<2x2xcomplex<f32>>,
|
%0 = "mhlo.subtract"(%lhs, %rhs) : (tensor<2x2xcomplex<f32>>,
|
||||||
tensor<2x2xcomplex<f32>>) -> tensor<2x2xcomplex<f32>>
|
tensor<2x2xcomplex<f32>>) -> tensor<2x2xcomplex<f32>>
|
||||||
return %0 : tensor<2x2xcomplex<f32>>
|
return %0 : tensor<2x2xcomplex<f32>>
|
||||||
|
|
|
@ -700,7 +700,7 @@ func @complex(%real: memref<2x2xf32>,
|
||||||
}
|
}
|
||||||
// CHECK: linalg.generic
|
// CHECK: linalg.generic
|
||||||
// CHECK-NEXT: ^bb0(%[[RE:.*]]: f32, %[[IM:.*]]: f32, %[[CP:.*]]: complex<f32>):
|
// CHECK-NEXT: ^bb0(%[[RE:.*]]: f32, %[[IM:.*]]: f32, %[[CP:.*]]: complex<f32>):
|
||||||
// CHECK-NEXT: %[[RESULT:.*]] = create_complex %[[RE]], %[[IM]] : complex<f32>
|
// CHECK-NEXT: %[[RESULT:.*]] = complex.create %[[RE]], %[[IM]] : complex<f32>
|
||||||
// CHECK-NEXT: linalg.yield %[[RESULT]] : complex<f32>
|
// CHECK-NEXT: linalg.yield %[[RESULT]] : complex<f32>
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
@ -714,7 +714,7 @@ func @real(%cplx: memref<2x2xcomplex<f32>>,
|
||||||
}
|
}
|
||||||
// CHECK: linalg.generic
|
// CHECK: linalg.generic
|
||||||
// CHECK-NEXT: ^bb0(%[[CPLX_IN:.*]]: complex<f32>, %[[REAL_OUT:.*]]: f32):
|
// CHECK-NEXT: ^bb0(%[[CPLX_IN:.*]]: complex<f32>, %[[REAL_OUT:.*]]: f32):
|
||||||
// CHECK-NEXT: %[[REAL:.*]] = re %[[CPLX_IN:.*]] : complex<f32>
|
// CHECK-NEXT: %[[REAL:.*]] = complex.re %[[CPLX_IN:.*]] : complex<f32>
|
||||||
// CHECK-NEXT: linalg.yield %[[REAL]] : f32
|
// CHECK-NEXT: linalg.yield %[[REAL]] : f32
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
@ -728,7 +728,7 @@ func @imag(%cplx: memref<2x2xcomplex<f32>>,
|
||||||
}
|
}
|
||||||
// CHECK: linalg.generic
|
// CHECK: linalg.generic
|
||||||
// CHECK-NEXT: ^bb0(%[[CPLX_IN:.*]]: complex<f32>, %[[IMAG_OUT:.*]]: f32):
|
// CHECK-NEXT: ^bb0(%[[CPLX_IN:.*]]: complex<f32>, %[[IMAG_OUT:.*]]: f32):
|
||||||
// CHECK-NEXT: %[[IMAG:.*]] = im %[[CPLX_IN:.*]] : complex<f32>
|
// CHECK-NEXT: %[[IMAG:.*]] = complex.im %[[CPLX_IN:.*]] : complex<f32>
|
||||||
// CHECK-NEXT: linalg.yield %[[IMAG]] : f32
|
// CHECK-NEXT: linalg.yield %[[IMAG]] : f32
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
Loading…
Reference in New Issue