Add tests for lowering HLO_ExpOp for complex types to Linalg.

PiperOrigin-RevId: 379944871
This commit is contained in:
Adrian Kuegel 2021-06-17 06:33:13 -07:00 committed by TensorFlow MLIR Team
parent 470ac45f45
commit 2ab16024cf
2 changed files with 25 additions and 0 deletions

View File

@ -169,6 +169,17 @@ func @float_exp(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
// ----- // -----
// CHECK-LABEL: func @complex_exp
func @complex_exp(%arg0: tensor<2x2xcomplex<f32>>) -> tensor<2x2xcomplex<f32>> {
// CHECK: linalg.generic
// CHECK: complex.exp
%0 = "mhlo.exponential"(%arg0) : (tensor<2x2xcomplex<f32>>)
-> tensor<2x2xcomplex<f32>>
return %0 : tensor<2x2xcomplex<f32>>
}
// -----
// CHECK-LABEL: func @float_expm1 // CHECK-LABEL: func @float_expm1
func @float_expm1(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> { func @float_expm1(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
// CHECK: linalg.generic // CHECK: linalg.generic

View File

@ -135,6 +135,20 @@ func @exp(%input: memref<2x2xf32>, %result: memref<2x2xf32>) {
// ----- // -----
// CHECK-LABEL: func @complex_exp
func @complex_exp(%input: memref<2x2xcomplex<f32>>,
%result: memref<2x2xcomplex<f32>>) {
"lmhlo.exponential"(%input, %result)
: (memref<2x2xcomplex<f32>>, memref<2x2xcomplex<f32>>) -> ()
return
}
// CHECK: linalg.generic
// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: complex<f32>, %[[RESULT_OUT:.*]]):
// CHECK-NEXT: %[[RESULT:.*]] = complex.exp %[[OPERAND_IN]] : complex<f32>
// CHECK-NEXT: linalg.yield %[[RESULT]] : complex<f32>
// -----
// CHECK-LABEL: func @log // CHECK-LABEL: func @log
func @log(%input: memref<2x2xf32>, %result: memref<2x2xf32>) { func @log(%input: memref<2x2xf32>, %result: memref<2x2xf32>) {
"lmhlo.log"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> () "lmhlo.log"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> ()