diff --git a/lib/utils/convert_op_folder.cc b/lib/utils/convert_op_folder.cc index f7177ec..623a015 100644 --- a/lib/utils/convert_op_folder.cc +++ b/lib/utils/convert_op_folder.cc @@ -27,6 +27,11 @@ namespace hlo { mlir::ElementsAttr ConvertElementsAttr(const mlir::ElementsAttr& elements, mlir::Type new_type) { auto old_type = getElementTypeOrSelf(elements); + // TODO(kramerb): Add support when MLIR can represent const complex tensors. + if (old_type.isa() || new_type.isa()) { + return {}; + } + size_t bit_width = new_type.isBF16() ? 64 : new_type.getIntOrFloatBitWidth(); if (old_type.isa()) { diff --git a/tests/convert.mlir b/tests/convert.mlir index 246cf41..ff8b94a 100644 --- a/tests/convert.mlir +++ b/tests/convert.mlir @@ -245,3 +245,53 @@ func @const_high_rank_tensor() -> tensor<2x3xi32> { return %0 : tensor<2x3xi32> } +// ----- + +// CHECK-LABEL: func @const_int_complex +func @const_int_complex() -> tensor<2xcomplex> { + %cst = mhlo.constant dense<[0, 1]> : tensor<2xi1> + // CHECK: mhlo.convert + %0 = "mhlo.convert"(%cst) : (tensor<2xi1>) -> tensor<2xcomplex> + return %0 : tensor<2xcomplex> +} + +// ----- + +// CHECK-LABEL: func @const_float_complex +func @const_float_complex() -> tensor<2xcomplex> { + %cst = mhlo.constant dense<[0.0, 1.0]> : tensor<2xf32> + // CHECK: mhlo.convert + %0 = "mhlo.convert"(%cst) : (tensor<2xf32>) -> tensor<2xcomplex> + return %0 : tensor<2xcomplex> +} + + +// ----- + +// CHECK-LABEL: func @const_complex_int +func @const_complex_int() -> tensor { + %cst = mhlo.constant dense<(0.0, 1.0)> : tensor> + // CHECK: mhlo.convert + %0 = "mhlo.convert"(%cst) : (tensor>) -> tensor + return %0 : tensor +} + +// ----- + +// CHECK-LABEL: func @const_complex_float +func @const_complex_float() -> tensor { + %cst = mhlo.constant dense<(0.0, 1.0)> : tensor> + // CHECK: mhlo.convert + %0 = "mhlo.convert"(%cst) : (tensor>) -> tensor + return %0 : tensor +} + +// ----- + +// CHECK-LABEL: func @const_complex_complex +func @const_complex_complex() -> tensor> { + %cst = mhlo.constant dense<(0.0, 1.0)> : tensor> + // CHECK: mhlo.convert + %0 = "mhlo.convert"(%cst) : (tensor>) -> tensor> + return %0 : tensor> +}