diff --git a/tests/canonicalize.mlir b/tests/canonicalize.mlir new file mode 100644 index 0000000..90dd8f1 --- /dev/null +++ b/tests/canonicalize.mlir @@ -0,0 +1,457 @@ +// RUN: mlir-hlo-opt %s -pass-pipeline='func(canonicalize)' | FileCheck %s + +// CHECK-LABEL: add_fold +func @add_fold() -> tensor<4xi64> { + %0 = xla_hlo.constant dense<[1, 2, 3, 4]> : tensor<4xi64> + %1 = xla_hlo.constant dense<[5, 6, 7, 8]> : tensor<4xi64> + // CHECK: xla_hlo.constant dense<[6, 8, 10, 12]> + %2 = "xla_hlo.add"(%0, %1) : (tensor<4xi64>, tensor<4xi64>) -> (tensor<4xi64>) + return %2 : tensor<4xi64> +} + +// CHECK-LABEL: add_scalar_fold +func @add_scalar_fold() -> tensor<4xi64> { + %0 = xla_hlo.constant dense<1> : tensor<4xi64> + %1 = xla_hlo.constant dense<5> : tensor<4xi64> + // CHECK: xla_hlo.constant dense<6> + %2 = "xla_hlo.add"(%0, %1) : (tensor<4xi64>, tensor<4xi64>) -> (tensor<4xi64>) + return %2 : tensor<4xi64> +} + +// CHECK-LABEL: add_fold_float +func @add_fold_float() -> tensor<4xf64> { + %0 = xla_hlo.constant dense<[1.0, 2.0, 3.0, 4.0]> : tensor<4xf64> + %1 = xla_hlo.constant dense<[5.0, 6.0, 7.0, 8.0]> : tensor<4xf64> + // CHECK: xla_hlo.constant dense<[6.000000e+00, 8.000000e+00, 1.000000e+01, 1.200000e+01]> + %2 = "xla_hlo.add"(%0, %1) : (tensor<4xf64>, tensor<4xf64>) -> (tensor<4xf64>) + return %2 : tensor<4xf64> +} + +// CHECK-LABEL: sub_scalar_fold +func @sub_scalar_fold() -> tensor<4xi64> { + %0 = xla_hlo.constant dense<5> : tensor<4xi64> + %1 = xla_hlo.constant dense<1> : tensor<4xi64> + // CHECK: xla_hlo.constant dense<4> + %2 = "xla_hlo.subtract"(%0, %1) : (tensor<4xi64>, tensor<4xi64>) -> (tensor<4xi64>) + return %2 : tensor<4xi64> +} + +// CHECK-LABEL: multiply_scalar_fold +func @multiply_scalar_fold() -> tensor<4xi64> { + %0 = xla_hlo.constant dense<5> : tensor<4xi64> + %1 = xla_hlo.constant dense<3> : tensor<4xi64> + // CHECK: xla_hlo.constant dense<15> + %2 = "xla_hlo.multiply"(%0, %1) : (tensor<4xi64>, tensor<4xi64>) -> (tensor<4xi64>) + return %2 : tensor<4xi64> +} + +// CHECK-LABEL: divide_scalar_fold +func @divide_scalar_fold() -> tensor<4xi64> { + %0 = xla_hlo.constant dense<7> : tensor<4xi64> + %1 = xla_hlo.constant dense<5> : tensor<4xi64> + // CHECK: xla_hlo.constant dense<1> + %2 = "xla_hlo.divide"(%0, %1) : (tensor<4xi64>, tensor<4xi64>) -> (tensor<4xi64>) + return %2 : tensor<4xi64> +} + +// CHECK-LABEL: divide_fold_float +func @divide_fold_float() -> tensor<4xf64> { + %0 = xla_hlo.constant dense<[5.0, 66.0, 5.0, 1.0]> : tensor<4xf64> + %1 = xla_hlo.constant dense<[5.0, 3.0, 2.0, 4.0]> : tensor<4xf64> + // CHECK: xla_hlo.constant dense<[1.000000e+00, 2.200000e+01, 2.500000e+00, 2.500000e-01]> + %2 = "xla_hlo.divide"(%0, %1) : (tensor<4xf64>, tensor<4xf64>) -> (tensor<4xf64>) + return %2 : tensor<4xf64> +} + +// CHECK-LABEL: max_scalar_fold +func @max_scalar_fold() -> tensor<4xi64> { + %0 = xla_hlo.constant dense<7> : tensor<4xi64> + %1 = xla_hlo.constant dense<5> : tensor<4xi64> + // CHECK: xla_hlo.constant dense<7> + %2 = "xla_hlo.maximum"(%0, %1) : (tensor<4xi64>, tensor<4xi64>) -> (tensor<4xi64>) + return %2 : tensor<4xi64> +} + +// CHECK-LABEL: max_fold_float +func @max_fold_float() -> tensor<4xf64> { + %0 = xla_hlo.constant dense<[5.0, 66.0, 5.0, 1.0]> : tensor<4xf64> + %1 = xla_hlo.constant dense<[5.0, 3.0, 2.0, 4.0]> : tensor<4xf64> + // CHECK: xla_hlo.constant dense<[5.000000e+00, 6.600000e+01, 5.000000e+00, 4.000000e+00]> + %2 = "xla_hlo.maximum"(%0, %1) : (tensor<4xf64>, tensor<4xf64>) -> (tensor<4xf64>) + return %2 : tensor<4xf64> +} + +// CHECK-LABEL: min_scalar_fold +func @min_scalar_fold() -> tensor<4xi64> { + %0 = xla_hlo.constant dense<7> : tensor<4xi64> + %1 = xla_hlo.constant dense<-5> : tensor<4xi64> + // CHECK: xla_hlo.constant dense<-5> + %2 = "xla_hlo.minimum"(%0, %1) : (tensor<4xi64>, tensor<4xi64>) -> (tensor<4xi64>) + return %2 : tensor<4xi64> +} + +// CHECK-LABEL: min_fold_float +func @min_fold_float() -> tensor<4xf64> { + %0 = xla_hlo.constant dense<[5.0, 66.0, 5.0, 1.0]> : tensor<4xf64> + %1 = xla_hlo.constant dense<[5.0, 3.0, 2.0, 4.0]> : tensor<4xf64> + // CHECK: xla_hlo.constant dense<[5.000000e+00, 3.000000e+00, 2.000000e+00, 1.000000e+00]> + %2 = "xla_hlo.minimum"(%0, %1) : (tensor<4xf64>, tensor<4xf64>) -> (tensor<4xf64>) + return %2 : tensor<4xf64> +} + +// CHECK-LABEL: concatenate_noop +func @concatenate_noop(%arg0: tensor<4xi32>) -> tensor<4xi32> { + // CHECK-SAME: [[ARG:%.+]]: tensor<4xi32> + %0 = "xla_hlo.concatenate"(%arg0) { dimension = 0 : i64 } : (tensor<4xi32>) -> tensor<4xi32> + + // CHECK: return [[ARG]] + return %0 : tensor<4xi32> +} + +// CHECK-LABEL: concatenate_remove_operand +func @concatenate_remove_operand(%arg0: tensor<4xi32>, %arg1: tensor<0xi32>) -> tensor<4xi32> { + // CHECK-SAME: [[ARG0:%.+]]: tensor<4xi32> + // CHECK-SAME: [[ARG1:%.+]]: tensor<0xi32> + %0 = "xla_hlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<4xi32>, tensor<0xi32>) -> tensor<4xi32> + + // CHECK: return [[ARG0]] + return %0 : tensor<4xi32> +} + +// CHECK-LABEL: concatenate_empty_bool +func @concatenate_empty_bool(%arg0: tensor<0xi1>, %arg1: tensor<0xi1>) -> tensor<0xi1> { + // CHECK: xla_hlo.constant + %0 = "xla_hlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<0xi1>, tensor<0xi1>) -> tensor<0xi1> + + return %0 : tensor<0xi1> +} + +// CHECK-LABEL: concatenate_empty_int +func @concatenate_empty_int(%arg0: tensor<0xi32>, %arg1: tensor<0xi32>) -> tensor<0xi32> { + // CHECK: xla_hlo.constant + %0 = "xla_hlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<0xi32>, tensor<0xi32>) -> tensor<0xi32> + + return %0 : tensor<0xi32> +} + +// CHECK-LABEL: concatenate_empty_float +func @concatenate_empty_float(%arg0: tensor<0xf32>, %arg1: tensor<0xf32>) -> tensor<0xf32> { + // CHECK: xla_hlo.constant + %0 = "xla_hlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<0xf32>, tensor<0xf32>) -> tensor<0xf32> + + return %0 : tensor<0xf32> +} + +// CHECK-LABEL: concatenate_const_1D +func @concatenate_const_1D() -> tensor<4xi32> { + // CHECK: [[VAL:%.+]]= xla_hlo.constant dense<[0, 1, 2, 3]> + %0 = xla_hlo.constant dense<[0, 1]> : tensor<2xi32> + %1 = xla_hlo.constant dense<[2, 3]> : tensor<2xi32> + %2 = "xla_hlo.concatenate"(%0, %1) { dimension = 0 : i64 } : (tensor<2xi32>, tensor<2xi32>) -> tensor<4xi32> + + // CHECK: return [[VAL]] + return %2 : tensor<4xi32> +} + +// CHECK-LABEL: concatenate_const_1D_float +func @concatenate_const_1D_float() -> tensor<4xf32> { + // CHECK: [[VAL:%.+]] = xla_hlo.constant dense<[0.000000e+00, 1.000000e+00, 2.000000e+00, 3.000000e+00]> + + %0 = xla_hlo.constant dense<[0.0, 1.0]> : tensor<2xf32> + %1 = xla_hlo.constant dense<[2.0, 3.0]> : tensor<2xf32> + %2 = "xla_hlo.concatenate"(%0, %1) { dimension = 0 : i64 } : (tensor<2xf32>, tensor<2xf32>) -> tensor<4xf32> + + // CHECK: return [[VAL]] + return %2 : tensor<4xf32> +} + +// CHECK-LABEL: concatenate_const_2D_vertical +func @concatenate_const_2D_vertical() -> tensor<2x2xi32> { + // CHECK: [[VAL:%.+]]= xla_hlo.constant dense<[ + // CHECK-SAME: [0, 1], [2, 3] + // CHECK-SAME: ]> + %0 = xla_hlo.constant dense<[[0, 1]]> : tensor<1x2xi32> + %1 = xla_hlo.constant dense<[[2, 3]]> : tensor<1x2xi32> + %2 = "xla_hlo.concatenate"(%0, %1) { dimension = 0 : i64 } : (tensor<1x2xi32>, tensor<1x2xi32>) -> tensor<2x2xi32> + + // CHECK: return [[VAL]] + return %2 : tensor<2x2xi32> +} + +// CHECK-LABEL: concatenate_const_2D_horizontal +func @concatenate_const_2D_horizontal() -> tensor<2x2xi32> { + // CHECK: [[VAL:%.+]]= xla_hlo.constant dense<[ + // CHECK-SAME: [0, 2], [1, 3] + // CHECK-SAME: ]> + %0 = xla_hlo.constant dense<[[0], [1]]> : tensor<2x1xi32> + %1 = xla_hlo.constant dense<[[2], [3]]> : tensor<2x1xi32> + %2 = "xla_hlo.concatenate"(%0, %1) { dimension = 1 : i64 } : (tensor<2x1xi32>, tensor<2x1xi32>) -> tensor<2x2xi32> + + // CHECK: return [[VAL]] + return %2 : tensor<2x2xi32> +} + +// CHECK-LABEL: dynamic_slice_variable_start +func @dynamic_slice_variable_start(%arg0: tensor<3x4xi32>, %arg1: tensor, %arg2: tensor) -> tensor<1x4xi32> { + // CHECK: "xla_hlo.dynamic-slice" + %1 = "xla_hlo.dynamic-slice"(%arg0, %arg1, %arg2) {slice_sizes = dense<[1, 4]> : tensor<2xi64>} : (tensor<3x4xi32>, tensor, tensor) -> tensor<1x4xi32> + return %1 : tensor<1x4xi32> +} + +// CHECK-LABEL: dynamic_slice_constant_start +func @dynamic_slice_constant_start(%arg0: tensor<4xi32>) -> tensor<2xi32> { + // CHECK: %[[RESULT:.*]] = "xla_hlo.slice"(%arg0) + // CHECK-DAG-SAME: limit_indices = dense<3> : tensor<1xi64> + // CHECK-DAG-SAME: start_indices = dense<1> : tensor<1xi64> + // CHECK-DAG-SAME: strides = dense<1> : tensor<1xi64>} + // CHECK: return %[[RESULT]] : tensor<2xi32> + %0 = xla_hlo.constant dense<1> : tensor + %1 = "xla_hlo.dynamic-slice"(%arg0, %0) {slice_sizes = dense<2> : tensor<1xi64>} : (tensor<4xi32>, tensor) -> tensor<2xi32> + return %1 : tensor<2xi32> +} + +// CHECK-LABEL: dynamic_slice_constant_start_dynamic_shape +func @dynamic_slice_constant_start_dynamic_shape(%arg0: tensor, %arg1: tensor<2xi64>) -> tensor { + // CHECK: %[[RESULT:.*]] = "xla_hlo.slice"(%arg0) + // CHECK-DAG-SAME: limit_indices = dense<[2, 4]> : tensor<2xi64> + // CHECK-DAG-SAME: start_indices = dense<[1, 0]> : tensor<2xi64> + // CHECK-DAG-SAME: strides = dense<1> : tensor<2xi64> + // CHECK: return %[[RESULT]] : tensor + %0 = xla_hlo.constant dense<1> : tensor + %1 = xla_hlo.constant dense<0> : tensor + %2 = "xla_hlo.dynamic-slice"(%arg0, %0, %1) {slice_sizes = dense<[1, 4]> : tensor<2xi64>} : (tensor, tensor, tensor) -> tensor + return %2 : tensor +} + +// CHECK-LABEL: slice_2D_noop +// CHECK-SAME: [[ARG:%.+]]: tensor<2x2xi64> +func @slice_2D_noop(%arg0: tensor<2x2xi64>) -> tensor<2x2xi64> { + %0 = "xla_hlo.slice"(%arg0) { limit_indices = dense<[2, 2]> : tensor<2xi64>, start_indices = dense<[0, 0]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<2x2xi64>) -> (tensor<2x2xi64>) + + // CHECK-NEXT: return [[ARG]] + return %0 : tensor<2x2xi64> +} + +// CHECK-LABEL: slice_1D_fold +func @slice_1D_fold() -> tensor<2xi64> { + %0 = xla_hlo.constant dense<[5, 7, 9, 10]> : tensor<4xi64> + // CHECK: xla_hlo.constant dense<[7, 9]> + %1 = "xla_hlo.slice"(%0) { limit_indices = dense<[3]> : tensor<1xi64>, start_indices = dense<[1]> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} : (tensor<4xi64>) -> (tensor<2xi64>) + return %1 : tensor<2xi64> +} + +// CHECK-LABEL: slice_1D_fp +func @slice_1D_fp() -> tensor<2xf32> { + %0 = xla_hlo.constant dense<[5.0, 7.0, 9.0, 10.0]> : tensor<4xf32> + // CHECK: xla_hlo.constant dense<[7.000000e+00, 9.000000e+00]> + %1 = "xla_hlo.slice"(%0) { limit_indices = dense<[3]> : tensor<1xi64>, start_indices = dense<[1]> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} : (tensor<4xf32>) -> (tensor<2xf32>) + return %1 : tensor<2xf32> +} + +// CHECK-LABEL: slice_1D_strided_fold +func @slice_1D_strided_fold() -> tensor<2xi64> { + %0 = xla_hlo.constant dense<[5, 7, 9, 10]> : tensor<4xi64> + // CHECK: xla_hlo.constant dense<[7, 10]> + %1 = "xla_hlo.slice"(%0) { limit_indices = dense<[4]> : tensor<1xi64>, start_indices = dense<[1]> : tensor<1xi64>, strides = dense<2> : tensor<1xi64>} : (tensor<4xi64>) -> (tensor<2xi64>) + return %1 : tensor<2xi64> +} + +// CHECK-LABEL: slice_2D_fold +func @slice_2D_fold() -> tensor<2x2xi64> { + %0 = xla_hlo.constant dense<[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15]]> : tensor<4x4xi64> + // CHECK-NEXT: xla_hlo.constant dense<[ + // CHECK-SAME: [6, 7], + // CHECK-SAME: [10, 11] + // CHECK-SAME: ]> + %1 = "xla_hlo.slice"(%0) { limit_indices = dense<[3, 4]> : tensor<2xi64>, start_indices = dense<[1, 2]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<4x4xi64>) -> (tensor<2x2xi64>) + return %1 : tensor<2x2xi64> +} + +// CHECK-LABEL: slice_2D_fold_horizontal +func @slice_2D_fold_horizontal() -> tensor<1x4xi64> { + %0 = xla_hlo.constant dense<[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15]]> : tensor<4x4xi64> + // CHECK-NEXT: xla_hlo.constant dense<[ + // CHECK-SAME: [0, 1, 2, 3] + // CHECK-SAME: ]> + %1 = "xla_hlo.slice"(%0) { limit_indices = dense<[1, 4]> : tensor<2xi64>, start_indices = dense<[0, 0]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<4x4xi64>) -> (tensor<1x4xi64>) + return %1 : tensor<1x4xi64> +} + +// CHECK-LABEL: slice_2D_fold_vertical +func @slice_2D_fold_vertical() -> tensor<4x1xi64> { + %0 = xla_hlo.constant dense<[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15]]> : tensor<4x4xi64> + // CHECK-NEXT: xla_hlo.constant dense<[ + // CHECK-SAME: [2], [6], [10], [14] + // CHECK-SAME: ]> + %1 = "xla_hlo.slice"(%0) { limit_indices = dense<[4, 3]> : tensor<2xi64>, start_indices = dense<[0, 2]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<4x4xi64>) -> (tensor<4x1xi64>) + return %1 : tensor<4x1xi64> +} + +// CHECK-LABEL: slice_concat_fold_first +func @slice_concat_fold_first(%arg0: tensor<1x5xf32>, %arg1: tensor<1x5xf32>) -> tensor<1x5xf32> { + %0 = "xla_hlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<1x5xf32>, tensor<1x5xf32>) -> tensor<2x5xf32> + %1 = "xla_hlo.slice"(%0) { limit_indices = dense<[1, 5]> : tensor<2xi64>, start_indices = dense<[0, 0]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<2x5xf32>) -> (tensor<1x5xf32>) + // CHECK: return %arg0 + return %1 : tensor<1x5xf32> +} + +// CHECK-LABEL: slice_concat_fold_second +func @slice_concat_fold_second(%arg0: tensor<1x5xf32>, %arg1: tensor<1x5xf32>) -> tensor<1x5xf32> { + %0 = "xla_hlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<1x5xf32>, tensor<1x5xf32>) -> tensor<2x5xf32> + %1 = "xla_hlo.slice"(%0) { limit_indices = dense<[2, 5]> : tensor<2xi64>, start_indices = dense<[1, 0]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<2x5xf32>) -> (tensor<1x5xf32>) + // CHECK: return %arg1 + return %1 : tensor<1x5xf32> +} + +// CHECK-LABEL: slice_concat_fold_second_with_slice +func @slice_concat_fold_second_with_slice(%arg0: tensor<1x5xf32>, %arg1: tensor<1x5xf32>) -> tensor<1x4xf32> { + %0 = "xla_hlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<1x5xf32>, tensor<1x5xf32>) -> tensor<2x5xf32> + // CHECK: [[SLICE:%.+]] = "xla_hlo.slice"(%arg1) {limit_indices = dense<[1, 5]> : tensor<2xi64>, start_indices = dense<[0, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<1x5xf32>) -> tensor<1x4xf32> + %1 = "xla_hlo.slice"(%0) { limit_indices = dense<[2, 5]> : tensor<2xi64>, start_indices = dense<[1, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<2x5xf32>) -> (tensor<1x4xf32>) + + // CHECK: return [[SLICE]] + return %1 : tensor<1x4xf32> +} + +// CHECK-LABEL: slice_concat_fold_middle +func @slice_concat_fold_middle(%arg0: tensor<1x5xf32>, %arg1: tensor<2x5xf32>, %arg2: tensor<1x5xf32>) -> tensor<1x5xf32> { + %0 = "xla_hlo.concatenate"(%arg0, %arg1, %arg2) { dimension = 0 : i64 } : (tensor<1x5xf32>, tensor<2x5xf32>, tensor<1x5xf32>) -> tensor<4x5xf32> + // CHECK: [[SLICE:%.+]] = "xla_hlo.slice"(%arg1) {limit_indices = dense<[2, 5]> : tensor<2xi64>, start_indices = dense<[1, 0]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} + %1 = "xla_hlo.slice"(%0) { limit_indices = dense<[3, 5]> : tensor<2xi64>, start_indices = dense<[2, 0]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<4x5xf32>) -> (tensor<1x5xf32>) + + // CHECK: return [[SLICE]] + return %1 : tensor<1x5xf32> +} + +// CHECK-LABEL: slice_concat_fold_two +func @slice_concat_fold_two(%arg0: tensor<1x5xf32>, %arg1: tensor<2x5xf32>, %arg2: tensor<1x5xf32>) -> tensor<2x5xf32> { + // CHECK: [[CONCAT:%.+]] = "xla_hlo.concatenate"(%arg1, %arg2) {dimension = 0 : i64} + %0 = "xla_hlo.concatenate"(%arg0, %arg1, %arg2) { dimension = 0 : i64 } : (tensor<1x5xf32>, tensor<2x5xf32>, tensor<1x5xf32>) -> tensor<4x5xf32> + + // CHECK: [[SLICE:%.+]] = "xla_hlo.slice"([[CONCAT]]) {limit_indices = dense<[3, 5]> : tensor<2xi64>, start_indices = dense<[1, 0]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} + %1 = "xla_hlo.slice"(%0) { limit_indices = dense<[4, 5]> : tensor<2xi64>, start_indices = dense<[2, 0]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<4x5xf32>) -> (tensor<2x5xf32>) + + // CHECK: return [[SLICE]] + return %1 : tensor<2x5xf32> +} + +// CHECK-LABEL: func @broadcast_in_dim_identity +func @broadcast_in_dim_identity(%arg0: tensor<2x3x4xf32>) -> tensor<2x3x4xf32> { + // CHECK: return %arg0 + %0 = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<2x3x4xf32>) -> tensor<2x3x4xf32> + return %0 : tensor<2x3x4xf32> +} + +// CHECK-LABEL: func @broadcast_in_dim_not_identity_because_it_actually_broadcasts +func @broadcast_in_dim_not_identity_because_it_actually_broadcasts(%arg0: tensor<1x2xf32>) -> tensor<2x2xf32> { + // CHECK: xla_hlo.broadcast_in_dim + %0 = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<1x2xf32>) -> tensor<2x2xf32> + return %0 : tensor<2x2xf32> +} + +// CHECK-LABEL: func @broadcast_in_dim_not_identity_permutation +func @broadcast_in_dim_not_identity_permutation(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> { + // CHECK: xla_hlo.broadcast_in_dim + %0 = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[1, 0]> : tensor<2xi64>} : (tensor<2x2xf32>) -> tensor<2x2xf32> + return %0 : tensor<2x2xf32> +} + + +// CHECK-LABEL: func @dynamic_broadcast_in_dim_op_not_actually_dynamic +func @dynamic_broadcast_in_dim_op_not_actually_dynamic(%arg0: tensor<4xf32>, %arg1: tensor<2xi64>) -> tensor<5x4xf32> { + // CHECK: %[[RESULT:.+]] = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<4xf32>) -> tensor<5x4xf32> + %0 = "xla_hlo.dynamic_broadcast_in_dim"(%arg0, %arg1) { broadcast_dimensions = dense<1> : tensor<1xi64> } : (tensor<4xf32>, tensor<2xi64>) -> tensor<5x4xf32> + // CHECK: return %[[RESULT]] : tensor<5x4xf32> + return %0 : tensor<5x4xf32> +} + +// CHECK-LABEL: @complex_expand_fold +func @complex_expand_fold(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> (tensor<4xf32>, tensor<4xf32>) { + %0 = "xla_hlo.complex"(%arg0, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> (tensor<4xcomplex>) + %1 = "xla_hlo.real"(%0) : (tensor<4xcomplex>) -> (tensor<4xf32>) + %2 = "xla_hlo.imag"(%0) : (tensor<4xcomplex>) -> (tensor<4xf32>) + // CHECK: return %arg0, %arg1 + return %1, %2 : tensor<4xf32>, tensor<4xf32> +} + +// CHECK-LABEL: @complex_collapse_fold +func @complex_collapse_fold(%arg0: tensor<4xcomplex>) -> tensor<4xcomplex> { + %0 = "xla_hlo.real"(%arg0) : (tensor<4xcomplex>) -> (tensor<4xf32>) + %1 = "xla_hlo.imag"(%arg0) : (tensor<4xcomplex>) -> (tensor<4xf32>) + %2 = "xla_hlo.complex"(%0, %1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xcomplex> + // CHECK: return %arg0 + return %2 : tensor<4xcomplex> +} + +// CHECK-LABEL: @dynamic_iota_is_static +func @dynamic_iota_is_static(%arg0 : tensor<1xindex>) -> tensor<4xi32> { + // CHECK: [[RESULT:%.*]] = "xla_hlo.iota" + // CHECK: return [[RESULT]] + %0 = "xla_hlo.dynamic_iota"(%arg0) {iota_dimension = 0 : i64} : (tensor<1xindex>) -> tensor<4xi32> + return %0 : tensor<4xi32> +} + +// CHECK-LABEL: @iota_not_lowered_to_constant +func @iota_not_lowered_to_constant() -> tensor<4xi32> { + // CHECK: [[RESULT:%.*]] = "xla_hlo.iota" + // CHECK: return [[RESULT]] + %0 = "xla_hlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<4xi32> + return %0 : tensor<4xi32> +} + +// CHECK-LABEL: @unary_einsum +func @unary_einsum(%arg0: tensor<2x3xf32>) -> tensor<2x2xf32> { + // CHECK: %[[ONE:.*]] = xla_hlo.constant dense<1.000000e+00> : tensor + // CHECK: "xla_hlo.einsum"(%[[ONE]], %arg0) {einsum_config = ",ab->aa"} + %0 = "xla_hlo.unary_einsum"(%arg0) {einsum_config = "ab->aa"} : (tensor<2x3xf32>) -> tensor<2x2xf32> + return %0 : tensor<2x2xf32> +} + +// CHECK-LABEL: func @fold_copy +// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]] +func @fold_copy(%arg : tensor<1x4xf32>) -> tensor<1x4xf32> { + // CHECK: return [[ARG]] + %0 = "xla_hlo.copy"(%arg) : (tensor<1x4xf32>) -> tensor<1x4xf32> + return %0 : tensor<1x4xf32> +} + +// CHECK-LABEL: func @dynamic_reshape_not_actually_dynamic +func @dynamic_reshape_not_actually_dynamic(%arg0: tensor<4xf32>, %shape: tensor<2xindex>) -> tensor<4x1xf32> { + // CHECK: xla_hlo.reshape + %0 = "xla_hlo.dynamic_reshape"(%arg0, %shape) : (tensor<4xf32>, tensor<2xindex>) -> tensor<4x1xf32> + return %0 : tensor<4x1xf32> +} + +// CHECK-LABEL: do_not_dce_while_with_outfeed +func @do_not_dce_while_with_outfeed(%arg0: tensor) -> tensor { + // CHECK: xla_hlo.while + %0 = "xla_hlo.while"(%arg0) ( { + ^bb0(%arg1: tensor): + %1 = "xla_hlo.compare"(%arg1, %arg1) {comparison_direction = "LT"} : (tensor, tensor) -> tensor + "xla_hlo.return"(%1) : (tensor) -> () + }, { + ^bb0(%arg1: tensor): + %1 = "xla_hlo.create_token"() : () -> !xla_hlo.token + // Side-effecting op outfeed present inside while. + %2 = "xla_hlo.outfeed"(%arg1, %1) {outfeed_config = ""} : (tensor, !xla_hlo.token) -> !xla_hlo.token + "xla_hlo.return"(%arg1) : (tensor) -> () + }) : (tensor) -> tensor + + return %arg0 : tensor +} + +// CHECK-LABEL: dce_while_without_side_effect +func @dce_while_without_side_effect(%arg0: tensor) -> tensor { + // CHECK-NOT: xla_hlo.while + %0 = "xla_hlo.while"(%arg0) ( { + ^bb0(%arg1: tensor): + %1 = "xla_hlo.compare"(%arg1, %arg1) {comparison_direction = "LT"} : (tensor, tensor) -> tensor + "xla_hlo.return"(%1) : (tensor) -> () + }, { + ^bb0(%arg1: tensor): + %1 = "xla_hlo.create_token"() : () -> !xla_hlo.token + "xla_hlo.return"(%arg1) : (tensor) -> () + }) : (tensor) -> tensor + + return %arg0 : tensor +} diff --git a/tests/chlo_infer_shape_type_methods.mlir b/tests/chlo_infer_shape_type_methods.mlir new file mode 100644 index 0000000..f71f58f --- /dev/null +++ b/tests/chlo_infer_shape_type_methods.mlir @@ -0,0 +1,56 @@ +// RUN: mlir-hlo-opt -test-xla-infer-shaped-type-methods -allow-unregistered-dialect -split-input-file -verify-diagnostics %s -o - | FileCheck %s + +// CHECK-LABEL: @broadcast_add +// Note that all broadcast_ops are expanded from the same template, so +// only test reification on an examplar op. +// CHECK-SAME: %[[ARG0:.+]]: tensor, +// CHECK-SAME: %[[ARG1:.+]]: tensor +func @broadcast_add(%arg0: tensor, %arg1: tensor) -> tensor<1xindex> { + // CHECK-DAG: %[[ARG0_S:.+]] = shape.shape_of %[[ARG0]] + // CHECK-DAG: %[[ARG1_S:.+]] = shape.shape_of %[[ARG1]] + // CHECK-DAG: %[[BCAST_S:.+]] = "shape.broadcast"(%[[ARG0_S]], %[[ARG1_S]]) + // CHECK: %[[EXTENTS:.+]] = shape.to_extent_tensor %[[BCAST_S]] + // CHECK: return %[[EXTENTS]] + %0 = xla_chlo.broadcast_add %arg0, %arg1 : (tensor, tensor) -> tensor + %1 = "xla_test.reify_return_type_shapes"(%0) : (tensor) -> tensor<1xindex> + return %1 : tensor<1xindex> +} + +// ----- +// CHECK-LABEL: @complex_ranked_components +func @complex_ranked_components(%arg0: tensor, %arg1: tensor) -> tensor> { + %0 = xla_chlo.broadcast_complex %arg0, %arg1 : (tensor, tensor) -> tensor> + // CHECK: "xla_test.return_type_components"(%0) {dims0 = [-1, -1], element_type0 = complex} + %1 = "xla_test.get_return_type_components"(%0) : (tensor>) -> tensor> + return %1 : tensor> +} + +// ----- +// CHECK-LABEL: @compare_ranked_components +func @compare_ranked_components(%arg0: tensor, %arg1: tensor) -> tensor { + %0 = xla_chlo.broadcast_compare %arg0, %arg1 {comparison_direction = "EQ"} : (tensor, tensor) -> tensor + // CHECK: "xla_test.return_type_components"(%0) {dims0 = [-1, -1], element_type0 = i1} + %1 = "xla_test.get_return_type_components"(%0) : (tensor) -> tensor + return %0 : tensor +} + +// ----- +// CHECK-LABEL: @broadcast_add_ranked_components_r1 +func @broadcast_add_ranked_components_r1(%arg0: tensor, %arg1: tensor) -> tensor { + %0 = xla_chlo.broadcast_add %arg0, %arg1 : (tensor, tensor) -> tensor + // CHECK: "xla_test.return_type_components"(%0) {dims0 = [-1], element_type0 = f32} + %1 = "xla_test.get_return_type_components"(%0) : (tensor) -> tensor + return %1 : tensor +} + +// ----- +// CHECK-LABEL: @broadcast_add_ranked_components_r1x2 +func @broadcast_add_ranked_components_r1x2(%arg0: tensor, %arg1: tensor) -> tensor { + %0 = xla_chlo.broadcast_add %arg0, %arg1 : (tensor, tensor) -> tensor + // TODO: Overly broad shapes are being returned. Tighten the calculation + // and update/extend these tests. + // CHECK: "xla_test.return_type_components"(%0) {dims0 = [-1, -1], element_type0 = f32} + %1 = "xla_test.get_return_type_components"(%0) : (tensor) -> tensor + return %1 : tensor +} + diff --git a/tests/chlo_legalize_to_hlo_broadcasts.mlir b/tests/chlo_legalize_to_hlo_broadcasts.mlir new file mode 100644 index 0000000..b290dcb --- /dev/null +++ b/tests/chlo_legalize_to_hlo_broadcasts.mlir @@ -0,0 +1,239 @@ +// RUN: mlir-hlo-opt -test-xla-chlo-legalize-to-hlo -cse -split-input-file -verify-diagnostics %s -o - | FileCheck %s + +// Check the non-broadcast case for each registered op, then just check a +// representative op for detailed broadcast semantics. +// CHECK-LABEL: @addWithoutBroadcast +func @addWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { + // CHECK: xla_hlo.add %arg0, %arg1 + %0 = xla_chlo.broadcast_add %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + return %0 : tensor<4xf32> +} + +// ----- +// CHECK-LABEL: @dynamicBroadcast +// CHECK-SAME: %[[ARG0:.+]]: tensor +// CHECK-SAME: %[[ARG1:.+]]: tensor +func @dynamicBroadcast(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK-DAG: %[[ARG0_S:.+]] = shape.shape_of %[[ARG0]] + // CHECK-DAG: %[[ARG1_S:.+]] = shape.shape_of %[[ARG1]] + // CHECK-NEXT: %[[WITNESS:.+]] = shape.cstr_broadcastable %[[ARG0_S]], %[[ARG1_S]] + // CHECK-NEXT: %[[FINAL_RESULT:.+]] = shape.assuming %[[WITNESS]] + // CHECK-DAG: %[[RESULT_S:.+]] = "shape.broadcast"(%[[ARG0_S]], %[[ARG1_S]]) + // CHECK: %[[RESULT_EXTENTS:.+]] = shape.to_extent_tensor %[[RESULT_S]] + // CHECK-DAG: %[[ARG0_B:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[ARG0]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} + // CHECK-DAG: %[[ARG1_B:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[ARG1]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} + // CHECK-NEXT: %[[RESULT:.+]] = xla_hlo.add %[[ARG0_B]], %[[ARG1_B]] + // CHECK-NEXT: shape.assuming_yield %[[RESULT]] + // CHECK-NEXT: } + // CHECK-NEXT: return %[[FINAL_RESULT]] : tensor + %0 = xla_chlo.broadcast_add %arg0, %arg1 : (tensor, tensor) -> tensor + return %0 : tensor +} + +// ----- +// CHECK-LABEL: @dynamicBroadcastComplex +// CHECK-SAME: %[[ARG0:.+]]: tensor +// CHECK-SAME: %[[ARG1:.+]]: tensor +func @dynamicBroadcastComplex(%arg0: tensor, %arg1: tensor) -> tensor> { + // CHECK-DAG: %[[ARG0_S:.+]] = shape.shape_of %[[ARG0]] + // CHECK-DAG: %[[ARG1_S:.+]] = shape.shape_of %[[ARG1]] + // CHECK-NEXT: %[[WITNESS:.+]] = shape.cstr_broadcastable %[[ARG0_S]], %[[ARG1_S]] + // CHECK-NEXT: %[[FINAL_RESULT:.+]] = shape.assuming %[[WITNESS]] + // CHECK-NEXT: %[[RESULT_S:.+]] = "shape.broadcast"(%[[ARG0_S]], %[[ARG1_S]]) + // CHECK-NEXT: %[[RESULT_EXTENTS:.+]] = shape.to_extent_tensor %[[RESULT_S]] + // CHECK-DAG: %[[ARG0_B:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[ARG0]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor, tensor<2xindex>) -> tensor + // CHECK-DAG: %[[ARG1_B:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[ARG1]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor, tensor<2xindex>) -> tensor + // CHECK-NEXT: %[[RESULT:.+]] = "xla_hlo.complex"(%[[ARG0_B]], %[[ARG1_B]]) : (tensor, tensor) -> tensor> + // CHECK-NEXT: shape.assuming_yield %[[RESULT]] + // CHECK-NEXT: } + // CHECK-NEXT: return %[[FINAL_RESULT]] : tensor> + %0 = xla_chlo.broadcast_complex %arg0, %arg1 : (tensor, tensor) -> tensor> + return %0 : tensor> +} + +// ----- +// CHECK-LABEL: @dynamicBroadcastCompare +// CHECK-SAME: %[[ARG0:.+]]: tensor +// CHECK-SAME: %[[ARG1:.+]]: tensor +func @dynamicBroadcastCompare(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK-DAG: %[[ARG0_S:.+]] = shape.shape_of %[[ARG0]] + // CHECK-DAG: %[[ARG1_S:.+]] = shape.shape_of %[[ARG1]] + // CHECK: %[[WITNESS:.+]] = shape.cstr_broadcastable %[[ARG0_S]], %[[ARG1_S]] + // CHECK: %[[FINAL_RESULT:.+]] = shape.assuming %[[WITNESS]] + // CHECK: %[[RESULT_S:.+]] = "shape.broadcast"(%[[ARG0_S]], %[[ARG1_S]]) + // CHECK: %[[RESULT_EXTENTS:.+]] = shape.to_extent_tensor %[[RESULT_S]] + // CHECK-DAG: %[[ARG0_B:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[ARG0]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor, tensor<2xindex>) -> tensor + // CHECK-DAG: %[[ARG1_B:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[ARG1]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor, tensor<2xindex>) -> tensor + // CHECK: %[[RESULT:.+]] = "xla_hlo.compare"(%[[ARG0_B]], %[[ARG1_B]]) {comparison_direction = "EQ"} : (tensor, tensor) -> tensor + // CHECK: shape.assuming_yield %[[RESULT]] + // CHECK-NEXT: } + // CHECK: return %[[FINAL_RESULT]] : tensor + %0 = xla_chlo.broadcast_compare %arg0, %arg1 {comparison_direction = "EQ"} : (tensor, tensor) -> tensor + return %0 : tensor +} + +// ----- +// Verifies that broadcast_dimensions validity checks are valid. +// CHECK-LABEL: @dynamicNonScalarBroadcastDimensions +func @dynamicNonScalarBroadcastDimensions(%arg0: tensor<1x4xf32>, %arg1: tensor<4xf32>) -> tensor<1x4xf32> { + // CHECK: xla_hlo.add + %0 = xla_chlo.broadcast_add %arg0, %arg1 {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1x4xf32>, tensor<4xf32>) -> tensor<1x4xf32> + return %0 : tensor<1x4xf32> +} + +// ----- +// Verifies that broadcast_dimensions validity checks are valid. +// CHECK-LABEL: @dynamicNonScalarByScalarBroadcastDimensions +func @dynamicNonScalarByScalarBroadcastDimensions(%arg0: tensor<1x4xf32>, %arg1: tensor) -> tensor<1x4xf32> { + // CHECK: xla_hlo.add + %0 = xla_chlo.broadcast_add %arg0, %arg1 {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<1x4xf32>, tensor) -> tensor<1x4xf32> + return %0 : tensor<1x4xf32> +} + +// ----- +// Verifies that invalid broadcast dimensions are rejected. +func @dynamicNonScalarBroadcastDimensionsSizeMismatch(%arg0: tensor<1x4xf32>, %arg1: tensor<4xf32>) -> tensor<1x4xf32> { + // expected-warning @+2 {{unsupported non prefix-padded dynamic rank broadcast_dimensions}} + // expected-error @+1 {{failed to legalize operation}} + %0 = xla_chlo.broadcast_add %arg0, %arg1 {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (tensor<1x4xf32>, tensor<4xf32>) -> tensor<1x4xf32> + return %0 : tensor<1x4xf32> +} + +// ----- +// Verifies that invalid broadcast dimensions are rejected. +func @dynamicNonScalarBroadcastDimensionsMismatch(%arg0: tensor<1x4xf32>, %arg1: tensor<4xf32>) -> tensor<1x4xf32> { + // expected-warning @+2 {{unsupported non prefix-padded dynamic rank broadcast_dimensions}} + // expected-error @+1 {{failed to legalize operation}} + %0 = xla_chlo.broadcast_add %arg0, %arg1 {broadcast_dimensions = dense<2> : tensor<1xi64>} : (tensor<1x4xf32>, tensor<4xf32>) -> tensor<1x4xf32> + return %0 : tensor<1x4xf32> +} + +// ----- +// Note that broadcast_add is used as a proxy for all of the template +// expansions. Tests below merely verify that the op has an expansion. +// CHECK-LABEL: @andWithoutBroadcast +func @andWithoutBroadcast(%arg0: tensor<4xi1>, %arg1: tensor<4xi1>) -> tensor<4xi1> { + // CHECK: xla_hlo.and %arg0, %arg1 + %0 = xla_chlo.broadcast_and %arg0, %arg1 : (tensor<4xi1>, tensor<4xi1>) -> tensor<4xi1> + return %0 : tensor<4xi1> +} + +// ----- +// CHECK-LABEL: @atan2WithoutBroadcast +func @atan2WithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { + // CHECK: xla_hlo.atan2 %arg0, %arg1 + %0 = xla_chlo.broadcast_atan2 %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + return %0 : tensor<4xf32> +} + +// ----- +// CHECK-LABEL: @compareWithoutBroadcast +func @compareWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xi1> { + // CHECK: "xla_hlo.compare"(%arg0, %arg1) {comparison_direction = "EQ"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1> + %0 = xla_chlo.broadcast_compare %arg0, %arg1 {comparison_direction = "EQ"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1> + return %0 : tensor<4xi1> +} + +// ----- +// CHECK-LABEL: @complexWithoutBroadcast +func @complexWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xcomplex> { + // CHECK: "xla_hlo.complex"(%arg0, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xcomplex> + %0 = xla_chlo.broadcast_complex %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xcomplex> + return %0 : tensor<4xcomplex> +} + +// ----- +// CHECK-LABEL: @divideWithoutBroadcast +func @divideWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { + // CHECK: xla_hlo.divide %arg0, %arg1 + %0 = xla_chlo.broadcast_divide %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + return %0 : tensor<4xf32> +} + +// ----- +// CHECK-LABEL: @maximumWithoutBroadcast +func @maximumWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { + // CHECK: xla_hlo.maximum %arg0, %arg1 + %0 = xla_chlo.broadcast_maximum %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + return %0 : tensor<4xf32> +} + +// ----- +// CHECK-LABEL: @minimumWithoutBroadcast +func @minimumWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { + // CHECK: xla_hlo.minimum %arg0, %arg1 + %0 = xla_chlo.broadcast_minimum %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + return %0 : tensor<4xf32> +} + +// ----- +// CHECK-LABEL: @multiplyWithoutBroadcast +func @multiplyWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { + // CHECK: xla_hlo.multiply %arg0, %arg1 + %0 = xla_chlo.broadcast_multiply %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + return %0 : tensor<4xf32> +} + +// ----- +// CHECK-LABEL: @orWithoutBroadcast +func @orWithoutBroadcast(%arg0: tensor<4xi1>, %arg1: tensor<4xi1>) -> tensor<4xi1> { + // CHECK: xla_hlo.or %arg0, %arg1 + %0 = xla_chlo.broadcast_or %arg0, %arg1 : (tensor<4xi1>, tensor<4xi1>) -> tensor<4xi1> + return %0 : tensor<4xi1> +} + +// ----- +// CHECK-LABEL: @powerWithoutBroadcast +func @powerWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { + // CHECK: xla_hlo.power %arg0, %arg1 + %0 = xla_chlo.broadcast_power %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + return %0 : tensor<4xf32> +} + +// ----- +// CHECK-LABEL: @remainderWithoutBroadcast +func @remainderWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { + // CHECK: xla_hlo.remainder %arg0, %arg1 + %0 = xla_chlo.broadcast_remainder %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + return %0 : tensor<4xf32> +} + +// ----- +// CHECK-LABEL: @shift_leftWithoutBroadcast +func @shift_leftWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { + // CHECK: xla_hlo.shift_left %arg0, %arg1 + %0 = xla_chlo.broadcast_shift_left %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + return %0 : tensor<4xf32> +} + +// ----- +// CHECK-LABEL: @shift_right_arithmeticWithoutBroadcast +func @shift_right_arithmeticWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { + // CHECK: xla_hlo.shift_right_arithmetic %arg0, %arg1 + %0 = xla_chlo.broadcast_shift_right_arithmetic %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + return %0 : tensor<4xf32> +} + +// ----- +// CHECK-LABEL: @shift_right_logicalWithoutBroadcast +func @shift_right_logicalWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { + // CHECK: xla_hlo.shift_right_logical %arg0, %arg1 + %0 = xla_chlo.broadcast_shift_right_logical %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + return %0 : tensor<4xf32> +} + +// ----- +// CHECK-LABEL: @subWithoutBroadcast +func @subWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { + // CHECK: xla_hlo.subtract %arg0, %arg1 + %0 = xla_chlo.broadcast_subtract %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + return %0 : tensor<4xf32> +} + +// ----- +// CHECK-LABEL: @xorWithoutBroadcast +func @xorWithoutBroadcast(%arg0: tensor<4xi1>, %arg1: tensor<4xi1>) -> tensor<4xi1> { + // CHECK: xla_hlo.xor %arg0, %arg1 + %0 = xla_chlo.broadcast_xor %arg0, %arg1 : (tensor<4xi1>, tensor<4xi1>) -> tensor<4xi1> + return %0 : tensor<4xi1> +} diff --git a/tests/concatenate.mlir b/tests/concatenate.mlir new file mode 100644 index 0000000..179616e --- /dev/null +++ b/tests/concatenate.mlir @@ -0,0 +1,9 @@ +// RUN: mlir-hlo-opt %s -pass-pipeline='func(canonicalize)' | FileCheck %s + +// CHECK-LABEL: func @single_operand +// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]] +func @single_operand(%arg: tensor<1x2xf32>) -> tensor<1x2xf32> { + %0 = "xla_hlo.concatenate"(%arg) {dimension = 0 : i64} : (tensor<1x2xf32>) -> tensor<1x2xf32> + // CHECK-NEXT: return [[ARG]] + return %0 : tensor<1x2xf32> +} \ No newline at end of file diff --git a/tests/convert.mlir b/tests/convert.mlir new file mode 100644 index 0000000..783fe8a --- /dev/null +++ b/tests/convert.mlir @@ -0,0 +1,225 @@ +// RUN: mlir-hlo-opt %s -split-input-file -pass-pipeline='func(canonicalize)' | FileCheck %s + +// ----- + +// CHECK-LABEL: func @same_type +// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]] +func @same_type(%arg: tensor) -> tensor { + %0 = "xla_hlo.convert"(%arg) : (tensor) -> tensor + // CHECK-NEXT: return [[ARG]] + return %0 : tensor +} + +// ----- + +// CHECK-LABEL: func @int_widening +// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]] +func @int_widening(%arg: tensor) -> tensor { + // CHECK-NEXT: [[RES:%.+]] = "xla_hlo.convert"([[ARG]]) : (tensor) -> tensor + %0 = "xla_hlo.convert"(%arg) : (tensor) -> tensor + // CHECK-NEXT: return [[RES]] + return %0 : tensor +} + +// ----- + +// CHECK-LABEL: func @int_narrowing +// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]] +func @int_narrowing(%arg: tensor) -> tensor { + // CHECK-NEXT: [[RES:%.+]] = "xla_hlo.convert"([[ARG]]) : (tensor) -> tensor + %0 = "xla_hlo.convert"(%arg) : (tensor) -> tensor + // CHECK-NEXT: return [[RES]] + return %0 : tensor +} + +// ----- + +// CHECK-LABEL: func @float_int +// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]] +func @float_int(%arg: tensor) -> tensor { + // CHECK-NEXT: [[RES:%.+]] = "xla_hlo.convert"([[ARG]]) : (tensor) -> tensor + %0 = "xla_hlo.convert"(%arg) : (tensor) -> tensor + // CHECK-NEXT: return [[RES]] + return %0 : tensor +} + +// ----- + +// CHECK-LABEL: func @int_float +// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]] +func @int_float(%arg: tensor) -> tensor { + // CHECK-NEXT: [[RES:%.+]] = "xla_hlo.convert"([[ARG]]) : (tensor) -> tensor + %0 = "xla_hlo.convert"(%arg) : (tensor) -> tensor + // CHECK-NEXT: return [[RES]] + return %0 : tensor +} + +// ----- + +// CHECK-LABEL: func @high_rank_tensor +// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]] +func @high_rank_tensor(%arg: tensor<2x3xi32>) -> tensor<2x3xf32> { + // CHECK-NEXT: [[RES:%.+]] = "xla_hlo.convert"([[ARG]]) : (tensor<2x3xi32>) -> tensor<2x3xf32> + %0 = "xla_hlo.convert"(%arg) : (tensor<2x3xi32>) -> tensor<2x3xf32> + // CHECK-NEXT: return [[RES]] + return %0 : tensor<2x3xf32> +} + +// ----- + + +// CHECK-LABEL: func @const_same_type +func @const_same_type() -> tensor { + // CHECK-NEXT: [[CST:%.+]] = xla_hlo.constant dense<42> : tensor + %cst = xla_hlo.constant dense<42> : tensor + %0 = "xla_hlo.convert"(%cst) : (tensor) -> tensor + // CHECK-NEXT: return [[CST]] + return %0 : tensor +} + +// ----- + +// CHECK-LABEL: func @const_float_int +func @const_float_int() -> tensor { + // CHECK-NEXT: [[CST:%.+]] = xla_hlo.constant dense<42> : tensor + %cst = xla_hlo.constant dense<42.0> : tensor + %0 = "xla_hlo.convert"(%cst) : (tensor) -> tensor + // CHECK-NEXT: return [[CST]] + return %0 : tensor +} + +// ----- + +// CHECK-LABEL: func @const_int_float +func @const_int_float() -> tensor { + // CHECK-NEXT: [[CST:%.+]] = xla_hlo.constant dense<4.{{0*}}e+00> : tensor + %cst = xla_hlo.constant dense<4> : tensor + %0 = "xla_hlo.convert"(%cst) : (tensor) -> tensor + // CHECK-NEXT: return [[CST]] + return %0 : tensor +} + +// ----- + +// CHECK-LABEL: func @const_negative_int_float +func @const_negative_int_float() -> tensor { + // CHECK-NEXT: [[CST:%.+]] = xla_hlo.constant dense<-4.{{0*}}e+00> : tensor + %cst = xla_hlo.constant dense<-4> : tensor + %0 = "xla_hlo.convert"(%cst) : (tensor) -> tensor + // CHECK-NEXT: return [[CST]] + return %0 : tensor +} + +// ----- + +// CHECK-LABEL: func @const_int_bf16 +func @const_int_bf16() -> tensor { + // CHECK-NEXT: [[CST:%.+]] = xla_hlo.constant dense<4.{{0*}}e+00> : tensor + %cst = xla_hlo.constant dense<4> : tensor + %0 = "xla_hlo.convert"(%cst) : (tensor) -> tensor + // CHECK-NEXT: return [[CST]] + return %0 : tensor +} + +// ----- + +// CHECK-LABEL: func @const_bf16_int +func @const_bf16_int() -> tensor { + // CHECK-NEXT: [[CST:%.+]] = xla_hlo.constant dense<42> : tensor + %cst = xla_hlo.constant dense<42.0> : tensor + %0 = "xla_hlo.convert"(%cst) : (tensor) -> tensor + // CHECK-NEXT: return [[CST]] + return %0 : tensor +} + +// ----- + +// CHECK-LABEL: func @const_int_narrowing +func @const_int_narrowing() -> tensor { + // CHECK-NEXT: [[CST:%.+]] = xla_hlo.constant dense<42> : tensor + %cst = xla_hlo.constant dense<42> : tensor + %0 = "xla_hlo.convert"(%cst) : (tensor) -> tensor + // CHECK-NEXT: return [[CST]] + return %0 : tensor +} + +// ----- + +// CHECK-LABEL: func @const_int_widening +func @const_int_widening() -> tensor { + // CHECK-NEXT: [[CST:%.+]] = xla_hlo.constant dense<42> : tensor + %cst = xla_hlo.constant dense<42> : tensor + %0 = "xla_hlo.convert"(%cst) : (tensor) -> tensor + // CHECK-NEXT: return [[CST]] + return %0 : tensor +} + +// ----- + +// CHECK-LABEL: func @const_negative_int_widening +func @const_negative_int_widening() -> tensor { + // CHECK-NEXT: [[CST:%.+]] = xla_hlo.constant dense<-42> : tensor + %cst = xla_hlo.constant dense<-42> : tensor + %0 = "xla_hlo.convert"(%cst) : (tensor) -> tensor + // CHECK-NEXT: return [[CST]] + return %0 : tensor +} + +// ----- + +// CHECK-LABEL: func @const_float_narrowing +func @const_float_narrowing() -> tensor { + // CHECK-NEXT: [[CST:%.+]] = xla_hlo.constant dense<4.2{{0*}}e+00> : tensor + %cst = xla_hlo.constant dense<4.2> : tensor + %0 = "xla_hlo.convert"(%cst) : (tensor) -> tensor + // CHECK-NEXT: return [[CST]] + return %0 : tensor +} + +// ----- + +// CHECK-LABEL: func @const_f32_bf16 +func @const_f32_bf16() -> tensor { + // CHECK-NEXT: [[CST:%.+]] = xla_hlo.constant dense<4.2{{0*}}e+01> : tensor + %cst = xla_hlo.constant dense<42.0> : tensor + %0 = "xla_hlo.convert"(%cst) : (tensor) -> tensor + // CHECK-NEXT: return [[CST]] + return %0 : tensor +} + +// ----- + +// CHECK-LABEL: func @const_bf16_f64 +func @const_bf16_f64() -> tensor { + // CHECK-NEXT: [[CST:%.+]] = xla_hlo.constant dense<4.187500e+00> : tensor + %cst = xla_hlo.constant dense<4.2> : tensor + %0 = "xla_hlo.convert"(%cst) : (tensor) -> tensor + // CHECK-NEXT: return [[CST]] + return %0 : tensor +} + +// ----- + +// CHECK-LABEL: func @const_bf16_int +func @const_bf16_int() -> tensor { + // CHECK-NEXT: [[CST:%.+]] = xla_hlo.constant dense<42> : tensor + %cst = xla_hlo.constant dense<42.0> : tensor + %0 = "xla_hlo.convert"(%cst) : (tensor) -> tensor + // CHECK-NEXT: return [[CST]] + return %0 : tensor +} + + +// ----- + +// CHECK-LABEL: func @const_high_rank_tensor +func @const_high_rank_tensor() -> tensor<2x3xi32> { + // CHECK-NEXT: [[CST:%.+]] = xla_hlo.constant dense<[ + // CHECK-SAME: [1, 2, 3], [4, 5, 6] + // CHECK-SAME: ]> : tensor<2x3xi32> + %cst = xla_hlo.constant dense<[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]> : tensor<2x3xf32> + %0 = "xla_hlo.convert"(%cst) : (tensor<2x3xf32>) -> tensor<2x3xi32> + // CHECK-NEXT: return [[CST]] + return %0 : tensor<2x3xi32> +} + diff --git a/tests/hlo-legalize-to-lhlo.mlir b/tests/hlo-legalize-to-lhlo.mlir new file mode 100644 index 0000000..b13dd27 --- /dev/null +++ b/tests/hlo-legalize-to-lhlo.mlir @@ -0,0 +1,488 @@ +// RUN: mlir-hlo-opt -hlo-legalize-to-lhlo -buffer-placement -split-input-file %s -o - | FileCheck --check-prefixes=PRE,BOTH %s +// RUN: mlir-hlo-opt -hlo-legalize-to-lhlo=results-escape-function=true -buffer-placement -split-input-file %s -o - | FileCheck --check-prefixes=ESC,BOTH %s + +// BOTH-LABEL: func @attrs +func @attrs_copy(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { + %tensor_operand = tensor_load %operand : memref<2x2xf32> + %tensor_result = "xla_hlo.exponential"(%tensor_operand) + {some_attr_1 = "exp.1", some_attr_2 = dense<1> : tensor<1xi64>} + : (tensor<2x2xf32>) -> tensor<2x2xf32> + // BOTH: "xla_lhlo.exponential"(%{{.*}}, %{{.*}}) {some_attr_1 = "exp.1", some_attr_2 = dense<1> : tensor<1xi64>} + tensor_store %tensor_result, %result : memref<2x2xf32> + return +} + +// ----- + +func @return_func(%arg0: tensor<4xf32>) -> tensor<4xf32> { + return %arg0 : tensor<4xf32> +} +// PRE: (%[[ARG0:.*]]: [[TYPE:.*]], %[[RESULT:.*]]: [[TYPE]]) +// PRE-NEXT: "xla_lhlo.copy"(%[[ARG0]], %[[RESULT]]) : ([[TYPE]], [[TYPE]]) -> () +// PRE-NEXT: return +// ESC: (%[[ARG0:.*]]: [[TYPE:.*]]) -> [[TYPE]] +// ESC-NOT: "xla_lhlo.copy" +// ESC-NEXT: return %[[ARG0]] + +// ----- + +// BOTH-LABEL: func @func_op_long +func @func_op_long(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { + %1 = xla_hlo.maximum %arg0, %arg1 : tensor<4xf32> + %2 = xla_hlo.add %arg0, %1 : tensor<4xf32> + %3 = xla_hlo.minimum %arg0, %arg1 : tensor<4xf32> + %4 = xla_hlo.subtract %arg1, %3 : tensor<4xf32> + %5 = xla_hlo.multiply %2, %4 : tensor<4xf32> + return %5 : tensor<4xf32> +} +// PRE: (%[[NEW_ARG0:.*]]: memref<4xf32>, %[[NEW_ARG1:.*]]: memref<4xf32>, %[[RESULT:.*]]: memref<4xf32>) +// ESC: (%[[NEW_ARG0:.*]]: memref<4xf32>, %[[NEW_ARG1:.*]]: memref<4xf32>) -> memref<4xf32> +// BOTH-NEXT: %[[MAX_RESULT:.*]] = alloc() : memref<4xf32> +// BOTH-NEXT: "xla_lhlo.maximum"(%[[NEW_ARG0]], %[[NEW_ARG1]], %[[MAX_RESULT]]) +// BOTH-NEXT: %[[ADD_RESULT:.*]] = alloc() : memref<4xf32> +// BOTH-NEXT: "xla_lhlo.add"(%[[NEW_ARG0]], %[[MAX_RESULT]], %[[ADD_RESULT]]) +// BOTH-NEXT: dealloc %[[MAX_RESULT]] : memref<4xf32> +// BOTH-NEXT: %[[MIN_RESULT:.*]] = alloc() : memref<4xf32> +// BOTH-NEXT: "xla_lhlo.minimum"(%[[NEW_ARG0]], %[[NEW_ARG1]], %[[MIN_RESULT]]) +// BOTH-NEXT: %[[SUB_RESULT:.*]] = alloc() : memref<4xf32> +//  BOTH-NEXT: "xla_lhlo.subtract"(%[[NEW_ARG1]], %[[MIN_RESULT]], %[[SUB_RESULT]]) +// BOTH-NEXT: dealloc %[[MIN_RESULT]] : memref<4xf32> +// BOTH-NEXT: %[[MUL_RESULT:.*]] = alloc() : memref<4xf32> +// BOTH-NEXT: "xla_lhlo.multiply"(%[[ADD_RESULT]], %[[SUB_RESULT]], %[[MUL_RESULT]]) +// BOTH-NEXT: dealloc %[[SUB_RESULT]] : memref<4xf32> +// BOTH-NEXT: dealloc %[[ADD_RESULT]] : memref<4xf32> +// PRE-NEXT: "xla_lhlo.copy"(%[[MUL_RESULT]], %[[RESULT]]) : (memref<4xf32>, memref<4xf32>) -> () +// PRE-NEXT: dealloc %[[MUL_RESULT]] : memref<4xf32> +// PRE-NEXT: return +// ESC-NEXT: return %[[MUL_RESULT]] : memref<4xf32> + +// ----- + +// BOTH-LABEL: func @fusion +func @fusion(%multiplier: memref<2x2xf32>, %summand_1: memref<2x2xf32>, + %summand_2: memref<2x2xf32>, %result: memref<2x2xf32>) { + // BOTH: (%{{.*}}: {{.*}}, {{.*}}: {{.*}}, {{.*}}: {{.*}}, %[[RESULT:.*]]: {{.*}}) + // BOTH-NEXT: %[[ADD_RESULT:.*]] = alloc() : memref<2x2xf32> + %tensor_summand_1 = tensor_load %summand_1 : memref<2x2xf32> + %tensor_summand_2 = tensor_load %summand_2 : memref<2x2xf32> + %sum = "xla_hlo.add"(%tensor_summand_1, %tensor_summand_2) + : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> + // BOTH-NEXT: "xla_lhlo.add"(%{{.*}}, %{{.*}}, %[[ADD_RESULT]]) + // BOTH-NEXT: %[[MUL_RESULT:.*]] = alloc() : memref<2x2xf32> + %tensor_multiplier = tensor_load %multiplier : memref<2x2xf32> + %tensor_result = "xla_hlo.multiply"(%sum, %tensor_multiplier) + : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> + // BOTH-NEXT: "xla_lhlo.multiply"(%[[ADD_RESULT]], %{{.*}}, %[[MUL_RESULT]]) + // BOTH-NEXT: dealloc %[[ADD_RESULT]] : memref<2x2xf32> + // BOTH-NEXT: "xla_lhlo.copy"(%[[MUL_RESULT]], %[[RESULT]]) + tensor_store %tensor_result, %result : memref<2x2xf32> + // BOTH-NEXT: dealloc %[[MUL_RESULT]] : memref<2x2xf32> + // BOTH-NEXT: return + return +} + +// ----- + +// BOTH-LABEL: func @copy +func @copy(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { + %tensor_operand = tensor_load %operand : memref<2x2xf32> + %tensor_result = "xla_hlo.copy"(%tensor_operand) + : (tensor<2x2xf32>) -> tensor<2x2xf32> + // BOTH: "xla_lhlo.copy"(%{{.*}}, %{{.*}}) + tensor_store %tensor_result, %result : memref<2x2xf32> + return +} + +// ----- + +// BOTH-LABEL: func @exp +func @exp(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { + %tensor_operand = tensor_load %operand : memref<2x2xf32> + %tensor_result = "xla_hlo.exponential"(%tensor_operand) + : (tensor<2x2xf32>) -> tensor<2x2xf32> + // BOTH: "xla_lhlo.exponential"(%{{.*}}, %{{.*}}) + tensor_store %tensor_result, %result : memref<2x2xf32> + return +} + +// ----- + +// BOTH-LABEL: func @log +func @log(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { + %tensor_operand = tensor_load %operand : memref<2x2xf32> + %tensor_result = "xla_hlo.log"(%tensor_operand) + : (tensor<2x2xf32>) -> tensor<2x2xf32> + // BOTH: "xla_lhlo.log"(%{{.*}}, %{{.*}}) + tensor_store %tensor_result, %result : memref<2x2xf32> + return +} + +// ----- + +// BOTH-LABEL: func @select +func @select(%pred: memref<2x2xi1>, %lhs: memref<2x2xf32>, + %rhs: memref<2x2xf32>, %result: memref<2x2xf32>) { + %tensor_pred = tensor_load %pred : memref<2x2xi1> + %tensor_lhs = tensor_load %lhs : memref<2x2xf32> + %tensor_rhs = tensor_load %rhs : memref<2x2xf32> + %tensor_result = "xla_hlo.select"(%tensor_pred, %tensor_lhs, %tensor_rhs) + : (tensor<2x2xi1>, tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> + // BOTH: "xla_lhlo.select"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) + tensor_store %tensor_result, %result : memref<2x2xf32> + return +} + +// ----- + +// BOTH-LABEL: func @compare +func @compare(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>, %result: memref<2x2xi1>) { + %tensor_lhs = tensor_load %lhs : memref<2x2xf32> + %tensor_rhs = tensor_load %rhs : memref<2x2xf32> + %tensor_result = "xla_hlo.compare"(%tensor_lhs, %tensor_rhs) + {comparison_direction = "EQ"} + : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xi1> + // BOTH: "xla_lhlo.compare"(%{{.*}}, %{{.*}}, %{{.*}}) {comparison_direction = "EQ"} + tensor_store %tensor_result, %result : memref<2x2xi1> + return +} + +// ----- + +// BOTH-LABEL: func @broadcast +func @broadcast(%operand: memref<5xf32>, %result: memref<10x5xf32>) { + %tensor_operand = tensor_load %operand : memref<5xf32> + %tensor_result = "xla_hlo.broadcast_in_dim"(%tensor_operand) + {broadcast_dimensions = dense<1> : tensor<1xi64>} + : (tensor<5xf32>) -> tensor<10x5xf32> + // BOTH: "xla_lhlo.broadcast_in_dim"(%{{.*}}, %{{.*}}) {broadcast_dimensions = dense<1> : tensor<1xi64>} + tensor_store %tensor_result, %result : memref<10x5xf32> + return +} + +// ----- + +func @external_func() -> tensor<3xi64> + +// BOTH: #[[MAP:.*]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s0 + d1 * s1)> + +// BOTH-LABEL: func @dyn_broadcast +func @dyn_broadcast(%operand: memref) { + // BOTH-SAME: (%[[OPERAND:.*]]: memref) + %tensor_operand = tensor_load %operand : memref + %shape = call @external_func() : () -> tensor<3xi64> + %tensor_result = "xla_hlo.dynamic_broadcast_in_dim"(%tensor_operand, %shape) { + broadcast_dimensions = dense<[1, 2]> : tensor<2xi64> + } : (tensor, tensor<3xi64>) -> tensor + // BOTH: %[[SHAPE:.*]] = call @external_func() + // BOTH: %[[C0:.*]] = constant 0 : index + // BOTH: %[[EL0:.*]] = extract_element %[[SHAPE]][%[[C0]]] : tensor<3xi64> + // BOTH: %[[IC0:.*]] = index_cast %[[EL0]] : i64 to index + // BOTH: %[[C1:.*]] = constant 1 : index + // BOTH: %[[EL1:.*]] = extract_element %[[SHAPE]][%[[C1]]] : tensor<3xi64> + // BOTH: %[[IC1:.*]] = index_cast %[[EL1]] : i64 to index + // BOTH: %[[C2:.*]] = constant 2 : index + // BOTH: %[[EL2:.*]] = extract_element %[[SHAPE]][%[[C2]]] : tensor<3xi64> + // BOTH: %[[IC2:.*]] = index_cast %[[EL2]] : i64 to index + // BOTH: %[[RESULT:.*]] = alloc(%[[IC0]], %[[IC1]], %[[IC2]]) + + // BOTH: %[[C0_:.*]] = constant 0 : index + // BOTH: %[[C1_:.*]] = constant 1 : index + + // BOTH: %[[C1__:.*]] = constant 1 : index + // BOTH: %[[EL1_:.*]] = extract_element %[[SHAPE]]{{\[}}%[[C1__]]] : tensor<3xi64> + // BOTH: %[[C0___:.*]] = constant 0 : index + // BOTH: %[[OPERAND_DIM_0:.*]] = dim %[[OPERAND]], %[[C0___]] : memref + // BOTH: %[[RESULT_DIM_1:.*]] = index_cast %[[EL1_]] : i64 to index + // BOTH: %[[EXPAND_0:.*]] = cmpi "slt", %[[OPERAND_DIM_0]], %[[RESULT_DIM_1]] + // BOTH: %[[STRIDE_0:.*]] = select %[[EXPAND_0]], %[[C0_]], %[[C1_]] : index + + // BOTH: %[[C2_:.*]] = constant 2 : index + // BOTH: %[[EL2_:.*]] = extract_element %[[SHAPE]]{{\[}}%[[C2_]]] : tensor<3xi64> + // BOTH: %[[C1___:.*]] = constant 1 : index + // BOTH: %[[OPERAND_DIM_1:.*]] = dim %[[OPERAND]], %[[C1___]] : memref + // BOTH: %[[RESULT_DIM_2:.*]] = index_cast %[[EL2_]] : i64 to index + // BOTH: %[[EXPAND_1:.*]] = cmpi "slt", %[[OPERAND_DIM_1]], %[[RESULT_DIM_2]] + // BOTH: %[[STRIDE_1:.*]] = select %[[EXPAND_1]], %[[C0_]], %[[C1_]] : index + + // BOTH: %[[TRANSFORMED_MEMREF:.*]] = xla_lhlo.dynamic_memref_cast + // BOTH-SAME: %[[OPERAND]](%[[RESULT_DIM_1]], %[[RESULT_DIM_2]]) + // BOTH-SAME: {{\[}}%[[STRIDE_0]], %[[STRIDE_1]]] + // BOTH-SAME: : memref -> memref + + // BOTH: "xla_lhlo.broadcast_in_dim"(%[[TRANSFORMED_MEMREF]], %[[RESULT]]) { + // BOTH-SAME: broadcast_dimensions = dense<[1, 2]> : tensor<2xi64> + // BOTH-SAME: } : (memref, memref) -> () + + // Do not store the value back to avoid the tensor-store being rewritten to + // a copy into the pre-allocated argument. + return +} + +// ----- + +// BOTH-LABEL: func @complex +func @complex(%real: memref<2x2xf32>, + %imag: memref<2x2xf32>, + %result: memref<2x2xcomplex>) { + %tensor_real = tensor_load %real : memref<2x2xf32> + %tensor_imag = tensor_load %imag : memref<2x2xf32> + %tensor_result = "xla_hlo.complex"(%tensor_real, %tensor_imag) + : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xcomplex> + // BOTH: "xla_lhlo.complex"(%{{.*}}, %{{.*}}) + tensor_store %tensor_result, %result : memref<2x2xcomplex> + return +} + +// ----- + +// BOTH-LABEL: func @real +func @real(%operand: memref<2x2xcomplex>, %result: memref<2x2xf32>) { + %tensor_operand = tensor_load %operand : memref<2x2xcomplex> + %tensor_result = "xla_hlo.real"(%tensor_operand) + : (tensor<2x2xcomplex>) -> tensor<2x2xf32> + // BOTH: "xla_lhlo.real"(%{{.*}}, %{{.*}}) + tensor_store %tensor_result, %result : memref<2x2xf32> + return +} + +// ----- + +// BOTH-LABEL: func @imag +func @imag(%operand: memref<2x2xcomplex>, %result: memref<2x2xf32>) { + %tensor_operand = tensor_load %operand : memref<2x2xcomplex> + %tensor_result = "xla_hlo.imag"(%tensor_operand) + : (tensor<2x2xcomplex>) -> tensor<2x2xf32> + // BOTH: "xla_lhlo.imag"(%{{.*}}, %{{.*}}) + tensor_store %tensor_result, %result : memref<2x2xf32> + return +} + +// ----- + +// BOTH-LABEL: func @iota +func @iota(%result: memref<10xi32>) { + %tensor_result = "xla_hlo.iota"() + {iota_dimension = 0 : i64} : () -> tensor<10xi32> + // BOTH: "xla_lhlo.iota"(%{{.*}}) {iota_dimension = 0 : i64} + tensor_store %tensor_result, %result : memref<10xi32> + return +} + +// ----- + +// BOTH-LABEL: func @abs +func @abs(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { + %tensor_operand = tensor_load %operand : memref<2x2xf32> + %tensor_result = "xla_hlo.abs"(%tensor_operand) + : (tensor<2x2xf32>) -> tensor<2x2xf32> + // BOTH: "xla_lhlo.abs"(%{{.*}}, %{{.*}}) + tensor_store %tensor_result, %result : memref<2x2xf32> + return +} + +// ----- + +// BOTH-LABEL: func @ceil +func @ceil(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { + %tensor_operand = tensor_load %operand : memref<2x2xf32> + %tensor_result = "xla_hlo.ceil"(%tensor_operand) + : (tensor<2x2xf32>) -> tensor<2x2xf32> + // BOTH: "xla_lhlo.ceil"(%{{.*}}, %{{.*}}) + tensor_store %tensor_result, %result : memref<2x2xf32> + return +} + +// ----- + +// BOTH-LABEL: func @convert +func @convert(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { + %tensor_operand = tensor_load %operand : memref<2x2xf32> + %tensor_result = "xla_hlo.convert"(%tensor_operand) + : (tensor<2x2xf32>) -> tensor<2x2xf32> + // BOTH: "xla_lhlo.copy"(%{{.*}}, %{{.*}}) + // BOTH-NOT: tensor_store + tensor_store %tensor_result, %result : memref<2x2xf32> + return +} + +// ----- + +// BOTH-LABEL: func @cos +func @cos(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { + %tensor_operand = tensor_load %operand : memref<2x2xf32> + %tensor_result = "xla_hlo.cosine"(%tensor_operand) + : (tensor<2x2xf32>) -> tensor<2x2xf32> + // BOTH: "xla_lhlo.cosine"(%{{.*}}, %{{.*}}) + tensor_store %tensor_result, %result : memref<2x2xf32> + return +} + +// ----- + +// BOTH-LABEL: func @neg +func @neg(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { + %tensor_operand = tensor_load %operand : memref<2x2xf32> + %tensor_result = "xla_hlo.negate"(%tensor_operand) + : (tensor<2x2xf32>) -> tensor<2x2xf32> + // BOTH: "xla_lhlo.negate"(%{{.*}}, %{{.*}}) + tensor_store %tensor_result, %result : memref<2x2xf32> + return +} + +// ----- + +// BOTH-LABEL: func @rsqrt +func @rsqrt(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { + %tensor_operand = tensor_load %operand : memref<2x2xf32> + %tensor_result = "xla_hlo.rsqrt"(%tensor_operand) + : (tensor<2x2xf32>) -> tensor<2x2xf32> + // BOTH: "xla_lhlo.rsqrt"(%{{.*}}, %{{.*}}) + tensor_store %tensor_result, %result : memref<2x2xf32> + return +} + +// ----- + +// BOTH-LABEL: func @sign +func @sign(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { + %tensor_operand = tensor_load %operand : memref<2x2xf32> + %tensor_result = "xla_hlo.sign"(%tensor_operand) + : (tensor<2x2xf32>) -> tensor<2x2xf32> + // BOTH: "xla_lhlo.sign"(%{{.*}}, %{{.*}}) + tensor_store %tensor_result, %result : memref<2x2xf32> + return +} + +// ----- + +// BOTH-LABEL: func @sqrt +func @sqrt(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { + %tensor_operand = tensor_load %operand : memref<2x2xf32> + %tensor_result = "xla_hlo.sqrt"(%tensor_operand) + : (tensor<2x2xf32>) -> tensor<2x2xf32> + // BOTH: "xla_lhlo.sqrt"(%{{.*}}, %{{.*}}) + tensor_store %tensor_result, %result : memref<2x2xf32> + return +} + +// ----- + +// BOTH-LABEL: func @tanh +func @tanh(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { + %tensor_operand = tensor_load %operand : memref<2x2xf32> + %tensor_result = "xla_hlo.tanh"(%tensor_operand) + : (tensor<2x2xf32>) -> tensor<2x2xf32> + // BOTH: "xla_lhlo.tanh"(%{{.*}}, %{{.*}}) + tensor_store %tensor_result, %result : memref<2x2xf32> + return +} + +// ----- + +// BOTH-LABEL: func @remainder +func @remainder(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>, %result: memref<2x2xf32>) { + %tensor_lhs = tensor_load %lhs : memref<2x2xf32> + %tensor_rhs = tensor_load %rhs : memref<2x2xf32> + %tensor_result = "xla_hlo.remainder"(%tensor_lhs, %tensor_rhs) + : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> + // BOTH: "xla_lhlo.remainder"(%{{.*}}, %{{.*}}, %{{.*}}) + tensor_store %tensor_result, %result : memref<2x2xf32> + return +} + +// ----- + +// Dynamic shape binary element-wise operation. +// BOTH-LABEL: func @add_dyn +func @add_dyn(%lhs: tensor, %rhs: tensor) { + %result = "xla_hlo.add"(%lhs, %rhs) + : (tensor, tensor) -> tensor + // BOTH: %[[C0:.*]] = constant 0 : index + // BOTH: %[[DIM0:.*]] = dim %arg0, %[[C0]] : memref + // BOTH: %[[IC0:.*]] = index_cast %[[DIM0]] : index to i64 + // BOTH: %[[C1:.*]] = constant 1 : index + // BOTH: %[[DIM1:.*]] = dim %arg0, %[[C1]] : memref + // BOTH: %[[IC1:.*]] = index_cast %[[DIM1]] : index to i64 + // BOTH: %[[SHAPE:.*]] = tensor_from_elements(%[[IC0]], %[[IC1]]) : tensor<2xi64> + // BOTH: %[[C0_:.*]] = constant 0 : index + // BOTH: %[[EE0:.*]] = extract_element %[[SHAPE]][%[[C0_]]] : tensor<2xi64> + // BOTH: %[[ICS0:.*]] = index_cast %[[EE0]] : i64 to index + // BOTH: %[[C1_:.*]] = constant 1 : index + // BOTH: %[[EE1:.*]] = extract_element %[[SHAPE]][%[[C1_]]] : tensor<2xi64> + // BOTH: %[[ICS1:.*]] = index_cast %[[EE1]] : i64 to index + // BOTH: %[[RESULT:.*]] = alloc(%[[ICS0]], %[[ICS1]]) + // BOTH: "xla_lhlo.add"(%arg0, %arg1, %[[RESULT]]) : (memref, memref, memref) -> () + return +} + +// ----- + +// Dynamic shape unary element-wise operation. +// BOTH-LABEL: func @tanh_dyn +func @tanh_dyn(%arg0: tensor) { + %result = "xla_hlo.tanh"(%arg0) + : (tensor) -> tensor + // BOTH: %[[C0:.*]] = constant 0 : index + // BOTH: %[[DIM0:.*]] = dim %arg0, %[[C0]] : memref + // BOTH: %[[IC0:.*]] = index_cast %[[DIM0]] : index to i64 + // BOTH: %[[C1:.*]] = constant 1 : index + // BOTH: %[[DIM1:.*]] = dim %arg0, %[[C1]] : memref + // BOTH: %[[IC1:.*]] = index_cast %[[DIM1]] : index to i64 + // BOTH: %[[SHAPE:.*]] = tensor_from_elements(%[[IC0]], %[[IC1]]) : tensor<2xi64> + // BOTH: %[[C0_:.*]] = constant 0 : index + // BOTH: %[[EE0:.*]] = extract_element %[[SHAPE]][%[[C0_]]] : tensor<2xi64> + // BOTH: %[[ICS0:.*]] = index_cast %[[EE0]] : i64 to index + // BOTH: %[[C1_:.*]] = constant 1 : index + // BOTH: %[[EE1:.*]] = extract_element %[[SHAPE]][%[[C1_]]] : tensor<2xi64> + // BOTH: %[[ICS1:.*]] = index_cast %[[EE1]] : i64 to index + // BOTH: %[[RESULT:.*]] = alloc(%[[ICS0]], %[[ICS1]]) + // BOTH: "xla_lhlo.tanh"(%arg0, %[[RESULT]]) : (memref, memref) -> () + return +} + +// ----- + +// BOTH-LABEL: func @dot +func @dot(%arg0: tensor<1024x1024xf32>) -> tensor<1024x1024xf32> { +// PRE-SAME: (%[[ARG0:.*]]: [[TYPE:.*]], %[[RESULT:.*]]: [[TYPE]]) +// ESC-SAME: (%[[ARG0:.*]]: [[TYPE:.*]]) -> [[TYPE]] +// BOTH-NEXT: %[[ALLOC:.*]] = alloc +// BOTH: "xla_lhlo.dot"(%[[ARG0]], %[[ARG0]], %[[ALLOC]]) : ([[TYPE]], [[TYPE]], [[TYPE]]) -> () + %dot = "xla_hlo.dot"(%arg0, %arg0) + : (tensor<1024x1024xf32>, tensor<1024x1024xf32>) -> tensor<1024x1024xf32> +// PRE: "xla_lhlo.copy"(%[[ALLOC]], %[[RESULT]]) +// ESC: return %[[ALLOC]] + return %dot : tensor<1024x1024xf32> +} + +// ----- + +// BOTH-LABEL: func @conv +func @conv(%input: tensor<3x5x5x3xf32>, %filter : tensor<2x2x3x4xf32>) -> tensor<3x5x5x4xf32> { + %c0 = constant 0 : index + // BOTH: %[[OUT:.*]] = alloc() : memref<3x5x5x4xf32> + // BOTH: "xla_lhlo.convolution"(%{{.+}}, %{{.+}}, %[[OUT]]) + // BOTH-SAME: padding = dense<[ + // BOTH-SAME: [0, 1], [0, 1]]> : tensor<2x2xi64> + // BOTH-SAME: rhs_dilation = dense<[1, 2]> + // BOTH-SAME: window_strides = dense<[2, 1]> + %out = "xla_hlo.convolution"(%filter, %input) { + batch_group_count = 1 : i64, + dimension_numbers = { + input_batch_dimension = 0 : i64, + input_feature_dimension = 3 : i64, + input_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>, + kernel_input_feature_dimension = 2 : i64, + kernel_output_feature_dimension = 3 : i64, + kernel_spatial_dimensions = dense<[0, 1]> : tensor<2xi64>, + output_batch_dimension = 0 : i64, + output_feature_dimension = 3 : i64, + output_spatial_dimensions = dense<[1, 2]> : tensor<2xi64> + }, + feature_group_count = 1 : i64, + padding = dense<[[0, 1], [0, 1]]> : tensor<2x2xi64>, + rhs_dilation = dense<[1, 2]> : tensor<2xi64>, + window_strides = dense<[2, 1]> : tensor<2xi64> + } : (tensor<2x2x3x4xf32>, tensor<3x5x5x3xf32>) -> tensor<3x5x5x4xf32> + return %out : tensor<3x5x5x4xf32> +} diff --git a/tests/hlo-legalize-to-linalg.mlir b/tests/hlo-legalize-to-linalg.mlir new file mode 100644 index 0000000..b633f17 --- /dev/null +++ b/tests/hlo-legalize-to-linalg.mlir @@ -0,0 +1,559 @@ +// RUN: mlir-hlo-opt %s -hlo-legalize-to-linalg -split-input-file | FileCheck %s + +// CHECK: #map0 = affine_map<(d0, d1) -> (d0, d1)> +// CHECK-LABEL: func @float_add +func @float_add(%lhs: tensor<2x2xf32>, + %rhs: tensor<2x2xf32>) -> tensor<2x2xf32> { + // CHECK: linalg.generic + // CHECK: ^{{[a-z0-9_]*}} + // CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]: f32 + // CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]: f32 + // CHECK: %[[RESULT:[a-zA-Z0-9_]*]] = addf %[[ARG0]], %[[ARG1]] + // CHECK: linalg.yield %[[RESULT]] + %0 = "xla_hlo.add"(%lhs, %rhs) : (tensor<2x2xf32>, + tensor<2x2xf32>) -> tensor<2x2xf32> + return %0 : tensor<2x2xf32> +} + +// ----- + +// CHECK-LABEL: integer_add +func @integer_add(%lhs: tensor<2x2xi32>, + %rhs: tensor<2x2xi32>) -> tensor<2x2xi32> { + // CHECK: linalg.generic + // CHECK: addi + %0 = "xla_hlo.add"(%lhs, %rhs) : (tensor<2x2xi32>, + tensor<2x2xi32>) -> tensor<2x2xi32> + return %0 : tensor<2x2xi32> +} + +// ----- + +// CHECK-LABEL: func @float_mul +func @float_mul(%lhs: tensor<2x2xf32>, + %rhs: tensor<2x2xf32>) -> tensor<2x2xf32> { + // CHECK: linalg.generic + // CHECK: mulf + %0 = "xla_hlo.multiply"(%lhs, %rhs) : (tensor<2x2xf32>, + tensor<2x2xf32>) -> tensor<2x2xf32> + return %0 : tensor<2x2xf32> +} + +// ----- + +// CHECK-LABEL: func @integer_mul +func @integer_mul(%lhs: tensor<2x2xi32>, + %rhs: tensor<2x2xi32>) -> tensor<2x2xi32> { + // CHECK: linalg.generic + // CHECK: muli + %0 = "xla_hlo.multiply"(%lhs, %rhs) : (tensor<2x2xi32>, + tensor<2x2xi32>) -> tensor<2x2xi32> + return %0 : tensor<2x2xi32> +} + +// ----- + +// CHECK-LABEL: func @float_remainder +func @float_remainder(%lhs: tensor<2x2xf32>, + %rhs: tensor<2x2xf32>) -> tensor<2x2xf32> { + // CHECK: linalg.generic + // CHECK: remf + %0 = "xla_hlo.remainder"(%lhs, %rhs) : (tensor<2x2xf32>, + tensor<2x2xf32>) -> tensor<2x2xf32> + return %0 : tensor<2x2xf32> +} + +// ----- + +// CHECK-LABEL: func @integer_remainder +func @integer_remainder(%lhs: tensor<2x2xi32>, + %rhs: tensor<2x2xi32>) -> tensor<2x2xi32> { + // CHECK: linalg.generic + // CHECK: remi_signed + %0 = "xla_hlo.remainder"(%lhs, %rhs) : (tensor<2x2xi32>, + tensor<2x2xi32>) -> tensor<2x2xi32> + return %0 : tensor<2x2xi32> +} + +// ----- + +// CHECK-LABEL: func @float_rsqrt +func @float_rsqrt(%operand: tensor<2x2xf32>) -> tensor<2x2xf32> { + %tensor_result = "xla_hlo.rsqrt"(%operand) + : (tensor<2x2xf32>) -> tensor<2x2xf32> + // CHECK: linalg.generic + // CHECK: rsqrt + return %tensor_result : tensor<2x2xf32> +} + +// ----- + +// CHECK-LABEL: func @float_sub +func @float_sub(%lhs: tensor<2x2xf32>, + %rhs: tensor<2x2xf32>) -> tensor<2x2xf32> { + // CHECK: linalg.generic + // CHECK: subf + %0 = "xla_hlo.subtract"(%lhs, %rhs) : (tensor<2x2xf32>, + tensor<2x2xf32>) -> tensor<2x2xf32> + return %0 : tensor<2x2xf32> +} + +// ----- + +// CHECK-LABEL: func @integer_sub +func @integer_sub(%lhs: tensor<2x2xi32>, + %rhs: tensor<2x2xi32>) -> tensor<2x2xi32> { + // CHECK: linalg.generic + // CHECK: subi + %0 = "xla_hlo.subtract"(%lhs, %rhs) : (tensor<2x2xi32>, + tensor<2x2xi32>) -> tensor<2x2xi32> + return %0 : tensor<2x2xi32> +} + +// ----- + +// CHECK-LABEL: func @float_abs +func @float_abs(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> { + // CHECK: linalg.generic + // CHECK: absf + %0 = "xla_hlo.abs"(%arg0) : (tensor<2x2xf32>) -> tensor<2x2xf32> + return %0 : tensor<2x2xf32> +} + +// ----- + +// CHECK-LABEL: func @float_exp +func @float_exp(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> { + // CHECK: linalg.generic + // CHECK: exp + %0 = "xla_hlo.exponential"(%arg0) : (tensor<2x2xf32>) -> tensor<2x2xf32> + return %0 : tensor<2x2xf32> +} + +// ----- + +// CHECK-LABEL: func @float_log +func @float_log(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> { + // CHECK: linalg.generic + // CHECK: log + %0 = "xla_hlo.log"(%arg0) : (tensor<2x2xf32>) -> tensor<2x2xf32> + return %0 : tensor<2x2xf32> +} + +// ----- + +// CHECK-LABEL: func @float_ceil +func @float_ceil(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> { + // CHECK: linalg.generic + // CHECK: ceilf + %0 = "xla_hlo.ceil"(%arg0) : (tensor<2x2xf32>) -> tensor<2x2xf32> + return %0 : tensor<2x2xf32> +} + +// ----- + +// CHECK-LABEL: func @float_neg +func @float_neg(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> { + // CHECK: linalg.generic + // CHECK: negf + %0 = "xla_hlo.negate"(%arg0) : (tensor<2x2xf32>) -> tensor<2x2xf32> + return %0 : tensor<2x2xf32> +} + +// ----- + +// CHECK-LABEL: func @float_tanh +func @float_tanh(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> { + // CHECK: linalg.generic + // CHECK: tanh + %0 = "xla_hlo.tanh"(%arg0) : (tensor<2x2xf32>) -> tensor<2x2xf32> + return %0 : tensor<2x2xf32> +} + +// ----- + +// CHECK-LABEL: func @integer_and +func @integer_and(%lhs: tensor<2x2xi32>, + %rhs: tensor<2x2xi32>) -> tensor<2x2xi32> { + // CHECK: linalg.generic + // CHECK: and + %0 = "xla_hlo.and"(%lhs, %rhs) : (tensor<2x2xi32>, + tensor<2x2xi32>) -> tensor<2x2xi32> + return %0 : tensor<2x2xi32> +} + +// ----- + +// CHECK-LABEL: func @float_cmp +func @float_cmp(%lhs: tensor<2x2xf32>, + %rhs: tensor<2x2xf32>) -> (tensor<2x2xi1>) { + %0 = "xla_hlo.compare"(%lhs, %rhs) {comparison_direction = "EQ"} + : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xi1> + return %0 : tensor<2x2xi1> +} +// CHECK: linalg.generic +// CHECK-NEXT: ^bb0(%[[LHS_IN:.*]]: f32, %[[RHS_IN:.*]]: f32): +// CHECK-NEXT: %[[RESULT:.*]] = cmpf "oeq", %[[LHS_IN]], %[[RHS_IN]] : f32 +// CHECK-NEXT: linalg.yield %[[RESULT]] : i1 + +// ----- + +// CHECK-LABEL: func @int_cmp +func @int_cmp(%lhs: tensor<2x2xi32>, + %rhs: tensor<2x2xi32>) -> tensor<2x2xi1> { + %0 = "xla_hlo.compare"(%lhs, %rhs) {comparison_direction = "LT"} + : (tensor<2x2xi32>, tensor<2x2xi32>) -> (tensor<2x2xi1>) + return %0 : tensor<2x2xi1> +} +// CHECK: linalg.generic +// CHECK-NEXT: ^bb0(%[[LHS_IN:.*]]: i32, %[[RHS_IN:.*]]: i32): +// CHECK-NEXT: %[[RESULT:.*]] = cmpi "slt", %[[LHS_IN]], %[[RHS_IN]] : i32 +// CHECK-NEXT: linalg.yield %[[RESULT]] : i1 + +// ----- + +// CHECK-LABEL: func @float_cos +func @float_cos(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> { + // CHECK: linalg.generic + // CHECK: cos + %0 = "xla_hlo.cosine"(%arg0) : (tensor<2x2xf32>) -> tensor<2x2xf32> + return %0 : tensor<2x2xf32> +} + +// ----- + +// CHECK-LABEL: func @float_sin +func @float_sin(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> { + // CHECK: linalg.generic + // CHECK: sin + %0 = "xla_hlo.sine"(%arg0) : (tensor<2x2xf32>) -> tensor<2x2xf32> + return %0 : tensor<2x2xf32> +} + +// ----- + +// CHECK-LABEL: func @copy +// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]] +func @copy(%input: tensor<2x4x8xf32>) -> tensor<2x4x8xf32> { + %0 = "xla_hlo.copy"(%input) : (tensor<2x4x8xf32>) -> (tensor<2x4x8xf32>) + return %0 : tensor<2x4x8xf32> +} +// CHECK: return [[ARG]] : tensor<2x4x8xf32> + +// ----- + +// CHECK-LABEL: func @select +func @select(%pred: tensor<2x2xi1>, %lhs: tensor<2x2xf32>, + %rhs: tensor<2x2xf32>) -> tensor<2x2xf32> { + %0 = "xla_hlo.select"(%pred, %lhs, %rhs) + : (tensor<2x2xi1>, tensor<2x2xf32>, tensor<2x2xf32>) -> (tensor<2x2xf32>) + return %0 : tensor<2x2xf32> +} +// CHECK: linalg.generic +// CHECK-NEXT: ^bb0(%[[PRED_IN:.*]]: i1, %[[LHS_IN:.*]]: f32, %[[RHS_IN:.*]]: f32): +// CHECK-NEXT: %[[RESULT:.*]] = select %[[PRED_IN]], %[[LHS_IN]], %[[RHS_IN]] : f32 +// CHECK-NEXT: linalg.yield %[[RESULT]] : f32 + +// ----- + +// CHECK-DAG: #[[OPERAND_MAP:.+]] = affine_map<(d0, d1, d2) -> ()> +// CHECK-DAG: #[[RESULT_MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +// CHECK-LABEL: func @broadcast_scalar +func @broadcast_scalar(%arg: tensor) -> tensor<4x2x1xf32> { + %0 = "xla_hlo.broadcast"(%arg) {broadcast_sizes = dense<[4, 2, 1]> : tensor<3xi64>} : (tensor) -> tensor<4x2x1xf32> + return %0: tensor<4x2x1xf32> +} +// CHECK: linalg.generic {{{.*}}indexing_maps = [#[[OPERAND_MAP]], #[[RESULT_MAP]]] +// CHECK-NEXT: ^bb0(%[[OPERAND:.*]]: f32): +// CHECK-NEXT: linalg.yield %[[OPERAND]] : f32 + +// ----- + +// CHECK-DAG: #[[OPERAND_MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d3, d4, d5)> +// CHECK-DAG: #[[RESULT_MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)> +// CHECK-LABEL: func @broadcast +func @broadcast(%arg: tensor<4x?x16xf32>) -> tensor<4x2x1x4x?x16xf32> { + %0 = "xla_hlo.broadcast"(%arg) {broadcast_sizes = dense<[4, 2, 1]> : tensor<3xi64>} : (tensor<4x?x16xf32>) -> tensor<4x2x1x4x?x16xf32> + return %0: tensor<4x2x1x4x?x16xf32> +} +// CHECK: linalg.generic {{{.*}}indexing_maps = [#[[OPERAND_MAP]], #[[RESULT_MAP]]] +// CHECK-NEXT: ^bb0(%[[OPERAND:.*]]: f32): +// CHECK-NEXT: linalg.yield %[[OPERAND]] : f32 + +// ----- + +// CHECK-DAG: #[[OPERAND_MAP:.*]] = affine_map<(d0, d1, d2, d3, d4) -> (d4, d0, 0)> +// CHECK-DAG: #[[RESULT_MAP:.*]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)> +// CHECK-LABEL: func @broadcast_in_dim +func @broadcast_in_dim(%operand: tensor<5x7x1xf32>) -> tensor<7x10x6x4x5xf32> { + %0 = "xla_hlo.broadcast_in_dim"(%operand) + {broadcast_dimensions = dense<[4,0,2]> : tensor<3xi64>} + : (tensor<5x7x1xf32>) -> tensor<7x10x6x4x5xf32> + return %0 : tensor<7x10x6x4x5xf32> +} +// CHECK: linalg.generic {{{.*}}indexing_maps = [#[[OPERAND_MAP]], #[[RESULT_MAP]]] +// CHECK-NEXT: ^bb0(%[[OPERAND:.*]]: f32): +// CHECK-NEXT: linalg.yield %[[OPERAND]] : f32 + +// ----- + +// CHECK-DAG: #[[OPERAND_MAP:.+]] = affine_map<(d0, d1) -> (d0)> +// CHECK-DAG: #[[RESULT_MAP:.+]] = affine_map<(d0, d1) -> (d0, d1)> +// CHECK-LABEL: func @broadcast_in_dim_with_one_to_one +func @broadcast_in_dim_with_one_to_one( + %operand: tensor<1xf32>) -> tensor<1x5xf32> { + %0 = "xla_hlo.broadcast_in_dim"(%operand) + {broadcast_dimensions = dense<[0]> : tensor<1xi64>} + : (tensor<1xf32>) -> tensor<1x5xf32> + return %0 : tensor<1x5xf32> +} +// CHECK: linalg.generic {{{.*}}indexing_maps = [#[[OPERAND_MAP]], #[[RESULT_MAP]]] +// CHECK-NEXT: ^bb0(%[[OPERAND:.*]]: f32): +// CHECK-NEXT: linalg.yield %[[OPERAND]] : f32 + +// ----- + +// CHECK-DAG: #[[OPERAND_MAP:.*]] = affine_map<(d0, d1, d2) -> ()> +// CHECK-DAG: #[[RESULT_MAP:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +// CHECK-LABEL: func @broadcast_scalar +func @broadcast_scalar(%operand: tensor) -> tensor<7x10x6xf32> { + %0 = "xla_hlo.broadcast_in_dim"(%operand) + {broadcast_dimensions = dense<[]> : tensor<0xi64>} + : (tensor) -> tensor<7x10x6xf32> + return %0 : tensor<7x10x6xf32> +} +// CHECK: linalg.generic {{{.*}}indexing_maps = [#[[OPERAND_MAP]], #[[RESULT_MAP]]] +// CHECK-NEXT: ^bb0(%[[OPERAND:.*]]: f32): +// CHECK-NEXT: linalg.yield %[[OPERAND]] : f32 + +// ----- + +// CHECK-DAG: #[[OPERAND_MAP:.*]] = affine_map<(d0, d1, d2, d3) -> (d1, d0, d3, d2)> +// CHECK-DAG: #[[RESULT_MAP:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +// CHECK-LABEL: func @transpose +func @transpose(%arg0: tensor<2x3x9x5xi32>) -> tensor<3x2x5x9xi32> { + %0 = "xla_hlo.transpose"(%arg0) {permutation = dense<[1, 0, 3, 2]> : tensor<4xi64>} + : (tensor<2x3x9x5xi32>) -> tensor<3x2x5x9xi32> + return %0 : tensor<3x2x5x9xi32> +} +// CHECK: linalg.generic {{{.*}}indexing_maps = [#[[OPERAND_MAP]], #[[RESULT_MAP]]] + +// ----- + +// CHECK-DAG: #[[RESHAPE_MAP1:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)> +// CHECK-DAG: #[[RESHAPE_MAP2:.*]] = affine_map<(d0, d1, d2) -> (d2)> +// CHECK-LABEL: func @reshape_3D_2D +func @reshape_3D_2D(%arg0: tensor<12x1x42xi32>) -> tensor<12x42xi32> { + %0 = "xla_hlo.reshape"(%arg0) : (tensor<12x1x42xi32>) -> tensor<12x42xi32> + return %0 : tensor<12x42xi32> +} +// CHECK: linalg.tensor_reshape %{{.*}} [#[[RESHAPE_MAP1]], #[[RESHAPE_MAP2]]] + +// ----- + +// CHECK-DAG: #[[RESHAPE_MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d0)> +// CHECK-DAG: #[[RESHAPE_MAP2:.*]] = affine_map<(d0, d1, d2, d3) -> (d1, d2, d3)> +// CHECK-LABEL: func @reshape_4D_2D +func @reshape_4D_2D(%arg0: tensor<12x42x1x1xi32>) -> tensor<12x42xi32> { + %0 = "xla_hlo.reshape"(%arg0) : (tensor<12x42x1x1xi32>) -> tensor<12x42xi32> + return %0 : tensor<12x42xi32> +} +// CHECK: linalg.tensor_reshape %{{.*}} [#[[RESHAPE_MAP1]], #[[RESHAPE_MAP2]]] + +// ----- + +// CHECK-DAG: #[[RESHAPE_MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1)> +// CHECK-DAG: #[[RESHAPE_MAP2:.*]] = affine_map<(d0, d1, d2, d3) -> (d2, d3)> +// CHECK-LABEL: func @reshape_2D_4D +func @reshape_2D_4D(%arg0: tensor<12x42xi32>) -> tensor<12x1x42x1xi32> { + %0 = "xla_hlo.reshape"(%arg0) : (tensor<12x42xi32>) -> tensor<12x1x42x1xi32> + return %0 : tensor<12x1x42x1xi32> +} +// CHECK: linalg.tensor_reshape %{{.*}} [#[[RESHAPE_MAP1]], #[[RESHAPE_MAP2]]] + +// ----- + +// CHECK-LABEL: func @minf +func @minf(%lhs: tensor<2x2xf32>, %rhs: tensor<2x2xf32>) -> tensor<2x2xf32> { + %0 = "xla_hlo.minimum"(%lhs, %rhs) + : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> + return %0 : tensor<2x2xf32> +} +// CHECK: linalg.generic +// CHECK-NEXT: ^bb0(%[[LHS_IN:.*]]: f32, %[[RHS_IN:.*]]: f32): +// CHECK-NEXT: %[[CMP:.*]] = cmpf "olt", %[[LHS_IN]], %[[RHS_IN]] : f32 +// CHECK-NEXT: %[[RESULT:.*]] = select %[[CMP]], %[[LHS_IN]], %[[RHS_IN]] : f32 +// CHECK-NEXT: linalg.yield %[[RESULT]] : f32 + +// ----- + +// CHECK-LABEL: func @maxi +func @maxi(%lhs: tensor<2x2xi32>, %rhs: tensor<2x2xi32>) -> tensor<2x2xi32> { + %0 = "xla_hlo.maximum"(%lhs, %rhs) + : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32> + return %0 : tensor<2x2xi32> +} +// CHECK: linalg.generic +// CHECK-NEXT: ^bb0(%[[LHS_IN:.*]]: i32, %[[RHS_IN:.*]]: i32): +// CHECK-NEXT: %[[CMP:.*]] = cmpi "sgt", %[[LHS_IN]], %[[RHS_IN]] : i32 +// CHECK-NEXT: %[[RESULT:.*]] = select %[[CMP]], %[[LHS_IN]], %[[RHS_IN]] : i32 +// CHECK-NEXT: linalg.yield %[[RESULT]] : i32 + +// ----- + +// CHECK-DAG: #[[MAP:.*]] = affine_map<() -> ()> +// CHECK-LABEL: func @add_scalar +func @add_scalar(%lhs: tensor, %rhs: tensor) -> tensor { + %0 = "xla_hlo.add"(%lhs, %rhs) : (tensor, tensor) -> tensor + return %0 : tensor +} +// CHECK: linalg.generic +// CHECK-SAME: indexing_maps = [#[[MAP]], #[[MAP]], #[[MAP]]] +// CHECK-NEXT: ^bb0(%[[LHS:.*]]: f32, %[[RHS:.*]]: f32): +// CHECK: %[[RESULT:.*]] = addf %[[LHS]], %[[RHS]] +// CHECK-NEXT: linalg.yield %[[RESULT]] : f32 + +// ----- + +func @reshape_collapse_single_dim + (%arg0: tensor<1x28x28x1xf32>) -> tensor<1x784xf32> { + %0 = "xla_hlo.reshape"(%arg0) : (tensor<1x28x28x1xf32>) -> tensor<1x784xf32> + return %0 : tensor<1x784xf32> +} +// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d0)> +// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d1, d2, d3)> +// CHECK-LABEL: func @reshape_collapse_single_dim +// CHECK: linalg.tensor_reshape %{{.*}} [#[[MAP0]], #[[MAP1]]] + +// ----- + +func @reshape_collapse(%arg0: tensor<2x2x2x3xf32>) -> tensor<2x4x3xf32> { + %0 = "xla_hlo.reshape"(%arg0) : (tensor<2x2x2x3xf32>) -> tensor<2x4x3xf32> + return %0 : tensor<2x4x3xf32> +} +// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d0)> +// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d1, d2)> +// CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0, d1, d2, d3) -> (d3)> +// CHECK-LABEL: func @reshape_collapse +// CHECK: linalg.tensor_reshape %{{.*}} [#[[MAP0]], #[[MAP1]], #[[MAP2]]] + +// ----- + +func @reshape_expand(%arg0: tensor<2x8xf32>) -> tensor<2x4x2xf32> { + %0 = "xla_hlo.reshape"(%arg0) : (tensor<2x8xf32>) -> tensor<2x4x2xf32> + return %0 : tensor<2x4x2xf32> +} +// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1, d2) -> (d0)> +// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1, d2) -> (d1, d2)> +// CHECK-LABEL: func @reshape_expand +// CHECK: linalg.tensor_reshape %{{.*}} [#[[MAP0]], #[[MAP1]]] + +// ----- + +func @reshape_single_expand(%arg0 : tensor<8xf32>) -> tensor<1x4x2xf32> { + %0 = "xla_hlo.reshape"(%arg0) : (tensor<8xf32>) -> tensor<1x4x2xf32> + return %0 : tensor<1x4x2xf32> +} +// CHECK: #[[MAP0:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +// CHECK-LABEL: func @reshape_single_expand +// CHECK: linalg.tensor_reshape %{{.*}} [#[[MAP0]]] + +// ----- + +func @reshape_multiple_collapse + (%arg0 : tensor<1x2x2x5x3x2xf32>) -> tensor<1x4x5x6xf32> { + %0 = "xla_hlo.reshape"(%arg0) : (tensor<1x2x2x5x3x2xf32>) -> tensor<1x4x5x6xf32> + return %0 : tensor<1x4x5x6xf32> +} +// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0)> +// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d2)> +// CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d3)> +// CHECK-DAG: #[[MAP3:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)> +// CHECK-LABEL: func @reshape_multiple_collapse +// CHECK: linalg.tensor_reshape %{{.*}} [#[[MAP0]], #[[MAP1]], #[[MAP2]], #[[MAP3]]] + +// ----- + +// CHECK-LABEL: func @convert_i32_to_f32 +func @convert_i32_to_f32(%input: tensor<2x2xi32>) -> tensor<2x2xf32> { + %result = "xla_hlo.convert"(%input) : (tensor<2x2xi32>) -> tensor<2x2xf32> + return %result : tensor<2x2xf32> +} +// CHECK: linalg.generic +// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: i32): +// CHECK-NEXT: %[[RESULT:.*]] = sitofp %[[OPERAND_IN]] : i32 to f32 +// CHECK-NEXT: linalg.yield %[[RESULT]] : f32 + +// ----- + +// CHECK-LABEL: func @convert_i16_to_i32 +func @convert_i16_to_i32(%input: tensor<2x2xi16>) -> tensor<2x2xi32> { + %result = "xla_hlo.convert"(%input) : (tensor<2x2xi16>) -> tensor<2x2xi32> + return %result : tensor<2x2xi32> +} +// CHECK: linalg.generic +// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: i16): +// CHECK-NEXT: %[[RESULT:.*]] = zexti %[[OPERAND_IN]] : i16 to i32 +// CHECK-NEXT: linalg.yield %[[RESULT]] : i32 + +// ----- + +// CHECK-LABEL: func @convert_i32_to_i16 +func @convert_i32_to_i16(%input: tensor<2x2xi32>) -> tensor<2x2xi16> { + %result = "xla_hlo.convert"(%input) : (tensor<2x2xi32>) -> tensor<2x2xi16> + return %result : tensor<2x2xi16> +} +// CHECK: linalg.generic +// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: i32): +// CHECK-NEXT: %[[RESULT:.*]] = trunci %[[OPERAND_IN]] : i32 to i16 +// CHECK-NEXT: linalg.yield %[[RESULT]] : i16 + +// ----- + +// CHECK-LABEL: func @convert_f32_to_f64 +func @convert_f32_to_f64(%input: tensor<2x2xf32>) -> tensor<2x2xf64> { + %result = "xla_hlo.convert"(%input) : (tensor<2x2xf32>) -> tensor<2x2xf64> + return %result : tensor<2x2xf64> +} +// CHECK: linalg.generic +// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: f32): +// CHECK-NEXT: %[[RESULT:.*]] = fpext %[[OPERAND_IN]] : f32 to f64 +// CHECK-NEXT: linalg.yield %[[RESULT]] : f64 + +// ----- + +// CHECK-LABEL: func @convert_f64_to_f32 +func @convert_f64_to_f32(%input: tensor<2x2xf64>) -> tensor<2x2xf32> { + %result = "xla_hlo.convert"(%input) : (tensor<2x2xf64>) -> tensor<2x2xf32> + return %result : tensor<2x2xf32> +} +// CHECK: linalg.generic +// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: f64): +// CHECK-NEXT: %[[RESULT:.*]] = fptrunc %[[OPERAND_IN]] : f64 to f32 +// CHECK-NEXT: linalg.yield %[[RESULT]] : f32 + +// ----- + +// CHECK-LABEL: func @convert_f32_to_i32 +func @convert_f32_to_i32(%input: tensor<2x2xf32>) -> tensor<2x2xi32> { + %result = "xla_hlo.convert"(%input) : (tensor<2x2xf32>) -> tensor<2x2xi32> + return %result : tensor<2x2xi32> +} +// CHECK: linalg.generic +// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: f32): +// CHECK-NEXT: %[[RESULT:.*]] = fptosi %[[OPERAND_IN]] : f32 to i32 +// CHECK-NEXT: linalg.yield %[[RESULT]] : i32 + +// ----- + +// CHECK-DAG: #[[OPERAND_MAP:.*]] = affine_map<(d0, d1) -> (d0, -d1 + 2)> +// CHECK-DAG: #[[RESULT_MAP:.*]] = affine_map<(d0, d1) -> (d0, d1)> +// CHECK-LABEL: func @reverse +func @reverse(%input: tensor<2x3xf32>) -> tensor<2x3xf32> { + %result = "xla_hlo.reverse"(%input) { + dimensions = dense<1> : tensor<1xi64> + } : (tensor<2x3xf32>) -> tensor<2x3xf32> + return %result : tensor<2x3xf32> +} +// CHECK: linalg.generic +// CHECK-SAME: indexing_maps = [#[[OPERAND_MAP]], #[[RESULT_MAP]]] diff --git a/tests/inlining.mlir b/tests/inlining.mlir new file mode 100644 index 0000000..7b1bbf5 --- /dev/null +++ b/tests/inlining.mlir @@ -0,0 +1,28 @@ +// RUN: mlir-hlo-opt %s -inline | FileCheck %s + +// Test case: Basic test of inlining into xla_hlo.while. + +// CHECK-LABEL: func @caller +// CHECK: "xla_hlo.while"{{.*}}( { +// CHECK: }, { +// CHECK: "xla_hlo.exponential" +// CHECK: }) +// CHECK-LABEL: func @callee + +func @caller(%arg0: tensor, %pred: tensor) -> tensor { + %0 = "xla_hlo.while"(%arg0) ( { + ^entry(%unused: tensor): + "xla_hlo.return"(%pred) : (tensor) -> () + }, { + ^entry(%0: tensor): + %1 = call @callee(%0) : (tensor) -> (tensor) + "xla_hlo.return"(%1) : (tensor) -> () + } ) : (tensor) -> (tensor) + return %0 : tensor +} + + +func @callee(%arg0: tensor) -> tensor { + %0 = "xla_hlo.exponential"(%arg0) : (tensor) -> tensor + return %0 : tensor +} diff --git a/tests/legalize-control-flow.mlir b/tests/legalize-control-flow.mlir new file mode 100644 index 0000000..4096b06 --- /dev/null +++ b/tests/legalize-control-flow.mlir @@ -0,0 +1,146 @@ +// RUN: mlir-hlo-opt -xla-legalize-control-flow %s -o - | FileCheck %s + +// CHECK-LABEL: func @while(%arg0: tensor) -> tensor { +func @while(%arg0: tensor) -> tensor { + //CHECK: br ^bb1(%arg0 : tensor) + //CHECK: ^bb1([[VAL0:%.+]]: tensor): + //CHECK: [[VAL1:%.+]] = "xla_hlo.compare"([[VAL0]], [[VAL0]]) + //CHECK: [[VAL2:%.+]] = extract_element [[VAL1]][] : tensor + //CHECK: cond_br [[VAL2]], ^bb2([[VAL0]] : tensor), ^bb3([[VAL0]] : tensor) + //CHECK: ^bb2([[VAL3:%.+]]: tensor): + //CHECK: [[VAL4:%.+]] = xla_hlo.add [[VAL3]], [[VAL3]] + //CHECK: br ^bb1([[VAL4]] : tensor) + //CHECK: ^bb3([[VAL5:%.+]]: tensor): + %0 = "xla_hlo.while"(%arg0) ( { + ^bb0(%arg1: tensor): + %1 = "xla_hlo.compare"(%arg1, %arg1) {comparison_direction = "LT", name = "compare.2"} : (tensor, tensor) -> tensor + "xla_hlo.return"(%1) : (tensor) -> () + }, { + ^bb0(%arg1: tensor): + %1 = xla_hlo.add %arg1, %arg1 {name = "compare.0"} : tensor + "xla_hlo.return"(%1) : (tensor) -> () + }) : (tensor) -> tensor + + // CHECK-NEXT: return [[VAL5]] + return %0 : tensor +} + +// CHECK-LABEL: func @conditional +func @conditional(%arg0: tensor) -> tensor { + // CHECK: [[C0:%.+]] = constant dense<1.000000e+01> : tensor + %cst = constant dense<1.000000e+01> : tensor + + // CHECK: [[VAL0:%.+]] = "xla_hlo.compare"(%arg0, [[C0]]) {comparison_direction = "LT"} : (tensor, tensor) -> tensor + %0 = "xla_hlo.compare"(%arg0, %cst) {comparison_direction = "LT"} : (tensor, tensor) -> tensor + + // CHECK: [[VAL1:%.+]] = extract_element [[VAL0]][] : tensor + // CHECK: cond_br [[VAL1]], ^bb1(%arg0 : tensor), ^bb2(%arg0 : tensor) + %1 = "xla_hlo.if"(%0, %arg0, %arg0) ( { + + ^bb0(%arg1: tensor): + // CHECK: ^bb1([[VAL2:%.+]]: tensor): + // CHECK: [[VAL3:%.+]] = "xla_hlo.log"([[VAL2]]) : (tensor) -> tensor + // CHECK: br ^bb3([[VAL3]] : tensor) + %2 = "xla_hlo.log"(%arg1) : (tensor) -> tensor + "xla_hlo.return"(%2) : (tensor) -> () + }, { + + ^bb0(%arg1: tensor): + // CHECK: ^bb2([[VAL4:%.+]]: tensor): + // CHECK: [[VAL5:%.+]] = "xla_hlo.exponential"([[VAL4]]) : (tensor) -> tensor + // CHECK: br ^bb3([[VAL5]] : tensor) + %2 = "xla_hlo.exponential"(%arg1) : (tensor) -> tensor + "xla_hlo.return"(%2) : (tensor) -> () + }) : (tensor, tensor, tensor) -> tensor + + // CHECK: ^bb3([[VAL6:%.+]]: tensor): + // CHECK: return [[VAL6]] : tensor + return %1 : tensor +} + +// CHECK-LABEL: func @while_with_multiple_blocks_in_body(%arg0: tensor) -> tensor { +func @while_with_multiple_blocks_in_body(%arg0: tensor) -> tensor { + // CHECK: br ^[[COND_ENTRY:.+]](%arg0 : tensor) + // CHECK: ^[[COND_ENTRY]](%0: tensor): + // CHECK: %1 = "xla_hlo.compare"(%0, %0) {comparison_direction = "LT"} : (tensor, tensor) -> tensor + // CHECK: %2 = extract_element %1[] : tensor + // CHECK: cond_br %2, ^[[BODY_ENTRY:.+]](%0 : tensor), ^[[EXIT:.+]](%0 : tensor) + // CHECK: ^[[BODY_ENTRY]](%3: tensor): + // CHECK: br ^[[BODY_SUCC:.+]](%3 : tensor) + // CHECK: ^[[BODY_SUCC]](%4: tensor): + // CHECK: %5 = xla_hlo.add %4, %4 : tensor + // CHECK: br ^[[COND_ENTRY]](%5 : tensor) + // CHECK: ^[[EXIT]](%6: tensor): + // CHECK: return %6 : tensor + // CHECK: } + %0 = "xla_hlo.while"(%arg0) ( { + ^cond_entry(%arg1: tensor): + %1 = "xla_hlo.compare"(%arg1, %arg1) {comparison_direction = "LT"} : (tensor, tensor) -> tensor + "xla_hlo.return"(%1) : (tensor) -> () + }, { + ^body_entry(%arg1: tensor): + br ^body_succ(%arg1: tensor) + ^body_succ(%0: tensor): + %1 = xla_hlo.add %0, %0 : tensor + "xla_hlo.return"(%1) : (tensor) -> () + }) : (tensor) -> tensor + + return %0 : tensor +} + +// CHECK-LABEL: func @while_with_multiple_blocks_in_cond(%arg0: tensor) -> tensor { +func @while_with_multiple_blocks_in_cond(%arg0: tensor) -> tensor { + // CHECK: br ^[[COND_ENTRY:.+]](%arg0 : tensor) + // CHECK: ^[[COND_ENTRY]](%0: tensor): + // CHECK: br ^[[COND_SUCC:.+]](%0 : tensor) + // CHECK: ^[[COND_SUCC]](%1: tensor): + // CHECK: %2 = "xla_hlo.compare"(%1, %1) {comparison_direction = "LT"} : (tensor, tensor) -> tensor + // CHECK: %3 = extract_element %2[] : tensor + // CHECK: cond_br %3, ^[[BODY_ENTRY:.+]](%0 : tensor), ^[[EXIT:.+]](%0 : tensor) + // CHECK: ^[[BODY_ENTRY]](%4: tensor): + // CHECK: br ^[[COND_ENTRY]](%4 : tensor) + // CHECK: ^[[EXIT]](%5: tensor): + // CHECK: return %5 : tensor + // CHECK: } + %0 = "xla_hlo.while"(%arg0) ( { + ^cond_entry(%arg1: tensor): + br ^cond_succ(%arg1: tensor) + ^cond_succ(%0: tensor): + %1 = "xla_hlo.compare"(%0, %0) {comparison_direction = "LT"} : (tensor, tensor) -> tensor + "xla_hlo.return"(%1) : (tensor) -> () + }, { + ^body_entry(%arg1: tensor): + "xla_hlo.return"(%arg1) : (tensor) -> () + }) : (tensor) -> tensor + + return %0 : tensor +} + +// CHECK-LABEL: func @conditional_with_multiple_blocks(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { +func @conditional_with_multiple_blocks(%arg0: tensor, %arg1: tensor, %pred: tensor) -> tensor { + // CHECK: %0 = extract_element %arg2[] : tensor + // CHECK: cond_br %0, ^[[THEN_ENTRY:.+]](%arg0 : tensor), ^[[ELSE_ENTRY:.+]](%arg1 : tensor) + // CHECK: ^[[THEN_ENTRY]](%1: tensor): + // CHECK: br ^[[THEN_SUCC:.+]](%1 : tensor) + // CHECK: ^[[THEN_SUCC]](%2: tensor): + // CHECK: %3 = "xla_hlo.log"(%2) : (tensor) -> tensor + // CHECK: br ^[[EXIT:.+]](%3 : tensor) + // CHECK: ^[[ELSE_ENTRY]](%4: tensor): + // CHECK: %5 = "xla_hlo.exponential"(%4) : (tensor) -> tensor + // CHECK: br ^[[EXIT]](%5 : tensor) + // CHECK: ^[[EXIT]](%6: tensor): + // CHECK: return %6 : tensor + // CHECK: } + %1 = "xla_hlo.if"(%pred, %arg0, %arg1) ( { + ^then_entry(%arg2: tensor): + br ^then_succ(%arg2: tensor) + ^then_succ(%0: tensor): + %2 = "xla_hlo.log"(%0) : (tensor) -> tensor + "xla_hlo.return"(%2) : (tensor) -> () + }, { + ^else_entry(%arg2: tensor): + %2 = "xla_hlo.exponential"(%arg2) : (tensor) -> tensor + "xla_hlo.return"(%2) : (tensor) -> () + }) : (tensor, tensor, tensor) -> tensor + return %1 : tensor +} diff --git a/tests/legalize-to-std.mlir b/tests/legalize-to-std.mlir new file mode 100644 index 0000000..c4153b2 --- /dev/null +++ b/tests/legalize-to-std.mlir @@ -0,0 +1,195 @@ +// RUN: mlir-hlo-opt -xla-legalize-to-std %s -o - | FileCheck %s + +// CHECK-LABEL: func @binary_ops_float(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { +func @binary_ops_float(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { + // CHECK-NEXT: %0 = addf %arg0, %arg1 : tensor<4xf32> + %0 = "xla_hlo.add"(%arg0, %arg1) {name = "add.3"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + + // CHECK-NEXT: %1 = mulf %0, %arg1 : tensor<4xf32> + %1 = "xla_hlo.multiply"(%0, %arg1) {name = "mul.4"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + + // CHECK-NEXT: %2 = subf %1, %arg1 : tensor<4xf32> + %2 = "xla_hlo.subtract"(%1, %arg1) {name = "sub.5"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + + // CHECK-NEXT: %3 = divf %2, %arg1 : tensor<4xf32> + %3 = "xla_hlo.divide"(%2, %arg1) {name = "div.6"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + + // CHECK-NEXT: %4 = remf %3, %arg1 : tensor<4xf32> + %4 = "xla_hlo.remainder"(%3, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + + // CHECK-NEXT: return %4 : tensor<4xf32> + return %4 : tensor<4xf32> +} + +// CHECK-LABEL: func @binary_ops_int(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> { +func @binary_ops_int(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> { + // CHECK-NEXT: %0 = addi %arg0, %arg1 : tensor<4xi32> + %0 = "xla_hlo.add"(%arg0, %arg1) {name = "add.3"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> + + // CHECK-NEXT: %1 = muli %0, %arg1 : tensor<4xi32> + %1 = "xla_hlo.multiply"(%0, %arg1) {name = "mul.4"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> + + // CHECK-NEXT: %2 = subi %1, %arg1 : tensor<4xi32> + %2 = "xla_hlo.subtract"(%1, %arg1) {name = "sub.5"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> + + // CHECK-NEXT: %3 = divi_signed %2, %arg1 : tensor<4xi32> + %3 = "xla_hlo.divide"(%2, %arg1) {name = "div.6"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> + + // CHECK-NEXT: %4 = remi_signed %3, %arg1 : tensor<4xi32> + %4 = "xla_hlo.remainder"(%3, %arg1) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> + + // CHECK-NEXT: return %4 : tensor<4xi32> + return %4 : tensor<4xi32> +} + +// CHECK-LABEL: func @compare_int(%arg0: tensor<4xi32>) -> (tensor<4xi1>, tensor<4xi1>, tensor<4xi1>, tensor<4xi1>, tensor<4xi1>, tensor<4xi1>) { +func @compare_int(%arg0: tensor<4xi32>) -> (tensor<4xi1>,tensor<4xi1>,tensor<4xi1>,tensor<4xi1>,tensor<4xi1>,tensor<4xi1>) { + // CHECK-NEXT: %0 = cmpi "eq", %arg0, %arg0 : tensor<4xi32> + %0 = "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "EQ"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1> + // CHECK-NEXT: %1 = cmpi "ne", %arg0, %arg0 : tensor<4xi32> + %1 = "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "NE"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1> + // CHECK-NEXT: %2 = cmpi "slt", %arg0, %arg0 : tensor<4xi32> + %2 = "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "LT"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1> + // CHECK-NEXT: %3 = cmpi "sle", %arg0, %arg0 : tensor<4xi32> + %3 = "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "LE"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1> + // CHECK-NEXT: %4 = cmpi "sgt", %arg0, %arg0 : tensor<4xi32> + %4 = "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "GT"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1> + // CHECK-NEXT: %5 = cmpi "sge", %arg0, %arg0 : tensor<4xi32> + %5 = "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "GE"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1> + // CHECK-NEXT: return %0, %1, %2, %3, %4, %5 : tensor<4xi1>, tensor<4xi1>, tensor<4xi1>, tensor<4xi1>, tensor<4xi1>, tensor<4xi1> + return %0, %1, %2, %3, %4, %5 : tensor<4xi1>, tensor<4xi1>, tensor<4xi1>, tensor<4xi1>, tensor<4xi1>, tensor<4xi1> +} + +// CHECK-LABEL: func @compare_float +func @compare_float(%arg0: tensor<4xf32>) -> (tensor<4xi1>,tensor<4xi1>,tensor<4xi1>,tensor<4xi1>,tensor<4xi1>,tensor<4xi1>) { + // CHECK-NEXT: %0 = cmpf "oeq", %arg0, %arg0 : tensor<4xf32> + %0 = "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "EQ"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1> + // CHECK-NEXT: %1 = cmpf "une", %arg0, %arg0 : tensor<4xf32> + %1 = "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "NE"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1> + // CHECK-NEXT: %2 = cmpf "olt", %arg0, %arg0 : tensor<4xf32> + %2 = "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "LT"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1> + // CHECK-NEXT: %3 = cmpf "ole", %arg0, %arg0 : tensor<4xf32> + %3 = "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "LE"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1> + // CHECK-NEXT: %4 = cmpf "ogt", %arg0, %arg0 : tensor<4xf32> + %4 = "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "GT"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1> + // CHECK-NEXT: %5 = cmpf "oge", %arg0, %arg0 : tensor<4xf32> + %5 = "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "GE"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1> + return %0, %1, %2, %3, %4, %5: tensor<4xi1>, tensor<4xi1>, tensor<4xi1>, tensor<4xi1>, tensor<4xi1>, tensor<4xi1> +} + +// CHECK-LABEL: func @int_constant +func @int_constant() -> (tensor, tensor<2x3xi32>, tensor<2x3xi32>) { + // CHECK-NEXT: [[CST0:%.+]] = constant {{.+}} : tensor + %0 = "xla_hlo.constant"() {value = dense<0> : tensor} : () -> (tensor) + // CHECK-NEXT: [[CST1:%.+]] = constant {{.+}} : tensor<2x3xi32> + %1 = "xla_hlo.constant"() {value = dense<1> : tensor<2x3xi32>} : () -> (tensor<2x3xi32>) + // CHECK-NEXT: [[CST2:%.+]] = constant {{.+}} : tensor<2x3xi32> + %2 = "xla_hlo.constant"() {value = dense<[[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi32>} : () -> (tensor<2x3xi32>) + // CHECK-NEXT: return [[CST0]], [[CST1]], [[CST2]] : tensor, tensor<2x3xi32>, tensor<2x3xi32> + return %0, %1, %2: tensor, tensor<2x3xi32>, tensor<2x3xi32> +} + +// CHECK-LABEL: func @float_constant +func @float_constant() -> (tensor, tensor<2x3xf32>, tensor<2x3xf32>) { + // CHECK-NEXT: [[CST0:%.+]] = constant {{.+}} : tensor + %0 = "xla_hlo.constant"() {value = dense<0.0> : tensor} : () -> (tensor) + // CHECK-NEXT: [[CST1:%.+]] = constant {{.+}} : tensor<2x3xf32> + %1 = "xla_hlo.constant"() {value = dense<1.0> : tensor<2x3xf32>} : () -> (tensor<2x3xf32>) + // CHECK-NEXT: [[CST2:%.+]] = constant {{.+}} : tensor<2x3xf32> + %2 = "xla_hlo.constant"() {value = dense<[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]> : tensor<2x3xf32>} : () -> (tensor<2x3xf32>) + // CHECK-NEXT: return [[CST0]], [[CST1]], [[CST2]] : tensor, tensor<2x3xf32>, tensor<2x3xf32> + return %0, %1, %2: tensor, tensor<2x3xf32>, tensor<2x3xf32> +} + +// Test Iota lowering to constant +// CHECK-LABEL: func @iota.const.1() -> tensor<4xi32> { +func @iota.const.1() -> tensor<4xi32> { + // CHECK-NEXT: %[[CST:.*]] = constant dense<[0, 1, 2, 3]> : tensor<4xi32> + %0 = "xla_hlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<4xi32> + // CHECK-NEXT: return %[[CST]] : tensor<4xi32> + return %0 : tensor<4xi32> +} + +// CHECK-LABEL: func @iota.const.2() -> tensor<2x4xi32> { +func @iota.const.2() -> tensor<2x4xi32> { + // CHECK-NEXT: %[[CST:.*]] = constant dense<{{\[\[}}0, 0, 0, 0], [1, 1, 1, 1]]> : tensor<2x4xi32> + %0 = "xla_hlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<2x4xi32> + // CHECK-NEXT: return %[[CST]] : tensor<2x4xi32> + return %0 : tensor<2x4xi32> +} + +// CHECK-LABEL: func @iota.const.3() -> tensor<2x4xi32> { +func @iota.const.3() -> tensor<2x4xi32> { + // CHECK-NEXT: %[[CST:.*]] = constant dense<{{\[\[}}0, 1, 2, 3], [0, 1, 2, 3]]> : tensor<2x4xi32> + %0 = "xla_hlo.iota"() {iota_dimension = 1 : i64} : () -> tensor<2x4xi32> + // CHECK-NEXT: return %[[CST]] : tensor<2x4xi32> + return %0 : tensor<2x4xi32> +} + +// CHECK-LABEL: func @iota.const.4() -> tensor<2x3x4xi32> { +func @iota.const.4() -> tensor<2x3x4xi32> { + // CHECK-NEXT: %[[CST:.*]] = constant dense<{{\[\[\[}}0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0{{\]\]}}, {{\[\[}}1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1]]]> : tensor<2x3x4xi32> + %0 = "xla_hlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<2x3x4xi32> + // CHECK-NEXT: return %[[CST]] : tensor<2x3x4xi32> + return %0 : tensor<2x3x4xi32> +} + +// CHECK-LABEL: func @iota.const.5() -> tensor<2x3x4xi32> { +func @iota.const.5() -> tensor<2x3x4xi32> { + // CHECK-NEXT: %[[CST:.*]] = constant dense<{{\[\[\[}}0, 0, 0, 0], [1, 1, 1, 1], [2, 2, 2, 2{{\]\]}}, {{\[\[}}0, 0, 0, 0], [1, 1, 1, 1], [2, 2, 2, 2]]]> : tensor<2x3x4xi32> + %0 = "xla_hlo.iota"() {iota_dimension = 1 : i64} : () -> tensor<2x3x4xi32> + // CHECK-NEXT: return %[[CST]] : tensor<2x3x4xi32> + return %0 : tensor<2x3x4xi32> +} + +// CHECK-LABEL: func @iota.const.6() -> tensor<2x3x4xi32> { +func @iota.const.6() -> tensor<2x3x4xi32> { + // CHECK-NEXT: %[[CST:.*]] = constant dense<{{\[\[\[}}0, 1, 2, 3], [0, 1, 2, 3], [0, 1, 2, 3{{\]\]}}, {{\[\[}}0, 1, 2, 3], [0, 1, 2, 3], [0, 1, 2, 3]]]> : tensor<2x3x4xi32> + %0 = "xla_hlo.iota"() {iota_dimension = 2 : i64} : () -> tensor<2x3x4xi32> + // CHECK-NEXT: return %[[CST]] : tensor<2x3x4xi32> + return %0 : tensor<2x3x4xi32> +} + +// CHECK-LABEL: func @iota.const.f32 +func @iota.const.f32() -> tensor<4xf32> { + // CHECK-NEXT: %[[CST:.*]] = constant dense<[0.000000e+00, 1.000000e+00, 2.000000e+00, 3.000000e+00]> : tensor<4xf32> + %0 = "xla_hlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<4xf32> + // CHECK-NEXT: return %[[CST]] : tensor<4xf32> + return %0 : tensor<4xf32> +} + +// CHECK-LABEL: func @iota.const.f64 +func @iota.const.f64() -> tensor<4xf64> { + // CHECK-NEXT: %[[CST:.*]] = constant dense<[0.000000e+00, 1.000000e+00, 2.000000e+00, 3.000000e+00]> : tensor<4xf64> + %0 = "xla_hlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<4xf64> + // CHECK-NEXT: return %[[CST]] : tensor<4xf64> + return %0 : tensor<4xf64> +} + +// CHECK-LABEL: func @iota.const.bf16 +func @iota.const.bf16() -> tensor<4xbf16> { + // CHECK-NEXT: %[[CST:.*]] = constant dense<[0.000000e+00, 1.000000e+00, 2.000000e+00, 3.000000e+00]> : tensor<4xbf16> + %0 = "xla_hlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<4xbf16> + // CHECK-NEXT: return %[[CST]] : tensor<4xbf16> + return %0 : tensor<4xbf16> +} + +// CHECK-LABEL: func @iota.const.complex.f32 +func @iota.const.complex.f32() -> tensor<4xcomplex> { + // CHECK-NEXT: [[REAL:%.*]] = constant dense<[0.000000e+00, 1.000000e+00, 2.000000e+00, 3.000000e+00]> : tensor<4xf32> + // CHECK-NEXT: [[IMAG:%.*]] = constant dense<0.000000e+00> : tensor<4xf32> + // CHECK-NEXT: [[COMPLEX:%.*]] = "xla_hlo.complex"([[REAL]], [[IMAG]]) + %0 = "xla_hlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<4xcomplex> + // CHECK-NEXT: return [[COMPLEX]] : tensor<4xcomplex> + return %0 : tensor<4xcomplex> +} + +// CHECK-LABEL: func @iota.const.complex.f64 +func @iota.const.complex.f64() -> tensor<4xcomplex> { + // CHECK-NEXT: [[REAL:%.*]] = constant dense<[0.000000e+00, 1.000000e+00, 2.000000e+00, 3.000000e+00]> : tensor<4xf64> + // CHECK-NEXT: [[IMAG:%.*]] = constant dense<0.000000e+00> : tensor<4xf64> + // CHECK-NEXT: [[COMPLEX:%.*]] = "xla_hlo.complex"([[REAL]], [[IMAG]]) + %0 = "xla_hlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<4xcomplex> + // CHECK-NEXT: return [[COMPLEX]] : tensor<4xcomplex> + return %0 : tensor<4xcomplex> +} diff --git a/tests/legalize_tanh_to_approximation.mlir b/tests/legalize_tanh_to_approximation.mlir new file mode 100644 index 0000000..eaa3fdc --- /dev/null +++ b/tests/legalize_tanh_to_approximation.mlir @@ -0,0 +1,125 @@ +// RUN: mlir-hlo-opt -xla-legalize-tanh-to-approximation -split-input-file %s | FileCheck %s + +func @tanh_f64(%arg0 : f64) -> f64 { + %res = tanh %arg0 : f64 + return %res : f64 +} + +// CHECK-LABEL: @tanh_f64 +// CHECK: tanh + +// ----- + +func @tanh_f32(%arg0 : f32) -> f32 { + %res = tanh %arg0 : f32 + return %res : f32 +} + +// NOTE: Assertions have been autogenerated by utils/generate-test-checks.py + +// CHECK-LABEL: func @tanh_f32( +// CHECK-SAME: %[[VAL_0:.*]]: f32) -> f32 { +// CHECK: %[[VAL_1:.*]] = constant 4.000000e-04 : f32 +// CHECK: %[[VAL_2:.*]] = constant 7.90531111 : f32 +// CHECK: %[[VAL_3:.*]] = constant -7.90531111 : f32 +// CHECK: %[[VAL_4:.*]] = constant -2.76076837E-16 : f32 +// CHECK: %[[VAL_5:.*]] = constant 2.00018794E-13 : f32 +// CHECK: %[[VAL_6:.*]] = constant -8.60467184E-11 : f32 +// CHECK: %[[VAL_7:.*]] = constant 5.12229725E-8 : f32 +// CHECK: %[[VAL_8:.*]] = constant 1.48572235E-5 : f32 +// CHECK: %[[VAL_9:.*]] = constant 6.37261954E-4 : f32 +// CHECK: %[[VAL_10:.*]] = constant 0.00489352457 : f32 +// CHECK: %[[VAL_11:.*]] = constant 1.19825836E-6 : f32 +// CHECK: %[[VAL_12:.*]] = constant 1.18534706E-4 : f32 +// CHECK: %[[VAL_13:.*]] = constant 0.00226843474 : f32 +// CHECK: %[[VAL_14:.*]] = constant 0.00489352504 : f32 +// CHECK: %[[VAL_15:.*]] = absf %[[VAL_0]] : f32 +// CHECK: %[[VAL_16:.*]] = cmpf "olt", %[[VAL_15]], %[[VAL_1]] : f32 +// CHECK: %[[VAL_17:.*]] = cmpf "ule", %[[VAL_0]], %[[VAL_2]] : f32 +// CHECK: %[[VAL_18:.*]] = select %[[VAL_17]], %[[VAL_0]], %[[VAL_2]] : f32 +// CHECK: %[[VAL_19:.*]] = cmpf "uge", %[[VAL_18]], %[[VAL_3]] : f32 +// CHECK: %[[VAL_20:.*]] = select %[[VAL_19]], %[[VAL_18]], %[[VAL_3]] : f32 +// CHECK: %[[VAL_21:.*]] = mulf %[[VAL_20]], %[[VAL_20]] : f32 +// CHECK: %[[VAL_22:.*]] = mulf %[[VAL_21]], %[[VAL_4]] : f32 +// CHECK: %[[VAL_23:.*]] = addf %[[VAL_22]], %[[VAL_5]] : f32 +// CHECK: %[[VAL_24:.*]] = mulf %[[VAL_21]], %[[VAL_23]] : f32 +// CHECK: %[[VAL_25:.*]] = addf %[[VAL_24]], %[[VAL_6]] : f32 +// CHECK: %[[VAL_26:.*]] = mulf %[[VAL_21]], %[[VAL_25]] : f32 +// CHECK: %[[VAL_27:.*]] = addf %[[VAL_26]], %[[VAL_7]] : f32 +// CHECK: %[[VAL_28:.*]] = mulf %[[VAL_21]], %[[VAL_27]] : f32 +// CHECK: %[[VAL_29:.*]] = addf %[[VAL_28]], %[[VAL_8]] : f32 +// CHECK: %[[VAL_30:.*]] = mulf %[[VAL_21]], %[[VAL_29]] : f32 +// CHECK: %[[VAL_31:.*]] = addf %[[VAL_30]], %[[VAL_9]] : f32 +// CHECK: %[[VAL_32:.*]] = mulf %[[VAL_21]], %[[VAL_31]] : f32 +// CHECK: %[[VAL_33:.*]] = addf %[[VAL_32]], %[[VAL_10]] : f32 +// CHECK: %[[VAL_34:.*]] = mulf %[[VAL_20]], %[[VAL_33]] : f32 +// CHECK: %[[VAL_35:.*]] = mulf %[[VAL_21]], %[[VAL_11]] : f32 +// CHECK: %[[VAL_36:.*]] = addf %[[VAL_35]], %[[VAL_12]] : f32 +// CHECK: %[[VAL_37:.*]] = mulf %[[VAL_21]], %[[VAL_36]] : f32 +// CHECK: %[[VAL_38:.*]] = addf %[[VAL_37]], %[[VAL_13]] : f32 +// CHECK: %[[VAL_39:.*]] = mulf %[[VAL_21]], %[[VAL_38]] : f32 +// CHECK: %[[VAL_40:.*]] = addf %[[VAL_39]], %[[VAL_14]] : f32 +// CHECK: %[[VAL_41:.*]] = divf %[[VAL_34]], %[[VAL_40]] : f32 +// CHECK: %[[VAL_42:.*]] = select %[[VAL_16]], %[[VAL_0]], %[[VAL_41]] : f32 +// CHECK: return %[[VAL_42]] : f32 +// CHECK: } + +// ----- + +func @tanh_f16(%arg0 : f16) -> f16 { + %res = tanh %arg0 : f16 + return %res : f16 +} + +// NOTE: Assertions have been autogenerated by utils/generate-test-checks.py + +// CHECK-LABEL: func @tanh_f16( +// CHECK-SAME: %[[VAL_0:.*]]: f16) -> f16 { +// CHECK: %[[VAL_1:.*]] = constant 4.000000e-04 : f32 +// CHECK: %[[VAL_2:.*]] = constant 7.90531111 : f32 +// CHECK: %[[VAL_3:.*]] = constant -7.90531111 : f32 +// CHECK: %[[VAL_4:.*]] = constant -2.76076837E-16 : f32 +// CHECK: %[[VAL_5:.*]] = constant 2.00018794E-13 : f32 +// CHECK: %[[VAL_6:.*]] = constant -8.60467184E-11 : f32 +// CHECK: %[[VAL_7:.*]] = constant 5.12229725E-8 : f32 +// CHECK: %[[VAL_8:.*]] = constant 1.48572235E-5 : f32 +// CHECK: %[[VAL_9:.*]] = constant 6.37261954E-4 : f32 +// CHECK: %[[VAL_10:.*]] = constant 0.00489352457 : f32 +// CHECK: %[[VAL_11:.*]] = constant 1.19825836E-6 : f32 +// CHECK: %[[VAL_12:.*]] = constant 1.18534706E-4 : f32 +// CHECK: %[[VAL_13:.*]] = constant 0.00226843474 : f32 +// CHECK: %[[VAL_14:.*]] = constant 0.00489352504 : f32 +// CHECK: %[[VAL_15:.*]] = fpext %[[VAL_0]] : f16 to f32 +// CHECK: %[[VAL_16:.*]] = absf %[[VAL_15]] : f32 +// CHECK: %[[VAL_17:.*]] = cmpf "olt", %[[VAL_16]], %[[VAL_1]] : f32 +// CHECK: %[[VAL_18:.*]] = cmpf "ule", %[[VAL_15]], %[[VAL_2]] : f32 +// CHECK: %[[VAL_19:.*]] = select %[[VAL_18]], %[[VAL_15]], %[[VAL_2]] : f32 +// CHECK: %[[VAL_20:.*]] = cmpf "uge", %[[VAL_19]], %[[VAL_3]] : f32 +// CHECK: %[[VAL_21:.*]] = select %[[VAL_20]], %[[VAL_19]], %[[VAL_3]] : f32 +// CHECK: %[[VAL_22:.*]] = mulf %[[VAL_21]], %[[VAL_21]] : f32 +// CHECK: %[[VAL_23:.*]] = mulf %[[VAL_22]], %[[VAL_4]] : f32 +// CHECK: %[[VAL_24:.*]] = addf %[[VAL_23]], %[[VAL_5]] : f32 +// CHECK: %[[VAL_25:.*]] = mulf %[[VAL_22]], %[[VAL_24]] : f32 +// CHECK: %[[VAL_26:.*]] = addf %[[VAL_25]], %[[VAL_6]] : f32 +// CHECK: %[[VAL_27:.*]] = mulf %[[VAL_22]], %[[VAL_26]] : f32 +// CHECK: %[[VAL_28:.*]] = addf %[[VAL_27]], %[[VAL_7]] : f32 +// CHECK: %[[VAL_29:.*]] = mulf %[[VAL_22]], %[[VAL_28]] : f32 +// CHECK: %[[VAL_30:.*]] = addf %[[VAL_29]], %[[VAL_8]] : f32 +// CHECK: %[[VAL_31:.*]] = mulf %[[VAL_22]], %[[VAL_30]] : f32 +// CHECK: %[[VAL_32:.*]] = addf %[[VAL_31]], %[[VAL_9]] : f32 +// CHECK: %[[VAL_33:.*]] = mulf %[[VAL_22]], %[[VAL_32]] : f32 +// CHECK: %[[VAL_34:.*]] = addf %[[VAL_33]], %[[VAL_10]] : f32 +// CHECK: %[[VAL_35:.*]] = mulf %[[VAL_21]], %[[VAL_34]] : f32 +// CHECK: %[[VAL_36:.*]] = mulf %[[VAL_22]], %[[VAL_11]] : f32 +// CHECK: %[[VAL_37:.*]] = addf %[[VAL_36]], %[[VAL_12]] : f32 +// CHECK: %[[VAL_38:.*]] = mulf %[[VAL_22]], %[[VAL_37]] : f32 +// CHECK: %[[VAL_39:.*]] = addf %[[VAL_38]], %[[VAL_13]] : f32 +// CHECK: %[[VAL_40:.*]] = mulf %[[VAL_22]], %[[VAL_39]] : f32 +// CHECK: %[[VAL_41:.*]] = addf %[[VAL_40]], %[[VAL_14]] : f32 +// CHECK: %[[VAL_42:.*]] = divf %[[VAL_35]], %[[VAL_41]] : f32 +// CHECK: %[[VAL_43:.*]] = select %[[VAL_17]], %[[VAL_15]], %[[VAL_42]] : f32 +// CHECK: %[[VAL_44:.*]] = fptrunc %[[VAL_43]] : f32 to f16 +// CHECK: return %[[VAL_44]] : f16 +// CHECK: } + + diff --git a/tests/lhlo-copy-removal.mlir b/tests/lhlo-copy-removal.mlir new file mode 100644 index 0000000..3d3f802 --- /dev/null +++ b/tests/lhlo-copy-removal.mlir @@ -0,0 +1,93 @@ +// RUN: mlir-hlo-opt -lhlo-copy-removal %s -o - | FileCheck %s + +// CHECK-LABEL: func @remove_simple +func @remove_simple(%arg0: memref<2x2xf32>) { + %0 = alloc() {temp = true} : memref<2x2xf32> + "xla_lhlo.copy"(%0, %arg0) : (memref<2x2xf32>, memref<2x2xf32>) -> () + dealloc %0 : memref<2x2xf32> + // CHECK-NEXT: "xla_lhlo.terminator"() : () -> () + "xla_lhlo.terminator"() : () -> () +} + +// ----- + +// CHECK-LABEL: func @remove_without_dealloc +func @remove_without_dealloc(%arg0: memref<2x2xf32>) { + %0 = alloc() {temp = true} : memref<2x2xf32> + "xla_lhlo.copy"(%0, %arg0) : (memref<2x2xf32>, memref<2x2xf32>) -> () + // CHECK-NEXT: "xla_lhlo.terminator"() : () -> () + "xla_lhlo.terminator"() : () -> () +} + +// ----- + +// CHECK-LABEL: func @replace_dependency +func @replace_dependency(%arg0: memref<2x2xf32>, %arg1: memref<2x2xf32>) { + %0 = alloc() {temp = true} : memref<2x2xf32> + "xla_lhlo.exponential"(%arg0, %0) : (memref<2x2xf32>, memref<2x2xf32>) -> () + // CHECK-NEXT: "xla_lhlo.exponential"(%arg0, %arg1) : (memref<2x2xf32>, memref<2x2xf32>) -> () + "xla_lhlo.copy"(%0, %arg1) : (memref<2x2xf32>, memref<2x2xf32>) -> () + dealloc %0 : memref<2x2xf32> + // CHECK-NEXT: "xla_lhlo.terminator"() : () -> () + "xla_lhlo.terminator"() : () -> () +} + +// ----- + +// CHECK-LABEL: func @keep_copies +func @keep_copies(%arg0: memref<2x2xf32>, %arg1: memref<2x2xf32>) { + // CHECK-NEXT: "xla_lhlo.copy"(%arg0, %arg1) : (memref<2x2xf32>, memref<2x2xf32>) -> () + "xla_lhlo.copy"(%arg0, %arg1) : (memref<2x2xf32>, memref<2x2xf32>) -> () + // CHECK-NEXT: "xla_lhlo.terminator"() : () -> () + "xla_lhlo.terminator"() : () -> () +} + +// ----- + +// CHECK-LABEL: func @must_not_be_removed +func @must_not_be_removed(%arg0: memref<2x2xf32>, + %arg1: memref<2x2xf32>, + %arg2: memref<2x2xf32>) { + // CHECK-NEXT: %[[ALLOC:.*]] = alloc() {temp = true} : memref<2x2xf32> + %0 = alloc() {temp = true} : memref<2x2xf32> + // CHECK-NEXT: "xla_lhlo.exponential"(%arg0, %[[ALLOC]]) : (memref<2x2xf32>, memref<2x2xf32>) -> () + "xla_lhlo.exponential"(%arg0, %0) : (memref<2x2xf32>, memref<2x2xf32>) -> () + // CHECK-NEXT: "xla_lhlo.exponential"(%arg1, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> () + "xla_lhlo.exponential"(%arg1, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> () + // CHECK-NEXT: "xla_lhlo.copy"(%[[ALLOC]], %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> () + "xla_lhlo.copy"(%0, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> () + dealloc %0 : memref<2x2xf32> + "xla_lhlo.terminator"() : () -> () +} + +// ----- + +// CHECK-LABEL: func @must_be_removed_first +func @must_be_removed_first(%arg0: memref<2x2xf32>, + %arg1: memref<2x2xf32>, + %arg2: memref<2x2xf32>) { + %0 = alloc() {temp = true} : memref<2x2xf32> + // CHECK-NEXT: "xla_lhlo.exponential"(%arg1, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> () + "xla_lhlo.exponential"(%arg1, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> () + // CHECK-NEXT: "xla_lhlo.exponential"(%arg0, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> () + "xla_lhlo.exponential"(%arg0, %0) : (memref<2x2xf32>, memref<2x2xf32>) -> () + "xla_lhlo.copy"(%0, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> () + dealloc %0 : memref<2x2xf32> + "xla_lhlo.terminator"() : () -> () +} + +// ----- + +// CHECK-LABEL: func @must_be_removed_second +func @must_be_removed_second(%arg0: memref<2x2xf32>, + %arg1: memref<2x2xf32>, + %arg2: memref<2x2xf32>) { + %0 = alloc() {temp = true} : memref<2x2xf32> + // CHECK-NEXT: "xla_lhlo.exponential"(%arg0, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> () + "xla_lhlo.exponential"(%arg0, %0) : (memref<2x2xf32>, memref<2x2xf32>) -> () + "xla_lhlo.copy"(%0, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> () + // CHECK-NEXT: "xla_lhlo.exponential"(%arg1, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> () + "xla_lhlo.exponential"(%arg1, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> () + dealloc %0 : memref<2x2xf32> + "xla_lhlo.terminator"() : () -> () +} diff --git a/tests/lhlo-fuse-linalg.mlir b/tests/lhlo-fuse-linalg.mlir new file mode 100644 index 0000000..6a67466 --- /dev/null +++ b/tests/lhlo-fuse-linalg.mlir @@ -0,0 +1,236 @@ +// RUN: mlir-hlo-opt -lhlo-fuse-linalg %s -split-input-file | FileCheck %s --dump-input=always +// RUN: mlir-hlo-opt -lhlo-fuse-linalg=tile-sizes=2,3 %s -split-input-file | FileCheck %s -check-prefix=TILED +// RUN: mlir-hlo-opt -lhlo-fuse-linalg=use-parallel-loops %s -split-input-file | FileCheck %s -check-prefix=PLOOP + +#map0 = affine_map<(d0, d1) -> (d0, d1)> +#pointwise_2d_trait = {args_in = 2, args_out = 1, indexing_maps = [#map0, #map0, #map0], iterator_types = ["parallel", "parallel"]} +func @fusion(%multiplier: memref<6x6xf32>, %summand_1: memref<6x6xf32>, + %summand_2: memref<6x6xf32>, %result: memref<6x6xf32>) { + %temp_result = alloc() : memref<6x6xf32> + linalg.generic #pointwise_2d_trait %summand_1, %summand_2, %temp_result { + ^bb0(%summand_1_in: f32, %summand_2_in: f32, %temp_result_in: f32): + %out = addf %summand_1_in, %summand_2_in : f32 + linalg.yield %out : f32 + } : memref<6x6xf32>, memref<6x6xf32>, memref<6x6xf32> + linalg.generic #pointwise_2d_trait %temp_result, %multiplier, %result { + ^bb0(%temp_result_in: f32, %multiplier_in: f32, %result_in: f32): + %out = mulf %temp_result_in, %multiplier_in : f32 + linalg.yield %out : f32 + } : memref<6x6xf32>, memref<6x6xf32>, memref<6x6xf32> + dealloc %temp_result : memref<6x6xf32> + return +} +// CHECK-LABEL: func @fusion +// CHECK: %[[C1:.*]] = constant 1 +// CHECK-NOT: linalg.generic +// CHECK: scf.for {{.*}} step %[[C1]] +// CHECK: scf.for {{.*}} step %[[C1]] +// CHECK-NOT: scf.for +// CHECK: linalg.generic +// CHECK: addf +// CHECK: linalg.generic +// CHECK: mulf + +// TILED-LABEL: func @fusion +// TILED-DAG: %[[C2:.*]] = constant 2 +// TILED-DAG: %[[C3:.*]] = constant 3 +// TILED-NOT: linalg.generic +// TILED: scf.for {{.*}} step %[[C2]] +// TILED: scf.for {{.*}} step %[[C3]] +// TILED-NOT: scf.for +// TILED: linalg.generic +// TILED: addf +// TILED: linalg.generic +// TILED: mulf + +// PLOOP-LABEL: func @fusion +// PLOOP-NOT: linalg.generic +// PLOOP: scf.parallel +// PLOOP-NOT: scf.parallel +// PLOOP: linalg.generic +// PLOOP: addf +// PLOOP: linalg.generic +// PLOOP: mulf + +// ----- + +func @fusion_of_three(%arg0: memref<100x10xf32>, + %arg1: memref<100xf32>, + %arg2: memref<100x10xf32>) { + %0 = alloc() : memref<100x10xf32> + linalg.generic { + args_in = 1 : i64, + args_out = 1 : i64, + indexing_maps = [affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0, d1)>], + iterator_types = ["parallel", "parallel"] + } %arg1, %0 { + ^bb0(%arg3: f32, %arg4: f32): // no predecessors + linalg.yield %arg3 : f32 + }: memref<100xf32>, memref<100x10xf32> + %1 = alloc() : memref<100x10xf32> + linalg.generic { + args_in = 2 : i64, + args_out = 1 : i64, + indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], + iterator_types = ["parallel", "parallel"] + } %arg0, %0, %1 { + ^bb0(%arg3: f32, %arg4: f32, %arg5: f32): // no predecessors + %2 = subf %arg3, %arg4 : f32 + linalg.yield %2 : f32 + }: memref<100x10xf32>, memref<100x10xf32>, memref<100x10xf32> + dealloc %0 : memref<100x10xf32> + linalg.generic { + args_in = 1 : i64, + args_out = 1 : i64, + indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], + iterator_types = ["parallel", "parallel"] + } %1, %arg2 { + ^bb0(%arg3: f32, %arg4: f32): // no predecessors + %2 = exp %arg3 : f32 + linalg.yield %2 : f32 + }: memref<100x10xf32>, memref<100x10xf32> + dealloc %1 : memref<100x10xf32> + return +} +// CHECK-LABEL: func @fusion +// CHECK: %[[C1:.*]] = constant 1 +// CHECK-NOT: linalg.generic +// CHECK: scf.for {{.*}} step %[[C1]] +// CHECK: scf.for {{.*}} step %[[C1]] +// CHECK-NOT: scf.for +// CHECK: linalg.generic +// CHECK: linalg.generic +// CHECK: subf +// CHECK: linalg.generic +// CHECK: exp + +// TILED-LABEL: func @fusion_of_three +// TILED-DAG: %[[C2:.*]] = constant 2 +// TILED-DAG: %[[C3:.*]] = constant 3 +// TILED-NOT: linalg.generic +// TILED: scf.for {{.*}} step %[[C2]] +// TILED: scf.for {{.*}} step %[[C3]] +// TILED-NOT: scf.for +// TILED: linalg.generic +// TILED: linalg.generic +// TILED: subf +// TILED: linalg.generic +// TILED: exp + +// PLOOP-LABEL: func @fusion_of_three +// PLOOP-NOT: linalg.generic +// PLOOP: scf.parallel +// PLOOP-NOT: scf.parallel +// PLOOP: linalg.generic +// PLOOP: linalg.generic +// PLOOP: subf +// PLOOP: linalg.generic +// PLOOP: exp + +// ----- + +#map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +#pointwise_4d_trait = {args_in = 2, args_out = 1, indexing_maps = [#map0, #map0, #map0], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} +func @fusion_4d(%multiplier: memref<6x6x6x6xf32>, %summand_1: memref<6x6x6x6xf32>, + %summand_2: memref<6x6x6x6xf32>, %result: memref<6x6x6x6xf32>) { + %temp_result = alloc() : memref<6x6x6x6xf32> + linalg.generic #pointwise_4d_trait %summand_1, %summand_2, %temp_result { + ^bb0(%summand_1_in: f32, %summand_2_in: f32, %temp_result_in: f32): + %out = addf %summand_1_in, %summand_2_in : f32 + linalg.yield %out : f32 + } : memref<6x6x6x6xf32>, memref<6x6x6x6xf32>, memref<6x6x6x6xf32> + linalg.generic #pointwise_4d_trait %temp_result, %multiplier, %result { + ^bb0(%temp_result_in: f32, %multiplier_in: f32, %result_in: f32): + %out = mulf %temp_result_in, %multiplier_in : f32 + linalg.yield %out : f32 + } : memref<6x6x6x6xf32>, memref<6x6x6x6xf32>, memref<6x6x6x6xf32> + dealloc %temp_result : memref<6x6x6x6xf32> + return +} +// CHECK-LABEL: func @fusion_4d +// CHECK: %[[C1:.*]] = constant 1 +// CHECK-NOT: linalg.generic +// CHECK: scf.for {{.*}} step %[[C1]] +// CHECK: scf.for {{.*}} step %[[C1]] +// CHECK: scf.for {{.*}} step %[[C1]] +// CHECK: scf.for {{.*}} step %[[C1]] +// CHECK-NOT: scf.for +// CHECK: linalg.generic +// CHECK: addf +// CHECK: linalg.generic +// CHECK: mulf + +// TILED-LABEL: func @fusion_4d +// TILED-DAG: %[[C2:.*]] = constant 2 +// TILED-DAG: %[[C3:.*]] = constant 3 +// TILED-NOT: linalg.generic +// TILED: scf.for {{.*}} step %[[C2]] +// TILED: scf.for {{.*}} step %[[C3]] +// TILED-NOT: scf.for +// TILED: linalg.generic +// TILED: addf +// TILED: linalg.generic +// TILED: mulf + +// PLOOP-LABEL: func @fusion_4d +// PLOOP-NOT: linalg.generic +// PLOOP: scf.parallel +// PLOOP-NOT: scf.parallel +// PLOOP: linalg.generic +// PLOOP: addf +// PLOOP: linalg.generic +// PLOOP: mulf + +// ----- + +#map0 = affine_map<(d0, d1) -> (d0, d1)> +#pointwise_2d_trait = {args_in = 2, args_out = 1, indexing_maps = [#map0, #map0, #map0], iterator_types = ["parallel", "parallel"]} +func @fusion(%multiplier: memref<6x6xf32>, %summand_1: memref<6x6xf32>, + %summand_2: memref<6x6xf32>) -> memref<6x6xf32> { + %temp_result = alloc() : memref<6x6xf32> + linalg.generic #pointwise_2d_trait %summand_1, %summand_2, %temp_result { + ^bb0(%summand_1_in: f32, %summand_2_in: f32, %temp_result_in: f32): + %out = addf %summand_1_in, %summand_2_in : f32 + linalg.yield %out : f32 + } : memref<6x6xf32>, memref<6x6xf32>, memref<6x6xf32> + %result = alloc() : memref<6x6xf32> + linalg.generic #pointwise_2d_trait %temp_result, %multiplier, %result { + ^bb0(%temp_result_in: f32, %multiplier_in: f32, %result_in: f32): + %out = mulf %temp_result_in, %multiplier_in : f32 + linalg.yield %out : f32 + } : memref<6x6xf32>, memref<6x6xf32>, memref<6x6xf32> + dealloc %temp_result : memref<6x6xf32> + return %result : memref<6x6xf32> +} + +// CHECK-LABEL: func @fusion +// CHECK: %[[C1:.*]] = constant 1 +// CHECK-NOT: linalg.generic +// CHECK: scf.for {{.*}} step %[[C1]] +// CHECK: scf.for {{.*}} step %[[C1]] +// CHECK-NOT: scf.for +// CHECK: linalg.generic +// CHECK: addf +// CHECK: linalg.generic +// CHECK: mulf + +// TILED-LABEL: func @fusion +// TILED-DAG: %[[C2:.*]] = constant 2 +// TILED-DAG: %[[C3:.*]] = constant 3 +// TILED-NOT: linalg.generic +// TILED: scf.for {{.*}} step %[[C2]] +// TILED: scf.for {{.*}} step %[[C3]] +// TILED-NOT: scf.for +// TILED: linalg.generic +// TILED: addf +// TILED: linalg.generic +// TILED: mulf + +// PLOOP-LABEL: func @fusion +// PLOOP-NOT: linalg.generic +// PLOOP: scf.parallel +// PLOOP-NOT: scf.parallel +// PLOOP: linalg.generic +// PLOOP: addf +// PLOOP: linalg.generic +// PLOOP: mulf diff --git a/tests/lhlo-legalize-select-and-scatter.mlir b/tests/lhlo-legalize-select-and-scatter.mlir new file mode 100644 index 0000000..2aa6378 --- /dev/null +++ b/tests/lhlo-legalize-select-and-scatter.mlir @@ -0,0 +1,193 @@ +// GenericAtomicRMWOp should contain only ops with no side effects. +// Unfortunately, the legalization pattern for SelectAndScatterOp has to adapt +// to XLA LHLO dialect using allocs/deallocs inside of GenericAtomicRMWOp body. +// Lowering to STD dialect and store forwarding pass would be required to get +// rid of them. This is exactly what is done in the real MLIR GPU pipeline, but +// here we disable verification with `verify-each=0` to check the output IR. +// RUN: mlir-hlo-opt %s -lhlo-legalize-to-parallel-loops -canonicalize --verify-each=0 | FileCheck %s + +func @select_and_scatter(%arg: memref<112x112xf32>, + %src: memref<56x56xf32>, + %init: memref, + %result: memref<112x112xf32>) { + "xla_lhlo.select_and_scatter"(%arg, %src, %init, %result) ( { + // select + ^bb0(%lhs: memref, %rhs: memref, %pred: memref): + "xla_lhlo.compare"(%lhs, %rhs, %pred) {comparison_direction = "GE"} : + (memref, memref, memref) -> () + "xla_lhlo.terminator"() : () -> () + }, { + // scatter + ^bb0(%lhs: memref, %rhs: memref, %out: memref): + "xla_lhlo.add"(%lhs, %rhs, %out) : + (memref, memref, memref) -> () + "xla_lhlo.terminator"() : () -> () + }) { + padding = dense<[[0, 1], [0, 1]]> : tensor<2x2xi64>, + window_dimensions = dense<[3, 3]> : tensor<2xi64>, + window_strides = dense<[2, 2]> : tensor<2xi64> + } : (memref<112x112xf32>, + memref<56x56xf32>, + memref, memref<112x112xf32>) -> () + "xla_lhlo.terminator"() : () -> () +} +// CHECK-LABEL: func @select_and_scatter( +// CHECK-SAME: [[ARG_BUF:%.*]]: memref<112x112xf32>, +// CHECK-SAME: [[SRC_BUF:%.*]]: memref<56x56xf32>, +// CHECK-SAME: [[INIT_BUF:%.*]]: memref, +// CHECK-SAME: [[RESULT_BUF:%.*]]: memref<112x112xf32>) { + +// Constants. +// CHECK-DAG: [[C56:%.*]] = constant 56 : index +// CHECK-DAG: [[C0:%.*]] = constant 0 : index +// CHECK-DAG: [[C1:%.*]] = constant 1 : index +// CHECK-DAG: [[C0_F32:%.*]] = constant 0.000000e+00 : f32 +// CHECK-DAG: [[CFALSE:%.*]] = constant false +// CHECK-DAG: [[C3:%.*]] = constant 3 : index +// CHECK-DAG: [[C2:%.*]] = constant 2 : index +// CHECK-DAG: [[C112:%.*]] = constant 112 : index +// CHECK-DAG: [[CTRUE:%.*]] = constant true + +// Parallel loop to initialize the output buffer. +// CHECK: [[INIT:%.*]] = load [[INIT_BUF]][] : memref +// CHECK: scf.parallel ([[I:%.*]], [[J:%.*]]) = ([[C0]], [[C0]]) +// CHECK-SAME: to ([[C112]], [[C112]]) step ([[C1]], [[C1]]) { +// CHECK: store [[INIT]], [[RESULT_BUF]]{{\[}}[[I]], [[J]]] +// CHECK: scf.yield +// CHECK: } + +// Parallel loop over source buffer to compute scattered values. +// CHECK: scf.parallel ([[II:%.*]], [[JJ:%.*]]) = ([[C0]], [[C0]]) +// CHECK-SAME: to ([[C56]], [[C56]]) step ([[C1]], [[C1]]) { + +// Window loop w.r.t. first dim. +// CHECK: [[SEL_RES_I:%.*]]:4 +// CHECK-SAME: = scf.for [[WIN_I:%.*]] = [[C0]] to [[C3]] step [[C1]] +// CHECK-SAME: iter_args( +// CHECK-SAME: [[SEL_I_0:%.*]] = [[C0]], [[SEL_J_0:%.*]] = [[C0]], +// CHECK-SAME: [[SEL_VAL_0:%.*]] = [[C0_F32]], +// CHECK-SAME: [[SEL_INIT_0:%.*]] = [[CFALSE]] +// CHECK-SAME: ) -> (index, index, f32, i1) { + +// Window loop w.r.t. second dim. +// CHECK: [[SEL_RES_J:%.*]]:4 +// CHECK-SAME: = scf.for [[WIN_J:%.*]] = [[C0]] to [[C3]] step [[C1]] +// CHECK-SAME: iter_args( +// CHECK-SAME: [[SEL_I:%.*]] = [[SEL_I_0]], [[SEL_J:%.*]] = [[SEL_J_0]], +// CHECK-SAME: [[SEL_VAL:%.*]] = [[SEL_VAL_0]], +// CHECK-SAME: [[SEL_INIT:%.*]] = [[SEL_INIT_0]] +// CHECK-SAME: ) -> (index, index, f32, i1) { + +// Compute index I of the ARG buffer and check whether it is in padding area. +// CHECK: [[START_I:%.*]] = muli [[II]], [[C2]] : index +// CHECK: [[ARG_I:%.*]] = addi [[START_I]], [[WIN_I]] : index +// CHECK: [[ARG_I_FITS:%.*]] = cmpi "ult", [[ARG_I]], [[C112]] : index + +// Compute index J of the ARG buffer and check whether it is in padding area. +// CHECK: [[START_J:%.*]] = muli [[JJ]], [[C2]] : index +// CHECK: [[ARG_J:%.*]] = addi [[START_J]], [[WIN_J]] : index +// CHECK: [[ARG_J_FITS:%.*]] = cmpi "ult", [[ARG_J]], [[C112]] : index + +// Update `INBOUNDS`, i.e. whether or not ARG indices are inside the boundaries +// of the buffer or they are in the padding area. +// CHECK: [[INBOUNDS_1:%.*]] = and [[ARG_I_FITS]], [[ARG_J_FITS]] : i1 + +// If ARG ivs are in the padding area, then 'select' function does not have to +// be applied, current selected ivs (SEL_I, SEL_J) and value (SEL_VAL) are +// returned in that case. +// CHECK: [[IF_INBOUNDS_RES:%.*]]:4 +// CHECK-SAME: = scf.if [[INBOUNDS_1]] -> (index, index, f32, i1) { + + + // INBOUNDS-THEN-BODY, i.e. if INBOUNDS == true + + // CHECK: [[ARG_ELEM:%.*]] = load [[ARG_BUF]]{{\[}}[[ARG_I]], [[ARG_J]]] + // CHECK: [[IF_INIT_RES:%.*]]:4 + // CHECK-SAME: = scf.if [[SEL_INIT]] -> (index, index, f32, i1) { + + // INIT-THEN-BODY, i.e. INBOUNDS == true and INIT = true + + // The LHLO IR of the select block of the lhlo.select_and_scatter is applied + // to the current selected value (SEL_VAL) and the element of the ARG buffer + // to compute boolean PRED, whether the new value and ivs should replace the + // current ones. + + // Allocate buffers for ARG element, current selected value to adapt LHLO + // code. + // CHECK: [[ARG_ELEM_BUF:%.*]] = alloc() : memref + // CHECK: [[SEL_VAL_BUF:%.*]] = alloc() : memref + // CHECK: [[PRED_BUF:%.*]] = alloc() : memref + // CHECK: store [[ARG_ELEM]], [[ARG_ELEM_BUF]][] : memref + // CHECK: store [[SEL_VAL]], [[SEL_VAL_BUF]][] : memref + + // Compute PRED. + // CHECK: "xla_lhlo.compare"( + // CHECK-SAME: [[ARG_ELEM_BUF]], [[SEL_VAL_BUF]], [[PRED_BUF]]) + // CHECK: [[PRED:%.*]] = load [[PRED_BUF]][] : memref + + + // Depending on PRED, return ARG ivs & elem or current select ivs and value. + // CHECK: [[IF_PRED_RES:%.*]]:4 = scf.if [[PRED]] + // CHECK: scf.yield [[ARG_I]], [[ARG_J]], [[ARG_ELEM]], [[CTRUE]] + // CHECK: } else { + // CHECK: scf.yield [[SEL_I]], [[SEL_J]], [[SEL_VAL]], [[SEL_INIT]] + // CHECK: } + + // INIT-THEN-BODY yield. + // CHECK: scf.yield [[IF_PRED_RES]]#0, [[IF_PRED_RES]]#1, + // CHECK-SAME: [[IF_PRED_RES]]#2, [[IF_PRED_RES]]#3 + + // INIT-ELSE-BODY, i.e. if INBOUNDS == TRUE and INIT == FALSE, returns ARG + // ivs and element without computing Select function. + // CHECK: scf.yield [[ARG_I]], [[ARG_J]], [[ARG_ELEM]], + // CHECK-SAME: [[CTRUE]] : index, index, f32, i1 + // CHECK: } + + // INBOUNDS-THEN-BODY yield. + // CHECK: scf.yield [[IF_INIT_RES]]#0, [[IF_INIT_RES]]#1, [[IF_INIT_RES]]#2, + // CHECK-SAME: [[IF_INIT_RES]]#3 : index, index, f32, i1 + // CHECK: } + + // INBOUNDS-ELSE-REGION, i.e. if INBOUNDS == FALSE + // We are in the pad area, return current iter_args. + // CHECK: scf.yield [[SEL_I]], [[SEL_J]], [[SEL_VAL]], + // CHECK-SAME: [[SEL_INIT]] : index, index, f32, i1 + // CHECK: } + +// Window loop w.r.t. second dim yield. +// CHECK: scf.yield [[IF_INBOUNDS_RES]]#0, [[IF_INBOUNDS_RES]]#1, +// CHECK-SAME: [[IF_INBOUNDS_RES]]#2, [[IF_INBOUNDS_RES]]#3 +// CHECK: } + +// Window loop w.r.t. first dim yield. +// CHECK: scf.yield [[SEL_RES_J]]#0, [[SEL_RES_J]]#1, [[SEL_RES_J]]#2, +// CHECK-SAME: [[SEL_RES_J]]#3 : index, index, f32, i1 +// CHECK: } + +// Use selected ivs to load element from the SRC buffer. +// CHECK: [[SRC_ELEM:%.*]] = load [[SRC_BUF]]{{\[}}[[II]], [[JJ]]] + +// Update of RESULT[SELECTED_I, SELECTED_J] should be done atomically, because +// it may happen that several other threads select the same IVs if the windows +// overlap. +// CHECK: generic_atomic_rmw [[RESULT_BUF]]{{\[}}[[SEL_RES_I]]#0, +// CHECK-SAME: [[SEL_RES_I]]#1] : memref<112x112xf32> +// CHECK: ^bb0([[CUR_RES:%.*]]: f32): + +// Allocate buffers for ARG element, current selected value to adapt LHLO code. +// CHECK: [[SRC_ELEM_BUF:%.*]] = alloc() : memref +// CHECK: [[CUR_RES_BUF:%.*]] = alloc() : memref +// CHECK: [[RES_BUF:%.*]] = alloc() : memref +// CHECK: store [[SRC_ELEM]], [[SRC_ELEM_BUF]][] : memref +// CHECK: store [[CUR_RES]], [[CUR_RES_BUF]][] : memref + +// Compute scatter value. +// CHECK: "xla_lhlo.add"([[SRC_ELEM_BUF]], [[CUR_RES_BUF]], [[RES_BUF]]) : +// CHECK-SAME: (memref, memref, memref) -> () +// CHECK: [[RES:%.*]] = load [[RES_BUF]][] : memref + +// Atomic RMW terminator that returns updated value. +// CHECK: atomic_yield [[RES]] : f32 + +// Parallel loop over source buffer yield +// CHECK: scf.yield diff --git a/tests/lhlo-legalize-to-affine.mlir b/tests/lhlo-legalize-to-affine.mlir new file mode 100644 index 0000000..1068d1a --- /dev/null +++ b/tests/lhlo-legalize-to-affine.mlir @@ -0,0 +1,181 @@ +// RUN: mlir-hlo-opt -lhlo-legalize-to-affine %s -o - | FileCheck %s + +// Smoke test. +// CHECK-LABEL: func @min_op +func @min_op(%lhs: memref<4x3x2x1xf32>, %rhs: memref<4x3x2x1xf32>, + %result: memref<4x3x2x1xf32>) -> () { + // CHECK-NEXT: affine.for %[[I:.*]] = 0 to 4 { + // CHECK-NEXT: affine.for %[[J:.*]] = 0 to 3 { + // CHECK-NEXT: affine.for %[[K:.*]] = 0 to 2 { + // CHECK-NEXT: affine.for %[[L:.*]] = 0 to 1 { + // CHECK-NEXT: %[[LHS:.*]] = affine.load %{{.*}}[%[[I]], %[[J]], %[[K]], %[[L]]] : memref<4x3x2x1xf32> + // CHECK-NEXT: %[[RHS:.*]] = affine.load %{{.*}}[%[[I]], %[[J]], %[[K]], %[[L]]] : memref<4x3x2x1xf32> + // CHECK-NEXT: %[[MIN_PREDICATE:.*]] = cmpf "olt", %[[LHS]], %[[RHS]] : f32 + // CHECK-NEXT: %[[MIN:.*]] = select %[[MIN_PREDICATE]], %[[LHS]], %[[RHS]] : f32 + // CHECK-NEXT: affine.store %[[MIN]], %{{.*}}[%[[I]], %[[J]], %[[K]], %[[L]]] : memref<4x3x2x1xf32> + // CHECK: return + "xla_lhlo.minimum"(%lhs, %rhs, %result) {name = "min.1"} : + (memref<4x3x2x1xf32>, memref<4x3x2x1xf32>, memref<4x3x2x1xf32>) -> () + return +} + +// Add tests. +// CHECK-LABEL: func @float_add_op +func @float_add_op(%lhs: memref<7xf32>, %rhs: memref<7xf32>, + %result: memref<7xf32>) -> () { + // CHECK: addf %{{.*}}, %{{.*}} : f32 + "xla_lhlo.add"(%lhs, %rhs, %result) {name = "add.1"} + : (memref<7xf32>, memref<7xf32>, memref<7xf32>) -> () + return +} +// CHECK-LABEL: func @int_add_op +func @int_add_op(%lhs: memref<7xi32>, %rhs: memref<7xi32>, + %result: memref<7xi32>) -> () { + // CHECK: addi %{{.*}}, %{{.*}} : i32 + "xla_lhlo.add"(%lhs, %rhs, %result) {name = "add.1"} + : (memref<7xi32>, memref<7xi32>, memref<7xi32>) -> () + return +} + +// And test. +// CHECK-LABEL: func @int_and_op +func @int_and_op(%lhs: memref<7xi32>, %rhs: memref<7xi32>, + %result: memref<7xi32>) -> () { + // CHECK: and %{{.*}}, %{{.*}} : i32 + "xla_lhlo.and"(%lhs, %rhs, %result) {name = "and.1"} + : (memref<7xi32>, memref<7xi32>, memref<7xi32>) -> () + return +} + +// Div tests. +// CHECK-LABEL: func @float_div_op +func @float_div_op(%lhs: memref<7xf32>, %rhs: memref<7xf32>, + %result: memref<7xf32>) -> () { + // CHECK: divf %{{.*}}, %{{.*}} : f32 + "xla_lhlo.divide"(%lhs, %rhs, %result) {name = "div.1"} + : (memref<7xf32>, memref<7xf32>, memref<7xf32>) -> () + return +} +// CHECK-LABEL: func @int_div_op +func @int_div_op(%lhs: memref<7xi32>, %rhs: memref<7xi32>, + %result: memref<7xi32>) -> () { + // CHECK: divi_signed %{{.*}}, %{{.*}} : i32 + "xla_lhlo.divide"(%lhs, %rhs, %result) {name = "div.1"} + : (memref<7xi32>, memref<7xi32>, memref<7xi32>) -> () + return +} + +// Max tests. +// CHECK-LABEL: func @float_max_op +func @float_max_op(%lhs: memref<7xf32>, %rhs: memref<7xf32>, + %result: memref<7xf32>) -> () { + // CHECK: %[[CHECK:.*]] = cmpf "ogt", %[[ONE:.*]], %[[TWO:.*]] : f32 + // CHECK: select %[[CHECK]], %[[ONE]], %[[TWO]] : f32 + "xla_lhlo.maximum"(%lhs, %rhs, %result) {name = "max.1"} + : (memref<7xf32>, memref<7xf32>, memref<7xf32>) -> () + return +} + +// CHECK-LABEL: func @int_max_op +func @int_max_op(%lhs: memref<7xi32>, %rhs: memref<7xi32>, + %result: memref<7xi32>) -> () { + // CHECK: %[[CHECK:.*]] = cmpi "sgt", %[[ONE:.*]], %[[TWO:.*]] : i32 + // CHECK: select %[[CHECK]], %[[ONE]], %[[TWO]] : i32 + "xla_lhlo.maximum"(%lhs, %rhs, %result) {name = "max.1"} + : (memref<7xi32>, memref<7xi32>, memref<7xi32>) -> () + return +} + +// Min tests. +// CHECK-LABEL: func @float_min_op +func @float_min_op(%lhs: memref<7xf32>, %rhs: memref<7xf32>, + %result: memref<7xf32>) -> () { + // CHECK: %[[CHECK:.*]] = cmpf "olt", %[[ONE:.*]], %[[TWO:.*]] : f32 + // CHECK: select %[[CHECK]], %[[ONE]], %[[TWO]] : f32 + "xla_lhlo.minimum"(%lhs, %rhs, %result) {name = "min.1"} + : (memref<7xf32>, memref<7xf32>, memref<7xf32>) -> () + return +} + +// CHECK-LABEL: func @int_min_op +func @int_min_op(%lhs: memref<7xi32>, %rhs: memref<7xi32>, + %result: memref<7xi32>) -> () { + // CHECK: %[[CHECK:.*]] = cmpi "slt", %[[ONE:.*]], %[[TWO:.*]] : i32 + // CHECK: select %[[CHECK]], %[[ONE]], %[[TWO]] : i32 + "xla_lhlo.minimum"(%lhs, %rhs, %result) {name = "min.1"} + : (memref<7xi32>, memref<7xi32>, memref<7xi32>) -> () + return +} + +// Mul tests. +// CHECK-LABEL: func @float_mul_op +func @float_mul_op(%lhs: memref<7xf32>, %rhs: memref<7xf32>, + %result: memref<7xf32>) -> () { + // CHECK: mulf %{{.*}}, %{{.*}} : f32 + "xla_lhlo.multiply"(%lhs, %rhs, %result) {name = "mul.1"} + : (memref<7xf32>, memref<7xf32>, memref<7xf32>) -> () + return +} + +// CHECK-LABEL: func @int_mul_op +func @int_mul_op(%lhs: memref<7xi32>, %rhs: memref<7xi32>, + %result: memref<7xi32>) -> () { + // CHECK: muli %{{.*}}, %{{.*}} : i32 + "xla_lhlo.multiply"(%lhs, %rhs, %result) {name = "mul.1"} + : (memref<7xi32>, memref<7xi32>, memref<7xi32>) -> () + return +} + +// Sub tests. +// CHECK-LABEL: func @float_sub_op +func @float_sub_op(%lhs: memref<7xf32>, %rhs: memref<7xf32>, + %result: memref<7xf32>) -> () { + // CHECK: subf %{{.*}}, %{{.*}} : f32 + "xla_lhlo.subtract"(%lhs, %rhs, %result) {name = "sub.1"} + : (memref<7xf32>, memref<7xf32>, memref<7xf32>) -> () + return +} +// CHECK-LABEL: func @int_sub_op +func @int_sub_op(%lhs: memref<7xi32>, %rhs: memref<7xi32>, + %result: memref<7xi32>) -> () { + // CHECK: subi %{{.*}}, %{{.*}} : i32 + "xla_lhlo.subtract"(%lhs, %rhs, %result) {name = "sub.1"} + : (memref<7xi32>, memref<7xi32>, memref<7xi32>) -> () + return +} + +// Dot tests. +// CHECK-LABEL: func @float_dot_op +func @float_dot_op(%lhs: memref<7x3xf32>, %rhs: + memref<3x4xf32>, %result: memref<7x4xf32> ) -> () { + // CHECK-NEXT: affine.for %[[I:.*]] = 0 to 7 { + // CHECK-NEXT: affine.for %[[J:.*]] = 0 to 4 { + // CHECK-NEXT: affine.for %[[K:.*]] = 0 to 3 { + // CHECK-NEXT: %[[LHS:.*]] = affine.load %{{.*}}[%[[I]], %[[K]]] : memref<7x3xf32> + // CHECK-NEXT: %[[RHS:.*]] = affine.load %{{.*}}[%[[K]], %[[J]]] : memref<3x4xf32> + // CHECK-NEXT: %[[RESULT:.*]] = affine.load %{{.*}}[%[[I]], %[[J]]] : memref<7x4xf32> + // CHECK-NEXT: %[[MULT:.*]] = mulf %[[LHS]], %[[RHS]] : f32 + // CHECK-NEXT: %[[ADD:.*]] = addf %[[MULT]], %[[RESULT]] : f32 + // CHECK-NEXT: affine.store %[[ADD]], %{{.*}}[%[[I]], %[[J]]] : memref<7x4xf32> + // CHECK: return + "xla_lhlo.dot"(%lhs, %rhs, %result) : + (memref<7x3xf32>, memref<3x4xf32>, memref<7x4xf32>) -> () + return +} +// CHECK-LABEL: func @int_dot_op +func @int_dot_op(%lhs: memref<7x3xi32>, %rhs: + memref<3x4xi32>, %result: memref<7x4xi32> ) -> () { + // CHECK-NEXT: affine.for %[[I:.*]] = 0 to 7 { + // CHECK-NEXT: affine.for %[[J:.*]] = 0 to 4 { + // CHECK-NEXT: affine.for %[[K:.*]] = 0 to 3 { + // CHECK-NEXT: %[[LHS:.*]] = affine.load %{{.*}}[%[[I]], %[[K]]] : memref<7x3xi32> + // CHECK-NEXT: %[[RHS:.*]] = affine.load %{{.*}}[%[[K]], %[[J]]] : memref<3x4xi32> + // CHECK-NEXT: %[[RESULT:.*]] = affine.load %{{.*}}[%[[I]], %[[J]]] : memref<7x4xi32> + // CHECK-NEXT: %[[MULT:.*]] = muli %[[LHS]], %[[RHS]] : i32 + // CHECK-NEXT: %[[ADD:.*]] = addi %[[MULT]], %[[RESULT]] : i32 + // CHECK-NEXT: affine.store %[[ADD]], %{{.*}}[%[[I]], %[[J]]] : memref<7x4xi32> + // CHECK: return + "xla_lhlo.dot"(%lhs, %rhs, %result) : + (memref<7x3xi32>, memref<3x4xi32>, memref<7x4xi32>) -> () + return +} diff --git a/tests/lhlo-legalize-to-gpu.mlir b/tests/lhlo-legalize-to-gpu.mlir new file mode 100644 index 0000000..e996581 --- /dev/null +++ b/tests/lhlo-legalize-to-gpu.mlir @@ -0,0 +1,34 @@ +// RUN: mlir-hlo-opt %s -lhlo-legalize-to-gpu -split-input-file | FileCheck %s + +func @reduce(%arg: memref<100x10xf32>, + %init: memref, + %result: memref<100xf32>) { + "xla_lhlo.reduce"(%arg, %init, %result) ( { + ^bb0(%lhs: memref, %rhs: memref, %res: memref): + "xla_lhlo.add"(%lhs, %rhs, %res) + : (memref, memref, memref) -> () + "xla_lhlo.terminator"() : () -> () + } ) {dimensions = dense<[1]> : tensor<1xi64>} + : (memref<100x10xf32>, memref, memref<100xf32>) -> () + return +} + +// CHECK: func @reduce(%[[ARG0:.*]]: memref<100x10xf32>, %[[ARG1:.*]]: memref, %[[ARG2:.*]]: memref<100xf32>) { +// CHECK-DAG: %[[C100:.*]] = constant 100 : index +// CHECK-DAG: %[[C1:.*]] = constant 1 : index +// CHECK: gpu.launch blocks({{.*}}, {{.*}}, {{.*}}) in ({{.*}} = %[[C1]], {{.*}} = %[[C1]], {{.*}} = %[[C1]]) threads(%[[IDX:.*]], {{.*}}, {{.*}}) in ({{.*}} = %[[C100]], {{.*}} = %[[C1]], {{.*}} = %[[C1]]) { +// CHECK: %[[ACC:.*]] = load %[[ARG1]][] : memref +// CHECK: store %[[ACC]], %[[ARG2]][%[[IDX:.*]]] : memref<100xf32> +// CHECK-DAG: %[[LB:.*]] = constant 0 : index +// CHECK-DAG: %[[UB:.*]] = constant 10 : index +// CHECK-DAG: %[[STEP:.*]] = constant 1 : index +// CHECK: scf.for %[[IDX1:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] { +// CHECK: %[[LHS:.*]] = linalg.slice %[[ARG2]][%[[IDX]]] : memref<100xf32>, index, memref +// CHECK: %[[RHS:.*]] = linalg.slice %[[ARG0]][%[[IDX]], %[[IDX1]]] : memref<100x10xf32>, index, index, memref +// CHECK: "xla_lhlo.add"(%[[LHS]], %[[RHS]], %[[LHS]]) : (memref, memref, memref) -> () +// CHECK: } +// CHECK: gpu.terminator +// CHECK: } +// CHECK: return +// CHECK: } +// CHECK: } diff --git a/tests/lhlo-legalize-to-linalg.mlir b/tests/lhlo-legalize-to-linalg.mlir new file mode 100644 index 0000000..8ebfb6b --- /dev/null +++ b/tests/lhlo-legalize-to-linalg.mlir @@ -0,0 +1,724 @@ +// RUN: mlir-hlo-opt %s -lhlo-legalize-to-linalg -split-input-file | FileCheck %s + +// CHECK: #map0 = affine_map<(d0, d1) -> (d0, d1)> +// CHECK-LABEL: func @element_wise +func @element_wise(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>, + %result: memref<2x2xf32>) { + "xla_lhlo.add"(%lhs, %rhs, %result) + : (memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>) -> () + return +} +// CHECK: linalg.generic +// CHECK-NEXT: ^bb0(%[[LHS_IN:.*]]: f32, %[[RHS_IN:.*]]: f32, %[[RESULT_OUT:.*]]: f32): +// CHECK-NEXT: %[[RESULT:.*]] = addf %[[LHS_IN]], %[[RHS_IN]] : f32 +// CHECK-NEXT: linalg.yield %[[RESULT]] : f32 + +// ----- + +// CHECK-LABEL: func @element_wise_with_dynamic_shape +func @element_wise_with_dynamic_shape(%lhs: memref, + %rhs: memref, + %result: memref) { + "xla_lhlo.add"(%lhs, %rhs, %result) + : (memref, memref, memref) -> () + return +} +// CHECK: linalg.generic +// CHECK-NEXT: ^bb0(%[[LHS_IN:.*]]: f32, %[[RHS_IN:.*]]: f32, %[[RESULT_OUT:.*]]: f32): +// CHECK-NEXT: %[[RESULT:.*]] = addf %[[LHS_IN]], %[[RHS_IN]] : f32 +// CHECK-NEXT: linalg.yield %[[RESULT]] : f32 + +// ----- + +// CHECK-LABEL: func @element_wise_scalar +func @element_wise_scalar(%lhs: memref, %rhs: memref, + %result: memref) { + "xla_lhlo.add"(%lhs, %rhs, %result) + : (memref, memref, memref) -> () + return +} +// CHECK: %[[LHS:.*]] = load +// CHECK: %[[RHS:.*]] = load +// CHECK: %[[RES:.*]] = addf %[[LHS]], %[[RHS]] +// CHECK: store %[[RES]] +// CHECK-NEXT: return + +// ----- + +// CHECK-LABEL: func @minf +func @minf(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>, + %result: memref<2x2xf32>) { + "xla_lhlo.minimum"(%lhs, %rhs, %result) + : (memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>) -> () + return +} +// CHECK: linalg.generic +// CHECK-NEXT: ^bb0(%[[LHS_IN:.*]]: f32, %[[RHS_IN:.*]]: f32, %[[RESULT_OUT:.*]]: f32): +// CHECK-NEXT: %[[CMP:.*]] = cmpf "olt", %[[LHS_IN]], %[[RHS_IN]] : f32 +// CHECK-NEXT: %[[RESULT:.*]] = select %[[CMP]], %[[LHS_IN]], %[[RHS_IN]] : f32 +// CHECK-NEXT: linalg.yield %[[RESULT]] : f32 + +// ----- + +// CHECK-LABEL: func @maxi +func @maxi(%lhs: memref<2x2xi32>, %rhs: memref<2x2xi32>, + %result: memref<2x2xi32>) { + "xla_lhlo.maximum"(%lhs, %rhs, %result) + : (memref<2x2xi32>, memref<2x2xi32>, memref<2x2xi32>) -> () + return +} +// CHECK: linalg.generic +// CHECK-NEXT: ^bb0(%[[LHS_IN:.*]]: i32, %[[RHS_IN:.*]]: i32, %[[RESULT_OUT:.*]]: i32): +// CHECK-NEXT: %[[CMP:.*]] = cmpi "sgt", %[[LHS_IN]], %[[RHS_IN]] : i32 +// CHECK-NEXT: %[[RESULT:.*]] = select %[[CMP]], %[[LHS_IN]], %[[RHS_IN]] : i32 +// CHECK-NEXT: linalg.yield %[[RESULT]] : i32 + +// ----- + +// CHECK-LABEL: func @and +func @and(%lhs: memref<2x2xi32>, %rhs: memref<2x2xi32>, + %result: memref<2x2xi32>) { + "xla_lhlo.and"(%lhs, %rhs, %result) + : (memref<2x2xi32>, memref<2x2xi32>, memref<2x2xi32>) -> () + return +} +// CHECK: linalg.generic +// CHECK-NEXT: ^bb0(%[[LHS_IN:.*]]: i32, %[[RHS_IN:.*]]: i32, %[[RESULT_OUT:.*]]: i32): +// CHECK-NEXT: %[[RESULT:.*]] = and %[[LHS_IN]], %[[RHS_IN]] : i32 +// CHECK-NEXT: linalg.yield %[[RESULT]] : i32 + +// ----- + +// CHECK-LABEL: func @exp +func @exp(%input: memref<2x2xf32>, %result: memref<2x2xf32>) { + "xla_lhlo.exponential"(%input, %result) + : (memref<2x2xf32>, memref<2x2xf32>) -> () + return +} +// CHECK: linalg.generic +// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: f32, %[[RESULT_OUT:.*]]): +// CHECK-NEXT: %[[RESULT:.*]] = exp %[[OPERAND_IN]] : f32 +// CHECK-NEXT: linalg.yield %[[RESULT]] : f32 + +// ----- + +// CHECK-LABEL: func @log +func @log(%input: memref<2x2xf32>, %result: memref<2x2xf32>) { + "xla_lhlo.log"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> () + return +} +// CHECK: linalg.generic +// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: f32, %[[RESULT_OUT:.*]]): +// CHECK-NEXT: %[[RESULT:.*]] = log %[[OPERAND_IN]] : f32 +// CHECK-NEXT: linalg.yield %[[RESULT]] : f32 + +// ----- + +// CHECK-LABEL: func @copy +func @copy(%in: memref<2x4x8xf32>, %out: memref<2x4x8xf32>) { + "xla_lhlo.copy"(%in, %out) : (memref<2x4x8xf32>, memref<2x4x8xf32>) -> () + return +} +// CHECK: linalg.generic +// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: f32, %[[RESULT_OUT:.*]]): +// CHECK-NEXT: linalg.yield %[[OPERAND_IN]] : f32 + +// ----- + +// CHECK-LABEL: func @float_cmp +func @float_cmp(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>, + %result: memref<2x2xi1>) { + "xla_lhlo.compare"(%lhs, %rhs, %result) {comparison_direction = "EQ"} + : (memref<2x2xf32>, memref<2x2xf32>, memref<2x2xi1>) -> () + return +} +// CHECK: linalg.generic +// CHECK-NEXT: ^bb0(%[[LHS_IN:.*]]: f32, %[[RHS_IN:.*]]: f32, %[[RESULT_OUT:.*]]: i1): +// CHECK-NEXT: %[[RESULT:.*]] = cmpf "oeq", %[[LHS_IN]], %[[RHS_IN]] : f32 +// CHECK-NEXT: linalg.yield %[[RESULT]] : i1 + +// ----- + +// CHECK-LABEL: func @int_cmp +func @int_cmp(%lhs: memref<2x2xi32>, %rhs: memref<2x2xi32>, + %result: memref<2x2xi1>) { + "xla_lhlo.compare"(%lhs, %rhs, %result) {comparison_direction = "LT"} + : (memref<2x2xi32>, memref<2x2xi32>, memref<2x2xi1>) -> () + return +} +// CHECK: linalg.generic +// CHECK-NEXT: ^bb0(%[[LHS_IN:.*]]: i32, %[[RHS_IN:.*]]: i32, %[[RESULT_OUT:.*]]: i1): +// CHECK-NEXT: %[[RESULT:.*]] = cmpi "slt", %[[LHS_IN]], %[[RHS_IN]] : i32 +// CHECK-NEXT: linalg.yield %[[RESULT]] : i1 + +// ----- + +// CHECK-LABEL: func @select +func @select(%pred: memref<2x2xi1>, %lhs: memref<2x2xf32>, + %rhs: memref<2x2xf32>, %result: memref<2x2xf32>) { + "xla_lhlo.select"(%pred, %lhs, %rhs, %result) + : (memref<2x2xi1>, memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>) -> () + return +} +// CHECK: linalg.generic +// CHECK-NEXT: ^bb0(%[[PRED_IN:.*]]: i1, %[[LHS_IN:.*]]: f32, %[[RHS_IN:.*]]: f32, %[[RESULT_OUT:.*]]: f32): +// CHECK-NEXT: %[[RESULT:.*]] = select %[[PRED_IN]], %[[LHS_IN]], %[[RHS_IN]] : f32 +// CHECK-NEXT: linalg.yield %[[RESULT]] : f32 + +// ----- + +// CHECK: #[[RESULT_MAP:.*]] = affine_map<(d0, d1) -> (d0, d1)> +// CHECK-LABEL: func @iota +func @iota(%out: memref<7x10xf32>) { + "xla_lhlo.iota"(%out) {iota_dimension = 1 : i64} : (memref<7x10xf32>) -> () + return +} +// CHECK: linalg.indexed_generic +// CHECK-SAME: indexing_maps = [#[[RESULT_MAP]]] +// CHECK-NEXT: ^bb0(%[[D0:.*]]: index, %[[D1:.*]]: index, %[[RESULT:.*]]: f32): +// CHECK-NEXT: %[[INT_CAST:.*]] = index_cast %[[D1]] : index to i32 +// CHECK-NEXT: %[[FLOAT_CAST:.*]] = sitofp %[[INT_CAST]] : i32 to f32 +// CHECK-NEXT: linalg.yield %[[FLOAT_CAST]] : f32 + +// ----- + +// CHECK-DAG: #[[OPERAND_MAP:.+]] = affine_map<(d0, d1, d2) -> ()> +// CHECK-DAG: #[[RESULT_MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +// CHECK-LABEL: func @broadcast_scalar +func @broadcast_scalar(%operand: memref, %result: memref<4x2x1xf32>) { + "xla_lhlo.broadcast"(%operand, %result) { + broadcast_sizes = dense<[4, 2, 1]> : tensor<3xi64> + } : (memref, memref<4x2x1xf32>) -> () + return +} +// CHECK: linalg.generic +// CHECK-SAME: indexing_maps = [#[[OPERAND_MAP]], #[[RESULT_MAP]]] +// CHECK-NEXT: ^bb0(%[[OPERAND:.+]]: f32, %{{.+}}: f32): +// CHECK-NEXT: linalg.yield %[[OPERAND]] : f32 + +// ----- + +// CHECK-DAG: #[[OPERAND_MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d3, d4, d5)> +// CHECK-DAG: #[[RESULT_MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)> +// CHECK-LABEL: func @broadcast +func @broadcast(%operand: memref<4x?x16xf32>, + %result: memref<4x2x1x4x?x16xf32>) { + "xla_lhlo.broadcast"(%operand, %result) { + broadcast_sizes = dense<[4, 2, 1]> : tensor<3xi64> + } : (memref<4x?x16xf32>, memref<4x2x1x4x?x16xf32>) -> () + return +} +// CHECK: linalg.generic +// CHECK-SAME: indexing_maps = [#[[OPERAND_MAP]], #[[RESULT_MAP]]] +// CHECK-NEXT: ^bb0(%[[OPERAND:.+]]: f32, %{{.+}}: f32): +// CHECK-NEXT: linalg.yield %[[OPERAND]] : f32 + +// ----- + +// CHECK-DAG: #[[OPERAND_MAP:.*]] = affine_map<(d0, d1, d2, d3, d4) -> (d4, d0, d2)> +// CHECK-DAG: #[[RESULT_MAP:.*]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)> +// CHECK-LABEL: func @dynamic_broadcast_in_dim +func @dynamic_broadcast_in_dim(%operand: memref, + %result: memref) { + "xla_lhlo.broadcast_in_dim"(%operand, %result) { + broadcast_dimensions = dense<[4,0,2]> : tensor<3xi64> + } : (memref, memref) -> () + return +} +// CHECK: linalg.generic +// CHECK-SAME: indexing_maps = [#[[OPERAND_MAP]], #[[RESULT_MAP]]] +// CHECK-NEXT: ^bb0(%[[OPERAND:.*]]: f32, %[[RESULT:.*]]: f32): +// CHECK-NEXT: linalg.yield %[[OPERAND]] : f32 + +// ----- + +// CHECK-DAG: #[[OPERAND_MAP:.*]] = affine_map<(d0, d1) -> (d0)> +// CHECK-DAG: #[[RESULT_MAP:.*]] = affine_map<(d0, d1) -> (d0, d1)> +// CHECK-LABEL: func @static_broadcast_in_dim_no_expansion +func @static_broadcast_in_dim_no_expansion(%operand: memref<5xf32>, + %result: memref<5x10xf32>) { + "xla_lhlo.broadcast_in_dim"(%operand, %result) { + broadcast_dimensions = dense<[0]> : tensor<1xi64> + } : (memref<5xf32>, memref<5x10xf32>) -> () + return +} +// CHECK-NOT: linalg.reshape +// CHECK: linalg.generic {{{.*}}indexing_maps = [#[[OPERAND_MAP]], #[[RESULT_MAP]]] +// CHECK-NEXT: ^bb0(%[[OPERAND:.*]]: f32, %[[RESULT:.*]]: f32): +// CHECK-NEXT: linalg.yield %[[OPERAND]] : f32 + +// ----- + +// CHECK-DAG: #[[REASSOCIATION:.*]] = affine_map<(d0, d1) -> (d0, d1)> +// CHECK-DAG: #[[OPERAND_MAP:.*]] = affine_map<(d0, d1, d2) -> (d0)> +// CHECK-DAG: #[[RESULT_MAP:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +// CHECK-LABEL: func @static_broadcast_in_dim_expansion +func @static_broadcast_in_dim_expansion(%operand: memref<1x5xf32>, + %result: memref<5x10x100xf32>) { + "xla_lhlo.broadcast_in_dim"(%operand, %result) { + broadcast_dimensions = dense<[2, 0]> : tensor<2xi64> + } : (memref<1x5xf32>, memref<5x10x100xf32>) -> () + return +} +// CHECK: %[[RESHAPED_ARG:.*]] = linalg.reshape %{{.*}}#[[REASSOCIATION]]] +// CHECK-SAME: memref<1x5xf32> into memref<5xf32> +// CHECK: linalg.generic {{{.*}}indexing_maps = +// CHECK-SAME: [#[[OPERAND_MAP]], #[[RESULT_MAP]]]{{.*}} %[[RESHAPED_ARG]] +// CHECK-NEXT: ^bb0(%[[OPERAND:.*]]: f32, %[[RESULT:.*]]: f32): +// CHECK-NEXT: linalg.yield %[[OPERAND]] : f32 + +// ----- + +// CHECK-DAG: #[[RESULT_MAP_0:.*]] = affine_map<(d0, d1) -> ()> +// CHECK-DAG: #[[RESULT_MAP:.*]] = affine_map<(d0, d1) -> (d0, d1)> +// CHECK-LABEL: func @static_broadcast_in_dim_scalar +func @static_broadcast_in_dim_scalar(%operand: memref, + %result: memref<5x10xf32>) { + "xla_lhlo.broadcast_in_dim"(%operand, %result) { + broadcast_dimensions = dense<[]> : tensor<0xi64> + } : (memref, memref<5x10xf32>) -> () + return +} +// CHECK-NOT: linalg.reshape +// CHECK: linalg.generic {{{.*}}indexing_maps = [#[[RESULT_MAP_0]], #[[RESULT_MAP]]] +// CHECK-NEXT: ^bb0(%[[CONST:.*]]: f32, %[[RESULT:.*]]: f32): +// CHECK-NEXT: linalg.yield %[[CONST]] : f32 + +// ----- + +// CHECK-DAG: #[[OPERAND_MAP:.+]] = affine_map<(d0, d1) -> (d0)> +// CHECK-DAG: #[[RESULT_MAP:.+]] = affine_map<(d0, d1) -> (d0, d1)> +// CHECK-LABEL: func @static_broadcast_in_dim_with_one_to_one +func @static_broadcast_in_dim_with_one_to_one(%operand: memref<1xf32>, + %result: memref<1x5xf32>) { + "xla_lhlo.broadcast_in_dim"(%operand, %result) { + broadcast_dimensions = dense<[0]> : tensor<1xi64> + } : (memref<1xf32>, memref<1x5xf32>) -> () + return +} +// CHECK-NOT: linalg.reshape +// CHECK: linalg.generic {{{.*}}indexing_maps = [#[[OPERAND_MAP]], #[[RESULT_MAP]]] +// CHECK-NEXT: ^bb0(%[[OPERAND:.+]]: f32, %{{.+}}: f32): +// CHECK-NEXT: linalg.yield %[[OPERAND]] : f32 + +// ----- + +// CHECK-DAG: #[[RESULT_MAP:.+]] = affine_map<(d0, d1) -> (d0, d1)> +// CHECK-LABEL: func @static_broadcast_in_dim_with_one_to_many +func @static_broadcast_in_dim_with_one_to_many(%operand: memref<1xf32>, + %result: memref<5x5xf32>) { + "xla_lhlo.broadcast_in_dim"(%operand, %result) { + broadcast_dimensions = dense<[1]> : tensor<1xi64> + } : (memref<1xf32>, memref<5x5xf32>) -> () + return +} +// CHECK-NOT: linalg.reshape +// CHECK: %[[C0:.*]] = constant 0 : index +// CHECK: %[[VALUE:.*]] = load %{{.*}}[[C0]] +// CHECK: linalg.generic {{{.*}}indexing_maps = [#[[RESULT_MAP]]] +// CHECK-NEXT: ^bb0(%{{.+}}: f32): +// CHECK-NEXT: linalg.yield %[[VALUE]] : f32 + +// ----- + +// CHECK-LABEL: func @constant +func @constant(%value: memref) { + "xla_lhlo.constant"(%value) { + value = dense<10> : tensor + } : (memref) -> () + return +} +// CHECK: %[[CONSTANT:.*]] = constant 10 : i32 +// CHECK: store %[[CONSTANT]], %{{.*}}[] : memref + +// ----- + +// CHECK-LABEL: func @absf +func @absf(%input: memref<2x2xf32>, %result: memref<2x2xf32>) { + "xla_lhlo.abs"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> () + return +} +// CHECK: linalg.generic +// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: f32, %[[RESULT_OUT:.*]]): +// CHECK-NEXT: %[[RESULT:.*]] = absf %[[OPERAND_IN]] : f32 +// CHECK-NEXT: linalg.yield %[[RESULT]] : f32 + +// ----- + +// CHECK-LABEL: func @absi +func @absi(%input: memref<2x2xi32>, + %result: memref<2x2xi32>) { + "xla_lhlo.abs"(%input, %result) : (memref<2x2xi32>, memref<2x2xi32>) -> () + return +} + +// CHECK: linalg.generic +// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: i32, %[[RESULT_OUT:.*]]): +// CHECK-NEXT: %[[L0:.*]] = constant 0 : i32 +// CHECK-NEXT: %[[L1:.*]] = cmpi "sge", %[[OPERAND_IN]], %[[L0]] : i32 +// CHECK-NEXT: %[[L2:.*]] = subi %[[L0]], %[[OPERAND_IN]] : i32 +// CHECK-NEXT: %[[RESULT:.*]] = select %[[L1]], %[[OPERAND_IN]], %[[L2]] : i32 +// CHECK-NEXT: linalg.yield %[[RESULT]] : i32 + +// ----- + +// CHECK-LABEL: func @ceil +func @ceil(%input: memref<2x2xf32>, %result: memref<2x2xf32>) { + "xla_lhlo.ceil"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> () + return +} +// CHECK: linalg.generic +// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: f32, %[[RESULT_OUT:.*]]): +// CHECK-NEXT: %[[RESULT:.*]] = ceilf %[[OPERAND_IN]] : f32 +// CHECK-NEXT: linalg.yield %[[RESULT]] : f32 + +// ----- + +// CHECK-LABEL: func @convert_i32_to_f32 +func @convert_i32_to_f32(%input: memref<2x2xi32>, %result: memref<2x2xf32>) { + "xla_lhlo.convert"(%input, %result) : (memref<2x2xi32>, memref<2x2xf32>) -> () + return +} +// CHECK: linalg.generic +// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: i32, %[[RESULT_OUT:.*]]: f32): +// CHECK-NEXT: %[[RESULT:.*]] = sitofp %[[OPERAND_IN]] : i32 to f32 +// CHECK-NEXT: linalg.yield %[[RESULT]] : f32 + +// ----- + +// CHECK-LABEL: func @convert_i16_to_i32 +func @convert_i16_to_i32(%input: memref<2x2xi16>, + %result: memref<2x2xi32>) { + "xla_lhlo.convert"(%input, %result) : (memref<2x2xi16>, memref<2x2xi32>) -> () + return +} +// CHECK: linalg.generic +// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: i16, %[[RESULT_OUT:.*]]: i32): +// CHECK-NEXT: %[[RESULT:.*]] = zexti %[[OPERAND_IN]] : i16 to i32 +// CHECK-NEXT: linalg.yield %[[RESULT]] : i32 + +// ----- + +// CHECK-LABEL: func @convert_i32_to_i16 +func @convert_i32_to_i16(%input: memref<2x2xi32>, %result: memref<2x2xi16>) { + "xla_lhlo.convert"(%input, %result) : (memref<2x2xi32>, memref<2x2xi16>) -> () + return +} +// CHECK: linalg.generic +// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: i32, %[[RESULT_OUT:.*]]: i16): +// CHECK-NEXT: %[[RESULT:.*]] = trunci %[[OPERAND_IN]] : i32 to i16 +// CHECK-NEXT: linalg.yield %[[RESULT]] : i16 + +// ----- + +// CHECK-LABEL: func @convert_f32_to_f64 +func @convert_f32_to_f64(%input: memref<2x2xf32>, %result: memref<2x2xf64>) { + "xla_lhlo.convert"(%input, %result) : (memref<2x2xf32>, memref<2x2xf64>) -> () + return +} +// CHECK: linalg.generic +// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: f32, %[[RESULT_OUT:.*]]: f64): +// CHECK-NEXT: %[[RESULT:.*]] = fpext %[[OPERAND_IN]] : f32 to f64 +// CHECK-NEXT: linalg.yield %[[RESULT]] : f64 + +// ----- + +// CHECK-LABEL: func @convert_f64_to_f32 +func @convert_f64_to_f32(%input: memref<2x2xf64>, %result: memref<2x2xf32>) { + "xla_lhlo.convert"(%input, %result) : (memref<2x2xf64>, memref<2x2xf32>) -> () + return +} +// CHECK: linalg.generic +// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: f64, %[[RESULT_OUT:.*]]: f32): +// CHECK-NEXT: %[[RESULT:.*]] = fptrunc %[[OPERAND_IN]] : f64 to f32 +// CHECK-NEXT: linalg.yield %[[RESULT]] : f32 + +// ----- + +// CHECK-LABEL: func @convert_i32_to_i32 +func @convert_i32_to_i32(%input: memref<2x2xi32>, %result: memref<2x2xi32>) { + "xla_lhlo.convert"(%input, %result) : (memref<2x2xi32>, memref<2x2xi32>) -> () + return +} +// CHECK: linalg.generic +// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: i32, %[[RESULT_OUT:.*]]: i32): +// CHECK-NEXT: linalg.yield %[[OPERAND_IN]] : i32 + +// ----- + +// CHECK-LABEL: func @convert_f32_to_f32 +func @convert_f32_to_f32(%input: memref<2x2xf32>, %result: memref<2x2xf32>) { + "xla_lhlo.convert"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> () + return +} +// CHECK: linalg.generic +// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: f32, %[[RESULT_OUT:.*]]: f32): +// CHECK-NEXT: linalg.yield %[[OPERAND_IN]] : f32 + +// ----- + +// CHECK-LABEL: func @convert_f32_to_i32 +func @convert_f32_to_i32(%input: memref<2x2xf32>, %result: memref<2x2xi32>) { + "xla_lhlo.convert"(%input, %result) + : (memref<2x2xf32>, memref<2x2xi32>) -> () + return +} +// CHECK: linalg.generic +// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: f32, %[[RESULT_OUT:.*]]: i32): +// CHECK-NEXT: %[[RESULT:.*]] = fptosi %[[OPERAND_IN]] : f32 to i32 +// CHECK-NEXT: linalg.yield %[[RESULT]] : i32 + +// ----- + +// CHECK-LABEL: func @cos +func @cos(%input: memref<2x2xf32>, %result: memref<2x2xf32>) { + "xla_lhlo.cosine"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> () + return +} +// CHECK: linalg.generic +// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: f32, %[[RESULT_OUT:.*]]): +// CHECK-NEXT: %[[RESULT:.*]] = cos %[[OPERAND_IN]] : f32 +// CHECK-NEXT: linalg.yield %[[RESULT]] : f32 + +// ----- + +// CHECK-LABEL: func @sin +func @sin(%input: memref<2x2xf32>, + %result: memref<2x2xf32>) { + "xla_lhlo.sine"(%input, %result) + : (memref<2x2xf32>, memref<2x2xf32>) -> () + return +} +// CHECK: linalg.generic +// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: f32, %[[RESULT_OUT:.*]]): +// CHECK-NEXT: %[[RESULT:.*]] = sin %[[OPERAND_IN]] : f32 +// CHECK-NEXT: linalg.yield %[[RESULT]] : f32 + +// ----- + +// CHECK-LABEL: func @negf +func @negf(%input: memref<2x2xf32>, %result: memref<2x2xf32>) { + "xla_lhlo.negate"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> () + return +} +// CHECK: linalg.generic +// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: f32, %[[RESULT_OUT:.*]]): +// CHECK-NEXT: %[[RESULT:.*]] = negf %[[OPERAND_IN]] : f32 +// CHECK-NEXT: linalg.yield %[[RESULT]] : f32 + +// ----- + +// CHECK-LABEL: func @negi +func @negi(%input: memref<2x2xi32>, %result: memref<2x2xi32>) { + "xla_lhlo.negate"(%input, %result) : (memref<2x2xi32>, memref<2x2xi32>) -> () + return +} +// CHECK: linalg.generic +// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: i32, %[[RESULT_OUT:.*]]): +// CHECK-NEXT: %[[L0:.*]] = constant 0 : i32 +// CHECK-NEXT: %[[RESULT:.*]] = subi %[[L0]], %[[OPERAND_IN]] : i32 +// CHECK-NEXT: linalg.yield %[[RESULT]] : i32 + +// ----- + +// CHECK-LABEL: func @rem +func @remainder(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>, + %result: memref<2x2xf32>) { + "xla_lhlo.remainder"(%lhs, %rhs, %result) + : (memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>) -> () + return +} +// CHECK: linalg.generic +// CHECK-NEXT: ^bb0(%[[LHS_IN:.*]]: f32, %[[RHS_IN:.*]]: f32, %[[RESULT:.*]]: f32): +// CHECK-NEXT: %[[RESULT:.*]] = remf %[[LHS_IN]], %[[RHS_IN]] : f32 +// CHECK-NEXT: linalg.yield %[[RESULT]] : f32 + +// ----- + +// CHECK-LABEL: func @rsqrt +func @rsqrt(%input: memref<2x2xf32>, %result: memref<2x2xf32>) { + "xla_lhlo.rsqrt"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> () + return +} +// CHECK: linalg.generic +// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: f32, %[[RESULT_OUT:.*]]): +// CHECK-NEXT: %[[RESULT:.*]] = rsqrt %[[OPERAND_IN]] : f32 +// CHECK-NEXT: linalg.yield %[[RESULT]] : f32 + +// ----- + +// CHECK-LABEL: func @sign +func @sign(%input: memref<2x2xf32>, %result: memref<2x2xf32>) { + "xla_lhlo.sign"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> () + return +} +// CHECK: linalg.generic +// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: f32, %[[RESULT_OUT:.*]]): +// CHECK-NEXT: %[[CST:.*]] = constant 1.000000e+00 : f32 +// CHECK-NEXT: %[[RESULT:.*]] = copysign %[[CST]], %[[OPERAND_IN]] : f32 +// CHECK-NEXT: linalg.yield %[[RESULT]] : f32 + +// ----- + +// CHECK-LABEL: func @sqrt +func @sqrt(%input: memref<2x2xf32>, %result: memref<2x2xf32>) { + "xla_lhlo.sqrt"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> () + return +} +// CHECK: linalg.generic +// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: f32, %[[RESULT_OUT:.*]]): +// CHECK-NEXT: %[[RESULT:.*]] = sqrt %[[OPERAND_IN]] : f32 +// CHECK-NEXT: linalg.yield %[[RESULT]] : f32 + +// ----- + +// CHECK-LABEL: func @tanh +func @tanh(%input: memref<2x2xf32>, %result: memref<2x2xf32>) { + "xla_lhlo.tanh"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> () + return +} +// CHECK: linalg.generic +// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: f32, %[[RESULT_OUT:.*]]): +// CHECK-NEXT: %[[RESULT:.*]] = tanh %[[OPERAND_IN]] : f32 +// CHECK-NEXT: linalg.yield %[[RESULT]] : f32 + +// ----- + +// CHECK-LABEL: func @complex +func @complex(%real: memref<2x2xf32>, + %imag: memref<2x2xf32>, + %cplx: memref<2x2xcomplex>) { + "xla_lhlo.complex"(%real, %imag, %cplx) + : (memref<2x2xf32>, memref<2x2xf32>, memref<2x2xcomplex>) -> () + return +} +// CHECK: linalg.generic +// CHECK-NEXT: ^bb0(%[[RE:.*]]: f32, %[[IM:.*]]: f32, %[[CP:.*]]: complex): +// CHECK-NEXT: %[[RESULT:.*]] = create_complex %[[RE]], %[[IM]] : complex +// CHECK-NEXT: linalg.yield %[[RESULT]] : complex + +// ----- + +// CHECK-LABEL: func @real +func @real(%cplx: memref<2x2xcomplex>, + %real: memref<2x2xf32>) { + "xla_lhlo.real"(%cplx, %real) + : (memref<2x2xcomplex>, memref<2x2xf32>) -> () + return +} +// CHECK: linalg.generic +// CHECK-NEXT: ^bb0(%[[CPLX_IN:.*]]: complex, %[[REAL_OUT:.*]]: f32): +// CHECK-NEXT: %[[REAL:.*]] = re %[[CPLX_IN:.*]] : complex +// CHECK-NEXT: linalg.yield %[[REAL]] : f32 + +// ----- + +// CHECK-LABEL: func @imag +func @imag(%cplx: memref<2x2xcomplex>, + %imag: memref<2x2xf32>) { + "xla_lhlo.imag"(%cplx, %imag) + : (memref<2x2xcomplex>, memref<2x2xf32>) -> () + return +} +// CHECK: linalg.generic +// CHECK-NEXT: ^bb0(%[[CPLX_IN:.*]]: complex, %[[IMAG_OUT:.*]]: f32): +// CHECK-NEXT: %[[IMAG:.*]] = im %[[CPLX_IN:.*]] : complex +// CHECK-NEXT: linalg.yield %[[IMAG]] : f32 + +// ----- + +// CHECK: func @slice(%[[IN:.*]]: memref, %[[OUT:.*]]: memref) +func @slice(%operand: memref, %result: memref) { + "xla_lhlo.slice"(%operand, %result) { + start_indices = dense<[0,1]> : tensor<2xi64>, + limit_indices = dense<[2,3]> : tensor<2xi64>, + strides = dense<[1,1]> : tensor<2xi64> + } : (memref, memref) -> () + return +} +// CHECK: %[[L0:.*]] = constant 0 : index +// CHECK: %[[L2:.*]] = constant 2 : index +// CHECK: %[[L1:.*]] = constant 1 : index +// CHECK: %[[LHS:.*]] = linalg.range %[[L0]] : %[[L2]] : %[[L1]] +// CHECK: %[[R0:.*]] = constant 1 : index +// CHECK: %[[R2:.*]] = constant 3 : index +// CHECK: %[[R1:.*]] = constant 1 : index +// CHECK: %[[RHS:.*]] = linalg.range %[[R0]] : %[[R2]] : %[[R1]] +// CHECK: %[[RESULT:.*]] = linalg.slice %[[IN]][%[[LHS]], %[[RHS]]] +// CHECK: linalg.copy(%[[RESULT]], %[[OUT]]) + +// ----- + +// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)> +// CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0, d1, d2) -> (d2)> +// CHECK-LABEL: func @reshape_3D_2D +func @reshape_3D_2D(%arg0: memref<12x1x42xi32>, %arg1 : memref<12x42xi32>) { + "xla_lhlo.reshape"(%arg0, %arg1) + : (memref<12x1x42xi32>, memref<12x42xi32>) -> () + return +} +// CHECK: linalg.reshape %{{.*}} [#[[MAP1]], #[[MAP2]]] +// CHECK-NEXT: linalg.copy + +// ----- + +// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d0)> +// CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0, d1, d2, d3) -> (d1, d2, d3)> +// CHECK-LABEL: func @reshape_4D_2D +func @reshape_4D_2D(%arg0: memref<12x42x1x1xi32>, %arg1 : memref<12x42xi32>) { + "xla_lhlo.reshape"(%arg0, %arg1) + : (memref<12x42x1x1xi32>, memref<12x42xi32>) -> () + return +} +// CHECK: linalg.reshape %{{.*}} [#[[MAP1]], #[[MAP2]]] +// CHECK-NEXT: linalg.copy + +// ----- + +// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1)> +// CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0, d1, d2, d3) -> (d2, d3)> +// CHECK-LABEL: func @reshape_2D_4D +func @reshape_2D_4D(%arg0: memref<12x42xi32>, %arg1 : memref<12x1x42x1xi32>) { + "xla_lhlo.reshape"(%arg0, %arg1) + : (memref<12x42xi32>, memref<12x1x42x1xi32>) -> () + return +} +// CHECK: linalg.reshape %{{.*}} [#[[MAP1]], #[[MAP2]]] +// CHECK-NEXT: linalg.copy + +// ----- + +// CHECK-DAG: #[[OPERAND_MAP:.*]] = affine_map<(d0, d1) -> (d0, -d1 + 2)> +// CHECK-DAG: #[[RESULT_MAP:.*]] = affine_map<(d0, d1) -> (d0, d1)> +// CHECK-LABEL: func @reverse +func @reverse(%arg0: memref<2x3xf32>, %arg1: memref<2x3xf32>) { + "xla_lhlo.reverse"(%arg0, %arg1) { + dimensions = dense<1> : tensor<1xi64> + } : (memref<2x3xf32>, memref<2x3xf32>) -> () + return +} +// CHECK: linalg.generic {{{.*}}indexing_maps = [#[[OPERAND_MAP]], #[[RESULT_MAP]]] + + +// ----- + +func @conv(%input: memref<3x5x5x3xf32>, %filter: memref<2x2x3x4xf32>, %output: memref<3x5x5x4xf32>) { + %c0 = constant 0 : index + %0 = alloc() : memref<3x5x5x4xf32> + // CHECK: linalg.conv(%{{.+}}, %{{.+}}, %{{.+}}) + // CHECK-SAME: dilations = [1, 2] + // CHECK-SAME: padding = dense<{{\[\[}}0, 1], [0, 1]]> : tensor<2x2xi64> + // CHECK-SAME: strides = [2, 1]} + // With all atributes explicitly specified. + "xla_lhlo.convolution"(%filter, %input, %0) {batch_group_count = 1 : i64, dimension_numbers = {input_batch_dimension = 0 : i64, input_feature_dimension = 3 : i64, input_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>, kernel_input_feature_dimension = 2 : i64, kernel_output_feature_dimension = 3 : i64, kernel_spatial_dimensions = dense<[0, 1]> : tensor<2xi64>, output_batch_dimension = 0 : i64, output_feature_dimension = 3 : i64, output_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>}, feature_group_count = 1 : i64, padding = dense<[[0, 1], [0, 1]]> : tensor<2x2xi64>, rhs_dilation = dense<[1, 2]> : tensor<2xi64>, window_strides = dense<[2, 1]> : tensor<2xi64>} : (memref<2x2x3x4xf32>, memref<3x5x5x3xf32>, memref<3x5x5x4xf32>) -> () + + // Dilation left unspecified, sets default dilation since linalg expects it. + // CHECK: linalg.conv(%{{.+}}, %{{.+}}, %{{.+}}) + // CHECK-SAME: dilations = [1, 1] + // Padding is not set if it's zero. + // CHECK-NOT: padding + "xla_lhlo.convolution"(%filter, %input, %0) {batch_group_count = 1 : i64, dimension_numbers = {input_batch_dimension = 0 : i64, input_feature_dimension = 3 : i64, input_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>, kernel_input_feature_dimension = 2 : i64, kernel_output_feature_dimension = 3 : i64, kernel_spatial_dimensions = dense<[0, 1]> : tensor<2xi64>, output_batch_dimension = 0 : i64, output_feature_dimension = 3 : i64, output_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>}, feature_group_count = 1 : i64, window_strides = dense<[2, 1]> : tensor<2xi64>} : (memref<2x2x3x4xf32>, memref<3x5x5x3xf32>, memref<3x5x5x4xf32>) -> () + + "xla_lhlo.copy"(%0, %output) : (memref<3x5x5x4xf32>, memref<3x5x5x4xf32>) -> () + "xla_lhlo.terminator"() : () -> () +} diff --git a/tests/lhlo-legalize-to-llvm.mlir b/tests/lhlo-legalize-to-llvm.mlir new file mode 100644 index 0000000..a9759c0 --- /dev/null +++ b/tests/lhlo-legalize-to-llvm.mlir @@ -0,0 +1,65 @@ +// RUN: mlir-hlo-opt %s --test-lhlo-legalize-to-llvm -split-input-file | FileCheck %s + +// CHECK-LABEL: func @static_memref_cast +func @static_memref_cast(%buf : memref<10x1x5xf32>) { + %0 = xla_lhlo.static_memref_cast %buf + : memref<10x1x5xf32> -> memref<10x5xf32, offset: 2, strides: [5, 1]> + return +} +// CHECK: %[[INPUT_MEMREF_BLDR:.*]] = llvm.mlir.undef : [[DESCRIPTOR_TYPE_3D:!.*]] +// CHECK: llvm.insertvalue +// CHECK: %[[MEMREF_BLDR_0:.*]] = llvm.mlir.undef : [[DESCRIPTOR_TYPE_2D:!.*]] + +// CHECK: %[[IN_PTR:.*]] = llvm.extractvalue %[[INPUT_MEMREF:.*]][0] : [[DESCRIPTOR_TYPE_3D]] +// CHECK: %[[PTR:.*]] = llvm.bitcast %[[IN_PTR]] : !llvm<"float*"> to !llvm<"float*"> +// CHECK: %[[MEMREF_BLDR_1:.*]] = llvm.insertvalue %[[PTR]], %[[MEMREF_BLDR_0]][0] : [[DESCRIPTOR_TYPE_2D]] + +// CHECK: %[[IN_ALIGNED_PTR:.*]] = llvm.extractvalue %[[INPUT_MEMREF]][1] : [[DESCRIPTOR_TYPE_3D]] +// CHECK: %[[ALIGNED_PTR:.*]] = llvm.bitcast %[[IN_ALIGNED_PTR]] : !llvm<"float*"> to !llvm<"float*"> +// CHECK: %[[MEMREF_BLDR_2:.*]] = llvm.insertvalue %[[ALIGNED_PTR]], %[[MEMREF_BLDR_1]][1] : [[DESCRIPTOR_TYPE_2D]] + +// CHECK: %[[C2:.*]] = llvm.mlir.constant(2 : index) : !llvm.i64 +// CHECK: %[[MEMREF_BLDR_3:.*]] = llvm.insertvalue %[[C2]], %[[MEMREF_BLDR_2]][2] : [[DESCRIPTOR_TYPE_2D]] + +// CHECK: %[[C10:.*]] = llvm.mlir.constant(10 : index) : !llvm.i64 +// CHECK: %[[MEMREF_BLDR_4:.*]] = llvm.insertvalue %[[C10]], %[[MEMREF_BLDR_3]][3, 0] : [[DESCRIPTOR_TYPE_2D]] +// CHECK: %[[C5:.*]] = llvm.mlir.constant(5 : index) : !llvm.i64 +// CHECK: %[[MEMREF_BLDR_5:.*]] = llvm.insertvalue %[[C5]], %[[MEMREF_BLDR_4]][4, 0] : [[DESCRIPTOR_TYPE_2D]] +// CHECK: %[[C5_:.*]] = llvm.mlir.constant(5 : index) : !llvm.i64 +// CHECK: %[[MEMREF_BLDR_6:.*]] = llvm.insertvalue %[[C5_]], %[[MEMREF_BLDR_5]][3, 1] : [[DESCRIPTOR_TYPE_2D]] +// CHECK: %[[C1:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64 +// CHECK: %[[MEMREF_BLDR_7:.*]] = llvm.insertvalue %[[C1]], %[[MEMREF_BLDR_6]][4, 1] : [[DESCRIPTOR_TYPE_2D]] + +// ----- + +// CHECK-LABEL: func @dynamic_memref_cast +func @dynamic_memref_cast(%buf : memref) { + %size_X = constant 10 : index + %size_Y = constant 50 : index + %stride_X = constant 1 : index + %stride_Y = constant 0 : index + %0 = xla_lhlo.dynamic_memref_cast %buf(%size_X, %size_Y)[%stride_X, %stride_Y] + : memref -> memref + return +} +// CHECK: %[[C10:.*]] = llvm.mlir.constant(10 : index) : !llvm.i64 +// CHECK: %[[C50:.*]] = llvm.mlir.constant(50 : index) : !llvm.i64 +// CHECK: %[[C1:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64 +// CHECK: %[[C0:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64 + +// CHECK: %[[MEMREF_BLDR_0:.*]] = llvm.mlir.undef : [[DESCRIPTOR_TYPE:!.*]] + +// CHECK: %[[IN_PTR:.*]] = llvm.extractvalue %[[INPUT_MEMREF:.*]][0] : [[DESCRIPTOR_TYPE]] +// CHECK: %[[PTR:.*]] = llvm.bitcast %[[IN_PTR]] : !llvm<"float*"> to !llvm<"float*"> +// CHECK: %[[MEMREF_BLDR_1:.*]] = llvm.insertvalue %[[PTR]], %[[MEMREF_BLDR_0]][0] : [[DESCRIPTOR_TYPE]] + +// CHECK: %[[IN_ALIGNED_PTR:.*]] = llvm.extractvalue %[[INPUT_MEMREF]][1] : [[DESCRIPTOR_TYPE]] +// CHECK: %[[ALIGNED_PTR:.*]] = llvm.bitcast %[[IN_ALIGNED_PTR]] : !llvm<"float*"> to !llvm<"float*"> +// CHECK: %[[MEMREF_BLDR_2:.*]] = llvm.insertvalue %[[ALIGNED_PTR]], %[[MEMREF_BLDR_1]][1] : [[DESCRIPTOR_TYPE]] + +// CHECK: %[[SRC_OFFSET:.*]] = llvm.extractvalue %[[INPUT_MEMREF]][2] : [[DESCRIPTOR_TYPE]] +// CHECK: %[[MEMREF_BLDR_3:.*]] = llvm.insertvalue %[[SRC_OFFSET]], %[[MEMREF_BLDR_2]][2] : [[DESCRIPTOR_TYPE]] +// CHECK: %[[MEMREF_BLDR_4:.*]] = llvm.insertvalue %[[C10]], %[[MEMREF_BLDR_3]][3, 0] : [[DESCRIPTOR_TYPE]] +// CHECK: %[[MEMREF_BLDR_5:.*]] = llvm.insertvalue %[[C1]], %[[MEMREF_BLDR_4]][4, 0] : [[DESCRIPTOR_TYPE]] +// CHECK: %[[MEMREF_BLDR_6:.*]] = llvm.insertvalue %[[C50]], %[[MEMREF_BLDR_5]][3, 1] : [[DESCRIPTOR_TYPE]] +// CHECK: %[[MEMREF_BLDR_7:.*]] = llvm.insertvalue %[[C0]], %[[MEMREF_BLDR_6]][4, 1] : [[DESCRIPTOR_TYPE]] diff --git a/tests/lhlo-legalize-to-parallel-loops.mlir b/tests/lhlo-legalize-to-parallel-loops.mlir new file mode 100644 index 0000000..a3d76ef --- /dev/null +++ b/tests/lhlo-legalize-to-parallel-loops.mlir @@ -0,0 +1,202 @@ +// RUN: mlir-hlo-opt %s -lhlo-legalize-to-parallel-loops -canonicalize -split-input-file | FileCheck %s + +func @reduce(%arg: memref<100x10x5xf32>, + %init: memref, + %result: memref<100x5xf32>) { + "xla_lhlo.reduce"(%arg, %init, %result) ( { + ^bb0(%lhs: memref, %rhs: memref, %res: memref): + "xla_lhlo.add"(%lhs, %rhs, %res) + : (memref, memref, memref) -> () + "xla_lhlo.terminator"() : () -> () + } ) {dimensions = dense<[1]> : tensor<1xi64>} + : (memref<100x10x5xf32>, memref, memref<100x5xf32>) -> () + return +} +// CHECK-LABEL: func @reduce( +// CHECK-SAME: [[ARG_BUF:%.*]]: memref<100x10x5xf32>, +// CHECK-SAME: [[INIT_BUF:%.*]]: memref, +// CHECK-SAME: [[RESULT_BUF:%.*]]: memref<100x5xf32>) { +// CHECK-DAG: [[C0:%.*]] = constant 0 : index +// CHECK-DAG: [[C1:%.*]] = constant 1 : index +// CHECK-DAG: [[C5:%.*]] = constant 5 : index +// CHECK-DAG: [[C10:%.*]] = constant 10 : index +// CHECK-DAG: [[C100:%.*]] = constant 100 : index +// CHECK: [[INIT:%.*]] = load [[INIT_BUF]] +// CHECK: scf.parallel ([[I:%.*]], [[K:%.*]]) = ([[C0]], [[C0]]) +// CHECK-SAME: to ([[C100]], [[C5]]) step ([[C1]], [[C1]]) { +// CHECK: [[REDUCTION_RESULT:%.*]] = scf.parallel ([[J:%.*]]) = +// CHECK-SAME: ([[C0]]) to ([[C10]]) step ([[C1]]) init ([[INIT]]) -> f32 { +// CHECK: [[ELEM_TO_REDUCE:%.*]] = load [[ARG_BUF]] +// CHECK-SAME: {{\[}}[[I]], [[J]], [[K]]] : memref<100x10x5xf32> +// CHECK: scf.reduce([[ELEM_TO_REDUCE]]) : f32 { +// CHECK: ^bb0([[ELEM:%.*]]: f32, [[ACC:%.*]]: f32): +// CHECK: [[ELEM_BUF:%.*]] = alloc() : memref +// CHECK: [[ACC_BUF:%.*]] = alloc() : memref +// CHECK: [[ACC_OUT_BUF:%.*]] = alloc() : memref +// CHECK: store [[ELEM]], [[ELEM_BUF]][] : memref +// CHECK: store [[ACC]], [[ACC_BUF]][] : memref +// CHECK: "xla_lhlo.add"([[ELEM_BUF]], [[ACC_BUF]], [[ACC_OUT_BUF]]) +// CHECK: [[ACC_RESULT:%.*]] = load [[ACC_OUT_BUF]][] : memref +// CHECK: scf.reduce.return [[ACC_RESULT]] : f32 +// CHECK: } +// CHECK: scf.yield +// CHECK: } +// CHECK: store [[REDUCTION_RESULT]], [[RESULT_BUF]]{{\[}}[[I]], [[K]]] +// CHECK: scf.yield + +// ----- + +func @reduce_no_outer_loop(%arg: memref<100xf32>, + %init: memref, + %result: memref<1xf32>) { + "xla_lhlo.reduce"(%arg, %init, %result) ( { + ^bb0(%lhs: memref, %rhs: memref, %res: memref): + "xla_lhlo.add"(%lhs, %rhs, %res) + : (memref, memref, memref) -> () + "xla_lhlo.terminator"() : () -> () + } ) {dimensions = dense<[0]> : tensor<1xi64>} + : (memref<100xf32>, memref, memref<1xf32>) -> () + return +} +// CHECK-LABEL: func @reduce_no_outer_loop( +// CHECK-SAME: [[ARG_BUF:%.*]]: memref<100xf32>, +// CHECK-SAME: [[ELEM_TO_REDUCE_BUF:%.*]]: memref, +// CHECK-SAME: [[RESULT_BUF:%.*]]: memref<1xf32>) { +// CHECK-DAG: [[C0:%.*]] = constant 0 : index +// CHECK-DAG: [[C1:%.*]] = constant 1 : index +// CHECK-DAG: [[C100:%.*]] = constant 100 : index +// CHECK: [[INIT:%.*]] = load [[INIT_BUF]] +// CHECK: [[REDUCTION_RESULT:%.*]] = scf.parallel ([[I:%.*]]) = ([[C0]]) +// CHECK-SAME: to ([[C100]]) step ([[C1]]) init ([[INIT]]) -> f32 { +// CHECK: [[ELEM_TO_REDUCE:%.*]] = load [[ARG_BUF]]{{\[}}[[I]]{{\]}} +// CHECK: scf.reduce([[ELEM_TO_REDUCE]]) : f32 { +// CHECK: ^bb0([[ELEM:%.*]]: f32, [[ACC:%.*]]: f32): +// CHECK: [[ELEM_BUF:%.*]] = alloc() : memref +// CHECK: [[ACC_BUF:%.*]] = alloc() : memref +// CHECK: [[ACC_OUT_BUF:%.*]] = alloc() : memref +// CHECK: store [[ELEM]], [[ELEM_BUF]][] : memref +// CHECK: store [[ACC]], [[ACC_BUF]][] : memref +// CHECK: "xla_lhlo.add"([[ELEM_BUF]], [[ACC_BUF]], [[ACC_OUT_BUF]]) +// CHECK: [[ACC_RESULT:%.*]] = load [[ACC_OUT_BUF]][] : memref +// CHECK: scf.reduce.return [[ACC_RESULT]] +// CHECK: } +// CHECK: scf.yield +// CHECK: store [[REDUCTION_RESULT]], [[RESULT_BUF]]{{\[}}[[C0]]] + +// ----- + +func @dynamic_reduce(%arg: memref, + %init: memref, + %result: memref) { + "xla_lhlo.reduce"(%arg, %init, %result) ( { + ^bb0(%lhs: memref, %rhs: memref, %res: memref): + "xla_lhlo.add"(%lhs, %rhs, %res) + : (memref, memref, memref) -> () + "xla_lhlo.terminator"() : () -> () + } ) {dimensions = dense<[1]> : tensor<1xi64>} + : (memref, memref, memref) -> () + return +} +// CHECK-LABEL: func @dynamic_reduce( +// CHECK-SAME: [[ARG_BUF:%.*]]: memref, +// CHECK-SAME: [[INIT_BUF:%.*]]: memref, +// CHECK-SAME: [[RESULT_BUF:%.*]]: memref) { +// CHECK-DAG: [[C0:%.*]] = constant 0 : index +// CHECK-DAG: [[C1:%.*]] = constant 1 : index +// CHECK-DAG: [[C2:%.*]] = constant 2 : index +// CHECK: [[DIM0:%.*]] = dim [[ARG_BUF]], [[C0]] : memref +// CHECK: [[DIM1:%.*]] = dim [[ARG_BUF]], [[C1]] : memref +// CHECK: [[DIM2:%.*]] = dim [[ARG_BUF]], [[C2]] : memref +// CHECK: [[INIT:%.*]] = load [[INIT_BUF]] +// CHECK: scf.parallel ([[I:%.*]], [[K:%.*]]) = ([[C0]], [[C0]]) +// CHECK-SAME: to ([[DIM0]], [[DIM2]]) step ([[C1]], [[C1]]) { +// CHECK: [[REDUCTION_RESULT:%.*]] = scf.parallel ([[J:%.*]]) = +// CHECK-SAME: ([[C0]]) to ([[DIM1]]) step ([[C1]]) init ([[INIT]]) -> f32 { +// CHECK: [[ELEM_TO_REDUCE:%.*]] = load [[ARG_BUF]] +// CHECK-SAME: {{\[}}[[I]], [[J]], [[K]]] : memref +// CHECK: scf.reduce([[ELEM_TO_REDUCE]]) : f32 { +// CHECK: ^bb0([[ELEM:%.*]]: f32, [[ACC:%.*]]: f32): +// CHECK: [[ELEM_BUF:%.*]] = alloc() : memref +// CHECK: [[ACC_BUF:%.*]] = alloc() : memref +// CHECK: [[ACC_OUT_BUF:%.*]] = alloc() : memref +// CHECK: store [[ELEM]], [[ELEM_BUF]][] : memref +// CHECK: store [[ACC]], [[ACC_BUF]][] : memref +// CHECK: "xla_lhlo.add"([[ELEM_BUF]], [[ACC_BUF]], [[ACC_OUT_BUF]]) +// CHECK: [[ACC_RESULT:%.*]] = load [[ACC_OUT_BUF]][] : memref +// CHECK: scf.reduce.return [[ACC_RESULT]] : f32 +// CHECK: } +// CHECK: scf.yield +// CHECK: } +// CHECK: store [[REDUCTION_RESULT]], [[RESULT_BUF]]{{\[}}[[I]], [[K]]] +// CHECK: scf.yield + +// ----- + +func @reduce_window(%arg: memref<112x112xf32>, + %init: memref, + %result: memref<56x56xf32>) { + "xla_lhlo.reduce_window"(%arg, %init, %result) ( { + ^bb0(%lhs: memref, %rhs: memref, %res: memref): + "xla_lhlo.maximum"(%lhs, %rhs, %res) + : (memref, memref, memref) -> () + "xla_lhlo.terminator"() : () -> () + }) { + padding = dense<[[0, 1], [0, 1]]> : tensor<2x2xi64>, + window_dimensions = dense<[3, 3]> : tensor<2xi64>, + window_strides = dense<[2, 2]> : tensor<2xi64> + } : (memref<112x112xf32>, memref, memref<56x56xf32>) -> () + return +} +// CHECK-LABEL: func @reduce_window( +// CHECK-SAME: [[OPERAND_BUF:%.*]]: memref<112x112xf32>, +// CHECK-SAME: [[INIT_BUF:%.*]]: memref, +// CHECK-SAME: [[RESULT_BUF:%.*]]: memref<56x56xf32>) { +// CHECK-DAG: [[C0:%.*]] = constant 0 : index +// CHECK-DAG: [[C1:%.*]] = constant 1 : index +// CHECK-DAG: [[C2:%.*]] = constant 2 : index +// CHECK-DAG: [[C3:%.*]] = constant 3 : index +// CHECK-DAG: [[C56:%.*]] = constant 56 : index +// CHECK-DAG: [[C112:%.*]] = constant 112 : index +// CHECK: [[INIT:%.*]] = load [[INIT_BUF]][] : memref +// CHECK: scf.parallel ([[I:%.*]], [[J:%.*]]) = ([[C0]], [[C0]]) +// CHECK-SAME: to ([[C56]], [[C56]]) step ([[C1]], [[C1]]) { +// CHECK: [[REDUCTION_RESULT:%.*]] = scf.parallel +// CHECK-SAME: ([[IW:%.*]], [[JW:%.*]]) = ([[C0]], [[C0]]) +// CHECK-SAME: to ([[C3]], [[C3]]) step ([[C1]], [[C1]]) +// CHECK-SAME: init ([[INIT]]) -> f32 { + +// CHECK: [[START_I:%.*]] = muli [[I]], [[C2]] : index +// CHECK: [[INDEX_I:%.*]] = addi [[START_I]], [[IW]] : index +// CHECK: [[INDEX_I_FITS:%.*]] = cmpi "ult", [[INDEX_I]], [[C112]] + +// CHECK: [[START_J:%.*]] = muli [[J]], [[C2]] : index +// CHECK: [[INDEX_J:%.*]] = addi [[START_J]], [[JW]] : index +// CHECK: [[INDEX_J_FITS:%.*]] = cmpi "ult", [[INDEX_J]], [[C112]] +// CHECK: [[IN_BOUNDS_1:%.*]] = and [[INDEX_I_FITS]], [[INDEX_J_FITS]] + +// CHECK: [[ELEM_TO_REDUCE:%.*]] = scf.if [[IN_BOUNDS_1]] -> (f32) { +// CHECK: [[OPERAND_ELEM:%.*]] = +// CHECK-SAME: load [[OPERAND_BUF]]{{\[}}[[INDEX_I]], [[INDEX_J]]] +// CHECK: scf.yield [[OPERAND_ELEM]] : f32 +// CHECK: } else { +// CHECK: scf.yield [[INIT]] : f32 +// CHECK: } + +// CHECK: scf.reduce([[ELEM_TO_REDUCE]]) : f32 { +// CHECK: ^bb0([[ELEM:%.*]]: f32, [[ACC:%.*]]: f32): +// CHECK: [[ELEM_BUF:%.*]] = alloc() : memref +// CHECK: [[ACC_BUF:%.*]] = alloc() : memref +// CHECK: [[ACC_OUT_BUF:%.*]] = alloc() : memref +// CHECK: store [[ELEM]], [[ELEM_BUF]][] : memref +// CHECK: store [[ACC]], [[ACC_BUF]][] : memref +// CHECK: "xla_lhlo.maximum"([[ELEM_BUF]], [[ACC_BUF]], [[ACC_OUT_BUF]]) +// CHECK: [[ACC_RESULT:%.*]] = load [[ACC_OUT_BUF]][] : memref +// CHECK: scf.reduce.return [[ACC_RESULT]] : f32 +// CHECK: } +// CHECK: scf.yield +// CHECK: } +// CHECK: store [[REDUCTION_RESULT]], [[RESULT_BUF]]{{\[}}[[I]], [[J]]] +// CHECK: scf.yield +// CHECK: } +// CHECK: return +// CHECK: } diff --git a/tests/lhlo_ops.mlir b/tests/lhlo_ops.mlir new file mode 100644 index 0000000..11cecde --- /dev/null +++ b/tests/lhlo_ops.mlir @@ -0,0 +1,1045 @@ +// RUN: mlir-hlo-opt %s -verify-diagnostics -split-input-file | mlir-hlo-opt | FileCheck %s + +// ----- + +// CHECK-LABEL: func @ceil +func @ceil(%input: memref<2x2xf32>, %result: memref<2x2xf32>) { + "xla_lhlo.ceil"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> () + return +} + +// ----- + +func @ceil(%input: memref<2x2xi32>, %result: memref<2x2xi32>) { + // expected-error@+1{{must be memref of floating-point values}} + "xla_lhlo.ceil"(%input, %result) : (memref<2x2xi32>, memref<2x2xi32>) -> () + return +} + +// ----- + +// CHECK-LABEL: func @cos +func @cos(%input: memref<2x2xf32>, %result: memref<2x2xf32>) { + "xla_lhlo.cosine"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> () + return +} + +// ----- + +// CHECK-LABEL: func @cos +func @cos(%input: memref<2x2xcomplex>, %result: memref<2x2xcomplex>) { + "xla_lhlo.cosine"(%input, %result) : (memref<2x2xcomplex>, memref<2x2xcomplex>) -> () + return +} + +// ----- + +func @cos(%input: memref<2x2xi32>, %result: memref<2x2xi32>) { + // expected-error@+1{{must be memref of floating-point or complex-type values}} + "xla_lhlo.cosine"(%input, %result) : (memref<2x2xi32>, memref<2x2xi32>) -> () + return +} + +// ----- + +// CHECK-LABEL: func @sin +func @sin(%input: memref<2x2xf32>, %result: memref<2x2xf32>) { + "xla_lhlo.sine"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> () + return +} + +// ----- + +// CHECK-LABEL: func @sin +func @sin(%input: memref<2x2xcomplex>, %result: memref<2x2xcomplex>) { + "xla_lhlo.sine"(%input, %result) : (memref<2x2xcomplex>, memref<2x2xcomplex>) -> () + return +} + +// ----- + +func @sin(%input: memref<2x2xi32>, %result: memref<2x2xi32>) { + // expected-error@+1{{must be memref of floating-point or complex-type values}} + "xla_lhlo.sine"(%input, %result) : (memref<2x2xi32>, memref<2x2xi32>) -> () + return +} + +// ----- + +// CHECK-LABEL: func @add_memrefs +func @add_memrefs(%arg0: memref<1xi32>, %arg1: memref<1xi32>, %arg_out: memref<1xi32>) -> () { + "xla_lhlo.add"(%arg0, %arg1, %arg_out) : (memref<1xi32>, memref<1xi32>, memref<1xi32>) -> () + return +} + +// ----- + +// CHECK-LABEL: func @abs_memref +func @abs_memref(%in: memref<10xf32>, %out: memref<10xf32>) -> () { + "xla_lhlo.abs"(%in, %out) : (memref<10xf32>, memref<10xf32>) -> () + return +} + +// ----- + +// CHECK-LABEL: func @convert_memref +func @convert_memref(%in: memref<10xf32>, %out: memref<10xi32>) -> () { + "xla_lhlo.convert"(%in, %out) : (memref<10xf32>, memref<10xi32>) -> () + return +} + +// ----- + +func @convert_memref(%in: memref<10xf32>, %out: memref<9xi32>) -> () { + // expected-error@+1{{requires the same shape for all operands}} + "xla_lhlo.convert"(%in, %out) : (memref<10xf32>, memref<9xi32>) -> () + return +} + +// ----- + +// CHECK-LABEL: func @exp +func @exp(%input: memref<2x2xf32>, %result: memref<2x2xf32>) { + "xla_lhlo.exponential"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> () + return +} + +// ----- + +// CHECK-LABEL: func @exp +func @exp(%input: memref<2x2xcomplex>, %result: memref<2x2xcomplex>) { + "xla_lhlo.exponential"(%input, %result) : (memref<2x2xcomplex>, memref<2x2xcomplex>) -> () + return +} + +// ----- + +func @exp(%input: memref<2x2xi32>, %result: memref<2x2xi32>) { + // expected-error@+1{{must be memref of floating-point or complex-type values}} + "xla_lhlo.exponential"(%input, %result) : (memref<2x2xi32>, memref<2x2xi32>) -> () + return +} + +// ----- + +// CHECK-LABEL: func @log_memref +func @log_memref(%in: memref<10xf32>, %out: memref<10xf32>) -> () { + "xla_lhlo.log"(%in, %out) : (memref<10xf32>, memref<10xf32>) -> () + return +} + +// ----- + +// CHECK-LABEL: func @log_memref +func @log_memref(%in: memref<10xcomplex>, %out: memref<10xcomplex>) -> () { + "xla_lhlo.log"(%in, %out) : (memref<10xcomplex>, memref<10xcomplex>) -> () + return +} + +// ----- + +func @log_memref(%in: memref<10xi32>, %out: memref<10xi32>) -> () { + // expected-error@+1{{must be memref of floating-point or complex-type values}} + "xla_lhlo.log"(%in, %out) : (memref<10xi32>, memref<10xi32>) -> () + return +} + +// ----- + +// CHECK-LABEL: func @neg_memref +func @neg_memref(%in: memref<10xf32>, %out: memref<10xf32>) -> () { + "xla_lhlo.negate"(%in, %out) : (memref<10xf32>, memref<10xf32>) -> () + return +} + +// ----- + +// CHECK-LABEL: func @rsqrt_memref +func @rsqrt_memref(%in: memref<10xf32>, %out: memref<10xf32>) -> () { + "xla_lhlo.rsqrt"(%in, %out) : (memref<10xf32>, memref<10xf32>) -> () + return +} + +// ----- + +// CHECK-LABEL: func @rsqrt_memref +func @rsqrt_memref(%in: memref<10xcomplex>, %out: memref<10xcomplex>) -> () { + "xla_lhlo.rsqrt"(%in, %out) : (memref<10xcomplex>, memref<10xcomplex>) -> () + return +} + +// ----- + +func @rsqrt_memref(%in: memref<10xi32>, %out: memref<10xi32>) -> () { + // expected-error@+1{{must be memref of floating-point or complex-type values}} + "xla_lhlo.rsqrt"(%in, %out) : (memref<10xi32>, memref<10xi32>) -> () + return +} + +// ----- + +// CHECK-LABEL: func @sqrt_memref +func @sqrt_memref(%in: memref<10xf32>, %out: memref<10xf32>) -> () { + "xla_lhlo.sqrt"(%in, %out) : (memref<10xf32>, memref<10xf32>) -> () + return +} + +// ----- + +// CHECK-LABEL: func @sqrt_memref +func @sqrt_memref(%in: memref<10xcomplex>, %out: memref<10xcomplex>) -> () { + "xla_lhlo.sqrt"(%in, %out) : (memref<10xcomplex>, memref<10xcomplex>) -> () + return +} + +// ----- + +func @sqrt_memref(%in: memref<10xi32>, %out: memref<10xi32>) -> () { + // expected-error@+1{{must be memref of floating-point or complex-type values}} + "xla_lhlo.sqrt"(%in, %out) : (memref<10xi32>, memref<10xi32>) -> () + return +} + +// ----- + +// CHECK-LABEL: func @sign_memref +func @sign_memref(%in: memref<10xf32>, %out: memref<10xf32>) -> () { + "xla_lhlo.sign"(%in, %out) : (memref<10xf32>, memref<10xf32>) -> () + return +} + +// ----- + +// CHECK-LABEL: func @tanh_memref +func @tanh_memref(%in: memref<10xf32>, %out: memref<10xf32>) -> () { + "xla_lhlo.tanh"(%in, %out) : (memref<10xf32>, memref<10xf32>) -> () + return +} + +// ----- + +// CHECK-LABEL: func @tanh_memref +func @tanh_memref(%in: memref<10xcomplex>, %out: memref<10xcomplex>) -> () { + "xla_lhlo.tanh"(%in, %out) : (memref<10xcomplex>, memref<10xcomplex>) -> () + return +} + +// ----- + +func @tanh_memref(%in: memref<10xi32>, %out: memref<10xi32>) -> () { + // expected-error@+1{{must be memref of floating-point or complex-type values}} + "xla_lhlo.tanh"(%in, %out) : (memref<10xi32>, memref<10xi32>) -> () + return +} + +// ----- + +func @tanh_memref(%arg0: memref<1xf32>, %arg1: memref<2xf32>) -> () { + // expected-error@+1{{'xla_lhlo.tanh' op requires all operands to have the same type}} + "xla_lhlo.tanh"(%arg0, %arg1) : (memref<1xf32>, memref<2xf32>) -> () + return +} + +// ----- + +// CHECK-LABEL: func @add_memref +func @add_memref(%lhs: memref<10xf32>, %rhs: memref<10xf32>, %out: memref<10xf32>) -> () { + "xla_lhlo.add"(%lhs, %rhs, %out) : (memref<10xf32>, memref<10xf32>, memref<10xf32>) -> () + return +} + +// ----- + +// CHECK-LABEL: func @div_memref +func @div_memref(%lhs: memref<10xf32>, %rhs: memref<10xf32>, %out: memref<10xf32>) -> () { + "xla_lhlo.divide"(%lhs, %rhs, %out) : (memref<10xf32>, memref<10xf32>, memref<10xf32>) -> () + return +} + +// ----- + +// CHECK-LABEL: func @max_memref +func @max_memref(%lhs: memref<10xf32>, %rhs: memref<10xf32>, %out: memref<10xf32>) -> () { + "xla_lhlo.maximum"(%lhs, %rhs, %out) : (memref<10xf32>, memref<10xf32>, memref<10xf32>) -> () + return +} + +// ----- + +// CHECK-LABEL: func @min_memref +func @min_memref(%lhs: memref<10xf32>, %rhs: memref<10xf32>, %out: memref<10xf32>) -> () { + "xla_lhlo.minimum"(%lhs, %rhs, %out) : (memref<10xf32>, memref<10xf32>, memref<10xf32>) -> () + return +} + +// ----- + +// CHECK-LABEL: func @mul_memref +func @mul_memref(%lhs: memref<10xf32>, %rhs: memref<10xf32>, %out: memref<10xf32>) -> () { + "xla_lhlo.multiply"(%lhs, %rhs, %out) : (memref<10xf32>, memref<10xf32>, memref<10xf32>) -> () + return +} + +// ----- + +// CHECK-LABEL: func @sub_memref +func @sub_memref(%lhs: memref<10xf32>, %rhs: memref<10xf32>, %out: memref<10xf32>) -> () { + "xla_lhlo.subtract"(%lhs, %rhs, %out) : (memref<10xf32>, memref<10xf32>, memref<10xf32>) -> () + return +} + +// ----- + +// CHECK-LABEL: func @and_memref +func @and_memref(%lhs: memref<10xi32>, %rhs: memref<10xi32>, %out: memref<10xi32>) -> () { + "xla_lhlo.and"(%lhs, %rhs, %out) : (memref<10xi32>, memref<10xi32>, memref<10xi32>) -> () + return +} + +// ----- + +// CHECK-LABEL: func @and_memref +func @and_memref(%lhs: memref<10xi1>, %rhs: memref<10xi1>, %out: memref<10xi1>) -> () { + "xla_lhlo.and"(%lhs, %rhs, %out) : (memref<10xi1>, memref<10xi1>, memref<10xi1>) -> () + return +} + +// ----- + +func @and_memref(%lhs: memref<10xf32>, %rhs: memref<10xf32>, %out: memref<10xf32>) -> () { + // expected-error @+1 {{must be memref of 8/16/32/64-bit signless integer or 8/16/32/64-bit unsigned integer or pred (AKA boolean or 1-bit integer) values}} + "xla_lhlo.and"(%lhs, %rhs, %out) : (memref<10xf32>, memref<10xf32>, memref<10xf32>) -> () + return +} + +// ----- + +// CHECK-LABEL: func @or_memref +func @or_memref(%lhs: memref<10xi32>, %rhs: memref<10xi32>, %out: memref<10xi32>) -> () { + "xla_lhlo.or"(%lhs, %rhs, %out) : (memref<10xi32>, memref<10xi32>, memref<10xi32>) -> () + return +} + +// ----- + +// CHECK-LABEL: func @or_memref +func @or_memref(%lhs: memref<10xi1>, %rhs: memref<10xi1>, %out: memref<10xi1>) -> () { + "xla_lhlo.or"(%lhs, %rhs, %out) : (memref<10xi1>, memref<10xi1>, memref<10xi1>) -> () + return +} + +// ----- + +func @or_memref(%lhs: memref<10xf32>, %rhs: memref<10xf32>, %out: memref<10xf32>) -> () { + // expected-error @+1 {{must be memref of 8/16/32/64-bit signless integer or 8/16/32/64-bit unsigned integer or pred (AKA boolean or 1-bit integer) values}} + "xla_lhlo.or"(%lhs, %rhs, %out) : (memref<10xf32>, memref<10xf32>, memref<10xf32>) -> () + return +} + +// ----- + +// CHECK-LABEL: func @xor_memref +func @xor_memref(%lhs: memref<10xi32>, %rhs: memref<10xi32>, %out: memref<10xi32>) -> () { + "xla_lhlo.xor"(%lhs, %rhs, %out) : (memref<10xi32>, memref<10xi32>, memref<10xi32>) -> () + return +} + +// ----- + +// CHECK-LABEL: func @xor_memref +func @xor_memref(%lhs: memref<10xi1>, %rhs: memref<10xi1>, %out: memref<10xi1>) -> () { + "xla_lhlo.xor"(%lhs, %rhs, %out) : (memref<10xi1>, memref<10xi1>, memref<10xi1>) -> () + return +} + +// ----- + +func @xor_memref(%lhs: memref<10xf32>, %rhs: memref<10xf32>, %out: memref<10xf32>) -> () { + // expected-error @+1 {{must be memref of 8/16/32/64-bit signless integer or 8/16/32/64-bit unsigned integer or pred (AKA boolean or 1-bit integer) values}} + "xla_lhlo.xor"(%lhs, %rhs, %out) : (memref<10xf32>, memref<10xf32>, memref<10xf32>) -> () + return +} + +// ----- + +// CHECK-LABEL: func @broadcast_in_dim_memref +func @broadcast_in_dim_memref(%arg0: memref<1x2xi32>, %out: memref<1x2x2xi32>) -> () { + "xla_lhlo.broadcast_in_dim"(%arg0, %out) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (memref<1x2xi32>, memref<1x2x2xi32>) -> () + return +} + +// ----- + +// CHECK-LABEL: func @broadcast_in_dim_zero_rank_memref +func @broadcast_in_dim_zero_rank_memref(%arg0: memref, %out: memref<1x2x3xi32>) -> () { + "xla_lhlo.broadcast_in_dim"(%arg0, %out) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (memref, memref<1x2x3xi32>) -> () + return +} + +// ----- + + +// CHECK-LABEL: func @reduce_memref +func @reduce_memref(%input: memref<10xf32>, %init: memref, %out: memref<1xf32>) -> () { + "xla_lhlo.reduce"(%input, %init, %out) ( { + ^bb0(%arg1: memref, %arg2: memref, %result: memref): + "xla_lhlo.add"(%arg1, %arg2, %result) : (memref, memref, memref) -> () + "xla_lhlo.terminator"() : () -> () + } ) {dimensions = dense<[0]> : tensor<1xi64>} : (memref<10xf32>, memref, memref<1xf32>) -> () + return +} + +// ----- + +// CHECK-LABEL: func @fusion_memref +func @fusion_memref(%input1: memref<10xf32>, %input2: memref<10xf32>, %input3: memref<10xf32>, %out: memref<10xf32>) -> () { + "xla_lhlo.fusion"() ( { + %0 = tensor_load %input1 : memref<10xf32> + %1 = tensor_load %input2 : memref<10xf32> + %2 = "xla_hlo.add"(%0, %1) {name = "add"} : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32> + %3 = tensor_load %input3 : memref<10xf32> + %4 = "xla_hlo.multiply"(%2, %3) {name = "multiply"} : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32> + tensor_store %4, %out : memref<10xf32> + "xla_lhlo.terminator"() : () -> () + } ) : () -> () + return +} + +// ----- + +// CHECK-LABEL: func @case_memref +func @case_memref(%index: memref, %operand_1: memref, %operand_2: memref, %operand_3: memref, %out: memref) -> () { + "xla_lhlo.case"(%index, %operand_1, %operand_2, %operand_3, %out) ( { + ^bb0(%arg0: memref): + "xla_lhlo.negate"(%arg0, %out) : (memref, memref) -> () + "xla_lhlo.terminator"() : () -> () + }, { + ^bb0(%arg0: memref): + "xla_lhlo.copy"(%arg0, %out) : (memref, memref) -> () + "xla_lhlo.terminator"() : () -> () + }, { + ^bb0(%arg0: memref): + "xla_lhlo.add"(%arg0, %arg0, %out) : (memref, memref, memref) -> () + "xla_lhlo.terminator"() : () -> () + } + ) {operand_segment_sizes = dense<[1, 3, 1]> : vector<3xi32>} + : (memref, memref, memref, memref, memref) -> () + return +} + +// ----- + +func @static_memref_cast(%in: memref<10x1xf32>) { + %out = xla_lhlo.static_memref_cast %in + : memref<10x1xf32> -> memref<10xf32, offset: 0, strides: [1]> + return +} +// CHECK-LABEL: func @static_memref_cast + +// ----- + +func @static_memref_cast_dynamic_operand(%in: memref<10x?xf32>) { + // expected-error @+1 {{operand must have static shape}} + %out = xla_lhlo.static_memref_cast %in + : memref<10x?xf32> -> memref<10x1xf32, offset: 0, strides: [10, 1]> + return +} + +// ----- + +func @static_memref_cast_dynamic_result(%in: memref<10x1xf32>) { + // expected-error @+1 {{result must have static shape}} + %out = xla_lhlo.static_memref_cast %in + : memref<10x1xf32> -> memref<10x?xf32, offset: 0, strides: [?, ?]> + return +} + +// ----- + +func @dynamic_memref_cast(%in: memref) { + %size = constant 10 : index + %step = constant 1 : index + %out = xla_lhlo.dynamic_memref_cast %in(%size)[%step] + : memref -> memref + return +} +// CHECK-LABEL: func @dynamic_memref_cast + +// ----- + +func @dynamic_memref_cast_incompatible_result_type(%in: memref) { + // expected-error @+3 {{`sizes` args count must be equal to the rank of the output memref}} + %size = constant 10 : index + %step = constant 1 : index + %out = xla_lhlo.dynamic_memref_cast %in(%size)[%step] + : memref -> memref + return +} +// ----- + +// CHECK-LABEL: func @reshape_memref_cast( +func @reshape_memref_cast(%unranked: memref<*xf32>, %shape1: memref<1xi32>, + %shape2: memref<2xi32>, %shape3: memref) { + // CHECK-SAME: [[UNRANKED:%.*]]: memref<*xf32>, [[SHAPE_1:%.*]]: memref<1xi32>, + // CHECK-SAME: [[SHAPE_2:%.*]]: memref<2xi32>, [[SHAPE_3:%.*]]: memref + + // CHECK-NEXT: [[DYN_VEC:%.*]] = xla_lhlo.reshape_memref_cast [[UNRANKED]] + // CHECK-SAME: : (memref<*xf32>, memref<1xi32>) -> memref + %dyn_vec = xla_lhlo.reshape_memref_cast %unranked(%shape1) + : (memref<*xf32>, memref<1xi32>) -> memref + + // CHECK-NEXT: [[DYN_MAT:%.*]] = xla_lhlo.reshape_memref_cast [[DYN_VEC]] + // CHECK-SAME: : (memref, memref<2xi32>) -> memref + %dyn_mat = xla_lhlo.reshape_memref_cast %dyn_vec(%shape2) + : (memref, memref<2xi32>) -> memref + + // CHECK-NEXT: {{%.*}} = xla_lhlo.reshape_memref_cast [[DYN_MAT]] + // CHECK-SAME: : (memref, memref) -> memref<*xf32> + %new_unranked = xla_lhlo.reshape_memref_cast %dyn_mat(%shape3) + : (memref, memref) -> memref<*xf32> + return +} + +// ----- + +func @reshape_memref_cast_element_type_mismatch( + %buf: memref<*xf32>, %shape: memref<1xi32>) { + // expected-error @+1 {{element types of source and destination memref types should be the same}} + xla_lhlo.reshape_memref_cast %buf(%shape) + : (memref<*xf32>, memref<1xi32>) -> memref +} + +// ----- + +func @reshape_memref_cast_dst_ranked_shape_unranked( + %buf: memref<*xf32>, %shape: memref) { + // expected-error @+1 {{cannot use shape operand with dynamic length to cast statically-ranked memref type}} + xla_lhlo.reshape_memref_cast %buf(%shape) + : (memref<*xf32>, memref) -> memref + return +} + +// ----- + +func @reshape_memref_cast_dst_shape_rank_mismatch( + %buf: memref<*xf32>, %shape: memref<1xi32>) { + // expected-error @+1 {{length of shape operand differs from the result's memref rank}} + xla_lhlo.reshape_memref_cast %buf(%shape) + : (memref<*xf32>, memref<1xi32>) -> memref + return +} + +// ----- + +func @reshape_memref_cast_affine_map_is_not_identity( + %buf: memref<4x4xf32, offset: 0, strides: [3, 2]>, + %shape: memref<1xi32>) { + // expected-error @+1 {{operand memref type should have identity affine map}} + xla_lhlo.reshape_memref_cast %buf(%shape) + : (memref<4x4xf32, offset: 0, strides: [3, 2]>, memref<1xi32>) + -> memref<8xf32> + return +} + +// ----- + +// CHECK-LABEL: func @atan2_memrefs +func @atan2_memrefs(%arg0: memref<1xf32>, %arg1: memref<1xf32>, %arg_out: memref<1xf32>) -> () { + "xla_lhlo.atan2"(%arg0, %arg1, %arg_out) : (memref<1xf32>, memref<1xf32>, memref<1xf32>) -> () + return +} + +// ----- + +// CHECK-LABEL: func @atan2_memrefs +func @atan2_memrefs(%arg0: memref<1xcomplex>, %arg1: memref<1xcomplex>, %arg_out: memref<1xcomplex>) -> () { + "xla_lhlo.atan2"(%arg0, %arg1, %arg_out) : (memref<1xcomplex>, memref<1xcomplex>, memref<1xcomplex>) -> () + return +} + +// ----- + +func @atan2_memrefs(%arg0: memref<1xi32>, %arg1: memref<1xi32>, %arg_out: memref<1xi32>) -> () { + // expected-error@+1{{must be memref of floating-point or complex-type values}} + "xla_lhlo.atan2"(%arg0, %arg1, %arg_out) : (memref<1xi32>, memref<1xi32>, memref<1xi32>) -> () + return +} + +// ----- + +// CHECK-LABEL: func @bitcast_convert_memrefs +func @bitcast_convert_memrefs(%arg0: memref<1xf32>, %arg_out: memref<1xi32>) -> () { + "xla_lhlo.bitcast_convert"(%arg0, %arg_out) : (memref<1xf32>, memref<1xi32>) -> () + return +} + +// ----- + +func @bitcast_convert_memrefs(%arg0: memref<1xf32>, %arg_out: memref<2xi32>) -> () { + // expected-error@+1{{requires the same shape for all operands}} + "xla_lhlo.bitcast_convert"(%arg0, %arg_out) : (memref<1xf32>, memref<2xi32>) -> () + return +} + +// ----- + +// CHECK-LABEL: func @clz_memrefs +func @clz_memrefs(%arg0: memref<1xi32>, %arg_out: memref<1xi32>) -> () { + "xla_lhlo.count_leading_zeros"(%arg0, %arg_out) : (memref<1xi32>, memref<1xi32>) -> () + return +} + +// ----- + +// CHECK-LABEL: func @expm1_memrefs +func @expm1_memrefs(%arg0: memref<1xf32>, %arg_out: memref<1xf32>) -> () { + "xla_lhlo.exponential_minus_one"(%arg0, %arg_out) : (memref<1xf32>, memref<1xf32>) -> () + return +} + +// ----- + +// CHECK-LABEL: func @expm1_memrefs +func @expm1_memrefs(%arg0: memref<1xcomplex>, %arg_out: memref<1xcomplex>) -> () { + "xla_lhlo.exponential_minus_one"(%arg0, %arg_out) : (memref<1xcomplex>, memref<1xcomplex>) -> () + return +} + +// ----- + +// CHECK-LABEL: func @floor_memrefs +func @floor_memrefs(%arg0: memref<1xf32>, %arg_out: memref<1xf32>) -> () { + "xla_lhlo.floor"(%arg0, %arg_out) : (memref<1xf32>, memref<1xf32>) -> () + return +} + +// ----- + +func @floor_memrefs(%arg0: memref<1xi32>, %arg_out: memref<1xi32>) -> () { + // expected-error@+1{{must be memref of floating-point values}} + "xla_lhlo.floor"(%arg0, %arg_out) : (memref<1xi32>, memref<1xi32>) -> () + return +} + +// ----- + +// CHECK-LABEL: func @imag_memrefs +func @imag_memrefs(%arg0: memref<1xcomplex>, %arg_out: memref<1xf32>) -> () { + "xla_lhlo.imag"(%arg0, %arg_out) : (memref<1xcomplex>, memref<1xf32>) -> () + return +} + +// ----- + +func @imag_memrefs(%arg0: memref<1xf32>, %arg_out: memref<1xf32>) -> () { + // expected-error@+1{{must be memref of complex-type values}} + "xla_lhlo.imag"(%arg0, %arg_out) : (memref<1xf32>, memref<1xf32>) -> () + return +} + +// ----- + +// CHECK-LABEL: func @real_memrefs +func @real_memrefs(%arg0: memref<1xcomplex>, %arg_out: memref<1xf32>) -> () { + "xla_lhlo.real"(%arg0, %arg_out) : (memref<1xcomplex>, memref<1xf32>) -> () + return +} + +// ----- + +func @real_memrefs(%arg0: memref<1xf32>, %arg_out: memref<1xf32>) -> () { + // expected-error@+1{{must be memref of complex-type values}} + "xla_lhlo.real"(%arg0, %arg_out) : (memref<1xf32>, memref<1xf32>) -> () + return +} + +// ----- + +// CHECK-LABEL: func @is_finite_memrefs +func @is_finite_memrefs(%arg0: memref<1xf32>, %arg_out: memref<1xi1>) -> () { + "xla_lhlo.is_finite"(%arg0, %arg_out) : (memref<1xf32>, memref<1xi1>) -> () + return +} + +// ----- + +// CHECK-LABEL: func @log1p_memrefs +func @log1p_memrefs(%arg0: memref<1xf32>, %arg_out: memref<1xf32>) -> () { + "xla_lhlo.log_plus_one"(%arg0, %arg_out) : (memref<1xf32>, memref<1xf32>) -> () + return +} + +// ----- + +// CHECK-LABEL: func @log1p_memrefs +func @log1p_memrefs(%arg0: memref<1xcomplex>, %arg_out: memref<1xcomplex>) -> () { + "xla_lhlo.log_plus_one"(%arg0, %arg_out) : (memref<1xcomplex>, memref<1xcomplex>) -> () + return +} + +// ----- + +func @log1p_memref(%in: memref<10xi32>, %out: memref<10xi32>) -> () { + // expected-error@+1{{must be memref of floating-point or complex-type values}} + "xla_lhlo.log_plus_one"(%in, %out) : (memref<10xi32>, memref<10xi32>) -> () + return +} + +// ----- + +// CHECK-LABEL: func @not_memrefs +func @not_memrefs(%arg0: memref<1xi32>, %arg_out: memref<1xi32>) -> () { + "xla_lhlo.not"(%arg0, %arg_out) : (memref<1xi32>, memref<1xi32>) -> () + return +} + +// ----- + +// CHECK-LABEL: func @not_memrefs +func @not_memrefs(%arg0: memref<1xi1>, %arg_out: memref<1xi1>) -> () { + "xla_lhlo.not"(%arg0, %arg_out) : (memref<1xi1>, memref<1xi1>) -> () + return +} + +// ----- + +func @not_memrefs(%arg0: memref<1xf32>, %arg_out: memref<1xf32>) -> () { + // expected-error @+1 {{must be memref of 8/16/32/64-bit signless integer or 8/16/32/64-bit unsigned integer or pred (AKA boolean or 1-bit integer) values}} + "xla_lhlo.not"(%arg0, %arg_out) : (memref<1xf32>, memref<1xf32>) -> () + return +} + +// ----- + +// CHECK-LABEL: func @popcnt_memrefs +func @popcnt_memrefs(%arg0: memref<1xi32>, %arg_out: memref<1xi32>) -> () { + "xla_lhlo.popcnt"(%arg0, %arg_out) : (memref<1xi32>, memref<1xi32>) -> () + return +} + +// ----- + +func @popcnt_memrefs(%arg0: memref<1xf32>, %arg_out: memref<1xf32>) -> () { + // expected-error @+1 {{must be memref of 8/16/32/64-bit signless integer or 8/16/32/64-bit unsigned integer values}} + "xla_lhlo.popcnt"(%arg0, %arg_out) : (memref<1xf32>, memref<1xf32>) -> () + return +} + +// ----- + +// CHECK-LABEL: func @reduce_precision_memrefs +func @reduce_precision_memrefs(%arg0: memref<1xf32>, %arg_out: memref<1xf32>) -> () { + "xla_lhlo.reduce_precision"(%arg0, %arg_out) { exponent_bits = 4 : i32, mantissa_bits = 4 : i32 } : (memref<1xf32>, memref<1xf32>) -> () + return +} + +// ----- + +// CHECK-LABEL: func @round_memrefs +func @round_memrefs(%arg0: memref<1xf32>, %arg_out: memref<1xf32>) -> () { + "xla_lhlo.round_nearest_afz"(%arg0, %arg_out) : (memref<1xf32>, memref<1xf32>) -> () + return +} + +// ----- + +func @round_memrefs(%arg0: memref<1xi32>, %arg_out: memref<1xi32>) -> () { + // expected-error@+1{{must be memref of floating-point values}} + "xla_lhlo.round_nearest_afz"(%arg0, %arg_out) : (memref<1xi32>, memref<1xi32>) -> () + return +} + +// ----- + +// CHECK-LABEL: func @shift_left_memrefs +func @shift_left_memrefs(%arg0: memref<1xi32>, %arg1: memref<1xi32>, %arg_out: memref<1xi32>) -> () { + "xla_lhlo.shift_left"(%arg0, %arg1, %arg_out) : (memref<1xi32>, memref<1xi32>, memref<1xi32>) -> () + return +} + +// ----- + +func @shift_left_memrefs(%arg0: memref<1xf32>, %arg1: memref<1xf32>, %arg_out: memref<1xf32>) -> () { + // expected-error @+1 {{must be memref of 8/16/32/64-bit signless integer or 8/16/32/64-bit unsigned integer values}} + "xla_lhlo.shift_left"(%arg0, %arg1, %arg_out) : (memref<1xf32>, memref<1xf32>, memref<1xf32>) -> () + return +} + +// ----- + +// CHECK-LABEL: func @shift_right_arithmetic_memrefs +func @shift_right_arithmetic_memrefs(%arg0: memref<1xi32>, %arg1: memref<1xi32>, %arg_out: memref<1xi32>) -> () { + "xla_lhlo.shift_right_arithmetic"(%arg0, %arg1, %arg_out) : (memref<1xi32>, memref<1xi32>, memref<1xi32>) -> () + return +} + +// ----- + +func @shift_right_arithmetic_memrefs(%arg0: memref<1xf32>, %arg1: memref<1xf32>, %arg_out: memref<1xf32>) -> () { + // expected-error @+1 {{must be memref of 8/16/32/64-bit signless integer or 8/16/32/64-bit unsigned integer values}} + "xla_lhlo.shift_right_arithmetic"(%arg0, %arg1, %arg_out) : (memref<1xf32>, memref<1xf32>, memref<1xf32>) -> () + return +} + +// ----- + +// CHECK-LABEL: func @shift_right_logical_memrefs +func @shift_right_logical_memrefs(%arg0: memref<1xi32>, %arg1: memref<1xi32>, %arg_out: memref<1xi32>) -> () { + "xla_lhlo.shift_right_logical"(%arg0, %arg1, %arg_out) : (memref<1xi32>, memref<1xi32>, memref<1xi32>) -> () + return +} + +// ----- + +func @shift_right_logical_memrefs(%arg0: memref<1xf32>, %arg1: memref<1xf32>, %arg_out: memref<1xf32>) -> () { + // expected-error @+1 {{must be memref of 8/16/32/64-bit signless integer or 8/16/32/64-bit unsigned integer values}} + "xla_lhlo.shift_right_logical"(%arg0, %arg1, %arg_out) : (memref<1xf32>, memref<1xf32>, memref<1xf32>) -> () + return +} + +// ----- + +// CHECK-LABEL: func @all_reduce_memrefs +func @all_reduce_memrefs(%arg0: memref<10xf32>, %arg_out: memref<10xf32>) -> () { + "xla_lhlo.all_reduce"(%arg0, %arg_out) ({ + ^bb0(%lhs: tensor, %rhs: tensor): + %max = xla_hlo.maximum %lhs, %rhs : tensor + "xla_hlo.return"(%max) : (tensor) -> () + }) + { replica_groups = dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64> }: (memref<10xf32>, memref<10xf32>) -> () + + "xla_lhlo.all_reduce"(%arg0, %arg_out) ({ + ^bb0(%lhs: tensor, %rhs: tensor): + %max = xla_hlo.maximum %lhs, %rhs : tensor + "xla_hlo.return"(%max) : (tensor) -> () + }) + { + replica_groups = dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64>, + channel_id = { handle = 5 : i64, type = 2 : i64 }, + constrain_layout = true, + use_global_device_ids = true + }: (memref<10xf32>, memref<10xf32>) -> () + return +} + +// ----- + +// CHECK-LABEL: func @collective_permute_memrefs +func @collective_permute_memrefs(%arg0: memref<128x32xf32>, %arg_out: memref<128x32xf32>) -> () { + "xla_lhlo.collective_permute"(%arg0, %arg_out) { + source_target_pairs = dense<[[0, 1], [1, 2], [2, 3]]> : tensor<3x2xi64> + } : (memref<128x32xf32>, memref<128x32xf32>) -> () + + "xla_lhlo.collective_permute"(%arg0, %arg_out) { + source_target_pairs = dense<[[0, 1], [1, 2], [2, 3]]> : tensor<3x2xi64>, + channel_id = { handle = 5 : i64, type = 2 : i64 } + } : (memref<128x32xf32>, memref<128x32xf32>) -> () + return +} + +// ----- + +// CHECK-LABEL: func @fft_memrefs +func @fft_memrefs(%arg0: memref<3x9xf32>, %arg_out: memref<3x5xcomplex>) -> () { + "xla_lhlo.fft"(%arg0, %arg_out) {fft_length = dense<9> : tensor<1xi64>, fft_type = "RFFT"} : (memref<3x9xf32>, memref<3x5xcomplex>) -> () + return +} + +// ----- + +// CHECK-LABEL: func @batch_norm_grad_memrefs +func @batch_norm_grad_memrefs(%arg0: memref<8x8x8x8xf32>, %arg1: memref<8xf32>, %arg2: memref<8xf32>, + %arg3: memref<8xf32>, %arg4: memref<8x8x8x8xf32>, + %grad_operand: memref<8x8x8x8xf32>, %grad_scale: memref<8xf32>, + %grad_offset: memref<8xf32>) -> () { + "xla_lhlo.batch_norm_grad"(%arg0, %arg1, %arg2, %arg3, %arg4, %grad_operand, %grad_scale, %grad_offset) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64} + : (memref<8x8x8x8xf32>, memref<8xf32>, memref<8xf32>, memref<8xf32>, memref<8x8x8x8xf32>, + memref<8x8x8x8xf32>, memref<8xf32>, memref<8xf32>) -> () + return +} + +// ----- + +// CHECK-LABEL: func @batch_norm_inference_memrefs +func @batch_norm_inference_memrefs(%arg0: memref<8x8x8x8xf32>, %arg1: memref<8xf32>, %arg2: memref<8xf32>, + %arg3: memref<8xf32>, %arg4: memref<8xf32>, %arg_out: memref<8x8x8x8xf32>) -> () { + "xla_lhlo.batch_norm_inference"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg_out) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64} + : (memref<8x8x8x8xf32>, memref<8xf32>, memref<8xf32>, memref<8xf32>, memref<8xf32>, memref<8x8x8x8xf32>) -> () + return +} + +// ----- + +// CHECK-LABEL: func @batch_norm_training_memrefs +func @batch_norm_training_memrefs(%arg0: memref<8x8x8x8xf32>, %arg1: memref<8xf32>, %arg2: memref<8xf32>, + %output: memref<8x8x8x8xf32>, %batch_mean: memref<8xf32>, + %batch_var: memref<8xf32>) -> () { + "xla_lhlo.batch_norm_training"(%arg0, %arg1, %arg2, %output, %batch_mean, %batch_var) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64} + : (memref<8x8x8x8xf32>, memref<8xf32>, memref<8xf32>, memref<8x8x8x8xf32>, memref<8xf32>, memref<8xf32>) -> () + return +} + +// ----- + +// CHECK-LABEL: func @cholesky_memrefs +func @cholesky_memrefs(%arg0: memref<1x291x291xf32>, %arg_out: memref<1x291x291xf32>) -> () { + "xla_lhlo.cholesky"(%arg0, %arg_out) : (memref<1x291x291xf32>, memref<1x291x291xf32>) -> () + "xla_lhlo.cholesky"(%arg0, %arg_out) { lower = true } : (memref<1x291x291xf32>, memref<1x291x291xf32>) -> () + return +} + +// ----- + +// CHECK-LABEL: func @infeed_memrefs +func @infeed_memrefs(%arg_out: memref<3xf32>) -> () { + "xla_lhlo.infeed"(%arg_out) { config = "x" } : (memref<3xf32>) -> () + return +} + +// ----- + +// CHECK-LABEL: func @outfeed_memrefs +func @outfeed_memrefs(%arg0: memref<3xf32>) -> () { + "xla_lhlo.outfeed"(%arg0) { config = "x" } : (memref<3xf32>) -> () + return +} + +// ----- + +// CHECK-LABEL: func @replica_id_memrefs +func @replica_id_memrefs(%arg_out: memref) -> () { + "xla_lhlo.replica_id"(%arg_out) : (memref) -> () + return +} + +// ----- + +// CHECK-LABEL: func @triangular_solve_memrefs +func @triangular_solve_memrefs(%arg0: memref<4x4xf32>, %arg1: memref<3x4xf32>, %arg_out: memref<3x4xf32>) -> () { + "xla_lhlo.triangular_solve"(%arg0, %arg1, %arg_out) {left_side = true, lower = true, transpose_a = "NO_TRANSPOSE", unit_diagonal = true} + : (memref<4x4xf32>, memref<3x4xf32>, memref<3x4xf32>) -> () + return +} + +// ----- + +// CHECK-LABEL: func @while_memrefs +func @while_memrefs(%arg0: memref, %arg_out: memref) -> () { + "xla_lhlo.while"(%arg0, %arg_out) ( + { ^bb0(%arg: memref, %cond: memref): "xla_lhlo.terminator"() : () -> () }, + { ^bb0(%arg: memref, %body_out: memref): "xla_lhlo.terminator"() : () -> () } + ) : (memref, memref) -> () + return +} + +// ----- + +// CHECK-LABEL: func @while_memrefs +func @while_memrefs(%arg0: memref, %arg1: memref<5xf32>, %arg0_out: memref, %arg1_out: memref<5xf32>) -> () { + "xla_lhlo.while"(%arg0, %arg1, %arg0_out, %arg1_out) ( + { ^bb0(%cur0: memref, %cur1: memref<5xf32>, %cond: memref): "xla_lhlo.terminator"() : () -> () }, + { ^bb0(%cur0: memref, %cur1: memref<5xf32>, %body_out0: memref, %body_out1: memref<5xf32>): "xla_lhlo.terminator"() : () -> () } + ) : (memref, memref<5xf32>, memref, memref<5xf32>) -> () + return +} + +// ----- + +// CHECK-LABEL: func @bitcast_memrefs +func @bitcast_memrefs(%arg0: memref<1xf64>, %arg_out: memref<2xi32>) -> () { + "xla_lhlo.bitcast"(%arg0, %arg_out) : (memref<1xf64>, memref<2xi32>) -> () + return +} + +// ----- + +// CHECK-LABEL: func @scatter_memrefs +func @scatter_memrefs(%input: memref<200x100x300xf32>, %indices: memref<10x2xi32>, + %updates: memref<10x300xf32>, %arg_out: memref<200x100x300xf32>) -> () { + "xla_lhlo.scatter" (%input, %indices, %updates, %arg_out) ({ + ^bb0(%lhs: tensor, %rhs: tensor): // no predecessors + %add = xla_hlo.add %lhs, %rhs : tensor + "xla_hlo.return"(%add) : (tensor) -> () + }) { + scatter_dimension_numbers = { + update_window_dims = dense<[1]> : tensor<1xi64>, + inserted_window_dims = dense<[0, 1]> : tensor<2xi64>, + scatter_dims_to_operand_dims = dense<[0, 1]> : tensor<2xi64>, + index_vector_dim = 1 : i64 + }, + indices_are_sorted = true, + unique_indices = true + } : (memref<200x100x300xf32>, memref<10x2xi32>, memref<10x300xf32>, memref<200x100x300xf32>) -> () + return +} + +// ----- + +// CHECK-LABEL: func @map_memrefs +func @map_memrefs(%arg0: memref<20xf32>, %arg1: memref<20xf32>, %arg_out: memref<20xf32>) -> () { + "xla_lhlo.map"(%arg0, %arg1, %arg_out) ({ + ^bb0(%a: tensor, %b: tensor): + %c = xla_hlo.add %a, %b : tensor + "xla_hlo.return"(%c) : (tensor) -> () + }) {dimensions = dense<0> : tensor<1xi64>} : (memref<20xf32>, memref<20xf32>, memref<20xf32>) -> () + return +} + +// ----- + +func @map_memrefs(%arg0: memref<20xf32>, %arg1: memref<20xf32>, %arg_out: memref<10xf32>) -> () { + // expected-error@+1{{requires the same shape for all operands}} + "xla_lhlo.map"(%arg0, %arg1, %arg_out) ({ + ^bb0(%a: tensor, %b: tensor): + %c = xla_hlo.add %a, %b : tensor + "xla_hlo.return"(%c) : (tensor) -> () + }) {dimensions = dense<0> : tensor<1xi64>} : (memref<20xf32>, memref<20xf32>, memref<10xf32>) -> () + return +} + +// ----- + +// CHECK-LABEL: func @rng_get_and_update_state_memrefs +func @rng_get_and_update_state_memrefs(%state: memref<1xui64>) -> () { + "xla_lhlo.rng_get_and_update_state"(%state) { delta = 1 : i64 } : (memref<1xui64>) -> () + return +} + +// ----- + +// CHECK-LABEL: func @sort_memrefs +func @sort_memrefs(%arg0: memref<16x16xf32>, %arg1: memref<16x16xf16>, + %out0: memref<16x16xf32>, %out1: memref<16x16xf16>) -> () { + "xla_lhlo.sort"(%arg0, %arg1, %out0, %out1) ( { + ^bb0(%a: tensor, %b: tensor, %c: tensor, %d: tensor): + %7 = "xla_hlo.compare"(%a, %b) {comparison_direction = "GT"} : (tensor, tensor) -> tensor + "xla_hlo.return"(%7) : (tensor) -> () + }) {dimension = 1 : i64, is_stable = true} : (memref<16x16xf32>, memref<16x16xf16>, memref<16x16xf32>, memref<16x16xf16>) -> () + return +} + +// ----- + +// CHECK-LABEL: func @sort_memrefs +func @sort_memrefs(%arg0: memref<16x16xf32>, %arg1: memref<16x16xf16>, + %out0: memref<16x16xf32>, %out1: memref<16x16xf16>) -> () { + "xla_lhlo.sort"(%arg0, %arg1, %out0, %out1) ( { + ^bb0(%a: tensor, %b: tensor, %c: tensor, %d: tensor): + %7 = "xla_hlo.compare"(%a, %b) {comparison_direction = "GT"} : (tensor, tensor) -> tensor + "xla_hlo.return"(%7) : (tensor) -> () + }) {dimension = 1 : i64} : (memref<16x16xf32>, memref<16x16xf16>, memref<16x16xf32>, memref<16x16xf16>) -> () + return +} + +// ----- + +// CHECK-LABEL: func @sort_memrefs +func @sort_memrefs(%arg0: memref<16x16xf32>, %arg1: memref<16x16xf16>, + %out0: memref<16x16xf32>, %out1: memref<16x16xf16>) -> () { + "xla_lhlo.sort"(%arg0, %arg1, %out0, %out1) ( { + ^bb0(%a: tensor, %b: tensor, %c: tensor, %d: tensor): + %7 = "xla_hlo.compare"(%a, %b) {comparison_direction = "GT"} : (tensor, tensor) -> tensor + "xla_hlo.return"(%7) : (tensor) -> () + }) : (memref<16x16xf32>, memref<16x16xf16>, memref<16x16xf32>, memref<16x16xf16>) -> () + return +} diff --git a/tests/lower-complex.mlir b/tests/lower-complex.mlir new file mode 100644 index 0000000..696e225 --- /dev/null +++ b/tests/lower-complex.mlir @@ -0,0 +1,224 @@ +// RUN: mlir-hlo-opt %s -test-xla-chlo-legalize-to-hlo -test-xla-lower-complex | FileCheck %s + +// CHECK-LABEL: @add +func @add(%arg0 : tensor<2xf32>, %arg1 : tensor<2xf32>, %arg2 : tensor<2xf32>, %arg3 : tensor<2xf32>) -> (tensor<2xf32>, tensor<2xf32>) { + %2 = "xla_hlo.complex"(%arg0, %arg1) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex>) + %3 = "xla_hlo.complex"(%arg2, %arg3) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex>) + + // CHECK-DAG: [[VAL0:%.+]] = xla_hlo.add %arg0, %arg2 + // CHECK-DAG: [[VAL1:%.+]] = xla_hlo.add %arg1, %arg3 + %4 = "xla_hlo.add"(%2, %3) : (tensor<2xcomplex>, tensor<2xcomplex>) -> (tensor<2xcomplex>) + %5 = "xla_hlo.real"(%4) : (tensor<2xcomplex>) -> (tensor<2xf32>) + %6 = "xla_hlo.imag"(%4) : (tensor<2xcomplex>) -> (tensor<2xf32>) + + // CHECK: return [[VAL0]], [[VAL1]] + return %5, %6 : tensor<2xf32>, tensor<2xf32> +} + +// CHECK-LABEL: @add_unranked +func @add_unranked(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>, %arg2 : tensor<*xf32>, %arg3 : tensor<*xf32>) -> (tensor<*xf32>, tensor<*xf32>) { + %2 = "xla_hlo.complex"(%arg0, %arg1) : (tensor<*xf32>, tensor<*xf32>) -> (tensor<*xcomplex>) + %3 = "xla_hlo.complex"(%arg2, %arg3) : (tensor<*xf32>, tensor<*xf32>) -> (tensor<*xcomplex>) + + // CHECK-DAG: [[VAL0:%.+]] = xla_hlo.add %arg0, %arg2 + // CHECK-DAG: [[VAL1:%.+]] = xla_hlo.add %arg1, %arg3 + %4 = "xla_hlo.add"(%2, %3) : (tensor<*xcomplex>, tensor<*xcomplex>) -> (tensor<*xcomplex>) + %5 = "xla_hlo.real"(%4) : (tensor<*xcomplex>) -> (tensor<*xf32>) + %6 = "xla_hlo.imag"(%4) : (tensor<*xcomplex>) -> (tensor<*xf32>) + + // CHECK: return [[VAL0]], [[VAL1]] + return %5, %6 : tensor<*xf32>, tensor<*xf32> +} + +// CHECK-LABEL: @sub +func @sub(%arg0 : tensor<2xf32>, %arg1 : tensor<2xf32>, %arg2 : tensor<2xf32>, %arg3 : tensor<2xf32>) -> (tensor<2xf32>, tensor<2xf32>) { + %2 = "xla_hlo.complex"(%arg0, %arg1) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex>) + %3 = "xla_hlo.complex"(%arg2, %arg3) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex>) + + // CHECK-DAG: [[VAL0:%.+]] = xla_hlo.subtract %arg0, %arg2 + // CHECK-DAG: [[VAL1:%.+]] = xla_hlo.subtract %arg1, %arg3 + %4 = "xla_hlo.subtract"(%2, %3) : (tensor<2xcomplex>, tensor<2xcomplex>) -> (tensor<2xcomplex>) + %5 = "xla_hlo.real"(%4) : (tensor<2xcomplex>) -> (tensor<2xf32>) + %6 = "xla_hlo.imag"(%4) : (tensor<2xcomplex>) -> (tensor<2xf32>) + + // CHECK: return [[VAL0]], [[VAL1]] + return %5, %6 : tensor<2xf32>, tensor<2xf32> +} + +// CHECK-LABEL: @sub_unranked +func @sub_unranked(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>, %arg2 : tensor<*xf32>, %arg3 : tensor<*xf32>) -> (tensor<*xf32>, tensor<*xf32>) { + %2 = "xla_hlo.complex"(%arg0, %arg1) : (tensor<*xf32>, tensor<*xf32>) -> (tensor<*xcomplex>) + %3 = "xla_hlo.complex"(%arg2, %arg3) : (tensor<*xf32>, tensor<*xf32>) -> (tensor<*xcomplex>) + + // CHECK-DAG: [[VAL0:%.+]] = xla_hlo.subtract %arg0, %arg2 + // CHECK-DAG: [[VAL1:%.+]] = xla_hlo.subtract %arg1, %arg3 + %4 = "xla_hlo.subtract"(%2, %3) : (tensor<*xcomplex>, tensor<*xcomplex>) -> (tensor<*xcomplex>) + %5 = "xla_hlo.real"(%4) : (tensor<*xcomplex>) -> (tensor<*xf32>) + %6 = "xla_hlo.imag"(%4) : (tensor<*xcomplex>) -> (tensor<*xf32>) + + // CHECK: return [[VAL0]], [[VAL1]] + return %5, %6 : tensor<*xf32>, tensor<*xf32> +} + +// CHECK-LABEL: @mul +func @mul(%arg0 : tensor<2xf32>, %arg1 : tensor<2xf32>, %arg2 : tensor<2xf32>, %arg3 : tensor<2xf32>) -> (tensor<2xf32>, tensor<2xf32>) { + %2 = "xla_hlo.complex"(%arg0, %arg1) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex>) + %3 = "xla_hlo.complex"(%arg2, %arg3) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex>) + + // CHECK-DAG: [[VAL0:%.+]] = xla_hlo.multiply %arg0, %arg2 + // CHECK-DAG: [[VAL1:%.+]] = xla_hlo.multiply %arg1, %arg3 + // CHECK-DAG: [[VAL2:%.+]] = xla_hlo.subtract [[VAL0]], [[VAL1]] + // CHECK-DAG: [[VAL3:%.+]] = xla_hlo.multiply %arg0, %arg3 + // CHECK-DAG: [[VAL4:%.+]] = xla_hlo.multiply %arg1, %arg2 + // CHECK-DAG: [[VAL5:%.+]] = xla_hlo.add [[VAL3]], [[VAL4]] + %4 = "xla_hlo.multiply"(%2, %3) : (tensor<2xcomplex>, tensor<2xcomplex>) -> (tensor<2xcomplex>) + %5 = "xla_hlo.real"(%4) : (tensor<2xcomplex>) -> (tensor<2xf32>) + %6 = "xla_hlo.imag"(%4) : (tensor<2xcomplex>) -> (tensor<2xf32>) + + // CHECK: return %2, %5 : tensor<2xf32>, tensor<2xf32> + return %5, %6 : tensor<2xf32>, tensor<2xf32> +} + +// CHECK-LABEL: @mul_unranked +func @mul_unranked(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>, %arg2 : tensor<*xf32>, %arg3 : tensor<*xf32>) -> (tensor<*xf32>, tensor<*xf32>) { + %2 = "xla_hlo.complex"(%arg0, %arg1) : (tensor<*xf32>, tensor<*xf32>) -> (tensor<*xcomplex>) + %3 = "xla_hlo.complex"(%arg2, %arg3) : (tensor<*xf32>, tensor<*xf32>) -> (tensor<*xcomplex>) + + // CHECK-DAG: [[VAL0:%.+]] = xla_hlo.multiply %arg0, %arg2 + // CHECK-DAG: [[VAL1:%.+]] = xla_hlo.multiply %arg1, %arg3 + // CHECK-DAG: [[VAL2:%.+]] = xla_hlo.subtract [[VAL0]], [[VAL1]] + // CHECK-DAG: [[VAL3:%.+]] = xla_hlo.multiply %arg0, %arg3 + // CHECK-DAG: [[VAL4:%.+]] = xla_hlo.multiply %arg1, %arg2 + // CHECK-DAG: [[VAL5:%.+]] = xla_hlo.add [[VAL3]], [[VAL4]] + %4 = "xla_hlo.multiply"(%2, %3) : (tensor<*xcomplex>, tensor<*xcomplex>) -> (tensor<*xcomplex>) + %5 = "xla_hlo.real"(%4) : (tensor<*xcomplex>) -> (tensor<*xf32>) + %6 = "xla_hlo.imag"(%4) : (tensor<*xcomplex>) -> (tensor<*xf32>) + + // CHECK: return %2, %5 : tensor<*xf32>, tensor<*xf32> + return %5, %6 : tensor<*xf32>, tensor<*xf32> +} + +// CHECK-LABEL: @div +func @div(%arg0 : tensor<2xf32>, %arg1 : tensor<2xf32>, %arg2 : tensor<2xf32>, %arg3 : tensor<2xf32>) -> (tensor<2xf32>, tensor<2xf32>) { + %2 = "xla_hlo.complex"(%arg0, %arg1) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex>) + %3 = "xla_hlo.complex"(%arg2, %arg3) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex>) + + // CHECK-DAG: [[VAL0:%.+]] = "xla_hlo.negate"(%arg3) + + // Compute the numerator's real component: + // numerator.real = lhs.real * rhs.real lhs.imag * rhs.imag + // CHECK-DAG: [[VAL1:%.+]] = xla_hlo.multiply %arg0, %arg2 + // CHECK-DAG: [[VAL2:%.+]] = xla_hlo.multiply %arg1, [[VAL0]] + // CHECK-DAG: [[VAL3:%.+]] = xla_hlo.subtract [[VAL1]], [[VAL2]] + + // Compute the real valued denominator as rhs * con(rhs): + // denominator = rhs.real * rhs.real + rhs.imag * rhs.imag + // CHECK-DAG: [[VAL4:%.+]] = xla_hlo.multiply %arg2, %arg2 + // CHECK-DAG: [[VAL5:%.+]] = xla_hlo.multiply %arg3, [[VAL0]] + // CHECK-DAG: [[VAL6:%.+]] = xla_hlo.subtract [[VAL4]], [[VAL5]] + + // Compute the numerator's imaginary component: + // numerator.imag = lhs.imag * rhs.real - lhs.real * rhs.imag + // CHECK-DAG: [[VAL7:%.+]] = xla_hlo.multiply %arg1, %arg2 + // CHECK-DAG: [[VAL8:%.+]] = xla_hlo.multiply %arg0, [[VAL0]] + // CHECK-DAG: [[VAL9:%.+]] = xla_hlo.add [[VAL8]], [[VAL7]] + + // Divide the numerator by the real valued denominator. + // CHECK-DAG: [[VAL10:%.+]] = xla_hlo.divide [[VAL3]], [[VAL6]] + // CHECK-DAG: [[VAL11:%.+]] = xla_hlo.divide [[VAL9]], [[VAL6]] + %4 = "xla_hlo.divide"(%2, %3) : (tensor<2xcomplex>, tensor<2xcomplex>) -> (tensor<2xcomplex>) + + %5 = "xla_hlo.real"(%4) : (tensor<2xcomplex>) -> (tensor<2xf32>) + %6 = "xla_hlo.imag"(%4) : (tensor<2xcomplex>) -> (tensor<2xf32>) + + // CHECK: return [[VAL10]], [[VAL11]] + return %5, %6 : tensor<2xf32>, tensor<2xf32> +} + +// ----- + +// CHECK-LABEL: @div_unranked +func @div_unranked(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>, %arg2 : tensor<*xf32>, %arg3 : tensor<*xf32>) -> (tensor<*xf32>, tensor<*xf32>) { + %2 = "xla_hlo.complex"(%arg0, %arg1) : (tensor<*xf32>, tensor<*xf32>) -> (tensor<*xcomplex>) + %3 = "xla_hlo.complex"(%arg2, %arg3) : (tensor<*xf32>, tensor<*xf32>) -> (tensor<*xcomplex>) + + // CHECK-DAG: [[VAL0:%.+]] = "xla_hlo.negate"(%arg3) + + // Compute the numerator's real component: + // numerator.real = lhs.real * rhs.real lhs.imag * rhs.imag + // CHECK-DAG: [[VAL1:%.+]] = xla_hlo.multiply %arg0, %arg2 + // CHECK-DAG: [[VAL2:%.+]] = xla_hlo.multiply %arg1, [[VAL0]] + // CHECK-DAG: [[VAL3:%.+]] = xla_hlo.subtract [[VAL1]], [[VAL2]] + + // Compute the real valued denominator as rhs * con(rhs): + // denominator = rhs.real * rhs.real + rhs.imag * rhs.imag + // CHECK-DAG: [[VAL4:%.+]] = xla_hlo.multiply %arg2, %arg2 + // CHECK-DAG: [[VAL5:%.+]] = xla_hlo.multiply %arg3, [[VAL0]] + // CHECK-DAG: [[VAL6:%.+]] = xla_hlo.subtract [[VAL4]], [[VAL5]] + + // Compute the numerator's imaginary component: + // numerator.imag = lhs.imag * rhs.real - lhs.real * rhs.imag + // CHECK-DAG: [[VAL7:%.+]] = xla_hlo.multiply %arg1, %arg2 + // CHECK-DAG: [[VAL8:%.+]] = xla_hlo.multiply %arg0, [[VAL0]] + // CHECK-DAG: [[VAL9:%.+]] = xla_hlo.add [[VAL8]], [[VAL7]] + + // Divide the numerator by the real valued denominator. + // CHECK-DAG: [[VAL10:%.+]] = xla_hlo.divide [[VAL3]], [[VAL6]] + // CHECK-DAG: [[VAL11:%.+]] = xla_hlo.divide [[VAL9]], [[VAL6]] + %4 = "xla_hlo.divide"(%2, %3) : (tensor<*xcomplex>, tensor<*xcomplex>) -> (tensor<*xcomplex>) + + %5 = "xla_hlo.real"(%4) : (tensor<*xcomplex>) -> (tensor<*xf32>) + %6 = "xla_hlo.imag"(%4) : (tensor<*xcomplex>) -> (tensor<*xf32>) + + // CHECK: return [[VAL10]], [[VAL11]] + return %5, %6 : tensor<*xf32>, tensor<*xf32> +} + +// CHECK-LABEL: @abs +func @abs(%arg0 : tensor<2xf32>, %arg1 : tensor<2xf32>) -> (tensor<2xf32>) { + %0 = "xla_hlo.complex"(%arg0, %arg1) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex>) + + // CHECK-DAG: [[VAL0:%.+]] = xla_hlo.multiply %arg0, %arg0 + // CHECK-DAG: [[VAL1:%.+]] = xla_hlo.multiply %arg1, %arg1 + // CHECK-DAG: [[VAL2:%.+]] = xla_hlo.add [[VAL0]], [[VAL1]] + // CHECK-DAG: [[VAL3:%.+]] = "xla_hlo.sqrt"([[VAL2]]) + %1 = "xla_hlo.abs"(%0) : (tensor<2xcomplex>) -> (tensor<2xcomplex>) + %2 = "xla_hlo.real"(%1) : (tensor<2xcomplex>) -> (tensor<2xf32>) + + // CHECK: return [[VAL3]] + return %2 : tensor<2xf32> +} + +// CHECK-LABEL: @exp +func @exp(%arg0 : tensor<2xf32>, %arg1 : tensor<2xf32>) -> (tensor<2xf32>, tensor<2xf32>) { + %0 = "xla_hlo.complex"(%arg0, %arg1) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex>) + + // CHECK-DAG: [[VAL0:%.+]] = "xla_hlo.exponential"(%arg0) + // CHECK-DAG: [[VAL1:%.+]] = "xla_hlo.cosine"(%arg1) + // CHECK-DAG: [[VAL2:%.+]] = "xla_hlo.sine"(%arg1) + // CHECK-DAG: [[VAL3:%.+]] = xla_hlo.multiply [[VAL0]], [[VAL1]] + // CHECK-DAG: [[VAL4:%.+]] = xla_hlo.multiply [[VAL0]], [[VAL2]] + %1 = "xla_hlo.exponential"(%0) : (tensor<2xcomplex>) -> (tensor<2xcomplex>) + %2 = "xla_hlo.real"(%1) : (tensor<2xcomplex>) -> (tensor<2xf32>) + %3 = "xla_hlo.imag"(%1) : (tensor<2xcomplex>) -> (tensor<2xf32>) + + // CHECK: return [[VAL3]], [[VAL4]] + return %2, %3 : tensor<2xf32>, tensor<2xf32> +} + +// CHECK-LABEL: @exp_unranked +func @exp_unranked(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>) -> (tensor<*xf32>, tensor<*xf32>) { + %0 = "xla_hlo.complex"(%arg0, %arg1) : (tensor<*xf32>, tensor<*xf32>) -> (tensor<*xcomplex>) + + // CHECK-DAG: [[VAL0:%.+]] = "xla_hlo.exponential"(%arg0) + // CHECK-DAG: [[VAL1:%.+]] = "xla_hlo.cosine"(%arg1) + // CHECK-DAG: [[VAL2:%.+]] = "xla_hlo.sine"(%arg1) + // CHECK-DAG: [[VAL3:%.+]] = xla_hlo.multiply [[VAL0]], [[VAL1]] + // CHECK-DAG: [[VAL4:%.+]] = xla_hlo.multiply [[VAL0]], [[VAL2]] + %1 = "xla_hlo.exponential"(%0) : (tensor<*xcomplex>) -> (tensor<*xcomplex>) + %2 = "xla_hlo.real"(%1) : (tensor<*xcomplex>) -> (tensor<*xf32>) + %3 = "xla_hlo.imag"(%1) : (tensor<*xcomplex>) -> (tensor<*xf32>) + + // CHECK: return [[VAL3]], [[VAL4]] + return %2, %3 : tensor<*xf32>, tensor<*xf32> +} diff --git a/tests/lower-general-dot.mlir b/tests/lower-general-dot.mlir new file mode 100644 index 0000000..b54a0aa --- /dev/null +++ b/tests/lower-general-dot.mlir @@ -0,0 +1,35 @@ +// RUN: mlir-hlo-opt -test-xla-lower-general-dot -split-input-file %s -o - | FileCheck %s + +// CHECK-LABEL: @testDebatch1 +func @testDebatch1(%arg0: tensor<1x1x2xf32>, %arg1: tensor<2x3xf32>) -> tensor<1x1x3xf32> { + // CHECK-DAG: [[R0:%.+]] = "xla_hlo.reshape"(%arg0) : (tensor<1x1x2xf32>) -> tensor<1x2xf32> + // CHECK-DAG: [[R1:%.+]] = "xla_hlo.dot"([[R0]], %arg1) {precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32> + // CHECK: [[R2:%.+]] = "xla_hlo.reshape"([[R1]]) : (tensor<1x3xf32>) -> tensor<1x1x3xf32> + %0 = "xla_hlo.dot_general"(%arg0, %arg1) {dot_dimension_numbers = {lhs_batching_dimensions = dense<[]> : tensor<0xi64>, lhs_contracting_dimensions = dense<2> : tensor<1xi64>, rhs_batching_dimensions = dense<[]> : tensor<0xi64>, rhs_contracting_dimensions = dense<0> : tensor<1xi64>}, precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<1x1x2xf32>, tensor<2x3xf32>) -> tensor<1x1x3xf32> + + return %0 : tensor<1x1x3xf32> +} + +// ----- + +// CHECK-LABEL: @testDebatch2 +func @testDebatch2(%arg0: tensor<2x3xf32>, %arg1: tensor<1x1x2xf32>) -> tensor<3x1x1xf32> { + // CHECK-DAG: [[R0:%.+]] = "xla_hlo.transpose"(%arg0) {permutation = dense<[1, 0]> : tensor<2xi64>} : (tensor<2x3xf32>) -> tensor<3x2xf32> + // CHECK-DAG: [[R1:%.+]] = "xla_hlo.transpose"(%arg1) {permutation = dense<[2, 0, 1]> : tensor<3xi64>} : (tensor<1x1x2xf32>) -> tensor<2x1x1xf32> + // CHECK-DAG: [[R2:%.+]] = "xla_hlo.reshape"([[R1]]) : (tensor<2x1x1xf32>) -> tensor<2x1xf32> + // CHECK-DAG: [[R3:%.+]] = "xla_hlo.dot"([[R0]], [[R2]]) {precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<3x2xf32>, tensor<2x1xf32>) -> tensor<3x1xf32> + // CHECK: [[R4:%.+]] = "xla_hlo.reshape"([[R3]]) : (tensor<3x1xf32>) -> tensor<3x1x1xf32> + + %0 = "xla_hlo.dot_general"(%arg0, %arg1) {dot_dimension_numbers = {lhs_batching_dimensions = dense<[]> : tensor<0xi64>, lhs_contracting_dimensions = dense<0> : tensor<1xi64>, rhs_batching_dimensions = dense<[]> : tensor<0xi64>, rhs_contracting_dimensions = dense<2> : tensor<1xi64>}, precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<2x3xf32>, tensor<1x1x2xf32>) -> tensor<3x1x1xf32> + return %0 : tensor<3x1x1xf32> +} + +// ----- + +// CHECK-LABEL: @testBatchPassthrough +func @testBatchPassthrough(%arg0: tensor<2x2x3xf32>, %arg1: tensor<2x1x2xf32>) -> tensor<3x2x1xf32> { + // CHECK-NEXT: "xla_hlo.dot_general"(%arg0, %arg1) + %0 = "xla_hlo.dot_general"(%arg0, %arg1) {dot_dimension_numbers = {lhs_batching_dimensions = dense<[0]> : tensor<1xi64>, lhs_contracting_dimensions = dense<1> : tensor<1xi64>, rhs_batching_dimensions = dense<[0]> : tensor<1xi64>, rhs_contracting_dimensions = dense<2> : tensor<1xi64>}, precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<2x2x3xf32>, tensor<2x1x2xf32>) -> tensor<3x2x1xf32> + return %0 : tensor<3x2x1xf32> +} + diff --git a/tests/materialize-broadcasts.mlir b/tests/materialize-broadcasts.mlir new file mode 100644 index 0000000..bfe1fe3 --- /dev/null +++ b/tests/materialize-broadcasts.mlir @@ -0,0 +1,11 @@ +// RUN: mlir-hlo-opt -test-xla-materialize-broadcasts -split-input-file %s -o - | FileCheck %s + +// CHECK-LABEL: @clampBroadcast +// CHECK-SAME: (%[[MIN:.+]]: tensor, %[[VAL:.+]]: tensor<4xf32>, %[[MAX:.+]]: tensor) +func @clampBroadcast(%min: tensor, %value: tensor<4xf32>, %max: tensor) -> tensor<4xf32> { + // CHECK-DAG: %[[MIN_BC:.+]] = "xla_hlo.broadcast"(%[[MIN]]) {broadcast_sizes = dense<4> : tensor<1xi64>} : (tensor) -> tensor<4xf32> + // CHECK-DAG: %[[MAX_BC:.+]] = "xla_hlo.broadcast"(%[[MAX]]) {broadcast_sizes = dense<4> : tensor<1xi64>} : (tensor) -> tensor<4xf32> + // CHECK: "xla_hlo.clamp"(%[[MIN_BC]], %[[VAL]], %[[MAX_BC]]) : (tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + %0 = "xla_hlo.clamp"(%min, %value, %max) : (tensor, tensor<4xf32>, tensor) -> tensor<4xf32> + return %0 : tensor<4xf32> +} diff --git a/tests/ops.mlir b/tests/ops.mlir new file mode 100644 index 0000000..727e747 --- /dev/null +++ b/tests/ops.mlir @@ -0,0 +1,1132 @@ +// RUN: mlir-hlo-opt %s -verify-diagnostics -split-input-file | mlir-hlo-opt | FileCheck %s + +// Tests for types, ops with custom constraints, verifiers, printer or parser +// methods. + +// CHECK-LABEL: func @token_type() -> !xla_hlo.token +func @token_type() -> !xla_hlo.token + +// ----- + +// expected-error@+1 {{unknown xla_hlo type: foobar}} +func @invalid_type() -> !xla_hlo.foobar + +// ----- + +// CHECK-LABEL: func @alltoall +func @alltoall(%data: tensor<4x16xf32>) -> tensor<16x4xf32> { + %0 = "xla_hlo.all_to_all"(%data) { + split_dimension = 1 : i64, + concat_dimension = 0 : i64, + split_count = 4 : i64, + replica_groups = dense<[[0, 1, 2, 3]]> : tensor<1x4xi64> + } : (tensor<4x16xf32>) -> tensor<16x4xf32> + return %0 : tensor<16x4xf32> +} + +// ----- + +// CHECK-LABEL: func @alltoall_unranked_input +func @alltoall_unranked_input(%data: tensor<*xf32>) -> tensor<*xf32> { + %0 = "xla_hlo.all_to_all"(%data) { + split_dimension = 1 : i64, + concat_dimension = 0 : i64, + split_count = 5 : i64, + replica_groups = dense<[[0, 1, 2, 3, 4]]> : tensor<1x5xi64> + } : (tensor<*xf32>) -> tensor<*xf32> + return %0 : tensor<*xf32> +} + +// ----- + +func @alltoall_invalid_split_dim_size(%data: tensor<4x16xf32>) -> tensor<16x4xf32> { +// expected-error@+1 {{split dimension has size 16, expected to be a multiple of split_count 5}} + %0 = "xla_hlo.all_to_all"(%data) { + split_dimension = 1 : i64, + concat_dimension = 0 : i64, + split_count = 5 : i64, + replica_groups = dense<[[0, 1, 2, 3, 4]]> : tensor<1x5xi64> + } : (tensor<4x16xf32>) -> tensor<16x4xf32> + return %0 : tensor<16x4xf32> +} + +// ----- + +// CHECK-LABEL: func @broadcast +func @broadcast(%arg0: tensor<3xi32>) -> tensor<1x2x3xi32> { + %0 = "xla_hlo.broadcast"(%arg0) {broadcast_sizes = dense<[1, 2]> : tensor<2xi64>} : (tensor<3xi32>) -> tensor<1x2x3xi32> + return %0 : tensor<1x2x3xi32> +} + +// ----- + +func @broadcast_bad_sizes_rank(%arg0: tensor<3xi32>) -> tensor<1x2x3xi32> { + // expected-error@+1 {{broadcast_sizes has rank 2 instead of rank 1}} + %0 = "xla_hlo.broadcast"(%arg0) {broadcast_sizes = dense<[[1, 2]]> : tensor<1x2xi64>} : (tensor<3xi32>) -> tensor<1x2x3xi32> + return %0 : tensor<1x2x3xi32> +} + +// ----- + +func @broadcast_bad_result_rank(%arg0: tensor<3xi32>) -> tensor<1x2x3xi32> { + // expected-error@+1 {{result rank (3) does not match operand rank (1) plus size of broadcast_sizes (1)}} + %0 = "xla_hlo.broadcast"(%arg0) {broadcast_sizes = dense<[2]> : tensor<1xi64>} : (tensor<3xi32>) -> tensor<1x2x3xi32> + return %0 : tensor<1x2x3xi32> +} + +// ----- + +func @broadcast_bad_first_part_result_shape(%arg0: tensor<3xi32>) -> tensor<1x2x3xi32> { + // expected-error@+1 {{result has shape [1, 3] instead of [2, 3]}} + %0 = "xla_hlo.broadcast"(%arg0) {broadcast_sizes = dense<[2]> : tensor<1xi64>} : (tensor<3xi32>) -> tensor<1x3xi32> + return %0 : tensor<1x3xi32> +} + +// ----- + +func @broadcast_bad_second_part_result_shape(%arg0: tensor<3xi32>) -> tensor<1x2x3xi32> { + // expected-error@+1 {{result has shape [2, 1] instead of [2, 3]}} + %0 = "xla_hlo.broadcast"(%arg0) {broadcast_sizes = dense<[2]> : tensor<1xi64>} : (tensor<3xi32>) -> tensor<2x1xi32> + return %0 : tensor<2x1xi32> +} + +// ----- + +// CHECK-LABEL: func @broadcast_in_dim +func @broadcast_in_dim(%arg0: tensor<1x2xi32>) -> tensor<1x2x2xi32> { + %0 = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (tensor<1x2xi32>) -> tensor<1x2x2xi32> + return %0 : tensor<1x2x2xi32> +} + +// ----- + +// CHECK-LABEL: func @broadcast_in_dim_zero_rank +func @broadcast_in_dim_zero_rank(%arg0: tensor) -> tensor<1x2x3xi32> { + %0 = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor) -> tensor<1x2x3xi32> + return %0 : tensor<1x2x3xi32> +} + +// ----- + +// CHECK-LABEL: func @dynamic_broadcast_in_dim +func @dynamic_broadcast_in_dim(%arg0: tensor, %shape: tensor<3xi64>) -> tensor { + %0 = "xla_hlo.dynamic_broadcast_in_dim"(%arg0, %shape) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (tensor, tensor<3xi64>) -> tensor + return %0 : tensor +} + +// ----- + +func @broadcast_in_dim_bad_dimension_rank(%arg0: tensor<1x2xi32>) -> tensor<1x2x3xi32> { + // expected-error@+1 {{broadcast_dimensions has rank 2 instead of rank 1}} + %0 = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[[1,1],[1,1]]> : tensor<2x2xi64>} : (tensor<1x2xi32>) -> tensor<1x2x3xi32> + return %0 : tensor<1x2x3xi32> +} + +// ----- + +func @broadcast_in_dim_bad_dimension_size(%arg0: tensor<1x2xi32>) -> tensor<1x2x3xi32> { + // expected-error@+1 {{broadcast_dimensions size (1) does not match operand rank (2)}} + %0 = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[1]> : tensor<1xi64>} : (tensor<1x2xi32>) -> tensor<1x2x3xi32> + return %0 : tensor<1x2x3xi32> +} + +// ----- + +func @broadcast_in_dim_bad_rank_decrease(%arg0: tensor<1x2x3xi32>) -> tensor<3xi32> { + // expected-error@+1 {{result rank (1) is less than operand rank (3)}} + %0 = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[0,1,2]> : tensor<3xi64>} : (tensor<1x2x3xi32>) -> tensor<3xi32> + return %0 : tensor<3xi32> +} + +// ----- + +func @broadcast_in_dim_dimension_values_too_large(%arg0: tensor<1x2xi32>) -> tensor<1x2x3xi32> { + // expected-error@+1 {{broadcast_dimensions contains invalid value 9 for result result with rank 3}} + %0 = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[9, 2]> : tensor<2xi64>} : (tensor<1x2xi32>) -> tensor<1x2x3xi32> + return %0 : tensor<1x2x3xi32> +} + +// ----- + +func @broadcast_in_dim_bad_shape_mismatch(%arg0: tensor<3xi32>) -> tensor<1x2x3xi32> { + // expected-error@+1 {{size of operand dimension 0 (3) is not equal to 1 or size of result dimension 1 (2)}} + %0 = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[1]> : tensor<1xi64>} : (tensor<3xi32>) -> tensor<1x2x3xi32> + return %0 : tensor<1x2x3xi32> +} + +// ----- + +func @case_mismatch_num_args(%index: tensor, %operand_1: tensor, %operand_2: tensor, %operand_3: tensor) -> tensor { + // expected-error@+1 {{expects branch regions to have single argument, but found 2 for branch 1}} + %0 = "xla_hlo.case"(%index, %operand_1, %operand_2, %operand_3) ( { + ^bb0(%arg0: tensor): + %1 = "xla_hlo.negate"(%arg0) : (tensor) -> tensor + "xla_hlo.return"(%1) : (tensor) -> () + }, { + ^bb0(%arg0: tensor, %arg1: tensor): + %1 = "xla_hlo.copy"(%arg0) : (tensor) -> tensor + "xla_hlo.return"(%1) : (tensor) -> () + }, { + ^bb0(%arg0: tensor): + %1 = "xla_hlo.floor"(%arg0) : (tensor) -> tensor + "xla_hlo.return"(%1) : (tensor) -> () + } + ) : (tensor, tensor, tensor, tensor) -> tensor + return %0 : tensor +} + +// ----- + +func @case_mismatch_num_results(%index: tensor, %operand_1: tensor, %operand_2: tensor, %operand_3: tensor) -> tensor { + // expected-error@+1 {{branch 1 returned values do not match op result types}} + %0 = "xla_hlo.case"(%index, %operand_1, %operand_2, %operand_3) ( { + ^bb0(%arg0: tensor): + %1 = "xla_hlo.negate"(%arg0) : (tensor) -> tensor + "xla_hlo.return"(%1) : (tensor) -> () + }, { + ^bb0(%arg0: tensor): + %1 = "xla_hlo.copy"(%arg0) : (tensor) -> tensor + "xla_hlo.return"(%1, %arg0) : (tensor, tensor) -> () + }, { + ^bb0(%arg0: tensor): + %1 = "xla_hlo.floor"(%arg0) : (tensor) -> tensor + "xla_hlo.return"(%1) : (tensor) -> () + } + ) : (tensor, tensor, tensor, tensor) -> tensor + return %0 : tensor +} + +// ----- + +func @case_mismatch_arg_type(%index: tensor, %operand_1: tensor, %operand_2: tensor, %operand_3: tensor) -> tensor { + // expected-error@+1 {{expects operand 2 to be of type 'tensor', but found 'tensor'}} + %0 = "xla_hlo.case"(%index, %operand_1, %operand_2, %operand_3) ( { + ^bb0(%arg0: tensor): + %1 = "xla_hlo.negate"(%arg0) : (tensor) -> tensor + "xla_hlo.return"(%1) : (tensor) -> () + }, { + ^bb0(%arg0: tensor): + %1 = xla_hlo.constant dense<2.0> : tensor + "xla_hlo.return"(%1) : (tensor) -> () + }, { + ^bb0(%arg0: tensor): + %1 = "xla_hlo.floor"(%arg0) : (tensor) -> tensor + "xla_hlo.return"(%1) : (tensor) -> () + } + ) : (tensor, tensor, tensor, tensor) -> tensor + return %0 : tensor +} + +// ----- + +func @case_mismatch_return_type(%index: tensor, %operand_1: tensor, %operand_2: tensor, %operand_3: tensor) -> tensor { + // expected-error@+1 {{branch 1 returned values do not match op result types}} + %0 = "xla_hlo.case"(%index, %operand_1, %operand_2, %operand_3) ( { + ^bb0(%arg0: tensor): + %1 = "xla_hlo.negate"(%arg0) : (tensor) -> tensor + "xla_hlo.return"(%1) : (tensor) -> () + }, { + ^bb0(%arg0: tensor): + %1 = xla_hlo.constant dense<2> : tensor + "xla_hlo.return"(%1) : (tensor) -> () + }, { + ^bb0(%arg0: tensor): + %1 = "xla_hlo.floor"(%arg0) : (tensor) -> tensor + "xla_hlo.return"(%1) : (tensor) -> () + } + ) : (tensor, tensor, tensor, tensor) -> tensor + return %0 : tensor +} + +// ----- + +func @case_empty_region(%index: tensor, %operand_1: tensor) -> () { + // expected-error@+1 {{cannot have empty regions}} + "xla_hlo.case"(%index, %operand_1) ( {} ) : (tensor, tensor) -> tensor + return +} + +// ----- + +// CHECK-LABEL: func @comp_eq +func @comp_eq(%arg0: tensor<3xi32>, %arg1: tensor<3xi32>) -> tensor<3xi1> { + %0 = "xla_hlo.compare"(%arg0, %arg1) {comparison_direction = "EQ"} : (tensor<3xi32>, tensor<3xi32>) -> tensor<3xi1> + return %0 : tensor<3xi1> +} + +// ----- + +func @comp_bad_direction(%arg0: tensor<3xi32>, %arg1: tensor<3xi32>) -> tensor<3xi1> { + // expected-error@+1 {{'comparison_direction' failed to satisfy constraint}} + %0 = "xla_hlo.compare"(%arg0, %arg1) {comparison_direction = "FOOBAR"} : (tensor<3xi32>, tensor<3xi32>) -> tensor<3xi1> + return %0 : tensor<3xi1> +} + +// ----- + +func @collective_permute_duplicate_sources(%arg0: tensor<128x32xf32>) -> tensor<128x32xf32> { + // expected-error@+1 {{duplicate sources not allowed}} + %0 = "xla_hlo.collective_permute"(%arg0) { + source_target_pairs = dense<[[0, 1], [0, 2], [2, 3]]> : tensor<3x2xi64> + } : (tensor<128x32xf32>) -> tensor<128x32xf32> + return %0 : tensor<128x32xf32> +} + +// ----- + +func @collective_permute_duplicate_targets(%arg0: tensor<128x32xf32>) -> tensor<128x32xf32> { + // expected-error@+1 {{duplicate targets not allowed}} + %0 = "xla_hlo.collective_permute"(%arg0) { + source_target_pairs = dense<[[0, 1], [1, 2], [2, 1]]> : tensor<3x2xi64> + } : (tensor<128x32xf32>) -> tensor<128x32xf32> + return %0 : tensor<128x32xf32> +} + +// ----- + +func @collective_permute_duplicate_sources(%arg0: tensor<128x32xf32>) -> tensor<128x32xf32> { + // expected-error@+1 {{expect source_target_pairs attribute to be of rank 2, but got rank 1}} + %0 = "xla_hlo.collective_permute"(%arg0) { + source_target_pairs = dense<[0, 1]> : tensor<2xi64> + } : (tensor<128x32xf32>) -> tensor<128x32xf32> + return %0 : tensor<128x32xf32> +} + +// ----- + +func @collective_permute_duplicate_sources(%arg0: tensor<128x32xf32>) -> tensor<128x32xf32> { + // expected-error@+1 {{expect source_target_pairs attribute of shape (N, 2), but got (2, 3)}} + %0 = "xla_hlo.collective_permute"(%arg0) { + source_target_pairs = dense<[[0, 1, 2], [3, 4, 5]]> : tensor<2x3xi64> + } : (tensor<128x32xf32>) -> tensor<128x32xf32> + return %0 : tensor<128x32xf32> +} + +// ----- + +// CHECK-LABEL: @concat_1D +func @concat_1D(%arg0: tensor<1xi32>, %arg1: tensor<2xi32>) -> tensor<3xi32> { + %0 = "xla_hlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<1xi32>, tensor<2xi32>) -> tensor<3xi32> + return %0 : tensor<3xi32> +} + +// ----- + +func @concat_1D_type_error(%arg0: tensor<1xi32>, %arg1: tensor<2xf32>) -> tensor<3xi32> { + // expected-error@+1 {{'xla_hlo.concatenate' op requires the same element type for all operands and results}} + %0 = "xla_hlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<1xi32>, tensor<2xf32>) -> tensor<3xi32> + return %0 : tensor<3xi32> +} + +// ----- + +// CHECK-LABEL: @concat_1D_unranked +func @concat_1D_unranked(%arg0: tensor<1xi32>, %arg1: tensor<*xi32>) -> tensor<*xi32> { + %0 = "xla_hlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<1xi32>, tensor<*xi32>) -> tensor<*xi32> + return %0 : tensor<*xi32> +} + +// ----- + +func @concat_1D_unranked_error(%arg0: tensor<1xi32>, %arg1: tensor<*xi32>) -> tensor<3xi32> { + // expected-error@+1 {{'xla_hlo.concatenate' op inferred type incompatible with return type of operation}} + %0 = "xla_hlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<1xi32>, tensor<*xi32>) -> tensor<3xi32> + return %0 : tensor<3xi32> +} + +// ----- + +func @concat_1D_error(%arg0: tensor<1xi32>, %arg1: tensor<2xi32>) -> tensor<4xi32> { + // expected-error@+1 {{'xla_hlo.concatenate' op inferred type incompatible with return type of operation}} + %0 = "xla_hlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<1xi32>, tensor<2xi32>) -> tensor<4xi32> + return %0 : tensor<4xi32> +} + +// ----- + +// CHECK-LABEL: func @clamp +func @clamp(%arg0: tensor<1xi32>) -> tensor<1xi32> { + %0 = "xla_hlo.clamp"(%arg0, %arg0, %arg0) : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32> + return %0: tensor<1xi32> +} + +// ----- + +// CHECK-LABEL: func @clamp_scalar +func @clamp_scalar(%arg0: tensor<1xi32>, %arg1: tensor) -> tensor<1xi32> { + %0 = "xla_hlo.clamp"(%arg1, %arg0, %arg1) : (tensor, tensor<1xi32>, tensor) -> tensor<1xi32> + return %0: tensor<1xi32> +} + +// ----- + +func @clamp_invalid_clamp_element_type(%arg0: tensor<1xi32>, %arg1: tensor<1xf32>) -> tensor<1xi32> { + // expected-error@+1 {{'xla_hlo.clamp' op requires the same element type for all operands and results}} + %0 = "xla_hlo.clamp"(%arg1, %arg0, %arg0) : (tensor<1xf32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32> + return %0: tensor<1xi32> +} + +// ----- + +func @clamp_invalid_clamp_shape(%arg0: tensor<1xi32>, %arg1: tensor<2xi32>) -> tensor<1xi32> { + // expected-error@+1 {{min shape [2] is not scalar and does not match operand shape [1]}} + %0 = "xla_hlo.clamp"(%arg1, %arg0, %arg0) : (tensor<2xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32> + return %0: tensor<1xi32> +} + +// ----- + +// CHECK-LABEL: func @dot_vector +func @dot_vector(%arg0: tensor<1x2xi32>, %arg1: tensor<2x1xi32>) -> tensor { + %0 = "xla_hlo.dot"(%arg0, %arg1) : (tensor<1x2xi32>, tensor<2x1xi32>) -> tensor + return %0: tensor +} + +// ----- + +// CHECK-LABEL: func @dot_matrix +func @dot_matrix(%arg0: tensor<2x2xi32>, %arg1: tensor<2x2xi32>) -> tensor<2x2xi32> { + %0 = "xla_hlo.dot"(%arg0, %arg1) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32> + return %0: tensor<2x2xi32> +} + +// ----- + +// CHECK-LABEL: func @dot_precision_config +func @dot_precision_config(%arg0: tensor<2x2xi32>, %arg1: tensor<2x2xi32>) -> tensor<2x2xi32> { + %0 = "xla_hlo.dot"(%arg0, %arg1) {precision_config = ["HIGH", "HIGHEST"]} : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32> + return %0: tensor<2x2xi32> +} + +// ----- + +func @dot_bad_precision_config(%arg0: tensor<2x2xi32>, %arg1: tensor<2x2xi32>) -> tensor<2x2xi32> { + // expected-error@+1 {{'precision_config' failed to satisfy constraint}} + %0 = "xla_hlo.dot"(%arg0, %arg1) {precision_config = ["FOO", "HIGHEST"]} : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32> + return %0: tensor<2x2xi32> +} + +// ----- + +func @infeed_invalid_number_of_results(%token: !xla_hlo.token) -> tuple>, !xla_hlo.token, tensor> { + // expected-error@+1 {{result is expected to be a tuple of size 2, but got 3}} + %0 = "xla_hlo.infeed"(%token) {infeed_config = "foobar"} : (!xla_hlo.token) -> tuple>, !xla_hlo.token, tensor> + return %0 : tuple>, !xla_hlo.token, tensor> +} + +// ----- + +func @infeed_non_token_second_result(%token: !xla_hlo.token) -> tuple>, tensor> { + // expected-error@+1 {{second element of result tuple is expected to be of token type, but got 'tensor'}} + %0 = "xla_hlo.infeed"(%token) {infeed_config = "foobar"} : (!xla_hlo.token) -> tuple>, tensor> + return %0 : tuple>, tensor> +} + +// ----- + +func @iota_scalar() -> tensor { + // expected-error@+1 {{does not support scalars}} + %0 = "xla_hlo.iota"() {iota_dimension = 0 : i64} : () -> tensor + return %0 : tensor +} + +// ----- + +func @iota_invalid_iota_dimension() -> tensor<4xi32> { + // expected-error@+1 {{iota dimension cannot go beyond the output rank or be negative}} + %0 = "xla_hlo.iota"() {iota_dimension = 1 : i64} : () -> tensor<4xi32> + return %0 : tensor<4xi32> +} + +// ----- + +func @map_mismatched_args(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { + // expected-error@+1 {{expects number of operands to match the arity of map computation, but got: 2 and 1}} + %0 = "xla_hlo.map"(%arg0, %arg1) ( { + ^bb0(%arg: tensor): + %1 = xla_hlo.add %arg, %arg {name = "add"} : tensor + "xla_hlo.return"(%1) : (tensor) -> () + }) {dimensions = dense<0> : tensor<1xi64>} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + return %0 : tensor<4xf32> +} + +// ----- + +func @map_non_scalar_computation_operand(%arg0: tensor<4x5xf32>, %arg1: tensor<4x5xf32>) -> tensor<4x5xf32> { + // expected-error@+1 {{computation arguments must be 0-rank tensor, but got: arg #1 of type 'tensor<5xf32>'}} + %0 = "xla_hlo.map"(%arg0, %arg1) ( { + ^bb0(%arg2: tensor, %arg3: tensor<5xf32>): + %1 = xla_hlo.constant {value = dense<2.0> : tensor} : tensor + "xla_hlo.return"(%1) : (tensor) -> () + }) {dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<4x5xf32>, tensor<4x5xf32>) -> tensor<4x5xf32> + return %0 : tensor<4x5xf32> +} + +// ----- + +func @map_mismatch_operand_and_computation_args(%arg0: tensor<4x5xf32>, %arg1: tensor<4x5xf32>) -> tensor<4x5xf32> { + // expected-error@+1 {{element type of operands and computation arguments must match, but got: 'f32' and 'i32'}} + %0 = "xla_hlo.map"(%arg0, %arg1) ( { + ^bb0(%arg2: tensor, %arg3: tensor): + %1 = xla_hlo.constant {value = dense<2.0> : tensor} : tensor + "xla_hlo.return"(%1) : (tensor) -> () + }) {dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<4x5xf32>, tensor<4x5xf32>) -> tensor<4x5xf32> + return %0 : tensor<4x5xf32> +} + +// ----- + +func @map_invalid_number_of_computation_output(%arg0: tensor<4x5xf32>, %arg1: tensor<4x5xf32>) -> tensor<4x5xf32> { + // expected-error@+1 {{computation must return single output, but got: 0}} + %0 = "xla_hlo.map"(%arg0, %arg1) ( { + ^bb0(%arg2: tensor, %arg3: tensor): + %1 = xla_hlo.constant {value = dense<2.0> : tensor} : tensor + "xla_hlo.return"() : () -> () + }) {dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<4x5xf32>, tensor<4x5xf32>) -> tensor<4x5xf32> + return %0 : tensor<4x5xf32> +} + +// ----- + +func @main_non_scalar_computation_output(%arg0: tensor<4x5xf32>, %arg1: tensor<4x5xf32>) -> tensor<4x5xf32> { + // expected-error@+1 {{computation must return 0-rank tensor, but got: 'tensor<5xf32>'}} + %0 = "xla_hlo.map"(%arg0, %arg1) ( { + ^bb0(%arg2: tensor, %arg3: tensor): + %1 = xla_hlo.constant {value = dense<2.0> : tensor} : tensor<5xf32> + "xla_hlo.return"(%1) : (tensor<5xf32>) -> () + }) {dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<4x5xf32>, tensor<4x5xf32>) -> tensor<4x5xf32> + return %0 : tensor<4x5xf32> +} + +// ----- + +func @mismatch_computation_output_type(%arg0: tensor<4x5xf32>, %arg1: tensor<4x5xf32>) -> tensor<4x5xf32> { + // expected-error@+1 {{element type of result and computation output must match, but got: 'f32' and 'i32'}} + %0 = "xla_hlo.map"(%arg0, %arg1) ( { + ^bb0(%arg2: tensor, %arg3: tensor): + %1 = xla_hlo.constant {value = dense<2> : tensor} : tensor + "xla_hlo.return"(%1) : (tensor) -> () + }) {dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<4x5xf32>, tensor<4x5xf32>) -> tensor<4x5xf32> + return %0 : tensor<4x5xf32> +} + +// ----- + +func @map_invalid_dimension_numbers(%arg0: tensor<4x5xf32>, %arg1: tensor<4x5xf32>) -> tensor<4x5xf32> { + // expected-error@+1 {{requires monotonically increasing dimension numbers, but got: dense<[1, 0]> : tensor<2xi64>}} + %0 = "xla_hlo.map"(%arg0, %arg1) ( { + ^bb0(%arg2: tensor, %arg3: tensor): + %1 = xla_hlo.add %arg2, %arg3 {name = "add"} : tensor + "xla_hlo.return"(%1) : (tensor) -> () + }) {dimensions = dense<[1, 0]> : tensor<2xi64>} : (tensor<4x5xf32>, tensor<4x5xf32>) -> tensor<4x5xf32> + return %0 : tensor<4x5xf32> +} + +// ----- + +func @map_mismatch_arguments_and_dimensions(%arg0: tensor<4x5xf32>, %arg1: tensor<4x5xf32>) -> tensor<4x5xf32> { + // expected-error@+1 {{applied to a subset of dimensions currently not supported: operand dimensions = 2, requested map dimensions size = 3}} + %0 = "xla_hlo.map"(%arg0, %arg1) ( { + ^bb0(%arg2: tensor, %arg3: tensor): + %1 = xla_hlo.add %arg2, %arg3 {name = "add"} : tensor + "xla_hlo.return"(%1) : (tensor) -> () + }) {dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<4x5xf32>, tensor<4x5xf32>) -> tensor<4x5xf32> + return %0 : tensor<4x5xf32> +} + +// ----- + +// CHECK-LABEL: func @map_unranked +func @map_unranked(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> { + %0 = "xla_hlo.map"(%arg0, %arg1) ( { + ^bb0(%arg2: tensor, %arg3: tensor): + %1 = xla_hlo.add %arg2, %arg3 {name = "add"} : tensor + "xla_hlo.return"(%1) : (tensor) -> () + }) {dimensions = dense<0> : tensor<1xi64>} : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> + return %0 : tensor<*xf32> +} + +// ----- + +func @recv_invalid_number_of_results(%token: !xla_hlo.token) -> tuple, tensor, !xla_hlo.token> { + // expected-error@+1 {{result is expected to be a tuple of size 2, but got 3}} + %0 = "xla_hlo.recv"(%token) { + channel_id = { + handle = 5 : i64, + type = 3 : i64 // Host to device channel + }, + is_host_transfer = true + } : (!xla_hlo.token) -> tuple, tensor, !xla_hlo.token> + return %0 : tuple, tensor, !xla_hlo.token> +} + +// ----- + +func @recv_non_token_second_result(%token: !xla_hlo.token) -> tuple, tensor> { + // expected-error@+1 {{second element of result tuple is expected to be of token type, but got 'tensor'}} + %0 = "xla_hlo.recv"(%token) { + channel_id = { + handle = 5 : i64, + type = 3 : i64 // Host to device channel + }, + is_host_transfer = true + } : (!xla_hlo.token) -> tuple, tensor> + return %0 : tuple, tensor> +} + +// ----- + +func @rng_uniform_invalid_type(%mu: tensor>, %sigma: tensor) -> tensor<2x3x5xf32> { + %shape = xla_hlo.constant dense<[2, 3, 5]> : tensor<3xi64> + // expected-error@+1 {{but got 'tensor>'}} + %0 = "xla_hlo.rng_uniform"(%mu, %sigma, %shape) : (tensor>, tensor, tensor<3xi64>) -> tensor<2x3x5xf32> + return %0 : tensor<2x3x5xf32> +} + +// ----- + +// CHECK-LABEL: func @select +func @select(%arg0: tensor<2x3xi1>, %arg1: tensor<2x3xi32>, %arg2: tensor<2x3xi32>) -> tensor<2x3xi32> { + %0 = "xla_hlo.select"(%arg0, %arg1, %arg2) : (tensor<2x3xi1>, tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32> + return %0 : tensor<2x3xi32> +} + +// ----- + +// CHECK-LABEL: func @select_scalar_pred +func @select_scalar_pred(%arg0: tensor, %arg1: tensor<2x3xi32>, %arg2: tensor<2x3xi32>) -> tensor<2x3xi32> { + %0 = "xla_hlo.select"(%arg0, %arg1, %arg2) : (tensor, tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32> + return %0 : tensor<2x3xi32> +} + +// ----- + +// CHECK-LABEL: func @select_cast_compatible_types +func @select_cast_compatible_types(%arg0: tensor, %arg1: tensor<*xi32>, %arg2: tensor<2x3xi32>) -> tensor<*xi32> { + %0 = "xla_hlo.select"(%arg0, %arg1, %arg2) : (tensor, tensor<*xi32>, tensor<2x3xi32>) -> tensor<*xi32> + return %0 : tensor<*xi32> +} + +// ----- + +func @select_cast_compatible_types(%arg0: tensor, %arg1: tensor<2x?xi32>, %arg2: tensor) -> tensor { + // TODO(lucyfox): Update once this is supported. + // expected-error@+1 {{currently unsupported operand types: 'tensor<2x?xi32>' and 'tensor'}} + %0 = "xla_hlo.select"(%arg0, %arg1, %arg2) : (tensor, tensor<2x?xi32>, tensor) -> tensor + return %0 : tensor +} + +// ----- + +// CHECK-LABEL: func @select_scalar_x_y +func @select_scalar_x_y(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { + %0 = "xla_hlo.select"(%arg0, %arg1, %arg2) : (tensor, tensor, tensor) -> tensor + return %0 : tensor +} + +// ----- + +func @select_bad_pred_type(%arg0: tensor<3xi32>, %arg1: tensor<2x3xi32>, %arg2: tensor<2x3xi32>) -> tensor<2x3xi32> { + // expected-error@+1 {{must be tensor of pred (AKA boolean or 1-bit integer) values}} + %0 = "xla_hlo.select"(%arg0, %arg1, %arg2) : (tensor<3xi32>, tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32> + return %0 : tensor<2x3xi32> +} + +// ----- + +func @select_bad_shape_mismatch(%arg0: tensor<3xi1>, %arg1: tensor<2x4xi32>, %arg2: tensor<2x3xi32>) -> tensor<2x3xi32> { + // expected-error@+1 {{incompatible operand types: 'tensor<2x4xi32>' and 'tensor<2x3xi32>'}} + %0 = "xla_hlo.select"(%arg0, %arg1, %arg2) : (tensor<3xi1>, tensor<2x4xi32>, tensor<2x3xi32>) -> tensor<2x3xi32> + return %0 : tensor<2x3xi32> +} + +// ----- + +func @select_bad_element_type_mismatch(%arg0: tensor<3xi1>, %arg1: tensor<2x3xf32>, %arg2: tensor<2x3xi32>) -> tensor<2x3xi32> { + // expected-error@+1 {{incompatible operand types: 'tensor<2x3xf32>' and 'tensor<2x3xi32>'}} + %0 = "xla_hlo.select"(%arg0, %arg1, %arg2) : (tensor<3xi1>, tensor<2x3xf32>, tensor<2x3xi32>) -> tensor<2x3xi32> + return %0 : tensor<2x3xi32> +} + +// ----- + +// CHECK-LABEL: func @slice +func @slice(%arg0: tensor<3x4xi32>) -> tensor<1x4xi32> { + %0 = "xla_hlo.slice"(%arg0) {start_indices = dense<[1, 0]> : tensor<2xi64>, limit_indices = dense<[2, 4]> : tensor<2xi64>, strides = dense<[1, 2]> : tensor<2xi64>} : (tensor<3x4xi32>) -> tensor<1x4xi32> + return %0 : tensor<1x4xi32> +} + +// ----- + +func @slice_indices_mismatch(%arg0: tensor<3x4xi32>) -> tensor<1x4xi32> { + // expected-error@+1 {{failed to verify that all of {start_indices, limit_indices, strides} have same type}} + %0 = "xla_hlo.slice"(%arg0) {start_indices = dense<[1, 2, 3]> : tensor<3xi64>, limit_indices = dense<[2, 4]> : tensor<2xi64>, strides = dense<[1, 2]> : tensor<2xi64>} : (tensor<3x4xi32>) -> tensor<1x4xi32> + return %0 : tensor<1x4xi32> +} + +// ----- + +func @slice_operand_result_mismatch(%arg0: tensor<3x4xi32>) -> tensor<1x4xf32> { + // expected-error@+1 {{requires the same element type for all operands and results}} + %0 = "xla_hlo.slice"(%arg0) {start_indices = dense<[1, 0]> : tensor<2xi64>, limit_indices = dense<[2, 4]> : tensor<2xi64>, strides = dense<[1, 2]> : tensor<2xi64>} : (tensor<3x4xi32>) -> tensor<1x4xf32> + return %0 : tensor<1x4xf32> +} + +// ----- + +// CHECK-LABEL: func @dynamic_slice +func @dynamic_slice(%arg0: tensor<3x4xi32>, %arg1: tensor, %arg2: tensor) -> tensor<1x4xi32> { + %0 = "xla_hlo.dynamic-slice"(%arg0, %arg1, %arg2) {slice_sizes = dense<[1, 4]> : tensor<2xi64>} : (tensor<3x4xi32>, tensor, tensor) -> tensor<1x4xi32> + return %0 : tensor<1x4xi32> +} + +// ----- + +func @dynamic_slice_mismatch_indices(%arg0: tensor<3x4xi32>, %arg1: tensor, %arg2: tensor) -> tensor<1x4xi32> { + // expected-error@+1 {{has mismatched number of slice sizes (1) and number of start indices (2)}} + %0 = "xla_hlo.dynamic-slice"(%arg0, %arg1, %arg2) {slice_sizes = dense<[4]> : tensor<1xi64>} : (tensor<3x4xi32>, tensor, tensor) -> tensor<1x4xi32> + return %0 : tensor<1x4xi32> +} + +// ----- + +// CHECK-LABEL: @dynamic_slice_different_indice_element_type +func @dynamic_slice_different_indice_element_type(%arg0: tensor<3x4xi32>, %arg1: tensor) -> tensor<1x4xi32> { + %0 = "xla_hlo.dynamic-slice"(%arg0, %arg1) {slice_sizes = dense<[4]> : tensor<1xi64>} : (tensor<3x4xi32>, tensor) -> tensor<1x4xi32> + return %0 : tensor<1x4xi32> +} + +// ----- + +func @dynamic_slice_mismatch_element_types(%arg0: tensor<3x4xi32>, %arg1: tensor, %arg2: tensor) -> tensor<1x4xf32> { + // expected-error@+1 {{failed to verify that all of {operand, result} have same element type}} + %0 = "xla_hlo.dynamic-slice"(%arg0, %arg1, %arg2) {slice_sizes = dense<[1, 4]> : tensor<2xi64>} : (tensor<3x4xi32>, tensor, tensor) -> tensor<1x4xf32> + return %0 : tensor<1x4xf32> +} + +// ----- + +func @dynamic_slice_invalid_start(%arg0: tensor<3x4xi32>, %arg1: tensor<2xi64>) -> tensor<1x4xi32> { + // expected-error@+1 {{operand #1 must be a 0-dim integer tensor of 8/16/32/64-bit signless integer or 8/16/32/64-bit unsigned integer values, but got 'tensor<2xi64>'}} + %0 = "xla_hlo.dynamic-slice"(%arg0, %arg1) {slice_sizes = dense<[1, 4]> : tensor<2xi64>} : (tensor<3x4xi32>, tensor<2xi64>) -> tensor<1x4xi32> + return %0 : tensor<1x4xi32> +} + +// ----- + +// CHECK-LABEL: @dynamic_update_slice +func @dynamic_update_slice(%input: tensor<3x4xi64>, %update: tensor<2xi64>, %start1: tensor, %start2: tensor) -> tensor<3x4xi64> { + %0 = "xla_hlo.dynamic-update-slice"(%input, %update, %start1, %start2) : (tensor<3x4xi64>, tensor<2xi64>, tensor, tensor) -> tensor<3x4xi64> + return %0 : tensor<3x4xi64> +} + +// ----- + +func @dynamic_update_slice_invalid_start(%input: tensor<3x4xi64>, %update: tensor<2xi64>, %start: tensor<2xi64>) -> tensor<3x4xi64> { + // expected-error@+1 {{operand #2 must be a 0-dim integer tensor of 8/16/32/64-bit signless integer or 8/16/32/64-bit unsigned integer values, but got 'tensor<2xi64>'}} + %0 = "xla_hlo.dynamic-update-slice"(%input, %update, %start) : (tensor<3x4xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor<3x4xi64> + return %0 : tensor<3x4xi64> +} + +// ----- + +// CHECK-LABEL: func @transpose +func @transpose(%arg0: tensor<1x2x3x4xi32>) -> tensor<2x1x4x3xi32> { + %0 = "xla_hlo.transpose"(%arg0) {permutation = dense<[1, 0, 3, 2]> : tensor<4xi64>} : (tensor<1x2x3x4xi32>) -> tensor<2x1x4x3xi32> + return %0: tensor<2x1x4x3xi32> +} + +// ----- + +func @transpose_ranked(%arg0: tensor) -> tensor { + %0 = "xla_hlo.transpose"(%arg0) {permutation = dense<[1, 0, 3, 2]> : tensor<4xi64>} : (tensor) -> tensor + return %0: tensor +} + +// ----- + +func @transpose_unranked(%arg0: tensor<*xi32>) -> tensor<*xi32> { + %0 = "xla_hlo.transpose"(%arg0) {permutation = dense<[1, 0, 3, 2]> : tensor<4xi64>} : (tensor<*xi32>) -> tensor<*xi32> + return %0: tensor<*xi32> +} + +// ----- + +func @transpose_bad_permutations_rank(%arg0: tensor<1x2x3x4xi32>) -> tensor<2x1x4x3xi32> { + // expected-error@+1 {{permutation has rank 2 instead of rank 1}} + %0 = "xla_hlo.transpose"(%arg0) {permutation = dense<[[1]]> : tensor<1x1xi64>} : (tensor<1x2x3x4xi32>) -> tensor<2x1x4x3xi32> + return %0: tensor<2x1x4x3xi32> +} + +// ----- + +func @transpose_bad_permutations_size(%arg0: tensor<1x2x3x4xi32>) -> tensor<2x1x4x3xi32> { + // expected-error@+1 {{operand rank (4) does not match permutation size (1)}} + %0 = "xla_hlo.transpose"(%arg0) {permutation = dense<[1]> : tensor<1xi64>} : (tensor<1x2x3x4xi32>) -> tensor<2x1x4x3xi32> + return %0: tensor<2x1x4x3xi32> +} + +// ----- + +func @transpose_operand_result_rank_mismatch(%arg0: tensor<1x2x3x4xi32>) -> tensor<2xi32> { + // expected-error@+1 {{result rank (1) does not match permutation size (4)}} + %0 = "xla_hlo.transpose"(%arg0) {permutation = dense<[1, 0, 3, 2]> : tensor<4xi64>} : (tensor<1x2x3x4xi32>) -> tensor<2xi32> + return %0: tensor<2xi32> +} + +// ----- + +func @transpose_operand_result_permutation_mismatch(%arg0: tensor<1x?x3x?xi32>) -> tensor { + // expected-error@+1 {{result type tensor is incompatible with the expected type tensor}} + %0 = "xla_hlo.transpose"(%arg0) {permutation = dense<[1, 0, 3, 2]> : tensor<4xi64>} : (tensor<1x?x3x?xi32>) -> tensor + return %0: tensor +} + +// ----- + +func @triangular_solve_unranked(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> { + %0 = "xla_hlo.triangular_solve"(%arg0, %arg1) {left_side = true, lower = true, transpose_a = "NO_TRANSPOSE", unit_diagonal = true} : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> + return %0 : tensor<*xf32> +} + +// ----- + +func @triangular_solve_rank_less_than_2(%arg0: tensor<4xf32>, %arg1: tensor<4x3xf32>) -> tensor<4x3xf32> { + // expected-error@+1 {{operand 'a' must have rank >= 2, but got 'tensor<4xf32>'}} + %0 = "xla_hlo.triangular_solve"(%arg0, %arg1) {left_side = true, lower = true, transpose_a = "NO_TRANSPOSE", unit_diagonal = true} : (tensor<4xf32>, tensor<4x3xf32>) -> tensor<4x3xf32> + return %0 : tensor<4x3xf32> +} + +// ----- + +func @triangular_solve_unequal_minor_dims_a(%arg0: tensor<4x3xf32>, %arg1: tensor<4x3xf32>) -> tensor<4x3xf32> { + // expected-error@+1 {{two minor dimensions of operand 'a' must have equal size, but got 'tensor<4x3xf32>'}} + %0 = "xla_hlo.triangular_solve"(%arg0, %arg1) {left_side = true, lower = true, transpose_a = "NO_TRANSPOSE", unit_diagonal = true} : (tensor<4x3xf32>, tensor<4x3xf32>) -> tensor<4x3xf32> + return %0 : tensor<4x3xf32> +} + +// ----- + +func @triangular_solve_unequal_rank(%arg0: tensor<10x4x4xf32>, %arg1: tensor<4x3xf32>) -> tensor<4x3xf32> { + // expected-error@+1 {{operands must have equal rank, but got 'tensor<10x4x4xf32>' and 'tensor<4x3xf32>'}} + %0 = "xla_hlo.triangular_solve"(%arg0, %arg1) {left_side = true, lower = true, transpose_a = "NO_TRANSPOSE", unit_diagonal = true} : (tensor<10x4x4xf32>, tensor<4x3xf32>) -> tensor<4x3xf32> + return %0 : tensor<4x3xf32> +} + +// ----- + +func @triangular_solve_mismatch_shared_dim(%arg0: tensor<4x4xf32>, %arg1: tensor<3x4xf32>) -> tensor<3x4xf32> { + // expected-error@+1 {{shared dimension of operands 'a' and 'b' does not match, but got 'tensor<4x4xf32>' and 'tensor<3x4xf32>'}} + %0 = "xla_hlo.triangular_solve"(%arg0, %arg1) {left_side = true, lower = true, transpose_a = "NO_TRANSPOSE", unit_diagonal = true} : (tensor<4x4xf32>, tensor<3x4xf32>) -> tensor<3x4xf32> + return %0 : tensor<3x4xf32> +} + +// ----- + +func @triangular_solve_mismatch_leading_dims(%arg0: tensor<10x5x4x4xf32>, %arg1: tensor<10x6x4x3xf32>) -> tensor<10x6x4x3xf32> { + // expected-error@+1 {{leading batch dimensions of the operands must be same, but got 'tensor<10x5x4x4xf32>' and 'tensor<10x6x4x3xf32>'}} + %0 = "xla_hlo.triangular_solve"(%arg0, %arg1) {left_side = true, lower = true, transpose_a = "NO_TRANSPOSE", unit_diagonal = true} : (tensor<10x5x4x4xf32>, tensor<10x6x4x3xf32>) -> tensor<10x6x4x3xf32> + return %0 : tensor<10x6x4x3xf32> +} + +// ----- + +func @triangular_solve_mismatch_result_and_b_type(%arg0: tensor<4x4xf32>, %arg1: tensor<4x3xf32>) -> tensor<4x4xf32> { + // expected-error@+1 {{result and operand 'b' must have same shape, but got 'tensor<4x4xf32>' and 'tensor<4x3xf32>'}} + %0 = "xla_hlo.triangular_solve"(%arg0, %arg1) {left_side = true, lower = true, transpose_a = "NO_TRANSPOSE", unit_diagonal = true} : (tensor<4x4xf32>, tensor<4x3xf32>) -> tensor<4x4xf32> + return %0 : tensor<4x4xf32> +} + +// ----- + +// CHECK-LABEL: func @tuple +func @tuple(%arg0: tensor<1xi32>, %arg1: tensor<1x2xf32>) -> tuple, tensor<1x2xf32>> { + %0 = "xla_hlo.tuple"(%arg0, %arg1) : (tensor<1xi32>, tensor<1x2xf32>) -> tuple, tensor<1x2xf32>> + return %0: tuple, tensor<1x2xf32>> +} + +// ----- + +func @tuple_arg_size_mismatch(%arg0: tensor, %arg1: tensor) -> tuple, tensor, tensor> { + // expected-error@+1 {{has return type tuple, tensor, tensor>, but expected tuple, tensor>}} + %0 = "xla_hlo.tuple"(%arg0, %arg1) : (tensor, tensor) -> tuple, tensor, tensor> + return %0 : tuple, tensor, tensor> +} + +// ----- + +func @tuple_type_mismatch(%arg0: tensor, %arg1: tensor) -> tuple, tensor> { + // expected-error@+1 {{has return type tuple, tensor>, but expected tuple, tensor>}} + %0 = "xla_hlo.tuple"(%arg0, %arg1) : (tensor, tensor) -> tuple, tensor> + return %0 : tuple, tensor> +} + +// ----- + +func @get_tuple_element(%arg0: tuple, tensor>) -> tensor { + %0 = "xla_hlo.get_tuple_element"(%arg0) {index = 0 : i32} : (tuple, tensor>) -> tensor + return %0 : tensor +} + +// ----- + +func @get_tuple_element_token(%arg0: tuple, !xla_hlo.token>) -> !xla_hlo.token { + %0 = "xla_hlo.get_tuple_element"(%arg0) {index = 1 : i32} : (tuple, !xla_hlo.token>) -> !xla_hlo.token + return %0 : !xla_hlo.token +} + +// ----- + +func @get_tuple_element_bad_type(%arg0: tuple, tensor>) -> tensor { + // expected-error@+1 {{has return type tensor, but expected tensor}} + %0 = "xla_hlo.get_tuple_element"(%arg0) {index = 0 : i32} : (tuple, tensor>) -> tensor + return %0 : tensor +} + +// ----- + +func @get_tuple_element_index_out_of_bounds(%arg0: tuple, tensor>) -> tensor { + // expected-error@+1 {{index 2 is out of bounds of operand with size 2}} + %0 = "xla_hlo.get_tuple_element"(%arg0) {index = 2 : i32} : (tuple, tensor>) -> tensor + return %0 : tensor +} + +// ----- + +// CHECK-LABEL: func @and_i32_type +func @and_i32_type(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> { + %0 = "xla_hlo.and"(%arg0, %arg1) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> + return %0 : tensor<4xi32> +} + +// ----- +// CHECK-LABEL: func @or_i1_type +func @or_i1_type(%arg0: tensor<4xi1>, %arg1: tensor<4xi1>) -> tensor<4xi1> { + %0 = "xla_hlo.or"(%arg0, %arg1) : (tensor<4xi1>, tensor<4xi1>) -> tensor<4xi1> + return %0 : tensor<4xi1> +} + +// ----- + +func @or_invalid_f32_type(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { + // expected-error@+1 {{but got 'tensor<4xf32>'}} + %0 = "xla_hlo.or"(%arg0, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + return %0 : tensor<4xf32> +} + +// ----- + +func @floor_invalid_i32_type(%arg0: tensor<4xi32>) -> tensor<4xi32> { + // expected-error@+1 {{must be tensor of floating-point values, but got 'tensor<4xi32>'}} + %0 = "xla_hlo.floor"(%arg0) : (tensor<4xi32>) -> tensor<4xi32> + return %0 : tensor<4xi32> +} + +// ----- + +// Verifiers HLO constant op custom printing and parsing. +// CHECK-LABEL: func @constants +func @constants() -> () { + // CHECK: xla_hlo.constant dense<0> : tensor + %0 = "xla_hlo.constant"() {value = dense<0> : tensor} : () -> (tensor) + + // CHECK: xla_hlo.constant {extra_attr = 3 : i32} dense<0> : tensor + %1 = "xla_hlo.constant"() {extra_attr = 3 : i32, value = dense<0> : tensor} : () -> (tensor) + return +} + +// ----- + +func @constant_invalid() -> () { + // expected-error@+1 {{op failed to verify that all of {value, output} have same type}} + %0 = "xla_hlo.constant"() {value = dense<0> : tensor} : () -> (tensor<*xi32>) + return +} + +// ----- + +func @sort(%input0: tensor<16x16xf32>, %input1: tensor<16x16xi32>) { + // CHECK: xla_hlo.sort + %0 = "xla_hlo.sort"(%input0, %input1) ( { + ^bb0(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor): + %7 = "xla_hlo.compare"(%arg0, %arg1) {comparison_direction = "GT"} : (tensor, tensor) -> tensor + "xla_hlo.return"(%7) : (tensor) -> () + }) {dimension = 1 : i64, is_stable = true} : (tensor<16x16xf32>, tensor<16x16xi32>) -> tuple, tensor<16x16xi32>> + return +} + +// ----- + +func @sort_no_operands() { + // expected-error @+1 {{op requires at least one input}} + %0 = "xla_hlo.sort"() ( { + ^bb0(%arg1: tensor, %arg2: tensor, %arg3: tensor, %arg4: tensor): + %7 = "xla_hlo.compare"(%arg1, %arg2) {comparison_direction = "GT"} : (tensor, tensor) -> tensor + "xla_hlo.return"(%7) : (tensor) -> () + }) {dimension = 1 : i64, is_stable = true} : () -> tuple<> + return +} + +// ----- + +func @sort_unknown_rank(%input0: tensor<*xf32>, %input1: tensor<16x16xi32>) { + %0 = "xla_hlo.sort"(%input0, %input1) ( { + ^bb0(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor): + %7 = "xla_hlo.compare"(%arg0, %arg1) {comparison_direction = "GT"} : (tensor, tensor) -> tensor + "xla_hlo.return"(%7) : (tensor) -> () + }) {dimension = 1 : i64, is_stable = true} : (tensor<*xf32>, tensor<16x16xi32>) -> tuple, tensor<16x16xi32>> + return +} + +// ----- + +func @sort_unknown_rank(%input0: tensor<*xf32>, %input1: tensor<16x16xi32>) { + // expected-error @+1 {{comparator block argument #0 should be of type 'tensor' but got 'tensor'}} + %0 = "xla_hlo.sort"(%input0, %input1) ( { + ^bb0(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor): + %7 = "xla_hlo.compare"(%arg0, %arg1) {comparison_direction = "GT"} : (tensor, tensor) -> tensor + "xla_hlo.return"(%7) : (tensor) -> () + }) {dimension = 1 : i64, is_stable = true} : (tensor<*xf32>, tensor<16x16xi32>) -> tuple, tensor<16x16xi32>> + return +} + +// ----- + +func @sort_different_dims(%input0: tensor<16x8xf32>, %input1: tensor<16x16xi32>) { + // expected-error @+1 {{op requires all inputs to have the same dimensions}} + %0 = "xla_hlo.sort"(%input0, %input1) ( { + ^bb0(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor): + %7 = "xla_hlo.compare"(%arg0, %arg1) {comparison_direction = "GT"} : (tensor, tensor) -> tensor + "xla_hlo.return"(%7) : (tensor) -> () + }) {dimension = 1 : i64, is_stable = true} : (tensor<16x8xf32>, tensor<16x16xi32>) -> tuple, tensor<16x16xi32>> + return +} + +// ----- + +func @sort_dim_out_of_range(%input0: tensor<16x16xf32>, %input1: tensor<16x16xi32>) { + // expected-error @+1 {{dimension attribute value must be in range [-2, 2), but found 10}} + %0 = "xla_hlo.sort"(%input0, %input1) ( { + ^bb0(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor): + %7 = "xla_hlo.compare"(%arg0, %arg1) {comparison_direction = "GT"} : (tensor, tensor) -> tensor + "xla_hlo.return"(%7) : (tensor) -> () + }) {dimension = 10 : i64, is_stable = true} : (tensor<16x16xf32>, tensor<16x16xi32>) -> tuple, tensor<16x16xi32>> + return +} + +// ----- + +func @sort_dim_out_of_range(%input0: tensor<16x16xf32>, %input1: tensor<16x16xi32>) { + // expected-error @+1 {{dimension attribute value must be in range [-2, 2), but found -3}} + %0 = "xla_hlo.sort"(%input0, %input1) ( { + ^bb0(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor): + %7 = "xla_hlo.compare"(%arg0, %arg1) {comparison_direction = "GT"} : (tensor, tensor) -> tensor + "xla_hlo.return"(%7) : (tensor) -> () + }) {dimension = -3 : i64, is_stable = true} : (tensor<16x16xf32>, tensor<16x16xi32>) -> tuple, tensor<16x16xi32>> + return +} + +// ----- + +func @sort_wrong_block_arg_count(%input0: tensor<16x16xf32>, %input1: tensor<16x16xi32>) { + // expected-error @+1 {{op comparator block should have 4 arguments}} + %0 = "xla_hlo.sort"(%input0, %input1) ( { + ^bb0(%arg0: tensor, %arg1: tensor): + %7 = "xla_hlo.compare"(%arg0, %arg1) {comparison_direction = "GT"} : (tensor, tensor) -> tensor + "xla_hlo.return"(%7) : (tensor) -> () + }) {dimension = 1 : i64, is_stable = true} : (tensor<16x16xf32>, tensor<16x16xi32>) -> tuple, tensor<16x16xi32>> + return +} + +// ----- + +func @sort_wrong_block_arg_type(%input0: tensor<16x16xf32>, %input1: tensor<16x16xi32>) { + // expected-error @+1 {{op comparator block argument #3 should be of type 'tensor' but got 'tensor'}} + %0 = "xla_hlo.sort"(%input0, %input1) ( { + ^bb0(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor): + %7 = "xla_hlo.compare"(%arg0, %arg1) {comparison_direction = "GT"} : (tensor, tensor) -> tensor + "xla_hlo.return"(%7) : (tensor) -> () + }) {dimension = 1 : i64, is_stable = true} : (tensor<16x16xf32>, tensor<16x16xi32>) -> tuple, tensor<16x16xi32>> + return +} + +// ----- + +// CHECK: func @dequantize +func @dequantize(%arg: tensor<16x16xi32>) -> tensor<16x64xbf16> { + %0 = "xla_hlo.dequantize"(%arg) {min_range = -0.1 : f32, max_range = 0.1 : f32, mode = "MIN_COMBINED", transpose_output = false} : (tensor<16x16xi32>) -> tensor<16x64xbf16> + return %0 : tensor<16x64xbf16> +} + +// ----- + +func @dequantize_wrong_shape(%arg: tensor<16x16xi32>) -> tensor<16x64xbf16> { + // expected-error @+1 {{mismatched dimensions.}} + %0 = "xla_hlo.dequantize"(%arg) {min_range = -0.1 : f32, max_range = 0.1 : f32, mode = "MIN_COMBINED", transpose_output = true} : (tensor<16x16xi32>) -> tensor<16x64xbf16> + return %0 : tensor<16x64xbf16> +} + +// ----- + +func @dequantize_wrong_size(%arg: tensor<16x16xi32>) -> tensor<16x16xbf16> { + // expected-error @+1 {{last dimension of output should be 4x of the input.}} + %0 = "xla_hlo.dequantize"(%arg) {min_range = -0.1 : f32, max_range = 0.1 : f32, mode = "MIN_COMBINED", transpose_output = false} : (tensor<16x16xi32>) -> tensor<16x16xbf16> + return %0 : tensor<16x16xbf16> +} + +// ----- + +func @dequantize_wrong_mode(%arg: tensor<16x16xi32>) -> tensor<16x64xbf16> { + // expected-error @+1 {{Dequantization mode. Only MIN_COMBINED is supported.}} + %0 = "xla_hlo.dequantize"(%arg) {min_range = -0.1 : f32, max_range = 0.1 : f32, mode = "hello", transpose_output = false} : (tensor<16x16xi32>) -> tensor<16x64xbf16> + return %0 : tensor<16x64xbf16> +} + +// ----- + +func @reshape_invalid_shapes(%operand: tensor<2x4xf32>) -> tensor<3x3xf32> { + // expected-error @+1 {{number of output elements (9) doesn't match expected number of elements (8)}} + %0 = "xla_hlo.reshape"(%operand) : (tensor<2x4xf32>) -> tensor<3x3xf32> + return %0 : tensor<3x3xf32> +} + +// ----- + +func @dot_general(%arg0: tensor, %arg1: tensor) { + // expected-error @+1 {{lhs and rhs should have the same number of batching dimensions}} + %0 = "xla_hlo.dot_general"(%arg0, %arg1) { dot_dimension_numbers = { + lhs_batching_dimensions = dense<0> : tensor<1xi64>, + lhs_contracting_dimensions = dense<2> : tensor<1xi64>, + rhs_batching_dimensions = dense<[]> : tensor<0xi64>, + rhs_contracting_dimensions = dense<1> : tensor<1xi64> + }} : (tensor, tensor) -> tensor + return +} + +// ----- + +func @dot_general(%arg0: tensor, %arg1: tensor) { + // expected-error @+1 {{lhs and rhs should have the same number of contracting dimensions}} + %0 = "xla_hlo.dot_general"(%arg0, %arg1) { dot_dimension_numbers = { + lhs_batching_dimensions = dense<0> : tensor<1xi64>, + lhs_contracting_dimensions = dense<[]> : tensor<0xi64>, + rhs_batching_dimensions = dense<0> : tensor<1xi64>, + rhs_contracting_dimensions = dense<1> : tensor<1xi64> + }} : (tensor, tensor) -> tensor + return +} + +// ----- + +func @compatible_shapes(%arg0: tensor, %shape: tensor<2xindex>) -> tensor { + %0 = "xla_hlo.dynamic_reshape"(%arg0, %shape) : (tensor, tensor<2xindex>) -> tensor + return %0 : tensor +} + +// ----- + +func @incompatible_shapes(%arg0: tensor, %shape: tensor<2xindex>) -> tensor { + // expected-error @+1 {{output should have a rank equal to the number of elements in output_shape}} + %0 = "xla_hlo.dynamic_reshape"(%arg0, %shape) : (tensor, tensor<2xindex>) -> tensor + return %0 : tensor +} diff --git a/tests/reduce.mlir b/tests/reduce.mlir new file mode 100644 index 0000000..4566b63 --- /dev/null +++ b/tests/reduce.mlir @@ -0,0 +1,14 @@ +// RUN: mlir-hlo-opt %s -pass-pipeline='func(canonicalize)' | FileCheck %s + +// CHECK-LABEL: func @noop +// CHECK-SAME: (%[[ARG0:.*]]: tensor<4x8xf32>) +// CHECK: return %[[ARG0]] +func @noop(%arg0: tensor<4x8xf32>) -> tensor<4x8xf32> { + %0 = xla_hlo.constant dense<0.000000e+00> : tensor + %2 = "xla_hlo.reduce"(%arg0, %0) ( { + ^bb0(%arg1: tensor, %arg2: tensor): + %4 = xla_hlo.add %arg1, %arg2 : tensor + "xla_hlo.return"(%4) : (tensor) -> () + }) {dimensions = dense<[]> : tensor<0xi64>} : (tensor<4x8xf32>, tensor) -> tensor<4x8xf32> + return %2 : tensor<4x8xf32> +} diff --git a/tests/reshape.mlir b/tests/reshape.mlir new file mode 100644 index 0000000..c9e6c5a --- /dev/null +++ b/tests/reshape.mlir @@ -0,0 +1,149 @@ +// RUN: mlir-hlo-opt %s -split-input-file -pass-pipeline='func(canonicalize)' | FileCheck %s + +// CHECK-LABEL: func @const_fold_collapse_to_scalar +func @const_fold_collapse_to_scalar() -> tensor { + // CHECK-NEXT: [[CST:%.+]] = xla_hlo.constant dense<42> : tensor + %cst = xla_hlo.constant dense<42> : tensor<1x1xi32> + %0 = "xla_hlo.reshape"(%cst) : (tensor<1x1xi32>) -> tensor + // CHECK-NEXT: return [[CST]] + return %0 : tensor +} + +// ----- + +// CHECK-LABEL: func @const_fold_collapse_to_tensor +func @const_fold_collapse_to_tensor() -> tensor<2xi32> { + // CHECK-NEXT: [[CST:%.+]] = xla_hlo.constant dense<42> : tensor<2xi32> + %cst = xla_hlo.constant dense<42> : tensor<1x2xi32> + %0 = "xla_hlo.reshape"(%cst) : (tensor<1x2xi32>) -> tensor<2xi32> + // CHECK-NEXT: return [[CST]] + return %0 : tensor<2xi32> +} + +// ----- + +// CHECK-LABEL: func @const_fold_expand +func @const_fold_expand() -> tensor<1xi32> { + // CHECK-NEXT: [[CST:%.+]] = xla_hlo.constant dense<42> : tensor<1xi32> + %cst = xla_hlo.constant dense<42> : tensor + %0 = "xla_hlo.reshape"(%cst) : (tensor) -> tensor<1xi32> + // CHECK-NEXT: return [[CST]] + return %0 : tensor<1xi32> +} + +// ----- + +// CHECK-LABEL: func @const_fold_nontrivial +func @const_fold_nontrivial() -> tensor<16xi64> { + // CHECK-NEXT: [[CST:%.+]] = xla_hlo.constant dense<42> : tensor<16xi64> + %cst = xla_hlo.constant dense<42> : tensor<4x4xi64> + %0 = "xla_hlo.reshape"(%cst) : (tensor<4x4xi64>) -> tensor<16xi64> + // CHECK-NEXT: return [[CST]] + return %0 : tensor<16xi64> +} + +// ----- + +// CHECK-LABEL: func @const_fold_flatten +func @const_fold_flatten() -> tensor<16xi64> { + // CHECK-NEXT: [[CST:%.+]] = xla_hlo.constant dense<42> : tensor<16xi64> + %cst = xla_hlo.constant dense<42> : tensor<4x4xi64> + %0 = "xla_hlo.reshape"(%cst) : (tensor<4x4xi64>) -> tensor<16xi64> + // CHECK-NEXT: return [[CST]] + return %0 : tensor<16xi64> +} + +// ----- + +// CHECK-LABEL: func @const_fold_6 +func @const_fold_6() -> tensor<6xi32> { + // CHECK-NEXT: [[CST:%.+]] = xla_hlo.constant dense<[1, 2, 3, 4, 5, 6]> : tensor<6xi32> + %cst = xla_hlo.constant dense<[[1, 2], [3, 4], [5, 6]]> : tensor<3x2xi32> + %0 = "xla_hlo.reshape"(%cst) : (tensor<3x2xi32>) -> tensor<6xi32> + // CHECK-NEXT: return [[CST]] + return %0 : tensor<6xi32> +} + +// ----- + +// CHECK-LABEL: func @const_fold_same_shape +func @const_fold_same_shape() -> tensor<2x3xi32> { + // CHECK-NEXT: [[CST:%.+]] = xla_hlo.constant dense<[ + // CHECK-SAME: [1, 2, 3], [4, 5, 6] + // CHECK-SAME: ]> : tensor<2x3xi32> + %cst = xla_hlo.constant dense<[1, 2, 3, 4, 5, 6]> : tensor<6xi32> + %0 = "xla_hlo.reshape"(%cst) : (tensor<6xi32>) -> tensor<2x3xi32> + // CHECK-NEXT: return [[CST]] + return %0 : tensor<2x3xi32> +} + +// ----- + +// CHECK-LABEL: func @const_fold_float +func @const_fold_float() -> tensor<16xf64> { + // CHECK-NEXT: [[CST:%.+]] = xla_hlo.constant dense<4.2{{0*}}e+00> : tensor<16xf64> + %cst = xla_hlo.constant dense<4.2> : tensor<4x4xf64> + %0 = "xla_hlo.reshape"(%cst) : (tensor<4x4xf64>) -> tensor<16xf64> + // CHECK-NEXT: return [[CST]] + return %0 : tensor<16xf64> +} + +// ----- + +// CHECK-LABEL: func @non_const_same_shape +// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]] +func @non_const_same_shape(%arg : tensor<2x3xi32>) -> tensor<2x3xi32> { + // CHECK-NEXT: return [[ARG]] + %0 = "xla_hlo.reshape"(%arg) : (tensor<2x3xi32>) -> tensor<2x3xi32> + return %0 : tensor<2x3xi32> +} + +// ----- + +// CHECK-LABEL: func @non_const_chained_reshape +// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]] +func @non_const_chained_reshape(%arg : tensor<2x3xi32>) -> (tensor<3x2xi32>, tensor<6xi32>) { + // CHECK-NEXT: "xla_hlo.reshape"([[ARG]]) : (tensor<2x3xi32>) -> tensor<3x2xi32> + // CHECK-NEXT: "xla_hlo.reshape"([[ARG]]) : (tensor<2x3xi32>) -> tensor<6xi32> + %0 = "xla_hlo.reshape"(%arg) : (tensor<2x3xi32>) -> tensor<3x2xi32> + %1 = "xla_hlo.reshape"(%0) : (tensor<3x2xi32>) -> tensor<6xi32> + return %0, %1 : tensor<3x2xi32>, tensor<6xi32> // return both so nothing is removed +} + +// ----- + +// CHECK-LABEL: func @non_const_chained_reshape_unused_parent +// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]] +func @non_const_chained_reshape_unused_parent(%arg : tensor<2x3xi32>) -> tensor<6xi32> { + // CHECK-NEXT: [[RES:%.+]] = "xla_hlo.reshape"([[ARG]]) : (tensor<2x3xi32>) -> tensor<6xi32> + %0 = "xla_hlo.reshape"(%arg) : (tensor<2x3xi32>) -> tensor<3x2xi32> + %1 = "xla_hlo.reshape"(%0) : (tensor<3x2xi32>) -> tensor<6xi32> + // CHECK-NEXT: return [[RES]] + return %1 : tensor<6xi32> +} + +// ----- + +// CHECK-LABEL: func @non_const_chained_reshape_becomes_noop +// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]] +func @non_const_chained_reshape_becomes_noop(%arg : tensor<2x3xi32>) -> tensor<2x3xi32> { + %0 = "xla_hlo.reshape"(%arg) : (tensor<2x3xi32>) -> tensor<3x2xi32> + %1 = "xla_hlo.reshape"(%0) : (tensor<3x2xi32>) -> tensor<2x3xi32> + // CHECK-NEXT: return [[ARG]] + return %1 : tensor<2x3xi32> +} + +// ----- + +// CHECK-LABEL: func @non_const_many_chained_reshapes +// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]] +func @non_const_many_chained_reshapes(%arg : tensor<2x3x4xi32>) -> tensor<1x2x4x3xi32> { + // CHECK-NEXT: [[RES:%.+]] = "xla_hlo.reshape"([[ARG]]) : (tensor<2x3x4xi32>) -> tensor<1x2x4x3xi32> + %0 = "xla_hlo.reshape"(%arg) : (tensor<2x3x4xi32>) -> tensor<4x3x2xi32> + %1 = "xla_hlo.reshape"(%0) : (tensor<4x3x2xi32>) -> tensor<12x2xi32> + %2 = "xla_hlo.reshape"(%1) : (tensor<12x2xi32>) -> tensor<2x12xi32> + %3 = "xla_hlo.reshape"(%2) : (tensor<2x12xi32>) -> tensor<24xi32> + %4 = "xla_hlo.reshape"(%3) : (tensor<24xi32>) -> tensor<1x2x4x3xi32> + // CHECK-NEXT: return [[RES]] + return %4 : tensor<1x2x4x3xi32> +} diff --git a/tests/reverse.mlir b/tests/reverse.mlir new file mode 100644 index 0000000..9a1c113 --- /dev/null +++ b/tests/reverse.mlir @@ -0,0 +1,9 @@ +// RUN: mlir-hlo-opt %s -pass-pipeline='func(canonicalize)' | FileCheck %s + +// CHECK-LABEL: func @noop +// CHECK-SAME: (%[[ARG0:.*]]: tensor<1x2xf32>) +func @noop(%arg0: tensor<1x2xf32>) -> tensor<1x2xf32> { + %0 = "xla_hlo.reverse"(%arg0) {dimensions = dense<[]> : tensor<0xi64>} : (tensor<1x2xf32>) -> tensor<1x2xf32> + // CHECK: return %[[ARG0]] + return %0 : tensor<1x2xf32> +} diff --git a/tests/sink-constants-to-control-flow.mlir b/tests/sink-constants-to-control-flow.mlir new file mode 100644 index 0000000..35682a5 --- /dev/null +++ b/tests/sink-constants-to-control-flow.mlir @@ -0,0 +1,60 @@ +// RUN: mlir-hlo-opt %s -xla-hlo-sink-constants-to-control-flow | FileCheck %s + +// Tests sinking constants to a while loop. + +// CHECK-LABEL: func @sink_const_to_while +func @sink_const_to_while(%arg0: tensor) -> tensor { + // CHECK-NEXT: xla_hlo.while + %c0 = xla_hlo.constant dense<1> : tensor + %c1 = xla_hlo.constant dense<2> : tensor + %0 = "xla_hlo.while"(%arg0) ( { + ^bb0(%arg1: tensor): + // CHECK: %[[ARG1A:.+]]: tensor + // CHECK: %[[C0:.+]] = xla_hlo.constant dense<1> : tensor + // CHECK: "xla_hlo.compare"(%[[C0]], %[[ARG1A]]) + %1 = "xla_hlo.compare"(%c0, %arg1) {comparison_direction = "LT"} : (tensor, tensor) -> tensor + "xla_hlo.return"(%1) : (tensor) -> () + }, { + ^bb0(%arg1: tensor): + // CHECK: %[[ARG1B:.+]]: tensor + // CHECK-DAG: %[[C1:.+]] = xla_hlo.constant dense<2> : tensor + // CHECK-DAG: %[[ADD0:.+]] = xla_hlo.add %[[ARG1B]], %[[ARG1B]] + %2 = xla_hlo.add %arg1, %arg1 : tensor + // CHECK: %[[ADD1:.+]] = xla_hlo.add %[[C1]], %[[ADD0]] + %3 = xla_hlo.add %c1, %2 : tensor + // CHECK: %[[ADD2:.+]] = xla_hlo.add %[[C1]], %[[ADD1]] + %4 = xla_hlo.add %c1, %3 : tensor + "xla_hlo.return"(%4) : (tensor) -> () + }) : (tensor) -> tensor + return %0 : tensor +} + +// Tests sinking constants to a conditional op. + +// CHECK-LABEL: func @sink_const_to_conditional +func @sink_const_to_conditional(%arg0: tensor) -> tensor { + %c0 = xla_hlo.constant dense<1> : tensor + %c1 = xla_hlo.constant dense<2> : tensor + %0 = "xla_hlo.compare"(%arg0, %c0) {comparison_direction = "LT"} : (tensor, tensor) -> tensor + %1 = "xla_hlo.tuple"(%arg0) : (tensor) -> tuple> + // CHECK: xla_hlo.if + %2 = "xla_hlo.if"(%0, %1, %1) ( { + ^bb0(%arg1: tuple>): + // CHECK: %[[C0:.+]] = xla_hlo.constant dense<1> : tensor + %3 = "xla_hlo.get_tuple_element"(%arg1) {index = 0 : i32} : (tuple>) -> tensor + // CHECK: %[[ADD0:.+]] = xla_hlo.add %[[C0]], + %4 = xla_hlo.add %c0, %3 : tensor + %5 = "xla_hlo.tuple"(%4) : (tensor) -> tuple> + "xla_hlo.return"(%5) : (tuple>) -> () + }, { + ^bb0(%arg1: tuple>): + // CHECK: %[[C1:.+]] = xla_hlo.constant dense<2> : tensor + %6 = "xla_hlo.get_tuple_element"(%arg1) {index = 0 : i32} : (tuple>) -> tensor + // CHECK: %[[ADD1:.+]] = xla_hlo.add %[[C1]], + %7 = xla_hlo.add %c1, %6 : tensor + %8 = "xla_hlo.tuple"(%7) : (tensor) -> tuple> + "xla_hlo.return"(%8) : (tuple>) -> () + }) : (tensor, tuple>, tuple>) -> tuple> + %9 = "xla_hlo.get_tuple_element"(%2) {index = 0 : i32} : (tuple>) -> tensor + return %9 : tensor +} diff --git a/tests/transpose.mlir b/tests/transpose.mlir new file mode 100644 index 0000000..ce11a2a --- /dev/null +++ b/tests/transpose.mlir @@ -0,0 +1,29 @@ +// RUN: mlir-hlo-opt %s -split-input-file -pass-pipeline='func(canonicalize)' | FileCheck %s + +// CHECK-LABEL: func @remove_noop +// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]] +func @remove_noop(%arg : tensor<2x3x9x5xi32>) -> tensor<2x3x9x5xi32> { + %0 = "xla_hlo.transpose"(%arg) {permutation = dense<[0, 1, 2, 3]> : tensor<4xi64>}: (tensor<2x3x9x5xi32>) -> tensor<2x3x9x5xi32> + // CHECK-NEXT: return [[ARG]] + return %0 : tensor<2x3x9x5xi32> +} + +// ----- + +// CHECK-LABEL: func @keep_real_transpose +// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]] +func @keep_real_transpose(%arg : tensor<2x3x9x5xi32>) -> tensor<3x2x5x9xi32> { + // CHECK-NEXT: "xla_hlo.transpose"([[ARG]]) + %0 = "xla_hlo.transpose"(%arg) {permutation = dense<[1, 0, 3, 2]> : tensor<4xi64>}: (tensor<2x3x9x5xi32>) -> tensor<3x2x5x9xi32> + return %0 : tensor<3x2x5x9xi32> +} + +// ----- + +// CHECK-LABEL: func @keep_same_shape_real_transpose +// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]] +func @keep_same_shape_real_transpose(%arg : tensor<4x4xi32>) -> tensor<4x4xi32> { + // CHECK-NEXT: "xla_hlo.transpose"([[ARG]]) + %0 = "xla_hlo.transpose"(%arg) {permutation = dense<[1, 0]> : tensor<2xi64>}: (tensor<4x4xi32>) -> tensor<4x4xi32> + return %0 : tensor<4x4xi32> +} diff --git a/tests/tuple.mlir b/tests/tuple.mlir new file mode 100644 index 0000000..bf68009 --- /dev/null +++ b/tests/tuple.mlir @@ -0,0 +1,10 @@ +// RUN: mlir-hlo-opt %s -split-input-file -pass-pipeline='func(canonicalize)' | FileCheck %s + +// CHECK-LABEL: func @fold_access +// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]] +func @fold_access(%arg : tensor) -> tensor { + // CHECK-NEXT: return [[ARG]] + %tuple = "xla_hlo.tuple"(%arg) : (tensor) -> tuple> + %element = "xla_hlo.get_tuple_element"(%tuple) {index = 0 : i32} : (tuple>) -> tensor + return %element : tensor +} diff --git a/tests/unfuse_batch_norm.mlir b/tests/unfuse_batch_norm.mlir new file mode 100644 index 0000000..cefceeb --- /dev/null +++ b/tests/unfuse_batch_norm.mlir @@ -0,0 +1,135 @@ +// RUN: mlir-hlo-opt -split-input-file -test-xla-unfuse-batch-norm -verify-diagnostics %s | FileCheck --enable-var-scope %s + +// CHECK-LABEL: @batchNormInference_2D_inner_features +// CHECK-SAME: %[[X:[^:[:space:]]+]] +// CHECK-SAME: %[[SCALE:[^:[:space:]]+]] +// CHECK-SAME: %[[OFFSET:[^:[:space:]]+]] +// CHECK-SAME: %[[MEAN:[^:[:space:]]+]] +// CHECK-SAME: %[[VARIANCE:[^:[:space:]]+]] +func @batchNormInference_2D_inner_features( + %x: tensor<4x256xf32>, %scale: tensor<256xf32>, %offset: tensor<256xf32>, + %mean: tensor<256xf32>, %variance: tensor<256xf32>) + -> (tensor<4x256xf32>) { + // CHECK-DAG: %[[EPS:.+]] = xla_hlo.constant dense<1.001000e-05> : tensor + // CHECK-DAG: %[[EPS_BCAST:.+]] = "xla_hlo.broadcast_in_dim"(%[[EPS]]) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor) -> tensor<256xf32> + // CHECK-DAG: %[[VARIANCE_EPS:.+]] = xla_hlo.add %[[VARIANCE]], %[[EPS_BCAST]] : tensor<256xf32> + // CHECK-DAG: %[[STDDEV:.+]] = "xla_hlo.sqrt"(%[[VARIANCE_EPS]]) : (tensor<256xf32>) -> tensor<256xf32> + // CHECK-DAG: %[[STDDEV_BCAST:.+]] = "xla_hlo.broadcast_in_dim"(%[[STDDEV]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<256xf32>) -> tensor<4x256xf32> + // CHECK-DAG: %[[SCALE_BCAST:.+]] = "xla_hlo.broadcast_in_dim"(%[[SCALE]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<256xf32>) -> tensor<4x256xf32> + // CHECK-DAG: %[[OFFSET_BCAST:.+]] = "xla_hlo.broadcast_in_dim"(%[[OFFSET]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<256xf32>) -> tensor<4x256xf32> + // CHECK-DAG: %[[MEAN_BCAST:.+]] = "xla_hlo.broadcast_in_dim"(%[[MEAN]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<256xf32>) -> tensor<4x256xf32> + // CHECK-DAG: %[[X_CENTER:.+]] = xla_hlo.subtract %[[X]], %[[MEAN_BCAST]] : tensor<4x256xf32> + // CHECK-DAG: %[[X_SCALED:.+]] = xla_hlo.multiply %[[X_CENTER]], %[[SCALE_BCAST]] : tensor<4x256xf32> + // CHECK-DAG: %[[X_NORMED:.+]] = xla_hlo.divide %[[X_SCALED]], %[[STDDEV_BCAST]] : tensor<4x256xf32> + // CHECK-DAG: %[[RESULT:.+]] = xla_hlo.add %[[X_NORMED]], %[[OFFSET_BCAST]] : tensor<4x256xf32> + %0 = "xla_hlo.batch_norm_inference"(%x, %scale, %offset, %mean, %variance) + {epsilon = 1.001000e-05 : f32, feature_index = 1 : i64} : + (tensor<4x256xf32>, tensor<256xf32>, tensor<256xf32>, tensor<256xf32>, + tensor<256xf32>) -> tensor<4x256xf32> + // CHECK-DAG: return %[[RESULT]] + return %0 : tensor<4x256xf32> +} + +// ----- +// CHECK-LABEL: @batchNormInference_4D_middle_features +// Just validate that one of the broadcasts happens correctly and rely on +// the verifier to enforce the rest. +// CHECK-SAME: %[[X:[^:]+]] +// CHECK-SAME: %[[SCALE:[^:]+]] +// CHECK-DAG: %[[SCALE_BCAST:.+]] = "xla_hlo.broadcast_in_dim"(%[[SCALE]]) {broadcast_dimensions = dense<2> : tensor<1xi64>} : (tensor<256xf32>) -> tensor<3x4x256x6xf32> +func @batchNormInference_4D_middle_features( + %x: tensor<3x4x256x6xf32>, %scale: tensor<256xf32>, %offset: tensor<256xf32>, + %mean: tensor<256xf32>, %variance: tensor<256xf32>) + -> (tensor<3x4x256x6xf32>) { + %0 = "xla_hlo.batch_norm_inference"(%x, %scale, %offset, %mean, %variance) + {epsilon = 1.001000e-05 : f32, feature_index = 2 : i64} : + (tensor<3x4x256x6xf32>, tensor<256xf32>, tensor<256xf32>, tensor<256xf32>, + tensor<256xf32>) -> tensor<3x4x256x6xf32> + return %0 : tensor<3x4x256x6xf32> +} + +// ----- +// CHECK-LABEL: @batchNormInference_f64 +// Validate that epsilon is properly promoted to f64 +// CHECK-DAG: %[[EPS:.+]] = xla_hlo.constant dense<1.000000e+00> : tensor +func @batchNormInference_f64( + %x: tensor<4x256xf64>, %scale: tensor<256xf64>, %offset: tensor<256xf64>, + %mean: tensor<256xf64>, %variance: tensor<256xf64>) + -> (tensor<4x256xf64>) { + %0 = "xla_hlo.batch_norm_inference"(%x, %scale, %offset, %mean, %variance) + {epsilon = 1.0 : f32, feature_index = 1 : i64} : + (tensor<4x256xf64>, tensor<256xf64>, tensor<256xf64>, tensor<256xf64>, + tensor<256xf64>) -> tensor<4x256xf64> + return %0 : tensor<4x256xf64> +} + +// ----- +// CHECK-LABEL: @batchNormInference_f16 +// Validate that epsilon is properly promoted to f64 +// CHECK-DAG: %[[EPS:.+]] = xla_hlo.constant dense<1.000000e+00> : tensor +func @batchNormInference_f16( + %x: tensor<4x256xf16>, %scale: tensor<256xf16>, %offset: tensor<256xf16>, + %mean: tensor<256xf16>, %variance: tensor<256xf16>) + -> (tensor<4x256xf16>) { + %0 = "xla_hlo.batch_norm_inference"(%x, %scale, %offset, %mean, %variance) + {epsilon = 1.0 : f32, feature_index = 1 : i64} : + (tensor<4x256xf16>, tensor<256xf16>, tensor<256xf16>, tensor<256xf16>, + tensor<256xf16>) -> tensor<4x256xf16> + return %0 : tensor<4x256xf16> +} + +// ----- +// Validate that epsilon is properly promoted to f64 +func @batchNormInference_f16_overflow( + %x: tensor<4x256xf16>, %scale: tensor<256xf16>, %offset: tensor<256xf16>, + %mean: tensor<256xf16>, %variance: tensor<256xf16>) + -> (tensor<4x256xf16>) { + // expected-warning @+1 {{Could not convert batch_norm epsilon to target fp type: opStatus = 24}} + %0 = "xla_hlo.batch_norm_inference"(%x, %scale, %offset, %mean, %variance) + {epsilon = 0.00000001 : f32, feature_index = 1 : i64} : + (tensor<4x256xf16>, tensor<256xf16>, tensor<256xf16>, tensor<256xf16>, + tensor<256xf16>) -> tensor<4x256xf16> + return %0 : tensor<4x256xf16> +} + +// ----- +// CHECK-LABEL: @batchNormInference_dynamic_shape +// Validate that dynamic shapes are handled properly. +// CHECK-SAME: %[[X:[^:[:space:]]+]] +// CHECK-SAME: %[[SCALE:[^:[:space:]]+]] +// CHECK-SAME: %[[OFFSET:[^:[:space:]]+]] +// CHECK-SAME: %[[MEAN:[^:[:space:]]+]] +// CHECK-SAME: %[[VARIANCE:[^:[:space:]]+]] +func @batchNormInference_dynamic_shape( + %x: tensor, %scale: tensor, %offset: tensor, + %mean: tensor, %variance: tensor) + -> tensor { + // CHECK-DAG: %[[C0:.*]] = constant 0 : index + // CHECK-DAG: %[[C1:.*]] = constant 1 : index + // CHECK-DAG: %[[C2:.*]] = constant 2 : index + // CHECK-DAG: %[[C3:.*]] = constant 3 : index + // CHECK-DAG: %[[EPS:.+]] = xla_hlo.constant dense<1.000000e-03> : tensor + // CHECK-DAG: %[[DIM:.+]] = dim %[[VARIANCE]], %[[C0]] : tensor + // CHECK-DAG: %[[TO_DIM_TENSOR:.+]] = tensor_from_elements(%[[DIM]]) : tensor<1xindex> + // CHECK-DAG: %[[EPS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[EPS]], %[[TO_DIM_TENSOR]]) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor, tensor<1xindex>) -> tensor + // CHECK-DAG: %[[VARIANCE_EPS:.+]] = xla_hlo.add %[[VARIANCE]], %[[EPS_BCAST]] : tensor + // CHECK-DAG: %[[STDDEV:.+]] = "xla_hlo.sqrt"(%[[VARIANCE_EPS]]) : (tensor) -> tensor + // CHECK-DAG: %[[INPUT_DIM_0:.+]] = dim %[[X]], %[[C0]] : tensor + // CHECK-DAG: %[[INPUT_DIM_1:.+]] = dim %[[X]], %[[C1]] : tensor + // CHECK-DAG: %[[INPUT_DIM_2:.+]] = dim %[[X]], %[[C2]] : tensor + // CHECK-DAG: %[[INPUT_DIM_3:.+]] = dim %[[X]], %[[C3]] : tensor + // CHECK-DAG: %[[TO_INPUT_DIM_TENSOR:.+]] = tensor_from_elements(%[[INPUT_DIM_0]], %[[INPUT_DIM_1]], %[[INPUT_DIM_2]], %[[INPUT_DIM_3]]) : tensor<4xindex> + // CHECK-DAG: %[[STDDEV_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[STDDEV]], %[[TO_INPUT_DIM_TENSOR]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor, tensor<4xindex>) -> tensor + // CHECK-DAG: %[[SCALE_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[SCALE]], %[[TO_INPUT_DIM_TENSOR]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor, tensor<4xindex>) -> tensor + // CHECK-DAG: %[[OFFSET_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[OFFSET]], %[[TO_INPUT_DIM_TENSOR]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor, tensor<4xindex>) -> tensor + // CHECK-DAG: %[[MEAN_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[MEAN]], %[[TO_INPUT_DIM_TENSOR]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor, tensor<4xindex>) -> tensor + // CHECK-DAG: %[[X_CENTER:.+]] = xla_hlo.subtract %[[X]], %[[MEAN_BCAST]] : tensor + // CHECK-DAG: %[[X_SCALED:.+]] = xla_hlo.multiply %[[X_CENTER]], %[[SCALE_BCAST]] : tensor + // CHECK-DAG: %[[X_NORMED:.+]] = xla_hlo.divide %[[X_SCALED]], %[[STDDEV_BCAST]] : tensor + // CHECK-DAG: %[[RESULT:.+]] = xla_hlo.add %[[X_NORMED]], %[[OFFSET_BCAST]] : tensor + %0 = "xla_hlo.batch_norm_inference"(%x, %scale, %offset, %mean, %variance) + {epsilon = 0.001 : f32, feature_index = 1 : i64} : + (tensor, tensor, tensor, tensor, + tensor) -> tensor + return %0 : tensor +} diff --git a/tests/xla-hlo-fusion.mlir b/tests/xla-hlo-fusion.mlir new file mode 100644 index 0000000..7061bc2 --- /dev/null +++ b/tests/xla-hlo-fusion.mlir @@ -0,0 +1,97 @@ +// RUN: mlir-hlo-opt %s -xla-hlo-fusion -split-input-file | FileCheck %s + +// CHECK-LABEL: func @multi_outputs_same +func @multi_outputs_same(%arg0: tensor, %arg1: tensor) -> (tensor, tensor) { + %0 = "xla_hlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + %1 = "xla_hlo.subtract"(%arg0, %0) : (tensor, tensor) -> tensor + %2 = "xla_hlo.add"(%1, %1) : (tensor, tensor) -> tensor + // CHECK: %[[RET:.*]]:2 = "xla_hlo.fusion" + // CHECK-NEXT: xla_hlo.add + // CHECK-NEXT: xla_hlo.subtract + // CHECK-NEXT: xla_hlo.add + // CHECK-NEXT: xla_hlo.return + return %1, %2 : tensor, tensor +} + +// ----- + +// CHECK-LABEL: func @multi_outputs_same_2 +func @multi_outputs_same_2(%arg0: tensor, %arg1: tensor) -> (tensor, tensor, tensor) { + %0 = "xla_hlo.abs"(%arg0) : (tensor) -> tensor + %1 = "xla_hlo.abs"(%arg1) : (tensor) -> tensor + %2 = "xla_hlo.add"(%0, %1) : (tensor, tensor) -> tensor + %3 = "xla_hlo.abs"(%0) : (tensor) -> tensor + %4 = "xla_hlo.abs"(%1) : (tensor) -> tensor + // CHECK: %[[RET:.*]]:3 = "xla_hlo.fusion" + // CHECK-NEXT: xla_hlo.abs + // CHECK-NEXT: xla_hlo.abs + // CHECK-NEXT: xla_hlo.add + // CHECK-NEXT: xla_hlo.abs + // CHECK-NEXT: xla_hlo.abs + // CHECK-NEXT: xla_hlo.return + return %2, %3, %4 : tensor, tensor, tensor +} + +// ----- + +// CHECK-LABEL: func @multi_outputs_not_sure_same +func @multi_outputs_not_sure_same(%arg0: tensor, %arg1: tensor) -> (tensor, tensor) { + %0 = "xla_hlo.add"(%arg0, %arg0) : (tensor, tensor) -> tensor + // CHECK-NOT: xla_hlo.fusion + %1 = "xla_hlo.subtract"(%arg1, %arg1) : (tensor, tensor) -> tensor + return %0, %1 : tensor, tensor +} + +// ----- + +// CHECK-LABEL: func @reduce +func @reduce(%arg0: tensor, %arg1: tensor) -> (tensor, tensor) { + %0 = "xla_hlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + %1 = "xla_hlo.subtract"(%arg0, %0) : (tensor, tensor) -> tensor + // CHECK: %[[RET0:.*]] = "xla_hlo.fusion" + // CHECK-NEXT: xla_hlo.add + // CHECK-NEXT: xla_hlo.subtract + // CHECK-NEXT: xla_hlo.return + // Currently we do not support fuse arguments and ops without direct producer-consumer + // relationship. Thus Reduce Op should not be fused with above two ops. + + %2 = xla_hlo.constant dense<0.000000e+00> : tensor + %3 = "xla_hlo.reduce"(%arg0, %2) ( { + ^bb0(%arg2: tensor, %arg3: tensor): + %4 = "xla_hlo.add"(%arg2, %arg3) : (tensor, tensor) -> tensor + "xla_hlo.return"(%4) : (tensor) -> () + }) {dimensions = dense<[1]> : tensor<1xi64>} : (tensor, tensor) -> tensor + %4 = "xla_hlo.add"(%3, %3) : (tensor, tensor) -> tensor + // Above two ops should not be fused since reduce op can not be + // fused with its consumer. + // CHECK-NOT: xla_hlo.fusion + + return %1, %4 : tensor, tensor +} + +// ----- + +// CHECK-LABEL: func @reduce_2 +func @reduce_2(%arg0: tensor, %arg1: tensor) -> (tensor, tensor) { + %0 = "xla_hlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + %1 = "xla_hlo.subtract"(%arg0, %0) : (tensor, tensor) -> tensor + + %2 = xla_hlo.constant dense<0.000000e+00> : tensor + %3 = "xla_hlo.reduce"(%1, %2) ( { + ^bb0(%arg2: tensor, %arg3: tensor): + %4 = "xla_hlo.add"(%arg2, %arg3) : (tensor, tensor) -> tensor + "xla_hlo.return"(%4) : (tensor) -> () + }) {dimensions = dense<[1]> : tensor<1xi64>} : (tensor, tensor) -> tensor + // CHECK: %[[RET0:.*]]:2 = "xla_hlo.fusion" + // CHECK-NEXT: xla_hlo.add + // CHECK-NEXT: xla_hlo.subtract + // CHECK-NEXT: xla_hlo.constant + // CHECK-NEXT: xla_hlo.reduce + // CHECK: xla_hlo.return + + // Following op should not be fused with the above ops since reduce op can not be + // fused with its consumer. + // CHECK-NOT: xla_hlo.fusion + %4 = "xla_hlo.add"(%3, %3) : (tensor, tensor) -> tensor + return %1, %4 : tensor, tensor +} diff --git a/tests/xla-transform-unranked-hlo.mlir b/tests/xla-transform-unranked-hlo.mlir new file mode 100644 index 0000000..eb98789 --- /dev/null +++ b/tests/xla-transform-unranked-hlo.mlir @@ -0,0 +1,88 @@ +// RUN: mlir-hlo-opt -transform-unranked-hlo -split-input-file %s | FileCheck %s + +// Check the validity of expected IR. +// CHECK-LABEL: @sqr_transform_result +func @sqr_transform_result(%a: tensor<*xf32>) -> tensor<*xf32> { + + // Flatten operand shape. + %shape = shape.shape_of %a : tensor<*xf32> + %num_elements = shape.num_elements %shape + %num_elements_as_index = shape.size_to_index %num_elements + %flat_shape = tensor_from_elements(%num_elements_as_index) : tensor<1xindex> + %flat_a = "xla_hlo.dynamic_reshape"(%a, %flat_shape) + : (tensor<*xf32>, tensor<1xindex>) -> tensor + + // Apply operation. + %flat_b = "xla_hlo.sqrt"(%flat_a) : (tensor) -> tensor + + // Restore original shape. + %shape_as_extent_tensor = shape.to_extent_tensor %shape : tensor + %b = "xla_hlo.dynamic_reshape"(%flat_b, %shape_as_extent_tensor) + : (tensor, tensor) -> tensor<*xf32> + + return %b : tensor<*xf32> +} + +// ----- + +// Check transformation of unranked code. +// CHECK-LABEL: @sqrt +// CHECK-SAME: (%[[A:.*]]: tensor<*xf32>) +func @sqrt(%a: tensor<*xf32>) -> tensor<*xf32> { + // CHECK-NEXT: %[[SHAPE:.*]] = shape.shape_of %[[A]] : tensor<*xf32> + // CHECK-NEXT: %[[NUM_ELEMENTS:.*]] = shape.num_elements %[[SHAPE]] + // CHECK-NEXT: %[[NUM_ELEMENTS_AS_INDEX:.*]] = shape.size_to_index %[[NUM_ELEMENTS]] + // CHECK-NEXT: %[[FLAT_SHAPE:.*]] = tensor_from_elements(%[[NUM_ELEMENTS_AS_INDEX]]) : tensor<1xindex> + // CHECK-NEXT: %[[FLAT_A:.*]] = "xla_hlo.dynamic_reshape"(%[[A]], %[[FLAT_SHAPE]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor + // CHECK-NEXT: %[[FLAT_B:.*]] = "xla_hlo.sqrt"(%[[FLAT_A]]) : (tensor) -> tensor + // CHECK-NEXT: %[[SHAPE_AS_EXTENT_TENSOR:.*]] = shape.to_extent_tensor %[[SHAPE]] : tensor + // CHECK-NEXT: %[[B:.*]] = "xla_hlo.dynamic_reshape"(%[[FLAT_B]], %[[SHAPE_AS_EXTENT_TENSOR]]) : (tensor, tensor) -> tensor<*xf32> + // CHECK-NEXT: return %[[B]] : tensor<*xf32> + %b = "xla_hlo.sqrt"(%a) : (tensor<*xf32>) -> tensor<*xf32> + return %b : tensor<*xf32> +} + +// ----- + +// Not transformed when ranked. +// CHECK-LABEL: @sqrt_ranked +// CHECK-SAME: (%[[A:.*]]: tensor<3x?xf32>) +func @sqrt_ranked(%a: tensor<3x?xf32>) -> tensor<3x?xf32> { + // CHECK-NEXT: %[[B:.*]] = "xla_hlo.sqrt"(%[[A]]) : (tensor<3x?xf32>) -> tensor<3x?xf32> + // CHECK-NEXT: return %[[B]] : tensor<3x?xf32> + %b = "xla_hlo.sqrt"(%a) : (tensor<3x?xf32>) -> tensor<3x?xf32> + return %b : tensor<3x?xf32> +} + +// ----- + +// Not transformed when statically shaped. +// CHECK-LABEL: @sqrt_static +// CHECK-SAME: (%[[A:.*]]: tensor<2x3xf32>) +func @sqrt_static(%a: tensor<2x3xf32>) -> tensor<2x3xf32> { + // CHECK-NEXT: %[[B:.*]] = "xla_hlo.sqrt"(%[[A]]) : (tensor<2x3xf32>) -> tensor<2x3xf32> + // CHECK-NEXT: return %[[B]] : tensor<2x3xf32> + %b = "xla_hlo.sqrt"(%a) : (tensor<2x3xf32>) -> tensor<2x3xf32> + return %b : tensor<2x3xf32> +} + +// ----- + +// CHECK-LABEL: @add_unranked +// CHECK-SAME: (%[[A:.*]]: tensor<*xf32>, %[[B:.*]]: tensor<*xf32>) -> tensor<*xf32> +func @add_unranked(%a : tensor<*xf32>, %b : tensor<*xf32>) -> tensor<*xf32> { + // CHECK: %[[SHAPE_A:.*]] = shape.shape_of %[[A]] + // CHECK: %[[SHAPE_B:.*]] = shape.shape_of %[[B]] + // CHECK: %[[SHAPE:.*]] = shape.any %[[SHAPE_A]], %[[SHAPE_B]] + // CHECK: %[[NUM_ELEMENTS:.*]] = shape.num_elements %[[SHAPE]] + // CHECK: %[[NUM_ELEMENTS_AS_INDEX:.*]] = shape.size_to_index %[[NUM_ELEMENTS]] + // CHECK: %[[FLAT_SHAPE:.*]] = tensor_from_elements(%[[NUM_ELEMENTS_AS_INDEX]]) : tensor<1xindex> + // CHECK: %[[FLAT_A:.*]] = "xla_hlo.dynamic_reshape"(%[[A]], %[[FLAT_SHAPE]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor + // CHECK: %[[FLAT_B:.*]] = "xla_hlo.dynamic_reshape"(%[[B]], %[[FLAT_SHAPE]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor + // CHECK: %[[FLAT_RESULT:.*]] = xla_hlo.add %[[FLAT_A]], %[[FLAT_B]] : tensor + // CHECK: %[[SHAPE_AS_EXTENT_TENSOR:.*]] = shape.to_extent_tensor %[[SHAPE]] : tensor + // CHECK: %[[RESULT:.*]] = "xla_hlo.dynamic_reshape"(%[[FLAT_RESULT]], %[[SHAPE_AS_EXTENT_TENSOR]]) : (tensor, tensor) -> tensor<*xf32> + // CHECK: return %[[RESULT]] : tensor<*xf32> + %result = xla_hlo.add %a, %b : tensor<*xf32> + return %result : tensor<*xf32> +}