Move HLO tests to `third_party/tensorflow/compiler/mlir/hlo/`

Also add a localized `mlir-hlo-opt` binary for the testing of
tensorflow/compiler/mlir/hlo/... ; this directory is intended to be self-contained
and depend only on MLIR.

PiperOrigin-RevId: 319878984
This commit is contained in:
Mehdi Amini 2020-07-06 23:28:26 +00:00 committed by Mehdi Amini
parent 31dc1b21eb
commit fa057cc0bc
33 changed files with 7293 additions and 0 deletions

457
tests/canonicalize.mlir Normal file
View File

@ -0,0 +1,457 @@
// RUN: mlir-hlo-opt %s -pass-pipeline='func(canonicalize)' | FileCheck %s
// CHECK-LABEL: add_fold
func @add_fold() -> tensor<4xi64> {
%0 = xla_hlo.constant dense<[1, 2, 3, 4]> : tensor<4xi64>
%1 = xla_hlo.constant dense<[5, 6, 7, 8]> : tensor<4xi64>
// CHECK: xla_hlo.constant dense<[6, 8, 10, 12]>
%2 = "xla_hlo.add"(%0, %1) : (tensor<4xi64>, tensor<4xi64>) -> (tensor<4xi64>)
return %2 : tensor<4xi64>
}
// CHECK-LABEL: add_scalar_fold
func @add_scalar_fold() -> tensor<4xi64> {
%0 = xla_hlo.constant dense<1> : tensor<4xi64>
%1 = xla_hlo.constant dense<5> : tensor<4xi64>
// CHECK: xla_hlo.constant dense<6>
%2 = "xla_hlo.add"(%0, %1) : (tensor<4xi64>, tensor<4xi64>) -> (tensor<4xi64>)
return %2 : tensor<4xi64>
}
// CHECK-LABEL: add_fold_float
func @add_fold_float() -> tensor<4xf64> {
%0 = xla_hlo.constant dense<[1.0, 2.0, 3.0, 4.0]> : tensor<4xf64>
%1 = xla_hlo.constant dense<[5.0, 6.0, 7.0, 8.0]> : tensor<4xf64>
// CHECK: xla_hlo.constant dense<[6.000000e+00, 8.000000e+00, 1.000000e+01, 1.200000e+01]>
%2 = "xla_hlo.add"(%0, %1) : (tensor<4xf64>, tensor<4xf64>) -> (tensor<4xf64>)
return %2 : tensor<4xf64>
}
// CHECK-LABEL: sub_scalar_fold
func @sub_scalar_fold() -> tensor<4xi64> {
%0 = xla_hlo.constant dense<5> : tensor<4xi64>
%1 = xla_hlo.constant dense<1> : tensor<4xi64>
// CHECK: xla_hlo.constant dense<4>
%2 = "xla_hlo.subtract"(%0, %1) : (tensor<4xi64>, tensor<4xi64>) -> (tensor<4xi64>)
return %2 : tensor<4xi64>
}
// CHECK-LABEL: multiply_scalar_fold
func @multiply_scalar_fold() -> tensor<4xi64> {
%0 = xla_hlo.constant dense<5> : tensor<4xi64>
%1 = xla_hlo.constant dense<3> : tensor<4xi64>
// CHECK: xla_hlo.constant dense<15>
%2 = "xla_hlo.multiply"(%0, %1) : (tensor<4xi64>, tensor<4xi64>) -> (tensor<4xi64>)
return %2 : tensor<4xi64>
}
// CHECK-LABEL: divide_scalar_fold
func @divide_scalar_fold() -> tensor<4xi64> {
%0 = xla_hlo.constant dense<7> : tensor<4xi64>
%1 = xla_hlo.constant dense<5> : tensor<4xi64>
// CHECK: xla_hlo.constant dense<1>
%2 = "xla_hlo.divide"(%0, %1) : (tensor<4xi64>, tensor<4xi64>) -> (tensor<4xi64>)
return %2 : tensor<4xi64>
}
// CHECK-LABEL: divide_fold_float
func @divide_fold_float() -> tensor<4xf64> {
%0 = xla_hlo.constant dense<[5.0, 66.0, 5.0, 1.0]> : tensor<4xf64>
%1 = xla_hlo.constant dense<[5.0, 3.0, 2.0, 4.0]> : tensor<4xf64>
// CHECK: xla_hlo.constant dense<[1.000000e+00, 2.200000e+01, 2.500000e+00, 2.500000e-01]>
%2 = "xla_hlo.divide"(%0, %1) : (tensor<4xf64>, tensor<4xf64>) -> (tensor<4xf64>)
return %2 : tensor<4xf64>
}
// CHECK-LABEL: max_scalar_fold
func @max_scalar_fold() -> tensor<4xi64> {
%0 = xla_hlo.constant dense<7> : tensor<4xi64>
%1 = xla_hlo.constant dense<5> : tensor<4xi64>
// CHECK: xla_hlo.constant dense<7>
%2 = "xla_hlo.maximum"(%0, %1) : (tensor<4xi64>, tensor<4xi64>) -> (tensor<4xi64>)
return %2 : tensor<4xi64>
}
// CHECK-LABEL: max_fold_float
func @max_fold_float() -> tensor<4xf64> {
%0 = xla_hlo.constant dense<[5.0, 66.0, 5.0, 1.0]> : tensor<4xf64>
%1 = xla_hlo.constant dense<[5.0, 3.0, 2.0, 4.0]> : tensor<4xf64>
// CHECK: xla_hlo.constant dense<[5.000000e+00, 6.600000e+01, 5.000000e+00, 4.000000e+00]>
%2 = "xla_hlo.maximum"(%0, %1) : (tensor<4xf64>, tensor<4xf64>) -> (tensor<4xf64>)
return %2 : tensor<4xf64>
}
// CHECK-LABEL: min_scalar_fold
func @min_scalar_fold() -> tensor<4xi64> {
%0 = xla_hlo.constant dense<7> : tensor<4xi64>
%1 = xla_hlo.constant dense<-5> : tensor<4xi64>
// CHECK: xla_hlo.constant dense<-5>
%2 = "xla_hlo.minimum"(%0, %1) : (tensor<4xi64>, tensor<4xi64>) -> (tensor<4xi64>)
return %2 : tensor<4xi64>
}
// CHECK-LABEL: min_fold_float
func @min_fold_float() -> tensor<4xf64> {
%0 = xla_hlo.constant dense<[5.0, 66.0, 5.0, 1.0]> : tensor<4xf64>
%1 = xla_hlo.constant dense<[5.0, 3.0, 2.0, 4.0]> : tensor<4xf64>
// CHECK: xla_hlo.constant dense<[5.000000e+00, 3.000000e+00, 2.000000e+00, 1.000000e+00]>
%2 = "xla_hlo.minimum"(%0, %1) : (tensor<4xf64>, tensor<4xf64>) -> (tensor<4xf64>)
return %2 : tensor<4xf64>
}
// CHECK-LABEL: concatenate_noop
func @concatenate_noop(%arg0: tensor<4xi32>) -> tensor<4xi32> {
// CHECK-SAME: [[ARG:%.+]]: tensor<4xi32>
%0 = "xla_hlo.concatenate"(%arg0) { dimension = 0 : i64 } : (tensor<4xi32>) -> tensor<4xi32>
// CHECK: return [[ARG]]
return %0 : tensor<4xi32>
}
// CHECK-LABEL: concatenate_remove_operand
func @concatenate_remove_operand(%arg0: tensor<4xi32>, %arg1: tensor<0xi32>) -> tensor<4xi32> {
// CHECK-SAME: [[ARG0:%.+]]: tensor<4xi32>
// CHECK-SAME: [[ARG1:%.+]]: tensor<0xi32>
%0 = "xla_hlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<4xi32>, tensor<0xi32>) -> tensor<4xi32>
// CHECK: return [[ARG0]]
return %0 : tensor<4xi32>
}
// CHECK-LABEL: concatenate_empty_bool
func @concatenate_empty_bool(%arg0: tensor<0xi1>, %arg1: tensor<0xi1>) -> tensor<0xi1> {
// CHECK: xla_hlo.constant
%0 = "xla_hlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<0xi1>, tensor<0xi1>) -> tensor<0xi1>
return %0 : tensor<0xi1>
}
// CHECK-LABEL: concatenate_empty_int
func @concatenate_empty_int(%arg0: tensor<0xi32>, %arg1: tensor<0xi32>) -> tensor<0xi32> {
// CHECK: xla_hlo.constant
%0 = "xla_hlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<0xi32>, tensor<0xi32>) -> tensor<0xi32>
return %0 : tensor<0xi32>
}
// CHECK-LABEL: concatenate_empty_float
func @concatenate_empty_float(%arg0: tensor<0xf32>, %arg1: tensor<0xf32>) -> tensor<0xf32> {
// CHECK: xla_hlo.constant
%0 = "xla_hlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<0xf32>, tensor<0xf32>) -> tensor<0xf32>
return %0 : tensor<0xf32>
}
// CHECK-LABEL: concatenate_const_1D
func @concatenate_const_1D() -> tensor<4xi32> {
// CHECK: [[VAL:%.+]]= xla_hlo.constant dense<[0, 1, 2, 3]>
%0 = xla_hlo.constant dense<[0, 1]> : tensor<2xi32>
%1 = xla_hlo.constant dense<[2, 3]> : tensor<2xi32>
%2 = "xla_hlo.concatenate"(%0, %1) { dimension = 0 : i64 } : (tensor<2xi32>, tensor<2xi32>) -> tensor<4xi32>
// CHECK: return [[VAL]]
return %2 : tensor<4xi32>
}
// CHECK-LABEL: concatenate_const_1D_float
func @concatenate_const_1D_float() -> tensor<4xf32> {
// CHECK: [[VAL:%.+]] = xla_hlo.constant dense<[0.000000e+00, 1.000000e+00, 2.000000e+00, 3.000000e+00]>
%0 = xla_hlo.constant dense<[0.0, 1.0]> : tensor<2xf32>
%1 = xla_hlo.constant dense<[2.0, 3.0]> : tensor<2xf32>
%2 = "xla_hlo.concatenate"(%0, %1) { dimension = 0 : i64 } : (tensor<2xf32>, tensor<2xf32>) -> tensor<4xf32>
// CHECK: return [[VAL]]
return %2 : tensor<4xf32>
}
// CHECK-LABEL: concatenate_const_2D_vertical
func @concatenate_const_2D_vertical() -> tensor<2x2xi32> {
// CHECK: [[VAL:%.+]]= xla_hlo.constant dense<[
// CHECK-SAME: [0, 1], [2, 3]
// CHECK-SAME: ]>
%0 = xla_hlo.constant dense<[[0, 1]]> : tensor<1x2xi32>
%1 = xla_hlo.constant dense<[[2, 3]]> : tensor<1x2xi32>
%2 = "xla_hlo.concatenate"(%0, %1) { dimension = 0 : i64 } : (tensor<1x2xi32>, tensor<1x2xi32>) -> tensor<2x2xi32>
// CHECK: return [[VAL]]
return %2 : tensor<2x2xi32>
}
// CHECK-LABEL: concatenate_const_2D_horizontal
func @concatenate_const_2D_horizontal() -> tensor<2x2xi32> {
// CHECK: [[VAL:%.+]]= xla_hlo.constant dense<[
// CHECK-SAME: [0, 2], [1, 3]
// CHECK-SAME: ]>
%0 = xla_hlo.constant dense<[[0], [1]]> : tensor<2x1xi32>
%1 = xla_hlo.constant dense<[[2], [3]]> : tensor<2x1xi32>
%2 = "xla_hlo.concatenate"(%0, %1) { dimension = 1 : i64 } : (tensor<2x1xi32>, tensor<2x1xi32>) -> tensor<2x2xi32>
// CHECK: return [[VAL]]
return %2 : tensor<2x2xi32>
}
// CHECK-LABEL: dynamic_slice_variable_start
func @dynamic_slice_variable_start(%arg0: tensor<3x4xi32>, %arg1: tensor<i64>, %arg2: tensor<i64>) -> tensor<1x4xi32> {
// CHECK: "xla_hlo.dynamic-slice"
%1 = "xla_hlo.dynamic-slice"(%arg0, %arg1, %arg2) {slice_sizes = dense<[1, 4]> : tensor<2xi64>} : (tensor<3x4xi32>, tensor<i64>, tensor<i64>) -> tensor<1x4xi32>
return %1 : tensor<1x4xi32>
}
// CHECK-LABEL: dynamic_slice_constant_start
func @dynamic_slice_constant_start(%arg0: tensor<4xi32>) -> tensor<2xi32> {
// CHECK: %[[RESULT:.*]] = "xla_hlo.slice"(%arg0)
// CHECK-DAG-SAME: limit_indices = dense<3> : tensor<1xi64>
// CHECK-DAG-SAME: start_indices = dense<1> : tensor<1xi64>
// CHECK-DAG-SAME: strides = dense<1> : tensor<1xi64>}
// CHECK: return %[[RESULT]] : tensor<2xi32>
%0 = xla_hlo.constant dense<1> : tensor<i64>
%1 = "xla_hlo.dynamic-slice"(%arg0, %0) {slice_sizes = dense<2> : tensor<1xi64>} : (tensor<4xi32>, tensor<i64>) -> tensor<2xi32>
return %1 : tensor<2xi32>
}
// CHECK-LABEL: dynamic_slice_constant_start_dynamic_shape
func @dynamic_slice_constant_start_dynamic_shape(%arg0: tensor<?x4xi32>, %arg1: tensor<2xi64>) -> tensor<?x4xi32> {
// CHECK: %[[RESULT:.*]] = "xla_hlo.slice"(%arg0)
// CHECK-DAG-SAME: limit_indices = dense<[2, 4]> : tensor<2xi64>
// CHECK-DAG-SAME: start_indices = dense<[1, 0]> : tensor<2xi64>
// CHECK-DAG-SAME: strides = dense<1> : tensor<2xi64>
// CHECK: return %[[RESULT]] : tensor<?x4xi32>
%0 = xla_hlo.constant dense<1> : tensor<i64>
%1 = xla_hlo.constant dense<0> : tensor<i64>
%2 = "xla_hlo.dynamic-slice"(%arg0, %0, %1) {slice_sizes = dense<[1, 4]> : tensor<2xi64>} : (tensor<?x4xi32>, tensor<i64>, tensor<i64>) -> tensor<?x4xi32>
return %2 : tensor<?x4xi32>
}
// CHECK-LABEL: slice_2D_noop
// CHECK-SAME: [[ARG:%.+]]: tensor<2x2xi64>
func @slice_2D_noop(%arg0: tensor<2x2xi64>) -> tensor<2x2xi64> {
%0 = "xla_hlo.slice"(%arg0) { limit_indices = dense<[2, 2]> : tensor<2xi64>, start_indices = dense<[0, 0]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<2x2xi64>) -> (tensor<2x2xi64>)
// CHECK-NEXT: return [[ARG]]
return %0 : tensor<2x2xi64>
}
// CHECK-LABEL: slice_1D_fold
func @slice_1D_fold() -> tensor<2xi64> {
%0 = xla_hlo.constant dense<[5, 7, 9, 10]> : tensor<4xi64>
// CHECK: xla_hlo.constant dense<[7, 9]>
%1 = "xla_hlo.slice"(%0) { limit_indices = dense<[3]> : tensor<1xi64>, start_indices = dense<[1]> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} : (tensor<4xi64>) -> (tensor<2xi64>)
return %1 : tensor<2xi64>
}
// CHECK-LABEL: slice_1D_fp
func @slice_1D_fp() -> tensor<2xf32> {
%0 = xla_hlo.constant dense<[5.0, 7.0, 9.0, 10.0]> : tensor<4xf32>
// CHECK: xla_hlo.constant dense<[7.000000e+00, 9.000000e+00]>
%1 = "xla_hlo.slice"(%0) { limit_indices = dense<[3]> : tensor<1xi64>, start_indices = dense<[1]> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} : (tensor<4xf32>) -> (tensor<2xf32>)
return %1 : tensor<2xf32>
}
// CHECK-LABEL: slice_1D_strided_fold
func @slice_1D_strided_fold() -> tensor<2xi64> {
%0 = xla_hlo.constant dense<[5, 7, 9, 10]> : tensor<4xi64>
// CHECK: xla_hlo.constant dense<[7, 10]>
%1 = "xla_hlo.slice"(%0) { limit_indices = dense<[4]> : tensor<1xi64>, start_indices = dense<[1]> : tensor<1xi64>, strides = dense<2> : tensor<1xi64>} : (tensor<4xi64>) -> (tensor<2xi64>)
return %1 : tensor<2xi64>
}
// CHECK-LABEL: slice_2D_fold
func @slice_2D_fold() -> tensor<2x2xi64> {
%0 = xla_hlo.constant dense<[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15]]> : tensor<4x4xi64>
// CHECK-NEXT: xla_hlo.constant dense<[
// CHECK-SAME: [6, 7],
// CHECK-SAME: [10, 11]
// CHECK-SAME: ]>
%1 = "xla_hlo.slice"(%0) { limit_indices = dense<[3, 4]> : tensor<2xi64>, start_indices = dense<[1, 2]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<4x4xi64>) -> (tensor<2x2xi64>)
return %1 : tensor<2x2xi64>
}
// CHECK-LABEL: slice_2D_fold_horizontal
func @slice_2D_fold_horizontal() -> tensor<1x4xi64> {
%0 = xla_hlo.constant dense<[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15]]> : tensor<4x4xi64>
// CHECK-NEXT: xla_hlo.constant dense<[
// CHECK-SAME: [0, 1, 2, 3]
// CHECK-SAME: ]>
%1 = "xla_hlo.slice"(%0) { limit_indices = dense<[1, 4]> : tensor<2xi64>, start_indices = dense<[0, 0]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<4x4xi64>) -> (tensor<1x4xi64>)
return %1 : tensor<1x4xi64>
}
// CHECK-LABEL: slice_2D_fold_vertical
func @slice_2D_fold_vertical() -> tensor<4x1xi64> {
%0 = xla_hlo.constant dense<[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15]]> : tensor<4x4xi64>
// CHECK-NEXT: xla_hlo.constant dense<[
// CHECK-SAME: [2], [6], [10], [14]
// CHECK-SAME: ]>
%1 = "xla_hlo.slice"(%0) { limit_indices = dense<[4, 3]> : tensor<2xi64>, start_indices = dense<[0, 2]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<4x4xi64>) -> (tensor<4x1xi64>)
return %1 : tensor<4x1xi64>
}
// CHECK-LABEL: slice_concat_fold_first
func @slice_concat_fold_first(%arg0: tensor<1x5xf32>, %arg1: tensor<1x5xf32>) -> tensor<1x5xf32> {
%0 = "xla_hlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<1x5xf32>, tensor<1x5xf32>) -> tensor<2x5xf32>
%1 = "xla_hlo.slice"(%0) { limit_indices = dense<[1, 5]> : tensor<2xi64>, start_indices = dense<[0, 0]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<2x5xf32>) -> (tensor<1x5xf32>)
// CHECK: return %arg0
return %1 : tensor<1x5xf32>
}
// CHECK-LABEL: slice_concat_fold_second
func @slice_concat_fold_second(%arg0: tensor<1x5xf32>, %arg1: tensor<1x5xf32>) -> tensor<1x5xf32> {
%0 = "xla_hlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<1x5xf32>, tensor<1x5xf32>) -> tensor<2x5xf32>
%1 = "xla_hlo.slice"(%0) { limit_indices = dense<[2, 5]> : tensor<2xi64>, start_indices = dense<[1, 0]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<2x5xf32>) -> (tensor<1x5xf32>)
// CHECK: return %arg1
return %1 : tensor<1x5xf32>
}
// CHECK-LABEL: slice_concat_fold_second_with_slice
func @slice_concat_fold_second_with_slice(%arg0: tensor<1x5xf32>, %arg1: tensor<1x5xf32>) -> tensor<1x4xf32> {
%0 = "xla_hlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<1x5xf32>, tensor<1x5xf32>) -> tensor<2x5xf32>
// CHECK: [[SLICE:%.+]] = "xla_hlo.slice"(%arg1) {limit_indices = dense<[1, 5]> : tensor<2xi64>, start_indices = dense<[0, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<1x5xf32>) -> tensor<1x4xf32>
%1 = "xla_hlo.slice"(%0) { limit_indices = dense<[2, 5]> : tensor<2xi64>, start_indices = dense<[1, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<2x5xf32>) -> (tensor<1x4xf32>)
// CHECK: return [[SLICE]]
return %1 : tensor<1x4xf32>
}
// CHECK-LABEL: slice_concat_fold_middle
func @slice_concat_fold_middle(%arg0: tensor<1x5xf32>, %arg1: tensor<2x5xf32>, %arg2: tensor<1x5xf32>) -> tensor<1x5xf32> {
%0 = "xla_hlo.concatenate"(%arg0, %arg1, %arg2) { dimension = 0 : i64 } : (tensor<1x5xf32>, tensor<2x5xf32>, tensor<1x5xf32>) -> tensor<4x5xf32>
// CHECK: [[SLICE:%.+]] = "xla_hlo.slice"(%arg1) {limit_indices = dense<[2, 5]> : tensor<2xi64>, start_indices = dense<[1, 0]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}
%1 = "xla_hlo.slice"(%0) { limit_indices = dense<[3, 5]> : tensor<2xi64>, start_indices = dense<[2, 0]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<4x5xf32>) -> (tensor<1x5xf32>)
// CHECK: return [[SLICE]]
return %1 : tensor<1x5xf32>
}
// CHECK-LABEL: slice_concat_fold_two
func @slice_concat_fold_two(%arg0: tensor<1x5xf32>, %arg1: tensor<2x5xf32>, %arg2: tensor<1x5xf32>) -> tensor<2x5xf32> {
// CHECK: [[CONCAT:%.+]] = "xla_hlo.concatenate"(%arg1, %arg2) {dimension = 0 : i64}
%0 = "xla_hlo.concatenate"(%arg0, %arg1, %arg2) { dimension = 0 : i64 } : (tensor<1x5xf32>, tensor<2x5xf32>, tensor<1x5xf32>) -> tensor<4x5xf32>
// CHECK: [[SLICE:%.+]] = "xla_hlo.slice"([[CONCAT]]) {limit_indices = dense<[3, 5]> : tensor<2xi64>, start_indices = dense<[1, 0]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}
%1 = "xla_hlo.slice"(%0) { limit_indices = dense<[4, 5]> : tensor<2xi64>, start_indices = dense<[2, 0]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<4x5xf32>) -> (tensor<2x5xf32>)
// CHECK: return [[SLICE]]
return %1 : tensor<2x5xf32>
}
// CHECK-LABEL: func @broadcast_in_dim_identity
func @broadcast_in_dim_identity(%arg0: tensor<2x3x4xf32>) -> tensor<2x3x4xf32> {
// CHECK: return %arg0
%0 = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<2x3x4xf32>) -> tensor<2x3x4xf32>
return %0 : tensor<2x3x4xf32>
}
// CHECK-LABEL: func @broadcast_in_dim_not_identity_because_it_actually_broadcasts
func @broadcast_in_dim_not_identity_because_it_actually_broadcasts(%arg0: tensor<1x2xf32>) -> tensor<2x2xf32> {
// CHECK: xla_hlo.broadcast_in_dim
%0 = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<1x2xf32>) -> tensor<2x2xf32>
return %0 : tensor<2x2xf32>
}
// CHECK-LABEL: func @broadcast_in_dim_not_identity_permutation
func @broadcast_in_dim_not_identity_permutation(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
// CHECK: xla_hlo.broadcast_in_dim
%0 = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[1, 0]> : tensor<2xi64>} : (tensor<2x2xf32>) -> tensor<2x2xf32>
return %0 : tensor<2x2xf32>
}
// CHECK-LABEL: func @dynamic_broadcast_in_dim_op_not_actually_dynamic
func @dynamic_broadcast_in_dim_op_not_actually_dynamic(%arg0: tensor<4xf32>, %arg1: tensor<2xi64>) -> tensor<5x4xf32> {
// CHECK: %[[RESULT:.+]] = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<4xf32>) -> tensor<5x4xf32>
%0 = "xla_hlo.dynamic_broadcast_in_dim"(%arg0, %arg1) { broadcast_dimensions = dense<1> : tensor<1xi64> } : (tensor<4xf32>, tensor<2xi64>) -> tensor<5x4xf32>
// CHECK: return %[[RESULT]] : tensor<5x4xf32>
return %0 : tensor<5x4xf32>
}
// CHECK-LABEL: @complex_expand_fold
func @complex_expand_fold(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> (tensor<4xf32>, tensor<4xf32>) {
%0 = "xla_hlo.complex"(%arg0, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> (tensor<4xcomplex<f32>>)
%1 = "xla_hlo.real"(%0) : (tensor<4xcomplex<f32>>) -> (tensor<4xf32>)
%2 = "xla_hlo.imag"(%0) : (tensor<4xcomplex<f32>>) -> (tensor<4xf32>)
// CHECK: return %arg0, %arg1
return %1, %2 : tensor<4xf32>, tensor<4xf32>
}
// CHECK-LABEL: @complex_collapse_fold
func @complex_collapse_fold(%arg0: tensor<4xcomplex<f32>>) -> tensor<4xcomplex<f32>> {
%0 = "xla_hlo.real"(%arg0) : (tensor<4xcomplex<f32>>) -> (tensor<4xf32>)
%1 = "xla_hlo.imag"(%arg0) : (tensor<4xcomplex<f32>>) -> (tensor<4xf32>)
%2 = "xla_hlo.complex"(%0, %1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xcomplex<f32>>
// CHECK: return %arg0
return %2 : tensor<4xcomplex<f32>>
}
// CHECK-LABEL: @dynamic_iota_is_static
func @dynamic_iota_is_static(%arg0 : tensor<1xindex>) -> tensor<4xi32> {
// CHECK: [[RESULT:%.*]] = "xla_hlo.iota"
// CHECK: return [[RESULT]]
%0 = "xla_hlo.dynamic_iota"(%arg0) {iota_dimension = 0 : i64} : (tensor<1xindex>) -> tensor<4xi32>
return %0 : tensor<4xi32>
}
// CHECK-LABEL: @iota_not_lowered_to_constant
func @iota_not_lowered_to_constant() -> tensor<4xi32> {
// CHECK: [[RESULT:%.*]] = "xla_hlo.iota"
// CHECK: return [[RESULT]]
%0 = "xla_hlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<4xi32>
return %0 : tensor<4xi32>
}
// CHECK-LABEL: @unary_einsum
func @unary_einsum(%arg0: tensor<2x3xf32>) -> tensor<2x2xf32> {
// CHECK: %[[ONE:.*]] = xla_hlo.constant dense<1.000000e+00> : tensor<f32>
// CHECK: "xla_hlo.einsum"(%[[ONE]], %arg0) {einsum_config = ",ab->aa"}
%0 = "xla_hlo.unary_einsum"(%arg0) {einsum_config = "ab->aa"} : (tensor<2x3xf32>) -> tensor<2x2xf32>
return %0 : tensor<2x2xf32>
}
// CHECK-LABEL: func @fold_copy
// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]]
func @fold_copy(%arg : tensor<1x4xf32>) -> tensor<1x4xf32> {
// CHECK: return [[ARG]]
%0 = "xla_hlo.copy"(%arg) : (tensor<1x4xf32>) -> tensor<1x4xf32>
return %0 : tensor<1x4xf32>
}
// CHECK-LABEL: func @dynamic_reshape_not_actually_dynamic
func @dynamic_reshape_not_actually_dynamic(%arg0: tensor<4xf32>, %shape: tensor<2xindex>) -> tensor<4x1xf32> {
// CHECK: xla_hlo.reshape
%0 = "xla_hlo.dynamic_reshape"(%arg0, %shape) : (tensor<4xf32>, tensor<2xindex>) -> tensor<4x1xf32>
return %0 : tensor<4x1xf32>
}
// CHECK-LABEL: do_not_dce_while_with_outfeed
func @do_not_dce_while_with_outfeed(%arg0: tensor<i64>) -> tensor<i64> {
// CHECK: xla_hlo.while
%0 = "xla_hlo.while"(%arg0) ( {
^bb0(%arg1: tensor<i64>):
%1 = "xla_hlo.compare"(%arg1, %arg1) {comparison_direction = "LT"} : (tensor<i64>, tensor<i64>) -> tensor<i1>
"xla_hlo.return"(%1) : (tensor<i1>) -> ()
}, {
^bb0(%arg1: tensor<i64>):
%1 = "xla_hlo.create_token"() : () -> !xla_hlo.token
// Side-effecting op outfeed present inside while.
%2 = "xla_hlo.outfeed"(%arg1, %1) {outfeed_config = ""} : (tensor<i64>, !xla_hlo.token) -> !xla_hlo.token
"xla_hlo.return"(%arg1) : (tensor<i64>) -> ()
}) : (tensor<i64>) -> tensor<i64>
return %arg0 : tensor<i64>
}
// CHECK-LABEL: dce_while_without_side_effect
func @dce_while_without_side_effect(%arg0: tensor<i64>) -> tensor<i64> {
// CHECK-NOT: xla_hlo.while
%0 = "xla_hlo.while"(%arg0) ( {
^bb0(%arg1: tensor<i64>):
%1 = "xla_hlo.compare"(%arg1, %arg1) {comparison_direction = "LT"} : (tensor<i64>, tensor<i64>) -> tensor<i1>
"xla_hlo.return"(%1) : (tensor<i1>) -> ()
}, {
^bb0(%arg1: tensor<i64>):
%1 = "xla_hlo.create_token"() : () -> !xla_hlo.token
"xla_hlo.return"(%arg1) : (tensor<i64>) -> ()
}) : (tensor<i64>) -> tensor<i64>
return %arg0 : tensor<i64>
}

View File

@ -0,0 +1,56 @@
// RUN: mlir-hlo-opt -test-xla-infer-shaped-type-methods -allow-unregistered-dialect -split-input-file -verify-diagnostics %s -o - | FileCheck %s
// CHECK-LABEL: @broadcast_add
// Note that all broadcast_ops are expanded from the same template, so
// only test reification on an examplar op.
// CHECK-SAME: %[[ARG0:.+]]: tensor<?xf32>,
// CHECK-SAME: %[[ARG1:.+]]: tensor<?xf32>
func @broadcast_add(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>) -> tensor<1xindex> {
// CHECK-DAG: %[[ARG0_S:.+]] = shape.shape_of %[[ARG0]]
// CHECK-DAG: %[[ARG1_S:.+]] = shape.shape_of %[[ARG1]]
// CHECK-DAG: %[[BCAST_S:.+]] = "shape.broadcast"(%[[ARG0_S]], %[[ARG1_S]])
// CHECK: %[[EXTENTS:.+]] = shape.to_extent_tensor %[[BCAST_S]]
// CHECK: return %[[EXTENTS]]
%0 = xla_chlo.broadcast_add %arg0, %arg1 : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
%1 = "xla_test.reify_return_type_shapes"(%0) : (tensor<?xf32>) -> tensor<1xindex>
return %1 : tensor<1xindex>
}
// -----
// CHECK-LABEL: @complex_ranked_components
func @complex_ranked_components(%arg0: tensor<?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x?xcomplex<f32>> {
%0 = xla_chlo.broadcast_complex %arg0, %arg1 : (tensor<?xf32>, tensor<?x?xf32>) -> tensor<?x?xcomplex<f32>>
// CHECK: "xla_test.return_type_components"(%0) {dims0 = [-1, -1], element_type0 = complex<f32>}
%1 = "xla_test.get_return_type_components"(%0) : (tensor<?x?xcomplex<f32>>) -> tensor<?x?xcomplex<f32>>
return %1 : tensor<?x?xcomplex<f32>>
}
// -----
// CHECK-LABEL: @compare_ranked_components
func @compare_ranked_components(%arg0: tensor<?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x?xi1> {
%0 = xla_chlo.broadcast_compare %arg0, %arg1 {comparison_direction = "EQ"} : (tensor<?xf32>, tensor<?x?xf32>) -> tensor<?x?xi1>
// CHECK: "xla_test.return_type_components"(%0) {dims0 = [-1, -1], element_type0 = i1}
%1 = "xla_test.get_return_type_components"(%0) : (tensor<?x?xi1>) -> tensor<?x?xi1>
return %0 : tensor<?x?xi1>
}
// -----
// CHECK-LABEL: @broadcast_add_ranked_components_r1
func @broadcast_add_ranked_components_r1(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>) -> tensor<?xf32> {
%0 = xla_chlo.broadcast_add %arg0, %arg1 : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
// CHECK: "xla_test.return_type_components"(%0) {dims0 = [-1], element_type0 = f32}
%1 = "xla_test.get_return_type_components"(%0) : (tensor<?xf32>) -> tensor<?xf32>
return %1 : tensor<?xf32>
}
// -----
// CHECK-LABEL: @broadcast_add_ranked_components_r1x2
func @broadcast_add_ranked_components_r1x2(%arg0: tensor<?xf32>, %arg1: tensor<?x3xf32>) -> tensor<?x3xf32> {
%0 = xla_chlo.broadcast_add %arg0, %arg1 : (tensor<?xf32>, tensor<?x3xf32>) -> tensor<?x3xf32>
// TODO: Overly broad shapes are being returned. Tighten the calculation
// and update/extend these tests.
// CHECK: "xla_test.return_type_components"(%0) {dims0 = [-1, -1], element_type0 = f32}
%1 = "xla_test.get_return_type_components"(%0) : (tensor<?x3xf32>) -> tensor<?x3xf32>
return %1 : tensor<?x3xf32>
}

View File

@ -0,0 +1,239 @@
// RUN: mlir-hlo-opt -test-xla-chlo-legalize-to-hlo -cse -split-input-file -verify-diagnostics %s -o - | FileCheck %s
// Check the non-broadcast case for each registered op, then just check a
// representative op for detailed broadcast semantics.
// CHECK-LABEL: @addWithoutBroadcast
func @addWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
// CHECK: xla_hlo.add %arg0, %arg1
%0 = xla_chlo.broadcast_add %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
return %0 : tensor<4xf32>
}
// -----
// CHECK-LABEL: @dynamicBroadcast
// CHECK-SAME: %[[ARG0:.+]]: tensor<?xf32>
// CHECK-SAME: %[[ARG1:.+]]: tensor<?x?xf32>
func @dynamicBroadcast(%arg0: tensor<?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x?xf32> {
// CHECK-DAG: %[[ARG0_S:.+]] = shape.shape_of %[[ARG0]]
// CHECK-DAG: %[[ARG1_S:.+]] = shape.shape_of %[[ARG1]]
// CHECK-NEXT: %[[WITNESS:.+]] = shape.cstr_broadcastable %[[ARG0_S]], %[[ARG1_S]]
// CHECK-NEXT: %[[FINAL_RESULT:.+]] = shape.assuming %[[WITNESS]]
// CHECK-DAG: %[[RESULT_S:.+]] = "shape.broadcast"(%[[ARG0_S]], %[[ARG1_S]])
// CHECK: %[[RESULT_EXTENTS:.+]] = shape.to_extent_tensor %[[RESULT_S]]
// CHECK-DAG: %[[ARG0_B:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[ARG0]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<1> : tensor<1xi64>}
// CHECK-DAG: %[[ARG1_B:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[ARG1]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>}
// CHECK-NEXT: %[[RESULT:.+]] = xla_hlo.add %[[ARG0_B]], %[[ARG1_B]]
// CHECK-NEXT: shape.assuming_yield %[[RESULT]]
// CHECK-NEXT: }
// CHECK-NEXT: return %[[FINAL_RESULT]] : tensor<?x?xf32>
%0 = xla_chlo.broadcast_add %arg0, %arg1 : (tensor<?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
return %0 : tensor<?x?xf32>
}
// -----
// CHECK-LABEL: @dynamicBroadcastComplex
// CHECK-SAME: %[[ARG0:.+]]: tensor<?xf32>
// CHECK-SAME: %[[ARG1:.+]]: tensor<?x?xf32>
func @dynamicBroadcastComplex(%arg0: tensor<?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x?xcomplex<f32>> {
// CHECK-DAG: %[[ARG0_S:.+]] = shape.shape_of %[[ARG0]]
// CHECK-DAG: %[[ARG1_S:.+]] = shape.shape_of %[[ARG1]]
// CHECK-NEXT: %[[WITNESS:.+]] = shape.cstr_broadcastable %[[ARG0_S]], %[[ARG1_S]]
// CHECK-NEXT: %[[FINAL_RESULT:.+]] = shape.assuming %[[WITNESS]]
// CHECK-NEXT: %[[RESULT_S:.+]] = "shape.broadcast"(%[[ARG0_S]], %[[ARG1_S]])
// CHECK-NEXT: %[[RESULT_EXTENTS:.+]] = shape.to_extent_tensor %[[RESULT_S]]
// CHECK-DAG: %[[ARG0_B:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[ARG0]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<?xf32>, tensor<2xindex>) -> tensor<?x?xf32>
// CHECK-DAG: %[[ARG1_B:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[ARG1]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<?x?xf32>, tensor<2xindex>) -> tensor<?x?xf32>
// CHECK-NEXT: %[[RESULT:.+]] = "xla_hlo.complex"(%[[ARG0_B]], %[[ARG1_B]]) : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xcomplex<f32>>
// CHECK-NEXT: shape.assuming_yield %[[RESULT]]
// CHECK-NEXT: }
// CHECK-NEXT: return %[[FINAL_RESULT]] : tensor<?x?xcomplex<f32>>
%0 = xla_chlo.broadcast_complex %arg0, %arg1 : (tensor<?xf32>, tensor<?x?xf32>) -> tensor<?x?xcomplex<f32>>
return %0 : tensor<?x?xcomplex<f32>>
}
// -----
// CHECK-LABEL: @dynamicBroadcastCompare
// CHECK-SAME: %[[ARG0:.+]]: tensor<?xf32>
// CHECK-SAME: %[[ARG1:.+]]: tensor<?x?xf32>
func @dynamicBroadcastCompare(%arg0: tensor<?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x?xi1> {
// CHECK-DAG: %[[ARG0_S:.+]] = shape.shape_of %[[ARG0]]
// CHECK-DAG: %[[ARG1_S:.+]] = shape.shape_of %[[ARG1]]
// CHECK: %[[WITNESS:.+]] = shape.cstr_broadcastable %[[ARG0_S]], %[[ARG1_S]]
// CHECK: %[[FINAL_RESULT:.+]] = shape.assuming %[[WITNESS]]
// CHECK: %[[RESULT_S:.+]] = "shape.broadcast"(%[[ARG0_S]], %[[ARG1_S]])
// CHECK: %[[RESULT_EXTENTS:.+]] = shape.to_extent_tensor %[[RESULT_S]]
// CHECK-DAG: %[[ARG0_B:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[ARG0]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<?xf32>, tensor<2xindex>) -> tensor<?x?xf32>
// CHECK-DAG: %[[ARG1_B:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[ARG1]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<?x?xf32>, tensor<2xindex>) -> tensor<?x?xf32>
// CHECK: %[[RESULT:.+]] = "xla_hlo.compare"(%[[ARG0_B]], %[[ARG1_B]]) {comparison_direction = "EQ"} : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xi1>
// CHECK: shape.assuming_yield %[[RESULT]]
// CHECK-NEXT: }
// CHECK: return %[[FINAL_RESULT]] : tensor<?x?xi1>
%0 = xla_chlo.broadcast_compare %arg0, %arg1 {comparison_direction = "EQ"} : (tensor<?xf32>, tensor<?x?xf32>) -> tensor<?x?xi1>
return %0 : tensor<?x?xi1>
}
// -----
// Verifies that broadcast_dimensions validity checks are valid.
// CHECK-LABEL: @dynamicNonScalarBroadcastDimensions
func @dynamicNonScalarBroadcastDimensions(%arg0: tensor<1x4xf32>, %arg1: tensor<4xf32>) -> tensor<1x4xf32> {
// CHECK: xla_hlo.add
%0 = xla_chlo.broadcast_add %arg0, %arg1 {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1x4xf32>, tensor<4xf32>) -> tensor<1x4xf32>
return %0 : tensor<1x4xf32>
}
// -----
// Verifies that broadcast_dimensions validity checks are valid.
// CHECK-LABEL: @dynamicNonScalarByScalarBroadcastDimensions
func @dynamicNonScalarByScalarBroadcastDimensions(%arg0: tensor<1x4xf32>, %arg1: tensor<f32>) -> tensor<1x4xf32> {
// CHECK: xla_hlo.add
%0 = xla_chlo.broadcast_add %arg0, %arg1 {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<1x4xf32>, tensor<f32>) -> tensor<1x4xf32>
return %0 : tensor<1x4xf32>
}
// -----
// Verifies that invalid broadcast dimensions are rejected.
func @dynamicNonScalarBroadcastDimensionsSizeMismatch(%arg0: tensor<1x4xf32>, %arg1: tensor<4xf32>) -> tensor<1x4xf32> {
// expected-warning @+2 {{unsupported non prefix-padded dynamic rank broadcast_dimensions}}
// expected-error @+1 {{failed to legalize operation}}
%0 = xla_chlo.broadcast_add %arg0, %arg1 {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (tensor<1x4xf32>, tensor<4xf32>) -> tensor<1x4xf32>
return %0 : tensor<1x4xf32>
}
// -----
// Verifies that invalid broadcast dimensions are rejected.
func @dynamicNonScalarBroadcastDimensionsMismatch(%arg0: tensor<1x4xf32>, %arg1: tensor<4xf32>) -> tensor<1x4xf32> {
// expected-warning @+2 {{unsupported non prefix-padded dynamic rank broadcast_dimensions}}
// expected-error @+1 {{failed to legalize operation}}
%0 = xla_chlo.broadcast_add %arg0, %arg1 {broadcast_dimensions = dense<2> : tensor<1xi64>} : (tensor<1x4xf32>, tensor<4xf32>) -> tensor<1x4xf32>
return %0 : tensor<1x4xf32>
}
// -----
// Note that broadcast_add is used as a proxy for all of the template
// expansions. Tests below merely verify that the op has an expansion.
// CHECK-LABEL: @andWithoutBroadcast
func @andWithoutBroadcast(%arg0: tensor<4xi1>, %arg1: tensor<4xi1>) -> tensor<4xi1> {
// CHECK: xla_hlo.and %arg0, %arg1
%0 = xla_chlo.broadcast_and %arg0, %arg1 : (tensor<4xi1>, tensor<4xi1>) -> tensor<4xi1>
return %0 : tensor<4xi1>
}
// -----
// CHECK-LABEL: @atan2WithoutBroadcast
func @atan2WithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
// CHECK: xla_hlo.atan2 %arg0, %arg1
%0 = xla_chlo.broadcast_atan2 %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
return %0 : tensor<4xf32>
}
// -----
// CHECK-LABEL: @compareWithoutBroadcast
func @compareWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xi1> {
// CHECK: "xla_hlo.compare"(%arg0, %arg1) {comparison_direction = "EQ"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1>
%0 = xla_chlo.broadcast_compare %arg0, %arg1 {comparison_direction = "EQ"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1>
return %0 : tensor<4xi1>
}
// -----
// CHECK-LABEL: @complexWithoutBroadcast
func @complexWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xcomplex<f32>> {
// CHECK: "xla_hlo.complex"(%arg0, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xcomplex<f32>>
%0 = xla_chlo.broadcast_complex %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xcomplex<f32>>
return %0 : tensor<4xcomplex<f32>>
}
// -----
// CHECK-LABEL: @divideWithoutBroadcast
func @divideWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
// CHECK: xla_hlo.divide %arg0, %arg1
%0 = xla_chlo.broadcast_divide %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
return %0 : tensor<4xf32>
}
// -----
// CHECK-LABEL: @maximumWithoutBroadcast
func @maximumWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
// CHECK: xla_hlo.maximum %arg0, %arg1
%0 = xla_chlo.broadcast_maximum %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
return %0 : tensor<4xf32>
}
// -----
// CHECK-LABEL: @minimumWithoutBroadcast
func @minimumWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
// CHECK: xla_hlo.minimum %arg0, %arg1
%0 = xla_chlo.broadcast_minimum %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
return %0 : tensor<4xf32>
}
// -----
// CHECK-LABEL: @multiplyWithoutBroadcast
func @multiplyWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
// CHECK: xla_hlo.multiply %arg0, %arg1
%0 = xla_chlo.broadcast_multiply %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
return %0 : tensor<4xf32>
}
// -----
// CHECK-LABEL: @orWithoutBroadcast
func @orWithoutBroadcast(%arg0: tensor<4xi1>, %arg1: tensor<4xi1>) -> tensor<4xi1> {
// CHECK: xla_hlo.or %arg0, %arg1
%0 = xla_chlo.broadcast_or %arg0, %arg1 : (tensor<4xi1>, tensor<4xi1>) -> tensor<4xi1>
return %0 : tensor<4xi1>
}
// -----
// CHECK-LABEL: @powerWithoutBroadcast
func @powerWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
// CHECK: xla_hlo.power %arg0, %arg1
%0 = xla_chlo.broadcast_power %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
return %0 : tensor<4xf32>
}
// -----
// CHECK-LABEL: @remainderWithoutBroadcast
func @remainderWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
// CHECK: xla_hlo.remainder %arg0, %arg1
%0 = xla_chlo.broadcast_remainder %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
return %0 : tensor<4xf32>
}
// -----
// CHECK-LABEL: @shift_leftWithoutBroadcast
func @shift_leftWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
// CHECK: xla_hlo.shift_left %arg0, %arg1
%0 = xla_chlo.broadcast_shift_left %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
return %0 : tensor<4xf32>
}
// -----
// CHECK-LABEL: @shift_right_arithmeticWithoutBroadcast
func @shift_right_arithmeticWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
// CHECK: xla_hlo.shift_right_arithmetic %arg0, %arg1
%0 = xla_chlo.broadcast_shift_right_arithmetic %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
return %0 : tensor<4xf32>
}
// -----
// CHECK-LABEL: @shift_right_logicalWithoutBroadcast
func @shift_right_logicalWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
// CHECK: xla_hlo.shift_right_logical %arg0, %arg1
%0 = xla_chlo.broadcast_shift_right_logical %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
return %0 : tensor<4xf32>
}
// -----
// CHECK-LABEL: @subWithoutBroadcast
func @subWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
// CHECK: xla_hlo.subtract %arg0, %arg1
%0 = xla_chlo.broadcast_subtract %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
return %0 : tensor<4xf32>
}
// -----
// CHECK-LABEL: @xorWithoutBroadcast
func @xorWithoutBroadcast(%arg0: tensor<4xi1>, %arg1: tensor<4xi1>) -> tensor<4xi1> {
// CHECK: xla_hlo.xor %arg0, %arg1
%0 = xla_chlo.broadcast_xor %arg0, %arg1 : (tensor<4xi1>, tensor<4xi1>) -> tensor<4xi1>
return %0 : tensor<4xi1>
}

