From 86b7eb434cf546ce5ba8a1edcfcca47108bd4956 Mon Sep 17 00:00:00 2001 From: Benjamin Kramer Date: Tue, 11 May 2021 05:14:54 -0700 Subject: [PATCH] [MHLO] Don't crash trying to constant fold mhlo.convert on complex MLIR still doesn't have a complex attribute so this can't be implemented, so just bail out instead of trying to fold. PiperOrigin-RevId: 373128307 --- lib/utils/convert_op_folder.cc | 5 ++++ tests/convert.mlir | 50 ++++++++++++++++++++++++++++++++++ 2 files changed, 55 insertions(+) diff --git a/lib/utils/convert_op_folder.cc b/lib/utils/convert_op_folder.cc index f7177ec..623a015 100644 --- a/lib/utils/convert_op_folder.cc +++ b/lib/utils/convert_op_folder.cc @@ -27,6 +27,11 @@ namespace hlo { mlir::ElementsAttr ConvertElementsAttr(const mlir::ElementsAttr& elements, mlir::Type new_type) { auto old_type = getElementTypeOrSelf(elements); + // TODO(kramerb): Add support when MLIR can represent const complex tensors. + if (old_type.isa() || new_type.isa()) { + return {}; + } + size_t bit_width = new_type.isBF16() ? 64 : new_type.getIntOrFloatBitWidth(); if (old_type.isa()) { diff --git a/tests/convert.mlir b/tests/convert.mlir index 246cf41..ff8b94a 100644 --- a/tests/convert.mlir +++ b/tests/convert.mlir @@ -245,3 +245,53 @@ func @const_high_rank_tensor() -> tensor<2x3xi32> { return %0 : tensor<2x3xi32> } +// ----- + +// CHECK-LABEL: func @const_int_complex +func @const_int_complex() -> tensor<2xcomplex> { + %cst = mhlo.constant dense<[0, 1]> : tensor<2xi1> + // CHECK: mhlo.convert + %0 = "mhlo.convert"(%cst) : (tensor<2xi1>) -> tensor<2xcomplex> + return %0 : tensor<2xcomplex> +} + +// ----- + +// CHECK-LABEL: func @const_float_complex +func @const_float_complex() -> tensor<2xcomplex> { + %cst = mhlo.constant dense<[0.0, 1.0]> : tensor<2xf32> + // CHECK: mhlo.convert + %0 = "mhlo.convert"(%cst) : (tensor<2xf32>) -> tensor<2xcomplex> + return %0 : tensor<2xcomplex> +} + + +// ----- + +// CHECK-LABEL: func @const_complex_int +func @const_complex_int() -> tensor { + %cst = mhlo.constant dense<(0.0, 1.0)> : tensor> + // CHECK: mhlo.convert + %0 = "mhlo.convert"(%cst) : (tensor>) -> tensor + return %0 : tensor +} + +// ----- + +// CHECK-LABEL: func @const_complex_float +func @const_complex_float() -> tensor { + %cst = mhlo.constant dense<(0.0, 1.0)> : tensor> + // CHECK: mhlo.convert + %0 = "mhlo.convert"(%cst) : (tensor>) -> tensor + return %0 : tensor +} + +// ----- + +// CHECK-LABEL: func @const_complex_complex +func @const_complex_complex() -> tensor> { + %cst = mhlo.constant dense<(0.0, 1.0)> : tensor> + // CHECK: mhlo.convert + %0 = "mhlo.convert"(%cst) : (tensor>) -> tensor> + return %0 : tensor> +}