Support complex types when converting HLO divide op.
We can lower it to the DivOp in the complex dialect. Also add tests to hlo-legalize-to-linalg.mlir for CompareOp lowering of complex types. These were forgotten in a previous commit. PiperOrigin-RevId: 375669125
This commit is contained in:
parent
8e28008e38
commit
5816920258
|
@ -58,6 +58,7 @@ struct LhloToScalarOp<lmhlo::DivOp> {
|
|||
using FOp = ::mlir::DivFOp;
|
||||
using IOp = ::mlir::SignedDivIOp;
|
||||
using UOp = ::mlir::UnsignedDivIOp;
|
||||
using COp = ::mlir::complex::DivOp;
|
||||
};
|
||||
template <>
|
||||
struct LhloToScalarOp<lmhlo::MulOp> {
|
||||
|
@ -192,6 +193,7 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::AbsOp>(Location loc,
|
|||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
template <>
|
||||
inline Value MapLhloOpToStdScalarOp<lmhlo::AddOp>(Location loc,
|
||||
ArrayRef<Type> result_types,
|
||||
|
@ -301,6 +303,19 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::CopyOp>(Location loc,
|
|||
return args.front();
|
||||
}
|
||||
|
||||
template <>
|
||||
inline Value MapLhloOpToStdScalarOp<lmhlo::DivOp>(Location loc,
|
||||
ArrayRef<Type> result_types,
|
||||
ArrayRef<Type> arg_types,
|
||||
ArrayRef<Value> args,
|
||||
OpBuilder* b) {
|
||||
return MapLhloOpToScalarOpImpl<isSignedIntegerType, ScalarIOp<lmhlo::DivOp>,
|
||||
isUnsignedIntegerType, ScalarUOp<lmhlo::DivOp>,
|
||||
isFloatType, ScalarFOp<lmhlo::DivOp>,
|
||||
isComplexType, ScalarCOp<lmhlo::DivOp>>{}(
|
||||
loc, result_types, arg_types, args, b);
|
||||
}
|
||||
|
||||
template <>
|
||||
inline Value MapLhloOpToStdScalarOp<lmhlo::ExpOp>(Location loc,
|
||||
ArrayRef<Type> result_types,
|
||||
|
|
|
@ -323,6 +323,36 @@ func @int_cmp(%lhs: tensor<2x2xi32>,
|
|||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @complex_cmp_eq
|
||||
func @complex_cmp_eq(%lhs: tensor<2xcomplex<f32>>,
|
||||
%rhs: tensor<2xcomplex<f32>>) -> tensor<2xi1> {
|
||||
%0 = "mhlo.compare"(%lhs, %rhs) {comparison_direction = "EQ"}
|
||||
: (tensor<2xcomplex<f32>>, tensor<2xcomplex<f32>>) -> (tensor<2xi1>)
|
||||
return %0 : tensor<2xi1>
|
||||
}
|
||||
// CHECK: linalg.init_tensor [2] : tensor<2xi1>
|
||||
// CHECK: linalg.generic
|
||||
// CHECK-NEXT: ^bb0(%[[LHS_IN:.*]]: complex<f32>, %[[RHS_IN:.*]]: complex<f32>, %[[RESULT_OUT:.*]]: i1):
|
||||
// CHECK-NEXT: %[[RESULT:.*]] = complex.eq %[[LHS_IN]], %[[RHS_IN]] : complex<f32>
|
||||
// CHECK-NEXT: linalg.yield %[[RESULT]] : i1
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @complex_cmp_neq
|
||||
func @complex_cmp_neq(%lhs: tensor<2xcomplex<f64>>,
|
||||
%rhs: tensor<2xcomplex<f64>>) -> tensor<2xi1> {
|
||||
%0 = "mhlo.compare"(%lhs, %rhs) {comparison_direction = "NE"}
|
||||
: (tensor<2xcomplex<f64>>, tensor<2xcomplex<f64>>) -> (tensor<2xi1>)
|
||||
return %0 : tensor<2xi1>
|
||||
}
|
||||
// CHECK: linalg.init_tensor [2] : tensor<2xi1>
|
||||
// CHECK: linalg.generic
|
||||
// CHECK-NEXT: ^bb0(%[[LHS_IN:.*]]: complex<f64>, %[[RHS_IN:.*]]: complex<f64>, %[[RESULT_OUT:.*]]: i1):
|
||||
// CHECK-NEXT: %[[RESULT:.*]] = complex.neq %[[LHS_IN]], %[[RHS_IN]] : complex<f64>
|
||||
// CHECK-NEXT: linalg.yield %[[RESULT]] : i1
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @float_cos
|
||||
func @float_cos(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
||||
// CHECK: linalg.generic
|
||||
|
@ -2352,6 +2382,17 @@ func @unsigned_divide(%lhs: tensor<2x2xui32>, %rhs: tensor<2x2xui32>) -> tensor<
|
|||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: complex_divide
|
||||
func @complex_divide(%lhs: tensor<2xcomplex<f32>>,
|
||||
%rhs: tensor<2xcomplex<f32>>) -> tensor<2xcomplex<f32>> {
|
||||
// CHECK: linalg.generic
|
||||
// CHECK: complex.div
|
||||
%0 = "mhlo.divide"(%lhs, %rhs) : (tensor<2xcomplex<f32>>, tensor<2xcomplex<f32>>) -> tensor<2xcomplex<f32>>
|
||||
return %0 : tensor<2xcomplex<f32>>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: unsigned_remainder
|
||||
func @unsigned_remainder(%lhs: tensor<2x2xui32>, %rhs: tensor<2x2xui32>) -> tensor<2x2xui32> {
|
||||
// CHECK: linalg.generic
|
||||
|
|
|
@ -158,8 +158,8 @@ func @copy(%in: memref<2x4x8xf32>, %out: memref<2x4x8xf32>) {
|
|||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @is_finte
|
||||
func @is_finte(%input: memref<2x2xf32>, %result: memref<2x2xi1>) {
|
||||
// CHECK-LABEL: func @is_finite
|
||||
func @is_finite(%input: memref<2x2xf32>, %result: memref<2x2xi1>) {
|
||||
"lmhlo.is_finite"(%input, %result) : (memref<2x2xf32>, memref<2x2xi1>) -> ()
|
||||
return
|
||||
}
|
||||
|
@ -228,6 +228,20 @@ func @complex_cmp_neq(%lhs: memref<2xcomplex<f64>>, %rhs: memref<2xcomplex<f64>>
|
|||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @complex_divide
|
||||
func @complex_divide(%lhs: memref<2xcomplex<f64>>, %rhs: memref<2xcomplex<f64>>,
|
||||
%result: memref<2xcomplex<f64>>) {
|
||||
"lmhlo.divide"(%lhs, %rhs, %result)
|
||||
: (memref<2xcomplex<f64>>, memref<2xcomplex<f64>>, memref<2xcomplex<f64>>) -> ()
|
||||
return
|
||||
}
|
||||
// CHECK: linalg.generic
|
||||
// CHECK-NEXT: ^bb0(%[[LHS_IN:.*]]: complex<f64>, %[[RHS_IN:.*]]: complex<f64>, %[[RESULT_OUT:.*]]: complex<f64>):
|
||||
// CHECK-NEXT: %[[RESULT:.*]] = complex.div %[[LHS_IN]], %[[RHS_IN]] : complex<f64>
|
||||
// CHECK-NEXT: linalg.yield %[[RESULT]] : complex<f64>
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @select
|
||||
func @select(%pred: memref<2x2xi1>, %lhs: memref<2x2xf32>,
|
||||
%rhs: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||
|
|
Loading…
Reference in New Issue