[MHLO] Don't crash trying to constant fold mhlo.convert on complex

MLIR still doesn't have a complex attribute so this can't be implemented, so
just bail out instead of trying to fold.

PiperOrigin-RevId: 373128307
This commit is contained in:
Benjamin Kramer 2021-05-11 05:14:54 -07:00 committed by TensorFlow MLIR Team
parent 3cce7017fc
commit 86b7eb434c
2 changed files with 55 additions and 0 deletions

View File

@ -27,6 +27,11 @@ namespace hlo {
mlir::ElementsAttr ConvertElementsAttr(const mlir::ElementsAttr& elements,
mlir::Type new_type) {
auto old_type = getElementTypeOrSelf(elements);
// TODO(kramerb): Add support when MLIR can represent const complex tensors.
if (old_type.isa<mlir::ComplexType>() || new_type.isa<mlir::ComplexType>()) {
return {};
}
size_t bit_width = new_type.isBF16() ? 64 : new_type.getIntOrFloatBitWidth();
if (old_type.isa<mlir::FloatType>()) {

View File

@ -245,3 +245,53 @@ func @const_high_rank_tensor() -> tensor<2x3xi32> {
return %0 : tensor<2x3xi32>
}
// -----
// CHECK-LABEL: func @const_int_complex
func @const_int_complex() -> tensor<2xcomplex<f32>> {
%cst = mhlo.constant dense<[0, 1]> : tensor<2xi1>
// CHECK: mhlo.convert
%0 = "mhlo.convert"(%cst) : (tensor<2xi1>) -> tensor<2xcomplex<f32>>
return %0 : tensor<2xcomplex<f32>>
}
// -----
// CHECK-LABEL: func @const_float_complex
func @const_float_complex() -> tensor<2xcomplex<f64>> {
%cst = mhlo.constant dense<[0.0, 1.0]> : tensor<2xf32>
// CHECK: mhlo.convert
%0 = "mhlo.convert"(%cst) : (tensor<2xf32>) -> tensor<2xcomplex<f64>>
return %0 : tensor<2xcomplex<f64>>
}
// -----
// CHECK-LABEL: func @const_complex_int
func @const_complex_int() -> tensor<i32> {
%cst = mhlo.constant dense<(0.0, 1.0)> : tensor<complex<f32>>
// CHECK: mhlo.convert
%0 = "mhlo.convert"(%cst) : (tensor<complex<f32>>) -> tensor<i32>
return %0 : tensor<i32>
}
// -----
// CHECK-LABEL: func @const_complex_float
func @const_complex_float() -> tensor<f32> {
%cst = mhlo.constant dense<(0.0, 1.0)> : tensor<complex<f32>>
// CHECK: mhlo.convert
%0 = "mhlo.convert"(%cst) : (tensor<complex<f32>>) -> tensor<f32>
return %0 : tensor<f32>
}
// -----
// CHECK-LABEL: func @const_complex_complex
func @const_complex_complex() -> tensor<complex<f64>> {
%cst = mhlo.constant dense<(0.0, 1.0)> : tensor<complex<f32>>
// CHECK: mhlo.convert
%0 = "mhlo.convert"(%cst) : (tensor<complex<f32>>) -> tensor<complex<f64>>
return %0 : tensor<complex<f64>>
}