Support complex types when converting HLO multiply op.
We can lower it to the MulOp in the complex dialect. PiperOrigin-RevId: 375675079
This commit is contained in:
parent
5816920258
commit
a847109ac7
|
@ -65,6 +65,7 @@ struct LhloToScalarOp<lmhlo::MulOp> {
|
||||||
using FOp = ::mlir::MulFOp;
|
using FOp = ::mlir::MulFOp;
|
||||||
using IOp = ::mlir::MulIOp;
|
using IOp = ::mlir::MulIOp;
|
||||||
using UOp = ::mlir::MulIOp;
|
using UOp = ::mlir::MulIOp;
|
||||||
|
using COp = ::mlir::complex::MulOp;
|
||||||
};
|
};
|
||||||
template <>
|
template <>
|
||||||
struct LhloToScalarOp<lmhlo::RemOp> {
|
struct LhloToScalarOp<lmhlo::RemOp> {
|
||||||
|
@ -631,6 +632,19 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::MinOp>(Location loc,
|
||||||
args, loc, b);
|
args, loc, b);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
inline Value MapLhloOpToStdScalarOp<lmhlo::MulOp>(Location loc,
|
||||||
|
ArrayRef<Type> result_types,
|
||||||
|
ArrayRef<Type> arg_types,
|
||||||
|
ArrayRef<Value> args,
|
||||||
|
OpBuilder* b) {
|
||||||
|
return MapLhloOpToScalarOpImpl<isSignedIntegerType, ScalarIOp<lmhlo::MulOp>,
|
||||||
|
isUnsignedIntegerType, ScalarUOp<lmhlo::MulOp>,
|
||||||
|
isFloatType, ScalarFOp<lmhlo::MulOp>,
|
||||||
|
isComplexType, ScalarCOp<lmhlo::MulOp>>{}(
|
||||||
|
loc, result_types, arg_types, args, b);
|
||||||
|
}
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
inline Value MapLhloOpToStdScalarOp<lmhlo::ClampOp>(Location loc,
|
inline Value MapLhloOpToStdScalarOp<lmhlo::ClampOp>(Location loc,
|
||||||
ArrayRef<Type> result_types,
|
ArrayRef<Type> result_types,
|
||||||
|
|
|
@ -65,6 +65,19 @@ func @integer_mul(%lhs: tensor<2x2xi32>,
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @complex_mul
|
||||||
|
func @complex_mul(%lhs: tensor<2x2xcomplex<f32>>,
|
||||||
|
%rhs: tensor<2x2xcomplex<f32>>) -> tensor<2x2xcomplex<f32>> {
|
||||||
|
// CHECK: linalg.generic
|
||||||
|
// CHECK: complex.mul
|
||||||
|
%0 = "mhlo.multiply"(%lhs, %rhs)
|
||||||
|
: (tensor<2x2xcomplex<f32>>, tensor<2x2xcomplex<f32>>)
|
||||||
|
-> tensor<2x2xcomplex<f32>>
|
||||||
|
return %0 : tensor<2x2xcomplex<f32>>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
// CHECK-LABEL: func @float_remainder
|
// CHECK-LABEL: func @float_remainder
|
||||||
func @float_remainder(%lhs: tensor<2x2xf32>,
|
func @float_remainder(%lhs: tensor<2x2xf32>,
|
||||||
%rhs: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
%rhs: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
||||||
|
|
|
@ -242,6 +242,20 @@ func @complex_divide(%lhs: memref<2xcomplex<f64>>, %rhs: memref<2xcomplex<f64>>,
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @complex_multiply
|
||||||
|
func @complex_multiply(%lhs: memref<2xcomplex<f64>>, %rhs: memref<2xcomplex<f64>>,
|
||||||
|
%result: memref<2xcomplex<f64>>) {
|
||||||
|
"lmhlo.multiply"(%lhs, %rhs, %result)
|
||||||
|
: (memref<2xcomplex<f64>>, memref<2xcomplex<f64>>, memref<2xcomplex<f64>>) -> ()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// CHECK: linalg.generic
|
||||||
|
// CHECK-NEXT: ^bb0(%[[LHS_IN:.*]]: complex<f64>, %[[RHS_IN:.*]]: complex<f64>, %[[RESULT_OUT:.*]]: complex<f64>):
|
||||||
|
// CHECK-NEXT: %[[RESULT:.*]] = complex.mul %[[LHS_IN]], %[[RHS_IN]] : complex<f64>
|
||||||
|
// CHECK-NEXT: linalg.yield %[[RESULT]] : complex<f64>
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
// CHECK-LABEL: func @select
|
// CHECK-LABEL: func @select
|
||||||
func @select(%pred: memref<2x2xi1>, %lhs: memref<2x2xf32>,
|
func @select(%pred: memref<2x2xi1>, %lhs: memref<2x2xf32>,
|
||||||
%rhs: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
%rhs: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||||
|
|
Loading…
Reference in New Issue