Add tests for lowering HLO_ExpOp for complex types to Linalg.
PiperOrigin-RevId: 379944871
This commit is contained in:
parent
470ac45f45
commit
2ab16024cf
|
@ -169,6 +169,17 @@ func @float_exp(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @complex_exp
|
||||||
|
func @complex_exp(%arg0: tensor<2x2xcomplex<f32>>) -> tensor<2x2xcomplex<f32>> {
|
||||||
|
// CHECK: linalg.generic
|
||||||
|
// CHECK: complex.exp
|
||||||
|
%0 = "mhlo.exponential"(%arg0) : (tensor<2x2xcomplex<f32>>)
|
||||||
|
-> tensor<2x2xcomplex<f32>>
|
||||||
|
return %0 : tensor<2x2xcomplex<f32>>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
// CHECK-LABEL: func @float_expm1
|
// CHECK-LABEL: func @float_expm1
|
||||||
func @float_expm1(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
func @float_expm1(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
||||||
// CHECK: linalg.generic
|
// CHECK: linalg.generic
|
||||||
|
|
|
@ -135,6 +135,20 @@ func @exp(%input: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @complex_exp
|
||||||
|
func @complex_exp(%input: memref<2x2xcomplex<f32>>,
|
||||||
|
%result: memref<2x2xcomplex<f32>>) {
|
||||||
|
"lmhlo.exponential"(%input, %result)
|
||||||
|
: (memref<2x2xcomplex<f32>>, memref<2x2xcomplex<f32>>) -> ()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// CHECK: linalg.generic
|
||||||
|
// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: complex<f32>, %[[RESULT_OUT:.*]]):
|
||||||
|
// CHECK-NEXT: %[[RESULT:.*]] = complex.exp %[[OPERAND_IN]] : complex<f32>
|
||||||
|
// CHECK-NEXT: linalg.yield %[[RESULT]] : complex<f32>
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
// CHECK-LABEL: func @log
|
// CHECK-LABEL: func @log
|
||||||
func @log(%input: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
func @log(%input: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||||
"lmhlo.log"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
"lmhlo.log"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||||
|
|
Loading…
Reference in New Issue