// RUN: mlir-hlo-opt %s -pass-pipeline='func(canonicalize)' | FileCheck %s // CHECK-LABEL: add_fold func @add_fold() -> tensor<4xi64> { %0 = mhlo.constant dense<[1, 2, 3, 4]> : tensor<4xi64> %1 = mhlo.constant dense<[5, 6, 7, 8]> : tensor<4xi64> // CHECK: mhlo.constant dense<[6, 8, 10, 12]> %2 = "mhlo.add"(%0, %1) : (tensor<4xi64>, tensor<4xi64>) -> (tensor<4xi64>) return %2 : tensor<4xi64> } // CHECK-LABEL: add_scalar_fold func @add_scalar_fold() -> tensor<4xi64> { %0 = mhlo.constant dense<1> : tensor<4xi64> %1 = mhlo.constant dense<5> : tensor<4xi64> // CHECK: mhlo.constant dense<6> %2 = "mhlo.add"(%0, %1) : (tensor<4xi64>, tensor<4xi64>) -> (tensor<4xi64>) return %2 : tensor<4xi64> } // CHECK-LABEL: add_fold_float func @add_fold_float() -> tensor<4xf64> { %0 = mhlo.constant dense<[1.0, 2.0, 3.0, 4.0]> : tensor<4xf64> %1 = mhlo.constant dense<[5.0, 6.0, 7.0, 8.0]> : tensor<4xf64> // CHECK: mhlo.constant dense<[6.000000e+00, 8.000000e+00, 1.000000e+01, 1.200000e+01]> %2 = "mhlo.add"(%0, %1) : (tensor<4xf64>, tensor<4xf64>) -> (tensor<4xf64>) return %2 : tensor<4xf64> } // CHECK-LABEL: sub_scalar_fold func @sub_scalar_fold() -> tensor<4xi64> { %0 = mhlo.constant dense<5> : tensor<4xi64> %1 = mhlo.constant dense<1> : tensor<4xi64> // CHECK: mhlo.constant dense<4> %2 = "mhlo.subtract"(%0, %1) : (tensor<4xi64>, tensor<4xi64>) -> (tensor<4xi64>) return %2 : tensor<4xi64> } // CHECK-LABEL: multiply_scalar_fold func @multiply_scalar_fold() -> tensor<4xi64> { %0 = mhlo.constant dense<5> : tensor<4xi64> %1 = mhlo.constant dense<3> : tensor<4xi64> // CHECK: mhlo.constant dense<15> %2 = "mhlo.multiply"(%0, %1) : (tensor<4xi64>, tensor<4xi64>) -> (tensor<4xi64>) return %2 : tensor<4xi64> } // CHECK-LABEL: divide_scalar_fold func @divide_scalar_fold() -> tensor<4xi64> { %0 = mhlo.constant dense<7> : tensor<4xi64> %1 = mhlo.constant dense<5> : tensor<4xi64> // CHECK: mhlo.constant dense<1> %2 = "mhlo.divide"(%0, %1) : (tensor<4xi64>, tensor<4xi64>) -> (tensor<4xi64>) return %2 : tensor<4xi64> } // CHECK-LABEL: divide_fold_float func @divide_fold_float() -> tensor<4xf64> { %0 = mhlo.constant dense<[5.0, 66.0, 5.0, 1.0]> : tensor<4xf64> %1 = mhlo.constant dense<[5.0, 3.0, 2.0, 4.0]> : tensor<4xf64> // CHECK: mhlo.constant dense<[1.000000e+00, 2.200000e+01, 2.500000e+00, 2.500000e-01]> %2 = "mhlo.divide"(%0, %1) : (tensor<4xf64>, tensor<4xf64>) -> (tensor<4xf64>) return %2 : tensor<4xf64> } // CHECK-LABEL: remainder_fold_int func @remainder_fold_int() -> tensor<4xi32> { %0 = mhlo.constant dense<[5, 66, 5, 1]> : tensor<4xi32> %1 = mhlo.constant dense<[3, 5, 1, 2]> : tensor<4xi32> // CHECK: mhlo.constant dense<[2, 1, 0, 1]> %2 = "mhlo.remainder"(%0, %1) : (tensor<4xi32>, tensor<4xi32>) -> (tensor<4xi32>) return %2 : tensor<4xi32> } // CHECK-LABEL: remainder_fold_float func @remainder_fold_float() -> tensor<4xf32> { %0 = mhlo.constant dense<[7.0, 66.5, 5.0, 3.1]> : tensor<4xf32> %1 = mhlo.constant dense<[3.0, 5.0, 1.0, 2.6]> : tensor<4xf32> // CHECK: mhlo.constant dense<[1.000000e+00, 1.500000e+00, 0.000000e+00, 5.000000e-01]> %2 = "mhlo.remainder"(%0, %1) : (tensor<4xf32>, tensor<4xf32>) -> (tensor<4xf32>) return %2 : tensor<4xf32> } // CHECK-LABEL: max_scalar_fold func @max_scalar_fold() -> tensor<4xi64> { %0 = mhlo.constant dense<7> : tensor<4xi64> %1 = mhlo.constant dense<5> : tensor<4xi64> // CHECK: mhlo.constant dense<7> %2 = "mhlo.maximum"(%0, %1) : (tensor<4xi64>, tensor<4xi64>) -> (tensor<4xi64>) return %2 : tensor<4xi64> } // CHECK-LABEL: max_fold_float func @max_fold_float() -> tensor<4xf64> { %0 = mhlo.constant dense<[5.0, 66.0, 5.0, 1.0]> : tensor<4xf64> %1 = mhlo.constant dense<[5.0, 3.0, 2.0, 4.0]> : tensor<4xf64> // CHECK: mhlo.constant dense<[5.000000e+00, 6.600000e+01, 5.000000e+00, 4.000000e+00]> %2 = "mhlo.maximum"(%0, %1) : (tensor<4xf64>, tensor<4xf64>) -> (tensor<4xf64>) return %2 : tensor<4xf64> } // CHECK-LABEL: min_scalar_fold func @min_scalar_fold() -> tensor<4xi64> { %0 = mhlo.constant dense<7> : tensor<4xi64> %1 = mhlo.constant dense<-5> : tensor<4xi64> // CHECK: mhlo.constant dense<-5> %2 = "mhlo.minimum"(%0, %1) : (tensor<4xi64>, tensor<4xi64>) -> (tensor<4xi64>) return %2 : tensor<4xi64> } // CHECK-LABEL: min_fold_float func @min_fold_float() -> tensor<4xf64> { %0 = mhlo.constant dense<[5.0, 66.0, 5.0, 1.0]> : tensor<4xf64> %1 = mhlo.constant dense<[5.0, 3.0, 2.0, 4.0]> : tensor<4xf64> // CHECK: mhlo.constant dense<[5.000000e+00, 3.000000e+00, 2.000000e+00, 1.000000e+00]> %2 = "mhlo.minimum"(%0, %1) : (tensor<4xf64>, tensor<4xf64>) -> (tensor<4xf64>) return %2 : tensor<4xf64> } // CHECK-LABEL: concatenate_noop func @concatenate_noop(%arg0: tensor<4xi32>) -> tensor<4xi32> { // CHECK-SAME: [[ARG:%.+]]: tensor<4xi32> %0 = "mhlo.concatenate"(%arg0) { dimension = 0 : i64 } : (tensor<4xi32>) -> tensor<4xi32> // CHECK: return [[ARG]] return %0 : tensor<4xi32> } // CHECK-LABEL: concatenate_remove_operand func @concatenate_remove_operand(%arg0: tensor<4xi32>, %arg1: tensor<0xi32>) -> tensor<4xi32> { // CHECK-SAME: [[ARG0:%.+]]: tensor<4xi32> // CHECK-SAME: [[ARG1:%.+]]: tensor<0xi32> %0 = "mhlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<4xi32>, tensor<0xi32>) -> tensor<4xi32> // CHECK: return [[ARG0]] return %0 : tensor<4xi32> } // CHECK-LABEL: concatenate_empty_bool func @concatenate_empty_bool(%arg0: tensor<0xi1>, %arg1: tensor<0xi1>) -> tensor<0xi1> { // CHECK: mhlo.constant %0 = "mhlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<0xi1>, tensor<0xi1>) -> tensor<0xi1> return %0 : tensor<0xi1> } // CHECK-LABEL: concatenate_empty_int func @concatenate_empty_int(%arg0: tensor<0xi32>, %arg1: tensor<0xi32>) -> tensor<0xi32> { // CHECK: mhlo.constant %0 = "mhlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<0xi32>, tensor<0xi32>) -> tensor<0xi32> return %0 : tensor<0xi32> } // CHECK-LABEL: concatenate_empty_float func @concatenate_empty_float(%arg0: tensor<0xf32>, %arg1: tensor<0xf32>) -> tensor<0xf32> { // CHECK: mhlo.constant %0 = "mhlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<0xf32>, tensor<0xf32>) -> tensor<0xf32> return %0 : tensor<0xf32> } // CHECK-LABEL: concatenate_const_1D func @concatenate_const_1D() -> tensor<4xi32> { // CHECK: [[VAL:%.+]]= mhlo.constant dense<[0, 1, 2, 3]> %0 = mhlo.constant dense<[0, 1]> : tensor<2xi32> %1 = mhlo.constant dense<[2, 3]> : tensor<2xi32> %2 = "mhlo.concatenate"(%0, %1) { dimension = 0 : i64 } : (tensor<2xi32>, tensor<2xi32>) -> tensor<4xi32> // CHECK: return [[VAL]] return %2 : tensor<4xi32> } // CHECK-LABEL: concatenate_const_1D_float func @concatenate_const_1D_float() -> tensor<4xf32> { // CHECK: [[VAL:%.+]] = mhlo.constant dense<[0.000000e+00, 1.000000e+00, 2.000000e+00, 3.000000e+00]> %0 = mhlo.constant dense<[0.0, 1.0]> : tensor<2xf32> %1 = mhlo.constant dense<[2.0, 3.0]> : tensor<2xf32> %2 = "mhlo.concatenate"(%0, %1) { dimension = 0 : i64 } : (tensor<2xf32>, tensor<2xf32>) -> tensor<4xf32> // CHECK: return [[VAL]] return %2 : tensor<4xf32> } // CHECK-LABEL: concatenate_const_2D_vertical func @concatenate_const_2D_vertical() -> tensor<2x2xi32> { // CHECK: [[VAL:%.+]]= mhlo.constant dense<[ // CHECK-SAME: [0, 1], [2, 3] // CHECK-SAME: ]> %0 = mhlo.constant dense<[[0, 1]]> : tensor<1x2xi32> %1 = mhlo.constant dense<[[2, 3]]> : tensor<1x2xi32> %2 = "mhlo.concatenate"(%0, %1) { dimension = 0 : i64 } : (tensor<1x2xi32>, tensor<1x2xi32>) -> tensor<2x2xi32> // CHECK: return [[VAL]] return %2 : tensor<2x2xi32> } // CHECK-LABEL: concatenate_const_2D_horizontal func @concatenate_const_2D_horizontal() -> tensor<2x2xi32> { // CHECK: [[VAL:%.+]]= mhlo.constant dense<[ // CHECK-SAME: [0, 2], [1, 3] // CHECK-SAME: ]> %0 = mhlo.constant dense<[[0], [1]]> : tensor<2x1xi32> %1 = mhlo.constant dense<[[2], [3]]> : tensor<2x1xi32> %2 = "mhlo.concatenate"(%0, %1) { dimension = 1 : i64 } : (tensor<2x1xi32>, tensor<2x1xi32>) -> tensor<2x2xi32> // CHECK: return [[VAL]] return %2 : tensor<2x2xi32> } // CHECK-LABEL: constant_like_constant func @constant_like_constant(%arg0: tensor<3x4xi32>) -> tensor<3x4xf32> { // CHECK: mhlo.constant dense<3.200000e+00> %0 = "chlo.constant_like"(%arg0) { value = 3.2 : f32 } : (tensor<3x4xi32>) -> tensor<3x4xf32> return %0 : tensor<3x4xf32> } // CHECK-LABEL: constant_like_constant_dynamic func @constant_like_constant_dynamic(%arg0: tensor<*xi32>) -> tensor<*xf32> { // CHECK: chlo.constant_like %0 = "chlo.constant_like"(%arg0) { value = 3.2 : f32 } : (tensor<*xi32>) -> tensor<*xf32> return %0 : tensor<*xf32> } // CHECK-LABEL: dynamic_slice_variable_start func @dynamic_slice_variable_start(%arg0: tensor<3x4xi32>, %arg1: tensor, %arg2: tensor) -> tensor<1x4xi32> { // CHECK: "mhlo.dynamic-slice" %1 = "mhlo.dynamic-slice"(%arg0, %arg1, %arg2) {slice_sizes = dense<[1, 4]> : tensor<2xi64>} : (tensor<3x4xi32>, tensor, tensor) -> tensor<1x4xi32> return %1 : tensor<1x4xi32> } // CHECK-LABEL: dynamic_slice_constant_start func @dynamic_slice_constant_start(%arg0: tensor<4xi32>) -> tensor<2xi32> { // CHECK: %[[RESULT:.*]] = "mhlo.slice"(%arg0) // CHECK-DAG-SAME: limit_indices = dense<3> : tensor<1xi64> // CHECK-DAG-SAME: start_indices = dense<1> : tensor<1xi64> // CHECK-DAG-SAME: strides = dense<1> : tensor<1xi64>} // CHECK: return %[[RESULT]] : tensor<2xi32> %0 = mhlo.constant dense<1> : tensor %1 = "mhlo.dynamic-slice"(%arg0, %0) {slice_sizes = dense<2> : tensor<1xi64>} : (tensor<4xi32>, tensor) -> tensor<2xi32> return %1 : tensor<2xi32> } // CHECK-LABEL: dynamic_slice_constant_start_dynamic_shape func @dynamic_slice_constant_start_dynamic_shape(%arg0: tensor, %arg1: tensor<2xi64>) -> tensor { // CHECK: %[[RESULT:.*]] = "mhlo.slice"(%arg0) // CHECK-DAG-SAME: limit_indices = dense<[2, 4]> : tensor<2xi64> // CHECK-DAG-SAME: start_indices = dense<[1, 0]> : tensor<2xi64> // CHECK-DAG-SAME: strides = dense<1> : tensor<2xi64> // CHECK: return %[[RESULT]] : tensor %0 = mhlo.constant dense<1> : tensor %1 = mhlo.constant dense<0> : tensor %2 = "mhlo.dynamic-slice"(%arg0, %0, %1) {slice_sizes = dense<[1, 4]> : tensor<2xi64>} : (tensor, tensor, tensor) -> tensor return %2 : tensor } // CHECK-LABEL: slice_2D_noop // CHECK-SAME: [[ARG:%.+]]: tensor<2x2xi64> func @slice_2D_noop(%arg0: tensor<2x2xi64>) -> tensor<2x2xi64> { %0 = "mhlo.slice"(%arg0) { limit_indices = dense<[2, 2]> : tensor<2xi64>, start_indices = dense<[0, 0]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<2x2xi64>) -> (tensor<2x2xi64>) // CHECK-NEXT: return [[ARG]] return %0 : tensor<2x2xi64> } // CHECK-LABEL: slice_1D_fold func @slice_1D_fold() -> tensor<2xi64> { %0 = mhlo.constant dense<[5, 7, 9, 10]> : tensor<4xi64> // CHECK: mhlo.constant dense<[7, 9]> %1 = "mhlo.slice"(%0) { limit_indices = dense<[3]> : tensor<1xi64>, start_indices = dense<[1]> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} : (tensor<4xi64>) -> (tensor<2xi64>) return %1 : tensor<2xi64> } // CHECK-LABEL: slice_1D_fp func @slice_1D_fp() -> tensor<2xf32> { %0 = mhlo.constant dense<[5.0, 7.0, 9.0, 10.0]> : tensor<4xf32> // CHECK: mhlo.constant dense<[7.000000e+00, 9.000000e+00]> %1 = "mhlo.slice"(%0) { limit_indices = dense<[3]> : tensor<1xi64>, start_indices = dense<[1]> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} : (tensor<4xf32>) -> (tensor<2xf32>) return %1 : tensor<2xf32> } // CHECK-LABEL: slice_1D_strided_fold func @slice_1D_strided_fold() -> tensor<2xi64> { %0 = mhlo.constant dense<[5, 7, 9, 10]> : tensor<4xi64> // CHECK: mhlo.constant dense<[7, 10]> %1 = "mhlo.slice"(%0) { limit_indices = dense<[4]> : tensor<1xi64>, start_indices = dense<[1]> : tensor<1xi64>, strides = dense<2> : tensor<1xi64>} : (tensor<4xi64>) -> (tensor<2xi64>) return %1 : tensor<2xi64> } // CHECK-LABEL: slice_2D_fold func @slice_2D_fold() -> tensor<2x2xi64> { %0 = mhlo.constant dense<[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15]]> : tensor<4x4xi64> // CHECK-NEXT: mhlo.constant dense<[ // CHECK-SAME: [6, 7], // CHECK-SAME: [10, 11] // CHECK-SAME: ]> %1 = "mhlo.slice"(%0) { limit_indices = dense<[3, 4]> : tensor<2xi64>, start_indices = dense<[1, 2]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<4x4xi64>) -> (tensor<2x2xi64>) return %1 : tensor<2x2xi64> } // CHECK-LABEL: slice_2D_fold_horizontal func @slice_2D_fold_horizontal() -> tensor<1x4xi64> { %0 = mhlo.constant dense<[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15]]> : tensor<4x4xi64> // CHECK-NEXT: mhlo.constant dense<[ // CHECK-SAME: [0, 1, 2, 3] // CHECK-SAME: ]> %1 = "mhlo.slice"(%0) { limit_indices = dense<[1, 4]> : tensor<2xi64>, start_indices = dense<[0, 0]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<4x4xi64>) -> (tensor<1x4xi64>) return %1 : tensor<1x4xi64> } // CHECK-LABEL: slice_2D_fold_vertical func @slice_2D_fold_vertical() -> tensor<4x1xi64> { %0 = mhlo.constant dense<[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15]]> : tensor<4x4xi64> // CHECK-NEXT: mhlo.constant dense<[ // CHECK-SAME: [2], [6], [10], [14] // CHECK-SAME: ]> %1 = "mhlo.slice"(%0) { limit_indices = dense<[4, 3]> : tensor<2xi64>, start_indices = dense<[0, 2]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<4x4xi64>) -> (tensor<4x1xi64>) return %1 : tensor<4x1xi64> } // CHECK-LABEL: slice_unknown_shape func @slice_unknown_shape(%arg0: tensor<*xf32>) -> tensor<*xf32> { // CHECK: "mhlo.slice"(%arg0) {limit_indices = dense<[1, 4]> : tensor<2xi64>, start_indices = dense<0> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<*xf32>) -> tensor<*xf32> %0 = "mhlo.slice"(%arg0) {limit_indices = dense<[1, 4]> : tensor<2xi64>, start_indices = dense<0> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<*xf32>) -> tensor<*xf32> return %0 : tensor<*xf32> } // CHECK-LABEL: slice_concat_fold_first func @slice_concat_fold_first(%arg0: tensor<1x5xf32>, %arg1: tensor<1x5xf32>) -> tensor<1x5xf32> { %0 = "mhlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<1x5xf32>, tensor<1x5xf32>) -> tensor<2x5xf32> %1 = "mhlo.slice"(%0) { limit_indices = dense<[1, 5]> : tensor<2xi64>, start_indices = dense<[0, 0]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<2x5xf32>) -> (tensor<1x5xf32>) // CHECK: return %arg0 return %1 : tensor<1x5xf32> } // CHECK-LABEL: slice_concat_fold_second func @slice_concat_fold_second(%arg0: tensor<1x5xf32>, %arg1: tensor<1x5xf32>) -> tensor<1x5xf32> { %0 = "mhlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<1x5xf32>, tensor<1x5xf32>) -> tensor<2x5xf32> %1 = "mhlo.slice"(%0) { limit_indices = dense<[2, 5]> : tensor<2xi64>, start_indices = dense<[1, 0]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<2x5xf32>) -> (tensor<1x5xf32>) // CHECK: return %arg1 return %1 : tensor<1x5xf32> } // CHECK-LABEL: slice_concat_fold_second_with_slice func @slice_concat_fold_second_with_slice(%arg0: tensor<1x5xf32>, %arg1: tensor<1x5xf32>) -> tensor<1x4xf32> { %0 = "mhlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<1x5xf32>, tensor<1x5xf32>) -> tensor<2x5xf32> // CHECK: [[SLICE:%.+]] = "mhlo.slice"(%arg1) {limit_indices = dense<[1, 5]> : tensor<2xi64>, start_indices = dense<[0, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<1x5xf32>) -> tensor<1x4xf32> %1 = "mhlo.slice"(%0) { limit_indices = dense<[2, 5]> : tensor<2xi64>, start_indices = dense<[1, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<2x5xf32>) -> (tensor<1x4xf32>) // CHECK: return [[SLICE]] return %1 : tensor<1x4xf32> } // CHECK-LABEL: slice_concat_fold_middle func @slice_concat_fold_middle(%arg0: tensor<1x5xf32>, %arg1: tensor<2x5xf32>, %arg2: tensor<1x5xf32>) -> tensor<1x5xf32> { %0 = "mhlo.concatenate"(%arg0, %arg1, %arg2) { dimension = 0 : i64 } : (tensor<1x5xf32>, tensor<2x5xf32>, tensor<1x5xf32>) -> tensor<4x5xf32> // CHECK: [[SLICE:%.+]] = "mhlo.slice"(%arg1) {limit_indices = dense<[2, 5]> : tensor<2xi64>, start_indices = dense<[1, 0]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} %1 = "mhlo.slice"(%0) { limit_indices = dense<[3, 5]> : tensor<2xi64>, start_indices = dense<[2, 0]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<4x5xf32>) -> (tensor<1x5xf32>) // CHECK: return [[SLICE]] return %1 : tensor<1x5xf32> } // CHECK-LABEL: slice_concat_fold_two func @slice_concat_fold_two(%arg0: tensor<1x5xf32>, %arg1: tensor<2x5xf32>, %arg2: tensor<1x5xf32>) -> tensor<2x5xf32> { // CHECK: [[CONCAT:%.+]] = "mhlo.concatenate"(%arg1, %arg2) {dimension = 0 : i64} %0 = "mhlo.concatenate"(%arg0, %arg1, %arg2) { dimension = 0 : i64 } : (tensor<1x5xf32>, tensor<2x5xf32>, tensor<1x5xf32>) -> tensor<4x5xf32> // CHECK: [[SLICE:%.+]] = "mhlo.slice"([[CONCAT]]) {limit_indices = dense<[3, 5]> : tensor<2xi64>, start_indices = dense<[1, 0]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} %1 = "mhlo.slice"(%0) { limit_indices = dense<[4, 5]> : tensor<2xi64>, start_indices = dense<[2, 0]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<4x5xf32>) -> (tensor<2x5xf32>) // CHECK: return [[SLICE]] return %1 : tensor<2x5xf32> } // CHECK-LABEL: func @broadcast_in_dim_identity func @broadcast_in_dim_identity(%arg0: tensor<2x3x4xf32>) -> tensor<2x3x4xf32> { // CHECK: return %arg0 %0 = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<2x3x4xf32>) -> tensor<2x3x4xf32> return %0 : tensor<2x3x4xf32> } // CHECK-LABEL: func @broadcast_in_dim_not_identity_because_it_actually_broadcasts func @broadcast_in_dim_not_identity_because_it_actually_broadcasts(%arg0: tensor<1x2xf32>) -> tensor<2x2xf32> { // CHECK: mhlo.broadcast_in_dim %0 = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<1x2xf32>) -> tensor<2x2xf32> return %0 : tensor<2x2xf32> } // CHECK-LABEL: func @broadcast_in_dim_not_identity_permutation func @broadcast_in_dim_not_identity_permutation(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> { // CHECK: mhlo.broadcast_in_dim %0 = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[1, 0]> : tensor<2xi64>} : (tensor<2x2xf32>) -> tensor<2x2xf32> return %0 : tensor<2x2xf32> } // CHECK-LABEL: func @dynamic_broadcast_in_dim_op_not_actually_dynamic func @dynamic_broadcast_in_dim_op_not_actually_dynamic(%arg0: tensor<4xf32>, %arg1: tensor<2xi64>) -> tensor<5x4xf32> { // CHECK: %[[RESULT:.+]] = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<4xf32>) -> tensor<5x4xf32> %0 = "mhlo.dynamic_broadcast_in_dim"(%arg0, %arg1) { broadcast_dimensions = dense<1> : tensor<1xi64> } : (tensor<4xf32>, tensor<2xi64>) -> tensor<5x4xf32> // CHECK: return %[[RESULT]] : tensor<5x4xf32> return %0 : tensor<5x4xf32> } // CHECK-LABEL: func @dynamic_broadcast_in_dim_to_same_shape_1 func @dynamic_broadcast_in_dim_to_same_shape_1(%arg0: tensor) -> tensor { // CHECK-SAME: %[[ARG:.*]]: tensor %0 = shape.shape_of %arg0 : tensor -> tensor<1xindex> %2 = "mhlo.dynamic_broadcast_in_dim"(%arg0, %0) { broadcast_dimensions = dense<0> : tensor<1xi64> } : (tensor, tensor<1xindex>) -> tensor // CHECK: return %[[ARG]] : tensor return %2 : tensor } // CHECK-LABEL: func @dynamic_broadcast_in_dim_to_same_shape_2 func @dynamic_broadcast_in_dim_to_same_shape_2(%arg0: tensor) -> tensor { // CHECK-SAME: %[[ARG:.*]]: tensor %0 = shape.shape_of %arg0 : tensor -> !shape.shape %1 = shape.to_extent_tensor %0 : !shape.shape -> tensor<1xindex> %2 = "mhlo.dynamic_broadcast_in_dim"(%arg0, %1) { broadcast_dimensions = dense<0> : tensor<1xi64> } : (tensor, tensor<1xindex>) -> tensor // CHECK: return %[[ARG]] : tensor return %2 : tensor } // CHECK-LABEL: func @broadcast_in_dim_constant_fold_0d func @broadcast_in_dim_constant_fold_0d() -> tensor<1x64x224x224xf32> { %cst = mhlo.constant dense<0.000000e+00> : tensor %b = "mhlo.broadcast_in_dim"(%cst) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor) -> tensor<1x64x224x224xf32> return %b : tensor<1x64x224x224xf32> } // CHECK-NEXT: %[[CST:.*]] = mhlo.constant dense<0.000000e+00> : tensor<1x64x224x224xf32> // CHECK-NEXT: return %[[CST]] : tensor<1x64x224x224xf32> // CHECK-LABEL: func @broadcast_in_dim_constant_fold func @broadcast_in_dim_constant_fold() -> tensor<1x64x4x4xf32> { %cst = mhlo.constant dense<0.000000e+00> : tensor<4x4xf32> %b = "mhlo.broadcast_in_dim"(%cst) {broadcast_dimensions = dense<[2, 3]> : tensor<2xi64>} : (tensor<4x4xf32>) -> tensor<1x64x4x4xf32> return %b : tensor<1x64x4x4xf32> } // CHECK-NEXT: %[[CST:.*]] = mhlo.constant dense<0.000000e+00> : tensor<1x64x4x4xf32> // CHECK-NEXT: return %[[CST]] : tensor<1x64x4x4xf32> // CHECK-LABEL: @complex_expand_fold func @complex_expand_fold(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> (tensor<4xf32>, tensor<4xf32>) { %0 = "mhlo.complex"(%arg0, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> (tensor<4xcomplex>) %1 = "mhlo.real"(%0) : (tensor<4xcomplex>) -> (tensor<4xf32>) %2 = "mhlo.imag"(%0) : (tensor<4xcomplex>) -> (tensor<4xf32>) // CHECK: return %arg0, %arg1 return %1, %2 : tensor<4xf32>, tensor<4xf32> } // CHECK-LABEL: @complex_collapse_fold func @complex_collapse_fold(%arg0: tensor<4xcomplex>) -> tensor<4xcomplex> { %0 = "mhlo.real"(%arg0) : (tensor<4xcomplex>) -> (tensor<4xf32>) %1 = "mhlo.imag"(%arg0) : (tensor<4xcomplex>) -> (tensor<4xf32>) %2 = "mhlo.complex"(%0, %1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xcomplex> // CHECK: return %arg0 return %2 : tensor<4xcomplex> } // CHECK-LABEL: @dynamic_iota_is_static func @dynamic_iota_is_static(%arg0 : tensor<1xindex>) -> tensor<4xi32> { // CHECK: [[RESULT:%.*]] = "mhlo.iota" // CHECK: return [[RESULT]] %0 = "mhlo.dynamic_iota"(%arg0) {iota_dimension = 0 : i64} : (tensor<1xindex>) -> tensor<4xi32> return %0 : tensor<4xi32> } // CHECK-LABEL: @dynamic_iota_broadcast func @dynamic_iota_broadcast(%arg0 : tensor<2xindex>) -> tensor<5x?xi32> { // CHECK: [[IOTA:%.+]] = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<5xi32> // CHECK: [[BROADCAST:%.+]] = "mhlo.dynamic_broadcast_in_dim"([[IOTA]], %arg0) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor<5xi32>, tensor<2xindex>) -> tensor<5x?xi32> %0 = "mhlo.dynamic_iota"(%arg0) {iota_dimension = 0 : i64} : (tensor<2xindex>) -> tensor<5x?xi32> // CHECK: return [[BROADCAST]] return %0 : tensor<5x?xi32> } // CHECK-LABEL: @dynamic_iota_broadcast_second func @dynamic_iota_broadcast_second(%arg0 : tensor<2xindex>) -> tensor<5x?xi32> { // CHECK-NEXT: [[CAST1:%.+]] = index_cast %arg0 : tensor<2xindex> to tensor<2xi64> // CHECK-NEXT: [[SLICE:%.+]] = "mhlo.slice"([[CAST1]]) {limit_indices = dense<2> : tensor<1xi64>, start_indices = dense<1> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} : (tensor<2xi64>) -> tensor<1xi64> // CHECK-NEXT: [[CAST2:%.+]] = index_cast [[SLICE]] : tensor<1xi64> to tensor<1xindex> // CHECK-NEXT: [[IOTA:%.+]] = "mhlo.dynamic_iota"([[CAST2]]) {iota_dimension = 0 : i64} : (tensor<1xindex>) -> tensor // CHECK-NEXT: [[BROADCAST:%.+]] = "mhlo.dynamic_broadcast_in_dim"([[IOTA]], %arg0) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor, tensor<2xindex>) -> tensor<5x?xi32> %0 = "mhlo.dynamic_iota"(%arg0) {iota_dimension = 1 : i64} : (tensor<2xindex>) -> tensor<5x?xi32> // CHECK: return [[BROADCAST]] return %0 : tensor<5x?xi32> } // CHECK-LABEL: @dynamic_iota_constant func @dynamic_iota_constant(%arg0 : tensor<2xindex>) -> tensor<1x?xi32> { // CHECK: [[IOTA:%.+]] = mhlo.constant dense<0> : tensor<1xi32> // CHECK: [[BROADCAST:%.+]] = "mhlo.dynamic_broadcast_in_dim"([[IOTA]], %arg0) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor<1xi32>, tensor<2xindex>) -> tensor<1x?xi32> %0 = "mhlo.dynamic_iota"(%arg0) {iota_dimension = 0 : i64} : (tensor<2xindex>) -> tensor<1x?xi32> // CHECK: return [[BROADCAST]] return %0 : tensor<1x?xi32> } // CHECK-LABEL: @iota_constant func @iota_constant() -> tensor<1xi32> { // CHECK: [[CONST:%.+]] = mhlo.constant dense<0> : tensor<1xi32> %0 = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<1xi32> // CHECK: return [[CONST]] : tensor<1xi32> return %0 : tensor<1xi32> } // CHECK-LABEL: @iota_constant_multi func @iota_constant_multi() -> tensor<1x4xi32> { // CHECK: [[CONST:%.+]] = mhlo.constant dense<0> : tensor<1x4xi32> %0 = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<1x4xi32> // CHECK: return [[CONST]] : tensor<1x4xi32> return %0 : tensor<1x4xi32> } // CHECK-LABEL: @iota_not_lowered_to_constant func @iota_not_lowered_to_constant() -> tensor<4xi32> { // CHECK: [[RESULT:%.*]] = "mhlo.iota" // CHECK: return [[RESULT]] %0 = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<4xi32> return %0 : tensor<4xi32> } // CHECK-LABEL: @iota_broadcast func @iota_broadcast() -> tensor<5x4xi32> { // CHECK: [[IOTA:%.+]] = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<5xi32> // CHECK: [[RESULT:%.+]] = "mhlo.broadcast_in_dim"([[IOTA]]) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor<5xi32>) -> tensor<5x4xi32> %0 = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<5x4xi32> return %0 : tensor<5x4xi32> } // CHECK-LABEL: @iota_broadcast func @iota_broadcast_second() -> tensor<5x4xi32> { // CHECK: [[IOTA:%.+]] = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<4xi32> // CHECK: [[RESULT:%.+]] = "mhlo.broadcast_in_dim"([[IOTA]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<4xi32>) -> tensor<5x4xi32> %0 = "mhlo.iota"() {iota_dimension = 1 : i64} : () -> tensor<5x4xi32> return %0 : tensor<5x4xi32> } // CHECK-LABEL: @unary_einsum func @unary_einsum(%arg0: tensor<2x3xf32>) -> tensor<2x2xf32> { // CHECK: %[[ONE:.*]] = mhlo.constant dense<1.000000e+00> : tensor // CHECK: "mhlo.einsum"(%[[ONE]], %arg0) {einsum_config = ",ab->aa"} %0 = "mhlo.unary_einsum"(%arg0) {einsum_config = "ab->aa"} : (tensor<2x3xf32>) -> tensor<2x2xf32> return %0 : tensor<2x2xf32> } // CHECK-LABEL: func @fold_copy // CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]] func @fold_copy(%arg : tensor<1x4xf32>) -> tensor<1x4xf32> { // CHECK: return [[ARG]] %0 = "mhlo.copy"(%arg) : (tensor<1x4xf32>) -> tensor<1x4xf32> return %0 : tensor<1x4xf32> } // CHECK-LABEL: func @dynamic_reshape_not_actually_dynamic func @dynamic_reshape_not_actually_dynamic(%arg0: tensor<4xf32>, %shape: tensor<2xindex>) -> tensor<4x1xf32> { // CHECK: mhlo.reshape %0 = "mhlo.dynamic_reshape"(%arg0, %shape) : (tensor<4xf32>, tensor<2xindex>) -> tensor<4x1xf32> return %0 : tensor<4x1xf32> } // CHECK-LABEL: do_not_dce_while_with_outfeed func @do_not_dce_while_with_outfeed(%arg0: tensor) -> tensor { // CHECK: mhlo.while %0 = "mhlo.while"(%arg0) ( { ^bb0(%arg1: tensor): %1 = "mhlo.compare"(%arg1, %arg1) {comparison_direction = "LT"} : (tensor, tensor) -> tensor "mhlo.return"(%1) : (tensor) -> () }, { ^bb0(%arg1: tensor): %1 = "mhlo.create_token"() : () -> !mhlo.token // Side-effecting op outfeed present inside while. %2 = "mhlo.outfeed"(%arg1, %1) {outfeed_config = ""} : (tensor, !mhlo.token) -> !mhlo.token "mhlo.return"(%arg1) : (tensor) -> () }) : (tensor) -> tensor return %arg0 : tensor } // CHECK-LABEL: dce_while_without_side_effect func @dce_while_without_side_effect(%arg0: tensor) -> tensor { // CHECK-NOT: mhlo.while %0 = "mhlo.while"(%arg0) ( { ^bb0(%arg1: tensor): %1 = "mhlo.compare"(%arg1, %arg1) {comparison_direction = "LT"} : (tensor, tensor) -> tensor "mhlo.return"(%1) : (tensor) -> () }, { ^bb0(%arg1: tensor): %1 = "mhlo.create_token"() : () -> !mhlo.token "mhlo.return"(%arg1) : (tensor) -> () }) : (tensor) -> tensor return %arg0 : tensor } // CHECK-LABEL: fold_compare_same_eq func @fold_compare_same_eq(%arg0: tensor) -> tensor { // CHECK: %0 = mhlo.constant dense : tensor %0 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "EQ"} : (tensor, tensor) -> tensor return %0 : tensor } // CHECK-LABEL: fold_compare_same_le func @fold_compare_same_le(%arg0: tensor) -> tensor { // CHECK: %0 = mhlo.constant dense : tensor %0 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "LE"} : (tensor, tensor) -> tensor return %0 : tensor } // CHECK-LABEL: fold_compare_same_ge func @fold_compare_same_ge(%arg0: tensor) -> tensor { // CHECK: %0 = mhlo.constant dense : tensor %0 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "GE"} : (tensor, tensor) -> tensor return %0 : tensor } // CHECK-LABEL: fold_compare_same_ne func @fold_compare_same_ne(%arg0: tensor) -> tensor { // CHECK: %0 = mhlo.constant dense : tensor %0 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "NE"} : (tensor, tensor) -> tensor return %0 : tensor } // CHECK-LABEL: fold_compare_same_lt func @fold_compare_same_lt(%arg0: tensor) -> tensor { // CHECK: %0 = mhlo.constant dense : tensor %0 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "LT"} : (tensor, tensor) -> tensor return %0 : tensor } // CHECK-LABEL: fold_compare_same_gt func @fold_compare_same_gt(%arg0: tensor) -> tensor { // CHECK: %0 = mhlo.constant dense : tensor %0 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "GT"} : (tensor, tensor) -> tensor return %0 : tensor } // CHECK-LABEL: fold_compare_false_eq func @fold_compare_false_eq() -> tensor { %0 = mhlo.constant dense<0> : tensor %1 = mhlo.constant dense<1> : tensor // CHECK: %0 = mhlo.constant dense : tensor %2 = "mhlo.compare"(%0, %1) {comparison_direction = "EQ"} : (tensor, tensor) -> tensor return %2 : tensor } // CHECK-LABEL: fold_compare_true_eq func @fold_compare_true_eq() -> tensor { %0 = mhlo.constant dense<1> : tensor %1 = mhlo.constant dense<1> : tensor // CHECK: %0 = mhlo.constant dense : tensor %2 = "mhlo.compare"(%0, %1) {comparison_direction = "EQ"} : (tensor, tensor) -> tensor return %2 : tensor } // CHECK-LABEL: fold_compare_false_eq_float func @fold_compare_false_eq_float() -> tensor { %0 = mhlo.constant dense<0.> : tensor %1 = mhlo.constant dense<1.> : tensor // CHECK: %0 = mhlo.constant dense : tensor %2 = "mhlo.compare"(%0, %1) {comparison_direction = "EQ"} : (tensor, tensor) -> tensor return %2 : tensor } // CHECK-LABEL: fold_compare_true_eq_float func @fold_compare_true_eq_float() -> tensor { %0 = mhlo.constant dense<1.> : tensor %1 = mhlo.constant dense<1.> : tensor // CHECK: %0 = mhlo.constant dense : tensor %2 = "mhlo.compare"(%0, %1) {comparison_direction = "EQ"} : (tensor, tensor) -> tensor return %2 : tensor } // CHECK-LABEL: fold_compare_false_ne func @fold_compare_false_ne() -> tensor { %0 = mhlo.constant dense<1> : tensor %1 = mhlo.constant dense<1> : tensor // CHECK: %0 = mhlo.constant dense : tensor %2 = "mhlo.compare"(%0, %1) {comparison_direction = "NE"} : (tensor, tensor) -> tensor return %2 : tensor } // CHECK-LABEL: fold_compare_true_ne func @fold_compare_true_ne() -> tensor { %0 = mhlo.constant dense<1> : tensor %1 = mhlo.constant dense<0> : tensor // CHECK: %0 = mhlo.constant dense : tensor %2 = "mhlo.compare"(%0, %1) {comparison_direction = "NE"} : (tensor, tensor) -> tensor return %2 : tensor } // CHECK-LABEL: fold_compare_false_ne_float func @fold_compare_false_ne_float() -> tensor { %0 = mhlo.constant dense<1.> : tensor %1 = mhlo.constant dense<1.> : tensor // CHECK: %0 = mhlo.constant dense : tensor %2 = "mhlo.compare"(%0, %1) {comparison_direction = "NE"} : (tensor, tensor) -> tensor return %2 : tensor } // CHECK-LABEL: fold_compare_true_ne_float func @fold_compare_true_ne_float() -> tensor { %0 = mhlo.constant dense<0.> : tensor %1 = mhlo.constant dense<1.> : tensor // CHECK: %0 = mhlo.constant dense : tensor %2 = "mhlo.compare"(%0, %1) {comparison_direction = "NE"} : (tensor, tensor) -> tensor return %2 : tensor } // CHECK-LABEL: fold_compare_false_lt func @fold_compare_false_lt() -> tensor { %0 = mhlo.constant dense<1> : tensor %1 = mhlo.constant dense<1> : tensor // CHECK: %0 = mhlo.constant dense : tensor %2 = "mhlo.compare"(%0, %1) {comparison_direction = "LT"} : (tensor, tensor) -> tensor return %2 : tensor } // CHECK-LABEL: fold_compare_true_lt func @fold_compare_true_lt() -> tensor { %0 = mhlo.constant dense<0> : tensor %1 = mhlo.constant dense<1> : tensor // CHECK: %0 = mhlo.constant dense : tensor %2 = "mhlo.compare"(%0, %1) {comparison_direction = "LT"} : (tensor, tensor) -> tensor return %2 : tensor } // CHECK-LABEL: fold_compare_false_lt_float func @fold_compare_false_lt_float() -> tensor { %0 = mhlo.constant dense<1.> : tensor %1 = mhlo.constant dense<1.> : tensor // CHECK: %0 = mhlo.constant dense : tensor %2 = "mhlo.compare"(%0, %1) {comparison_direction = "LT"} : (tensor, tensor) -> tensor return %2 : tensor } // CHECK-LABEL: fold_compare_true_lt_float func @fold_compare_true_lt_float() -> tensor { %0 = mhlo.constant dense<0.> : tensor %1 = mhlo.constant dense<1.> : tensor // CHECK: %0 = mhlo.constant dense : tensor %2 = "mhlo.compare"(%0, %1) {comparison_direction = "LT"} : (tensor, tensor) -> tensor return %2 : tensor } // CHECK-LABEL: fold_compare_false_le func @fold_compare_false_le() -> tensor { %0 = mhlo.constant dense<1> : tensor %1 = mhlo.constant dense<0> : tensor // CHECK: %0 = mhlo.constant dense : tensor %2 = "mhlo.compare"(%0, %1) {comparison_direction = "LE"} : (tensor, tensor) -> tensor return %2 : tensor } // CHECK-LABEL: fold_compare_true_le func @fold_compare_true_le() -> tensor { %0 = mhlo.constant dense<1> : tensor %1 = mhlo.constant dense<1> : tensor // CHECK: %0 = mhlo.constant dense : tensor %2 = "mhlo.compare"(%0, %1) {comparison_direction = "LE"} : (tensor, tensor) -> tensor return %2 : tensor } // CHECK-LABEL: fold_compare_false_le_float func @fold_compare_false_le_float() -> tensor { %0 = mhlo.constant dense<1.> : tensor %1 = mhlo.constant dense<0.> : tensor // CHECK: %0 = mhlo.constant dense : tensor %2 = "mhlo.compare"(%0, %1) {comparison_direction = "LE"} : (tensor, tensor) -> tensor return %2 : tensor } // CHECK-LABEL: fold_compare_true_le_float func @fold_compare_true_le_float() -> tensor { %0 = mhlo.constant dense<1.> : tensor %1 = mhlo.constant dense<1.> : tensor // CHECK: %0 = mhlo.constant dense : tensor %2 = "mhlo.compare"(%0, %1) {comparison_direction = "LE"} : (tensor, tensor) -> tensor return %2 : tensor } // CHECK-LABEL: fold_compare_false_gt func @fold_compare_false_gt() -> tensor { %0 = mhlo.constant dense<0> : tensor %1 = mhlo.constant dense<0> : tensor // CHECK: %0 = mhlo.constant dense : tensor %2 = "mhlo.compare"(%0, %1) {comparison_direction = "GT"} : (tensor, tensor) -> tensor return %2 : tensor } // CHECK-LABEL: fold_compare_true_gt func @fold_compare_true_gt() -> tensor { %0 = mhlo.constant dense<1> : tensor %1 = mhlo.constant dense<0> : tensor // CHECK: %0 = mhlo.constant dense : tensor %2 = "mhlo.compare"(%0, %1) {comparison_direction = "GT"} : (tensor, tensor) -> tensor return %2 : tensor } // CHECK-LABEL: fold_compare_false_gt_float func @fold_compare_false_gt_float() -> tensor { %0 = mhlo.constant dense<0.> : tensor %1 = mhlo.constant dense<0.> : tensor // CHECK: %0 = mhlo.constant dense : tensor %2 = "mhlo.compare"(%0, %1) {comparison_direction = "GT"} : (tensor, tensor) -> tensor return %2 : tensor } // CHECK-LABEL: fold_compare_true_gt_float func @fold_compare_true_gt_float() -> tensor { %0 = mhlo.constant dense<1.> : tensor %1 = mhlo.constant dense<0.> : tensor // CHECK: %0 = mhlo.constant dense : tensor %2 = "mhlo.compare"(%0, %1) {comparison_direction = "GT"} : (tensor, tensor) -> tensor return %2 : tensor } // CHECK-LABEL: fold_compare_false_ge func @fold_compare_false_ge() -> tensor { %0 = mhlo.constant dense<0> : tensor %1 = mhlo.constant dense<1> : tensor // CHECK: %0 = mhlo.constant dense : tensor %2 = "mhlo.compare"(%0, %1) {comparison_direction = "GE"} : (tensor, tensor) -> tensor return %2 : tensor } // CHECK-LABEL: fold_compare_true_ge func @fold_compare_true_ge() -> tensor { %0 = mhlo.constant dense<0> : tensor %1 = mhlo.constant dense<0> : tensor // CHECK: %0 = mhlo.constant dense : tensor %2 = "mhlo.compare"(%0, %1) {comparison_direction = "GE"} : (tensor, tensor) -> tensor return %2 : tensor } // CHECK-LABEL: fold_compare_false_ge_float func @fold_compare_false_ge_float() -> tensor { %0 = mhlo.constant dense<0.> : tensor %1 = mhlo.constant dense<1.> : tensor // CHECK: %0 = mhlo.constant dense : tensor %2 = "mhlo.compare"(%0, %1) {comparison_direction = "GE"} : (tensor, tensor) -> tensor return %2 : tensor } // CHECK-LABEL: fold_compare_true_ge_float func @fold_compare_true_ge_float() -> tensor { %0 = mhlo.constant dense<0.> : tensor %1 = mhlo.constant dense<0.> : tensor // CHECK: %0 = mhlo.constant dense : tensor %2 = "mhlo.compare"(%0, %1) {comparison_direction = "GE"} : (tensor, tensor) -> tensor return %2 : tensor } // CHECK-LABEL: unpack_repack_same_tuple // CHECK-SAME: ([[ARG0:%.*]]: tuple, !mhlo.token, tensor>) func @unpack_repack_same_tuple(%arg0: tuple, !mhlo.token, tensor>) -> tuple, !mhlo.token, tensor> { %0 = "mhlo.get_tuple_element"(%arg0) {index = 0 : i32} : (tuple, !mhlo.token, tensor>) -> tensor %1 = "mhlo.get_tuple_element"(%arg0) {index = 1 : i32} : (tuple, !mhlo.token, tensor>) -> !mhlo.token %2 = "mhlo.get_tuple_element"(%arg0) {index = 2 : i32} : (tuple, !mhlo.token, tensor>) -> tensor %3 = "mhlo.tuple"(%0, %1, %2) : (tensor, !mhlo.token, tensor) -> tuple, !mhlo.token, tensor> // CHECK: return [[ARG0]] return %3 : tuple, !mhlo.token, tensor> } // CHECK-LABEL: unpack_repack_same_tuple_single_element // CHECK-SAME: ([[ARG0:%.*]]: tuple>) func @unpack_repack_same_tuple_single_element(%arg0: tuple>) -> tuple> { %0 = "mhlo.get_tuple_element"(%arg0) {index = 0 : i32} : (tuple>) -> tensor %3 = "mhlo.tuple"(%0) : (tensor) -> tuple> // CHECK: return [[ARG0]] return %3 : tuple> } // CHECK-LABEL: func @erase_dead_lhlo_constant func @erase_dead_lhlo_constant() { %M = alloc() : memref<256x1024xf32> // CHECK-NEXT: return "lmhlo.constant"(%M) {value = dense<0.0> : tensor} : (memref<256x1024xf32>) -> () dealloc %M : memref<256x1024xf32> return } // A negative test for dead lhlo constant op erasure. // CHECK-LABEL: func @erase_dead_lhlo_constant_negative func @erase_dead_lhlo_constant_negative(%M : memref<4xf32>) -> memref<256x1024xf32> { // CHECK-NEXT: lmhlo.constant "lmhlo.constant"(%M) {value = dense<0.0> : tensor} : (memref<4xf32>) -> () // CHECK-NEXT: alloc // CHECK-NEXT: lmhlo.constant %N = alloc() : memref<256x1024xf32> "lmhlo.constant"(%N) {value = dense<0.0> : tensor} : (memref<256x1024xf32>) -> () return %N : memref<256x1024xf32> } // CHECK-LABEL: func @fold_get_dimension_size func @fold_get_dimension_size(%I : tensor<1x128x512xf32>) -> tensor { %size = "mhlo.get_dimension_size"(%I) {dimension = 2 : i32} : (tensor<1x128x512xf32>) -> tensor return %size : tensor // CHECK-NEXT: %[[C:.*]] = mhlo.constant dense<512> : tensor // CHECK-NEXT: return %[[C]] } // CHECK-LABEL: func @fold_select_same func @fold_select_same(%arg0 : tensor, %arg1 : tensor) -> tensor { %1 = "mhlo.select"(%arg1, %arg0, %arg0) : (tensor, tensor, tensor) -> tensor // CHECK: return %arg0 return %1 : tensor } // CHECK-LABEL: func @fold_select_first func @fold_select_first(%arg0 : tensor, %arg1 : tensor) -> tensor { %0 = mhlo.constant dense<1> : tensor %1 = "mhlo.select"(%0, %arg0, %arg1) : (tensor, tensor, tensor) -> tensor // CHECK: return %arg0 return %1 : tensor } // CHECK-LABEL: func @fold_select_second func @fold_select_second(%arg0 : tensor, %arg1 : tensor) -> tensor { %0 = mhlo.constant dense<0> : tensor %1 = "mhlo.select"(%0, %arg0, %arg1) : (tensor, tensor, tensor) -> tensor // CHECK: return %arg1 return %1 : tensor } // CHECK-LABEL: func @fold_select_vector func @fold_select_vector(%arg0 : tensor<4xf32>, %arg1 : tensor<4xf32>) -> tensor<4xf32> { %0 = mhlo.constant dense<1> : tensor<4xi1> %1 = "mhlo.select"(%0, %arg0, %arg1) : (tensor<4xi1>, tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> // CHECK: return %arg0 return %1 : tensor<4xf32> } // CHECK-LABEL: gather_to_slice func @gather_to_slice(%arg0: tensor<5x6x7xf32>) -> tensor<3x6x5xf32> { %0 = constant dense<[1, 2]> : tensor<2xi32> %1 = "mhlo.gather"(%arg0, %0) { dimension_numbers = {collapsed_slice_dims = dense<> : tensor<0xi64>, index_vector_dim = 0 : i64, offset_dims = dense<[0, 1, 2]> : tensor<3xi64>, start_index_map = dense<[0, 2]> : tensor<2xi64>}, indices_are_sorted = false, slice_sizes = dense<[3, 6, 5]> : tensor<3xi64>} : (tensor<5x6x7xf32>, tensor<2xi32>) -> tensor<3x6x5xf32> return %1 : tensor<3x6x5xf32> // CHECK: %[[RET:.*]] = "mhlo.slice"(%arg0) {limit_indices = dense<[4, 6, 7]> : tensor<3xi64>, start_indices = dense<[1, 0, 2]> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>} : (tensor<5x6x7xf32>) -> tensor<3x6x5xf32> // CHECK: return %[[RET]] : tensor<3x6x5xf32> } // CHECK-LABEL: gather_scalar_index_to_slice func @gather_scalar_index_to_slice(%arg0: tensor<5x6x7xf32>) -> tensor<5x6x4xf32> { %0 = constant dense<1> : tensor %1 = "mhlo.gather"(%arg0, %0) { dimension_numbers = {collapsed_slice_dims = dense<> : tensor<0xi64>, index_vector_dim = 0 : i64, offset_dims = dense<[0, 1, 2]> : tensor<3xi64>, start_index_map = dense<[2]> : tensor<1xi64>}, indices_are_sorted = false, slice_sizes = dense<[5, 6, 4]> : tensor<3xi64>} : (tensor<5x6x7xf32>, tensor) -> tensor<5x6x4xf32> return %1 : tensor<5x6x4xf32> // CHECK: %[[RET:.*]] = "mhlo.slice"(%arg0) {limit_indices = dense<[5, 6, 5]> : tensor<3xi64>, start_indices = dense<[0, 0, 1]> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>} : (tensor<5x6x7xf32>) -> tensor<5x6x4xf32> // CHECK: return %[[RET]] : tensor<5x6x4xf32> } // CHECK-LABEL: gather_to_slice_reshape func @gather_to_slice_reshape(%arg0: tensor<5x6x7xf32>) -> tensor<3x6xf32> { %0 = constant dense<[1, 2]> : tensor<2xi32> %1 = "mhlo.gather"(%arg0, %0) { dimension_numbers = {collapsed_slice_dims = dense<[2]> : tensor<1xi64>, index_vector_dim = 0 : i64, offset_dims = dense<[0, 1, 2]> : tensor<3xi64>, start_index_map = dense<[0, 2]> : tensor<2xi64>}, indices_are_sorted = false, slice_sizes = dense<[3, 6, 1]> : tensor<3xi64>} : (tensor<5x6x7xf32>, tensor<2xi32>) -> tensor<3x6xf32> return %1 : tensor<3x6xf32> // CHECK: %[[V0:.*]] = "mhlo.slice"(%arg0) {limit_indices = dense<[4, 6, 3]> : tensor<3xi64>, start_indices = dense<[1, 0, 2]> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>} : (tensor<5x6x7xf32>) -> tensor<3x6x1xf32> // CHECK: %[[V1:.*]] = "mhlo.reshape"(%[[V0]]) : (tensor<3x6x1xf32>) -> tensor<3x6xf32> // CHECK: return %[[V1]] : tensor<3x6xf32> } // CHECK-LABEL: func @fold_and_same func @fold_and_same(%arg0 : tensor<4xi32>) -> tensor<4xi32> { %0 = "mhlo.and"(%arg0, %arg0) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> // CHECK: return %arg0 return %0 : tensor<4xi32> } // CHECK-LABEL: func @fold_and_ones func @fold_and_ones(%arg0 : tensor<4xi32>) -> tensor<4xi32> { %0 = mhlo.constant dense<-1> : tensor<4xi32> %1 = "mhlo.and"(%0, %arg0) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> // CHECK: return %arg0 return %1 : tensor<4xi32> } // CHECK-LABEL: func @fold_and_zeros func @fold_and_zeros(%arg0 : tensor<4xi32>) -> tensor<4xi32> { %0 = mhlo.constant dense<0> : tensor<4xi32> %1 = "mhlo.and"(%0, %arg0) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> // CHECK: return %0 return %1 : tensor<4xi32> } // CHECK-LABEL: func @fold_and_constant func @fold_and_constant(%arg0 : tensor<4xi32>) -> tensor<4xi32> { %0 = mhlo.constant dense<7> : tensor<4xi32> // CHECK: mhlo.and %1 = "mhlo.and"(%0, %arg0) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> return %1 : tensor<4xi32> } // CHECK-LABEL: func @fold_and_constants func @fold_and_constants() -> tensor<4xi32> { %0 = mhlo.constant dense<[0, 1, 6, 3]> : tensor<4xi32> %1 = mhlo.constant dense<[7, 3, 7, 2]> : tensor<4xi32> %2 = "mhlo.and"(%0, %1) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> // CHECK: %0 = mhlo.constant dense<[0, 1, 6, 2]> : tensor<4xi32> // CHECK: return %0 return %2 : tensor<4xi32> } // CHECK-LABEL: func @fold_or_same func @fold_or_same(%arg0 : tensor<4xi32>) -> tensor<4xi32> { %0 = "mhlo.or"(%arg0, %arg0) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> // CHECK: return %arg0 return %0 : tensor<4xi32> } // CHECK-LABEL: func @fold_or_ones func @fold_or_ones(%arg0 : tensor<4xi32>) -> tensor<4xi32> { %0 = mhlo.constant dense<-1> : tensor<4xi32> %1 = "mhlo.or"(%0, %arg0) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> // CHECK: return %0 return %1 : tensor<4xi32> } // CHECK-LABEL: func @fold_or_zeros func @fold_or_zeros(%arg0 : tensor<4xi32>) -> tensor<4xi32> { %0 = mhlo.constant dense<0> : tensor<4xi32> %1 = "mhlo.or"(%0, %arg0) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> // CHECK: return %arg0 return %1 : tensor<4xi32> } // CHECK-LABEL: func @fold_or_constant func @fold_or_constant(%arg0 : tensor<4xi32>) -> tensor<4xi32> { %0 = mhlo.constant dense<7> : tensor<4xi32> // CHECK: mhlo.or %1 = "mhlo.or"(%0, %arg0) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> return %1 : tensor<4xi32> } // CHECK-LABEL: func @fold_or_zeros_right func @fold_or_zeros_right(%arg0 : tensor<4xi32>) -> tensor<4xi32> { %0 = mhlo.constant dense<0> : tensor<4xi32> %1 = "mhlo.or"(%arg0, %0) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> // CHECK: return %arg0 return %1 : tensor<4xi32> } // CHECK-LABEL: func @fold_or_zeros_constants func @fold_or_zeros_constants() -> tensor<4xi32> { %0 = mhlo.constant dense<[0, 1, 6, 3]> : tensor<4xi32> %1 = mhlo.constant dense<[7, 3, 7, 2]> : tensor<4xi32> %2 = "mhlo.or"(%0, %1) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> // CHECK: %0 = mhlo.constant dense<[7, 3, 7, 3]> : tensor<4xi32> // CHECK: return %0 return %2 : tensor<4xi32> } // CHECK-LABEL: func @fold_xor_same func @fold_xor_same(%arg0 : tensor<4xi32>) -> tensor<4xi32> { %0 = "mhlo.xor"(%arg0, %arg0) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> // CHECK: %0 = mhlo.constant dense<0> : tensor<4xi32> // CHECK: return %0 return %0 : tensor<4xi32> } // CHECK-LABEL: func @fold_xor_ones_left func @fold_xor_ones_left(%arg0 : tensor<4xi32>) -> tensor<4xi32> { %0 = mhlo.constant dense<-1> : tensor<4xi32> // CHECK: mhlo.xor %1 = "mhlo.xor"(%0, %arg0) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> return %1 : tensor<4xi32> } // CHECK-LABEL: func @fold_xor_ones_right func @fold_xor_ones_right(%arg0 : tensor<4xi32>) -> tensor<4xi32> { %0 = mhlo.constant dense<-1> : tensor<4xi32> // CHECK: mhlo.xor %1 = "mhlo.xor"(%arg0, %0) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> return %1 : tensor<4xi32> } // CHECK-LABEL: func @fold_xor_zeros_left func @fold_xor_zeros_left(%arg0 : tensor<4xi32>) -> tensor<4xi32> { %0 = mhlo.constant dense<0> : tensor<4xi32> %1 = "mhlo.xor"(%0, %arg0) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> // CHECK: return %arg0 return %1 : tensor<4xi32> } // CHECK-LABEL: func @fold_xor_zeros_right func @fold_xor_zeros_right(%arg0 : tensor<4xi32>) -> tensor<4xi32> { %0 = mhlo.constant dense<0> : tensor<4xi32> %1 = "mhlo.xor"(%arg0, %0) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> // CHECK: return %arg0 return %1 : tensor<4xi32> } // CHECK-LABEL: func @fold_xor_zeros_constants func @fold_xor_zeros_constants() -> tensor<4xi32> { %0 = mhlo.constant dense<[0, 1, 6, 3]> : tensor<4xi32> %1 = mhlo.constant dense<[7, 3, 7, 2]> : tensor<4xi32> %2 = "mhlo.xor"(%0, %1) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> // CHECK: %0 = mhlo.constant dense<[7, 2, 1, 1]> : tensor<4xi32> // CHECK: return %0 return %2 : tensor<4xi32> } // CHECK-LABEL: func @fold_negate_int func @fold_negate_int() -> tensor<4xi32> { %0 = mhlo.constant dense<[0, 1, 6, -3]> : tensor<4xi32> // CHECK: mhlo.constant dense<[0, -1, -6, 3]> %1 = "mhlo.negate"(%0) : (tensor<4xi32>) -> tensor<4xi32> return %1 : tensor<4xi32> } // CHECK-LABEL: func @fold_negate_float func @fold_negate_float() -> tensor<4xf32> { %0 = mhlo.constant dense<[0., 1., 6., -3.]> : tensor<4xf32> // CHECK: mhlo.constant dense<[-0.000000e+00, -1.000000e+00, -6.000000e+00, 3.000000e+00]> %1 = "mhlo.negate"(%0) : (tensor<4xf32>) -> tensor<4xf32> return %1 : tensor<4xf32> } // CHECK-LABEL: func @fold_sqrt_f32_constants func @fold_sqrt_f32_constants() -> tensor<4xf32> { %0 = mhlo.constant dense<1.0> : tensor<4xf32> %1 = "mhlo.sqrt"(%0) : (tensor<4xf32>) -> tensor<4xf32> // CHECK: mhlo.constant dense<1.000000e+00> : tensor<4xf32> // CHECK-NOT: mhlo.sqrt return %1 : tensor<4xf32> } // CHECK-LABEL: func @fold_sqrt_f64_constants func @fold_sqrt_f64_constants() -> tensor<4xf64> { %0 = mhlo.constant dense<[1.0, 4.0, 9.0, 16.0]> : tensor<4xf64> %1 = "mhlo.sqrt"(%0) : (tensor<4xf64>) -> tensor<4xf64> // CHECK: mhlo.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : tensor<4xf64> // CHECK-NOT: mhlo.sqrt return %1 : tensor<4xf64> } // CHECK-LABEL: func @not_fold_sqrt_neg_constants func @not_fold_sqrt_neg_constants() -> tensor<4xf32> { %0 = mhlo.constant dense<-1.0> : tensor<4xf32> %1 = "mhlo.sqrt"(%0) : (tensor<4xf32>) -> tensor<4xf32> // CHECK: mhlo.constant dense<-1.000000e+00> : tensor<4xf32> // CHECK: mhlo.sqrt return %1 : tensor<4xf32> } // CHECK-LABEL: @tensor_flow_scatter_v1_update func @tensor_flow_scatter_v1_update() -> tensor<3x3xi32> { %0 = constant dense<[[1, 2, 3], [4, 5, 6], [7, 8, 9]]> : tensor<3x3xi32> %1 = constant dense<[0, 2]> : tensor<2xi32> %2 = constant dense<[[10, 20, 30], [70, 80, 90]]> : tensor<2x3xi32> %3 = "mhlo.scatter"(%0, %1, %2) ( { ^bb0(%arg0: tensor, %arg1: tensor): "mhlo.return"(%arg1) : (tensor) -> () }) {indices_are_sorted = false, scatter_dimension_numbers = { index_vector_dim = 1 : i64, inserted_window_dims = dense<0> : tensor<1xi64>, scatter_dims_to_operand_dims = dense<0> : tensor<1xi64>, update_window_dims = dense<[1]> : tensor<1xi64> }, unique_indices = false } : (tensor<3x3xi32>, tensor<2xi32>, tensor<2x3xi32>) -> tensor<3x3xi32> return %3 : tensor<3x3xi32> // CHECK: mhlo.constant dense<[ // CHECK-SAME: [10, 20, 30], [4, 5, 6], [70, 80, 90] // CHECK-SAME: ]> : tensor<3x3xi32> } // CHECK-LABEL: @tensor_flow_scatter_v2_update func @tensor_flow_scatter_v2_update() -> tensor<3x3xi32> { %0 = constant dense<[[1, 2, 3], [4, 5, 6], [7, 8, 9]]> : tensor<3x3xi32> %1 = constant dense<[0, 2]> : tensor<2xi32> %2 = constant dense<[[10, 30], [40, 60], [70, 90]]> : tensor<3x2xi32> %3 = "mhlo.scatter"(%0, %1, %2) ( { ^bb0(%arg0: tensor, %arg1: tensor): "mhlo.return"(%arg1) : (tensor) -> () }) {indices_are_sorted = false, scatter_dimension_numbers = { index_vector_dim = 1 : i64, inserted_window_dims = dense<1> : tensor<1xi64>, scatter_dims_to_operand_dims = dense<1> : tensor<1xi64>, update_window_dims = dense<[0]> : tensor<1xi64> }, unique_indices = false } : (tensor<3x3xi32>, tensor<2xi32>, tensor<3x2xi32>) -> tensor<3x3xi32> return %3 : tensor<3x3xi32> // CHECK: mhlo.constant dense<[ // CHECK-SAME: [10, 2, 30], [40, 5, 60], [70, 8, 90] // CHECK-SAME: ]> : tensor<3x3xi32> } // CHECK-LABEL: @tensor_flow_scatter_add func @tensor_flow_scatter_add() -> tensor<3x3xi32> { %0 = constant dense<[[1, 2, 3], [4, 5, 6], [7, 8, 9]]> : tensor<3x3xi32> %1 = constant dense<[0, 2]> : tensor<2xi32> %2 = constant dense<[[10, 20, 30], [70, 80, 90]]> : tensor<2x3xi32> %3 = "mhlo.scatter"(%0, %1, %2) ( { ^bb0(%arg0: tensor, %arg1: tensor): %4 = "mhlo.add"(%arg0, %arg1) : (tensor, tensor) -> (tensor) "mhlo.return"(%4) : (tensor) -> () }) {indices_are_sorted = false, scatter_dimension_numbers = { index_vector_dim = 1 : i64, inserted_window_dims = dense<0> : tensor<1xi64>, scatter_dims_to_operand_dims = dense<0> : tensor<1xi64>, update_window_dims = dense<[1]> : tensor<1xi64> }, unique_indices = false } : (tensor<3x3xi32>, tensor<2xi32>, tensor<2x3xi32>) -> tensor<3x3xi32> return %3 : tensor<3x3xi32> // CHECK: mhlo.constant dense<[ // CHECK-SAME: [11, 22, 33], [4, 5, 6], [77, 88, 99] // CHECK-SAME: ]> : tensor<3x3xi32> } // CHECK-LABEL: @tensor_flow_scatter_repeated func @tensor_flow_scatter_repeated() -> tensor<3x3xi32> { %0 = constant dense<[[1, 2, 3], [4, 5, 6], [7, 8, 9]]> : tensor<3x3xi32> %1 = constant dense<[1, 1]> : tensor<2xi32> %2 = constant dense<[[10, 20, 30], [70, 80, 90]]> : tensor<2x3xi32> %3 = "mhlo.scatter"(%0, %1, %2) ( { ^bb0(%arg0: tensor, %arg1: tensor): %4 = "mhlo.add"(%arg0, %arg1) : (tensor, tensor) -> (tensor) "mhlo.return"(%4) : (tensor) -> () }) {indices_are_sorted = false, scatter_dimension_numbers = { index_vector_dim = 1 : i64, inserted_window_dims = dense<0> : tensor<1xi64>, scatter_dims_to_operand_dims = dense<0> : tensor<1xi64>, update_window_dims = dense<[1]> : tensor<1xi64> }, unique_indices = false } : (tensor<3x3xi32>, tensor<2xi32>, tensor<2x3xi32>) -> tensor<3x3xi32> return %3 : tensor<3x3xi32> // CHECK: mhlo.constant dense<[ // CHECK-SAME: [1, 2, 3], [84, 105, 126], [7, 8, 9] // CHECK-SAME: ]> : tensor<3x3xi32> } // CHECK-LABEL: @tensor_flow_scatter_multiple_batch func @tensor_flow_scatter_multiple_batch() -> tensor<3x3xi32> { %0 = constant dense<[[1, 2, 3], [4, 5, 6], [7, 8, 9]]> : tensor<3x3xi32> %1 = constant dense<[[0, 2], [2, 1]]> : tensor<2x2xi32> %2 = constant dense<[[[10, 30], [40, 60], [70, 90]], [[5, 5], [5, 5], [5, 5]]]> : tensor<2x3x2xi32> %3 = "mhlo.scatter"(%0, %1, %2) ( { ^bb0(%arg0: tensor, %arg1: tensor): %4 = "mhlo.add"(%arg0, %arg1) : (tensor, tensor) -> (tensor) "mhlo.return"(%4) : (tensor) -> () }) {indices_are_sorted = false, scatter_dimension_numbers = { index_vector_dim = 2 : i64, inserted_window_dims = dense<1> : tensor<1xi64>, scatter_dims_to_operand_dims = dense<1> : tensor<1xi64>, update_window_dims = dense<[1]> : tensor<1xi64> }, unique_indices = false } : (tensor<3x3xi32>, tensor<2x2xi32>, tensor<2x3x2xi32>) -> tensor<3x3xi32> return %3 : tensor<3x3xi32> // CHECK: mhlo.constant dense<[ // CHECK-SAME: [11, 7, 38], [44, 10, 71], [77, 13, 104] // CHECK-SAME: ]> : tensor<3x3xi32> } // CHECK-LABEL: @tensor_flow_scatter_nd func @tensor_flow_scatter_nd() -> tensor<3x3x2xi32> { %0 = constant dense<[[[-1, 1], [-2, 2], [-3, 3]], [[-4, 4], [-5, 5], [-6, 6]], [[-7, 7], [-8, 8], [-9, 9]]]> : tensor<3x3x2xi32> %1 = constant dense<[[0, 0], [1, 0]]> : tensor<2x2xi32> %2 = constant dense<[[-10, 10], [-40, 40]]> : tensor<2x2xi32> %3 = "mhlo.scatter"(%0, %1, %2) ( { ^bb0(%arg0: tensor, %arg1: tensor): "mhlo.return"(%arg1) : (tensor) -> () }) {indices_are_sorted = false, scatter_dimension_numbers = { index_vector_dim = 1 : i64, inserted_window_dims = dense<[0, 1]> : tensor<2xi64>, scatter_dims_to_operand_dims = dense<[0, 1]> : tensor<2xi64>, update_window_dims = dense<[1]> : tensor<1xi64> }, unique_indices = false } : (tensor<3x3x2xi32>, tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<3x3x2xi32> return %3 : tensor<3x3x2xi32> // CHECK: mhlo.constant dense<[ // CHECK-SAME: [-10, 10], [-2, 2], [-3, 3] // CHECK-SAME: [-40, 40], [-5, 5], [-6, 6] // CHECK-SAME: [-7, 7], [-8, 8], [-9, 9] // CHECK-SAME: ]> : tensor<3x3x2xi32> } // CHECK-LABEL: @tensor_flow_scatter_nd_index_vector func @tensor_flow_scatter_nd_index_vector() -> tensor<3x3x2xi32> { %0 = constant dense<[[[-1, 1], [-2, 2], [-3, 3]], [[-4, 4], [-5, 5], [-6, 6]], [[-7, 7], [-8, 8], [-9, 9]]]> : tensor<3x3x2xi32> %1 = constant dense<[[0, 0], [1, 0]]> : tensor<2x2xi32> %2 = constant dense<[[-10, 10], [-20, 20]]> : tensor<2x2xi32> %3 = "mhlo.scatter"(%0, %1, %2) ( { ^bb0(%arg0: tensor, %arg1: tensor): "mhlo.return"(%arg1) : (tensor) -> () }) {indices_are_sorted = false, scatter_dimension_numbers = { index_vector_dim = 0 : i64, inserted_window_dims = dense<[0, 1]> : tensor<2xi64>, scatter_dims_to_operand_dims = dense<[0, 1]> : tensor<2xi64>, update_window_dims = dense<[1]> : tensor<1xi64> }, unique_indices = false } : (tensor<3x3x2xi32>, tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<3x3x2xi32> return %3 : tensor<3x3x2xi32> // CHECK: mhlo.constant dense<[ // CHECK-SAME: [-20, 20], [-10, 10], [-3, 3] // CHECK-SAME: [-4, 4], [-5, 5], [-6, 6] // CHECK-SAME: [-7, 7], [-8, 8], [-9, 9] // CHECK-SAME: ]> : tensor<3x3x2xi32> } // CHECK-LABEL: @scatter_batch_dus func @scatter_batch_dus() -> tensor<3x3xi32> { %0 = constant dense<[[1, 2, 3], [4, 5, 6], [7, 8, 9]]> : tensor<3x3xi32> %1 = constant dense<[[2, 1], [1, 1]]> : tensor<2x2xi32> %2 = constant dense<[[[10]], [[20]]]> : tensor<2x1x1xi32> %3 = "mhlo.scatter"(%0, %1, %2) ( { ^bb0(%arg0: tensor, %arg1: tensor): "mhlo.return"(%arg1) : (tensor) -> () }) {indices_are_sorted = false, scatter_dimension_numbers = { index_vector_dim = 0 : i64, inserted_window_dims = dense<> : tensor<0xi64>, scatter_dims_to_operand_dims = dense<[0, 1]> : tensor<2xi64>, update_window_dims = dense<[1, 2]> : tensor<2xi64> }, unique_indices = false } : (tensor<3x3xi32>, tensor<2x2xi32>, tensor<2x1x1xi32>) -> tensor<3x3xi32> return %3 : tensor<3x3xi32> // CHECK: mhlo.constant dense<[ // CHECK-SAME: [1, 2, 3], [4, 20, 6], [7, 10, 9] // CHECK-SAME: ]> : tensor<3x3xi32> } // CHECK-LABEL: @scatter_no_update_window_dim func @scatter_no_update_window_dim() -> tensor<3xi32> { %0 = constant dense<[0, 1, 2]> : tensor<3xi32> %1 = constant dense<[[[0], [1]], [[2], [1]]]> : tensor<2x2x1xi32> %2 = constant dense<[[10, 20], [30, 40]]> : tensor<2x2xi32> %3 = "mhlo.scatter"(%0, %1, %2) ( { ^bb0(%arg0: tensor, %arg1: tensor): %4 = "mhlo.add"(%arg0, %arg1) : (tensor, tensor) -> (tensor) "mhlo.return"(%4) : (tensor) -> () }) {indices_are_sorted = false, scatter_dimension_numbers = { index_vector_dim = 2 : i64, inserted_window_dims = dense<0> : tensor<1xi64>, scatter_dims_to_operand_dims = dense<0> : tensor<1xi64>, update_window_dims = dense<> : tensor<0xi64> }, unique_indices = false } : (tensor<3xi32>, tensor<2x2x1xi32>, tensor<2x2xi32>) -> tensor<3xi32> return %3 : tensor<3xi32> // CHECK: mhlo.constant dense<[10, 61, 32]> : tensor<3xi32> } // CHECK-LABEL: @scatter_negative_index func @scatter_negative_index() -> tensor<3x3xi32> { %0 = constant dense<[[1, 2, 3], [4, 5, 6], [7, 8, 9]]> : tensor<3x3xi32> %1 = constant dense<[0, -1]> : tensor<2xi32> %2 = constant dense<[[10, 20, 30], [70, 80, 90]]> : tensor<2x3xi32> %3 = "mhlo.scatter"(%0, %1, %2) ( { ^bb0(%arg0: tensor, %arg1: tensor): "mhlo.return"(%arg1) : (tensor) -> () }) {indices_are_sorted = false, scatter_dimension_numbers = { index_vector_dim = 1 : i64, inserted_window_dims = dense<0> : tensor<1xi64>, scatter_dims_to_operand_dims = dense<0> : tensor<1xi64>, update_window_dims = dense<[1]> : tensor<1xi64> }, unique_indices = false } : (tensor<3x3xi32>, tensor<2xi32>, tensor<2x3xi32>) -> tensor<3x3xi32> return %3 : tensor<3x3xi32> // CHECK: constant dense<[ // CHECK-SAME: [1, 2, 3], [4, 5, 6], [7, 8, 9] // CHECK-SAME: ]> : tensor<3x3xi32> // CHECK: "mhlo.scatter" } // CHECK-LABEL: @scatter_out_of_bound func @scatter_out_of_bound() -> tensor<3x3xi32> { %0 = constant dense<[[1, 2, 3], [4, 5, 6], [7, 8, 9]]> : tensor<3x3xi32> %1 = constant dense<[1, 5]> : tensor<2xi32> %2 = constant dense<[[10, 20, 30], [70, 80, 90]]> : tensor<2x3xi32> %3 = "mhlo.scatter"(%0, %1, %2) ( { ^bb0(%arg0: tensor, %arg1: tensor): "mhlo.return"(%arg1) : (tensor) -> () }) {indices_are_sorted = false, scatter_dimension_numbers = { index_vector_dim = 1 : i64, inserted_window_dims = dense<0> : tensor<1xi64>, scatter_dims_to_operand_dims = dense<0> : tensor<1xi64>, update_window_dims = dense<[1]> : tensor<1xi64> }, unique_indices = false } : (tensor<3x3xi32>, tensor<2xi32>, tensor<2x3xi32>) -> tensor<3x3xi32> return %3 : tensor<3x3xi32> // CHECK: constant dense<[ // CHECK-SAME: [1, 2, 3], [4, 5, 6], [7, 8, 9] // CHECK-SAME: ]> : tensor<3x3xi32> // CHECK: "mhlo.scatter" }