2020-11-09 20:23:54 +08:00
|
|
|
|
// RUN: mlir-hlo-opt -hlo-legalize-to-lhlo -buffer-hoisting \
|
|
|
|
|
// RUN: -buffer-deallocation -split-input-file -cse %s -o - \
|
|
|
|
|
// RUN: | FILECHECK_OPTS="" FileCheck %s
|
2020-07-07 07:28:26 +08:00
|
|
|
|
|
2020-11-04 01:49:13 +08:00
|
|
|
|
// CHECK-LABEL: func @attrs
|
2020-12-22 22:27:57 +08:00
|
|
|
|
func @attrs_copy(%operand: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
|
|
|
|
%result = "mhlo.exponential"(%operand)
|
2020-07-07 07:28:26 +08:00
|
|
|
|
{some_attr_1 = "exp.1", some_attr_2 = dense<1> : tensor<1xi64>}
|
|
|
|
|
: (tensor<2x2xf32>) -> tensor<2x2xf32>
|
2020-11-04 01:49:13 +08:00
|
|
|
|
// CHECK: "lmhlo.exponential"(%{{.*}}, %{{.*}}) {some_attr_1 = "exp.1", some_attr_2 = dense<1> : tensor<1xi64>}
|
2020-12-22 22:27:57 +08:00
|
|
|
|
return %result : tensor<2x2xf32>
|
2020-07-07 07:28:26 +08:00
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// -----
|
|
|
|
|
|
|
|
|
|
func @return_func(%arg0: tensor<4xf32>) -> tensor<4xf32> {
|
|
|
|
|
return %arg0 : tensor<4xf32>
|
|
|
|
|
}
|
2020-11-04 01:49:13 +08:00
|
|
|
|
// CHECK: (%[[ARG0:.*]]: [[TYPE:.*]]) -> [[TYPE]]
|
|
|
|
|
// CHECK-NEXT: return %[[ARG0]]
|
2020-07-07 07:28:26 +08:00
|
|
|
|
|
|
|
|
|
// -----
|
|
|
|
|
|
2020-11-04 01:49:13 +08:00
|
|
|
|
// CHECK-LABEL: func @func_op_long
|
2020-07-07 07:28:26 +08:00
|
|
|
|
func @func_op_long(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
|
2020-07-07 12:51:24 +08:00
|
|
|
|
%1 = mhlo.maximum %arg0, %arg1 : tensor<4xf32>
|
|
|
|
|
%2 = mhlo.add %arg0, %1 : tensor<4xf32>
|
|
|
|
|
%3 = mhlo.minimum %arg0, %arg1 : tensor<4xf32>
|
|
|
|
|
%4 = mhlo.subtract %arg1, %3 : tensor<4xf32>
|
|
|
|
|
%5 = mhlo.multiply %2, %4 : tensor<4xf32>
|
2020-07-07 07:28:26 +08:00
|
|
|
|
return %5 : tensor<4xf32>
|
|
|
|
|
}
|
2020-11-04 01:49:13 +08:00
|
|
|
|
// CHECK: (%[[NEW_ARG0:.*]]: memref<4xf32>, %[[NEW_ARG1:.*]]: memref<4xf32>) -> memref<4xf32>
|
|
|
|
|
// CHECK-NEXT: %[[MAX_RESULT:.*]] = alloc() : memref<4xf32>
|
|
|
|
|
// CHECK-NEXT: "lmhlo.maximum"(%[[NEW_ARG0]], %[[NEW_ARG1]], %[[MAX_RESULT]])
|
|
|
|
|
// CHECK-NEXT: %[[ADD_RESULT:.*]] = alloc() : memref<4xf32>
|
|
|
|
|
// CHECK-NEXT: "lmhlo.add"(%[[NEW_ARG0]], %[[MAX_RESULT]], %[[ADD_RESULT]])
|
|
|
|
|
// CHECK-NEXT: dealloc %[[MAX_RESULT]] : memref<4xf32>
|
|
|
|
|
// CHECK-NEXT: %[[MIN_RESULT:.*]] = alloc() : memref<4xf32>
|
|
|
|
|
// CHECK-NEXT: "lmhlo.minimum"(%[[NEW_ARG0]], %[[NEW_ARG1]], %[[MIN_RESULT]])
|
|
|
|
|
// CHECK-NEXT: %[[SUB_RESULT:.*]] = alloc() : memref<4xf32>
|
|
|
|
|
// CHECK-NEXT: "lmhlo.subtract"(%[[NEW_ARG1]], %[[MIN_RESULT]], %[[SUB_RESULT]])
|
|
|
|
|
// CHECK-NEXT: dealloc %[[MIN_RESULT]] : memref<4xf32>
|
|
|
|
|
// CHECK-NEXT: %[[MUL_RESULT:.*]] = alloc() : memref<4xf32>
|
|
|
|
|
// CHECK-NEXT: "lmhlo.multiply"(%[[ADD_RESULT]], %[[SUB_RESULT]], %[[MUL_RESULT]])
|
|
|
|
|
// CHECK-NEXT: dealloc %[[SUB_RESULT]] : memref<4xf32>
|
|
|
|
|
// CHECK-NEXT: dealloc %[[ADD_RESULT]] : memref<4xf32>
|
|
|
|
|
// CHECK-NEXT: return %[[MUL_RESULT]] : memref<4xf32>
|
|
|
|
|
|
|
|
|
|
// -----
|
|
|
|
|
|
|
|
|
|
// CHECK-LABEL: func @fusion
|
2020-12-22 22:27:57 +08:00
|
|
|
|
func @fusion(%multiplier: tensor<2x2xf32>, %summand_1: tensor<2x2xf32>,
|
|
|
|
|
%summand_2: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
|
|
|
|
// CHECK: (%{{.*}}: {{.*}}, {{.*}}: {{.*}}, {{.*}}: {{.*}})
|
2020-11-04 01:49:13 +08:00
|
|
|
|
// CHECK-NEXT: %[[ADD_RESULT:.*]] = alloc() : memref<2x2xf32>
|
2020-12-22 22:27:57 +08:00
|
|
|
|
%sum = "mhlo.add"(%summand_1, %summand_2)
|
2020-07-07 07:28:26 +08:00
|
|
|
|
: (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
|
2020-11-04 01:49:13 +08:00
|
|
|
|
// CHECK-NEXT: "lmhlo.add"(%{{.*}}, %{{.*}}, %[[ADD_RESULT]])
|
|
|
|
|
// CHECK-NEXT: %[[MUL_RESULT:.*]] = alloc() : memref<2x2xf32>
|
2020-12-22 22:27:57 +08:00
|
|
|
|
%result = "mhlo.multiply"(%sum, %multiplier)
|
2020-07-07 07:28:26 +08:00
|
|
|
|
: (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
|
2020-11-04 01:49:13 +08:00
|
|
|
|
// CHECK-NEXT: "lmhlo.multiply"(%[[ADD_RESULT]], %{{.*}}, %[[MUL_RESULT]])
|
|
|
|
|
// CHECK-NEXT: dealloc %[[ADD_RESULT]] : memref<2x2xf32>
|
2020-12-22 22:27:57 +08:00
|
|
|
|
// CHECK-NEXT: return %[[MUL_RESULT]] : memref<2x2xf32>
|
|
|
|
|
return %result : tensor<2x2xf32>
|
2020-07-07 07:28:26 +08:00
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// -----
|
|
|
|
|
|
2020-11-04 01:49:13 +08:00
|
|
|
|
// CHECK-LABEL: func @copy
|
2020-12-22 22:27:57 +08:00
|
|
|
|
func @copy(%operand: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
|
|
|
|
%result = "mhlo.copy"(%operand) : (tensor<2x2xf32>) -> tensor<2x2xf32>
|
|
|
|
|
// TODO(herhut): An explicit copy should not be removed.
|
|
|
|
|
// TODO-CHECK: "lmhlo.copy"(%{{.*}}, %{{.*}})
|
|
|
|
|
return %result : tensor<2x2xf32>
|
2020-07-07 07:28:26 +08:00
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// -----
|
|
|
|
|
|
2020-11-04 01:49:13 +08:00
|
|
|
|
// CHECK-LABEL: func @exp
|
2020-12-22 22:27:57 +08:00
|
|
|
|
func @exp(%operand: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
|
|
|
|
%result = "mhlo.exponential"(%operand) : (tensor<2x2xf32>) -> tensor<2x2xf32>
|
2020-11-04 01:49:13 +08:00
|
|
|
|
// CHECK: "lmhlo.exponential"(%{{.*}}, %{{.*}})
|
2020-12-22 22:27:57 +08:00
|
|
|
|
return %result : tensor<2x2xf32>
|
2020-07-07 07:28:26 +08:00
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// -----
|
|
|
|
|
|
2021-02-18 20:53:34 +08:00
|
|
|
|
// CHECK-LABEL: func @expm1
|
|
|
|
|
func @expm1(%operand: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
|
|
|
|
%result = "mhlo.exponential_minus_one"(%operand) : (tensor<2x2xf32>) -> tensor<2x2xf32>
|
|
|
|
|
// CHECK: "lmhlo.exponential_minus_one"(%{{.*}}, %{{.*}})
|
|
|
|
|
return %result : tensor<2x2xf32>
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// -----
|
|
|
|
|
|
2020-11-04 01:49:13 +08:00
|
|
|
|
// CHECK-LABEL: func @log
|
2020-12-22 22:27:57 +08:00
|
|
|
|
func @log(%operand: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
|
|
|
|
%result = "mhlo.log"(%operand) : (tensor<2x2xf32>) -> tensor<2x2xf32>
|
2020-11-04 01:49:13 +08:00
|
|
|
|
// CHECK: "lmhlo.log"(%{{.*}}, %{{.*}})
|
2020-12-22 22:27:57 +08:00
|
|
|
|
return %result : tensor<2x2xf32>
|
2020-07-07 07:28:26 +08:00
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// -----
|
|
|
|
|
|
2020-11-04 01:49:13 +08:00
|
|
|
|
// CHECK-LABEL: func @select
|
2020-12-22 22:27:57 +08:00
|
|
|
|
func @select(%pred: tensor<2x2xi1>, %lhs: tensor<2x2xf32>,
|
|
|
|
|
%rhs: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
|
|
|
|
%result = "mhlo.select"(%pred, %lhs, %rhs)
|
2020-07-07 07:28:26 +08:00
|
|
|
|
: (tensor<2x2xi1>, tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
|
2020-11-04 01:49:13 +08:00
|
|
|
|
// CHECK: "lmhlo.select"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}})
|
2020-12-22 22:27:57 +08:00
|
|
|
|
return %result : tensor<2x2xf32>
|
2020-07-07 07:28:26 +08:00
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// -----
|
|
|
|
|
|
2020-11-04 01:49:13 +08:00
|
|
|
|
// CHECK-LABEL: func @compare
|
2020-12-22 22:27:57 +08:00
|
|
|
|
func @compare(%lhs: tensor<2x2xf32>, %rhs: tensor<2x2xf32>) -> tensor<2x2xi1> {
|
|
|
|
|
%result = "mhlo.compare"(%lhs, %rhs)
|
2020-07-07 07:28:26 +08:00
|
|
|
|
{comparison_direction = "EQ"}
|
|
|
|
|
: (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xi1>
|
2020-11-04 01:49:13 +08:00
|
|
|
|
// CHECK: "lmhlo.compare"(%{{.*}}, %{{.*}}, %{{.*}}) {comparison_direction = "EQ"}
|
2020-12-22 22:27:57 +08:00
|
|
|
|
return %result : tensor<2x2xi1>
|
2020-07-07 07:28:26 +08:00
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// -----
|
|
|
|
|
|
2020-11-04 01:49:13 +08:00
|
|
|
|
// CHECK-LABEL: func @broadcast
|
2020-12-22 22:27:57 +08:00
|
|
|
|
func @broadcast(%operand: tensor<5xf32>) -> tensor<10x5xf32> {
|
|
|
|
|
%result = "mhlo.broadcast_in_dim"(%operand)
|
2020-07-07 07:28:26 +08:00
|
|
|
|
{broadcast_dimensions = dense<1> : tensor<1xi64>}
|
|
|
|
|
: (tensor<5xf32>) -> tensor<10x5xf32>
|
2020-11-04 01:49:13 +08:00
|
|
|
|
// CHECK: "lmhlo.broadcast_in_dim"(%{{.*}}, %{{.*}}) {broadcast_dimensions = dense<1> : tensor<1xi64>}
|
2020-12-22 22:27:57 +08:00
|
|
|
|
return %result : tensor<10x5xf32>
|
2020-07-07 07:28:26 +08:00
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// -----
|
|
|
|
|
|
2020-11-09 20:23:54 +08:00
|
|
|
|
// CHECK: #[[MAP:.*]] = affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0 * s0 + d1 * s1 + d2 * s2)>
|
2020-07-07 07:28:26 +08:00
|
|
|
|
|
2020-11-04 01:49:13 +08:00
|
|
|
|
// CHECK-LABEL: func @dyn_broadcast
|
2020-12-22 22:27:57 +08:00
|
|
|
|
func @dyn_broadcast(%operand: tensor<?x?xf32>) -> tensor<?x?x?xf32> {
|
2020-11-09 20:23:54 +08:00
|
|
|
|
// CHECK-SAME: %[[OPERAND:.*]]: memref<?x?xf32>
|
2020-07-08 20:59:45 +08:00
|
|
|
|
%c1 = constant 1 : i64
|
2021-01-20 23:08:32 +08:00
|
|
|
|
%shape = tensor.from_elements %c1, %c1, %c1 : tensor<3xi64>
|
2020-12-22 22:27:57 +08:00
|
|
|
|
%result = "mhlo.dynamic_broadcast_in_dim"(%operand, %shape) {
|
2020-07-07 07:28:26 +08:00
|
|
|
|
broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>
|
|
|
|
|
} : (tensor<?x?xf32>, tensor<3xi64>) -> tensor<?x?x?xf32>
|
2020-12-22 22:27:57 +08:00
|
|
|
|
return %result : tensor<?x?x?xf32>
|
2020-11-09 20:23:54 +08:00
|
|
|
|
}
|
2021-01-20 23:08:32 +08:00
|
|
|
|
// CHECK: %[[SHAPE:.*]] = tensor.from_elements
|
2020-12-16 17:50:12 +08:00
|
|
|
|
|
2020-11-09 20:23:54 +08:00
|
|
|
|
// CHECK: %[[C0:.*]] = constant 0 : index
|
2020-12-16 17:50:12 +08:00
|
|
|
|
// CHECK: %[[C1:.*]] = constant 1 : index
|
|
|
|
|
// CHECK: %[[OPER_DIM_1:.*]] = dim %[[OPERAND]], %[[C1]] : memref<?x?xf32>
|
|
|
|
|
// CHECK: %[[OP_STRIDE_0:.*]] = muli %[[C1]], %[[OPER_DIM_1]] : index
|
|
|
|
|
// CHECK: %[[OPER_DIM_0:.*]] = dim %[[OPERAND]], %[[C0]] : memref<?x?xf32>
|
|
|
|
|
|
2020-12-17 12:29:15 +08:00
|
|
|
|
// CHECK: %[[EL0:.*]] = tensor.extract %[[SHAPE]]{{\[}}%[[C0]]] : tensor<3xi64>
|
2020-11-09 20:23:54 +08:00
|
|
|
|
// CHECK: %[[SIZE_0:.*]] = index_cast %[[EL0]] : i64 to index
|
2020-12-17 12:29:15 +08:00
|
|
|
|
// CHECK: %[[EL1:.*]] = tensor.extract %[[SHAPE]]{{\[}}%[[C1]]] : tensor<3xi64>
|
2020-12-16 17:50:12 +08:00
|
|
|
|
|
2020-11-09 20:23:54 +08:00
|
|
|
|
// CHECK: %[[SIZE_1:.*]] = index_cast %[[EL1]] : i64 to index
|
2021-01-15 10:04:30 +08:00
|
|
|
|
// CHECK: %[[EXPAND_1:.*]] = cmpi slt, %[[OPER_DIM_0]], %[[SIZE_1]] : index
|
2020-12-16 17:50:12 +08:00
|
|
|
|
// CHECK: %[[STRIDE_1:.*]] = select %[[EXPAND_1]], %[[C0]], %[[OP_STRIDE_0]] : index
|
|
|
|
|
|
2020-11-09 20:23:54 +08:00
|
|
|
|
// CHECK: %[[C2:.*]] = constant 2 : index
|
2020-12-17 12:29:15 +08:00
|
|
|
|
// CHECK: %[[EL2:.*]] = tensor.extract %[[SHAPE]]{{\[}}%[[C2]]] : tensor<3xi64>
|
2020-11-09 20:23:54 +08:00
|
|
|
|
// CHECK: %[[SIZE_2:.*]] = index_cast %[[EL2]] : i64 to index
|
2021-01-15 10:04:30 +08:00
|
|
|
|
// CHECK: %[[EXPAND_2:.*]] = cmpi slt, %[[OPER_DIM_1]], %[[SIZE_2]] : index
|
2020-11-09 20:23:54 +08:00
|
|
|
|
// CHECK: %[[STRIDE_2:.*]] = select %[[EXPAND_2]], %[[C0]], %[[C1]] : index
|
2020-12-16 17:50:12 +08:00
|
|
|
|
|
2021-02-02 17:53:06 +08:00
|
|
|
|
// CHECK: %[[TRANSFORMED_MEMREF:.*]] = memref_reinterpret_cast %[[OPERAND]] to offset: [0], sizes: {{\[}}%[[SIZE_0]], %[[SIZE_1]], %[[SIZE_2]]], strides: {{\[}}%[[C0]], %[[STRIDE_1]], %[[STRIDE_2]]] : memref<?x?xf32> to memref<?x?x?xf32, #map>
|
2020-12-16 17:50:12 +08:00
|
|
|
|
|
|
|
|
|
// CHECK: %[[RESULT:.*]] = alloc(%[[SIZE_0]], %[[SIZE_1]], %[[SIZE_2]]) : memref<?x?x?xf32>
|
|
|
|
|
|
2020-11-09 20:23:54 +08:00
|
|
|
|
// CHECK: "lmhlo.copy"(%[[TRANSFORMED_MEMREF]], %[[RESULT]]) : (memref<?x?x?xf32, #map>, memref<?x?x?xf32>) -> ()
|
2020-12-22 22:27:57 +08:00
|
|
|
|
// CHECK: return %[[RESULT]] : memref<?x?x?xf32>
|
2020-07-07 07:28:26 +08:00
|
|
|
|
|
|
|
|
|
// -----
|
|
|
|
|
|
2020-11-04 01:49:13 +08:00
|
|
|
|
// CHECK-LABEL: func @complex
|
2020-12-22 22:27:57 +08:00
|
|
|
|
func @complex(%real: tensor<2x2xf32>, %imag: tensor<2x2xf32>)
|
|
|
|
|
-> tensor<2x2xcomplex<f32>> {
|
|
|
|
|
%result = "mhlo.complex"(%real, %imag)
|
2020-07-07 07:28:26 +08:00
|
|
|
|
: (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xcomplex<f32>>
|
2020-11-04 01:49:13 +08:00
|
|
|
|
// CHECK: "lmhlo.complex"(%{{.*}}, %{{.*}})
|
2020-12-22 22:27:57 +08:00
|
|
|
|
return %result : tensor<2x2xcomplex<f32>>
|
2020-07-07 07:28:26 +08:00
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// -----
|
|
|
|
|
|
2020-11-04 01:49:13 +08:00
|
|
|
|
// CHECK-LABEL: func @complex_dyn
|
2020-12-22 22:27:57 +08:00
|
|
|
|
func @complex_dyn(%real: tensor<?xf32>, %imag: tensor<?xf32>)
|
|
|
|
|
-> tensor<?xcomplex<f32>> {
|
|
|
|
|
%result = "mhlo.complex"(%real, %imag)
|
2020-10-01 03:13:21 +08:00
|
|
|
|
: (tensor<?xf32>, tensor<?xf32>) -> tensor<?xcomplex<f32>>
|
2020-11-04 01:49:13 +08:00
|
|
|
|
// CHECK: "lmhlo.complex"(%{{.*}}, %{{.*}})
|
2020-12-22 22:27:57 +08:00
|
|
|
|
return %result : tensor<?xcomplex<f32>>
|
2020-10-01 03:13:21 +08:00
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// -----
|
|
|
|
|
|
2020-11-04 01:49:13 +08:00
|
|
|
|
// CHECK-LABEL: func @real
|
2020-12-22 22:27:57 +08:00
|
|
|
|
func @real(%operand: tensor<2x2xcomplex<f32>>) -> tensor<2x2xf32> {
|
|
|
|
|
%result = "mhlo.real"(%operand)
|
2020-07-07 07:28:26 +08:00
|
|
|
|
: (tensor<2x2xcomplex<f32>>) -> tensor<2x2xf32>
|
2020-11-04 01:49:13 +08:00
|
|
|
|
// CHECK: "lmhlo.real"(%{{.*}}, %{{.*}})
|
2020-12-22 22:27:57 +08:00
|
|
|
|
return %result : tensor<2x2xf32>
|
2020-07-07 07:28:26 +08:00
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// -----
|
|
|
|
|
|
2020-11-04 01:49:13 +08:00
|
|
|
|
// CHECK-LABEL: func @real_dyn
|
2020-12-22 22:27:57 +08:00
|
|
|
|
func @real_dyn(%operand: tensor<?xcomplex<f32>>) -> tensor<?xf32> {
|
|
|
|
|
%result = "mhlo.real"(%operand)
|
2020-09-30 17:01:45 +08:00
|
|
|
|
: (tensor<?xcomplex<f32>>) -> tensor<?xf32>
|
2020-11-04 01:49:13 +08:00
|
|
|
|
// CHECK: "lmhlo.real"(%{{.*}}, %{{.*}})
|
2020-12-22 22:27:57 +08:00
|
|
|
|
return %result : tensor<?xf32>
|
2020-09-30 17:01:45 +08:00
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// -----
|
|
|
|
|
|
2020-11-04 01:49:13 +08:00
|
|
|
|
// CHECK-LABEL: func @imag
|
2020-12-22 22:27:57 +08:00
|
|
|
|
func @imag(%operand: tensor<2x2xcomplex<f32>>) -> tensor<2x2xf32> {
|
|
|
|
|
%result = "mhlo.imag"(%operand)
|
2020-07-07 07:28:26 +08:00
|
|
|
|
: (tensor<2x2xcomplex<f32>>) -> tensor<2x2xf32>
|
2020-11-04 01:49:13 +08:00
|
|
|
|
// CHECK: "lmhlo.imag"(%{{.*}}, %{{.*}})
|
2020-12-22 22:27:57 +08:00
|
|
|
|
return %result : tensor<2x2xf32>
|
2020-07-07 07:28:26 +08:00
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// -----
|
|
|
|
|
|
2020-11-04 01:49:13 +08:00
|
|
|
|
// CHECK-LABEL: func @gather
|
2020-12-22 22:27:57 +08:00
|
|
|
|
func @gather(%operand: tensor<13x7xf32>, %idxs: tensor<5xi32>)
|
|
|
|
|
-> tensor<5x7xf32> {
|
|
|
|
|
%result =
|
|
|
|
|
"mhlo.gather"(%operand, %idxs)
|
2020-10-20 06:13:24 +08:00
|
|
|
|
{ dimension_numbers =
|
|
|
|
|
{ collapsed_slice_dims = dense<0> : tensor<1xi64>
|
|
|
|
|
, index_vector_dim = 1 : i64
|
|
|
|
|
, offset_dims = dense<1> : tensor<1xi64>
|
|
|
|
|
, start_index_map = dense<0> : tensor<1xi64> }
|
|
|
|
|
, indices_are_sorted = false
|
|
|
|
|
, name = "gather.71"
|
|
|
|
|
, slice_sizes = dense<[1, 7]> : tensor<2xi64> }
|
|
|
|
|
: (tensor<13x7xf32>, tensor<5xi32>) -> tensor<5x7xf32>
|
2020-11-04 01:49:13 +08:00
|
|
|
|
// CHECK: "lmhlo.gather"(%{{.*}}, %{{.*}}, %{{.*}})
|
2020-12-22 22:27:57 +08:00
|
|
|
|
return %result : tensor<5x7xf32>
|
2020-10-20 06:13:24 +08:00
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// -----
|
|
|
|
|
|
2020-11-04 01:49:13 +08:00
|
|
|
|
// CHECK-LABEL: func @imag_dyn
|
2020-12-22 22:27:57 +08:00
|
|
|
|
func @imag_dyn(%operand: tensor<?xcomplex<f32>>) -> tensor<?xf32> {
|
|
|
|
|
%result = "mhlo.imag"(%operand)
|
2020-09-30 17:01:45 +08:00
|
|
|
|
: (tensor<?xcomplex<f32>>) -> tensor<?xf32>
|
2020-11-04 01:49:13 +08:00
|
|
|
|
// CHECK: "lmhlo.imag"(%{{.*}}, %{{.*}})
|
2020-12-22 22:27:57 +08:00
|
|
|
|
return %result : tensor<?xf32>
|
2020-09-30 17:01:45 +08:00
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// -----
|
|
|
|
|
|
2020-11-04 01:49:13 +08:00
|
|
|
|
// CHECK-LABEL: func @iota
|
2020-12-22 22:27:57 +08:00
|
|
|
|
// TODO(herhut): Dummy should not be required here.
|
|
|
|
|
func @iota(%dummy: tensor<?xf32>) -> tensor<10xi32> {
|
|
|
|
|
%result = "mhlo.iota"()
|
2020-07-07 07:28:26 +08:00
|
|
|
|
{iota_dimension = 0 : i64} : () -> tensor<10xi32>
|
2020-11-04 01:49:13 +08:00
|
|
|
|
// CHECK: "lmhlo.iota"(%{{.*}}) {iota_dimension = 0 : i64}
|
2020-12-22 22:27:57 +08:00
|
|
|
|
return %result : tensor<10xi32>
|
2020-07-07 07:28:26 +08:00
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// -----
|
|
|
|
|
|
2020-11-04 01:49:13 +08:00
|
|
|
|
// CHECK-LABEL: func @abs
|
2020-12-22 22:27:57 +08:00
|
|
|
|
func @abs(%operand: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
|
|
|
|
%result = "mhlo.abs"(%operand)
|
2020-07-07 07:28:26 +08:00
|
|
|
|
: (tensor<2x2xf32>) -> tensor<2x2xf32>
|
2020-11-04 01:49:13 +08:00
|
|
|
|
// CHECK: "lmhlo.abs"(%{{.*}}, %{{.*}})
|
2020-12-22 22:27:57 +08:00
|
|
|
|
return %result : tensor<2x2xf32>
|
2020-07-07 07:28:26 +08:00
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// -----
|
|
|
|
|
|
2020-12-08 22:38:26 +08:00
|
|
|
|
// CHECK-LABEL: func @and
|
2020-12-22 22:27:57 +08:00
|
|
|
|
func @and(%operand0: tensor<2x2xi32>, %operand1: tensor<2x2xi32>)
|
|
|
|
|
-> tensor<2x2xi32> {
|
|
|
|
|
%result = "mhlo.and"(%operand0, %operand1)
|
2020-12-08 22:38:26 +08:00
|
|
|
|
: (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
|
|
|
|
|
// CHECK: "lmhlo.and"(%{{.*}}, %{{.*}}, %{{.*}})
|
2020-12-22 22:27:57 +08:00
|
|
|
|
return %result : tensor<2x2xi32>
|
2020-12-08 22:38:26 +08:00
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// -----
|
|
|
|
|
|
2020-11-04 01:49:13 +08:00
|
|
|
|
// CHECK-LABEL: func @ceil
|
2020-12-22 22:27:57 +08:00
|
|
|
|
func @ceil(%operand: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
|
|
|
|
%result = "mhlo.ceil"(%operand)
|
2020-07-07 07:28:26 +08:00
|
|
|
|
: (tensor<2x2xf32>) -> tensor<2x2xf32>
|
2020-11-04 01:49:13 +08:00
|
|
|
|
// CHECK: "lmhlo.ceil"(%{{.*}}, %{{.*}})
|
2020-12-22 22:27:57 +08:00
|
|
|
|
return %result : tensor<2x2xf32>
|
2020-07-07 07:28:26 +08:00
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// -----
|
|
|
|
|
|
2020-11-04 01:49:13 +08:00
|
|
|
|
// CHECK-LABEL: func @convert
|
2020-12-22 22:27:57 +08:00
|
|
|
|
func @convert(%operand: tensor<2x2xf32>) -> tensor<2x2xi32> {
|
|
|
|
|
%result = "mhlo.convert"(%operand)
|
|
|
|
|
: (tensor<2x2xf32>) -> tensor<2x2xi32>
|
|
|
|
|
// CHECK: "lmhlo.convert"(%{{.*}}, %{{.*}})
|
|
|
|
|
return %result : tensor<2x2xi32>
|
2020-07-07 07:28:26 +08:00
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// -----
|
|
|
|
|
|
2020-11-04 01:49:13 +08:00
|
|
|
|
// CHECK-LABEL: func @cos
|
2020-12-22 22:27:57 +08:00
|
|
|
|
func @cos(%operand: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
|
|
|
|
%result = "mhlo.cosine"(%operand)
|
2020-07-07 07:28:26 +08:00
|
|
|
|
: (tensor<2x2xf32>) -> tensor<2x2xf32>
|
2020-11-04 01:49:13 +08:00
|
|
|
|
// CHECK: "lmhlo.cosine"(%{{.*}}, %{{.*}})
|
2020-12-22 22:27:57 +08:00
|
|
|
|
return %result : tensor<2x2xf32>
|
2020-07-07 07:28:26 +08:00
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// -----
|
|
|
|
|
|
2020-11-04 01:49:13 +08:00
|
|
|
|
// CHECK-LABEL: func @floor
|
2020-12-22 22:27:57 +08:00
|
|
|
|
func @floor(%operand: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
|
|
|
|
%result = "mhlo.floor"(%operand)
|
2020-08-31 23:15:32 +08:00
|
|
|
|
: (tensor<2x2xf32>) -> tensor<2x2xf32>
|
2020-11-04 01:49:13 +08:00
|
|
|
|
// CHECK: "lmhlo.floor"(%{{.*}}, %{{.*}})
|
2020-12-22 22:27:57 +08:00
|
|
|
|
return %result : tensor<2x2xf32>
|
2020-08-31 23:15:32 +08:00
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// -----
|
|
|
|
|
|
2020-11-04 01:49:13 +08:00
|
|
|
|
// CHECK-LABEL: func @neg
|
2020-12-22 22:27:57 +08:00
|
|
|
|
func @neg(%operand: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
|
|
|
|
%result = "mhlo.negate"(%operand)
|
2020-07-07 07:28:26 +08:00
|
|
|
|
: (tensor<2x2xf32>) -> tensor<2x2xf32>
|
2020-11-04 01:49:13 +08:00
|
|
|
|
// CHECK: "lmhlo.negate"(%{{.*}}, %{{.*}})
|
2020-12-22 22:27:57 +08:00
|
|
|
|
return %result : tensor<2x2xf32>
|
2020-07-07 07:28:26 +08:00
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// -----
|
|
|
|
|
|
2020-11-04 01:49:13 +08:00
|
|
|
|
// CHECK-LABEL: func @not
|
2020-12-22 22:27:57 +08:00
|
|
|
|
func @not(%operand: tensor<2x2xi32>) -> tensor<2x2xi32> {
|
|
|
|
|
%result = "mhlo.not"(%operand)
|
2020-09-29 20:58:52 +08:00
|
|
|
|
: (tensor<2x2xi32>) -> tensor<2x2xi32>
|
2020-11-04 01:49:13 +08:00
|
|
|
|
// CHECK: "lmhlo.not"(%{{.*}}, %{{.*}})
|
2020-12-22 22:27:57 +08:00
|
|
|
|
return %result : tensor<2x2xi32>
|
2020-09-29 20:58:52 +08:00
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// -----
|
|
|
|
|
|
2020-12-08 22:38:26 +08:00
|
|
|
|
// CHECK-LABEL: func @or
|
2020-12-22 22:27:57 +08:00
|
|
|
|
func @or(%operand0: tensor<2x2xi32>, %operand1: tensor<2x2xi32>)
|
|
|
|
|
-> tensor<2x2xi32> {
|
|
|
|
|
%result = "mhlo.or"(%operand0, %operand1)
|
2020-12-08 22:38:26 +08:00
|
|
|
|
: (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
|
|
|
|
|
// CHECK: "lmhlo.or"(%{{.*}}, %{{.*}}, %{{.*}})
|
2020-12-22 22:27:57 +08:00
|
|
|
|
return %result : tensor<2x2xi32>
|
2020-12-08 22:38:26 +08:00
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// -----
|
|
|
|
|
|
2020-11-04 01:49:13 +08:00
|
|
|
|
// CHECK-LABEL: func @rsqrt
|
2020-12-22 22:27:57 +08:00
|
|
|
|
func @rsqrt(%operand: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
|
|
|
|
%result = "mhlo.rsqrt"(%operand)
|
2020-07-07 07:28:26 +08:00
|
|
|
|
: (tensor<2x2xf32>) -> tensor<2x2xf32>
|
2020-11-04 01:49:13 +08:00
|
|
|
|
// CHECK: "lmhlo.rsqrt"(%{{.*}}, %{{.*}})
|
2020-12-22 22:27:57 +08:00
|
|
|
|
return %result : tensor<2x2xf32>
|
2020-07-07 07:28:26 +08:00
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// -----
|
|
|
|
|
|
2020-11-04 01:49:13 +08:00
|
|
|
|
// CHECK-LABEL: func @sign
|
2020-12-22 22:27:57 +08:00
|
|
|
|
func @sign(%operand: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
|
|
|
|
%result = "mhlo.sign"(%operand)
|
2020-07-07 07:28:26 +08:00
|
|
|
|
: (tensor<2x2xf32>) -> tensor<2x2xf32>
|
2020-11-04 01:49:13 +08:00
|
|
|
|
// CHECK: "lmhlo.sign"(%{{.*}}, %{{.*}})
|
2020-12-22 22:27:57 +08:00
|
|
|
|
return %result : tensor<2x2xf32>
|
2020-07-07 07:28:26 +08:00
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// -----
|
|
|
|
|
|
2020-11-04 01:49:13 +08:00
|
|
|
|
// CHECK-LABEL: func @sqrt
|
2020-12-22 22:27:57 +08:00
|
|
|
|
func @sqrt(%operand: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
|
|
|
|
%result = "mhlo.sqrt"(%operand)
|
2020-07-07 07:28:26 +08:00
|
|
|
|
: (tensor<2x2xf32>) -> tensor<2x2xf32>
|
2020-11-04 01:49:13 +08:00
|
|
|
|
// CHECK: "lmhlo.sqrt"(%{{.*}}, %{{.*}})
|
2020-12-22 22:27:57 +08:00
|
|
|
|
return %result : tensor<2x2xf32>
|
2020-07-07 07:28:26 +08:00
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// -----
|
|
|
|
|
|
2020-12-08 05:01:25 +08:00
|
|
|
|
// CHECK-LABEL: func @shift_left
|
2020-12-22 22:27:57 +08:00
|
|
|
|
func @shift_left(%lhs: tensor<2x2xi32>, %rhs: tensor<2x2xi32>)
|
|
|
|
|
-> tensor<2x2xi32> {
|
|
|
|
|
%result = "mhlo.shift_left"(%lhs, %rhs)
|
2020-12-08 05:01:25 +08:00
|
|
|
|
: (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
|
|
|
|
|
// CHECK: "lmhlo.shift_left"(%{{.*}}, %{{.*}})
|
2020-12-22 22:27:57 +08:00
|
|
|
|
return %result : tensor<2x2xi32>
|
2020-12-08 05:01:25 +08:00
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// -----
|
|
|
|
|
|
|
|
|
|
// CHECK-LABEL: func @shift_right_arithmetic
|
2020-12-22 22:27:57 +08:00
|
|
|
|
func @shift_right_arithmetic(%lhs: tensor<2x2xi32>, %rhs: tensor<2x2xi32>)
|
|
|
|
|
-> tensor<2x2xi32> {
|
|
|
|
|
%result = "mhlo.shift_right_arithmetic"(%lhs, %rhs)
|
2020-12-08 05:01:25 +08:00
|
|
|
|
: (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
|
|
|
|
|
// CHECK: "lmhlo.shift_right_arithmetic"(%{{.*}}, %{{.*}})
|
2020-12-22 22:27:57 +08:00
|
|
|
|
return %result : tensor<2x2xi32>
|
2020-12-08 05:01:25 +08:00
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// -----
|
|
|
|
|
|
|
|
|
|
// CHECK-LABEL: func @shift_right_logical
|
2020-12-22 22:27:57 +08:00
|
|
|
|
func @shift_right_logical(%lhs: tensor<2x2xi32>, %rhs: tensor<2x2xi32>)
|
|
|
|
|
-> tensor<2x2xi32> {
|
|
|
|
|
%result = "mhlo.shift_right_logical"(%lhs, %rhs)
|
2020-12-08 05:01:25 +08:00
|
|
|
|
: (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
|
|
|
|
|
// CHECK: "lmhlo.shift_right_logical"(%{{.*}}, %{{.*}})
|
2020-12-22 22:27:57 +08:00
|
|
|
|
return %result : tensor<2x2xi32>
|
2020-12-08 05:01:25 +08:00
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// -----
|
|
|
|
|
|
2020-11-04 01:49:13 +08:00
|
|
|
|
// CHECK-LABEL: func @tanh
|
2020-12-22 22:27:57 +08:00
|
|
|
|
func @tanh(%operand: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
|
|
|
|
%result = "mhlo.tanh"(%operand)
|
2020-07-07 07:28:26 +08:00
|
|
|
|
: (tensor<2x2xf32>) -> tensor<2x2xf32>
|
2020-11-04 01:49:13 +08:00
|
|
|
|
// CHECK: "lmhlo.tanh"(%{{.*}}, %{{.*}})
|
2020-12-22 22:27:57 +08:00
|
|
|
|
return %result : tensor<2x2xf32>
|
2020-07-07 07:28:26 +08:00
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// -----
|
|
|
|
|
|
2020-11-04 01:49:13 +08:00
|
|
|
|
// CHECK-LABEL: func @remainder
|
2020-12-22 22:27:57 +08:00
|
|
|
|
func @remainder(%lhs: tensor<2x2xf32>, %rhs: tensor<2x2xf32>)
|
|
|
|
|
-> tensor<2x2xf32> {
|
|
|
|
|
%result = "mhlo.remainder"(%lhs, %rhs)
|
2020-07-07 07:28:26 +08:00
|
|
|
|
: (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
|
2020-11-04 01:49:13 +08:00
|
|
|
|
// CHECK: "lmhlo.remainder"(%{{.*}}, %{{.*}}, %{{.*}})
|
2020-12-22 22:27:57 +08:00
|
|
|
|
return %result : tensor<2x2xf32>
|
2020-07-07 07:28:26 +08:00
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// -----
|
|
|
|
|
|
2020-12-08 22:38:26 +08:00
|
|
|
|
// CHECK-LABEL: func @xor
|
2020-12-22 22:27:57 +08:00
|
|
|
|
func @xor(%operand0: tensor<2x2xi32>, %operand1: tensor<2x2xi32>)
|
|
|
|
|
-> tensor<2x2xi32> {
|
|
|
|
|
%result = "mhlo.xor"(%operand0, %operand1)
|
2020-12-08 22:38:26 +08:00
|
|
|
|
: (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
|
|
|
|
|
// CHECK: "lmhlo.xor"(%{{.*}}, %{{.*}})
|
2020-12-22 22:27:57 +08:00
|
|
|
|
return %result : tensor<2x2xi32>
|
2020-12-08 22:38:26 +08:00
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// -----
|
|
|
|
|
|
2020-07-07 07:28:26 +08:00
|
|
|
|
// Dynamic shape binary element-wise operation.
|
2020-11-04 01:49:13 +08:00
|
|
|
|
// CHECK-LABEL: func @add_dyn
|
2020-12-22 22:27:57 +08:00
|
|
|
|
func @add_dyn(%lhs: tensor<?x?xf32>, %rhs: tensor<?x?xf32>) -> tensor<?x?xf32> {
|
2020-07-07 12:51:24 +08:00
|
|
|
|
%result = "mhlo.add"(%lhs, %rhs)
|
2020-07-07 07:28:26 +08:00
|
|
|
|
: (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
|
2021-03-11 18:00:50 +08:00
|
|
|
|
// CHECK: %[[SHAPE:.*]] = shape.shape_of %arg0 : memref<?x?xf32> -> tensor<2xindex>
|
2021-03-08 18:41:10 +08:00
|
|
|
|
// CHECK: %[[EE0:.*]] = tensor.extract %[[SHAPE]][%[[C0]]] : tensor<2xindex>
|
|
|
|
|
// CHECK: %[[EE1:.*]] = tensor.extract %[[SHAPE]][%[[C1]]] : tensor<2xindex>
|
|
|
|
|
// CHECK: %[[RESULT:.*]] = alloc(%[[EE0]], %[[EE1]])
|
2020-11-04 01:49:13 +08:00
|
|
|
|
// CHECK: "lmhlo.add"(%arg0, %arg1, %[[RESULT]]) : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>) -> ()
|
2020-12-22 22:27:57 +08:00
|
|
|
|
return %result : tensor<?x?xf32>
|
|
|
|
|
// CHECK: return %[[RESULT]]
|
2020-07-07 07:28:26 +08:00
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// -----
|
|
|
|
|
|
|
|
|
|
// Dynamic shape unary element-wise operation.
|
2020-11-04 01:49:13 +08:00
|
|
|
|
// CHECK-LABEL: func @tanh_dyn
|
2020-12-22 22:27:57 +08:00
|
|
|
|
func @tanh_dyn(%arg0: tensor<?x?xf32>) -> tensor<?x?xf32> {
|
2020-07-07 12:51:24 +08:00
|
|
|
|
%result = "mhlo.tanh"(%arg0)
|
2020-07-07 07:28:26 +08:00
|
|
|
|
: (tensor<?x?xf32>) -> tensor<?x?xf32>
|
2021-03-11 18:00:50 +08:00
|
|
|
|
// CHECK: %[[SHAPE:.*]] = shape.shape_of %arg0 : memref<?x?xf32> -> tensor<2xindex>
|
2021-03-08 18:41:10 +08:00
|
|
|
|
// CHECK: %[[EE0:.*]] = tensor.extract %[[SHAPE]][%[[C0]]] : tensor<2xindex>
|
|
|
|
|
// CHECK: %[[EE1:.*]] = tensor.extract %[[SHAPE]][%[[C1]]] : tensor<2xindex>
|
|
|
|
|
// CHECK: %[[RESULT:.*]] = alloc(%[[EE0]], %[[EE1]])
|
2020-11-04 01:49:13 +08:00
|
|
|
|
// CHECK: "lmhlo.tanh"(%arg0, %[[RESULT]]) : (memref<?x?xf32>, memref<?x?xf32>) -> ()
|
2020-12-22 22:27:57 +08:00
|
|
|
|
return %result : tensor<?x?xf32>
|
|
|
|
|
// CHECK: return %[[RESULT]]
|
2020-07-07 07:28:26 +08:00
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// -----
|
|
|
|
|
|
2020-11-04 01:49:13 +08:00
|
|
|
|
// CHECK-LABEL: func @dot
|
2020-07-07 07:28:26 +08:00
|
|
|
|
func @dot(%arg0: tensor<1024x1024xf32>) -> tensor<1024x1024xf32> {
|
2020-11-04 01:49:13 +08:00
|
|
|
|
// CHECK-SAME: (%[[ARG0:.*]]: [[TYPE:.*]]) -> [[TYPE]]
|
|
|
|
|
// CHECK-NEXT: %[[ALLOC:.*]] = alloc
|
|
|
|
|
// CHECK: "lmhlo.dot"(%[[ARG0]], %[[ARG0]], %[[ALLOC]]) {
|
2020-10-16 06:08:30 +08:00
|
|
|
|
// dot_dimension_numbers = {
|
|
|
|
|
// lhs_batching_dimensions = dense<> : tensor<0xi64>,
|
|
|
|
|
// lhs_contracting_dimensions = dense<1> : tensor<1xi64>,
|
|
|
|
|
// rhs_batching_dimensions = dense<> : tensor<0xi64>,
|
|
|
|
|
// rhs_contracting_dimensions = dense<0> : tensor<1xi64>}}
|
|
|
|
|
// : ([[TYPE]], [[TYPE]], [[TYPE]]) -> ()
|
2020-07-07 12:51:24 +08:00
|
|
|
|
%dot = "mhlo.dot"(%arg0, %arg0)
|
2020-12-22 22:27:57 +08:00
|
|
|
|
: (tensor<1024x1024xf32>, tensor<1024x1024xf32>)
|
|
|
|
|
-> tensor<1024x1024xf32>
|
2020-11-04 01:49:13 +08:00
|
|
|
|
// CHECK: return %[[ALLOC]]
|
2020-07-07 07:28:26 +08:00
|
|
|
|
return %dot : tensor<1024x1024xf32>
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// -----
|
|
|
|
|
|
2020-11-04 01:49:13 +08:00
|
|
|
|
// CHECK-LABEL: func @conv
|
2020-12-22 22:27:57 +08:00
|
|
|
|
func @conv(%input: tensor<3x5x5x3xf32>, %filter : tensor<2x2x3x4xf32>)
|
|
|
|
|
-> tensor<3x5x5x4xf32> {
|
2020-07-07 07:28:26 +08:00
|
|
|
|
%c0 = constant 0 : index
|
2020-11-04 01:49:13 +08:00
|
|
|
|
// CHECK: %[[OUT:.*]] = alloc() : memref<3x5x5x4xf32>
|
|
|
|
|
// CHECK: "lmhlo.convolution"(%{{.+}}, %{{.+}}, %[[OUT]])
|
|
|
|
|
// CHECK-SAME: padding = dense<[
|
|
|
|
|
// CHECK-SAME: [0, 1], [0, 1]]> : tensor<2x2xi64>
|
|
|
|
|
// CHECK-SAME: rhs_dilation = dense<[1, 2]>
|
|
|
|
|
// CHECK-SAME: window_strides = dense<[2, 1]>
|
2020-07-07 12:51:24 +08:00
|
|
|
|
%out = "mhlo.convolution"(%filter, %input) {
|
2020-07-07 07:28:26 +08:00
|
|
|
|
batch_group_count = 1 : i64,
|
|
|
|
|
dimension_numbers = {
|
|
|
|
|
input_batch_dimension = 0 : i64,
|
|
|
|
|
input_feature_dimension = 3 : i64,
|
|
|
|
|
input_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>,
|
|
|
|
|
kernel_input_feature_dimension = 2 : i64,
|
|
|
|
|
kernel_output_feature_dimension = 3 : i64,
|
|
|
|
|
kernel_spatial_dimensions = dense<[0, 1]> : tensor<2xi64>,
|
|
|
|
|
output_batch_dimension = 0 : i64,
|
|
|
|
|
output_feature_dimension = 3 : i64,
|
|
|
|
|
output_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>
|
|
|
|
|
},
|
|
|
|
|
feature_group_count = 1 : i64,
|
|
|
|
|
padding = dense<[[0, 1], [0, 1]]> : tensor<2x2xi64>,
|
|
|
|
|
rhs_dilation = dense<[1, 2]> : tensor<2xi64>,
|
|
|
|
|
window_strides = dense<[2, 1]> : tensor<2xi64>
|
|
|
|
|
} : (tensor<2x2x3x4xf32>, tensor<3x5x5x3xf32>) -> tensor<3x5x5x4xf32>
|
|
|
|
|
return %out : tensor<3x5x5x4xf32>
|
|
|
|
|
}
|
2020-07-16 19:40:32 +08:00
|
|
|
|
|
|
|
|
|
// -----
|
|
|
|
|
|
2020-11-04 01:49:13 +08:00
|
|
|
|
// CHECK-LABEL: func @reduce
|
2020-07-16 19:40:32 +08:00
|
|
|
|
func @reduce(%arg0: tensor<1x8xf32>, %arg1: tensor<f32>) -> tensor<1xf32> {
|
2020-11-04 01:49:13 +08:00
|
|
|
|
// CHECK: %[[OUT:.*]] = alloc() : memref<1xf32>
|
|
|
|
|
// CHECK: "lmhlo.reduce"(%{{.+}}, %{{.+}}, %[[OUT]]) ( {
|
|
|
|
|
// CHECK: ^bb0(%[[ARG1:.*]]: memref<f32>, %[[ARG2:.*]]: memref<f32>,
|
|
|
|
|
// CHECK-SAME: %[[ARG3:.*]]: memref<f32>):
|
|
|
|
|
// CHECK: %[[TMP:.*]] = alloc() : memref<f32>
|
|
|
|
|
// CHECK: "lmhlo.add"(%[[ARG1]], %[[ARG2]], %[[TMP]])
|
|
|
|
|
// CHECK: "lmhlo.copy"(%[[TMP]], %[[ARG3]])
|
|
|
|
|
// CHECK: "lmhlo.terminator"() : () -> ()
|
|
|
|
|
// CHECK: }) {dimensions = dense<1> : tensor<1xi64>}
|
|
|
|
|
// CHECK-SAME: : (memref<1x8xf32>, memref<f32>, memref<1xf32>) -> ()
|
2020-07-16 19:40:32 +08:00
|
|
|
|
%0 = "mhlo.reduce"(%arg0, %arg1) ( {
|
|
|
|
|
^bb0(%arg2: tensor<f32>, %arg3: tensor<f32>): // no predecessors
|
|
|
|
|
%1 = mhlo.add %arg2, %arg3 : tensor<f32>
|
|
|
|
|
"mhlo.return"(%1) : (tensor<f32>) -> ()
|
|
|
|
|
}) {dimensions = dense<1> : tensor<1xi64>}
|
|
|
|
|
: (tensor<1x8xf32>, tensor<f32>) -> tensor<1xf32>
|
|
|
|
|
return %0 : tensor<1xf32>
|
|
|
|
|
}
|
2020-09-05 05:58:10 +08:00
|
|
|
|
|
|
|
|
|
// -----
|
|
|
|
|
|
2020-11-04 01:49:13 +08:00
|
|
|
|
// CHECK-LABEL: func @transpose
|
2020-12-22 22:27:57 +08:00
|
|
|
|
func @transpose(%operand: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
|
|
|
|
%result = "mhlo.transpose"(%operand) {permutation = dense<[1, 0]> : tensor<2xi64>}
|
|
|
|
|
: (tensor<2x2xf32>) -> tensor<2x2xf32>
|
2020-11-04 01:49:13 +08:00
|
|
|
|
// CHECK: "lmhlo.transpose"(%{{.*}}, %{{.*}}) {permutation = dense<[1, 0]> : tensor<2xi64>}
|
2020-12-22 22:27:57 +08:00
|
|
|
|
return %result : tensor<2x2xf32>
|
2020-09-05 05:58:10 +08:00
|
|
|
|
}
|
2020-10-01 11:55:49 +08:00
|
|
|
|
|
|
|
|
|
// -----
|
|
|
|
|
|
2020-11-04 01:49:13 +08:00
|
|
|
|
// CHECK-LABEL: func @custom_call
|
2020-12-22 22:27:57 +08:00
|
|
|
|
// CHECK-SAME:([[ARG0:%.*]]: memref<2x2xf32>, [[ARG1:%.*]]: memref<2x3xf32>)
|
|
|
|
|
func @custom_call(%arg0: tensor<2x2xf32>, %arg1: tensor<2x3xf32>) -> tensor<4x4xf16> {
|
2020-11-13 01:45:39 +08:00
|
|
|
|
// CHECK: "lmhlo.custom_call"([[ARG0]], [[ARG1]], %{{.*}}) {backend_config = "", call_target_name = "foo", has_side_effect = false, operand_segment_sizes = dense<[2, 1]> : vector<2xi32>}
|
2020-12-22 22:27:57 +08:00
|
|
|
|
%result = "mhlo.custom_call"(%arg0, %arg1)
|
|
|
|
|
{backend_config = "", call_target_name = "foo", has_side_effect = false}
|
|
|
|
|
: (tensor<2x2xf32>, tensor<2x3xf32>) -> tensor<4x4xf16>
|
|
|
|
|
return %result : tensor<4x4xf16>
|
2020-10-01 11:55:49 +08:00
|
|
|
|
}
|
2020-10-02 18:07:56 +08:00
|
|
|
|
|
2020-12-22 22:27:57 +08:00
|
|
|
|
// -----
|
2020-10-02 18:07:56 +08:00
|
|
|
|
|
2020-11-13 01:45:39 +08:00
|
|
|
|
// CHECK-LABEL: func @custom_call_multiout
|
2020-12-22 22:27:57 +08:00
|
|
|
|
// CHECK-SAME:([[ARG0:%.*]]: memref<2x2xf32>, [[ARG1:%.*]]: memref<2x3xf32>)
|
|
|
|
|
func @custom_call_multiout(%arg0: tensor<2x2xf32>, %arg1: tensor<2x3xf32>) -> tensor<4x4xf16> {
|
2020-11-13 01:45:39 +08:00
|
|
|
|
// CHECK: "lmhlo.custom_call"([[ARG0]], [[ARG1]], %{{.*}}, %{{.*}}) {backend_config = "", call_target_name = "foo", has_side_effect = false, operand_segment_sizes = dense<2> : vector<2xi32>}
|
2020-12-22 22:27:57 +08:00
|
|
|
|
%temp:2 = "mhlo.custom_call"(%arg0, %arg1)
|
2020-11-13 01:45:39 +08:00
|
|
|
|
{backend_config = "", call_target_name = "foo", has_side_effect = false}
|
|
|
|
|
: (tensor<2x2xf32>, tensor<2x3xf32>) -> (tensor<4x4xf16>, tensor<4x4xf16>)
|
2020-12-22 22:27:57 +08:00
|
|
|
|
%result = "mhlo.add"(%temp#0, %temp#1) : (tensor<4x4xf16>, tensor<4x4xf16>) -> tensor<4x4xf16>
|
|
|
|
|
return %result : tensor<4x4xf16>
|
2020-11-13 01:45:39 +08:00
|
|
|
|
}
|
|
|
|
|
|
2020-12-22 22:27:57 +08:00
|
|
|
|
// -----
|
2020-11-13 01:45:39 +08:00
|
|
|
|
|
2020-11-04 01:49:13 +08:00
|
|
|
|
// CHECK-LABEL: func @isfinite
|
2020-12-22 22:27:57 +08:00
|
|
|
|
func @isfinite(%arg0: tensor<2x2xf32>) -> tensor<2x2xi1> {
|
2020-11-04 01:49:13 +08:00
|
|
|
|
// CHECK: "lmhlo.is_finite"(%{{.*}}, %{{.*}})
|
2020-12-22 22:27:57 +08:00
|
|
|
|
%result = "mhlo.is_finite"(%arg0) : (tensor<2x2xf32>) -> tensor<2x2xi1>
|
|
|
|
|
return %result : tensor<2x2xi1>
|
2020-10-02 18:07:56 +08:00
|
|
|
|
}
|
2020-10-09 22:13:14 +08:00
|
|
|
|
|
|
|
|
|
// -----
|
|
|
|
|
|
2020-12-22 22:27:57 +08:00
|
|
|
|
// Test that assuming ops propagate tensor types.
|
|
|
|
|
// CHECK-LABEL: func @shape_assuming_tensor
|
|
|
|
|
func @shape_assuming_tensor(%arg0: tensor<?xf16>) -> tensor<?xf16> {
|
2020-10-09 22:13:14 +08:00
|
|
|
|
%0 = mhlo.constant dense<0.000000e+00> : tensor<f16>
|
|
|
|
|
%1 = shape.const_witness true
|
2020-11-04 01:49:13 +08:00
|
|
|
|
// CHECK: shape.assuming %{{.*}} -> (memref<?xf16>)
|
2020-10-09 22:13:14 +08:00
|
|
|
|
%2 = shape.assuming %1 -> (tensor<?xf16>) {
|
|
|
|
|
%3 = shape.shape_of %arg0 : tensor<?xf16> -> tensor<?xindex>
|
2020-12-24 15:53:08 +08:00
|
|
|
|
%4 = tensor.cast %3 : tensor<?xindex> to tensor<1xindex>
|
2020-10-09 22:13:14 +08:00
|
|
|
|
%5 = "mhlo.dynamic_broadcast_in_dim"(%0, %4) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f16>, tensor<1xindex>) -> tensor<?xf16>
|
|
|
|
|
%6 = "mhlo.dynamic_broadcast_in_dim"(%arg0, %4) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor<?xf16>, tensor<1xindex>) -> tensor<?xf16>
|
2020-11-09 20:23:54 +08:00
|
|
|
|
// CHECK: "lmhlo.maximum"(%{{.*}}, %{{.*}}, %{{.*}}) : (memref<?xf16>, memref<?xf16>, memref<?xf16>) -> ()
|
2020-10-09 22:13:14 +08:00
|
|
|
|
%7 = mhlo.maximum %5, %6 : tensor<?xf16>
|
2020-11-04 01:49:13 +08:00
|
|
|
|
// CHECK: shape.assuming_yield %{{.*}} : memref<?xf16>
|
2020-10-09 22:13:14 +08:00
|
|
|
|
shape.assuming_yield %7 : tensor<?xf16>
|
|
|
|
|
}
|
|
|
|
|
return %2 : tensor<?xf16>
|
2020-10-16 06:08:30 +08:00
|
|
|
|
}
|
2020-12-24 15:53:08 +08:00
|
|
|
|
|
|
|
|
|
|