2020-07-07 07:28:26 +08:00
|
|
|
// RUN: mlir-hlo-opt -transform-unranked-hlo -split-input-file %s | FileCheck %s
|
|
|
|
|
|
|
|
// Check the validity of expected IR.
|
|
|
|
// CHECK-LABEL: @sqr_transform_result
|
|
|
|
func @sqr_transform_result(%a: tensor<*xf32>) -> tensor<*xf32> {
|
|
|
|
|
|
|
|
// Flatten operand shape.
|
2020-07-27 22:13:38 +08:00
|
|
|
%shape = shape.shape_of %a : tensor<*xf32> -> tensor<?xindex>
|
|
|
|
%num_elements = shape.num_elements %shape : tensor<?xindex> -> index
|
2020-09-10 23:13:44 +08:00
|
|
|
%flat_shape = tensor_from_elements %num_elements : tensor<1xindex>
|
2020-07-07 12:51:24 +08:00
|
|
|
%flat_a = "mhlo.dynamic_reshape"(%a, %flat_shape)
|
2020-07-07 07:28:26 +08:00
|
|
|
: (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
|
|
|
|
|
|
|
|
// Apply operation.
|
2020-07-07 12:51:24 +08:00
|
|
|
%flat_b = "mhlo.sqrt"(%flat_a) : (tensor<?xf32>) -> tensor<?xf32>
|
2020-07-07 07:28:26 +08:00
|
|
|
|
|
|
|
// Restore original shape.
|
2020-08-06 02:10:20 +08:00
|
|
|
%b = "mhlo.dynamic_reshape"(%flat_b, %shape)
|
2020-07-07 07:28:26 +08:00
|
|
|
: (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
|
|
|
|
|
|
|
|
return %b : tensor<*xf32>
|
|
|
|
}
|
|
|
|
|
|
|
|
// -----
|
|
|
|
|
|
|
|
// Check transformation of unranked code.
|
|
|
|
// CHECK-LABEL: @sqrt
|
|
|
|
// CHECK-SAME: (%[[A:.*]]: tensor<*xf32>)
|
|
|
|
func @sqrt(%a: tensor<*xf32>) -> tensor<*xf32> {
|
2020-08-06 02:10:20 +08:00
|
|
|
// CHECK-NEXT: %[[SHAPE:.*]] = shape.shape_of %[[A]] : tensor<*xf32> -> tensor<?xindex>
|
2020-07-07 07:28:26 +08:00
|
|
|
// CHECK-NEXT: %[[NUM_ELEMENTS:.*]] = shape.num_elements %[[SHAPE]]
|
2020-09-10 23:13:44 +08:00
|
|
|
// CHECK-NEXT: %[[FLAT_SHAPE:.*]] = tensor_from_elements %[[NUM_ELEMENTS]] : tensor<1xindex>
|
2020-07-07 12:51:24 +08:00
|
|
|
// CHECK-NEXT: %[[FLAT_A:.*]] = "mhlo.dynamic_reshape"(%[[A]], %[[FLAT_SHAPE]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
|
|
|
|
// CHECK-NEXT: %[[FLAT_B:.*]] = "mhlo.sqrt"(%[[FLAT_A]]) : (tensor<?xf32>) -> tensor<?xf32>
|
2020-08-06 02:10:20 +08:00
|
|
|
// CHECK-NEXT: %[[B:.*]] = "mhlo.dynamic_reshape"(%[[FLAT_B]], %[[SHAPE]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
|
2020-07-07 07:28:26 +08:00
|
|
|
// CHECK-NEXT: return %[[B]] : tensor<*xf32>
|
2020-07-07 12:51:24 +08:00
|
|
|
%b = "mhlo.sqrt"(%a) : (tensor<*xf32>) -> tensor<*xf32>
|
2020-07-07 07:28:26 +08:00
|
|
|
return %b : tensor<*xf32>
|
|
|
|
}
|
|
|
|
|
|
|
|
// -----
|
|
|
|
|
|
|
|
// Not transformed when ranked.
|
|
|
|
// CHECK-LABEL: @sqrt_ranked
|
|
|
|
// CHECK-SAME: (%[[A:.*]]: tensor<3x?xf32>)
|
|
|
|
func @sqrt_ranked(%a: tensor<3x?xf32>) -> tensor<3x?xf32> {
|
2020-07-07 12:51:24 +08:00
|
|
|
// CHECK-NEXT: %[[B:.*]] = "mhlo.sqrt"(%[[A]]) : (tensor<3x?xf32>) -> tensor<3x?xf32>
|
2020-07-07 07:28:26 +08:00
|
|
|
// CHECK-NEXT: return %[[B]] : tensor<3x?xf32>
|
2020-07-07 12:51:24 +08:00
|
|
|
%b = "mhlo.sqrt"(%a) : (tensor<3x?xf32>) -> tensor<3x?xf32>
|
2020-07-07 07:28:26 +08:00
|
|
|
return %b : tensor<3x?xf32>
|
|
|
|
}
|
|
|
|
|
|
|
|
// -----
|
|
|
|
|
|
|
|
// Not transformed when statically shaped.
|
|
|
|
// CHECK-LABEL: @sqrt_static
|
|
|
|
// CHECK-SAME: (%[[A:.*]]: tensor<2x3xf32>)
|
|
|
|
func @sqrt_static(%a: tensor<2x3xf32>) -> tensor<2x3xf32> {
|
2020-07-07 12:51:24 +08:00
|
|
|
// CHECK-NEXT: %[[B:.*]] = "mhlo.sqrt"(%[[A]]) : (tensor<2x3xf32>) -> tensor<2x3xf32>
|
2020-07-07 07:28:26 +08:00
|
|
|
// CHECK-NEXT: return %[[B]] : tensor<2x3xf32>
|
2020-07-07 12:51:24 +08:00
|
|
|
%b = "mhlo.sqrt"(%a) : (tensor<2x3xf32>) -> tensor<2x3xf32>
|
2020-07-07 07:28:26 +08:00
|
|
|
return %b : tensor<2x3xf32>
|
|
|
|
}
|
|
|
|
|
|
|
|
// -----
|
|
|
|
|
|
|
|
// CHECK-LABEL: @add_unranked
|
|
|
|
// CHECK-SAME: (%[[A:.*]]: tensor<*xf32>, %[[B:.*]]: tensor<*xf32>) -> tensor<*xf32>
|
|
|
|
func @add_unranked(%a : tensor<*xf32>, %b : tensor<*xf32>) -> tensor<*xf32> {
|
|
|
|
// CHECK: %[[SHAPE_A:.*]] = shape.shape_of %[[A]]
|
|
|
|
// CHECK: %[[SHAPE_B:.*]] = shape.shape_of %[[B]]
|
2020-08-18 09:21:19 +08:00
|
|
|
// CHECK: %[[SHAPE:.*]] = shape.any %[[SHAPE_A]], %[[SHAPE_B]]
|
2020-07-07 07:28:26 +08:00
|
|
|
// CHECK: %[[NUM_ELEMENTS:.*]] = shape.num_elements %[[SHAPE]]
|
2020-09-10 23:13:44 +08:00
|
|
|
// CHECK: %[[FLAT_SHAPE:.*]] = tensor_from_elements %[[NUM_ELEMENTS]] : tensor<1xindex>
|
2020-07-07 12:51:24 +08:00
|
|
|
// CHECK: %[[FLAT_A:.*]] = "mhlo.dynamic_reshape"(%[[A]], %[[FLAT_SHAPE]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
|
|
|
|
// CHECK: %[[FLAT_B:.*]] = "mhlo.dynamic_reshape"(%[[B]], %[[FLAT_SHAPE]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
|
|
|
|
// CHECK: %[[FLAT_RESULT:.*]] = mhlo.add %[[FLAT_A]], %[[FLAT_B]] : tensor<?xf32>
|
2020-08-06 02:10:20 +08:00
|
|
|
// CHECK: %[[RESULT:.*]] = "mhlo.dynamic_reshape"(%[[FLAT_RESULT]], %[[SHAPE]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
|
2020-07-07 07:28:26 +08:00
|
|
|
// CHECK: return %[[RESULT]] : tensor<*xf32>
|
2020-07-07 12:51:24 +08:00
|
|
|
%result = mhlo.add %a, %b : tensor<*xf32>
|
2020-07-07 07:28:26 +08:00
|
|
|
return %result : tensor<*xf32>
|
|
|
|
}
|