9
tests/concatenate.mlir Normal file
View File

@ -0,0 +1,9 @@
// RUN: mlir-hlo-opt %s -pass-pipeline='func(canonicalize)' | FileCheck %s
// CHECK-LABEL: func @single_operand
// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]]
func @single_operand(%arg: tensor<1x2xf32>) -> tensor<1x2xf32> {
%0 = "xla_hlo.concatenate"(%arg) {dimension = 0 : i64} : (tensor<1x2xf32>) -> tensor<1x2xf32>
// CHECK-NEXT: return [[ARG]]
return %0 : tensor<1x2xf32>
}

225
tests/convert.mlir Normal file
View File

@ -0,0 +1,225 @@
// RUN: mlir-hlo-opt %s -split-input-file -pass-pipeline='func(canonicalize)' | FileCheck %s
// -----
// CHECK-LABEL: func @same_type
// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]]
func @same_type(%arg: tensor<f32>) -> tensor<f32> {
%0 = "xla_hlo.convert"(%arg) : (tensor<f32>) -> tensor<f32>
// CHECK-NEXT: return [[ARG]]
return %0 : tensor<f32>
}
// -----
// CHECK-LABEL: func @int_widening
// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]]
func @int_widening(%arg: tensor<i32>) -> tensor<i64> {
// CHECK-NEXT: [[RES:%.+]] = "xla_hlo.convert"([[ARG]]) : (tensor<i32>) -> tensor<i64>
%0 = "xla_hlo.convert"(%arg) : (tensor<i32>) -> tensor<i64>
// CHECK-NEXT: return [[RES]]
return %0 : tensor<i64>
}
// -----
// CHECK-LABEL: func @int_narrowing
// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]]
func @int_narrowing(%arg: tensor<i32>) -> tensor<i16> {
// CHECK-NEXT: [[RES:%.+]] = "xla_hlo.convert"([[ARG]]) : (tensor<i32>) -> tensor<i16>
%0 = "xla_hlo.convert"(%arg) : (tensor<i32>) -> tensor<i16>
// CHECK-NEXT: return [[RES]]
return %0 : tensor<i16>
}
// -----
// CHECK-LABEL: func @float_int
// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]]
func @float_int(%arg: tensor<f32>) -> tensor<i32> {
// CHECK-NEXT: [[RES:%.+]] = "xla_hlo.convert"([[ARG]]) : (tensor<f32>) -> tensor<i32>
%0 = "xla_hlo.convert"(%arg) : (tensor<f32>) -> tensor<i32>
// CHECK-NEXT: return [[RES]]
return %0 : tensor<i32>
}
// -----
// CHECK-LABEL: func @int_float
// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]]
func @int_float(%arg: tensor<i32>) -> tensor<f32> {
// CHECK-NEXT: [[RES:%.+]] = "xla_hlo.convert"([[ARG]]) : (tensor<i32>) -> tensor<f32>
%0 = "xla_hlo.convert"(%arg) : (tensor<i32>) -> tensor<f32>
// CHECK-NEXT: return [[RES]]
return %0 : tensor<f32>
}
// -----
// CHECK-LABEL: func @high_rank_tensor
// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]]
func @high_rank_tensor(%arg: tensor<2x3xi32>) -> tensor<2x3xf32> {
// CHECK-NEXT: [[RES:%.+]] = "xla_hlo.convert"([[ARG]]) : (tensor<2x3xi32>) -> tensor<2x3xf32>
%0 = "xla_hlo.convert"(%arg) : (tensor<2x3xi32>) -> tensor<2x3xf32>
// CHECK-NEXT: return [[RES]]
return %0 : tensor<2x3xf32>
}
// -----
// CHECK-LABEL: func @const_same_type
func @const_same_type() -> tensor<i32> {
// CHECK-NEXT: [[CST:%.+]] = xla_hlo.constant dense<42> : tensor<i32>
%cst = xla_hlo.constant dense<42> : tensor<i32>
%0 = "xla_hlo.convert"(%cst) : (tensor<i32>) -> tensor<i32>
// CHECK-NEXT: return [[CST]]
return %0 : tensor<i32>
}
// -----
// CHECK-LABEL: func @const_float_int
func @const_float_int() -> tensor<i32> {
// CHECK-NEXT: [[CST:%.+]] = xla_hlo.constant dense<42> : tensor<i32>
%cst = xla_hlo.constant dense<42.0> : tensor<f32>
%0 = "xla_hlo.convert"(%cst) : (tensor<f32>) -> tensor<i32>
// CHECK-NEXT: return [[CST]]
return %0 : tensor<i32>
}
// -----
// CHECK-LABEL: func @const_int_float
func @const_int_float() -> tensor<f32> {
// CHECK-NEXT: [[CST:%.+]] = xla_hlo.constant dense<4.{{0*}}e+00> : tensor<f32>
%cst = xla_hlo.constant dense<4> : tensor<i32>
%0 = "xla_hlo.convert"(%cst) : (tensor<i32>) -> tensor<f32>
// CHECK-NEXT: return [[CST]]
return %0 : tensor<f32>
}
// -----
// CHECK-LABEL: func @const_negative_int_float
func @const_negative_int_float() -> tensor<f32> {
// CHECK-NEXT: [[CST:%.+]] = xla_hlo.constant dense<-4.{{0*}}e+00> : tensor<f32>
%cst = xla_hlo.constant dense<-4> : tensor<i32>
%0 = "xla_hlo.convert"(%cst) : (tensor<i32>) -> tensor<f32>
// CHECK-NEXT: return [[CST]]
return %0 : tensor<f32>
}
// -----
// CHECK-LABEL: func @const_int_bf16
func @const_int_bf16() -> tensor<bf16> {
// CHECK-NEXT: [[CST:%.+]] = xla_hlo.constant dense<4.{{0*}}e+00> : tensor<bf16>
%cst = xla_hlo.constant dense<4> : tensor<i32>
%0 = "xla_hlo.convert"(%cst) : (tensor<i32>) -> tensor<bf16>
// CHECK-NEXT: return [[CST]]
return %0 : tensor<bf16>
}
// -----
// CHECK-LABEL: func @const_bf16_int
func @const_bf16_int() -> tensor<i16> {
// CHECK-NEXT: [[CST:%.+]] = xla_hlo.constant dense<42> : tensor<i16>
%cst = xla_hlo.constant dense<42.0> : tensor<bf16>
%0 = "xla_hlo.convert"(%cst) : (tensor<bf16>) -> tensor<i16>
// CHECK-NEXT: return [[CST]]
return %0 : tensor<i16>
}
// -----
// CHECK-LABEL: func @const_int_narrowing
func @const_int_narrowing() -> tensor<i32> {
// CHECK-NEXT: [[CST:%.+]] = xla_hlo.constant dense<42> : tensor<i32>
%cst = xla_hlo.constant dense<42> : tensor<i64>
%0 = "xla_hlo.convert"(%cst) : (tensor<i64>) -> tensor<i32>
// CHECK-NEXT: return [[CST]]
return %0 : tensor<i32>
}
// -----
// CHECK-LABEL: func @const_int_widening
func @const_int_widening() -> tensor<i64> {
// CHECK-NEXT: [[CST:%.+]] = xla_hlo.constant dense<42> : tensor<i64>
%cst = xla_hlo.constant dense<42> : tensor<i32>
%0 = "xla_hlo.convert"(%cst) : (tensor<i32>) -> tensor<i64>
// CHECK-NEXT: return [[CST]]
return %0 : tensor<i64>
}
// -----
// CHECK-LABEL: func @const_negative_int_widening
func @const_negative_int_widening() -> tensor<i64> {
// CHECK-NEXT: [[CST:%.+]] = xla_hlo.constant dense<-42> : tensor<i64>
%cst = xla_hlo.constant dense<-42> : tensor<i32>
%0 = "xla_hlo.convert"(%cst) : (tensor<i32>) -> tensor<i64>
// CHECK-NEXT: return [[CST]]
return %0 : tensor<i64>
}
// -----
// CHECK-LABEL: func @const_float_narrowing
func @const_float_narrowing() -> tensor<f32> {
// CHECK-NEXT: [[CST:%.+]] = xla_hlo.constant dense<4.2{{0*}}e+00> : tensor<f32>
%cst = xla_hlo.constant dense<4.2> : tensor<f64>
%0 = "xla_hlo.convert"(%cst) : (tensor<f64>) -> tensor<f32>
// CHECK-NEXT: return [[CST]]
return %0 : tensor<f32>
}
// -----
// CHECK-LABEL: func @const_f32_bf16
func @const_f32_bf16() -> tensor<bf16> {
// CHECK-NEXT: [[CST:%.+]] = xla_hlo.constant dense<4.2{{0*}}e+01> : tensor<bf16>
%cst = xla_hlo.constant dense<42.0> : tensor<f32>
%0 = "xla_hlo.convert"(%cst) : (tensor<f32>) -> tensor<bf16>
// CHECK-NEXT: return [[CST]]
return %0 : tensor<bf16>
}
// -----
// CHECK-LABEL: func @const_bf16_f64
func @const_bf16_f64() -> tensor<f64> {
// CHECK-NEXT: [[CST:%.+]] = xla_hlo.constant dense<4.187500e+00> : tensor<f64>
%cst = xla_hlo.constant dense<4.2> : tensor<bf16>
%0 = "xla_hlo.convert"(%cst) : (tensor<bf16>) -> tensor<f64>
// CHECK-NEXT: return [[CST]]
return %0 : tensor<f64>
}
// -----
// CHECK-LABEL: func @const_bf16_int
func @const_bf16_int() -> tensor<i64> {
// CHECK-NEXT: [[CST:%.+]] = xla_hlo.constant dense<42> : tensor<i64>
%cst = xla_hlo.constant dense<42.0> : tensor<bf16>
%0 = "xla_hlo.convert"(%cst) : (tensor<bf16>) -> tensor<i64>
// CHECK-NEXT: return [[CST]]
return %0 : tensor<i64>
}
// -----
// CHECK-LABEL: func @const_high_rank_tensor
func @const_high_rank_tensor() -> tensor<2x3xi32> {
// CHECK-NEXT: [[CST:%.+]] = xla_hlo.constant dense<[
// CHECK-SAME: [1, 2, 3], [4, 5, 6]
// CHECK-SAME: ]> : tensor<2x3xi32>
%cst = xla_hlo.constant dense<[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]> : tensor<2x3xf32>
%0 = "xla_hlo.convert"(%cst) : (tensor<2x3xf32>) -> tensor<2x3xi32>
// CHECK-NEXT: return [[CST]]
return %0 : tensor<2x3xi32>
}

View File

@ -0,0 +1,488 @@
// RUN: mlir-hlo-opt -hlo-legalize-to-lhlo -buffer-placement -split-input-file %s -o - | FileCheck --check-prefixes=PRE,BOTH %s
// RUN: mlir-hlo-opt -hlo-legalize-to-lhlo=results-escape-function=true -buffer-placement -split-input-file %s -o - | FileCheck --check-prefixes=ESC,BOTH %s
// BOTH-LABEL: func @attrs
func @attrs_copy(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
%tensor_operand = tensor_load %operand : memref<2x2xf32>
%tensor_result = "xla_hlo.exponential"(%tensor_operand)
{some_attr_1 = "exp.1", some_attr_2 = dense<1> : tensor<1xi64>}
: (tensor<2x2xf32>) -> tensor<2x2xf32>
// BOTH: "xla_lhlo.exponential"(%{{.*}}, %{{.*}}) {some_attr_1 = "exp.1", some_attr_2 = dense<1> : tensor<1xi64>}
tensor_store %tensor_result, %result : memref<2x2xf32>
return
}
// -----
func @return_func(%arg0: tensor<4xf32>) -> tensor<4xf32> {
return %arg0 : tensor<4xf32>
}
// PRE: (%[[ARG0:.*]]: [[TYPE:.*]], %[[RESULT:.*]]: [[TYPE]])
// PRE-NEXT: "xla_lhlo.copy"(%[[ARG0]], %[[RESULT]]) : ([[TYPE]], [[TYPE]]) -> ()
// PRE-NEXT: return
// ESC: (%[[ARG0:.*]]: [[TYPE:.*]]) -> [[TYPE]]
// ESC-NOT: "xla_lhlo.copy"
// ESC-NEXT: return %[[ARG0]]
// -----
// BOTH-LABEL: func @func_op_long
func @func_op_long(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
%1 = xla_hlo.maximum %arg0, %arg1 : tensor<4xf32>
%2 = xla_hlo.add %arg0, %1 : tensor<4xf32>
%3 = xla_hlo.minimum %arg0, %arg1 : tensor<4xf32>
%4 = xla_hlo.subtract %arg1, %3 : tensor<4xf32>
%5 = xla_hlo.multiply %2, %4 : tensor<4xf32>
return %5 : tensor<4xf32>
}
// PRE: (%[[NEW_ARG0:.*]]: memref<4xf32>, %[[NEW_ARG1:.*]]: memref<4xf32>, %[[RESULT:.*]]: memref<4xf32>)
// ESC: (%[[NEW_ARG0:.*]]: memref<4xf32>, %[[NEW_ARG1:.*]]: memref<4xf32>) -> memref<4xf32>
// BOTH-NEXT: %[[MAX_RESULT:.*]] = alloc() : memref<4xf32>
// BOTH-NEXT: "xla_lhlo.maximum"(%[[NEW_ARG0]], %[[NEW_ARG1]], %[[MAX_RESULT]])
// BOTH-NEXT: %[[ADD_RESULT:.*]] = alloc() : memref<4xf32>
// BOTH-NEXT: "xla_lhlo.add"(%[[NEW_ARG0]], %[[MAX_RESULT]], %[[ADD_RESULT]])
// BOTH-NEXT: dealloc %[[MAX_RESULT]] : memref<4xf32>
// BOTH-NEXT: %[[MIN_RESULT:.*]] = alloc() : memref<4xf32>
// BOTH-NEXT: "xla_lhlo.minimum"(%[[NEW_ARG0]], %[[NEW_ARG1]], %[[MIN_RESULT]])
// BOTH-NEXT: %[[SUB_RESULT:.*]] = alloc() : memref<4xf32>
//  BOTH-NEXT: "xla_lhlo.subtract"(%[[NEW_ARG1]], %[[MIN_RESULT]], %[[SUB_RESULT]])
// BOTH-NEXT: dealloc %[[MIN_RESULT]] : memref<4xf32>
// BOTH-NEXT: %[[MUL_RESULT:.*]] = alloc() : memref<4xf32>
// BOTH-NEXT: "xla_lhlo.multiply"(%[[ADD_RESULT]], %[[SUB_RESULT]], %[[MUL_RESULT]])
// BOTH-NEXT: dealloc %[[SUB_RESULT]] : memref<4xf32>
// BOTH-NEXT: dealloc %[[ADD_RESULT]] : memref<4xf32>
// PRE-NEXT: "xla_lhlo.copy"(%[[MUL_RESULT]], %[[RESULT]]) : (memref<4xf32>, memref<4xf32>) -> ()
// PRE-NEXT: dealloc %[[MUL_RESULT]] : memref<4xf32>
// PRE-NEXT: return
// ESC-NEXT: return %[[MUL_RESULT]] : memref<4xf32>
// -----
// BOTH-LABEL: func @fusion
func @fusion(%multiplier: memref<2x2xf32>, %summand_1: memref<2x2xf32>,
%summand_2: memref<2x2xf32>, %result: memref<2x2xf32>) {
// BOTH: (%{{.*}}: {{.*}}, {{.*}}: {{.*}}, {{.*}}: {{.*}}, %[[RESULT:.*]]: {{.*}})
// BOTH-NEXT: %[[ADD_RESULT:.*]] = alloc() : memref<2x2xf32>
%tensor_summand_1 = tensor_load %summand_1 : memref<2x2xf32>
%tensor_summand_2 = tensor_load %summand_2 : memref<2x2xf32>
%sum = "xla_hlo.add"(%tensor_summand_1, %tensor_summand_2)
: (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
// BOTH-NEXT: "xla_lhlo.add"(%{{.*}}, %{{.*}}, %[[ADD_RESULT]])
// BOTH-NEXT: %[[MUL_RESULT:.*]] = alloc() : memref<2x2xf32>
%tensor_multiplier = tensor_load %multiplier : memref<2x2xf32>
%tensor_result = "xla_hlo.multiply"(%sum, %tensor_multiplier)
: (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
// BOTH-NEXT: "xla_lhlo.multiply"(%[[ADD_RESULT]], %{{.*}}, %[[MUL_RESULT]])
// BOTH-NEXT: dealloc %[[ADD_RESULT]] : memref<2x2xf32>
// BOTH-NEXT: "xla_lhlo.copy"(%[[MUL_RESULT]], %[[RESULT]])
tensor_store %tensor_result, %result : memref<2x2xf32>
// BOTH-NEXT: dealloc %[[MUL_RESULT]] : memref<2x2xf32>
// BOTH-NEXT: return
return
}
// -----
// BOTH-LABEL: func @copy
func @copy(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
%tensor_operand = tensor_load %operand : memref<2x2xf32>
%tensor_result = "xla_hlo.copy"(%tensor_operand)
: (tensor<2x2xf32>) -> tensor<2x2xf32>
// BOTH: "xla_lhlo.copy"(%{{.*}}, %{{.*}})
tensor_store %tensor_result, %result : memref<2x2xf32>
return
}
// -----
// BOTH-LABEL: func @exp
func @exp(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
%tensor_operand = tensor_load %operand : memref<2x2xf32>
%tensor_result = "xla_hlo.exponential"(%tensor_operand)
: (tensor<2x2xf32>) -> tensor<2x2xf32>
// BOTH: "xla_lhlo.exponential"(%{{.*}}, %{{.*}})
tensor_store %tensor_result, %result : memref<2x2xf32>
return
}
// -----
// BOTH-LABEL: func @log
func @log(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
%tensor_operand = tensor_load %operand : memref<2x2xf32>
%tensor_result = "xla_hlo.log"(%tensor_operand)
: (tensor<2x2xf32>) -> tensor<2x2xf32>
// BOTH: "xla_lhlo.log"(%{{.*}}, %{{.*}})
tensor_store %tensor_result, %result : memref<2x2xf32>
return
}
// -----
// BOTH-LABEL: func @select
func @select(%pred: memref<2x2xi1>, %lhs: memref<2x2xf32>,
%rhs: memref<2x2xf32>, %result: memref<2x2xf32>) {
%tensor_pred = tensor_load %pred : memref<2x2xi1>
%tensor_lhs = tensor_load %lhs : memref<2x2xf32>
%tensor_rhs = tensor_load %rhs : memref<2x2xf32>
%tensor_result = "xla_hlo.select"(%tensor_pred, %tensor_lhs, %tensor_rhs)
: (tensor<2x2xi1>, tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
// BOTH: "xla_lhlo.select"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}})
tensor_store %tensor_result, %result : memref<2x2xf32>
return
}
// -----
// BOTH-LABEL: func @compare
func @compare(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>, %result: memref<2x2xi1>) {
%tensor_lhs = tensor_load %lhs : memref<2x2xf32>
%tensor_rhs = tensor_load %rhs : memref<2x2xf32>
%tensor_result = "xla_hlo.compare"(%tensor_lhs, %tensor_rhs)
{comparison_direction = "EQ"}
: (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xi1>
// BOTH: "xla_lhlo.compare"(%{{.*}}, %{{.*}}, %{{.*}}) {comparison_direction = "EQ"}
tensor_store %tensor_result, %result : memref<2x2xi1>
return
}
// -----
// BOTH-LABEL: func @broadcast
func @broadcast(%operand: memref<5xf32>, %result: memref<10x5xf32>) {
%tensor_operand = tensor_load %operand : memref<5xf32>
%tensor_result = "xla_hlo.broadcast_in_dim"(%tensor_operand)
{broadcast_dimensions = dense<1> : tensor<1xi64>}
: (tensor<5xf32>) -> tensor<10x5xf32>
// BOTH: "xla_lhlo.broadcast_in_dim"(%{{.*}}, %{{.*}}) {broadcast_dimensions = dense<1> : tensor<1xi64>}
tensor_store %tensor_result, %result : memref<10x5xf32>
return
}
// -----
func @external_func() -> tensor<3xi64>
// BOTH: #[[MAP:.*]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s0 + d1 * s1)>
// BOTH-LABEL: func @dyn_broadcast
func @dyn_broadcast(%operand: memref<?x?xf32>) {
// BOTH-SAME: (%[[OPERAND:.*]]: memref<?x?xf32>)
%tensor_operand = tensor_load %operand : memref<?x?xf32>
%shape = call @external_func() : () -> tensor<3xi64>
%tensor_result = "xla_hlo.dynamic_broadcast_in_dim"(%tensor_operand, %shape) {
broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>
} : (tensor<?x?xf32>, tensor<3xi64>) -> tensor<?x?x?xf32>
// BOTH: %[[SHAPE:.*]] = call @external_func()
// BOTH: %[[C0:.*]] = constant 0 : index
// BOTH: %[[EL0:.*]] = extract_element %[[SHAPE]][%[[C0]]] : tensor<3xi64>
// BOTH: %[[IC0:.*]] = index_cast %[[EL0]] : i64 to index
// BOTH: %[[C1:.*]] = constant 1 : index
// BOTH: %[[EL1:.*]] = extract_element %[[SHAPE]][%[[C1]]] : tensor<3xi64>
// BOTH: %[[IC1:.*]] = index_cast %[[EL1]] : i64 to index
// BOTH: %[[C2:.*]] = constant 2 : index
// BOTH: %[[EL2:.*]] = extract_element %[[SHAPE]][%[[C2]]] : tensor<3xi64>
// BOTH: %[[IC2:.*]] = index_cast %[[EL2]] : i64 to index
// BOTH: %[[RESULT:.*]] = alloc(%[[IC0]], %[[IC1]], %[[IC2]])
// BOTH: %[[C0_:.*]] = constant 0 : index
// BOTH: %[[C1_:.*]] = constant 1 : index
// BOTH: %[[C1__:.*]] = constant 1 : index
// BOTH: %[[EL1_:.*]] = extract_element %[[SHAPE]]{{\[}}%[[C1__]]] : tensor<3xi64>
// BOTH: %[[C0___:.*]] = constant 0 : index
// BOTH: %[[OPERAND_DIM_0:.*]] = dim %[[OPERAND]], %[[C0___]] : memref<?x?xf32>
// BOTH: %[[RESULT_DIM_1:.*]] = index_cast %[[EL1_]] : i64 to index
// BOTH: %[[EXPAND_0:.*]] = cmpi "slt", %[[OPERAND_DIM_0]], %[[RESULT_DIM_1]]
// BOTH: %[[STRIDE_0:.*]] = select %[[EXPAND_0]], %[[C0_]], %[[C1_]] : index
// BOTH: %[[C2_:.*]] = constant 2 : index
// BOTH: %[[EL2_:.*]] = extract_element %[[SHAPE]]{{\[}}%[[C2_]]] : tensor<3xi64>
// BOTH: %[[C1___:.*]] = constant 1 : index
// BOTH: %[[OPERAND_DIM_1:.*]] = dim %[[OPERAND]], %[[C1___]] : memref<?x?xf32>
// BOTH: %[[RESULT_DIM_2:.*]] = index_cast %[[EL2_]] : i64 to index
// BOTH: %[[EXPAND_1:.*]] = cmpi "slt", %[[OPERAND_DIM_1]], %[[RESULT_DIM_2]]
// BOTH: %[[STRIDE_1:.*]] = select %[[EXPAND_1]], %[[C0_]], %[[C1_]] : index
// BOTH: %[[TRANSFORMED_MEMREF:.*]] = xla_lhlo.dynamic_memref_cast
// BOTH-SAME: %[[OPERAND]](%[[RESULT_DIM_1]], %[[RESULT_DIM_2]])
// BOTH-SAME: {{\[}}%[[STRIDE_0]], %[[STRIDE_1]]]
// BOTH-SAME: : memref<?x?xf32> -> memref<?x?xf32, #map0>
// BOTH: "xla_lhlo.broadcast_in_dim"(%[[TRANSFORMED_MEMREF]], %[[RESULT]]) {
// BOTH-SAME: broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>
// BOTH-SAME: } : (memref<?x?xf32, #[[MAP]]>, memref<?x?x?xf32>) -> ()
// Do not store the value back to avoid the tensor-store being rewritten to
// a copy into the pre-allocated argument.
return
}
// -----
// BOTH-LABEL: func @complex
func @complex(%real: memref<2x2xf32>,
%imag: memref<2x2xf32>,
%result: memref<2x2xcomplex<f32>>) {
%tensor_real = tensor_load %real : memref<2x2xf32>
%tensor_imag = tensor_load %imag : memref<2x2xf32>
%tensor_result = "xla_hlo.complex"(%tensor_real, %tensor_imag)
: (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xcomplex<f32>>
// BOTH: "xla_lhlo.complex"(%{{.*}}, %{{.*}})
tensor_store %tensor_result, %result : memref<2x2xcomplex<f32>>
return
}
// -----
// BOTH-LABEL: func @real
func @real(%operand: memref<2x2xcomplex<f32>>, %result: memref<2x2xf32>) {
%tensor_operand = tensor_load %operand : memref<2x2xcomplex<f32>>
%tensor_result = "xla_hlo.real"(%tensor_operand)
: (tensor<2x2xcomplex<f32>>) -> tensor<2x2xf32>
// BOTH: "xla_lhlo.real"(%{{.*}}, %{{.*}})
tensor_store %tensor_result, %result : memref<2x2xf32>
return
}
// -----
// BOTH-LABEL: func @imag
func @imag(%operand: memref<2x2xcomplex<f32>>, %result: memref<2x2xf32>) {
%tensor_operand = tensor_load %operand : memref<2x2xcomplex<f32>>
%tensor_result = "xla_hlo.imag"(%tensor_operand)
: (tensor<2x2xcomplex<f32>>) -> tensor<2x2xf32>
// BOTH: "xla_lhlo.imag"(%{{.*}}, %{{.*}})
tensor_store %tensor_result, %result : memref<2x2xf32>
return
}
// -----
// BOTH-LABEL: func @iota
func @iota(%result: memref<10xi32>) {
%tensor_result = "xla_hlo.iota"()
{iota_dimension = 0 : i64} : () -> tensor<10xi32>
// BOTH: "xla_lhlo.iota"(%{{.*}}) {iota_dimension = 0 : i64}
tensor_store %tensor_result, %result : memref<10xi32>
return
}
// -----
// BOTH-LABEL: func @abs
func @abs(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
%tensor_operand = tensor_load %operand : memref<2x2xf32>
%tensor_result = "xla_hlo.abs"(%tensor_operand)
: (tensor<2x2xf32>) -> tensor<2x2xf32>
// BOTH: "xla_lhlo.abs"(%{{.*}}, %{{.*}})
tensor_store %tensor_result, %result : memref<2x2xf32>
return
}
// -----
// BOTH-LABEL: func @ceil
func @ceil(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
%tensor_operand = tensor_load %operand : memref<2x2xf32>
%tensor_result = "xla_hlo.ceil"(%tensor_operand)
: (tensor<2x2xf32>) -> tensor<2x2xf32>
// BOTH: "xla_lhlo.ceil"(%{{.*}}, %{{.*}})
tensor_store %tensor_result, %result : memref<2x2xf32>
return
}
// -----
// BOTH-LABEL: func @convert
func @convert(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
%tensor_operand = tensor_load %operand : memref<2x2xf32>
%tensor_result = "xla_hlo.convert"(%tensor_operand)
: (tensor<2x2xf32>) -> tensor<2x2xf32>
// BOTH: "xla_lhlo.copy"(%{{.*}}, %{{.*}})
// BOTH-NOT: tensor_store
tensor_store %tensor_result, %result : memref<2x2xf32>
return
}
// -----
// BOTH-LABEL: func @cos
func @cos(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
%tensor_operand = tensor_load %operand : memref<2x2xf32>
%tensor_result = "xla_hlo.cosine"(%tensor_operand)
: (tensor<2x2xf32>) -> tensor<2x2xf32>
// BOTH: "xla_lhlo.cosine"(%{{.*}}, %{{.*}})
tensor_store %tensor_result, %result : memref<2x2xf32>
return
}
// -----
// BOTH-LABEL: func @neg
func @neg(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
%tensor_operand = tensor_load %operand : memref<2x2xf32>
%tensor_result = "xla_hlo.negate"(%tensor_operand)
: (tensor<2x2xf32>) -> tensor<2x2xf32>
// BOTH: "xla_lhlo.negate"(%{{.*}}, %{{.*}})
tensor_store %tensor_result, %result : memref<2x2xf32>
return
}
// -----
// BOTH-LABEL: func @rsqrt
func @rsqrt(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
%tensor_operand = tensor_load %operand : memref<2x2xf32>
%tensor_result = "xla_hlo.rsqrt"(%tensor_operand)
: (tensor<2x2xf32>) -> tensor<2x2xf32>
// BOTH: "xla_lhlo.rsqrt"(%{{.*}}, %{{.*}})
tensor_store %tensor_result, %result : memref<2x2xf32>
return
}
// -----
// BOTH-LABEL: func @sign
func @sign(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
%tensor_operand = tensor_load %operand : memref<2x2xf32>
%tensor_result = "xla_hlo.sign"(%tensor_operand)
: (tensor<2x2xf32>) -> tensor<2x2xf32>
// BOTH: "xla_lhlo.sign"(%{{.*}}, %{{.*}})
tensor_store %tensor_result, %result : memref<2x2xf32>
return
}
// -----
// BOTH-LABEL: func @sqrt
func @sqrt(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
%tensor_operand = tensor_load %operand : memref<2x2xf32>
%tensor_result = "xla_hlo.sqrt"(%tensor_operand)
: (tensor<2x2xf32>) -> tensor<2x2xf32>
// BOTH: "xla_lhlo.sqrt"(%{{.*}}, %{{.*}})
tensor_store %tensor_result, %result : memref<2x2xf32>
return
}
// -----
// BOTH-LABEL: func @tanh
func @tanh(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
%tensor_operand = tensor_load %operand : memref<2x2xf32>
%tensor_result = "xla_hlo.tanh"(%tensor_operand)
: (tensor<2x2xf32>) -> tensor<2x2xf32>
// BOTH: "xla_lhlo.tanh"(%{{.*}}, %{{.*}})
tensor_store %tensor_result, %result : memref<2x2xf32>
return
}
// -----
// BOTH-LABEL: func @remainder
func @remainder(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>, %result: memref<2x2xf32>) {
%tensor_lhs = tensor_load %lhs : memref<2x2xf32>
%tensor_rhs = tensor_load %rhs : memref<2x2xf32>
%tensor_result = "xla_hlo.remainder"(%tensor_lhs, %tensor_rhs)
: (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
// BOTH: "xla_lhlo.remainder"(%{{.*}}, %{{.*}}, %{{.*}})
tensor_store %tensor_result, %result : memref<2x2xf32>
return
}
// -----
// Dynamic shape binary element-wise operation.
// BOTH-LABEL: func @add_dyn
func @add_dyn(%lhs: tensor<?x?xf32>, %rhs: tensor<?x?xf32>) {
%result = "xla_hlo.add"(%lhs, %rhs)
: (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
// BOTH: %[[C0:.*]] = constant 0 : index
// BOTH: %[[DIM0:.*]] = dim %arg0, %[[C0]] : memref<?x?xf32>
// BOTH: %[[IC0:.*]] = index_cast %[[DIM0]] : index to i64
// BOTH: %[[C1:.*]] = constant 1 : index
// BOTH: %[[DIM1:.*]] = dim %arg0, %[[C1]] : memref<?x?xf32>
// BOTH: %[[IC1:.*]] = index_cast %[[DIM1]] : index to i64
// BOTH: %[[SHAPE:.*]] = tensor_from_elements(%[[IC0]], %[[IC1]]) : tensor<2xi64>
// BOTH: %[[C0_:.*]] = constant 0 : index
// BOTH: %[[EE0:.*]] = extract_element %[[SHAPE]][%[[C0_]]] : tensor<2xi64>
// BOTH: %[[ICS0:.*]] = index_cast %[[EE0]] : i64 to index
// BOTH: %[[C1_:.*]] = constant 1 : index
// BOTH: %[[EE1:.*]] = extract_element %[[SHAPE]][%[[C1_]]] : tensor<2xi64>
// BOTH: %[[ICS1:.*]] = index_cast %[[EE1]] : i64 to index
// BOTH: %[[RESULT:.*]] = alloc(%[[ICS0]], %[[ICS1]])
// BOTH: "xla_lhlo.add"(%arg0, %arg1, %[[RESULT]]) : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>) -> ()
return
}
// -----
// Dynamic shape unary element-wise operation.
// BOTH-LABEL: func @tanh_dyn
func @tanh_dyn(%arg0: tensor<?x?xf32>) {
%result = "xla_hlo.tanh"(%arg0)
: (tensor<?x?xf32>) -> tensor<?x?xf32>
// BOTH: %[[C0:.*]] = constant 0 : index
// BOTH: %[[DIM0:.*]] = dim %arg0, %[[C0]] : memref<?x?xf32>
// BOTH: %[[IC0:.*]] = index_cast %[[DIM0]] : index to i64
// BOTH: %[[C1:.*]] = constant 1 : index
// BOTH: %[[DIM1:.*]] = dim %arg0, %[[C1]] : memref<?x?xf32>
// BOTH: %[[IC1:.*]] = index_cast %[[DIM1]] : index to i64
// BOTH: %[[SHAPE:.*]] = tensor_from_elements(%[[IC0]], %[[IC1]]) : tensor<2xi64>
// BOTH: %[[C0_:.*]] = constant 0 : index
// BOTH: %[[EE0:.*]] = extract_element %[[SHAPE]][%[[C0_]]] : tensor<2xi64>
// BOTH: %[[ICS0:.*]] = index_cast %[[EE0]] : i64 to index
// BOTH: %[[C1_:.*]] = constant 1 : index
// BOTH: %[[EE1:.*]] = extract_element %[[SHAPE]][%[[C1_]]] : tensor<2xi64>
// BOTH: %[[ICS1:.*]] = index_cast %[[EE1]] : i64 to index
// BOTH: %[[RESULT:.*]] = alloc(%[[ICS0]], %[[ICS1]])
// BOTH: "xla_lhlo.tanh"(%arg0, %[[RESULT]]) : (memref<?x?xf32>, memref<?x?xf32>) -> ()
return
}
// -----
// BOTH-LABEL: func @dot
func @dot(%arg0: tensor<1024x1024xf32>) -> tensor<1024x1024xf32> {
// PRE-SAME: (%[[ARG0:.*]]: [[TYPE:.*]], %[[RESULT:.*]]: [[TYPE]])
// ESC-SAME: (%[[ARG0:.*]]: [[TYPE:.*]]) -> [[TYPE]]
// BOTH-NEXT: %[[ALLOC:.*]] = alloc
// BOTH: "xla_lhlo.dot"(%[[ARG0]], %[[ARG0]], %[[ALLOC]]) : ([[TYPE]], [[TYPE]], [[TYPE]]) -> ()
%dot = "xla_hlo.dot"(%arg0, %arg0)
: (tensor<1024x1024xf32>, tensor<1024x1024xf32>) -> tensor<1024x1024xf32>
// PRE: "xla_lhlo.copy"(%[[ALLOC]], %[[RESULT]])
// ESC: return %[[ALLOC]]
return %dot : tensor<1024x1024xf32>
}
// -----
// BOTH-LABEL: func @conv
func @conv(%input: tensor<3x5x5x3xf32>, %filter : tensor<2x2x3x4xf32>) -> tensor<3x5x5x4xf32> {
%c0 = constant 0 : index
// BOTH: %[[OUT:.*]] = alloc() : memref<3x5x5x4xf32>
// BOTH: "xla_lhlo.convolution"(%{{.+}}, %{{.+}}, %[[OUT]])
// BOTH-SAME: padding = dense<[
// BOTH-SAME: [0, 1], [0, 1]]> : tensor<2x2xi64>
// BOTH-SAME: rhs_dilation = dense<[1, 2]>
// BOTH-SAME: window_strides = dense<[2, 1]>
%out = "xla_hlo.convolution"(%filter, %input) {
batch_group_count = 1 : i64,
dimension_numbers = {
input_batch_dimension = 0 : i64,
input_feature_dimension = 3 : i64,
input_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>,
kernel_input_feature_dimension = 2 : i64,
kernel_output_feature_dimension = 3 : i64,
kernel_spatial_dimensions = dense<[0, 1]> : tensor<2xi64>,
output_batch_dimension = 0 : i64,
output_feature_dimension = 3 : i64,
output_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>
},
feature_group_count = 1 : i64,
padding = dense<[[0, 1], [0, 1]]> : tensor<2x2xi64>,
rhs_dilation = dense<[1, 2]> : tensor<2xi64>,
window_strides = dense<[2, 1]> : tensor<2xi64>
} : (tensor<2x2x3x4xf32>, tensor<3x5x5x3xf32>) -> tensor<3x5x5x4xf32>
return %out : tensor<3x5x5x4xf32>
}

View File

@ -0,0 +1,559 @@
// RUN: mlir-hlo-opt %s -hlo-legalize-to-linalg -split-input-file | FileCheck %s
// CHECK: #map0 = affine_map<(d0, d1) -> (d0, d1)>
// CHECK-LABEL: func @float_add
func @float_add(%lhs: tensor<2x2xf32>,
%rhs: tensor<2x2xf32>) -> tensor<2x2xf32> {
// CHECK: linalg.generic
// CHECK: ^{{[a-z0-9_]*}}
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]: f32
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]: f32
// CHECK: %[[RESULT:[a-zA-Z0-9_]*]] = addf %[[ARG0]], %[[ARG1]]
// CHECK: linalg.yield %[[RESULT]]
%0 = "xla_hlo.add"(%lhs, %rhs) : (tensor<2x2xf32>,
tensor<2x2xf32>) -> tensor<2x2xf32>
return %0 : tensor<2x2xf32>
}
// -----
// CHECK-LABEL: integer_add
func @integer_add(%lhs: tensor<2x2xi32>,
%rhs: tensor<2x2xi32>) -> tensor<2x2xi32> {
// CHECK: linalg.generic
// CHECK: addi
%0 = "xla_hlo.add"(%lhs, %rhs) : (tensor<2x2xi32>,
tensor<2x2xi32>) -> tensor<2x2xi32>
return %0 : tensor<2x2xi32>
}
// -----
// CHECK-LABEL: func @float_mul
func @float_mul(%lhs: tensor<2x2xf32>,
%rhs: tensor<2x2xf32>) -> tensor<2x2xf32> {
// CHECK: linalg.generic
// CHECK: mulf
%0 = "xla_hlo.multiply"(%lhs, %rhs) : (tensor<2x2xf32>,
tensor<2x2xf32>) -> tensor<2x2xf32>
return %0 : tensor<2x2xf32>
}
// -----
// CHECK-LABEL: func @integer_mul
func @integer_mul(%lhs: tensor<2x2xi32>,
%rhs: tensor<2x2xi32>) -> tensor<2x2xi32> {
// CHECK: linalg.generic
// CHECK: muli
%0 = "xla_hlo.multiply"(%lhs, %rhs) : (tensor<2x2xi32>,
tensor<2x2xi32>) -> tensor<2x2xi32>
return %0 : tensor<2x2xi32>
}
// -----
// CHECK-LABEL: func @float_remainder
func @float_remainder(%lhs: tensor<2x2xf32>,
%rhs: tensor<2x2xf32>) -> tensor<2x2xf32> {
// CHECK: linalg.generic
// CHECK: remf
%0 = "xla_hlo.remainder"(%lhs, %rhs) : (tensor<2x2xf32>,
tensor<2x2xf32>) -> tensor<2x2xf32>
return %0 : tensor<2x2xf32>
}
// -----
// CHECK-LABEL: func @integer_remainder
func @integer_remainder(%lhs: tensor<2x2xi32>,
%rhs: tensor<2x2xi32>) -> tensor<2x2xi32> {
// CHECK: linalg.generic
// CHECK: remi_signed
%0 = "xla_hlo.remainder"(%lhs, %rhs) : (tensor<2x2xi32>,
tensor<2x2xi32>) -> tensor<2x2xi32>
return %0 : tensor<2x2xi32>
}
// -----
// CHECK-LABEL: func @float_rsqrt
func @float_rsqrt(%operand: tensor<2x2xf32>) -> tensor<2x2xf32> {
%tensor_result = "xla_hlo.rsqrt"(%operand)
: (tensor<2x2xf32>) -> tensor<2x2xf32>
// CHECK: linalg.generic
// CHECK: rsqrt
return %tensor_result : tensor<2x2xf32>
}
// -----
// CHECK-LABEL: func @float_sub
func @float_sub(%lhs: tensor<2x2xf32>,
%rhs: tensor<2x2xf32>) -> tensor<2x2xf32> {
// CHECK: linalg.generic
// CHECK: subf
%0 = "xla_hlo.subtract"(%lhs, %rhs) : (tensor<2x2xf32>,
tensor<2x2xf32>) -> tensor<2x2xf32>
return %0 : tensor<2x2xf32>
}
// -----
// CHECK-LABEL: func @integer_sub
func @integer_sub(%lhs: tensor<2x2xi32>,
%rhs: tensor<2x2xi32>) -> tensor<2x2xi32> {
// CHECK: linalg.generic
// CHECK: subi
%0 = "xla_hlo.subtract"(%lhs, %rhs) : (tensor<2x2xi32>,
tensor<2x2xi32>) -> tensor<2x2xi32>
return %0 : tensor<2x2xi32>
}
// -----
// CHECK-LABEL: func @float_abs
func @float_abs(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
// CHECK: linalg.generic
// CHECK: absf
%0 = "xla_hlo.abs"(%arg0) : (tensor<2x2xf32>) -> tensor<2x2xf32>
return %0 : tensor<2x2xf32>
}
// -----
// CHECK-LABEL: func @float_exp
func @float_exp(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
// CHECK: linalg.generic
// CHECK: exp
%0 = "xla_hlo.exponential"(%arg0) : (tensor<2x2xf32>) -> tensor<2x2xf32>
return %0 : tensor<2x2xf32>
}
// -----
// CHECK-LABEL: func @float_log
func @float_log(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
// CHECK: linalg.generic
// CHECK: log
%0 = "xla_hlo.log"(%arg0) : (tensor<2x2xf32>) -> tensor<2x2xf32>
return %0 : tensor<2x2xf32>
}
// -----
// CHECK-LABEL: func @float_ceil
func @float_ceil(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
// CHECK: linalg.generic
// CHECK: ceilf
%0 = "xla_hlo.ceil"(%arg0) : (tensor<2x2xf32>) -> tensor<2x2xf32>
return %0 : tensor<2x2xf32>
}
// -----
// CHECK-LABEL: func @float_neg
func @float_neg(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
// CHECK: linalg.generic
// CHECK: negf
%0 = "xla_hlo.negate"(%arg0) : (tensor<2x2xf32>) -> tensor<2x2xf32>
return %0 : tensor<2x2xf32>
}
// -----
// CHECK-LABEL: func @float_tanh
func @float_tanh(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
// CHECK: linalg.generic
// CHECK: tanh
%0 = "xla_hlo.tanh"(%arg0) : (tensor<2x2xf32>) -> tensor<2x2xf32>
return %0 : tensor<2x2xf32>
}
// -----
// CHECK-LABEL: func @integer_and
func @integer_and(%lhs: tensor<2x2xi32>,
%rhs: tensor<2x2xi32>) -> tensor<2x2xi32> {
// CHECK: linalg.generic
// CHECK: and
%0 = "xla_hlo.and"(%lhs, %rhs) : (tensor<2x2xi32>,
tensor<2x2xi32>) -> tensor<2x2xi32>
return %0 : tensor<2x2xi32>
}
// -----
// CHECK-LABEL: func @float_cmp
func @float_cmp(%lhs: tensor<2x2xf32>,
%rhs: tensor<2x2xf32>) -> (tensor<2x2xi1>) {
%0 = "xla_hlo.compare"(%lhs, %rhs) {comparison_direction = "EQ"}
: (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xi1>
return %0 : tensor<2x2xi1>
}
// CHECK: linalg.generic
// CHECK-NEXT: ^bb0(%[[LHS_IN:.*]]: f32, %[[RHS_IN:.*]]: f32):
// CHECK-NEXT: %[[RESULT:.*]] = cmpf "oeq", %[[LHS_IN]], %[[RHS_IN]] : f32
// CHECK-NEXT: linalg.yield %[[RESULT]] : i1
// -----
// CHECK-LABEL: func @int_cmp
func @int_cmp(%lhs: tensor<2x2xi32>,
%rhs: tensor<2x2xi32>) -> tensor<2x2xi1> {
%0 = "xla_hlo.compare"(%lhs, %rhs) {comparison_direction = "LT"}
: (tensor<2x2xi32>, tensor<2x2xi32>) -> (tensor<2x2xi1>)
return %0 : tensor<2x2xi1>
}
// CHECK: linalg.generic
// CHECK-NEXT: ^bb0(%[[LHS_IN:.*]]: i32, %[[RHS_IN:.*]]: i32):
// CHECK-NEXT: %[[RESULT:.*]] = cmpi "slt", %[[LHS_IN]], %[[RHS_IN]] : i32
// CHECK-NEXT: linalg.yield %[[RESULT]] : i1
// -----
// CHECK-LABEL: func @float_cos
func @float_cos(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
// CHECK: linalg.generic
// CHECK: cos
%0 = "xla_hlo.cosine"(%arg0) : (tensor<2x2xf32>) -> tensor<2x2xf32>
return %0 : tensor<2x2xf32>
}
// -----
// CHECK-LABEL: func @float_sin
func @float_sin(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
// CHECK: linalg.generic
// CHECK: sin
%0 = "xla_hlo.sine"(%arg0) : (tensor<2x2xf32>) -> tensor<2x2xf32>
return %0 : tensor<2x2xf32>
}
// -----
// CHECK-LABEL: func @copy
// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]]
func @copy(%input: tensor<2x4x8xf32>) -> tensor<2x4x8xf32> {
%0 = "xla_hlo.copy"(%input) : (tensor<2x4x8xf32>) -> (tensor<2x4x8xf32>)
return %0 : tensor<2x4x8xf32>
}
// CHECK: return [[ARG]] : tensor<2x4x8xf32>
// -----
// CHECK-LABEL: func @select
func @select(%pred: tensor<2x2xi1>, %lhs: tensor<2x2xf32>,
%rhs: tensor<2x2xf32>) -> tensor<2x2xf32> {
%0 = "xla_hlo.select"(%pred, %lhs, %rhs)
: (tensor<2x2xi1>, tensor<2x2xf32>, tensor<2x2xf32>) -> (tensor<2x2xf32>)
return %0 : tensor<2x2xf32>
}
// CHECK: linalg.generic
// CHECK-NEXT: ^bb0(%[[PRED_IN:.*]]: i1, %[[LHS_IN:.*]]: f32, %[[RHS_IN:.*]]: f32):
// CHECK-NEXT: %[[RESULT:.*]] = select %[[PRED_IN]], %[[LHS_IN]], %[[RHS_IN]] : f32
// CHECK-NEXT: linalg.yield %[[RESULT]] : f32
// -----
// CHECK-DAG: #[[OPERAND_MAP:.+]] = affine_map<(d0, d1, d2) -> ()>
// CHECK-DAG: #[[RESULT_MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
// CHECK-LABEL: func @broadcast_scalar
func @broadcast_scalar(%arg: tensor<f32>) -> tensor<4x2x1xf32> {
%0 = "xla_hlo.broadcast"(%arg) {broadcast_sizes = dense<[4, 2, 1]> : tensor<3xi64>} : (tensor<f32>) -> tensor<4x2x1xf32>
return %0: tensor<4x2x1xf32>
}
// CHECK: linalg.generic {{{.*}}indexing_maps = [#[[OPERAND_MAP]], #[[RESULT_MAP]]]
// CHECK-NEXT: ^bb0(%[[OPERAND:.*]]: f32):
// CHECK-NEXT: linalg.yield %[[OPERAND]] : f32
// -----
// CHECK-DAG: #[[OPERAND_MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d3, d4, d5)>
// CHECK-DAG: #[[RESULT_MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)>
// CHECK-LABEL: func @broadcast
func @broadcast(%arg: tensor<4x?x16xf32>) -> tensor<4x2x1x4x?x16xf32> {
%0 = "xla_hlo.broadcast"(%arg) {broadcast_sizes = dense<[4, 2, 1]> : tensor<3xi64>} : (tensor<4x?x16xf32>) -> tensor<4x2x1x4x?x16xf32>
return %0: tensor<4x2x1x4x?x16xf32>
}
// CHECK: linalg.generic {{{.*}}indexing_maps = [#[[OPERAND_MAP]], #[[RESULT_MAP]]]
// CHECK-NEXT: ^bb0(%[[OPERAND:.*]]: f32):
// CHECK-NEXT: linalg.yield %[[OPERAND]] : f32
// -----
// CHECK-DAG: #[[OPERAND_MAP:.*]] = affine_map<(d0, d1, d2, d3, d4) -> (d4, d0, 0)>
// CHECK-DAG: #[[RESULT_MAP:.*]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>
// CHECK-LABEL: func @broadcast_in_dim
func @broadcast_in_dim(%operand: tensor<5x7x1xf32>) -> tensor<7x10x6x4x5xf32> {
%0 = "xla_hlo.broadcast_in_dim"(%operand)
{broadcast_dimensions = dense<[4,0,2]> : tensor<3xi64>}
: (tensor<5x7x1xf32>) -> tensor<7x10x6x4x5xf32>
return %0 : tensor<7x10x6x4x5xf32>
}
// CHECK: linalg.generic {{{.*}}indexing_maps = [#[[OPERAND_MAP]], #[[RESULT_MAP]]]
// CHECK-NEXT: ^bb0(%[[OPERAND:.*]]: f32):
// CHECK-NEXT: linalg.yield %[[OPERAND]] : f32
// -----
// CHECK-DAG: #[[OPERAND_MAP:.+]] = affine_map<(d0, d1) -> (d0)>
// CHECK-DAG: #[[RESULT_MAP:.+]] = affine_map<(d0, d1) -> (d0, d1)>
// CHECK-LABEL: func @broadcast_in_dim_with_one_to_one
func @broadcast_in_dim_with_one_to_one(
%operand: tensor<1xf32>) -> tensor<1x5xf32> {
%0 = "xla_hlo.broadcast_in_dim"(%operand)
{broadcast_dimensions = dense<[0]> : tensor<1xi64>}
: (tensor<1xf32>) -> tensor<1x5xf32>
return %0 : tensor<1x5xf32>
}
// CHECK: linalg.generic {{{.*}}indexing_maps = [#[[OPERAND_MAP]], #[[RESULT_MAP]]]
// CHECK-NEXT: ^bb0(%[[OPERAND:.*]]: f32):
// CHECK-NEXT: linalg.yield %[[OPERAND]] : f32
// -----
// CHECK-DAG: #[[OPERAND_MAP:.*]] = affine_map<(d0, d1, d2) -> ()>
// CHECK-DAG: #[[RESULT_MAP:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
// CHECK-LABEL: func @broadcast_scalar
func @broadcast_scalar(%operand: tensor<f32>) -> tensor<7x10x6xf32> {
%0 = "xla_hlo.broadcast_in_dim"(%operand)
{broadcast_dimensions = dense<[]> : tensor<0xi64>}
: (tensor<f32>) -> tensor<7x10x6xf32>
return %0 : tensor<7x10x6xf32>
}
// CHECK: linalg.generic {{{.*}}indexing_maps = [#[[OPERAND_MAP]], #[[RESULT_MAP]]]
// CHECK-NEXT: ^bb0(%[[OPERAND:.*]]: f32):
// CHECK-NEXT: linalg.yield %[[OPERAND]] : f32
// -----
// CHECK-DAG: #[[OPERAND_MAP:.*]] = affine_map<(d0, d1, d2, d3) -> (d1, d0, d3, d2)>
// CHECK-DAG: #[[RESULT_MAP:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
// CHECK-LABEL: func @transpose
func @transpose(%arg0: tensor<2x3x9x5xi32>) -> tensor<3x2x5x9xi32> {
%0 = "xla_hlo.transpose"(%arg0) {permutation = dense<[1, 0, 3, 2]> : tensor<4xi64>}
: (tensor<2x3x9x5xi32>) -> tensor<3x2x5x9xi32>
return %0 : tensor<3x2x5x9xi32>
}
// CHECK: linalg.generic {{{.*}}indexing_maps = [#[[OPERAND_MAP]], #[[RESULT_MAP]]]
// -----
// CHECK-DAG: #[[RESHAPE_MAP1:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
// CHECK-DAG: #[[RESHAPE_MAP2:.*]] = affine_map<(d0, d1, d2) -> (d2)>
// CHECK-LABEL: func @reshape_3D_2D
func @reshape_3D_2D(%arg0: tensor<12x1x42xi32>) -> tensor<12x42xi32> {
%0 = "xla_hlo.reshape"(%arg0) : (tensor<12x1x42xi32>) -> tensor<12x42xi32>
return %0 : tensor<12x42xi32>
}
// CHECK: linalg.tensor_reshape %{{.*}} [#[[RESHAPE_MAP1]], #[[RESHAPE_MAP2]]]
// -----
// CHECK-DAG: #[[RESHAPE_MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d0)>
// CHECK-DAG: #[[RESHAPE_MAP2:.*]] = affine_map<(d0, d1, d2, d3) -> (d1, d2, d3)>
// CHECK-LABEL: func @reshape_4D_2D
func @reshape_4D_2D(%arg0: tensor<12x42x1x1xi32>) -> tensor<12x42xi32> {
%0 = "xla_hlo.reshape"(%arg0) : (tensor<12x42x1x1xi32>) -> tensor<12x42xi32>
return %0 : tensor<12x42xi32>
}
// CHECK: linalg.tensor_reshape %{{.*}} [#[[RESHAPE_MAP1]], #[[RESHAPE_MAP2]]]
// -----
// CHECK-DAG: #[[RESHAPE_MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1)>
// CHECK-DAG: #[[RESHAPE_MAP2:.*]] = affine_map<(d0, d1, d2, d3) -> (d2, d3)>
// CHECK-LABEL: func @reshape_2D_4D
func @reshape_2D_4D(%arg0: tensor<12x42xi32>) -> tensor<12x1x42x1xi32> {
%0 = "xla_hlo.reshape"(%arg0) : (tensor<12x42xi32>) -> tensor<12x1x42x1xi32>
return %0 : tensor<12x1x42x1xi32>
}
// CHECK: linalg.tensor_reshape %{{.*}} [#[[RESHAPE_MAP1]], #[[RESHAPE_MAP2]]]
// -----
// CHECK-LABEL: func @minf
func @minf(%lhs: tensor<2x2xf32>, %rhs: tensor<2x2xf32>) -> tensor<2x2xf32> {
%0 = "xla_hlo.minimum"(%lhs, %rhs)
: (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
return %0 : tensor<2x2xf32>
}
// CHECK: linalg.generic
// CHECK-NEXT: ^bb0(%[[LHS_IN:.*]]: f32, %[[RHS_IN:.*]]: f32):
// CHECK-NEXT: %[[CMP:.*]] = cmpf "olt", %[[LHS_IN]], %[[RHS_IN]] : f32
// CHECK-NEXT: %[[RESULT:.*]] = select %[[CMP]], %[[LHS_IN]], %[[RHS_IN]] : f32
// CHECK-NEXT: linalg.yield %[[RESULT]] : f32
// -----
// CHECK-LABEL: func @maxi
func @maxi(%lhs: tensor<2x2xi32>, %rhs: tensor<2x2xi32>) -> tensor<2x2xi32> {
%0 = "xla_hlo.maximum"(%lhs, %rhs)
: (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
return %0 : tensor<2x2xi32>
}
// CHECK: linalg.generic
// CHECK-NEXT: ^bb0(%[[LHS_IN:.*]]: i32, %[[RHS_IN:.*]]: i32):
// CHECK-NEXT: %[[CMP:.*]] = cmpi "sgt", %[[LHS_IN]], %[[RHS_IN]] : i32
// CHECK-NEXT: %[[RESULT:.*]] = select %[[CMP]], %[[LHS_IN]], %[[RHS_IN]] : i32
// CHECK-NEXT: linalg.yield %[[RESULT]] : i32
// -----
// CHECK-DAG: #[[MAP:.*]] = affine_map<() -> ()>
// CHECK-LABEL: func @add_scalar
func @add_scalar(%lhs: tensor<f32>, %rhs: tensor<f32>) -> tensor<f32> {
%0 = "xla_hlo.add"(%lhs, %rhs) : (tensor<f32>, tensor<f32>) -> tensor<f32>
return %0 : tensor<f32>
}
// CHECK: linalg.generic
// CHECK-SAME: indexing_maps = [#[[MAP]], #[[MAP]], #[[MAP]]]
// CHECK-NEXT: ^bb0(%[[LHS:.*]]: f32, %[[RHS:.*]]: f32):
// CHECK: %[[RESULT:.*]] = addf %[[LHS]], %[[RHS]]
// CHECK-NEXT: linalg.yield %[[RESULT]] : f32
// -----
func @reshape_collapse_single_dim
(%arg0: tensor<1x28x28x1xf32>) -> tensor<1x784xf32> {
%0 = "xla_hlo.reshape"(%arg0) : (tensor<1x28x28x1xf32>) -> tensor<1x784xf32>
return %0 : tensor<1x784xf32>
}
// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d0)>
// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d1, d2, d3)>
// CHECK-LABEL: func @reshape_collapse_single_dim
// CHECK: linalg.tensor_reshape %{{.*}} [#[[MAP0]], #[[MAP1]]]
// -----
func @reshape_collapse(%arg0: tensor<2x2x2x3xf32>) -> tensor<2x4x3xf32> {
%0 = "xla_hlo.reshape"(%arg0) : (tensor<2x2x2x3xf32>) -> tensor<2x4x3xf32>
return %0 : tensor<2x4x3xf32>
}
// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d0)>
// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
// CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0, d1, d2, d3) -> (d3)>
// CHECK-LABEL: func @reshape_collapse
// CHECK: linalg.tensor_reshape %{{.*}} [#[[MAP0]], #[[MAP1]], #[[MAP2]]]
// -----
func @reshape_expand(%arg0: tensor<2x8xf32>) -> tensor<2x4x2xf32> {
%0 = "xla_hlo.reshape"(%arg0) : (tensor<2x8xf32>) -> tensor<2x4x2xf32>
return %0 : tensor<2x4x2xf32>
}
// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1, d2) -> (d0)>
// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1, d2) -> (d1, d2)>
// CHECK-LABEL: func @reshape_expand
// CHECK: linalg.tensor_reshape %{{.*}} [#[[MAP0]], #[[MAP1]]]
// -----
func @reshape_single_expand(%arg0 : tensor<8xf32>) -> tensor<1x4x2xf32> {
%0 = "xla_hlo.reshape"(%arg0) : (tensor<8xf32>) -> tensor<1x4x2xf32>
return %0 : tensor<1x4x2xf32>
}
// CHECK: #[[MAP0:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
// CHECK-LABEL: func @reshape_single_expand
// CHECK: linalg.tensor_reshape %{{.*}} [#[[MAP0]]]
// -----
func @reshape_multiple_collapse
(%arg0 : tensor<1x2x2x5x3x2xf32>) -> tensor<1x4x5x6xf32> {
%0 = "xla_hlo.reshape"(%arg0) : (tensor<1x2x2x5x3x2xf32>) -> tensor<1x4x5x6xf32>
return %0 : tensor<1x4x5x6xf32>
}
// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0)>
// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d2)>
// CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d3)>
// CHECK-DAG: #[[MAP3:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)>
// CHECK-LABEL: func @reshape_multiple_collapse
// CHECK: linalg.tensor_reshape %{{.*}} [#[[MAP0]], #[[MAP1]], #[[MAP2]], #[[MAP3]]]
// -----
// CHECK-LABEL: func @convert_i32_to_f32
func @convert_i32_to_f32(%input: tensor<2x2xi32>) -> tensor<2x2xf32> {
%result = "xla_hlo.convert"(%input) : (tensor<2x2xi32>) -> tensor<2x2xf32>
return %result : tensor<2x2xf32>
}
// CHECK: linalg.generic
// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: i32):
// CHECK-NEXT: %[[RESULT:.*]] = sitofp %[[OPERAND_IN]] : i32 to f32
// CHECK-NEXT: linalg.yield %[[RESULT]] : f32
// -----
// CHECK-LABEL: func @convert_i16_to_i32
func @convert_i16_to_i32(%input: tensor<2x2xi16>) -> tensor<2x2xi32> {
%result = "xla_hlo.convert"(%input) : (tensor<2x2xi16>) -> tensor<2x2xi32>
return %result : tensor<2x2xi32>
}
// CHECK: linalg.generic
// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: i16):
// CHECK-NEXT: %[[RESULT:.*]] = zexti %[[OPERAND_IN]] : i16 to i32
// CHECK-NEXT: linalg.yield %[[RESULT]] : i32
// -----
// CHECK-LABEL: func @convert_i32_to_i16
func @convert_i32_to_i16(%input: tensor<2x2xi32>) -> tensor<2x2xi16> {
%result = "xla_hlo.convert"(%input) : (tensor<2x2xi32>) -> tensor<2x2xi16>
return %result : tensor<2x2xi16>
}
// CHECK: linalg.generic
// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: i32):
// CHECK-NEXT: %[[RESULT:.*]] = trunci %[[OPERAND_IN]] : i32 to i16
// CHECK-NEXT: linalg.yield %[[RESULT]] : i16
// -----
// CHECK-LABEL: func @convert_f32_to_f64
func @convert_f32_to_f64(%input: tensor<2x2xf32>) -> tensor<2x2xf64> {
%result = "xla_hlo.convert"(%input) : (tensor<2x2xf32>) -> tensor<2x2xf64>
return %result : tensor<2x2xf64>
}
// CHECK: linalg.generic
// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: f32):
// CHECK-NEXT: %[[RESULT:.*]] = fpext %[[OPERAND_IN]] : f32 to f64
// CHECK-NEXT: linalg.yield %[[RESULT]] : f64
// -----
// CHECK-LABEL: func @convert_f64_to_f32
func @convert_f64_to_f32(%input: tensor<2x2xf64>) -> tensor<2x2xf32> {
%result = "xla_hlo.convert"(%input) : (tensor<2x2xf64>) -> tensor<2x2xf32>
return %result : tensor<2x2xf32>
}
// CHECK: linalg.generic
// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: f64):
// CHECK-NEXT: %[[RESULT:.*]] = fptrunc %[[OPERAND_IN]] : f64 to f32
// CHECK-NEXT: linalg.yield %[[RESULT]] : f32
// -----
// CHECK-LABEL: func @convert_f32_to_i32
func @convert_f32_to_i32(%input: tensor<2x2xf32>) -> tensor<2x2xi32> {
%result = "xla_hlo.convert"(%input) : (tensor<2x2xf32>) -> tensor<2x2xi32>
return %result : tensor<2x2xi32>
}
// CHECK: linalg.generic
// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: f32):
// CHECK-NEXT: %[[RESULT:.*]] = fptosi %[[OPERAND_IN]] : f32 to i32
// CHECK-NEXT: linalg.yield %[[RESULT]] : i32
// -----
// CHECK-DAG: #[[OPERAND_MAP:.*]] = affine_map<(d0, d1) -> (d0, -d1 + 2)>
// CHECK-DAG: #[[RESULT_MAP:.*]] = affine_map<(d0, d1) -> (d0, d1)>
// CHECK-LABEL: func @reverse
func @reverse(%input: tensor<2x3xf32>) -> tensor<2x3xf32> {
%result = "xla_hlo.reverse"(%input) {
dimensions = dense<1> : tensor<1xi64>
} : (tensor<2x3xf32>) -> tensor<2x3xf32>
return %result : tensor<2x3xf32>
}
// CHECK: linalg.generic
// CHECK-SAME: indexing_maps = [#[[OPERAND_MAP]], #[[RESULT_MAP]]]

28
tests/inlining.mlir Normal file
View File

@ -0,0 +1,28 @@
// RUN: mlir-hlo-opt %s -inline | FileCheck %s
// Test case: Basic test of inlining into xla_hlo.while.
// CHECK-LABEL: func @caller
// CHECK: "xla_hlo.while"{{.*}}( {
// CHECK: }, {
// CHECK: "xla_hlo.exponential"
// CHECK: })
// CHECK-LABEL: func @callee
func @caller(%arg0: tensor<f32>, %pred: tensor<i1>) -> tensor<f32> {
%0 = "xla_hlo.while"(%arg0) ( {
^entry(%unused: tensor<f32>):
"xla_hlo.return"(%pred) : (tensor<i1>) -> ()
}, {
^entry(%0: tensor<f32>):
%1 = call @callee(%0) : (tensor<f32>) -> (tensor<f32>)
"xla_hlo.return"(%1) : (tensor<f32>) -> ()
} ) : (tensor<f32>) -> (tensor<f32>)
return %0 : tensor<f32>
}
func @callee(%arg0: tensor<f32>) -> tensor<f32> {
%0 = "xla_hlo.exponential"(%arg0) : (tensor<f32>) -> tensor<f32>
return %0 : tensor<f32>
}

View File

@ -0,0 +1,146 @@
// RUN: mlir-hlo-opt -xla-legalize-control-flow %s -o - | FileCheck %s
// CHECK-LABEL: func @while(%arg0: tensor<i64>) -> tensor<i64> {
func @while(%arg0: tensor<i64>) -> tensor<i64> {
//CHECK: br ^bb1(%arg0 : tensor<i64>)
//CHECK: ^bb1([[VAL0:%.+]]: tensor<i64>):
//CHECK: [[VAL1:%.+]] = "xla_hlo.compare"([[VAL0]], [[VAL0]])
//CHECK: [[VAL2:%.+]] = extract_element [[VAL1]][] : tensor<i1>
//CHECK: cond_br [[VAL2]], ^bb2([[VAL0]] : tensor<i64>), ^bb3([[VAL0]] : tensor<i64>)
//CHECK: ^bb2([[VAL3:%.+]]: tensor<i64>):
//CHECK: [[VAL4:%.+]] = xla_hlo.add [[VAL3]], [[VAL3]]
//CHECK: br ^bb1([[VAL4]] : tensor<i64>)
//CHECK: ^bb3([[VAL5:%.+]]: tensor<i64>):
%0 = "xla_hlo.while"(%arg0) ( {
^bb0(%arg1: tensor<i64>):
%1 = "xla_hlo.compare"(%arg1, %arg1) {comparison_direction = "LT", name = "compare.2"} : (tensor<i64>, tensor<i64>) -> tensor<i1>
"xla_hlo.return"(%1) : (tensor<i1>) -> ()
}, {
^bb0(%arg1: tensor<i64>):
%1 = xla_hlo.add %arg1, %arg1 {name = "compare.0"} : tensor<i64>
"xla_hlo.return"(%1) : (tensor<i64>) -> ()
}) : (tensor<i64>) -> tensor<i64>
// CHECK-NEXT: return [[VAL5]]
return %0 : tensor<i64>
}
// CHECK-LABEL: func @conditional
func @conditional(%arg0: tensor<f32>) -> tensor<f32> {
// CHECK: [[C0:%.+]] = constant dense<1.000000e+01> : tensor<f32>
%cst = constant dense<1.000000e+01> : tensor<f32>
// CHECK: [[VAL0:%.+]] = "xla_hlo.compare"(%arg0, [[C0]]) {comparison_direction = "LT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
%0 = "xla_hlo.compare"(%arg0, %cst) {comparison_direction = "LT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
// CHECK: [[VAL1:%.+]] = extract_element [[VAL0]][] : tensor<i1>
// CHECK: cond_br [[VAL1]], ^bb1(%arg0 : tensor<f32>), ^bb2(%arg0 : tensor<f32>)
%1 = "xla_hlo.if"(%0, %arg0, %arg0) ( {
^bb0(%arg1: tensor<f32>):
// CHECK: ^bb1([[VAL2:%.+]]: tensor<f32>):
// CHECK: [[VAL3:%.+]] = "xla_hlo.log"([[VAL2]]) : (tensor<f32>) -> tensor<f32>
// CHECK: br ^bb3([[VAL3]] : tensor<f32>)
%2 = "xla_hlo.log"(%arg1) : (tensor<f32>) -> tensor<f32>
"xla_hlo.return"(%2) : (tensor<f32>) -> ()
}, {
^bb0(%arg1: tensor<f32>):
// CHECK: ^bb2([[VAL4:%.+]]: tensor<f32>):
// CHECK: [[VAL5:%.+]] = "xla_hlo.exponential"([[VAL4]]) : (tensor<f32>) -> tensor<f32>
// CHECK: br ^bb3([[VAL5]] : tensor<f32>)
%2 = "xla_hlo.exponential"(%arg1) : (tensor<f32>) -> tensor<f32>
"xla_hlo.return"(%2) : (tensor<f32>) -> ()
}) : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
// CHECK: ^bb3([[VAL6:%.+]]: tensor<f32>):
// CHECK: return [[VAL6]] : tensor<f32>
return %1 : tensor<f32>
}
// CHECK-LABEL: func @while_with_multiple_blocks_in_body(%arg0: tensor<i64>) -> tensor<i64> {
func @while_with_multiple_blocks_in_body(%arg0: tensor<i64>) -> tensor<i64> {
// CHECK: br ^[[COND_ENTRY:.+]](%arg0 : tensor<i64>)
// CHECK: ^[[COND_ENTRY]](%0: tensor<i64>):
// CHECK: %1 = "xla_hlo.compare"(%0, %0) {comparison_direction = "LT"} : (tensor<i64>, tensor<i64>) -> tensor<i1>
// CHECK: %2 = extract_element %1[] : tensor<i1>
// CHECK: cond_br %2, ^[[BODY_ENTRY:.+]](%0 : tensor<i64>), ^[[EXIT:.+]](%0 : tensor<i64>)
// CHECK: ^[[BODY_ENTRY]](%3: tensor<i64>):
// CHECK: br ^[[BODY_SUCC:.+]](%3 : tensor<i64>)
// CHECK: ^[[BODY_SUCC]](%4: tensor<i64>):
// CHECK: %5 = xla_hlo.add %4, %4 : tensor<i64>
// CHECK: br ^[[COND_ENTRY]](%5 : tensor<i64>)
// CHECK: ^[[EXIT]](%6: tensor<i64>):
// CHECK: return %6 : tensor<i64>
// CHECK: }
%0 = "xla_hlo.while"(%arg0) ( {
^cond_entry(%arg1: tensor<i64>):
%1 = "xla_hlo.compare"(%arg1, %arg1) {comparison_direction = "LT"} : (tensor<i64>, tensor<i64>) -> tensor<i1>
"xla_hlo.return"(%1) : (tensor<i1>) -> ()
}, {
^body_entry(%arg1: tensor<i64>):
br ^body_succ(%arg1: tensor<i64>)
^body_succ(%0: tensor<i64>):
%1 = xla_hlo.add %0, %0 : tensor<i64>
"xla_hlo.return"(%1) : (tensor<i64>) -> ()
}) : (tensor<i64>) -> tensor<i64>
return %0 : tensor<i64>
}
// CHECK-LABEL: func @while_with_multiple_blocks_in_cond(%arg0: tensor<i64>) -> tensor<i64> {
func @while_with_multiple_blocks_in_cond(%arg0: tensor<i64>) -> tensor<i64> {
// CHECK: br ^[[COND_ENTRY:.+]](%arg0 : tensor<i64>)
// CHECK: ^[[COND_ENTRY]](%0: tensor<i64>):
// CHECK: br ^[[COND_SUCC:.+]](%0 : tensor<i64>)
// CHECK: ^[[COND_SUCC]](%1: tensor<i64>):
// CHECK: %2 = "xla_hlo.compare"(%1, %1) {comparison_direction = "LT"} : (tensor<i64>, tensor<i64>) -> tensor<i1>
// CHECK: %3 = extract_element %2[] : tensor<i1>
// CHECK: cond_br %3, ^[[BODY_ENTRY:.+]](%0 : tensor<i64>), ^[[EXIT:.+]](%0 : tensor<i64>)
// CHECK: ^[[BODY_ENTRY]](%4: tensor<i64>):
// CHECK: br ^[[COND_ENTRY]](%4 : tensor<i64>)
// CHECK: ^[[EXIT]](%5: tensor<i64>):
// CHECK: return %5 : tensor<i64>
// CHECK: }
%0 = "xla_hlo.while"(%arg0) ( {
^cond_entry(%arg1: tensor<i64>):
br ^cond_succ(%arg1: tensor<i64>)
^cond_succ(%0: tensor<i64>):
%1 = "xla_hlo.compare"(%0, %0) {comparison_direction = "LT"} : (tensor<i64>, tensor<i64>) -> tensor<i1>
"xla_hlo.return"(%1) : (tensor<i1>) -> ()
}, {
^body_entry(%arg1: tensor<i64>):
"xla_hlo.return"(%arg1) : (tensor<i64>) -> ()
}) : (tensor<i64>) -> tensor<i64>
return %0 : tensor<i64>
}
// CHECK-LABEL: func @conditional_with_multiple_blocks(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> tensor<f32> {
func @conditional_with_multiple_blocks(%arg0: tensor<f32>, %arg1: tensor<f32>, %pred: tensor<i1>) -> tensor<f32> {
// CHECK: %0 = extract_element %arg2[] : tensor<i1>
// CHECK: cond_br %0, ^[[THEN_ENTRY:.+]](%arg0 : tensor<f32>), ^[[ELSE_ENTRY:.+]](%arg1 : tensor<f32>)
// CHECK: ^[[THEN_ENTRY]](%1: tensor<f32>):
// CHECK: br ^[[THEN_SUCC:.+]](%1 : tensor<f32>)
// CHECK: ^[[THEN_SUCC]](%2: tensor<f32>):
// CHECK: %3 = "xla_hlo.log"(%2) : (tensor<f32>) -> tensor<f32>
// CHECK: br ^[[EXIT:.+]](%3 : tensor<f32>)
// CHECK: ^[[ELSE_ENTRY]](%4: tensor<f32>):
// CHECK: %5 = "xla_hlo.exponential"(%4) : (tensor<f32>) -> tensor<f32>
// CHECK: br ^[[EXIT]](%5 : tensor<f32>)
// CHECK: ^[[EXIT]](%6: tensor<f32>):
// CHECK: return %6 : tensor<f32>
// CHECK: }
%1 = "xla_hlo.if"(%pred, %arg0, %arg1) ( {
^then_entry(%arg2: tensor<f32>):
br ^then_succ(%arg2: tensor<f32>)
^then_succ(%0: tensor<f32>):
%2 = "xla_hlo.log"(%0) : (tensor<f32>) -> tensor<f32>
"xla_hlo.return"(%2) : (tensor<f32>) -> ()
}, {
^else_entry(%arg2: tensor<f32>):
%2 = "xla_hlo.exponential"(%arg2) : (tensor<f32>) -> tensor<f32>
"xla_hlo.return"(%2) : (tensor<f32>) -> ()
}) : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
return %1 : tensor<f32>
}

195
tests/legalize-to-std.mlir Normal file
View File

@ -0,0 +1,195 @@
// RUN: mlir-hlo-opt -xla-legalize-to-std %s -o - | FileCheck %s
// CHECK-LABEL: func @binary_ops_float(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
func @binary_ops_float(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
// CHECK-NEXT: %0 = addf %arg0, %arg1 : tensor<4xf32>
%0 = "xla_hlo.add"(%arg0, %arg1) {name = "add.3"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
// CHECK-NEXT: %1 = mulf %0, %arg1 : tensor<4xf32>
%1 = "xla_hlo.multiply"(%0, %arg1) {name = "mul.4"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
// CHECK-NEXT: %2 = subf %1, %arg1 : tensor<4xf32>
%2 = "xla_hlo.subtract"(%1, %arg1) {name = "sub.5"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
// CHECK-NEXT: %3 = divf %2, %arg1 : tensor<4xf32>
%3 = "xla_hlo.divide"(%2, %arg1) {name = "div.6"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
// CHECK-NEXT: %4 = remf %3, %arg1 : tensor<4xf32>
%4 = "xla_hlo.remainder"(%3, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
// CHECK-NEXT: return %4 : tensor<4xf32>
return %4 : tensor<4xf32>
}
// CHECK-LABEL: func @binary_ops_int(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> {
func @binary_ops_int(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> {
// CHECK-NEXT: %0 = addi %arg0, %arg1 : tensor<4xi32>
%0 = "xla_hlo.add"(%arg0, %arg1) {name = "add.3"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
// CHECK-NEXT: %1 = muli %0, %arg1 : tensor<4xi32>
%1 = "xla_hlo.multiply"(%0, %arg1) {name = "mul.4"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
// CHECK-NEXT: %2 = subi %1, %arg1 : tensor<4xi32>
%2 = "xla_hlo.subtract"(%1, %arg1) {name = "sub.5"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
// CHECK-NEXT: %3 = divi_signed %2, %arg1 : tensor<4xi32>
%3 = "xla_hlo.divide"(%2, %arg1) {name = "div.6"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
// CHECK-NEXT: %4 = remi_signed %3, %arg1 : tensor<4xi32>
%4 = "xla_hlo.remainder"(%3, %arg1) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
// CHECK-NEXT: return %4 : tensor<4xi32>
return %4 : tensor<4xi32>
}
// CHECK-LABEL: func @compare_int(%arg0: tensor<4xi32>) -> (tensor<4xi1>, tensor<4xi1>, tensor<4xi1>, tensor<4xi1>, tensor<4xi1>, tensor<4xi1>) {
func @compare_int(%arg0: tensor<4xi32>) -> (tensor<4xi1>,tensor<4xi1>,tensor<4xi1>,tensor<4xi1>,tensor<4xi1>,tensor<4xi1>) {
// CHECK-NEXT: %0 = cmpi "eq", %arg0, %arg0 : tensor<4xi32>
%0 = "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "EQ"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1>
// CHECK-NEXT: %1 = cmpi "ne", %arg0, %arg0 : tensor<4xi32>
%1 = "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "NE"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1>
// CHECK-NEXT: %2 = cmpi "slt", %arg0, %arg0 : tensor<4xi32>
%2 = "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "LT"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1>
// CHECK-NEXT: %3 = cmpi "sle", %arg0, %arg0 : tensor<4xi32>
%3 = "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "LE"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1>
// CHECK-NEXT: %4 = cmpi "sgt", %arg0, %arg0 : tensor<4xi32>
%4 = "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "GT"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1>
// CHECK-NEXT: %5 = cmpi "sge", %arg0, %arg0 : tensor<4xi32>
%5 = "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "GE"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1>
// CHECK-NEXT: return %0, %1, %2, %3, %4, %5 : tensor<4xi1>, tensor<4xi1>, tensor<4xi1>, tensor<4xi1>, tensor<4xi1>, tensor<4xi1>
return %0, %1, %2, %3, %4, %5 : tensor<4xi1>, tensor<4xi1>, tensor<4xi1>, tensor<4xi1>, tensor<4xi1>, tensor<4xi1>
}
// CHECK-LABEL: func @compare_float
func @compare_float(%arg0: tensor<4xf32>) -> (tensor<4xi1>,tensor<4xi1>,tensor<4xi1>,tensor<4xi1>,tensor<4xi1>,tensor<4xi1>) {
// CHECK-NEXT: %0 = cmpf "oeq", %arg0, %arg0 : tensor<4xf32>
%0 = "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "EQ"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1>
// CHECK-NEXT: %1 = cmpf "une", %arg0, %arg0 : tensor<4xf32>
%1 = "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "NE"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1>
// CHECK-NEXT: %2 = cmpf "olt", %arg0, %arg0 : tensor<4xf32>
%2 = "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "LT"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1>
// CHECK-NEXT: %3 = cmpf "ole", %arg0, %arg0 : tensor<4xf32>
%3 = "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "LE"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1>
// CHECK-NEXT: %4 = cmpf "ogt", %arg0, %arg0 : tensor<4xf32>
%4 = "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "GT"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1>
// CHECK-NEXT: %5 = cmpf "oge", %arg0, %arg0 : tensor<4xf32>
%5 = "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "GE"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1>
return %0, %1, %2, %3, %4, %5: tensor<4xi1>, tensor<4xi1>, tensor<4xi1>, tensor<4xi1>, tensor<4xi1>, tensor<4xi1>
}
// CHECK-LABEL: func @int_constant
func @int_constant() -> (tensor<i32>, tensor<2x3xi32>, tensor<2x3xi32>) {
// CHECK-NEXT: [[CST0:%.+]] = constant {{.+}} : tensor<i32>
%0 = "xla_hlo.constant"() {value = dense<0> : tensor<i32>} : () -> (tensor<i32>)
// CHECK-NEXT: [[CST1:%.+]] = constant {{.+}} : tensor<2x3xi32>
%1 = "xla_hlo.constant"() {value = dense<1> : tensor<2x3xi32>} : () -> (tensor<2x3xi32>)
// CHECK-NEXT: [[CST2:%.+]] = constant {{.+}} : tensor<2x3xi32>
%2 = "xla_hlo.constant"() {value = dense<[[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi32>} : () -> (tensor<2x3xi32>)
// CHECK-NEXT: return [[CST0]], [[CST1]], [[CST2]] : tensor<i32>, tensor<2x3xi32>, tensor<2x3xi32>
return %0, %1, %2: tensor<i32>, tensor<2x3xi32>, tensor<2x3xi32>
}
// CHECK-LABEL: func @float_constant
func @float_constant() -> (tensor<f32>, tensor<2x3xf32>, tensor<2x3xf32>) {
// CHECK-NEXT: [[CST0:%.+]] = constant {{.+}} : tensor<f32>
%0 = "xla_hlo.constant"() {value = dense<0.0> : tensor<f32>} : () -> (tensor<f32>)
// CHECK-NEXT: [[CST1:%.+]] = constant {{.+}} : tensor<2x3xf32>
%1 = "xla_hlo.constant"() {value = dense<1.0> : tensor<2x3xf32>} : () -> (tensor<2x3xf32>)
// CHECK-NEXT: [[CST2:%.+]] = constant {{.+}} : tensor<2x3xf32>
%2 = "xla_hlo.constant"() {value = dense<[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]> : tensor<2x3xf32>} : () -> (tensor<2x3xf32>)
// CHECK-NEXT: return [[CST0]], [[CST1]], [[CST2]] : tensor<f32>, tensor<2x3xf32>, tensor<2x3xf32>
return %0, %1, %2: tensor<f32>, tensor<2x3xf32>, tensor<2x3xf32>
}
// Test Iota lowering to constant
// CHECK-LABEL: func @iota.const.1() -> tensor<4xi32> {
func @iota.const.1() -> tensor<4xi32> {
// CHECK-NEXT: %[[CST:.*]] = constant dense<[0, 1, 2, 3]> : tensor<4xi32>
%0 = "xla_hlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<4xi32>
// CHECK-NEXT: return %[[CST]] : tensor<4xi32>
return %0 : tensor<4xi32>
}
// CHECK-LABEL: func @iota.const.2() -> tensor<2x4xi32> {
func @iota.const.2() -> tensor<2x4xi32> {
// CHECK-NEXT: %[[CST:.*]] = constant dense<{{\[\[}}0, 0, 0, 0], [1, 1, 1, 1]]> : tensor<2x4xi32>
%0 = "xla_hlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<2x4xi32>
// CHECK-NEXT: return %[[CST]] : tensor<2x4xi32>
return %0 : tensor<2x4xi32>
}
// CHECK-LABEL: func @iota.const.3() -> tensor<2x4xi32> {
func @iota.const.3() -> tensor<2x4xi32> {
// CHECK-NEXT: %[[CST:.*]] = constant dense<{{\[\[}}0, 1, 2, 3], [0, 1, 2, 3]]> : tensor<2x4xi32>
%0 = "xla_hlo.iota"() {iota_dimension = 1 : i64} : () -> tensor<2x4xi32>
// CHECK-NEXT: return %[[CST]] : tensor<2x4xi32>
return %0 : tensor<2x4xi32>
}
// CHECK-LABEL: func @iota.const.4() -> tensor<2x3x4xi32> {
func @iota.const.4() -> tensor<2x3x4xi32> {
// CHECK-NEXT: %[[CST:.*]] = constant dense<{{\[\[\[}}0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0{{\]\]}}, {{\[\[}}1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1]]]> : tensor<2x3x4xi32>
%0 = "xla_hlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<2x3x4xi32>
// CHECK-NEXT: return %[[CST]] : tensor<2x3x4xi32>
return %0 : tensor<2x3x4xi32>
}
// CHECK-LABEL: func @iota.const.5() -> tensor<2x3x4xi32> {
func @iota.const.5() -> tensor<2x3x4xi32> {
// CHECK-NEXT: %[[CST:.*]] = constant dense<{{\[\[\[}}0, 0, 0, 0], [1, 1, 1, 1], [2, 2, 2, 2{{\]\]}}, {{\[\[}}0, 0, 0, 0], [1, 1, 1, 1], [2, 2, 2, 2]]]> : tensor<2x3x4xi32>
%0 = "xla_hlo.iota"() {iota_dimension = 1 : i64} : () -> tensor<2x3x4xi32>
// CHECK-NEXT: return %[[CST]] : tensor<2x3x4xi32>
return %0 : tensor<2x3x4xi32>
}
// CHECK-LABEL: func @iota.const.6() -> tensor<2x3x4xi32> {
func @iota.const.6() -> tensor<2x3x4xi32> {
// CHECK-NEXT: %[[CST:.*]] = constant dense<{{\[\[\[}}0, 1, 2, 3], [0, 1, 2, 3], [0, 1, 2, 3{{\]\]}}, {{\[\[}}0, 1, 2, 3], [0, 1, 2, 3], [0, 1, 2, 3]]]> : tensor<2x3x4xi32>
%0 = "xla_hlo.iota"() {iota_dimension = 2 : i64} : () -> tensor<2x3x4xi32>
// CHECK-NEXT: return %[[CST]] : tensor<2x3x4xi32>
return %0 : tensor<2x3x4xi32>
}
// CHECK-LABEL: func @iota.const.f32
func @iota.const.f32() -> tensor<4xf32> {
// CHECK-NEXT: %[[CST:.*]] = constant dense<[0.000000e+00, 1.000000e+00, 2.000000e+00, 3.000000e+00]> : tensor<4xf32>
%0 = "xla_hlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<4xf32>
// CHECK-NEXT: return %[[CST]] : tensor<4xf32>
return %0 : tensor<4xf32>
}
// CHECK-LABEL: func @iota.const.f64
func @iota.const.f64() -> tensor<4xf64> {
// CHECK-NEXT: %[[CST:.*]] = constant dense<[0.000000e+00, 1.000000e+00, 2.000000e+00, 3.000000e+00]> : tensor<4xf64>
%0 = "xla_hlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<4xf64>
// CHECK-NEXT: return %[[CST]] : tensor<4xf64>
return %0 : tensor<4xf64>
}
// CHECK-LABEL: func @iota.const.bf16
func @iota.const.bf16() -> tensor<4xbf16> {
// CHECK-NEXT: %[[CST:.*]] = constant dense<[0.000000e+00, 1.000000e+00, 2.000000e+00, 3.000000e+00]> : tensor<4xbf16>
%0 = "xla_hlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<4xbf16>
// CHECK-NEXT: return %[[CST]] : tensor<4xbf16>
return %0 : tensor<4xbf16>
}
// CHECK-LABEL: func @iota.const.complex.f32
func @iota.const.complex.f32() -> tensor<4xcomplex<f32>> {
// CHECK-NEXT: [[REAL:%.*]] = constant dense<[0.000000e+00, 1.000000e+00, 2.000000e+00, 3.000000e+00]> : tensor<4xf32>
// CHECK-NEXT: [[IMAG:%.*]] = constant dense<0.000000e+00> : tensor<4xf32>
// CHECK-NEXT: [[COMPLEX:%.*]] = "xla_hlo.complex"([[REAL]], [[IMAG]])
%0 = "xla_hlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<4xcomplex<f32>>
// CHECK-NEXT: return [[COMPLEX]] : tensor<4xcomplex<f32>>
return %0 : tensor<4xcomplex<f32>>
}
// CHECK-LABEL: func @iota.const.complex.f64
func @iota.const.complex.f64() -> tensor<4xcomplex<f64>> {
// CHECK-NEXT: [[REAL:%.*]] = constant dense<[0.000000e+00, 1.000000e+00, 2.000000e+00, 3.000000e+00]> : tensor<4xf64>
// CHECK-NEXT: [[IMAG:%.*]] = constant dense<0.000000e+00> : tensor<4xf64>
// CHECK-NEXT: [[COMPLEX:%.*]] = "xla_hlo.complex"([[REAL]], [[IMAG]])
%0 = "xla_hlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<4xcomplex<f64>>
// CHECK-NEXT: return [[COMPLEX]] : tensor<4xcomplex<f64>>
return %0 : tensor<4xcomplex<f64>>
}

View File

@ -0,0 +1,125 @@
// RUN: mlir-hlo-opt -xla-legalize-tanh-to-approximation -split-input-file %s | FileCheck %s
func @tanh_f64(%arg0 : f64) -> f64 {
%res = tanh %arg0 : f64
return %res : f64
}
// CHECK-LABEL: @tanh_f64
// CHECK: tanh
// -----
func @tanh_f32(%arg0 : f32) -> f32 {
%res = tanh %arg0 : f32
return %res : f32
}
// NOTE: Assertions have been autogenerated by utils/generate-test-checks.py
// CHECK-LABEL: func @tanh_f32(
// CHECK-SAME: %[[VAL_0:.*]]: f32) -> f32 {
// CHECK: %[[VAL_1:.*]] = constant 4.000000e-04 : f32
// CHECK: %[[VAL_2:.*]] = constant 7.90531111 : f32
// CHECK: %[[VAL_3:.*]] = constant -7.90531111 : f32
// CHECK: %[[VAL_4:.*]] = constant -2.76076837E-16 : f32
// CHECK: %[[VAL_5:.*]] = constant 2.00018794E-13 : f32
// CHECK: %[[VAL_6:.*]] = constant -8.60467184E-11 : f32
// CHECK: %[[VAL_7:.*]] = constant 5.12229725E-8 : f32
// CHECK: %[[VAL_8:.*]] = constant 1.48572235E-5 : f32
// CHECK: %[[VAL_9:.*]] = constant 6.37261954E-4 : f32
// CHECK: %[[VAL_10:.*]] = constant 0.00489352457 : f32
// CHECK: %[[VAL_11:.*]] = constant 1.19825836E-6 : f32
// CHECK: %[[VAL_12:.*]] = constant 1.18534706E-4 : f32
// CHECK: %[[VAL_13:.*]] = constant 0.00226843474 : f32
// CHECK: %[[VAL_14:.*]] = constant 0.00489352504 : f32
// CHECK: %[[VAL_15:.*]] = absf %[[VAL_0]] : f32
// CHECK: %[[VAL_16:.*]] = cmpf "olt", %[[VAL_15]], %[[VAL_1]] : f32
// CHECK: %[[VAL_17:.*]] = cmpf "ule", %[[VAL_0]], %[[VAL_2]] : f32
// CHECK: %[[VAL_18:.*]] = select %[[VAL_17]], %[[VAL_0]], %[[VAL_2]] : f32
// CHECK: %[[VAL_19:.*]] = cmpf "uge", %[[VAL_18]], %[[VAL_3]] : f32
// CHECK: %[[VAL_20:.*]] = select %[[VAL_19]], %[[VAL_18]], %[[VAL_3]] : f32
// CHECK: %[[VAL_21:.*]] = mulf %[[VAL_20]], %[[VAL_20]] : f32
// CHECK: %[[VAL_22:.*]] = mulf %[[VAL_21]], %[[VAL_4]] : f32
// CHECK: %[[VAL_23:.*]] = addf %[[VAL_22]], %[[VAL_5]] : f32
// CHECK: %[[VAL_24:.*]] = mulf %[[VAL_21]], %[[VAL_23]] : f32
// CHECK: %[[VAL_25:.*]] = addf %[[VAL_24]], %[[VAL_6]] : f32
// CHECK: %[[VAL_26:.*]] = mulf %[[VAL_21]], %[[VAL_25]] : f32
// CHECK: %[[VAL_27:.*]] = addf %[[VAL_26]], %[[VAL_7]] : f32
// CHECK: %[[VAL_28:.*]] = mulf %[[VAL_21]], %[[VAL_27]] : f32
// CHECK: %[[VAL_29:.*]] = addf %[[VAL_28]], %[[VAL_8]] : f32
// CHECK: %[[VAL_30:.*]] = mulf %[[VAL_21]], %[[VAL_29]] : f32
// CHECK: %[[VAL_31:.*]] = addf %[[VAL_30]], %[[VAL_9]] : f32
// CHECK: %[[VAL_32:.*]] = mulf %[[VAL_21]], %[[VAL_31]] : f32
// CHECK: %[[VAL_33:.*]] = addf %[[VAL_32]], %[[VAL_10]] : f32
// CHECK: %[[VAL_34:.*]] = mulf %[[VAL_20]], %[[VAL_33]] : f32
// CHECK: %[[VAL_35:.*]] = mulf %[[VAL_21]], %[[VAL_11]] : f32
// CHECK: %[[VAL_36:.*]] = addf %[[VAL_35]], %[[VAL_12]] : f32
// CHECK: %[[VAL_37:.*]] = mulf %[[VAL_21]], %[[VAL_36]] : f32
// CHECK: %[[VAL_38:.*]] = addf %[[VAL_37]], %[[VAL_13]] : f32
// CHECK: %[[VAL_39:.*]] = mulf %[[VAL_21]], %[[VAL_38]] : f32
// CHECK: %[[VAL_40:.*]] = addf %[[VAL_39]], %[[VAL_14]] : f32
// CHECK: %[[VAL_41:.*]] = divf %[[VAL_34]], %[[VAL_40]] : f32
// CHECK: %[[VAL_42:.*]] = select %[[VAL_16]], %[[VAL_0]], %[[VAL_41]] : f32
// CHECK: return %[[VAL_42]] : f32
// CHECK: }
// -----
func @tanh_f16(%arg0 : f16) -> f16 {
%res = tanh %arg0 : f16
return %res : f16
}
// NOTE: Assertions have been autogenerated by utils/generate-test-checks.py
// CHECK-LABEL: func @tanh_f16(
// CHECK-SAME: %[[VAL_0:.*]]: f16) -> f16 {
// CHECK: %[[VAL_1:.*]] = constant 4.000000e-04 : f32
// CHECK: %[[VAL_2:.*]] = constant 7.90531111 : f32
// CHECK: %[[VAL_3:.*]] = constant -7.90531111 : f32
// CHECK: %[[VAL_4:.*]] = constant -2.76076837E-16 : f32
// CHECK: %[[VAL_5:.*]] = constant 2.00018794E-13 : f32
// CHECK: %[[VAL_6:.*]] = constant -8.60467184E-11 : f32
// CHECK: %[[VAL_7:.*]] = constant 5.12229725E-8 : f32
// CHECK: %[[VAL_8:.*]] = constant 1.48572235E-5 : f32
// CHECK: %[[VAL_9:.*]] = constant 6.37261954E-4 : f32
// CHECK: %[[VAL_10:.*]] = constant 0.00489352457 : f32
// CHECK: %[[VAL_11:.*]] = constant 1.19825836E-6 : f32
// CHECK: %[[VAL_12:.*]] = constant 1.18534706E-4 : f32
// CHECK: %[[VAL_13:.*]] = constant 0.00226843474 : f32
// CHECK: %[[VAL_14:.*]] = constant 0.00489352504 : f32
// CHECK: %[[VAL_15:.*]] = fpext %[[VAL_0]] : f16 to f32
// CHECK: %[[VAL_16:.*]] = absf %[[VAL_15]] : f32
// CHECK: %[[VAL_17:.*]] = cmpf "olt", %[[VAL_16]], %[[VAL_1]] : f32
// CHECK: %[[VAL_18:.*]] = cmpf "ule", %[[VAL_15]], %[[VAL_2]] : f32
// CHECK: %[[VAL_19:.*]] = select %[[VAL_18]], %[[VAL_15]], %[[VAL_2]] : f32
// CHECK: %[[VAL_20:.*]] = cmpf "uge", %[[VAL_19]], %[[VAL_3]] : f32
// CHECK: %[[VAL_21:.*]] = select %[[VAL_20]], %[[VAL_19]], %[[VAL_3]] : f32
// CHECK: %[[VAL_22:.*]] = mulf %[[VAL_21]], %[[VAL_21]] : f32
// CHECK: %[[VAL_23:.*]] = mulf %[[VAL_22]], %[[VAL_4]] : f32
// CHECK: %[[VAL_24:.*]] = addf %[[VAL_23]], %[[VAL_5]] : f32
// CHECK: %[[VAL_25:.*]] = mulf %[[VAL_22]], %[[VAL_24]] : f32
// CHECK: %[[VAL_26:.*]] = addf %[[VAL_25]], %[[VAL_6]] : f32
// CHECK: %[[VAL_27:.*]] = mulf %[[VAL_22]], %[[VAL_26]] : f32
// CHECK: %[[VAL_28:.*]] = addf %[[VAL_27]], %[[VAL_7]] : f32
// CHECK: %[[VAL_29:.*]] = mulf %[[VAL_22]], %[[VAL_28]] : f32
// CHECK: %[[VAL_30:.*]] = addf %[[VAL_29]], %[[VAL_8]] : f32
// CHECK: %[[VAL_31:.*]] = mulf %[[VAL_22]], %[[VAL_30]] : f32
// CHECK: %[[VAL_32:.*]] = addf %[[VAL_31]], %[[VAL_9]] : f32
// CHECK: %[[VAL_33:.*]] = mulf %[[VAL_22]], %[[VAL_32]] : f32
// CHECK: %[[VAL_34:.*]] = addf %[[VAL_33]], %[[VAL_10]] : f32
// CHECK: %[[VAL_35:.*]] = mulf %[[VAL_21]], %[[VAL_34]] : f32
// CHECK: %[[VAL_36:.*]] = mulf %[[VAL_22]], %[[VAL_11]] : f32
// CHECK: %[[VAL_37:.*]] = addf %[[VAL_36]], %[[VAL_12]] : f32
// CHECK: %[[VAL_38:.*]] = mulf %[[VAL_22]], %[[VAL_37]] : f32
// CHECK: %[[VAL_39:.*]] = addf %[[VAL_38]], %[[VAL_13]] : f32
// CHECK: %[[VAL_40:.*]] = mulf %[[VAL_22]], %[[VAL_39]] : f32
// CHECK: %[[VAL_41:.*]] = addf %[[VAL_40]], %[[VAL_14]] : f32
// CHECK: %[[VAL_42:.*]] = divf %[[VAL_35]], %[[VAL_41]] : f32
// CHECK: %[[VAL_43:.*]] = select %[[VAL_17]], %[[VAL_15]], %[[VAL_42]] : f32
// CHECK: %[[VAL_44:.*]] = fptrunc %[[VAL_43]] : f32 to f16
// CHECK: return %[[VAL_44]] : f16
// CHECK: }

View File

@ -0,0 +1,93 @@
// RUN: mlir-hlo-opt -lhlo-copy-removal %s -o - | FileCheck %s
// CHECK-LABEL: func @remove_simple
func @remove_simple(%arg0: memref<2x2xf32>) {
%0 = alloc() {temp = true} : memref<2x2xf32>
"xla_lhlo.copy"(%0, %arg0) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
dealloc %0 : memref<2x2xf32>
// CHECK-NEXT: "xla_lhlo.terminator"() : () -> ()
"xla_lhlo.terminator"() : () -> ()
}
// -----
// CHECK-LABEL: func @remove_without_dealloc
func @remove_without_dealloc(%arg0: memref<2x2xf32>) {
%0 = alloc() {temp = true} : memref<2x2xf32>
"xla_lhlo.copy"(%0, %arg0) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
// CHECK-NEXT: "xla_lhlo.terminator"() : () -> ()
"xla_lhlo.terminator"() : () -> ()
}
// -----
// CHECK-LABEL: func @replace_dependency
func @replace_dependency(%arg0: memref<2x2xf32>, %arg1: memref<2x2xf32>) {
%0 = alloc() {temp = true} : memref<2x2xf32>
"xla_lhlo.exponential"(%arg0, %0) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
// CHECK-NEXT: "xla_lhlo.exponential"(%arg0, %arg1) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
"xla_lhlo.copy"(%0, %arg1) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
dealloc %0 : memref<2x2xf32>
// CHECK-NEXT: "xla_lhlo.terminator"() : () -> ()
"xla_lhlo.terminator"() : () -> ()
}
// -----
// CHECK-LABEL: func @keep_copies
func @keep_copies(%arg0: memref<2x2xf32>, %arg1: memref<2x2xf32>) {
// CHECK-NEXT: "xla_lhlo.copy"(%arg0, %arg1) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
"xla_lhlo.copy"(%arg0, %arg1) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
// CHECK-NEXT: "xla_lhlo.terminator"() : () -> ()
"xla_lhlo.terminator"() : () -> ()
}
// -----
// CHECK-LABEL: func @must_not_be_removed
func @must_not_be_removed(%arg0: memref<2x2xf32>,
%arg1: memref<2x2xf32>,
%arg2: memref<2x2xf32>) {
// CHECK-NEXT: %[[ALLOC:.*]] = alloc() {temp = true} : memref<2x2xf32>
%0 = alloc() {temp = true} : memref<2x2xf32>
// CHECK-NEXT: "xla_lhlo.exponential"(%arg0, %[[ALLOC]]) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
"xla_lhlo.exponential"(%arg0, %0) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
// CHECK-NEXT: "xla_lhlo.exponential"(%arg1, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
"xla_lhlo.exponential"(%arg1, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
// CHECK-NEXT: "xla_lhlo.copy"(%[[ALLOC]], %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
"xla_lhlo.copy"(%0, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
dealloc %0 : memref<2x2xf32>
"xla_lhlo.terminator"() : () -> ()
}
// -----
// CHECK-LABEL: func @must_be_removed_first
func @must_be_removed_first(%arg0: memref<2x2xf32>,
%arg1: memref<2x2xf32>,
%arg2: memref<2x2xf32>) {
%0 = alloc() {temp = true} : memref<2x2xf32>
// CHECK-NEXT: "xla_lhlo.exponential"(%arg1, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
"xla_lhlo.exponential"(%arg1, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
// CHECK-NEXT: "xla_lhlo.exponential"(%arg0, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
"xla_lhlo.exponential"(%arg0, %0) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
"xla_lhlo.copy"(%0, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
dealloc %0 : memref<2x2xf32>
"xla_lhlo.terminator"() : () -> ()
}
// -----
// CHECK-LABEL: func @must_be_removed_second
func @must_be_removed_second(%arg0: memref<2x2xf32>,
%arg1: memref<2x2xf32>,
%arg2: memref<2x2xf32>) {
%0 = alloc() {temp = true} : memref<2x2xf32>
// CHECK-NEXT: "xla_lhlo.exponential"(%arg0, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
"xla_lhlo.exponential"(%arg0, %0) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
"xla_lhlo.copy"(%0, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
// CHECK-NEXT: "xla_lhlo.exponential"(%arg1, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
"xla_lhlo.exponential"(%arg1, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
dealloc %0 : memref<2x2xf32>
"xla_lhlo.terminator"() : () -> ()
}

236
tests/lhlo-fuse-linalg.mlir Normal file
View File

@ -0,0 +1,236 @@
// RUN: mlir-hlo-opt -lhlo-fuse-linalg %s -split-input-file | FileCheck %s --dump-input=always
// RUN: mlir-hlo-opt -lhlo-fuse-linalg=tile-sizes=2,3 %s -split-input-file | FileCheck %s -check-prefix=TILED
// RUN: mlir-hlo-opt -lhlo-fuse-linalg=use-parallel-loops %s -split-input-file | FileCheck %s -check-prefix=PLOOP
#map0 = affine_map<(d0, d1) -> (d0, d1)>
#pointwise_2d_trait = {args_in = 2, args_out = 1, indexing_maps = [#map0, #map0, #map0], iterator_types = ["parallel", "parallel"]}
func @fusion(%multiplier: memref<6x6xf32>, %summand_1: memref<6x6xf32>,
%summand_2: memref<6x6xf32>, %result: memref<6x6xf32>) {
%temp_result = alloc() : memref<6x6xf32>
linalg.generic #pointwise_2d_trait %summand_1, %summand_2, %temp_result {
^bb0(%summand_1_in: f32, %summand_2_in: f32, %temp_result_in: f32):
%out = addf %summand_1_in, %summand_2_in : f32
linalg.yield %out : f32
} : memref<6x6xf32>, memref<6x6xf32>, memref<6x6xf32>
linalg.generic #pointwise_2d_trait %temp_result, %multiplier, %result {
^bb0(%temp_result_in: f32, %multiplier_in: f32, %result_in: f32):
%out = mulf %temp_result_in, %multiplier_in : f32
linalg.yield %out : f32
} : memref<6x6xf32>, memref<6x6xf32>, memref<6x6xf32>
dealloc %temp_result : memref<6x6xf32>
return
}
// CHECK-LABEL: func @fusion
// CHECK: %[[C1:.*]] = constant 1
// CHECK-NOT: linalg.generic
// CHECK: scf.for {{.*}} step %[[C1]]
// CHECK: scf.for {{.*}} step %[[C1]]
// CHECK-NOT: scf.for
// CHECK: linalg.generic
// CHECK: addf
// CHECK: linalg.generic
// CHECK: mulf
// TILED-LABEL: func @fusion
// TILED-DAG: %[[C2:.*]] = constant 2
// TILED-DAG: %[[C3:.*]] = constant 3
// TILED-NOT: linalg.generic
// TILED: scf.for {{.*}} step %[[C2]]
// TILED: scf.for {{.*}} step %[[C3]]
// TILED-NOT: scf.for
// TILED: linalg.generic
// TILED: addf
// TILED: linalg.generic
// TILED: mulf
// PLOOP-LABEL: func @fusion
// PLOOP-NOT: linalg.generic
// PLOOP: scf.parallel
// PLOOP-NOT: scf.parallel
// PLOOP: linalg.generic
// PLOOP: addf
// PLOOP: linalg.generic
// PLOOP: mulf
// -----
func @fusion_of_three(%arg0: memref<100x10xf32>,
%arg1: memref<100xf32>,
%arg2: memref<100x10xf32>) {
%0 = alloc() : memref<100x10xf32>
linalg.generic {
args_in = 1 : i64,
args_out = 1 : i64,
indexing_maps = [affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0, d1)>],
iterator_types = ["parallel", "parallel"]
} %arg1, %0 {
^bb0(%arg3: f32, %arg4: f32): // no predecessors
linalg.yield %arg3 : f32
}: memref<100xf32>, memref<100x10xf32>
%1 = alloc() : memref<100x10xf32>
linalg.generic {
args_in = 2 : i64,
args_out = 1 : i64,
indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>],
iterator_types = ["parallel", "parallel"]
} %arg0, %0, %1 {
^bb0(%arg3: f32, %arg4: f32, %arg5: f32): // no predecessors
%2 = subf %arg3, %arg4 : f32
linalg.yield %2 : f32
}: memref<100x10xf32>, memref<100x10xf32>, memref<100x10xf32>
dealloc %0 : memref<100x10xf32>
linalg.generic {
args_in = 1 : i64,
args_out = 1 : i64,
indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>],
iterator_types = ["parallel", "parallel"]
} %1, %arg2 {
^bb0(%arg3: f32, %arg4: f32): // no predecessors
%2 = exp %arg3 : f32
linalg.yield %2 : f32
}: memref<100x10xf32>, memref<100x10xf32>
dealloc %1 : memref<100x10xf32>
return
}
// CHECK-LABEL: func @fusion
// CHECK: %[[C1:.*]] = constant 1
// CHECK-NOT: linalg.generic
// CHECK: scf.for {{.*}} step %[[C1]]
// CHECK: scf.for {{.*}} step %[[C1]]
// CHECK-NOT: scf.for
// CHECK: linalg.generic
// CHECK: linalg.generic
// CHECK: subf
// CHECK: linalg.generic
// CHECK: exp
// TILED-LABEL: func @fusion_of_three
// TILED-DAG: %[[C2:.*]] = constant 2
// TILED-DAG: %[[C3:.*]] = constant 3
// TILED-NOT: linalg.generic
// TILED: scf.for {{.*}} step %[[C2]]
// TILED: scf.for {{.*}} step %[[C3]]
// TILED-NOT: scf.for
// TILED: linalg.generic
// TILED: linalg.generic
// TILED: subf
// TILED: linalg.generic
// TILED: exp
// PLOOP-LABEL: func @fusion_of_three
// PLOOP-NOT: linalg.generic
// PLOOP: scf.parallel
// PLOOP-NOT: scf.parallel
// PLOOP: linalg.generic
// PLOOP: linalg.generic
// PLOOP: subf
// PLOOP: linalg.generic
// PLOOP: exp
// -----
#map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
#pointwise_4d_trait = {args_in = 2, args_out = 1, indexing_maps = [#map0, #map0, #map0], iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
func @fusion_4d(%multiplier: memref<6x6x6x6xf32>, %summand_1: memref<6x6x6x6xf32>,
%summand_2: memref<6x6x6x6xf32>, %result: memref<6x6x6x6xf32>) {
%temp_result = alloc() : memref<6x6x6x6xf32>
linalg.generic #pointwise_4d_trait %summand_1, %summand_2, %temp_result {
^bb0(%summand_1_in: f32, %summand_2_in: f32, %temp_result_in: f32):
%out = addf %summand_1_in, %summand_2_in : f32
linalg.yield %out : f32
} : memref<6x6x6x6xf32>, memref<6x6x6x6xf32>, memref<6x6x6x6xf32>
linalg.generic #pointwise_4d_trait %temp_result, %multiplier, %result {
^bb0(%temp_result_in: f32, %multiplier_in: f32, %result_in: f32):
%out = mulf %temp_result_in, %multiplier_in : f32
linalg.yield %out : f32
} : memref<6x6x6x6xf32>, memref<6x6x6x6xf32>, memref<6x6x6x6xf32>
dealloc %temp_result : memref<6x6x6x6xf32>
return
}
// CHECK-LABEL: func @fusion_4d
// CHECK: %[[C1:.*]] = constant 1
// CHECK-NOT: linalg.generic
// CHECK: scf.for {{.*}} step %[[C1]]
// CHECK: scf.for {{.*}} step %[[C1]]
// CHECK: scf.for {{.*}} step %[[C1]]
// CHECK: scf.for {{.*}} step %[[C1]]
// CHECK-NOT: scf.for
// CHECK: linalg.generic
// CHECK: addf
// CHECK: linalg.generic
// CHECK: mulf
// TILED-LABEL: func @fusion_4d
// TILED-DAG: %[[C2:.*]] = constant 2
// TILED-DAG: %[[C3:.*]] = constant 3
// TILED-NOT: linalg.generic
// TILED: scf.for {{.*}} step %[[C2]]
// TILED: scf.for {{.*}} step %[[C3]]
// TILED-NOT: scf.for
// TILED: linalg.generic
// TILED: addf
// TILED: linalg.generic
// TILED: mulf
// PLOOP-LABEL: func @fusion_4d
// PLOOP-NOT: linalg.generic
// PLOOP: scf.parallel
// PLOOP-NOT: scf.parallel
// PLOOP: linalg.generic
// PLOOP: addf
// PLOOP: linalg.generic
// PLOOP: mulf
// -----
#map0 = affine_map<(d0, d1) -> (d0, d1)>
#pointwise_2d_trait = {args_in = 2, args_out = 1, indexing_maps = [#map0, #map0, #map0], iterator_types = ["parallel", "parallel"]}
func @fusion(%multiplier: memref<6x6xf32>, %summand_1: memref<6x6xf32>,
%summand_2: memref<6x6xf32>) -> memref<6x6xf32> {
%temp_result = alloc() : memref<6x6xf32>
linalg.generic #pointwise_2d_trait %summand_1, %summand_2, %temp_result {
^bb0(%summand_1_in: f32, %summand_2_in: f32, %temp_result_in: f32):
%out = addf %summand_1_in, %summand_2_in : f32
linalg.yield %out : f32
} : memref<6x6xf32>, memref<6x6xf32>, memref<6x6xf32>
%result = alloc() : memref<6x6xf32>
linalg.generic #pointwise_2d_trait %temp_result, %multiplier, %result {
^bb0(%temp_result_in: f32, %multiplier_in: f32, %result_in: f32):
%out = mulf %temp_result_in, %multiplier_in : f32
linalg.yield %out : f32
} : memref<6x6xf32>, memref<6x6xf32>, memref<6x6xf32>
dealloc %temp_result : memref<6x6xf32>
return %result : memref<6x6xf32>
}
// CHECK-LABEL: func @fusion
// CHECK: %[[C1:.*]] = constant 1
// CHECK-NOT: linalg.generic
// CHECK: scf.for {{.*}} step %[[C1]]
// CHECK: scf.for {{.*}} step %[[C1]]
// CHECK-NOT: scf.for
// CHECK: linalg.generic
// CHECK: addf
// CHECK: linalg.generic
// CHECK: mulf
// TILED-LABEL: func @fusion
// TILED-DAG: %[[C2:.*]] = constant 2
// TILED-DAG: %[[C3:.*]] = constant 3
// TILED-NOT: linalg.generic
// TILED: scf.for {{.*}} step %[[C2]]
// TILED: scf.for {{.*}} step %[[C3]]
// TILED-NOT: scf.for
// TILED: linalg.generic
// TILED: addf
// TILED: linalg.generic
// TILED: mulf
// PLOOP-LABEL: func @fusion
// PLOOP-NOT: linalg.generic
// PLOOP: scf.parallel
// PLOOP-NOT: scf.parallel
// PLOOP: linalg.generic
// PLOOP: addf
// PLOOP: linalg.generic
// PLOOP: mulf

View File

@ -0,0 +1,193 @@
// GenericAtomicRMWOp should contain only ops with no side effects.
// Unfortunately, the legalization pattern for SelectAndScatterOp has to adapt
// to XLA LHLO dialect using allocs/deallocs inside of GenericAtomicRMWOp body.
// Lowering to STD dialect and store forwarding pass would be required to get
// rid of them. This is exactly what is done in the real MLIR GPU pipeline, but
// here we disable verification with `verify-each=0` to check the output IR.
// RUN: mlir-hlo-opt %s -lhlo-legalize-to-parallel-loops -canonicalize --verify-each=0 | FileCheck %s
func @select_and_scatter(%arg: memref<112x112xf32>,
%src: memref<56x56xf32>,
%init: memref<f32>,
%result: memref<112x112xf32>) {
"xla_lhlo.select_and_scatter"(%arg, %src, %init, %result) ( {
// select
^bb0(%lhs: memref<f32>, %rhs: memref<f32>, %pred: memref<i1>):
"xla_lhlo.compare"(%lhs, %rhs, %pred) {comparison_direction = "GE"} :
(memref<f32>, memref<f32>, memref<i1>) -> ()
"xla_lhlo.terminator"() : () -> ()
}, {
// scatter
^bb0(%lhs: memref<f32>, %rhs: memref<f32>, %out: memref<f32>):
"xla_lhlo.add"(%lhs, %rhs, %out) :
(memref<f32>, memref<f32>, memref<f32>) -> ()
"xla_lhlo.terminator"() : () -> ()
}) {
padding = dense<[[0, 1], [0, 1]]> : tensor<2x2xi64>,
window_dimensions = dense<[3, 3]> : tensor<2xi64>,
window_strides = dense<[2, 2]> : tensor<2xi64>
} : (memref<112x112xf32>,
memref<56x56xf32>,
memref<f32>, memref<112x112xf32>) -> ()
"xla_lhlo.terminator"() : () -> ()
}
// CHECK-LABEL: func @select_and_scatter(
// CHECK-SAME: [[ARG_BUF:%.*]]: memref<112x112xf32>,
// CHECK-SAME: [[SRC_BUF:%.*]]: memref<56x56xf32>,
// CHECK-SAME: [[INIT_BUF:%.*]]: memref<f32>,
// CHECK-SAME: [[RESULT_BUF:%.*]]: memref<112x112xf32>) {
// Constants.
// CHECK-DAG: [[C56:%.*]] = constant 56 : index
// CHECK-DAG: [[C0:%.*]] = constant 0 : index
// CHECK-DAG: [[C1:%.*]] = constant 1 : index
// CHECK-DAG: [[C0_F32:%.*]] = constant 0.000000e+00 : f32
// CHECK-DAG: [[CFALSE:%.*]] = constant false
// CHECK-DAG: [[C3:%.*]] = constant 3 : index
// CHECK-DAG: [[C2:%.*]] = constant 2 : index
// CHECK-DAG: [[C112:%.*]] = constant 112 : index
// CHECK-DAG: [[CTRUE:%.*]] = constant true
// Parallel loop to initialize the output buffer.
// CHECK: [[INIT:%.*]] = load [[INIT_BUF]][] : memref<f32>
// CHECK: scf.parallel ([[I:%.*]], [[J:%.*]]) = ([[C0]], [[C0]])
// CHECK-SAME: to ([[C112]], [[C112]]) step ([[C1]], [[C1]]) {
// CHECK: store [[INIT]], [[RESULT_BUF]]{{\[}}[[I]], [[J]]]
// CHECK: scf.yield
// CHECK: }
// Parallel loop over source buffer to compute scattered values.
// CHECK: scf.parallel ([[II:%.*]], [[JJ:%.*]]) = ([[C0]], [[C0]])
// CHECK-SAME: to ([[C56]], [[C56]]) step ([[C1]], [[C1]]) {
// Window loop w.r.t. first dim.
// CHECK: [[SEL_RES_I:%.*]]:4
// CHECK-SAME: = scf.for [[WIN_I:%.*]] = [[C0]] to [[C3]] step [[C1]]
// CHECK-SAME: iter_args(
// CHECK-SAME: [[SEL_I_0:%.*]] = [[C0]], [[SEL_J_0:%.*]] = [[C0]],
// CHECK-SAME: [[SEL_VAL_0:%.*]] = [[C0_F32]],
// CHECK-SAME: [[SEL_INIT_0:%.*]] = [[CFALSE]]
// CHECK-SAME: ) -> (index, index, f32, i1) {
// Window loop w.r.t. second dim.
// CHECK: [[SEL_RES_J:%.*]]:4
// CHECK-SAME: = scf.for [[WIN_J:%.*]] = [[C0]] to [[C3]] step [[C1]]
// CHECK-SAME: iter_args(
// CHECK-SAME: [[SEL_I:%.*]] = [[SEL_I_0]], [[SEL_J:%.*]] = [[SEL_J_0]],
// CHECK-SAME: [[SEL_VAL:%.*]] = [[SEL_VAL_0]],
// CHECK-SAME: [[SEL_INIT:%.*]] = [[SEL_INIT_0]]
// CHECK-SAME: ) -> (index, index, f32, i1) {
// Compute index I of the ARG buffer and check whether it is in padding area.
// CHECK: [[START_I:%.*]] = muli [[II]], [[C2]] : index
// CHECK: [[ARG_I:%.*]] = addi [[START_I]], [[WIN_I]] : index
// CHECK: [[ARG_I_FITS:%.*]] = cmpi "ult", [[ARG_I]], [[C112]] : index
// Compute index J of the ARG buffer and check whether it is in padding area.
// CHECK: [[START_J:%.*]] = muli [[JJ]], [[C2]] : index
// CHECK: [[ARG_J:%.*]] = addi [[START_J]], [[WIN_J]] : index
// CHECK: [[ARG_J_FITS:%.*]] = cmpi "ult", [[ARG_J]], [[C112]] : index
// Update `INBOUNDS`, i.e. whether or not ARG indices are inside the boundaries
// of the buffer or they are in the padding area.
// CHECK: [[INBOUNDS_1:%.*]] = and [[ARG_I_FITS]], [[ARG_J_FITS]] : i1
// If ARG ivs are in the padding area, then 'select' function does not have to
// be applied, current selected ivs (SEL_I, SEL_J) and value (SEL_VAL) are
// returned in that case.
// CHECK: [[IF_INBOUNDS_RES:%.*]]:4
// CHECK-SAME: = scf.if [[INBOUNDS_1]] -> (index, index, f32, i1) {
// INBOUNDS-THEN-BODY, i.e. if INBOUNDS == true
// CHECK: [[ARG_ELEM:%.*]] = load [[ARG_BUF]]{{\[}}[[ARG_I]], [[ARG_J]]]
// CHECK: [[IF_INIT_RES:%.*]]:4
// CHECK-SAME: = scf.if [[SEL_INIT]] -> (index, index, f32, i1) {
// INIT-THEN-BODY, i.e. INBOUNDS == true and INIT = true
// The LHLO IR of the select block of the lhlo.select_and_scatter is applied
// to the current selected value (SEL_VAL) and the element of the ARG buffer
// to compute boolean PRED, whether the new value and ivs should replace the
// current ones.
// Allocate buffers for ARG element, current selected value to adapt LHLO
// code.
// CHECK: [[ARG_ELEM_BUF:%.*]] = alloc() : memref<f32>
// CHECK: [[SEL_VAL_BUF:%.*]] = alloc() : memref<f32>
// CHECK: [[PRED_BUF:%.*]] = alloc() : memref<i1>
// CHECK: store [[ARG_ELEM]], [[ARG_ELEM_BUF]][] : memref<f32>
// CHECK: store [[SEL_VAL]], [[SEL_VAL_BUF]][] : memref<f32>
// Compute PRED.
// CHECK: "xla_lhlo.compare"(
// CHECK-SAME: [[ARG_ELEM_BUF]], [[SEL_VAL_BUF]], [[PRED_BUF]])
// CHECK: [[PRED:%.*]] = load [[PRED_BUF]][] : memref<i1>
// Depending on PRED, return ARG ivs & elem or current select ivs and value.
// CHECK: [[IF_PRED_RES:%.*]]:4 = scf.if [[PRED]]
// CHECK: scf.yield [[ARG_I]], [[ARG_J]], [[ARG_ELEM]], [[CTRUE]]
// CHECK: } else {
// CHECK: scf.yield [[SEL_I]], [[SEL_J]], [[SEL_VAL]], [[SEL_INIT]]
// CHECK: }
// INIT-THEN-BODY yield.
// CHECK: scf.yield [[IF_PRED_RES]]#0, [[IF_PRED_RES]]#1,
// CHECK-SAME: [[IF_PRED_RES]]#2, [[IF_PRED_RES]]#3
// INIT-ELSE-BODY, i.e. if INBOUNDS == TRUE and INIT == FALSE, returns ARG
// ivs and element without computing Select function.
// CHECK: scf.yield [[ARG_I]], [[ARG_J]], [[ARG_ELEM]],
// CHECK-SAME: [[CTRUE]] : index, index, f32, i1
// CHECK: }
// INBOUNDS-THEN-BODY yield.
// CHECK: scf.yield [[IF_INIT_RES]]#0, [[IF_INIT_RES]]#1, [[IF_INIT_RES]]#2,
// CHECK-SAME: [[IF_INIT_RES]]#3 : index, index, f32, i1
// CHECK: }
// INBOUNDS-ELSE-REGION, i.e. if INBOUNDS == FALSE
// We are in the pad area, return current iter_args.
// CHECK: scf.yield [[SEL_I]], [[SEL_J]], [[SEL_VAL]],
// CHECK-SAME: [[SEL_INIT]] : index, index, f32, i1
// CHECK: }
// Window loop w.r.t. second dim yield.
// CHECK: scf.yield [[IF_INBOUNDS_RES]]#0, [[IF_INBOUNDS_RES]]#1,
// CHECK-SAME: [[IF_INBOUNDS_RES]]#2, [[IF_INBOUNDS_RES]]#3
// CHECK: }
// Window loop w.r.t. first dim yield.
// CHECK: scf.yield [[SEL_RES_J]]#0, [[SEL_RES_J]]#1, [[SEL_RES_J]]#2,
// CHECK-SAME: [[SEL_RES_J]]#3 : index, index, f32, i1
// CHECK: }
// Use selected ivs to load element from the SRC buffer.
// CHECK: [[SRC_ELEM:%.*]] = load [[SRC_BUF]]{{\[}}[[II]], [[JJ]]]
// Update of RESULT[SELECTED_I, SELECTED_J] should be done atomically, because
// it may happen that several other threads select the same IVs if the windows
// overlap.
// CHECK: generic_atomic_rmw [[RESULT_BUF]]{{\[}}[[SEL_RES_I]]#0,
// CHECK-SAME: [[SEL_RES_I]]#1] : memref<112x112xf32>
// CHECK: ^bb0([[CUR_RES:%.*]]: f32):
// Allocate buffers for ARG element, current selected value to adapt LHLO code.
// CHECK: [[SRC_ELEM_BUF:%.*]] = alloc() : memref<f32>
// CHECK: [[CUR_RES_BUF:%.*]] = alloc() : memref<f32>
// CHECK: [[RES_BUF:%.*]] = alloc() : memref<f32>
// CHECK: store [[SRC_ELEM]], [[SRC_ELEM_BUF]][] : memref<f32>
// CHECK: store [[CUR_RES]], [[CUR_RES_BUF]][] : memref<f32>
// Compute scatter value.
// CHECK: "xla_lhlo.add"([[SRC_ELEM_BUF]], [[CUR_RES_BUF]], [[RES_BUF]]) :
// CHECK-SAME: (memref<f32>, memref<f32>, memref<f32>) -> ()
// CHECK: [[RES:%.*]] = load [[RES_BUF]][] : memref<f32>
// Atomic RMW terminator that returns updated value.
// CHECK: atomic_yield [[RES]] : f32
// Parallel loop over source buffer yield
// CHECK: scf.yield

View File

@ -0,0 +1,181 @@
// RUN: mlir-hlo-opt -lhlo-legalize-to-affine %s -o - | FileCheck %s
// Smoke test.
// CHECK-LABEL: func @min_op
func @min_op(%lhs: memref<4x3x2x1xf32>, %rhs: memref<4x3x2x1xf32>,
%result: memref<4x3x2x1xf32>) -> () {
// CHECK-NEXT: affine.for %[[I:.*]] = 0 to 4 {
// CHECK-NEXT: affine.for %[[J:.*]] = 0 to 3 {
// CHECK-NEXT: affine.for %[[K:.*]] = 0 to 2 {
// CHECK-NEXT: affine.for %[[L:.*]] = 0 to 1 {
// CHECK-NEXT: %[[LHS:.*]] = affine.load %{{.*}}[%[[I]], %[[J]], %[[K]], %[[L]]] : memref<4x3x2x1xf32>
// CHECK-NEXT: %[[RHS:.*]] = affine.load %{{.*}}[%[[I]], %[[J]], %[[K]], %[[L]]] : memref<4x3x2x1xf32>
// CHECK-NEXT: %[[MIN_PREDICATE:.*]] = cmpf "olt", %[[LHS]], %[[RHS]] : f32
// CHECK-NEXT: %[[MIN:.*]] = select %[[MIN_PREDICATE]], %[[LHS]], %[[RHS]] : f32
// CHECK-NEXT: affine.store %[[MIN]], %{{.*}}[%[[I]], %[[J]], %[[K]], %[[L]]] : memref<4x3x2x1xf32>
// CHECK: return
"xla_lhlo.minimum"(%lhs, %rhs, %result) {name = "min.1"} :
(memref<4x3x2x1xf32>, memref<4x3x2x1xf32>, memref<4x3x2x1xf32>) -> ()
return
}
// Add tests.
// CHECK-LABEL: func @float_add_op
func @float_add_op(%lhs: memref<7xf32>, %rhs: memref<7xf32>,
%result: memref<7xf32>) -> () {
// CHECK: addf %{{.*}}, %{{.*}} : f32
"xla_lhlo.add"(%lhs, %rhs, %result) {name = "add.1"}
: (memref<7xf32>, memref<7xf32>, memref<7xf32>) -> ()
return
}
// CHECK-LABEL: func @int_add_op
func @int_add_op(%lhs: memref<7xi32>, %rhs: memref<7xi32>,
%result: memref<7xi32>) -> () {
// CHECK: addi %{{.*}}, %{{.*}} : i32
"xla_lhlo.add"(%lhs, %rhs, %result) {name = "add.1"}
: (memref<7xi32>, memref<7xi32>, memref<7xi32>) -> ()
return
}
// And test.
// CHECK-LABEL: func @int_and_op
func @int_and_op(%lhs: memref<7xi32>, %rhs: memref<7xi32>,
%result: memref<7xi32>) -> () {
// CHECK: and %{{.*}}, %{{.*}} : i32
"xla_lhlo.and"(%lhs, %rhs, %result) {name = "and.1"}
: (memref<7xi32>, memref<7xi32>, memref<7xi32>) -> ()
return
}
// Div tests.
// CHECK-LABEL: func @float_div_op
func @float_div_op(%lhs: memref<7xf32>, %rhs: memref<7xf32>,
%result: memref<7xf32>) -> () {
// CHECK: divf %{{.*}}, %{{.*}} : f32
"xla_lhlo.divide"(%lhs, %rhs, %result) {name = "div.1"}
: (memref<7xf32>, memref<7xf32>, memref<7xf32>) -> ()
return
}
// CHECK-LABEL: func @int_div_op
func @int_div_op(%lhs: memref<7xi32>, %rhs: memref<7xi32>,
%result: memref<7xi32>) -> () {
// CHECK: divi_signed %{{.*}}, %{{.*}} : i32
"xla_lhlo.divide"(%lhs, %rhs, %result) {name = "div.1"}
: (memref<7xi32>, memref<7xi32>, memref<7xi32>) -> ()
return
}
// Max tests.
// CHECK-LABEL: func @float_max_op
func @float_max_op(%lhs: memref<7xf32>, %rhs: memref<7xf32>,
%result: memref<7xf32>) -> () {
// CHECK: %[[CHECK:.*]] = cmpf "ogt", %[[ONE:.*]], %[[TWO:.*]] : f32
// CHECK: select %[[CHECK]], %[[ONE]], %[[TWO]] : f32
"xla_lhlo.maximum"(%lhs, %rhs, %result) {name = "max.1"}
: (memref<7xf32>, memref<7xf32>, memref<7xf32>) -> ()
return
}
// CHECK-LABEL: func @int_max_op
func @int_max_op(%lhs: memref<7xi32>, %rhs: memref<7xi32>,
%result: memref<7xi32>) -> () {
// CHECK: %[[CHECK:.*]] = cmpi "sgt", %[[ONE:.*]], %[[TWO:.*]] : i32
// CHECK: select %[[CHECK]], %[[ONE]], %[[TWO]] : i32
"xla_lhlo.maximum"(%lhs, %rhs, %result) {name = "max.1"}
: (memref<7xi32>, memref<7xi32>, memref<7xi32>) -> ()
return
}
// Min tests.
// CHECK-LABEL: func @float_min_op
func @float_min_op(%lhs: memref<7xf32>, %rhs: memref<7xf32>,
%result: memref<7xf32>) -> () {
// CHECK: %[[CHECK:.*]] = cmpf "olt", %[[ONE:.*]], %[[TWO:.*]] : f32
// CHECK: select %[[CHECK]], %[[ONE]], %[[TWO]] : f32
"xla_lhlo.minimum"(%lhs, %rhs, %result) {name = "min.1"}
: (memref<7xf32>, memref<7xf32>, memref<7xf32>) -> ()
return
}
// CHECK-LABEL: func @int_min_op
func @int_min_op(%lhs: memref<7xi32>, %rhs: memref<7xi32>,
%result: memref<7xi32>) -> () {
// CHECK: %[[CHECK:.*]] = cmpi "slt", %[[ONE:.*]], %[[TWO:.*]] : i32
// CHECK: select %[[CHECK]], %[[ONE]], %[[TWO]] : i32
"xla_lhlo.minimum"(%lhs, %rhs, %result) {name = "min.1"}
: (memref<7xi32>, memref<7xi32>, memref<7xi32>) -> ()
return
}
// Mul tests.
// CHECK-LABEL: func @float_mul_op
func @float_mul_op(%lhs: memref<7xf32>, %rhs: memref<7xf32>,
%result: memref<7xf32>) -> () {
// CHECK: mulf %{{.*}}, %{{.*}} : f32
"xla_lhlo.multiply"(%lhs, %rhs, %result) {name = "mul.1"}
: (memref<7xf32>, memref<7xf32>, memref<7xf32>) -> ()
return
}
// CHECK-LABEL: func @int_mul_op
func @int_mul_op(%lhs: memref<7xi32>, %rhs: memref<7xi32>,
%result: memref<7xi32>) -> () {
// CHECK: muli %{{.*}}, %{{.*}} : i32
"xla_lhlo.multiply"(%lhs, %rhs, %result) {name = "mul.1"}
: (memref<7xi32>, memref<7xi32>, memref<7xi32>) -> ()
return
}
// Sub tests.
// CHECK-LABEL: func @float_sub_op
func @float_sub_op(%lhs: memref<7xf32>, %rhs: memref<7xf32>,
%result: memref<7xf32>) -> () {
// CHECK: subf %{{.*}}, %{{.*}} : f32
"xla_lhlo.subtract"(%lhs, %rhs, %result) {name = "sub.1"}
: (memref<7xf32>, memref<7xf32>, memref<7xf32>) -> ()
return
}
// CHECK-LABEL: func @int_sub_op
func @int_sub_op(%lhs: memref<7xi32>, %rhs: memref<7xi32>,
%result: memref<7xi32>) -> () {
// CHECK: subi %{{.*}}, %{{.*}} : i32
"xla_lhlo.subtract"(%lhs, %rhs, %result) {name = "sub.1"}
: (memref<7xi32>, memref<7xi32>, memref<7xi32>) -> ()
return
}
// Dot tests.
// CHECK-LABEL: func @float_dot_op
func @float_dot_op(%lhs: memref<7x3xf32>, %rhs:
memref<3x4xf32>, %result: memref<7x4xf32> ) -> () {
// CHECK-NEXT: affine.for %[[I:.*]] = 0 to 7 {
// CHECK-NEXT: affine.for %[[J:.*]] = 0 to 4 {
// CHECK-NEXT: affine.for %[[K:.*]] = 0 to 3 {
// CHECK-NEXT: %[[LHS:.*]] = affine.load %{{.*}}[%[[I]], %[[K]]] : memref<7x3xf32>
// CHECK-NEXT: %[[RHS:.*]] = affine.load %{{.*}}[%[[K]], %[[J]]] : memref<3x4xf32>
// CHECK-NEXT: %[[RESULT:.*]] = affine.load %{{.*}}[%[[I]], %[[J]]] : memref<7x4xf32>
// CHECK-NEXT: %[[MULT:.*]] = mulf %[[LHS]], %[[RHS]] : f32
// CHECK-NEXT: %[[ADD:.*]] = addf %[[MULT]], %[[RESULT]] : f32
// CHECK-NEXT: affine.store %[[ADD]], %{{.*}}[%[[I]], %[[J]]] : memref<7x4xf32>
// CHECK: return
"xla_lhlo.dot"(%lhs, %rhs, %result) :
(memref<7x3xf32>, memref<3x4xf32>, memref<7x4xf32>) -> ()
return
}
// CHECK-LABEL: func @int_dot_op
func @int_dot_op(%lhs: memref<7x3xi32>, %rhs:
memref<3x4xi32>, %result: memref<7x4xi32> ) -> () {
// CHECK-NEXT: affine.for %[[I:.*]] = 0 to 7 {
// CHECK-NEXT: affine.for %[[J:.*]] = 0 to 4 {
// CHECK-NEXT: affine.for %[[K:.*]] = 0 to 3 {
// CHECK-NEXT: %[[LHS:.*]] = affine.load %{{.*}}[%[[I]], %[[K]]] : memref<7x3xi32>
// CHECK-NEXT: %[[RHS:.*]] = affine.load %{{.*}}[%[[K]], %[[J]]] : memref<3x4xi32>
// CHECK-NEXT: %[[RESULT:.*]] = affine.load %{{.*}}[%[[I]], %[[J]]] : memref<7x4xi32>
// CHECK-NEXT: %[[MULT:.*]] = muli %[[LHS]], %[[RHS]] : i32
// CHECK-NEXT: %[[ADD:.*]] = addi %[[MULT]], %[[RESULT]] : i32
// CHECK-NEXT: affine.store %[[ADD]], %{{.*}}[%[[I]], %[[J]]] : memref<7x4xi32>
// CHECK: return
"xla_lhlo.dot"(%lhs, %rhs, %result) :
(memref<7x3xi32>, memref<3x4xi32>, memref<7x4xi32>) -> ()
return
}

View File

@ -0,0 +1,34 @@
// RUN: mlir-hlo-opt %s -lhlo-legalize-to-gpu -split-input-file | FileCheck %s
func @reduce(%arg: memref<100x10xf32>,
%init: memref<f32>,
%result: memref<100xf32>) {
"xla_lhlo.reduce"(%arg, %init, %result) ( {
^bb0(%lhs: memref<f32>, %rhs: memref<f32>, %res: memref<f32>):
"xla_lhlo.add"(%lhs, %rhs, %res)
: (memref<f32>, memref<f32>, memref<f32>) -> ()
"xla_lhlo.terminator"() : () -> ()
} ) {dimensions = dense<[1]> : tensor<1xi64>}
: (memref<100x10xf32>, memref<f32>, memref<100xf32>) -> ()
return
}
// CHECK: func @reduce(%[[ARG0:.*]]: memref<100x10xf32>, %[[ARG1:.*]]: memref<f32>, %[[ARG2:.*]]: memref<100xf32>) {
// CHECK-DAG: %[[C100:.*]] = constant 100 : index
// CHECK-DAG: %[[C1:.*]] = constant 1 : index
// CHECK: gpu.launch blocks({{.*}}, {{.*}}, {{.*}}) in ({{.*}} = %[[C1]], {{.*}} = %[[C1]], {{.*}} = %[[C1]]) threads(%[[IDX:.*]], {{.*}}, {{.*}}) in ({{.*}} = %[[C100]], {{.*}} = %[[C1]], {{.*}} = %[[C1]]) {
// CHECK: %[[ACC:.*]] = load %[[ARG1]][] : memref<f32>
// CHECK: store %[[ACC]], %[[ARG2]][%[[IDX:.*]]] : memref<100xf32>
// CHECK-DAG: %[[LB:.*]] = constant 0 : index
// CHECK-DAG: %[[UB:.*]] = constant 10 : index
// CHECK-DAG: %[[STEP:.*]] = constant 1 : index
// CHECK: scf.for %[[IDX1:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] {
// CHECK: %[[LHS:.*]] = linalg.slice %[[ARG2]][%[[IDX]]] : memref<100xf32>, index, memref<f32, #map0>
// CHECK: %[[RHS:.*]] = linalg.slice %[[ARG0]][%[[IDX]], %[[IDX1]]] : memref<100x10xf32>, index, index, memref<f32, #map0>
// CHECK: "xla_lhlo.add"(%[[LHS]], %[[RHS]], %[[LHS]]) : (memref<f32, {{.*}}>, memref<f32, {{.*}}>, memref<f32, {{.*}}>) -> ()
// CHECK: }
// CHECK: gpu.terminator
// CHECK: }
// CHECK: return
// CHECK: }
// CHECK: }

View File

@ -0,0 +1,724 @@
// RUN: mlir-hlo-opt %s -lhlo-legalize-to-linalg -split-input-file | FileCheck %s
// CHECK: #map0 = affine_map<(d0, d1) -> (d0, d1)>
// CHECK-LABEL: func @element_wise
func @element_wise(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>,
%result: memref<2x2xf32>) {
"xla_lhlo.add"(%lhs, %rhs, %result)
: (memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>) -> ()
return
}
// CHECK: linalg.generic
// CHECK-NEXT: ^bb0(%[[LHS_IN:.*]]: f32, %[[RHS_IN:.*]]: f32, %[[RESULT_OUT:.*]]: f32):
// CHECK-NEXT: %[[RESULT:.*]] = addf %[[LHS_IN]], %[[RHS_IN]] : f32
// CHECK-NEXT: linalg.yield %[[RESULT]] : f32
// -----
// CHECK-LABEL: func @element_wise_with_dynamic_shape
func @element_wise_with_dynamic_shape(%lhs: memref<?x?xf32>,
%rhs: memref<?x?xf32>,
%result: memref<?x?xf32>) {
"xla_lhlo.add"(%lhs, %rhs, %result)
: (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>) -> ()
return
}
// CHECK: linalg.generic
// CHECK-NEXT: ^bb0(%[[LHS_IN:.*]]: f32, %[[RHS_IN:.*]]: f32, %[[RESULT_OUT:.*]]: f32):
// CHECK-NEXT: %[[RESULT:.*]] = addf %[[LHS_IN]], %[[RHS_IN]] : f32
// CHECK-NEXT: linalg.yield %[[RESULT]] : f32
// -----
// CHECK-LABEL: func @element_wise_scalar
func @element_wise_scalar(%lhs: memref<f32>, %rhs: memref<f32>,
%result: memref<f32>) {
"xla_lhlo.add"(%lhs, %rhs, %result)
: (memref<f32>, memref<f32>, memref<f32>) -> ()
return
}
// CHECK: %[[LHS:.*]] = load
// CHECK: %[[RHS:.*]] = load
// CHECK: %[[RES:.*]] = addf %[[LHS]], %[[RHS]]
// CHECK: store %[[RES]]
// CHECK-NEXT: return
// -----
// CHECK-LABEL: func @minf
func @minf(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>,
%result: memref<2x2xf32>) {
"xla_lhlo.minimum"(%lhs, %rhs, %result)
: (memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>) -> ()
return
}
// CHECK: linalg.generic
// CHECK-NEXT: ^bb0(%[[LHS_IN:.*]]: f32, %[[RHS_IN:.*]]: f32, %[[RESULT_OUT:.*]]: f32):
// CHECK-NEXT: %[[CMP:.*]] = cmpf "olt", %[[LHS_IN]], %[[RHS_IN]] : f32
// CHECK-NEXT: %[[RESULT:.*]] = select %[[CMP]], %[[LHS_IN]], %[[RHS_IN]] : f32
// CHECK-NEXT: linalg.yield %[[RESULT]] : f32
// -----
// CHECK-LABEL: func @maxi
func @maxi(%lhs: memref<2x2xi32>, %rhs: memref<2x2xi32>,
%result: memref<2x2xi32>) {
"xla_lhlo.maximum"(%lhs, %rhs, %result)
: (memref<2x2xi32>, memref<2x2xi32>, memref<2x2xi32>) -> ()
return
}
// CHECK: linalg.generic
// CHECK-NEXT: ^bb0(%[[LHS_IN:.*]]: i32, %[[RHS_IN:.*]]: i32, %[[RESULT_OUT:.*]]: i32):
// CHECK-NEXT: %[[CMP:.*]] = cmpi "sgt", %[[LHS_IN]], %[[RHS_IN]] : i32
// CHECK-NEXT: %[[RESULT:.*]] = select %[[CMP]], %[[LHS_IN]], %[[RHS_IN]] : i32
// CHECK-NEXT: linalg.yield %[[RESULT]] : i32
// -----
// CHECK-LABEL: func @and
func @and(%lhs: memref<2x2xi32>, %rhs: memref<2x2xi32>,
%result: memref<2x2xi32>) {
"xla_lhlo.and"(%lhs, %rhs, %result)
: (memref<2x2xi32>, memref<2x2xi32>, memref<2x2xi32>) -> ()
return
}
// CHECK: linalg.generic
// CHECK-NEXT: ^bb0(%[[LHS_IN:.*]]: i32, %[[RHS_IN:.*]]: i32, %[[RESULT_OUT:.*]]: i32):
// CHECK-NEXT: %[[RESULT:.*]] = and %[[LHS_IN]], %[[RHS_IN]] : i32
// CHECK-NEXT: linalg.yield %[[RESULT]] : i32
// -----
// CHECK-LABEL: func @exp
func @exp(%input: memref<2x2xf32>, %result: memref<2x2xf32>) {
"xla_lhlo.exponential"(%input, %result)
: (memref<2x2xf32>, memref<2x2xf32>) -> ()
return
}
// CHECK: linalg.generic
// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: f32, %[[RESULT_OUT:.*]]):
// CHECK-NEXT: %[[RESULT:.*]] = exp %[[OPERAND_IN]] : f32
// CHECK-NEXT: linalg.yield %[[RESULT]] : f32
// -----
// CHECK-LABEL: func @log
func @log(%input: memref<2x2xf32>, %result: memref<2x2xf32>) {
"xla_lhlo.log"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
return
}
// CHECK: linalg.generic
// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: f32, %[[RESULT_OUT:.*]]):
// CHECK-NEXT: %[[RESULT:.*]] = log %[[OPERAND_IN]] : f32
// CHECK-NEXT: linalg.yield %[[RESULT]] : f32
// -----
// CHECK-LABEL: func @copy
func @copy(%in: memref<2x4x8xf32>, %out: memref<2x4x8xf32>) {
"xla_lhlo.copy"(%in, %out) : (memref<2x4x8xf32>, memref<2x4x8xf32>) -> ()
return
}
// CHECK: linalg.generic
// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: f32, %[[RESULT_OUT:.*]]):
// CHECK-NEXT: linalg.yield %[[OPERAND_IN]] : f32
// -----
// CHECK-LABEL: func @float_cmp
func @float_cmp(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>,
%result: memref<2x2xi1>) {
"xla_lhlo.compare"(%lhs, %rhs, %result) {comparison_direction = "EQ"}
: (memref<2x2xf32>, memref<2x2xf32>, memref<2x2xi1>) -> ()
return
}
// CHECK: linalg.generic
// CHECK-NEXT: ^bb0(%[[LHS_IN:.*]]: f32, %[[RHS_IN:.*]]: f32, %[[RESULT_OUT:.*]]: i1):
// CHECK-NEXT: %[[RESULT:.*]] = cmpf "oeq", %[[LHS_IN]], %[[RHS_IN]] : f32
// CHECK-NEXT: linalg.yield %[[RESULT]] : i1
// -----
// CHECK-LABEL: func @int_cmp
func @int_cmp(%lhs: memref<2x2xi32>, %rhs: memref<2x2xi32>,
%result: memref<2x2xi1>) {
"xla_lhlo.compare"(%lhs, %rhs, %result) {comparison_direction = "LT"}
: (memref<2x2xi32>, memref<2x2xi32>, memref<2x2xi1>) -> ()
return
}
// CHECK: linalg.generic
// CHECK-NEXT: ^bb0(%[[LHS_IN:.*]]: i32, %[[RHS_IN:.*]]: i32, %[[RESULT_OUT:.*]]: i1):
// CHECK-NEXT: %[[RESULT:.*]] = cmpi "slt", %[[LHS_IN]], %[[RHS_IN]] : i32
// CHECK-NEXT: linalg.yield %[[RESULT]] : i1
// -----
// CHECK-LABEL: func @select
func @select(%pred: memref<2x2xi1>, %lhs: memref<2x2xf32>,
%rhs: memref<2x2xf32>, %result: memref<2x2xf32>) {
"xla_lhlo.select"(%pred, %lhs, %rhs, %result)
: (memref<2x2xi1>, memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>) -> ()
return
}
// CHECK: linalg.generic
// CHECK-NEXT: ^bb0(%[[PRED_IN:.*]]: i1, %[[LHS_IN:.*]]: f32, %[[RHS_IN:.*]]: f32, %[[RESULT_OUT:.*]]: f32):
// CHECK-NEXT: %[[RESULT:.*]] = select %[[PRED_IN]], %[[LHS_IN]], %[[RHS_IN]] : f32
// CHECK-NEXT: linalg.yield %[[RESULT]] : f32
// -----
// CHECK: #[[RESULT_MAP:.*]] = affine_map<(d0, d1) -> (d0, d1)>
// CHECK-LABEL: func @iota
func @iota(%out: memref<7x10xf32>) {
"xla_lhlo.iota"(%out) {iota_dimension = 1 : i64} : (memref<7x10xf32>) -> ()
return
}
// CHECK: linalg.indexed_generic
// CHECK-SAME: indexing_maps = [#[[RESULT_MAP]]]
// CHECK-NEXT: ^bb0(%[[D0:.*]]: index, %[[D1:.*]]: index, %[[RESULT:.*]]: f32):
// CHECK-NEXT: %[[INT_CAST:.*]] = index_cast %[[D1]] : index to i32
// CHECK-NEXT: %[[FLOAT_CAST:.*]] = sitofp %[[INT_CAST]] : i32 to f32
// CHECK-NEXT: linalg.yield %[[FLOAT_CAST]] : f32
// -----
// CHECK-DAG: #[[OPERAND_MAP:.+]] = affine_map<(d0, d1, d2) -> ()>
// CHECK-DAG: #[[RESULT_MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
// CHECK-LABEL: func @broadcast_scalar
func @broadcast_scalar(%operand: memref<f32>, %result: memref<4x2x1xf32>) {
"xla_lhlo.broadcast"(%operand, %result) {
broadcast_sizes = dense<[4, 2, 1]> : tensor<3xi64>
} : (memref<f32>, memref<4x2x1xf32>) -> ()
return
}
// CHECK: linalg.generic
// CHECK-SAME: indexing_maps = [#[[OPERAND_MAP]], #[[RESULT_MAP]]]
// CHECK-NEXT: ^bb0(%[[OPERAND:.+]]: f32, %{{.+}}: f32):
// CHECK-NEXT: linalg.yield %[[OPERAND]] : f32
// -----
// CHECK-DAG: #[[OPERAND_MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d3, d4, d5)>
// CHECK-DAG: #[[RESULT_MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)>
// CHECK-LABEL: func @broadcast
func @broadcast(%operand: memref<4x?x16xf32>,
%result: memref<4x2x1x4x?x16xf32>) {
"xla_lhlo.broadcast"(%operand, %result) {
broadcast_sizes = dense<[4, 2, 1]> : tensor<3xi64>
} : (memref<4x?x16xf32>, memref<4x2x1x4x?x16xf32>) -> ()
return
}
// CHECK: linalg.generic
// CHECK-SAME: indexing_maps = [#[[OPERAND_MAP]], #[[RESULT_MAP]]]
// CHECK-NEXT: ^bb0(%[[OPERAND:.+]]: f32, %{{.+}}: f32):
// CHECK-NEXT: linalg.yield %[[OPERAND]] : f32
// -----
// CHECK-DAG: #[[OPERAND_MAP:.*]] = affine_map<(d0, d1, d2, d3, d4) -> (d4, d0, d2)>
// CHECK-DAG: #[[RESULT_MAP:.*]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>
// CHECK-LABEL: func @dynamic_broadcast_in_dim
func @dynamic_broadcast_in_dim(%operand: memref<?x?x?xf32>,
%result: memref<?x?x?x?x?xf32>) {
"xla_lhlo.broadcast_in_dim"(%operand, %result) {
broadcast_dimensions = dense<[4,0,2]> : tensor<3xi64>
} : (memref<?x?x?xf32>, memref<?x?x?x?x?xf32>) -> ()
return
}
// CHECK: linalg.generic
// CHECK-SAME: indexing_maps = [#[[OPERAND_MAP]], #[[RESULT_MAP]]]
// CHECK-NEXT: ^bb0(%[[OPERAND:.*]]: f32, %[[RESULT:.*]]: f32):
// CHECK-NEXT: linalg.yield %[[OPERAND]] : f32
// -----
// CHECK-DAG: #[[OPERAND_MAP:.*]] = affine_map<(d0, d1) -> (d0)>
// CHECK-DAG: #[[RESULT_MAP:.*]] = affine_map<(d0, d1) -> (d0, d1)>
// CHECK-LABEL: func @static_broadcast_in_dim_no_expansion
func @static_broadcast_in_dim_no_expansion(%operand: memref<5xf32>,
%result: memref<5x10xf32>) {
"xla_lhlo.broadcast_in_dim"(%operand, %result) {
broadcast_dimensions = dense<[0]> : tensor<1xi64>
} : (memref<5xf32>, memref<5x10xf32>) -> ()
return
}
// CHECK-NOT: linalg.reshape
// CHECK: linalg.generic {{{.*}}indexing_maps = [#[[OPERAND_MAP]], #[[RESULT_MAP]]]
// CHECK-NEXT: ^bb0(%[[OPERAND:.*]]: f32, %[[RESULT:.*]]: f32):
// CHECK-NEXT: linalg.yield %[[OPERAND]] : f32
// -----
// CHECK-DAG: #[[REASSOCIATION:.*]] = affine_map<(d0, d1) -> (d0, d1)>
// CHECK-DAG: #[[OPERAND_MAP:.*]] = affine_map<(d0, d1, d2) -> (d0)>
// CHECK-DAG: #[[RESULT_MAP:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
// CHECK-LABEL: func @static_broadcast_in_dim_expansion
func @static_broadcast_in_dim_expansion(%operand: memref<1x5xf32>,
%result: memref<5x10x100xf32>) {
"xla_lhlo.broadcast_in_dim"(%operand, %result) {
broadcast_dimensions = dense<[2, 0]> : tensor<2xi64>
} : (memref<1x5xf32>, memref<5x10x100xf32>) -> ()
return
}
// CHECK: %[[RESHAPED_ARG:.*]] = linalg.reshape %{{.*}}#[[REASSOCIATION]]]
// CHECK-SAME: memref<1x5xf32> into memref<5xf32>
// CHECK: linalg.generic {{{.*}}indexing_maps =
// CHECK-SAME: [#[[OPERAND_MAP]], #[[RESULT_MAP]]]{{.*}} %[[RESHAPED_ARG]]
// CHECK-NEXT: ^bb0(%[[OPERAND:.*]]: f32, %[[RESULT:.*]]: f32):
// CHECK-NEXT: linalg.yield %[[OPERAND]] : f32
// -----
// CHECK-DAG: #[[RESULT_MAP_0:.*]] = affine_map<(d0, d1) -> ()>
// CHECK-DAG: #[[RESULT_MAP:.*]] = affine_map<(d0, d1) -> (d0, d1)>
// CHECK-LABEL: func @static_broadcast_in_dim_scalar
func @static_broadcast_in_dim_scalar(%operand: memref<f32>,
%result: memref<5x10xf32>) {
"xla_lhlo.broadcast_in_dim"(%operand, %result) {
broadcast_dimensions = dense<[]> : tensor<0xi64>
} : (memref<f32>, memref<5x10xf32>) -> ()
return
}
// CHECK-NOT: linalg.reshape
// CHECK: linalg.generic {{{.*}}indexing_maps = [#[[RESULT_MAP_0]], #[[RESULT_MAP]]]
// CHECK-NEXT: ^bb0(%[[CONST:.*]]: f32, %[[RESULT:.*]]: f32):
// CHECK-NEXT: linalg.yield %[[CONST]] : f32
// -----
// CHECK-DAG: #[[OPERAND_MAP:.+]] = affine_map<(d0, d1) -> (d0)>
// CHECK-DAG: #[[RESULT_MAP:.+]] = affine_map<(d0, d1) -> (d0, d1)>
// CHECK-LABEL: func @static_broadcast_in_dim_with_one_to_one
func @static_broadcast_in_dim_with_one_to_one(%operand: memref<1xf32>,
%result: memref<1x5xf32>) {
"xla_lhlo.broadcast_in_dim"(%operand, %result) {
broadcast_dimensions = dense<[0]> : tensor<1xi64>
} : (memref<1xf32>, memref<1x5xf32>) -> ()
return
}
// CHECK-NOT: linalg.reshape
// CHECK: linalg.generic {{{.*}}indexing_maps = [#[[OPERAND_MAP]], #[[RESULT_MAP]]]
// CHECK-NEXT: ^bb0(%[[OPERAND:.+]]: f32, %{{.+}}: f32):
// CHECK-NEXT: linalg.yield %[[OPERAND]] : f32
// -----
// CHECK-DAG: #[[RESULT_MAP:.+]] = affine_map<(d0, d1) -> (d0, d1)>
// CHECK-LABEL: func @static_broadcast_in_dim_with_one_to_many
func @static_broadcast_in_dim_with_one_to_many(%operand: memref<1xf32>,
%result: memref<5x5xf32>) {
"xla_lhlo.broadcast_in_dim"(%operand, %result) {
broadcast_dimensions = dense<[1]> : tensor<1xi64>
} : (memref<1xf32>, memref<5x5xf32>) -> ()
return
}
// CHECK-NOT: linalg.reshape
// CHECK: %[[C0:.*]] = constant 0 : index
// CHECK: %[[VALUE:.*]] = load %{{.*}}[[C0]]
// CHECK: linalg.generic {{{.*}}indexing_maps = [#[[RESULT_MAP]]]
// CHECK-NEXT: ^bb0(%{{.+}}: f32):
// CHECK-NEXT: linalg.yield %[[VALUE]] : f32
// -----
// CHECK-LABEL: func @constant
func @constant(%value: memref<i32>) {
"xla_lhlo.constant"(%value) {
value = dense<10> : tensor<i32>
} : (memref<i32>) -> ()
return
}
// CHECK: %[[CONSTANT:.*]] = constant 10 : i32
// CHECK: store %[[CONSTANT]], %{{.*}}[] : memref<i32>
// -----
// CHECK-LABEL: func @absf
func @absf(%input: memref<2x2xf32>, %result: memref<2x2xf32>) {
"xla_lhlo.abs"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
return
}
// CHECK: linalg.generic
// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: f32, %[[RESULT_OUT:.*]]):
// CHECK-NEXT: %[[RESULT:.*]] = absf %[[OPERAND_IN]] : f32
// CHECK-NEXT: linalg.yield %[[RESULT]] : f32
// -----
// CHECK-LABEL: func @absi
func @absi(%input: memref<2x2xi32>,
%result: memref<2x2xi32>) {
"xla_lhlo.abs"(%input, %result) : (memref<2x2xi32>, memref<2x2xi32>) -> ()
return
}
// CHECK: linalg.generic
// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: i32, %[[RESULT_OUT:.*]]):
// CHECK-NEXT: %[[L0:.*]] = constant 0 : i32
// CHECK-NEXT: %[[L1:.*]] = cmpi "sge", %[[OPERAND_IN]], %[[L0]] : i32
// CHECK-NEXT: %[[L2:.*]] = subi %[[L0]], %[[OPERAND_IN]] : i32
// CHECK-NEXT: %[[RESULT:.*]] = select %[[L1]], %[[OPERAND_IN]], %[[L2]] : i32
// CHECK-NEXT: linalg.yield %[[RESULT]] : i32
// -----
// CHECK-LABEL: func @ceil
func @ceil(%input: memref<2x2xf32>, %result: memref<2x2xf32>) {
"xla_lhlo.ceil"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
return
}
// CHECK: linalg.generic
// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: f32, %[[RESULT_OUT:.*]]):
// CHECK-NEXT: %[[RESULT:.*]] = ceilf %[[OPERAND_IN]] : f32
// CHECK-NEXT: linalg.yield %[[RESULT]] : f32
// -----
// CHECK-LABEL: func @convert_i32_to_f32
func @convert_i32_to_f32(%input: memref<2x2xi32>, %result: memref<2x2xf32>) {
"xla_lhlo.convert"(%input, %result) : (memref<2x2xi32>, memref<2x2xf32>) -> ()
return
}
// CHECK: linalg.generic
// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: i32, %[[RESULT_OUT:.*]]: f32):
// CHECK-NEXT: %[[RESULT:.*]] = sitofp %[[OPERAND_IN]] : i32 to f32
// CHECK-NEXT: linalg.yield %[[RESULT]] : f32
// -----
// CHECK-LABEL: func @convert_i16_to_i32
func @convert_i16_to_i32(%input: memref<2x2xi16>,
%result: memref<2x2xi32>) {
"xla_lhlo.convert"(%input, %result) : (memref<2x2xi16>, memref<2x2xi32>) -> ()
return
}
// CHECK: linalg.generic
// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: i16, %[[RESULT_OUT:.*]]: i32):
// CHECK-NEXT: %[[RESULT:.*]] = zexti %[[OPERAND_IN]] : i16 to i32
// CHECK-NEXT: linalg.yield %[[RESULT]] : i32
// -----
// CHECK-LABEL: func @convert_i32_to_i16
func @convert_i32_to_i16(%input: memref<2x2xi32>, %result: memref<2x2xi16>) {
"xla_lhlo.convert"(%input, %result) : (memref<2x2xi32>, memref<2x2xi16>) -> ()
return
}
// CHECK: linalg.generic
// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: i32, %[[RESULT_OUT:.*]]: i16):
// CHECK-NEXT: %[[RESULT:.*]] = trunci %[[OPERAND_IN]] : i32 to i16
// CHECK-NEXT: linalg.yield %[[RESULT]] : i16
// -----
// CHECK-LABEL: func @convert_f32_to_f64
func @convert_f32_to_f64(%input: memref<2x2xf32>, %result: memref<2x2xf64>) {
"xla_lhlo.convert"(%input, %result) : (memref<2x2xf32>, memref<2x2xf64>) -> ()
return
}
// CHECK: linalg.generic
// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: f32, %[[RESULT_OUT:.*]]: f64):
// CHECK-NEXT: %[[RESULT:.*]] = fpext %[[OPERAND_IN]] : f32 to f64
// CHECK-NEXT: linalg.yield %[[RESULT]] : f64
// -----
// CHECK-LABEL: func @convert_f64_to_f32
func @convert_f64_to_f32(%input: memref<2x2xf64>, %result: memref<2x2xf32>) {
"xla_lhlo.convert"(%input, %result) : (memref<2x2xf64>, memref<2x2xf32>) -> ()
return
}
// CHECK: linalg.generic
// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: f64, %[[RESULT_OUT:.*]]: f32):
// CHECK-NEXT: %[[RESULT:.*]] = fptrunc %[[OPERAND_IN]] : f64 to f32
// CHECK-NEXT: linalg.yield %[[RESULT]] : f32
// -----
// CHECK-LABEL: func @convert_i32_to_i32
func @convert_i32_to_i32(%input: memref<2x2xi32>, %result: memref<2x2xi32>) {
"xla_lhlo.convert"(%input, %result) : (memref<2x2xi32>, memref<2x2xi32>) -> ()
return
}
// CHECK: linalg.generic
// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: i32, %[[RESULT_OUT:.*]]: i32):
// CHECK-NEXT: linalg.yield %[[OPERAND_IN]] : i32
// -----
// CHECK-LABEL: func @convert_f32_to_f32
func @convert_f32_to_f32(%input: memref<2x2xf32>, %result: memref<2x2xf32>) {
"xla_lhlo.convert"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
return
}
// CHECK: linalg.generic
// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: f32, %[[RESULT_OUT:.*]]: f32):
// CHECK-NEXT: linalg.yield %[[OPERAND_IN]] : f32
// -----
// CHECK-LABEL: func @convert_f32_to_i32
func @convert_f32_to_i32(%input: memref<2x2xf32>, %result: memref<2x2xi32>) {
"xla_lhlo.convert"(%input, %result)
: (memref<2x2xf32>, memref<2x2xi32>) -> ()
return
}
// CHECK: linalg.generic
// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: f32, %[[RESULT_OUT:.*]]: i32):
// CHECK-NEXT: %[[RESULT:.*]] = fptosi %[[OPERAND_IN]] : f32 to i32
// CHECK-NEXT: linalg.yield %[[RESULT]] : i32
// -----
// CHECK-LABEL: func @cos
func @cos(%input: memref<2x2xf32>, %result: memref<2x2xf32>) {
"xla_lhlo.cosine"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
return
}
// CHECK: linalg.generic
// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: f32, %[[RESULT_OUT:.*]]):
// CHECK-NEXT: %[[RESULT:.*]] = cos %[[OPERAND_IN]] : f32
// CHECK-NEXT: linalg.yield %[[RESULT]] : f32
// -----
// CHECK-LABEL: func @sin
func @sin(%input: memref<2x2xf32>,
%result: memref<2x2xf32>) {
"xla_lhlo.sine"(%input, %result)
: (memref<2x2xf32>, memref<2x2xf32>) -> ()
return
}
// CHECK: linalg.generic
// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: f32, %[[RESULT_OUT:.*]]):
// CHECK-NEXT: %[[RESULT:.*]] = sin %[[OPERAND_IN]] : f32
// CHECK-NEXT: linalg.yield %[[RESULT]] : f32
// -----
// CHECK-LABEL: func @negf
func @negf(%input: memref<2x2xf32>, %result: memref<2x2xf32>) {
"xla_lhlo.negate"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
return
}
// CHECK: linalg.generic
// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: f32, %[[RESULT_OUT:.*]]):
// CHECK-NEXT: %[[RESULT:.*]] = negf %[[OPERAND_IN]] : f32
// CHECK-NEXT: linalg.yield %[[RESULT]] : f32
// -----
// CHECK-LABEL: func @negi
func @negi(%input: memref<2x2xi32>, %result: memref<2x2xi32>) {
"xla_lhlo.negate"(%input, %result) : (memref<2x2xi32>, memref<2x2xi32>) -> ()
return
}
// CHECK: linalg.generic
// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: i32, %[[RESULT_OUT:.*]]):
// CHECK-NEXT: %[[L0:.*]] = constant 0 : i32
// CHECK-NEXT: %[[RESULT:.*]] = subi %[[L0]], %[[OPERAND_IN]] : i32
// CHECK-NEXT: linalg.yield %[[RESULT]] : i32
// -----
// CHECK-LABEL: func @rem
func @remainder(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>,
%result: memref<2x2xf32>) {
"xla_lhlo.remainder"(%lhs, %rhs, %result)
: (memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>) -> ()
return
}
// CHECK: linalg.generic
// CHECK-NEXT: ^bb0(%[[LHS_IN:.*]]: f32, %[[RHS_IN:.*]]: f32, %[[RESULT:.*]]: f32):
// CHECK-NEXT: %[[RESULT:.*]] = remf %[[LHS_IN]], %[[RHS_IN]] : f32
// CHECK-NEXT: linalg.yield %[[RESULT]] : f32
// -----
// CHECK-LABEL: func @rsqrt
func @rsqrt(%input: memref<2x2xf32>, %result: memref<2x2xf32>) {
"xla_lhlo.rsqrt"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
return
}
// CHECK: linalg.generic
// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: f32, %[[RESULT_OUT:.*]]):
// CHECK-NEXT: %[[RESULT:.*]] = rsqrt %[[OPERAND_IN]] : f32
// CHECK-NEXT: linalg.yield %[[RESULT]] : f32
// -----
// CHECK-LABEL: func @sign
func @sign(%input: memref<2x2xf32>, %result: memref<2x2xf32>) {
"xla_lhlo.sign"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
return
}
// CHECK: linalg.generic
// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: f32, %[[RESULT_OUT:.*]]):
// CHECK-NEXT: %[[CST:.*]] = constant 1.000000e+00 : f32
// CHECK-NEXT: %[[RESULT:.*]] = copysign %[[CST]], %[[OPERAND_IN]] : f32
// CHECK-NEXT: linalg.yield %[[RESULT]] : f32
// -----
// CHECK-LABEL: func @sqrt
func @sqrt(%input: memref<2x2xf32>, %result: memref<2x2xf32>) {
"xla_lhlo.sqrt"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
return
}
// CHECK: linalg.generic
// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: f32, %[[RESULT_OUT:.*]]):
// CHECK-NEXT: %[[RESULT:.*]] = sqrt %[[OPERAND_IN]] : f32
// CHECK-NEXT: linalg.yield %[[RESULT]] : f32
// -----
// CHECK-LABEL: func @tanh
func @tanh(%input: memref<2x2xf32>, %result: memref<2x2xf32>) {
"xla_lhlo.tanh"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
return
}
// CHECK: linalg.generic
// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: f32, %[[RESULT_OUT:.*]]):
// CHECK-NEXT: %[[RESULT:.*]] = tanh %[[OPERAND_IN]] : f32
// CHECK-NEXT: linalg.yield %[[RESULT]] : f32
// -----
// CHECK-LABEL: func @complex
func @complex(%real: memref<2x2xf32>,
%imag: memref<2x2xf32>,
%cplx: memref<2x2xcomplex<f32>>) {
"xla_lhlo.complex"(%real, %imag, %cplx)
: (memref<2x2xf32>, memref<2x2xf32>, memref<2x2xcomplex<f32>>) -> ()
return
}
// CHECK: linalg.generic
// CHECK-NEXT: ^bb0(%[[RE:.*]]: f32, %[[IM:.*]]: f32, %[[CP:.*]]: complex<f32>):
// CHECK-NEXT: %[[RESULT:.*]] = create_complex %[[RE]], %[[IM]] : complex<f32>
// CHECK-NEXT: linalg.yield %[[RESULT]] : complex<f32>
// -----
// CHECK-LABEL: func @real
func @real(%cplx: memref<2x2xcomplex<f32>>,
%real: memref<2x2xf32>) {
"xla_lhlo.real"(%cplx, %real)
: (memref<2x2xcomplex<f32>>, memref<2x2xf32>) -> ()
return
}
// CHECK: linalg.generic
// CHECK-NEXT: ^bb0(%[[CPLX_IN:.*]]: complex<f32>, %[[REAL_OUT:.*]]: f32):
// CHECK-NEXT: %[[REAL:.*]] = re %[[CPLX_IN:.*]] : complex<f32>
// CHECK-NEXT: linalg.yield %[[REAL]] : f32
// -----
// CHECK-LABEL: func @imag
func @imag(%cplx: memref<2x2xcomplex<f32>>,
%imag: memref<2x2xf32>) {
"xla_lhlo.imag"(%cplx, %imag)
: (memref<2x2xcomplex<f32>>, memref<2x2xf32>) -> ()
return
}
// CHECK: linalg.generic
// CHECK-NEXT: ^bb0(%[[CPLX_IN:.*]]: complex<f32>, %[[IMAG_OUT:.*]]: f32):
// CHECK-NEXT: %[[IMAG:.*]] = im %[[CPLX_IN:.*]] : complex<f32>
// CHECK-NEXT: linalg.yield %[[IMAG]] : f32
// -----
// CHECK: func @slice(%[[IN:.*]]: memref<?x?xf32>, %[[OUT:.*]]: memref<?x?xf32>)
func @slice(%operand: memref<?x?xf32>, %result: memref<?x?xf32>) {
"xla_lhlo.slice"(%operand, %result) {
start_indices = dense<[0,1]> : tensor<2xi64>,
limit_indices = dense<[2,3]> : tensor<2xi64>,
strides = dense<[1,1]> : tensor<2xi64>
} : (memref<?x?xf32>, memref<?x?xf32>) -> ()
return
}
// CHECK: %[[L0:.*]] = constant 0 : index
// CHECK: %[[L2:.*]] = constant 2 : index
// CHECK: %[[L1:.*]] = constant 1 : index
// CHECK: %[[LHS:.*]] = linalg.range %[[L0]] : %[[L2]] : %[[L1]]
// CHECK: %[[R0:.*]] = constant 1 : index
// CHECK: %[[R2:.*]] = constant 3 : index
// CHECK: %[[R1:.*]] = constant 1 : index
// CHECK: %[[RHS:.*]] = linalg.range %[[R0]] : %[[R2]] : %[[R1]]
// CHECK: %[[RESULT:.*]] = linalg.slice %[[IN]][%[[LHS]], %[[RHS]]]
// CHECK: linalg.copy(%[[RESULT]], %[[OUT]])
// -----
// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
// CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0, d1, d2) -> (d2)>
// CHECK-LABEL: func @reshape_3D_2D
func @reshape_3D_2D(%arg0: memref<12x1x42xi32>, %arg1 : memref<12x42xi32>) {
"xla_lhlo.reshape"(%arg0, %arg1)
: (memref<12x1x42xi32>, memref<12x42xi32>) -> ()
return
}
// CHECK: linalg.reshape %{{.*}} [#[[MAP1]], #[[MAP2]]]
// CHECK-NEXT: linalg.copy
// -----
// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d0)>
// CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0, d1, d2, d3) -> (d1, d2, d3)>
// CHECK-LABEL: func @reshape_4D_2D
func @reshape_4D_2D(%arg0: memref<12x42x1x1xi32>, %arg1 : memref<12x42xi32>) {
"xla_lhlo.reshape"(%arg0, %arg1)
: (memref<12x42x1x1xi32>, memref<12x42xi32>) -> ()
return
}
// CHECK: linalg.reshape %{{.*}} [#[[MAP1]], #[[MAP2]]]
// CHECK-NEXT: linalg.copy
// -----
// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1)>
// CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0, d1, d2, d3) -> (d2, d3)>
// CHECK-LABEL: func @reshape_2D_4D
func @reshape_2D_4D(%arg0: memref<12x42xi32>, %arg1 : memref<12x1x42x1xi32>) {
"xla_lhlo.reshape"(%arg0, %arg1)
: (memref<12x42xi32>, memref<12x1x42x1xi32>) -> ()
return
}
// CHECK: linalg.reshape %{{.*}} [#[[MAP1]], #[[MAP2]]]
// CHECK-NEXT: linalg.copy
// -----
// CHECK-DAG: #[[OPERAND_MAP:.*]] = affine_map<(d0, d1) -> (d0, -d1 + 2)>
// CHECK-DAG: #[[RESULT_MAP:.*]] = affine_map<(d0, d1) -> (d0, d1)>
// CHECK-LABEL: func @reverse
func @reverse(%arg0: memref<2x3xf32>, %arg1: memref<2x3xf32>) {
"xla_lhlo.reverse"(%arg0, %arg1) {
dimensions = dense<1> : tensor<1xi64>
} : (memref<2x3xf32>, memref<2x3xf32>) -> ()
return
}
// CHECK: linalg.generic {{{.*}}indexing_maps = [#[[OPERAND_MAP]], #[[RESULT_MAP]]]
// -----
func @conv(%input: memref<3x5x5x3xf32>, %filter: memref<2x2x3x4xf32>, %output: memref<3x5x5x4xf32>) {
%c0 = constant 0 : index
%0 = alloc() : memref<3x5x5x4xf32>
// CHECK: linalg.conv(%{{.+}}, %{{.+}}, %{{.+}})
// CHECK-SAME: dilations = [1, 2]
// CHECK-SAME: padding = dense<{{\[\[}}0, 1], [0, 1]]> : tensor<2x2xi64>
// CHECK-SAME: strides = [2, 1]}
// With all atributes explicitly specified.
"xla_lhlo.convolution"(%filter, %input, %0) {batch_group_count = 1 : i64, dimension_numbers = {input_batch_dimension = 0 : i64, input_feature_dimension = 3 : i64, input_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>, kernel_input_feature_dimension = 2 : i64, kernel_output_feature_dimension = 3 : i64, kernel_spatial_dimensions = dense<[0, 1]> : tensor<2xi64>, output_batch_dimension = 0 : i64, output_feature_dimension = 3 : i64, output_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>}, feature_group_count = 1 : i64, padding = dense<[[0, 1], [0, 1]]> : tensor<2x2xi64>, rhs_dilation = dense<[1, 2]> : tensor<2xi64>, window_strides = dense<[2, 1]> : tensor<2xi64>} : (memref<2x2x3x4xf32>, memref<3x5x5x3xf32>, memref<3x5x5x4xf32>) -> ()
// Dilation left unspecified, sets default dilation since linalg expects it.
// CHECK: linalg.conv(%{{.+}}, %{{.+}}, %{{.+}})
// CHECK-SAME: dilations = [1, 1]
// Padding is not set if it's zero.
// CHECK-NOT: padding
"xla_lhlo.convolution"(%filter, %input, %0) {batch_group_count = 1 : i64, dimension_numbers = {input_batch_dimension = 0 : i64, input_feature_dimension = 3 : i64, input_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>, kernel_input_feature_dimension = 2 : i64, kernel_output_feature_dimension = 3 : i64, kernel_spatial_dimensions = dense<[0, 1]> : tensor<2xi64>, output_batch_dimension = 0 : i64, output_feature_dimension = 3 : i64, output_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>}, feature_group_count = 1 : i64, window_strides = dense<[2, 1]> : tensor<2xi64>} : (memref<2x2x3x4xf32>, memref<3x5x5x3xf32>, memref<3x5x5x4xf32>) -> ()
"xla_lhlo.copy"(%0, %output) : (memref<3x5x5x4xf32>, memref<3x5x5x4xf32>) -> ()
"xla_lhlo.terminator"() : () -> ()
}

View File

@ -0,0 +1,65 @@
// RUN: mlir-hlo-opt %s --test-lhlo-legalize-to-llvm -split-input-file | FileCheck %s
// CHECK-LABEL: func @static_memref_cast
func @static_memref_cast(%buf : memref<10x1x5xf32>) {
%0 = xla_lhlo.static_memref_cast %buf
: memref<10x1x5xf32> -> memref<10x5xf32, offset: 2, strides: [5, 1]>
return
}
// CHECK: %[[INPUT_MEMREF_BLDR:.*]] = llvm.mlir.undef : [[DESCRIPTOR_TYPE_3D:!.*]]
// CHECK: llvm.insertvalue
// CHECK: %[[MEMREF_BLDR_0:.*]] = llvm.mlir.undef : [[DESCRIPTOR_TYPE_2D:!.*]]
// CHECK: %[[IN_PTR:.*]] = llvm.extractvalue %[[INPUT_MEMREF:.*]][0] : [[DESCRIPTOR_TYPE_3D]]
// CHECK: %[[PTR:.*]] = llvm.bitcast %[[IN_PTR]] : !llvm<"float*"> to !llvm<"float*">
// CHECK: %[[MEMREF_BLDR_1:.*]] = llvm.insertvalue %[[PTR]], %[[MEMREF_BLDR_0]][0] : [[DESCRIPTOR_TYPE_2D]]
// CHECK: %[[IN_ALIGNED_PTR:.*]] = llvm.extractvalue %[[INPUT_MEMREF]][1] : [[DESCRIPTOR_TYPE_3D]]
// CHECK: %[[ALIGNED_PTR:.*]] = llvm.bitcast %[[IN_ALIGNED_PTR]] : !llvm<"float*"> to !llvm<"float*">
// CHECK: %[[MEMREF_BLDR_2:.*]] = llvm.insertvalue %[[ALIGNED_PTR]], %[[MEMREF_BLDR_1]][1] : [[DESCRIPTOR_TYPE_2D]]
// CHECK: %[[C2:.*]] = llvm.mlir.constant(2 : index) : !llvm.i64
// CHECK: %[[MEMREF_BLDR_3:.*]] = llvm.insertvalue %[[C2]], %[[MEMREF_BLDR_2]][2] : [[DESCRIPTOR_TYPE_2D]]
// CHECK: %[[C10:.*]] = llvm.mlir.constant(10 : index) : !llvm.i64
// CHECK: %[[MEMREF_BLDR_4:.*]] = llvm.insertvalue %[[C10]], %[[MEMREF_BLDR_3]][3, 0] : [[DESCRIPTOR_TYPE_2D]]
// CHECK: %[[C5:.*]] = llvm.mlir.constant(5 : index) : !llvm.i64
// CHECK: %[[MEMREF_BLDR_5:.*]] = llvm.insertvalue %[[C5]], %[[MEMREF_BLDR_4]][4, 0] : [[DESCRIPTOR_TYPE_2D]]
// CHECK: %[[C5_:.*]] = llvm.mlir.constant(5 : index) : !llvm.i64
// CHECK: %[[MEMREF_BLDR_6:.*]] = llvm.insertvalue %[[C5_]], %[[MEMREF_BLDR_5]][3, 1] : [[DESCRIPTOR_TYPE_2D]]
// CHECK: %[[C1:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64
// CHECK: %[[MEMREF_BLDR_7:.*]] = llvm.insertvalue %[[C1]], %[[MEMREF_BLDR_6]][4, 1] : [[DESCRIPTOR_TYPE_2D]]
// -----
// CHECK-LABEL: func @dynamic_memref_cast
func @dynamic_memref_cast(%buf : memref<?x?xf32>) {
%size_X = constant 10 : index
%size_Y = constant 50 : index
%stride_X = constant 1 : index
%stride_Y = constant 0 : index
%0 = xla_lhlo.dynamic_memref_cast %buf(%size_X, %size_Y)[%stride_X, %stride_Y]
: memref<?x?xf32> -> memref<?x?xf32, offset: 0, strides: [?, ?]>
return
}
// CHECK: %[[C10:.*]] = llvm.mlir.constant(10 : index) : !llvm.i64
// CHECK: %[[C50:.*]] = llvm.mlir.constant(50 : index) : !llvm.i64
// CHECK: %[[C1:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64
// CHECK: %[[C0:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64
// CHECK: %[[MEMREF_BLDR_0:.*]] = llvm.mlir.undef : [[DESCRIPTOR_TYPE:!.*]]
// CHECK: %[[IN_PTR:.*]] = llvm.extractvalue %[[INPUT_MEMREF:.*]][0] : [[DESCRIPTOR_TYPE]]
// CHECK: %[[PTR:.*]] = llvm.bitcast %[[IN_PTR]] : !llvm<"float*"> to !llvm<"float*">
// CHECK: %[[MEMREF_BLDR_1:.*]] = llvm.insertvalue %[[PTR]], %[[MEMREF_BLDR_0]][0] : [[DESCRIPTOR_TYPE]]
// CHECK: %[[IN_ALIGNED_PTR:.*]] = llvm.extractvalue %[[INPUT_MEMREF]][1] : [[DESCRIPTOR_TYPE]]
// CHECK: %[[ALIGNED_PTR:.*]] = llvm.bitcast %[[IN_ALIGNED_PTR]] : !llvm<"float*"> to !llvm<"float*">
// CHECK: %[[MEMREF_BLDR_2:.*]] = llvm.insertvalue %[[ALIGNED_PTR]], %[[MEMREF_BLDR_1]][1] : [[DESCRIPTOR_TYPE]]
// CHECK: %[[SRC_OFFSET:.*]] = llvm.extractvalue %[[INPUT_MEMREF]][2] : [[DESCRIPTOR_TYPE]]
// CHECK: %[[MEMREF_BLDR_3:.*]] = llvm.insertvalue %[[SRC_OFFSET]], %[[MEMREF_BLDR_2]][2] : [[DESCRIPTOR_TYPE]]
// CHECK: %[[MEMREF_BLDR_4:.*]] = llvm.insertvalue %[[C10]], %[[MEMREF_BLDR_3]][3, 0] : [[DESCRIPTOR_TYPE]]
// CHECK: %[[MEMREF_BLDR_5:.*]] = llvm.insertvalue %[[C1]], %[[MEMREF_BLDR_4]][4, 0] : [[DESCRIPTOR_TYPE]]
// CHECK: %[[MEMREF_BLDR_6:.*]] = llvm.insertvalue %[[C50]], %[[MEMREF_BLDR_5]][3, 1] : [[DESCRIPTOR_TYPE]]
// CHECK: %[[MEMREF_BLDR_7:.*]] = llvm.insertvalue %[[C0]], %[[MEMREF_BLDR_6]][4, 1] : [[DESCRIPTOR_TYPE]]

View File

@ -0,0 +1,202 @@
// RUN: mlir-hlo-opt %s -lhlo-legalize-to-parallel-loops -canonicalize -split-input-file | FileCheck %s
func @reduce(%arg: memref<100x10x5xf32>,
%init: memref<f32>,
%result: memref<100x5xf32>) {
"xla_lhlo.reduce"(%arg, %init, %result) ( {
^bb0(%lhs: memref<f32>, %rhs: memref<f32>, %res: memref<f32>):
"xla_lhlo.add"(%lhs, %rhs, %res)
: (memref<f32>, memref<f32>, memref<f32>) -> ()
"xla_lhlo.terminator"() : () -> ()
} ) {dimensions = dense<[1]> : tensor<1xi64>}
: (memref<100x10x5xf32>, memref<f32>, memref<100x5xf32>) -> ()
return
}
// CHECK-LABEL: func @reduce(
// CHECK-SAME: [[ARG_BUF:%.*]]: memref<100x10x5xf32>,
// CHECK-SAME: [[INIT_BUF:%.*]]: memref<f32>,
// CHECK-SAME: [[RESULT_BUF:%.*]]: memref<100x5xf32>) {
// CHECK-DAG: [[C0:%.*]] = constant 0 : index
// CHECK-DAG: [[C1:%.*]] = constant 1 : index
// CHECK-DAG: [[C5:%.*]] = constant 5 : index
// CHECK-DAG: [[C10:%.*]] = constant 10 : index
// CHECK-DAG: [[C100:%.*]] = constant 100 : index
// CHECK: [[INIT:%.*]] = load [[INIT_BUF]]
// CHECK: scf.parallel ([[I:%.*]], [[K:%.*]]) = ([[C0]], [[C0]])
// CHECK-SAME: to ([[C100]], [[C5]]) step ([[C1]], [[C1]]) {
// CHECK: [[REDUCTION_RESULT:%.*]] = scf.parallel ([[J:%.*]]) =
// CHECK-SAME: ([[C0]]) to ([[C10]]) step ([[C1]]) init ([[INIT]]) -> f32 {
// CHECK: [[ELEM_TO_REDUCE:%.*]] = load [[ARG_BUF]]
// CHECK-SAME: {{\[}}[[I]], [[J]], [[K]]] : memref<100x10x5xf32>
// CHECK: scf.reduce([[ELEM_TO_REDUCE]]) : f32 {
// CHECK: ^bb0([[ELEM:%.*]]: f32, [[ACC:%.*]]: f32):
// CHECK: [[ELEM_BUF:%.*]] = alloc() : memref<f32>
// CHECK: [[ACC_BUF:%.*]] = alloc() : memref<f32>
// CHECK: [[ACC_OUT_BUF:%.*]] = alloc() : memref<f32>
// CHECK: store [[ELEM]], [[ELEM_BUF]][] : memref<f32>
// CHECK: store [[ACC]], [[ACC_BUF]][] : memref<f32>
// CHECK: "xla_lhlo.add"([[ELEM_BUF]], [[ACC_BUF]], [[ACC_OUT_BUF]])
// CHECK: [[ACC_RESULT:%.*]] = load [[ACC_OUT_BUF]][] : memref<f32>
// CHECK: scf.reduce.return [[ACC_RESULT]] : f32
// CHECK: }
// CHECK: scf.yield
// CHECK: }
// CHECK: store [[REDUCTION_RESULT]], [[RESULT_BUF]]{{\[}}[[I]], [[K]]]
// CHECK: scf.yield
// -----
func @reduce_no_outer_loop(%arg: memref<100xf32>,
%init: memref<f32>,
%result: memref<1xf32>) {
"xla_lhlo.reduce"(%arg, %init, %result) ( {
^bb0(%lhs: memref<f32>, %rhs: memref<f32>, %res: memref<f32>):
"xla_lhlo.add"(%lhs, %rhs, %res)
: (memref<f32>, memref<f32>, memref<f32>) -> ()
"xla_lhlo.terminator"() : () -> ()
} ) {dimensions = dense<[0]> : tensor<1xi64>}
: (memref<100xf32>, memref<f32>, memref<1xf32>) -> ()
return
}
// CHECK-LABEL: func @reduce_no_outer_loop(
// CHECK-SAME: [[ARG_BUF:%.*]]: memref<100xf32>,
// CHECK-SAME: [[ELEM_TO_REDUCE_BUF:%.*]]: memref<f32>,
// CHECK-SAME: [[RESULT_BUF:%.*]]: memref<1xf32>) {
// CHECK-DAG: [[C0:%.*]] = constant 0 : index
// CHECK-DAG: [[C1:%.*]] = constant 1 : index
// CHECK-DAG: [[C100:%.*]] = constant 100 : index
// CHECK: [[INIT:%.*]] = load [[INIT_BUF]]
// CHECK: [[REDUCTION_RESULT:%.*]] = scf.parallel ([[I:%.*]]) = ([[C0]])
// CHECK-SAME: to ([[C100]]) step ([[C1]]) init ([[INIT]]) -> f32 {
// CHECK: [[ELEM_TO_REDUCE:%.*]] = load [[ARG_BUF]]{{\[}}[[I]]{{\]}}
// CHECK: scf.reduce([[ELEM_TO_REDUCE]]) : f32 {
// CHECK: ^bb0([[ELEM:%.*]]: f32, [[ACC:%.*]]: f32):
// CHECK: [[ELEM_BUF:%.*]] = alloc() : memref<f32>
// CHECK: [[ACC_BUF:%.*]] = alloc() : memref<f32>
// CHECK: [[ACC_OUT_BUF:%.*]] = alloc() : memref<f32>
// CHECK: store [[ELEM]], [[ELEM_BUF]][] : memref<f32>
// CHECK: store [[ACC]], [[ACC_BUF]][] : memref<f32>
// CHECK: "xla_lhlo.add"([[ELEM_BUF]], [[ACC_BUF]], [[ACC_OUT_BUF]])
// CHECK: [[ACC_RESULT:%.*]] = load [[ACC_OUT_BUF]][] : memref<f32>
// CHECK: scf.reduce.return [[ACC_RESULT]]
// CHECK: }
// CHECK: scf.yield
// CHECK: store [[REDUCTION_RESULT]], [[RESULT_BUF]]{{\[}}[[C0]]]
// -----
func @dynamic_reduce(%arg: memref<?x?x?xf32>,
%init: memref<f32>,
%result: memref<?x?xf32>) {
"xla_lhlo.reduce"(%arg, %init, %result) ( {
^bb0(%lhs: memref<f32>, %rhs: memref<f32>, %res: memref<f32>):
"xla_lhlo.add"(%lhs, %rhs, %res)
: (memref<f32>, memref<f32>, memref<f32>) -> ()
"xla_lhlo.terminator"() : () -> ()
} ) {dimensions = dense<[1]> : tensor<1xi64>}
: (memref<?x?x?xf32>, memref<f32>, memref<?x?xf32>) -> ()
return
}
// CHECK-LABEL: func @dynamic_reduce(
// CHECK-SAME: [[ARG_BUF:%.*]]: memref<?x?x?xf32>,
// CHECK-SAME: [[INIT_BUF:%.*]]: memref<f32>,
// CHECK-SAME: [[RESULT_BUF:%.*]]: memref<?x?xf32>) {
// CHECK-DAG: [[C0:%.*]] = constant 0 : index
// CHECK-DAG: [[C1:%.*]] = constant 1 : index
// CHECK-DAG: [[C2:%.*]] = constant 2 : index
// CHECK: [[DIM0:%.*]] = dim [[ARG_BUF]], [[C0]] : memref<?x?x?xf32>
// CHECK: [[DIM1:%.*]] = dim [[ARG_BUF]], [[C1]] : memref<?x?x?xf32>
// CHECK: [[DIM2:%.*]] = dim [[ARG_BUF]], [[C2]] : memref<?x?x?xf32>
// CHECK: [[INIT:%.*]] = load [[INIT_BUF]]
// CHECK: scf.parallel ([[I:%.*]], [[K:%.*]]) = ([[C0]], [[C0]])
// CHECK-SAME: to ([[DIM0]], [[DIM2]]) step ([[C1]], [[C1]]) {
// CHECK: [[REDUCTION_RESULT:%.*]] = scf.parallel ([[J:%.*]]) =
// CHECK-SAME: ([[C0]]) to ([[DIM1]]) step ([[C1]]) init ([[INIT]]) -> f32 {
// CHECK: [[ELEM_TO_REDUCE:%.*]] = load [[ARG_BUF]]
// CHECK-SAME: {{\[}}[[I]], [[J]], [[K]]] : memref<?x?x?xf32>
// CHECK: scf.reduce([[ELEM_TO_REDUCE]]) : f32 {
// CHECK: ^bb0([[ELEM:%.*]]: f32, [[ACC:%.*]]: f32):
// CHECK: [[ELEM_BUF:%.*]] = alloc() : memref<f32>
// CHECK: [[ACC_BUF:%.*]] = alloc() : memref<f32>
// CHECK: [[ACC_OUT_BUF:%.*]] = alloc() : memref<f32>
// CHECK: store [[ELEM]], [[ELEM_BUF]][] : memref<f32>
// CHECK: store [[ACC]], [[ACC_BUF]][] : memref<f32>
// CHECK: "xla_lhlo.add"([[ELEM_BUF]], [[ACC_BUF]], [[ACC_OUT_BUF]])
// CHECK: [[ACC_RESULT:%.*]] = load [[ACC_OUT_BUF]][] : memref<f32>
// CHECK: scf.reduce.return [[ACC_RESULT]] : f32
// CHECK: }
// CHECK: scf.yield
// CHECK: }
// CHECK: store [[REDUCTION_RESULT]], [[RESULT_BUF]]{{\[}}[[I]], [[K]]]
// CHECK: scf.yield
// -----
func @reduce_window(%arg: memref<112x112xf32>,
%init: memref<f32>,
%result: memref<56x56xf32>) {
"xla_lhlo.reduce_window"(%arg, %init, %result) ( {
^bb0(%lhs: memref<f32>, %rhs: memref<f32>, %res: memref<f32>):
"xla_lhlo.maximum"(%lhs, %rhs, %res)
: (memref<f32>, memref<f32>, memref<f32>) -> ()
"xla_lhlo.terminator"() : () -> ()
}) {
padding = dense<[[0, 1], [0, 1]]> : tensor<2x2xi64>,
window_dimensions = dense<[3, 3]> : tensor<2xi64>,
window_strides = dense<[2, 2]> : tensor<2xi64>
} : (memref<112x112xf32>, memref<f32>, memref<56x56xf32>) -> ()
return
}
// CHECK-LABEL: func @reduce_window(
// CHECK-SAME: [[OPERAND_BUF:%.*]]: memref<112x112xf32>,
// CHECK-SAME: [[INIT_BUF:%.*]]: memref<f32>,
// CHECK-SAME: [[RESULT_BUF:%.*]]: memref<56x56xf32>) {
// CHECK-DAG: [[C0:%.*]] = constant 0 : index
// CHECK-DAG: [[C1:%.*]] = constant 1 : index
// CHECK-DAG: [[C2:%.*]] = constant 2 : index
// CHECK-DAG: [[C3:%.*]] = constant 3 : index
// CHECK-DAG: [[C56:%.*]] = constant 56 : index
// CHECK-DAG: [[C112:%.*]] = constant 112 : index
// CHECK: [[INIT:%.*]] = load [[INIT_BUF]][] : memref<f32>
// CHECK: scf.parallel ([[I:%.*]], [[J:%.*]]) = ([[C0]], [[C0]])
// CHECK-SAME: to ([[C56]], [[C56]]) step ([[C1]], [[C1]]) {
// CHECK: [[REDUCTION_RESULT:%.*]] = scf.parallel
// CHECK-SAME: ([[IW:%.*]], [[JW:%.*]]) = ([[C0]], [[C0]])
// CHECK-SAME: to ([[C3]], [[C3]]) step ([[C1]], [[C1]])
// CHECK-SAME: init ([[INIT]]) -> f32 {
// CHECK: [[START_I:%.*]] = muli [[I]], [[C2]] : index
// CHECK: [[INDEX_I:%.*]] = addi [[START_I]], [[IW]] : index
// CHECK: [[INDEX_I_FITS:%.*]] = cmpi "ult", [[INDEX_I]], [[C112]]
// CHECK: [[START_J:%.*]] = muli [[J]], [[C2]] : index
// CHECK: [[INDEX_J:%.*]] = addi [[START_J]], [[JW]] : index
// CHECK: [[INDEX_J_FITS:%.*]] = cmpi "ult", [[INDEX_J]], [[C112]]
// CHECK: [[IN_BOUNDS_1:%.*]] = and [[INDEX_I_FITS]], [[INDEX_J_FITS]]
// CHECK: [[ELEM_TO_REDUCE:%.*]] = scf.if [[IN_BOUNDS_1]] -> (f32) {
// CHECK: [[OPERAND_ELEM:%.*]] =
// CHECK-SAME: load [[OPERAND_BUF]]{{\[}}[[INDEX_I]], [[INDEX_J]]]
// CHECK: scf.yield [[OPERAND_ELEM]] : f32
// CHECK: } else {
// CHECK: scf.yield [[INIT]] : f32
// CHECK: }
// CHECK: scf.reduce([[ELEM_TO_REDUCE]]) : f32 {
// CHECK: ^bb0([[ELEM:%.*]]: f32, [[ACC:%.*]]: f32):
// CHECK: [[ELEM_BUF:%.*]] = alloc() : memref<f32>
// CHECK: [[ACC_BUF:%.*]] = alloc() : memref<f32>
// CHECK: [[ACC_OUT_BUF:%.*]] = alloc() : memref<f32>
// CHECK: store [[ELEM]], [[ELEM_BUF]][] : memref<f32>
// CHECK: store [[ACC]], [[ACC_BUF]][] : memref<f32>
// CHECK: "xla_lhlo.maximum"([[ELEM_BUF]], [[ACC_BUF]], [[ACC_OUT_BUF]])
// CHECK: [[ACC_RESULT:%.*]] = load [[ACC_OUT_BUF]][] : memref<f32>
// CHECK: scf.reduce.return [[ACC_RESULT]] : f32
// CHECK: }
// CHECK: scf.yield
// CHECK: }
// CHECK: store [[REDUCTION_RESULT]], [[RESULT_BUF]]{{\[}}[[I]], [[J]]]
// CHECK: scf.yield
// CHECK: }
// CHECK: return
// CHECK: }

1045
tests/lhlo_ops.mlir Normal file

File diff suppressed because it is too large Load Diff

224
tests/lower-complex.mlir Normal file
View File

@ -0,0 +1,224 @@
// RUN: mlir-hlo-opt %s -test-xla-chlo-legalize-to-hlo -test-xla-lower-complex | FileCheck %s
// CHECK-LABEL: @add
func @add(%arg0 : tensor<2xf32>, %arg1 : tensor<2xf32>, %arg2 : tensor<2xf32>, %arg3 : tensor<2xf32>) -> (tensor<2xf32>, tensor<2xf32>) {
%2 = "xla_hlo.complex"(%arg0, %arg1) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex<f32>>)
%3 = "xla_hlo.complex"(%arg2, %arg3) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex<f32>>)
// CHECK-DAG: [[VAL0:%.+]] = xla_hlo.add %arg0, %arg2
// CHECK-DAG: [[VAL1:%.+]] = xla_hlo.add %arg1, %arg3
%4 = "xla_hlo.add"(%2, %3) : (tensor<2xcomplex<f32>>, tensor<2xcomplex<f32>>) -> (tensor<2xcomplex<f32>>)
%5 = "xla_hlo.real"(%4) : (tensor<2xcomplex<f32>>) -> (tensor<2xf32>)
%6 = "xla_hlo.imag"(%4) : (tensor<2xcomplex<f32>>) -> (tensor<2xf32>)
// CHECK: return [[VAL0]], [[VAL1]]
return %5, %6 : tensor<2xf32>, tensor<2xf32>
}
// CHECK-LABEL: @add_unranked
func @add_unranked(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>, %arg2 : tensor<*xf32>, %arg3 : tensor<*xf32>) -> (tensor<*xf32>, tensor<*xf32>) {
%2 = "xla_hlo.complex"(%arg0, %arg1) : (tensor<*xf32>, tensor<*xf32>) -> (tensor<*xcomplex<f32>>)
%3 = "xla_hlo.complex"(%arg2, %arg3) : (tensor<*xf32>, tensor<*xf32>) -> (tensor<*xcomplex<f32>>)
// CHECK-DAG: [[VAL0:%.+]] = xla_hlo.add %arg0, %arg2
// CHECK-DAG: [[VAL1:%.+]] = xla_hlo.add %arg1, %arg3
%4 = "xla_hlo.add"(%2, %3) : (tensor<*xcomplex<f32>>, tensor<*xcomplex<f32>>) -> (tensor<*xcomplex<f32>>)
%5 = "xla_hlo.real"(%4) : (tensor<*xcomplex<f32>>) -> (tensor<*xf32>)
%6 = "xla_hlo.imag"(%4) : (tensor<*xcomplex<f32>>) -> (tensor<*xf32>)
// CHECK: return [[VAL0]], [[VAL1]]
return %5, %6 : tensor<*xf32>, tensor<*xf32>
}
// CHECK-LABEL: @sub
func @sub(%arg0 : tensor<2xf32>, %arg1 : tensor<2xf32>, %arg2 : tensor<2xf32>, %arg3 : tensor<2xf32>) -> (tensor<2xf32>, tensor<2xf32>) {
%2 = "xla_hlo.complex"(%arg0, %arg1) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex<f32>>)
%3 = "xla_hlo.complex"(%arg2, %arg3) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex<f32>>)
// CHECK-DAG: [[VAL0:%.+]] = xla_hlo.subtract %arg0, %arg2
// CHECK-DAG: [[VAL1:%.+]] = xla_hlo.subtract %arg1, %arg3
%4 = "xla_hlo.subtract"(%2, %3) : (tensor<2xcomplex<f32>>, tensor<2xcomplex<f32>>) -> (tensor<2xcomplex<f32>>)
%5 = "xla_hlo.real"(%4) : (tensor<2xcomplex<f32>>) -> (tensor<2xf32>)
%6 = "xla_hlo.imag"(%4) : (tensor<2xcomplex<f32>>) -> (tensor<2xf32>)
// CHECK: return [[VAL0]], [[VAL1]]
return %5, %6 : tensor<2xf32>, tensor<2xf32>
}
// CHECK-LABEL: @sub_unranked
func @sub_unranked(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>, %arg2 : tensor<*xf32>, %arg3 : tensor<*xf32>) -> (tensor<*xf32>, tensor<*xf32>) {
%2 = "xla_hlo.complex"(%arg0, %arg1) : (tensor<*xf32>, tensor<*xf32>) -> (tensor<*xcomplex<f32>>)
%3 = "xla_hlo.complex"(%arg2, %arg3) : (tensor<*xf32>, tensor<*xf32>) -> (tensor<*xcomplex<f32>>)
// CHECK-DAG: [[VAL0:%.+]] = xla_hlo.subtract %arg0, %arg2
// CHECK-DAG: [[VAL1:%.+]] = xla_hlo.subtract %arg1, %arg3
%4 = "xla_hlo.subtract"(%2, %3) : (tensor<*xcomplex<f32>>, tensor<*xcomplex<f32>>) -> (tensor<*xcomplex<f32>>)
%5 = "xla_hlo.real"(%4) : (tensor<*xcomplex<f32>>) -> (tensor<*xf32>)
%6 = "xla_hlo.imag"(%4) : (tensor<*xcomplex<f32>>) -> (tensor<*xf32>)
// CHECK: return [[VAL0]], [[VAL1]]
return %5, %6 : tensor<*xf32>, tensor<*xf32>
}
// CHECK-LABEL: @mul
func @mul(%arg0 : tensor<2xf32>, %arg1 : tensor<2xf32>, %arg2 : tensor<2xf32>, %arg3 : tensor<2xf32>) -> (tensor<2xf32>, tensor<2xf32>) {
%2 = "xla_hlo.complex"(%arg0, %arg1) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex<f32>>)
%3 = "xla_hlo.complex"(%arg2, %arg3) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex<f32>>)
// CHECK-DAG: [[VAL0:%.+]] = xla_hlo.multiply %arg0, %arg2
// CHECK-DAG: [[VAL1:%.+]] = xla_hlo.multiply %arg1, %arg3
// CHECK-DAG: [[VAL2:%.+]] = xla_hlo.subtract [[VAL0]], [[VAL1]]
// CHECK-DAG: [[VAL3:%.+]] = xla_hlo.multiply %arg0, %arg3
// CHECK-DAG: [[VAL4:%.+]] = xla_hlo.multiply %arg1, %arg2
// CHECK-DAG: [[VAL5:%.+]] = xla_hlo.add [[VAL3]], [[VAL4]]
%4 = "xla_hlo.multiply"(%2, %3) : (tensor<2xcomplex<f32>>, tensor<2xcomplex<f32>>) -> (tensor<2xcomplex<f32>>)
%5 = "xla_hlo.real"(%4) : (tensor<2xcomplex<f32>>) -> (tensor<2xf32>)
%6 = "xla_hlo.imag"(%4) : (tensor<2xcomplex<f32>>) -> (tensor<2xf32>)
// CHECK: return %2, %5 : tensor<2xf32>, tensor<2xf32>
return %5, %6 : tensor<2xf32>, tensor<2xf32>
}
// CHECK-LABEL: @mul_unranked
func @mul_unranked(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>, %arg2 : tensor<*xf32>, %arg3 : tensor<*xf32>) -> (tensor<*xf32>, tensor<*xf32>) {
%2 = "xla_hlo.complex"(%arg0, %arg1) : (tensor<*xf32>, tensor<*xf32>) -> (tensor<*xcomplex<f32>>)
%3 = "xla_hlo.complex"(%arg2, %arg3) : (tensor<*xf32>, tensor<*xf32>) -> (tensor<*xcomplex<f32>>)
// CHECK-DAG: [[VAL0:%.+]] = xla_hlo.multiply %arg0, %arg2
// CHECK-DAG: [[VAL1:%.+]] = xla_hlo.multiply %arg1, %arg3
// CHECK-DAG: [[VAL2:%.+]] = xla_hlo.subtract [[VAL0]], [[VAL1]]
// CHECK-DAG: [[VAL3:%.+]] = xla_hlo.multiply %arg0, %arg3
// CHECK-DAG: [[VAL4:%.+]] = xla_hlo.multiply %arg1, %arg2
// CHECK-DAG: [[VAL5:%.+]] = xla_hlo.add [[VAL3]], [[VAL4]]
%4 = "xla_hlo.multiply"(%2, %3) : (tensor<*xcomplex<f32>>, tensor<*xcomplex<f32>>) -> (tensor<*xcomplex<f32>>)
%5 = "xla_hlo.real"(%4) : (tensor<*xcomplex<f32>>) -> (tensor<*xf32>)
%6 = "xla_hlo.imag"(%4) : (tensor<*xcomplex<f32>>) -> (tensor<*xf32>)
// CHECK: return %2, %5 : tensor<*xf32>, tensor<*xf32>
return %5, %6 : tensor<*xf32>, tensor<*xf32>
}
// CHECK-LABEL: @div
func @div(%arg0 : tensor<2xf32>, %arg1 : tensor<2xf32>, %arg2 : tensor<2xf32>, %arg3 : tensor<2xf32>) -> (tensor<2xf32>, tensor<2xf32>) {
%2 = "xla_hlo.complex"(%arg0, %arg1) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex<f32>>)
%3 = "xla_hlo.complex"(%arg2, %arg3) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex<f32>>)
// CHECK-DAG: [[VAL0:%.+]] = "xla_hlo.negate"(%arg3)
// Compute the numerator's real component:
// numerator.real = lhs.real * rhs.real lhs.imag * rhs.imag
// CHECK-DAG: [[VAL1:%.+]] = xla_hlo.multiply %arg0, %arg2
// CHECK-DAG: [[VAL2:%.+]] = xla_hlo.multiply %arg1, [[VAL0]]
// CHECK-DAG: [[VAL3:%.+]] = xla_hlo.subtract [[VAL1]], [[VAL2]]
// Compute the real valued denominator as rhs * con(rhs):
// denominator = rhs.real * rhs.real + rhs.imag * rhs.imag
// CHECK-DAG: [[VAL4:%.+]] = xla_hlo.multiply %arg2, %arg2
// CHECK-DAG: [[VAL5:%.+]] = xla_hlo.multiply %arg3, [[VAL0]]
// CHECK-DAG: [[VAL6:%.+]] = xla_hlo.subtract [[VAL4]], [[VAL5]]
// Compute the numerator's imaginary component:
// numerator.imag = lhs.imag * rhs.real - lhs.real * rhs.imag
// CHECK-DAG: [[VAL7:%.+]] = xla_hlo.multiply %arg1, %arg2
// CHECK-DAG: [[VAL8:%.+]] = xla_hlo.multiply %arg0, [[VAL0]]
// CHECK-DAG: [[VAL9:%.+]] = xla_hlo.add [[VAL8]], [[VAL7]]
// Divide the numerator by the real valued denominator.
// CHECK-DAG: [[VAL10:%.+]] = xla_hlo.divide [[VAL3]], [[VAL6]]
// CHECK-DAG: [[VAL11:%.+]] = xla_hlo.divide [[VAL9]], [[VAL6]]
%4 = "xla_hlo.divide"(%2, %3) : (tensor<2xcomplex<f32>>, tensor<2xcomplex<f32>>) -> (tensor<2xcomplex<f32>>)
%5 = "xla_hlo.real"(%4) : (tensor<2xcomplex<f32>>) -> (tensor<2xf32>)
%6 = "xla_hlo.imag"(%4) : (tensor<2xcomplex<f32>>) -> (tensor<2xf32>)
// CHECK: return [[VAL10]], [[VAL11]]
return %5, %6 : tensor<2xf32>, tensor<2xf32>
}
// -----
// CHECK-LABEL: @div_unranked
func @div_unranked(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>, %arg2 : tensor<*xf32>, %arg3 : tensor<*xf32>) -> (tensor<*xf32>, tensor<*xf32>) {
%2 = "xla_hlo.complex"(%arg0, %arg1) : (tensor<*xf32>, tensor<*xf32>) -> (tensor<*xcomplex<f32>>)
%3 = "xla_hlo.complex"(%arg2, %arg3) : (tensor<*xf32>, tensor<*xf32>) -> (tensor<*xcomplex<f32>>)
// CHECK-DAG: [[VAL0:%.+]] = "xla_hlo.negate"(%arg3)
// Compute the numerator's real component:
// numerator.real = lhs.real * rhs.real lhs.imag * rhs.imag
// CHECK-DAG: [[VAL1:%.+]] = xla_hlo.multiply %arg0, %arg2
// CHECK-DAG: [[VAL2:%.+]] = xla_hlo.multiply %arg1, [[VAL0]]
// CHECK-DAG: [[VAL3:%.+]] = xla_hlo.subtract [[VAL1]], [[VAL2]]
// Compute the real valued denominator as rhs * con(rhs):
// denominator = rhs.real * rhs.real + rhs.imag * rhs.imag
// CHECK-DAG: [[VAL4:%.+]] = xla_hlo.multiply %arg2, %arg2
// CHECK-DAG: [[VAL5:%.+]] = xla_hlo.multiply %arg3, [[VAL0]]
// CHECK-DAG: [[VAL6:%.+]] = xla_hlo.subtract [[VAL4]], [[VAL5]]
// Compute the numerator's imaginary component:
// numerator.imag = lhs.imag * rhs.real - lhs.real * rhs.imag
// CHECK-DAG: [[VAL7:%.+]] = xla_hlo.multiply %arg1, %arg2
// CHECK-DAG: [[VAL8:%.+]] = xla_hlo.multiply %arg0, [[VAL0]]
// CHECK-DAG: [[VAL9:%.+]] = xla_hlo.add [[VAL8]], [[VAL7]]
// Divide the numerator by the real valued denominator.
// CHECK-DAG: [[VAL10:%.+]] = xla_hlo.divide [[VAL3]], [[VAL6]]
// CHECK-DAG: [[VAL11:%.+]] = xla_hlo.divide [[VAL9]], [[VAL6]]
%4 = "xla_hlo.divide"(%2, %3) : (tensor<*xcomplex<f32>>, tensor<*xcomplex<f32>>) -> (tensor<*xcomplex<f32>>)
%5 = "xla_hlo.real"(%4) : (tensor<*xcomplex<f32>>) -> (tensor<*xf32>)
%6 = "xla_hlo.imag"(%4) : (tensor<*xcomplex<f32>>) -> (tensor<*xf32>)
// CHECK: return [[VAL10]], [[VAL11]]
return %5, %6 : tensor<*xf32>, tensor<*xf32>
}
// CHECK-LABEL: @abs
func @abs(%arg0 : tensor<2xf32>, %arg1 : tensor<2xf32>) -> (tensor<2xf32>) {
%0 = "xla_hlo.complex"(%arg0, %arg1) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex<f32>>)
// CHECK-DAG: [[VAL0:%.+]] = xla_hlo.multiply %arg0, %arg0
// CHECK-DAG: [[VAL1:%.+]] = xla_hlo.multiply %arg1, %arg1
// CHECK-DAG: [[VAL2:%.+]] = xla_hlo.add [[VAL0]], [[VAL1]]
// CHECK-DAG: [[VAL3:%.+]] = "xla_hlo.sqrt"([[VAL2]])
%1 = "xla_hlo.abs"(%0) : (tensor<2xcomplex<f32>>) -> (tensor<2xcomplex<f32>>)
%2 = "xla_hlo.real"(%1) : (tensor<2xcomplex<f32>>) -> (tensor<2xf32>)
// CHECK: return [[VAL3]]
return %2 : tensor<2xf32>
}
// CHECK-LABEL: @exp
func @exp(%arg0 : tensor<2xf32>, %arg1 : tensor<2xf32>) -> (tensor<2xf32>, tensor<2xf32>) {
%0 = "xla_hlo.complex"(%arg0, %arg1) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex<f32>>)
// CHECK-DAG: [[VAL0:%.+]] = "xla_hlo.exponential"(%arg0)
// CHECK-DAG: [[VAL1:%.+]] = "xla_hlo.cosine"(%arg1)
// CHECK-DAG: [[VAL2:%.+]] = "xla_hlo.sine"(%arg1)
// CHECK-DAG: [[VAL3:%.+]] = xla_hlo.multiply [[VAL0]], [[VAL1]]
// CHECK-DAG: [[VAL4:%.+]] = xla_hlo.multiply [[VAL0]], [[VAL2]]
%1 = "xla_hlo.exponential"(%0) : (tensor<2xcomplex<f32>>) -> (tensor<2xcomplex<f32>>)
%2 = "xla_hlo.real"(%1) : (tensor<2xcomplex<f32>>) -> (tensor<2xf32>)
%3 = "xla_hlo.imag"(%1) : (tensor<2xcomplex<f32>>) -> (tensor<2xf32>)
// CHECK: return [[VAL3]], [[VAL4]]
return %2, %3 : tensor<2xf32>, tensor<2xf32>
}
// CHECK-LABEL: @exp_unranked
func @exp_unranked(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>) -> (tensor<*xf32>, tensor<*xf32>) {
%0 = "xla_hlo.complex"(%arg0, %arg1) : (tensor<*xf32>, tensor<*xf32>) -> (tensor<*xcomplex<f32>>)
// CHECK-DAG: [[VAL0:%.+]] = "xla_hlo.exponential"(%arg0)
// CHECK-DAG: [[VAL1:%.+]] = "xla_hlo.cosine"(%arg1)
// CHECK-DAG: [[VAL2:%.+]] = "xla_hlo.sine"(%arg1)
// CHECK-DAG: [[VAL3:%.+]] = xla_hlo.multiply [[VAL0]], [[VAL1]]
// CHECK-DAG: [[VAL4:%.+]] = xla_hlo.multiply [[VAL0]], [[VAL2]]
%1 = "xla_hlo.exponential"(%0) : (tensor<*xcomplex<f32>>) -> (tensor<*xcomplex<f32>>)
%2 = "xla_hlo.real"(%1) : (tensor<*xcomplex<f32>>) -> (tensor<*xf32>)
%3 = "xla_hlo.imag"(%1) : (tensor<*xcomplex<f32>>) -> (tensor<*xf32>)
// CHECK: return [[VAL3]], [[VAL4]]
return %2, %3 : tensor<*xf32>, tensor<*xf32>
}

View File

@ -0,0 +1,35 @@
// RUN: mlir-hlo-opt -test-xla-lower-general-dot -split-input-file %s -o - | FileCheck %s
// CHECK-LABEL: @testDebatch1
func @testDebatch1(%arg0: tensor<1x1x2xf32>, %arg1: tensor<2x3xf32>) -> tensor<1x1x3xf32> {
// CHECK-DAG: [[R0:%.+]] = "xla_hlo.reshape"(%arg0) : (tensor<1x1x2xf32>) -> tensor<1x2xf32>
// CHECK-DAG: [[R1:%.+]] = "xla_hlo.dot"([[R0]], %arg1) {precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32>
// CHECK: [[R2:%.+]] = "xla_hlo.reshape"([[R1]]) : (tensor<1x3xf32>) -> tensor<1x1x3xf32>
%0 = "xla_hlo.dot_general"(%arg0, %arg1) {dot_dimension_numbers = {lhs_batching_dimensions = dense<[]> : tensor<0xi64>, lhs_contracting_dimensions = dense<2> : tensor<1xi64>, rhs_batching_dimensions = dense<[]> : tensor<0xi64>, rhs_contracting_dimensions = dense<0> : tensor<1xi64>}, precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<1x1x2xf32>, tensor<2x3xf32>) -> tensor<1x1x3xf32>
return %0 : tensor<1x1x3xf32>
}
// -----
// CHECK-LABEL: @testDebatch2
func @testDebatch2(%arg0: tensor<2x3xf32>, %arg1: tensor<1x1x2xf32>) -> tensor<3x1x1xf32> {
// CHECK-DAG: [[R0:%.+]] = "xla_hlo.transpose"(%arg0) {permutation = dense<[1, 0]> : tensor<2xi64>} : (tensor<2x3xf32>) -> tensor<3x2xf32>
// CHECK-DAG: [[R1:%.+]] = "xla_hlo.transpose"(%arg1) {permutation = dense<[2, 0, 1]> : tensor<3xi64>} : (tensor<1x1x2xf32>) -> tensor<2x1x1xf32>
// CHECK-DAG: [[R2:%.+]] = "xla_hlo.reshape"([[R1]]) : (tensor<2x1x1xf32>) -> tensor<2x1xf32>
// CHECK-DAG: [[R3:%.+]] = "xla_hlo.dot"([[R0]], [[R2]]) {precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<3x2xf32>, tensor<2x1xf32>) -> tensor<3x1xf32>
// CHECK: [[R4:%.+]] = "xla_hlo.reshape"([[R3]]) : (tensor<3x1xf32>) -> tensor<3x1x1xf32>
%0 = "xla_hlo.dot_general"(%arg0, %arg1) {dot_dimension_numbers = {lhs_batching_dimensions = dense<[]> : tensor<0xi64>, lhs_contracting_dimensions = dense<0> : tensor<1xi64>, rhs_batching_dimensions = dense<[]> : tensor<0xi64>, rhs_contracting_dimensions = dense<2> : tensor<1xi64>}, precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<2x3xf32>, tensor<1x1x2xf32>) -> tensor<3x1x1xf32>
return %0 : tensor<3x1x1xf32>
}
// -----
// CHECK-LABEL: @testBatchPassthrough
func @testBatchPassthrough(%arg0: tensor<2x2x3xf32>, %arg1: tensor<2x1x2xf32>) -> tensor<3x2x1xf32> {
// CHECK-NEXT: "xla_hlo.dot_general"(%arg0, %arg1)
%0 = "xla_hlo.dot_general"(%arg0, %arg1) {dot_dimension_numbers = {lhs_batching_dimensions = dense<[0]> : tensor<1xi64>, lhs_contracting_dimensions = dense<1> : tensor<1xi64>, rhs_batching_dimensions = dense<[0]> : tensor<1xi64>, rhs_contracting_dimensions = dense<2> : tensor<1xi64>}, precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<2x2x3xf32>, tensor<2x1x2xf32>) -> tensor<3x2x1xf32>
return %0 : tensor<3x2x1xf32>
}

View File

@ -0,0 +1,11 @@
// RUN: mlir-hlo-opt -test-xla-materialize-broadcasts -split-input-file %s -o - | FileCheck %s
// CHECK-LABEL: @clampBroadcast
// CHECK-SAME: (%[[MIN:.+]]: tensor<f32>, %[[VAL:.+]]: tensor<4xf32>, %[[MAX:.+]]: tensor<f32>)
func @clampBroadcast(%min: tensor<f32>, %value: tensor<4xf32>, %max: tensor<f32>) -> tensor<4xf32> {
// CHECK-DAG: %[[MIN_BC:.+]] = "xla_hlo.broadcast"(%[[MIN]]) {broadcast_sizes = dense<4> : tensor<1xi64>} : (tensor<f32>) -> tensor<4xf32>
// CHECK-DAG: %[[MAX_BC:.+]] = "xla_hlo.broadcast"(%[[MAX]]) {broadcast_sizes = dense<4> : tensor<1xi64>} : (tensor<f32>) -> tensor<4xf32>
// CHECK: "xla_hlo.clamp"(%[[MIN_BC]], %[[VAL]], %[[MAX_BC]]) : (tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
%0 = "xla_hlo.clamp"(%min, %value, %max) : (tensor<f32>, tensor<4xf32>, tensor<f32>) -> tensor<4xf32>
return %0 : tensor<4xf32>
}

1132
tests/ops.mlir Normal file

File diff suppressed because it is too large Load Diff

14
tests/reduce.mlir Normal file
View File

@ -0,0 +1,14 @@
// RUN: mlir-hlo-opt %s -pass-pipeline='func(canonicalize)' | FileCheck %s
// CHECK-LABEL: func @noop
// CHECK-SAME: (%[[ARG0:.*]]: tensor<4x8xf32>)
// CHECK: return %[[ARG0]]
func @noop(%arg0: tensor<4x8xf32>) -> tensor<4x8xf32> {
%0 = xla_hlo.constant dense<0.000000e+00> : tensor<f32>
%2 = "xla_hlo.reduce"(%arg0, %0) ( {
^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>):
%4 = xla_hlo.add %arg1, %arg2 : tensor<f32>
"xla_hlo.return"(%4) : (tensor<f32>) -> ()
}) {dimensions = dense<[]> : tensor<0xi64>} : (tensor<4x8xf32>, tensor<f32>) -> tensor<4x8xf32>
return %2 : tensor<4x8xf32>
}

149
tests/reshape.mlir Normal file
View File

@ -0,0 +1,149 @@
// RUN: mlir-hlo-opt %s -split-input-file -pass-pipeline='func(canonicalize)' | FileCheck %s
// CHECK-LABEL: func @const_fold_collapse_to_scalar
func @const_fold_collapse_to_scalar() -> tensor<i32> {
// CHECK-NEXT: [[CST:%.+]] = xla_hlo.constant dense<42> : tensor<i32>
%cst = xla_hlo.constant dense<42> : tensor<1x1xi32>
%0 = "xla_hlo.reshape"(%cst) : (tensor<1x1xi32>) -> tensor<i32>
// CHECK-NEXT: return [[CST]]
return %0 : tensor<i32>
}
// -----
// CHECK-LABEL: func @const_fold_collapse_to_tensor
func @const_fold_collapse_to_tensor() -> tensor<2xi32> {
// CHECK-NEXT: [[CST:%.+]] = xla_hlo.constant dense<42> : tensor<2xi32>
%cst = xla_hlo.constant dense<42> : tensor<1x2xi32>
%0 = "xla_hlo.reshape"(%cst) : (tensor<1x2xi32>) -> tensor<2xi32>
// CHECK-NEXT: return [[CST]]
return %0 : tensor<2xi32>
}
// -----
// CHECK-LABEL: func @const_fold_expand
func @const_fold_expand() -> tensor<1xi32> {
// CHECK-NEXT: [[CST:%.+]] = xla_hlo.constant dense<42> : tensor<1xi32>
%cst = xla_hlo.constant dense<42> : tensor<i32>
%0 = "xla_hlo.reshape"(%cst) : (tensor<i32>) -> tensor<1xi32>
// CHECK-NEXT: return [[CST]]
return %0 : tensor<1xi32>
}
// -----
// CHECK-LABEL: func @const_fold_nontrivial
func @const_fold_nontrivial() -> tensor<16xi64> {
// CHECK-NEXT: [[CST:%.+]] = xla_hlo.constant dense<42> : tensor<16xi64>
%cst = xla_hlo.constant dense<42> : tensor<4x4xi64>
%0 = "xla_hlo.reshape"(%cst) : (tensor<4x4xi64>) -> tensor<16xi64>
// CHECK-NEXT: return [[CST]]
return %0 : tensor<16xi64>
}
// -----
// CHECK-LABEL: func @const_fold_flatten
func @const_fold_flatten() -> tensor<16xi64> {
// CHECK-NEXT: [[CST:%.+]] = xla_hlo.constant dense<42> : tensor<16xi64>
%cst = xla_hlo.constant dense<42> : tensor<4x4xi64>
%0 = "xla_hlo.reshape"(%cst) : (tensor<4x4xi64>) -> tensor<16xi64>
// CHECK-NEXT: return [[CST]]
return %0 : tensor<16xi64>
}
// -----
// CHECK-LABEL: func @const_fold_6
func @const_fold_6() -> tensor<6xi32> {
// CHECK-NEXT: [[CST:%.+]] = xla_hlo.constant dense<[1, 2, 3, 4, 5, 6]> : tensor<6xi32>
%cst = xla_hlo.constant dense<[[1, 2], [3, 4], [5, 6]]> : tensor<3x2xi32>
%0 = "xla_hlo.reshape"(%cst) : (tensor<3x2xi32>) -> tensor<6xi32>
// CHECK-NEXT: return [[CST]]
return %0 : tensor<6xi32>
}
// -----
// CHECK-LABEL: func @const_fold_same_shape
func @const_fold_same_shape() -> tensor<2x3xi32> {
// CHECK-NEXT: [[CST:%.+]] = xla_hlo.constant dense<[
// CHECK-SAME: [1, 2, 3], [4, 5, 6]
// CHECK-SAME: ]> : tensor<2x3xi32>
%cst = xla_hlo.constant dense<[1, 2, 3, 4, 5, 6]> : tensor<6xi32>
%0 = "xla_hlo.reshape"(%cst) : (tensor<6xi32>) -> tensor<2x3xi32>
// CHECK-NEXT: return [[CST]]
return %0 : tensor<2x3xi32>
}
// -----
// CHECK-LABEL: func @const_fold_float
func @const_fold_float() -> tensor<16xf64> {
// CHECK-NEXT: [[CST:%.+]] = xla_hlo.constant dense<4.2{{0*}}e+00> : tensor<16xf64>
%cst = xla_hlo.constant dense<4.2> : tensor<4x4xf64>
%0 = "xla_hlo.reshape"(%cst) : (tensor<4x4xf64>) -> tensor<16xf64>
// CHECK-NEXT: return [[CST]]
return %0 : tensor<16xf64>
}
// -----
// CHECK-LABEL: func @non_const_same_shape
// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]]
func @non_const_same_shape(%arg : tensor<2x3xi32>) -> tensor<2x3xi32> {
// CHECK-NEXT: return [[ARG]]
%0 = "xla_hlo.reshape"(%arg) : (tensor<2x3xi32>) -> tensor<2x3xi32>
return %0 : tensor<2x3xi32>
}
// -----
// CHECK-LABEL: func @non_const_chained_reshape
// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]]
func @non_const_chained_reshape(%arg : tensor<2x3xi32>) -> (tensor<3x2xi32>, tensor<6xi32>) {
// CHECK-NEXT: "xla_hlo.reshape"([[ARG]]) : (tensor<2x3xi32>) -> tensor<3x2xi32>
// CHECK-NEXT: "xla_hlo.reshape"([[ARG]]) : (tensor<2x3xi32>) -> tensor<6xi32>
%0 = "xla_hlo.reshape"(%arg) : (tensor<2x3xi32>) -> tensor<3x2xi32>
%1 = "xla_hlo.reshape"(%0) : (tensor<3x2xi32>) -> tensor<6xi32>
return %0, %1 : tensor<3x2xi32>, tensor<6xi32> // return both so nothing is removed
}
// -----
// CHECK-LABEL: func @non_const_chained_reshape_unused_parent
// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]]
func @non_const_chained_reshape_unused_parent(%arg : tensor<2x3xi32>) -> tensor<6xi32> {
// CHECK-NEXT: [[RES:%.+]] = "xla_hlo.reshape"([[ARG]]) : (tensor<2x3xi32>) -> tensor<6xi32>
%0 = "xla_hlo.reshape"(%arg) : (tensor<2x3xi32>) -> tensor<3x2xi32>
%1 = "xla_hlo.reshape"(%0) : (tensor<3x2xi32>) -> tensor<6xi32>
// CHECK-NEXT: return [[RES]]
return %1 : tensor<6xi32>
}
// -----
// CHECK-LABEL: func @non_const_chained_reshape_becomes_noop
// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]]
func @non_const_chained_reshape_becomes_noop(%arg : tensor<2x3xi32>) -> tensor<2x3xi32> {
%0 = "xla_hlo.reshape"(%arg) : (tensor<2x3xi32>) -> tensor<3x2xi32>
%1 = "xla_hlo.reshape"(%0) : (tensor<3x2xi32>) -> tensor<2x3xi32>
// CHECK-NEXT: return [[ARG]]
return %1 : tensor<2x3xi32>
}
// -----
// CHECK-LABEL: func @non_const_many_chained_reshapes
// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]]
func @non_const_many_chained_reshapes(%arg : tensor<2x3x4xi32>) -> tensor<1x2x4x3xi32> {
// CHECK-NEXT: [[RES:%.+]] = "xla_hlo.reshape"([[ARG]]) : (tensor<2x3x4xi32>) -> tensor<1x2x4x3xi32>
%0 = "xla_hlo.reshape"(%arg) : (tensor<2x3x4xi32>) -> tensor<4x3x2xi32>
%1 = "xla_hlo.reshape"(%0) : (tensor<4x3x2xi32>) -> tensor<12x2xi32>
%2 = "xla_hlo.reshape"(%1) : (tensor<12x2xi32>) -> tensor<2x12xi32>
%3 = "xla_hlo.reshape"(%2) : (tensor<2x12xi32>) -> tensor<24xi32>
%4 = "xla_hlo.reshape"(%3) : (tensor<24xi32>) -> tensor<1x2x4x3xi32>
// CHECK-NEXT: return [[RES]]
return %4 : tensor<1x2x4x3xi32>
}

9
tests/reverse.mlir Normal file
View File

@ -0,0 +1,9 @@
// RUN: mlir-hlo-opt %s -pass-pipeline='func(canonicalize)' | FileCheck %s
// CHECK-LABEL: func @noop
// CHECK-SAME: (%[[ARG0:.*]]: tensor<1x2xf32>)
func @noop(%arg0: tensor<1x2xf32>) -> tensor<1x2xf32> {
%0 = "xla_hlo.reverse"(%arg0) {dimensions = dense<[]> : tensor<0xi64>} : (tensor<1x2xf32>) -> tensor<1x2xf32>
// CHECK: return %[[ARG0]]
return %0 : tensor<1x2xf32>
}

View File

@ -0,0 +1,60 @@
// RUN: mlir-hlo-opt %s -xla-hlo-sink-constants-to-control-flow | FileCheck %s
// Tests sinking constants to a while loop.
// CHECK-LABEL: func @sink_const_to_while
func @sink_const_to_while(%arg0: tensor<i64>) -> tensor<i64> {
// CHECK-NEXT: xla_hlo.while
%c0 = xla_hlo.constant dense<1> : tensor<i64>
%c1 = xla_hlo.constant dense<2> : tensor<i64>
%0 = "xla_hlo.while"(%arg0) ( {
^bb0(%arg1: tensor<i64>):
// CHECK: %[[ARG1A:.+]]: tensor<i64>
// CHECK: %[[C0:.+]] = xla_hlo.constant dense<1> : tensor<i64>
// CHECK: "xla_hlo.compare"(%[[C0]], %[[ARG1A]])
%1 = "xla_hlo.compare"(%c0, %arg1) {comparison_direction = "LT"} : (tensor<i64>, tensor<i64>) -> tensor<i1>
"xla_hlo.return"(%1) : (tensor<i1>) -> ()
}, {
^bb0(%arg1: tensor<i64>):
// CHECK: %[[ARG1B:.+]]: tensor<i64>
// CHECK-DAG: %[[C1:.+]] = xla_hlo.constant dense<2> : tensor<i64>
// CHECK-DAG: %[[ADD0:.+]] = xla_hlo.add %[[ARG1B]], %[[ARG1B]]
%2 = xla_hlo.add %arg1, %arg1 : tensor<i64>
// CHECK: %[[ADD1:.+]] = xla_hlo.add %[[C1]], %[[ADD0]]
%3 = xla_hlo.add %c1, %2 : tensor<i64>
// CHECK: %[[ADD2:.+]] = xla_hlo.add %[[C1]], %[[ADD1]]
%4 = xla_hlo.add %c1, %3 : tensor<i64>
"xla_hlo.return"(%4) : (tensor<i64>) -> ()
}) : (tensor<i64>) -> tensor<i64>
return %0 : tensor<i64>
}
// Tests sinking constants to a conditional op.
// CHECK-LABEL: func @sink_const_to_conditional
func @sink_const_to_conditional(%arg0: tensor<i64>) -> tensor<i64> {
%c0 = xla_hlo.constant dense<1> : tensor<i64>
%c1 = xla_hlo.constant dense<2> : tensor<i64>
%0 = "xla_hlo.compare"(%arg0, %c0) {comparison_direction = "LT"} : (tensor<i64>, tensor<i64>) -> tensor<i1>
%1 = "xla_hlo.tuple"(%arg0) : (tensor<i64>) -> tuple<tensor<i64>>
// CHECK: xla_hlo.if
%2 = "xla_hlo.if"(%0, %1, %1) ( {
^bb0(%arg1: tuple<tensor<i64>>):
// CHECK: %[[C0:.+]] = xla_hlo.constant dense<1> : tensor<i64>
%3 = "xla_hlo.get_tuple_element"(%arg1) {index = 0 : i32} : (tuple<tensor<i64>>) -> tensor<i64>
// CHECK: %[[ADD0:.+]] = xla_hlo.add %[[C0]],
%4 = xla_hlo.add %c0, %3 : tensor<i64>
%5 = "xla_hlo.tuple"(%4) : (tensor<i64>) -> tuple<tensor<i64>>
"xla_hlo.return"(%5) : (tuple<tensor<i64>>) -> ()
}, {
^bb0(%arg1: tuple<tensor<i64>>):
// CHECK: %[[C1:.+]] = xla_hlo.constant dense<2> : tensor<i64>
%6 = "xla_hlo.get_tuple_element"(%arg1) {index = 0 : i32} : (tuple<tensor<i64>>) -> tensor<i64>
// CHECK: %[[ADD1:.+]] = xla_hlo.add %[[C1]],
%7 = xla_hlo.add %c1, %6 : tensor<i64>
%8 = "xla_hlo.tuple"(%7) : (tensor<i64>) -> tuple<tensor<i64>>
"xla_hlo.return"(%8) : (tuple<tensor<i64>>) -> ()
}) : (tensor<i1>, tuple<tensor<i64>>, tuple<tensor<i64>>) -> tuple<tensor<i64>>
%9 = "xla_hlo.get_tuple_element"(%2) {index = 0 : i32} : (tuple<tensor<i64>>) -> tensor<i64>
return %9 : tensor<i64>
}

29
tests/transpose.mlir Normal file
View File

@ -0,0 +1,29 @@
// RUN: mlir-hlo-opt %s -split-input-file -pass-pipeline='func(canonicalize)' | FileCheck %s
// CHECK-LABEL: func @remove_noop
// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]]
func @remove_noop(%arg : tensor<2x3x9x5xi32>) -> tensor<2x3x9x5xi32> {
%0 = "xla_hlo.transpose"(%arg) {permutation = dense<[0, 1, 2, 3]> : tensor<4xi64>}: (tensor<2x3x9x5xi32>) -> tensor<2x3x9x5xi32>
// CHECK-NEXT: return [[ARG]]
return %0 : tensor<2x3x9x5xi32>
}
// -----
// CHECK-LABEL: func @keep_real_transpose
// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]]
func @keep_real_transpose(%arg : tensor<2x3x9x5xi32>) -> tensor<3x2x5x9xi32> {
// CHECK-NEXT: "xla_hlo.transpose"([[ARG]])
%0 = "xla_hlo.transpose"(%arg) {permutation = dense<[1, 0, 3, 2]> : tensor<4xi64>}: (tensor<2x3x9x5xi32>) -> tensor<3x2x5x9xi32>
return %0 : tensor<3x2x5x9xi32>
}
// -----
// CHECK-LABEL: func @keep_same_shape_real_transpose
// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]]
func @keep_same_shape_real_transpose(%arg : tensor<4x4xi32>) -> tensor<4x4xi32> {
// CHECK-NEXT: "xla_hlo.transpose"([[ARG]])
%0 = "xla_hlo.transpose"(%arg) {permutation = dense<[1, 0]> : tensor<2xi64>}: (tensor<4x4xi32>) -> tensor<4x4xi32>
return %0 : tensor<4x4xi32>
}

10
tests/tuple.mlir Normal file
View File

@ -0,0 +1,10 @@
// RUN: mlir-hlo-opt %s -split-input-file -pass-pipeline='func(canonicalize)' | FileCheck %s
// CHECK-LABEL: func @fold_access
// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]]
func @fold_access(%arg : tensor<i32>) -> tensor<i32> {
// CHECK-NEXT: return [[ARG]]
%tuple = "xla_hlo.tuple"(%arg) : (tensor<i32>) -> tuple<tensor<i32>>
%element = "xla_hlo.get_tuple_element"(%tuple) {index = 0 : i32} : (tuple<tensor<i32>>) -> tensor<i32>
return %element : tensor<i32>
}

View File

@ -0,0 +1,135 @@
// RUN: mlir-hlo-opt -split-input-file -test-xla-unfuse-batch-norm -verify-diagnostics %s | FileCheck --enable-var-scope %s
// CHECK-LABEL: @batchNormInference_2D_inner_features
// CHECK-SAME: %[[X:[^:[:space:]]+]]
// CHECK-SAME: %[[SCALE:[^:[:space:]]+]]
// CHECK-SAME: %[[OFFSET:[^:[:space:]]+]]
// CHECK-SAME: %[[MEAN:[^:[:space:]]+]]
// CHECK-SAME: %[[VARIANCE:[^:[:space:]]+]]
func @batchNormInference_2D_inner_features(
%x: tensor<4x256xf32>, %scale: tensor<256xf32>, %offset: tensor<256xf32>,
%mean: tensor<256xf32>, %variance: tensor<256xf32>)
-> (tensor<4x256xf32>) {
// CHECK-DAG: %[[EPS:.+]] = xla_hlo.constant dense<1.001000e-05> : tensor<f32>
// CHECK-DAG: %[[EPS_BCAST:.+]] = "xla_hlo.broadcast_in_dim"(%[[EPS]]) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<f32>) -> tensor<256xf32>
// CHECK-DAG: %[[VARIANCE_EPS:.+]] = xla_hlo.add %[[VARIANCE]], %[[EPS_BCAST]] : tensor<256xf32>
// CHECK-DAG: %[[STDDEV:.+]] = "xla_hlo.sqrt"(%[[VARIANCE_EPS]]) : (tensor<256xf32>) -> tensor<256xf32>
// CHECK-DAG: %[[STDDEV_BCAST:.+]] = "xla_hlo.broadcast_in_dim"(%[[STDDEV]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<256xf32>) -> tensor<4x256xf32>
// CHECK-DAG: %[[SCALE_BCAST:.+]] = "xla_hlo.broadcast_in_dim"(%[[SCALE]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<256xf32>) -> tensor<4x256xf32>
// CHECK-DAG: %[[OFFSET_BCAST:.+]] = "xla_hlo.broadcast_in_dim"(%[[OFFSET]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<256xf32>) -> tensor<4x256xf32>
// CHECK-DAG: %[[MEAN_BCAST:.+]] = "xla_hlo.broadcast_in_dim"(%[[MEAN]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<256xf32>) -> tensor<4x256xf32>
// CHECK-DAG: %[[X_CENTER:.+]] = xla_hlo.subtract %[[X]], %[[MEAN_BCAST]] : tensor<4x256xf32>
// CHECK-DAG: %[[X_SCALED:.+]] = xla_hlo.multiply %[[X_CENTER]], %[[SCALE_BCAST]] : tensor<4x256xf32>
// CHECK-DAG: %[[X_NORMED:.+]] = xla_hlo.divide %[[X_SCALED]], %[[STDDEV_BCAST]] : tensor<4x256xf32>
// CHECK-DAG: %[[RESULT:.+]] = xla_hlo.add %[[X_NORMED]], %[[OFFSET_BCAST]] : tensor<4x256xf32>
%0 = "xla_hlo.batch_norm_inference"(%x, %scale, %offset, %mean, %variance)
{epsilon = 1.001000e-05 : f32, feature_index = 1 : i64} :
(tensor<4x256xf32>, tensor<256xf32>, tensor<256xf32>, tensor<256xf32>,
tensor<256xf32>) -> tensor<4x256xf32>
// CHECK-DAG: return %[[RESULT]]
return %0 : tensor<4x256xf32>
}
// -----
// CHECK-LABEL: @batchNormInference_4D_middle_features
// Just validate that one of the broadcasts happens correctly and rely on
// the verifier to enforce the rest.
// CHECK-SAME: %[[X:[^:]+]]
// CHECK-SAME: %[[SCALE:[^:]+]]
// CHECK-DAG: %[[SCALE_BCAST:.+]] = "xla_hlo.broadcast_in_dim"(%[[SCALE]]) {broadcast_dimensions = dense<2> : tensor<1xi64>} : (tensor<256xf32>) -> tensor<3x4x256x6xf32>
func @batchNormInference_4D_middle_features(
%x: tensor<3x4x256x6xf32>, %scale: tensor<256xf32>, %offset: tensor<256xf32>,
%mean: tensor<256xf32>, %variance: tensor<256xf32>)
-> (tensor<3x4x256x6xf32>) {
%0 = "xla_hlo.batch_norm_inference"(%x, %scale, %offset, %mean, %variance)
{epsilon = 1.001000e-05 : f32, feature_index = 2 : i64} :
(tensor<3x4x256x6xf32>, tensor<256xf32>, tensor<256xf32>, tensor<256xf32>,
tensor<256xf32>) -> tensor<3x4x256x6xf32>
return %0 : tensor<3x4x256x6xf32>
}
// -----
// CHECK-LABEL: @batchNormInference_f64
// Validate that epsilon is properly promoted to f64
// CHECK-DAG: %[[EPS:.+]] = xla_hlo.constant dense<1.000000e+00> : tensor<f64>
func @batchNormInference_f64(
%x: tensor<4x256xf64>, %scale: tensor<256xf64>, %offset: tensor<256xf64>,
%mean: tensor<256xf64>, %variance: tensor<256xf64>)
-> (tensor<4x256xf64>) {
%0 = "xla_hlo.batch_norm_inference"(%x, %scale, %offset, %mean, %variance)
{epsilon = 1.0 : f32, feature_index = 1 : i64} :
(tensor<4x256xf64>, tensor<256xf64>, tensor<256xf64>, tensor<256xf64>,
tensor<256xf64>) -> tensor<4x256xf64>
return %0 : tensor<4x256xf64>
}
// -----
// CHECK-LABEL: @batchNormInference_f16
// Validate that epsilon is properly promoted to f64
// CHECK-DAG: %[[EPS:.+]] = xla_hlo.constant dense<1.000000e+00> : tensor<f16>
func @batchNormInference_f16(
%x: tensor<4x256xf16>, %scale: tensor<256xf16>, %offset: tensor<256xf16>,
%mean: tensor<256xf16>, %variance: tensor<256xf16>)
-> (tensor<4x256xf16>) {
%0 = "xla_hlo.batch_norm_inference"(%x, %scale, %offset, %mean, %variance)
{epsilon = 1.0 : f32, feature_index = 1 : i64} :
(tensor<4x256xf16>, tensor<256xf16>, tensor<256xf16>, tensor<256xf16>,
tensor<256xf16>) -> tensor<4x256xf16>
return %0 : tensor<4x256xf16>
}
// -----
// Validate that epsilon is properly promoted to f64
func @batchNormInference_f16_overflow(
%x: tensor<4x256xf16>, %scale: tensor<256xf16>, %offset: tensor<256xf16>,
%mean: tensor<256xf16>, %variance: tensor<256xf16>)
-> (tensor<4x256xf16>) {
// expected-warning @+1 {{Could not convert batch_norm epsilon to target fp type: opStatus = 24}}
%0 = "xla_hlo.batch_norm_inference"(%x, %scale, %offset, %mean, %variance)
{epsilon = 0.00000001 : f32, feature_index = 1 : i64} :
(tensor<4x256xf16>, tensor<256xf16>, tensor<256xf16>, tensor<256xf16>,
tensor<256xf16>) -> tensor<4x256xf16>
return %0 : tensor<4x256xf16>
}
// -----
// CHECK-LABEL: @batchNormInference_dynamic_shape
// Validate that dynamic shapes are handled properly.
// CHECK-SAME: %[[X:[^:[:space:]]+]]
// CHECK-SAME: %[[SCALE:[^:[:space:]]+]]
// CHECK-SAME: %[[OFFSET:[^:[:space:]]+]]
// CHECK-SAME: %[[MEAN:[^:[:space:]]+]]
// CHECK-SAME: %[[VARIANCE:[^:[:space:]]+]]
func @batchNormInference_dynamic_shape(
%x: tensor<?x?x?x?xf32>, %scale: tensor<?xf32>, %offset: tensor<?xf32>,
%mean: tensor<?xf32>, %variance: tensor<?xf32>)
-> tensor<?x?x?x?xf32> {
// CHECK-DAG: %[[C0:.*]] = constant 0 : index
// CHECK-DAG: %[[C1:.*]] = constant 1 : index
// CHECK-DAG: %[[C2:.*]] = constant 2 : index
// CHECK-DAG: %[[C3:.*]] = constant 3 : index
// CHECK-DAG: %[[EPS:.+]] = xla_hlo.constant dense<1.000000e-03> : tensor<f32>
// CHECK-DAG: %[[DIM:.+]] = dim %[[VARIANCE]], %[[C0]] : tensor<?xf32>
// CHECK-DAG: %[[TO_DIM_TENSOR:.+]] = tensor_from_elements(%[[DIM]]) : tensor<1xindex>
// CHECK-DAG: %[[EPS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[EPS]], %[[TO_DIM_TENSOR]]) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<f32>, tensor<1xindex>) -> tensor<?xf32>
// CHECK-DAG: %[[VARIANCE_EPS:.+]] = xla_hlo.add %[[VARIANCE]], %[[EPS_BCAST]] : tensor<?xf32>
// CHECK-DAG: %[[STDDEV:.+]] = "xla_hlo.sqrt"(%[[VARIANCE_EPS]]) : (tensor<?xf32>) -> tensor<?xf32>
// CHECK-DAG: %[[INPUT_DIM_0:.+]] = dim %[[X]], %[[C0]] : tensor<?x?x?x?xf32>
// CHECK-DAG: %[[INPUT_DIM_1:.+]] = dim %[[X]], %[[C1]] : tensor<?x?x?x?xf32>
// CHECK-DAG: %[[INPUT_DIM_2:.+]] = dim %[[X]], %[[C2]] : tensor<?x?x?x?xf32>
// CHECK-DAG: %[[INPUT_DIM_3:.+]] = dim %[[X]], %[[C3]] : tensor<?x?x?x?xf32>
// CHECK-DAG: %[[TO_INPUT_DIM_TENSOR:.+]] = tensor_from_elements(%[[INPUT_DIM_0]], %[[INPUT_DIM_1]], %[[INPUT_DIM_2]], %[[INPUT_DIM_3]]) : tensor<4xindex>
// CHECK-DAG: %[[STDDEV_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[STDDEV]], %[[TO_INPUT_DIM_TENSOR]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<?xf32>, tensor<4xindex>) -> tensor<?x?x?x?xf32>
// CHECK-DAG: %[[SCALE_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[SCALE]], %[[TO_INPUT_DIM_TENSOR]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<?xf32>, tensor<4xindex>) -> tensor<?x?x?x?xf32>
// CHECK-DAG: %[[OFFSET_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[OFFSET]], %[[TO_INPUT_DIM_TENSOR]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<?xf32>, tensor<4xindex>) -> tensor<?x?x?x?xf32>
// CHECK-DAG: %[[MEAN_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[MEAN]], %[[TO_INPUT_DIM_TENSOR]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<?xf32>, tensor<4xindex>) -> tensor<?x?x?x?xf32>
// CHECK-DAG: %[[X_CENTER:.+]] = xla_hlo.subtract %[[X]], %[[MEAN_BCAST]] : tensor<?x?x?x?xf32>
// CHECK-DAG: %[[X_SCALED:.+]] = xla_hlo.multiply %[[X_CENTER]], %[[SCALE_BCAST]] : tensor<?x?x?x?xf32>
// CHECK-DAG: %[[X_NORMED:.+]] = xla_hlo.divide %[[X_SCALED]], %[[STDDEV_BCAST]] : tensor<?x?x?x?xf32>
// CHECK-DAG: %[[RESULT:.+]] = xla_hlo.add %[[X_NORMED]], %[[OFFSET_BCAST]] : tensor<?x?x?x?xf32>
%0 = "xla_hlo.batch_norm_inference"(%x, %scale, %offset, %mean, %variance)
{epsilon = 0.001 : f32, feature_index = 1 : i64} :
(tensor<?x?x?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>,
tensor<?xf32>) -> tensor<?x?x?x?xf32>
return %0 : tensor<?x?x?x?xf32>
}

97
tests/xla-hlo-fusion.mlir Normal file
View File

@ -0,0 +1,97 @@
// RUN: mlir-hlo-opt %s -xla-hlo-fusion -split-input-file | FileCheck %s
// CHECK-LABEL: func @multi_outputs_same
func @multi_outputs_same(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> (tensor<?x?xf32>, tensor<?x?xf32>) {
%0 = "xla_hlo.add"(%arg0, %arg1) : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
%1 = "xla_hlo.subtract"(%arg0, %0) : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
%2 = "xla_hlo.add"(%1, %1) : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK: %[[RET:.*]]:2 = "xla_hlo.fusion"
// CHECK-NEXT: xla_hlo.add
// CHECK-NEXT: xla_hlo.subtract
// CHECK-NEXT: xla_hlo.add
// CHECK-NEXT: xla_hlo.return
return %1, %2 : tensor<?x?xf32>, tensor<?x?xf32>
}
// -----
// CHECK-LABEL: func @multi_outputs_same_2
func @multi_outputs_same_2(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> (tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>) {
%0 = "xla_hlo.abs"(%arg0) : (tensor<?x?xf32>) -> tensor<?x?xf32>
%1 = "xla_hlo.abs"(%arg1) : (tensor<?x?xf32>) -> tensor<?x?xf32>
%2 = "xla_hlo.add"(%0, %1) : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
%3 = "xla_hlo.abs"(%0) : (tensor<?x?xf32>) -> tensor<?x?xf32>
%4 = "xla_hlo.abs"(%1) : (tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK: %[[RET:.*]]:3 = "xla_hlo.fusion"
// CHECK-NEXT: xla_hlo.abs
// CHECK-NEXT: xla_hlo.abs
// CHECK-NEXT: xla_hlo.add
// CHECK-NEXT: xla_hlo.abs
// CHECK-NEXT: xla_hlo.abs
// CHECK-NEXT: xla_hlo.return
return %2, %3, %4 : tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>
}
// -----
// CHECK-LABEL: func @multi_outputs_not_sure_same
func @multi_outputs_not_sure_same(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> (tensor<?x?xf32>, tensor<?x?xf32>) {
%0 = "xla_hlo.add"(%arg0, %arg0) : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK-NOT: xla_hlo.fusion
%1 = "xla_hlo.subtract"(%arg1, %arg1) : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
return %0, %1 : tensor<?x?xf32>, tensor<?x?xf32>
}
// -----
// CHECK-LABEL: func @reduce
func @reduce(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> (tensor<?x?xf32>, tensor<?xf32>) {
%0 = "xla_hlo.add"(%arg0, %arg1) : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
%1 = "xla_hlo.subtract"(%arg0, %0) : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK: %[[RET0:.*]] = "xla_hlo.fusion"
// CHECK-NEXT: xla_hlo.add
// CHECK-NEXT: xla_hlo.subtract
// CHECK-NEXT: xla_hlo.return
// Currently we do not support fuse arguments and ops without direct producer-consumer
// relationship. Thus Reduce Op should not be fused with above two ops.
%2 = xla_hlo.constant dense<0.000000e+00> : tensor<f32>
%3 = "xla_hlo.reduce"(%arg0, %2) ( {
^bb0(%arg2: tensor<f32>, %arg3: tensor<f32>):
%4 = "xla_hlo.add"(%arg2, %arg3) : (tensor<f32>, tensor<f32>) -> tensor<f32>
"xla_hlo.return"(%4) : (tensor<f32>) -> ()
}) {dimensions = dense<[1]> : tensor<1xi64>} : (tensor<?x?xf32>, tensor<f32>) -> tensor<?xf32>
%4 = "xla_hlo.add"(%3, %3) : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
// Above two ops should not be fused since reduce op can not be
// fused with its consumer.
// CHECK-NOT: xla_hlo.fusion
return %1, %4 : tensor<?x?xf32>, tensor<?xf32>
}
// -----
// CHECK-LABEL: func @reduce_2
func @reduce_2(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> (tensor<?x?xf32>, tensor<?xf32>) {
%0 = "xla_hlo.add"(%arg0, %arg1) : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
%1 = "xla_hlo.subtract"(%arg0, %0) : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
%2 = xla_hlo.constant dense<0.000000e+00> : tensor<f32>
%3 = "xla_hlo.reduce"(%1, %2) ( {
^bb0(%arg2: tensor<f32>, %arg3: tensor<f32>):
%4 = "xla_hlo.add"(%arg2, %arg3) : (tensor<f32>, tensor<f32>) -> tensor<f32>
"xla_hlo.return"(%4) : (tensor<f32>) -> ()
}) {dimensions = dense<[1]> : tensor<1xi64>} : (tensor<?x?xf32>, tensor<f32>) -> tensor<?xf32>
// CHECK: %[[RET0:.*]]:2 = "xla_hlo.fusion"
// CHECK-NEXT: xla_hlo.add
// CHECK-NEXT: xla_hlo.subtract
// CHECK-NEXT: xla_hlo.constant
// CHECK-NEXT: xla_hlo.reduce
// CHECK: xla_hlo.return
// Following op should not be fused with the above ops since reduce op can not be
// fused with its consumer.
// CHECK-NOT: xla_hlo.fusion
%4 = "xla_hlo.add"(%3, %3) : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
return %1, %4 : tensor<?x?xf32>, tensor<?xf32>
}

View File

@ -0,0 +1,88 @@
// RUN: mlir-hlo-opt -transform-unranked-hlo -split-input-file %s | FileCheck %s
// Check the validity of expected IR.
// CHECK-LABEL: @sqr_transform_result
func @sqr_transform_result(%a: tensor<*xf32>) -> tensor<*xf32> {
// Flatten operand shape.
%shape = shape.shape_of %a : tensor<*xf32>
%num_elements = shape.num_elements %shape
%num_elements_as_index = shape.size_to_index %num_elements
%flat_shape = tensor_from_elements(%num_elements_as_index) : tensor<1xindex>
%flat_a = "xla_hlo.dynamic_reshape"(%a, %flat_shape)
: (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
// Apply operation.
%flat_b = "xla_hlo.sqrt"(%flat_a) : (tensor<?xf32>) -> tensor<?xf32>
// Restore original shape.
%shape_as_extent_tensor = shape.to_extent_tensor %shape : tensor<?xindex>
%b = "xla_hlo.dynamic_reshape"(%flat_b, %shape_as_extent_tensor)
: (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
return %b : tensor<*xf32>
}
// -----
// Check transformation of unranked code.
// CHECK-LABEL: @sqrt
// CHECK-SAME: (%[[A:.*]]: tensor<*xf32>)
func @sqrt(%a: tensor<*xf32>) -> tensor<*xf32> {
// CHECK-NEXT: %[[SHAPE:.*]] = shape.shape_of %[[A]] : tensor<*xf32>
// CHECK-NEXT: %[[NUM_ELEMENTS:.*]] = shape.num_elements %[[SHAPE]]
// CHECK-NEXT: %[[NUM_ELEMENTS_AS_INDEX:.*]] = shape.size_to_index %[[NUM_ELEMENTS]]
// CHECK-NEXT: %[[FLAT_SHAPE:.*]] = tensor_from_elements(%[[NUM_ELEMENTS_AS_INDEX]]) : tensor<1xindex>
// CHECK-NEXT: %[[FLAT_A:.*]] = "xla_hlo.dynamic_reshape"(%[[A]], %[[FLAT_SHAPE]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
// CHECK-NEXT: %[[FLAT_B:.*]] = "xla_hlo.sqrt"(%[[FLAT_A]]) : (tensor<?xf32>) -> tensor<?xf32>
// CHECK-NEXT: %[[SHAPE_AS_EXTENT_TENSOR:.*]] = shape.to_extent_tensor %[[SHAPE]] : tensor<?xindex>
// CHECK-NEXT: %[[B:.*]] = "xla_hlo.dynamic_reshape"(%[[FLAT_B]], %[[SHAPE_AS_EXTENT_TENSOR]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
// CHECK-NEXT: return %[[B]] : tensor<*xf32>
%b = "xla_hlo.sqrt"(%a) : (tensor<*xf32>) -> tensor<*xf32>
return %b : tensor<*xf32>
}
// -----
// Not transformed when ranked.
// CHECK-LABEL: @sqrt_ranked
// CHECK-SAME: (%[[A:.*]]: tensor<3x?xf32>)
func @sqrt_ranked(%a: tensor<3x?xf32>) -> tensor<3x?xf32> {
// CHECK-NEXT: %[[B:.*]] = "xla_hlo.sqrt"(%[[A]]) : (tensor<3x?xf32>) -> tensor<3x?xf32>
// CHECK-NEXT: return %[[B]] : tensor<3x?xf32>
%b = "xla_hlo.sqrt"(%a) : (tensor<3x?xf32>) -> tensor<3x?xf32>
return %b : tensor<3x?xf32>
}
// -----
// Not transformed when statically shaped.
// CHECK-LABEL: @sqrt_static
// CHECK-SAME: (%[[A:.*]]: tensor<2x3xf32>)
func @sqrt_static(%a: tensor<2x3xf32>) -> tensor<2x3xf32> {
// CHECK-NEXT: %[[B:.*]] = "xla_hlo.sqrt"(%[[A]]) : (tensor<2x3xf32>) -> tensor<2x3xf32>
// CHECK-NEXT: return %[[B]] : tensor<2x3xf32>
%b = "xla_hlo.sqrt"(%a) : (tensor<2x3xf32>) -> tensor<2x3xf32>
return %b : tensor<2x3xf32>
}
// -----
// CHECK-LABEL: @add_unranked
// CHECK-SAME: (%[[A:.*]]: tensor<*xf32>, %[[B:.*]]: tensor<*xf32>) -> tensor<*xf32>
func @add_unranked(%a : tensor<*xf32>, %b : tensor<*xf32>) -> tensor<*xf32> {
// CHECK: %[[SHAPE_A:.*]] = shape.shape_of %[[A]]
// CHECK: %[[SHAPE_B:.*]] = shape.shape_of %[[B]]
// CHECK: %[[SHAPE:.*]] = shape.any %[[SHAPE_A]], %[[SHAPE_B]]
// CHECK: %[[NUM_ELEMENTS:.*]] = shape.num_elements %[[SHAPE]]
// CHECK: %[[NUM_ELEMENTS_AS_INDEX:.*]] = shape.size_to_index %[[NUM_ELEMENTS]]
// CHECK: %[[FLAT_SHAPE:.*]] = tensor_from_elements(%[[NUM_ELEMENTS_AS_INDEX]]) : tensor<1xindex>
// CHECK: %[[FLAT_A:.*]] = "xla_hlo.dynamic_reshape"(%[[A]], %[[FLAT_SHAPE]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
// CHECK: %[[FLAT_B:.*]] = "xla_hlo.dynamic_reshape"(%[[B]], %[[FLAT_SHAPE]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
// CHECK: %[[FLAT_RESULT:.*]] = xla_hlo.add %[[FLAT_A]], %[[FLAT_B]] : tensor<?xf32>
// CHECK: %[[SHAPE_AS_EXTENT_TENSOR:.*]] = shape.to_extent_tensor %[[SHAPE]] : tensor<?xindex>
// CHECK: %[[RESULT:.*]] = "xla_hlo.dynamic_reshape"(%[[FLAT_RESULT]], %[[SHAPE_AS_EXTENT_TENSOR]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
// CHECK: return %[[RESULT]] : tensor<*xf32>
%result = xla_hlo.add %a, %b : tensor<*xf32>
return %result : tensor<*xf32>
}