[MHLO] Don't crash trying to constant fold mhlo.convert on complex
MLIR still doesn't have a complex attribute so this can't be implemented, so just bail out instead of trying to fold. PiperOrigin-RevId: 373128307
This commit is contained in:
parent
3cce7017fc
commit
86b7eb434c
|
@ -27,6 +27,11 @@ namespace hlo {
|
||||||
mlir::ElementsAttr ConvertElementsAttr(const mlir::ElementsAttr& elements,
|
mlir::ElementsAttr ConvertElementsAttr(const mlir::ElementsAttr& elements,
|
||||||
mlir::Type new_type) {
|
mlir::Type new_type) {
|
||||||
auto old_type = getElementTypeOrSelf(elements);
|
auto old_type = getElementTypeOrSelf(elements);
|
||||||
|
// TODO(kramerb): Add support when MLIR can represent const complex tensors.
|
||||||
|
if (old_type.isa<mlir::ComplexType>() || new_type.isa<mlir::ComplexType>()) {
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
|
||||||
size_t bit_width = new_type.isBF16() ? 64 : new_type.getIntOrFloatBitWidth();
|
size_t bit_width = new_type.isBF16() ? 64 : new_type.getIntOrFloatBitWidth();
|
||||||
|
|
||||||
if (old_type.isa<mlir::FloatType>()) {
|
if (old_type.isa<mlir::FloatType>()) {
|
||||||
|
|
|
@ -245,3 +245,53 @@ func @const_high_rank_tensor() -> tensor<2x3xi32> {
|
||||||
return %0 : tensor<2x3xi32>
|
return %0 : tensor<2x3xi32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @const_int_complex
|
||||||
|
func @const_int_complex() -> tensor<2xcomplex<f32>> {
|
||||||
|
%cst = mhlo.constant dense<[0, 1]> : tensor<2xi1>
|
||||||
|
// CHECK: mhlo.convert
|
||||||
|
%0 = "mhlo.convert"(%cst) : (tensor<2xi1>) -> tensor<2xcomplex<f32>>
|
||||||
|
return %0 : tensor<2xcomplex<f32>>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @const_float_complex
|
||||||
|
func @const_float_complex() -> tensor<2xcomplex<f64>> {
|
||||||
|
%cst = mhlo.constant dense<[0.0, 1.0]> : tensor<2xf32>
|
||||||
|
// CHECK: mhlo.convert
|
||||||
|
%0 = "mhlo.convert"(%cst) : (tensor<2xf32>) -> tensor<2xcomplex<f64>>
|
||||||
|
return %0 : tensor<2xcomplex<f64>>
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @const_complex_int
|
||||||
|
func @const_complex_int() -> tensor<i32> {
|
||||||
|
%cst = mhlo.constant dense<(0.0, 1.0)> : tensor<complex<f32>>
|
||||||
|
// CHECK: mhlo.convert
|
||||||
|
%0 = "mhlo.convert"(%cst) : (tensor<complex<f32>>) -> tensor<i32>
|
||||||
|
return %0 : tensor<i32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @const_complex_float
|
||||||
|
func @const_complex_float() -> tensor<f32> {
|
||||||
|
%cst = mhlo.constant dense<(0.0, 1.0)> : tensor<complex<f32>>
|
||||||
|
// CHECK: mhlo.convert
|
||||||
|
%0 = "mhlo.convert"(%cst) : (tensor<complex<f32>>) -> tensor<f32>
|
||||||
|
return %0 : tensor<f32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @const_complex_complex
|
||||||
|
func @const_complex_complex() -> tensor<complex<f64>> {
|
||||||
|
%cst = mhlo.constant dense<(0.0, 1.0)> : tensor<complex<f32>>
|
||||||
|
// CHECK: mhlo.convert
|
||||||
|
%0 = "mhlo.convert"(%cst) : (tensor<complex<f32>>) -> tensor<complex<f64>>
|
||||||
|
return %0 : tensor<complex<f64>>
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue