Updates LLVM usage to match
[eed333149d17](https://github.com/llvm/llvm-project/commit/eed333149d17)

PiperOrigin-RevId: 323354988
This commit is contained in:
Thomas Joerg 2020-07-27 07:13:38 -07:00 committed by TensorFlow MLIR Team
parent 8023baa959
commit 739758f9cc
7 changed files with 54 additions and 35 deletions

View File

@ -747,7 +747,8 @@ class DynamicBroadcastInDimOpNotActuallyDynamic
void DynamicBroadcastInDimOp::getCanonicalizationPatterns(
OwningRewritePatternList& results, MLIRContext* context) {
results.insert<DynamicBroadcastInDimOpNotActuallyDynamic,
DynamicBroadcastToOwnShape>(context);
DynamicBroadcastToOwnShape_1, DynamicBroadcastToOwnShape_2>(
context);
}
//===----------------------------------------------------------------------===//

View File

@ -22,8 +22,11 @@ def EqualBinaryOperands : Constraint<CPred<"$0 == $1">>;
// Canonicalization patterns.
def DynamicBroadcastToOwnShape : Pat<
(HLO_DynamicBroadcastInDimOp:$op $arg0,
(Shape_ToExtentTensorOp (Shape_ShapeOfOp $arg1)), $attr),
def DynamicBroadcastToOwnShape_1 : Pat<
(HLO_DynamicBroadcastInDimOp:$op $arg0,
(Shape_ToExtentTensorOp (Shape_ShapeOfOp $arg1)), $attr),
(replaceWithValue $arg0), [(EqualBinaryOperands $arg0, $arg1)]>;
def DynamicBroadcastToOwnShape_2 : Pat<
(HLO_DynamicBroadcastInDimOp:$op $arg0, (Shape_ShapeOfOp $arg1), $attr),
(replaceWithValue $arg0), [(EqualBinaryOperands $arg0, $arg1)]>;

View File

@ -29,7 +29,6 @@ limitations under the License.
namespace mlir {
namespace chlo {
namespace {
// Converts binary ops that statically are determined to not broadcast directly

View File

@ -61,10 +61,9 @@ struct UnaryElementwiseOpConversion : public OpRewritePattern<OpTy> {
// Generate IR to flatten the operand.
auto loc = op.getLoc();
Value shape = rewriter.create<shape::ShapeOfOp>(loc, operand);
Value numElements = rewriter.create<shape::NumElementsOp>(
loc, rewriter.getType<shape::SizeType>(), shape);
Value numElementsAsIndex = rewriter.create<shape::SizeToIndexOp>(
loc, rewriter.getIndexType(), numElements);
Value numElements = rewriter.create<shape::NumElementsOp>(loc, shape);
Value numElementsAsIndex =
rewriter.create<shape::SizeToIndexOp>(loc, numElements);
Value flatShapeAsDimTensor =
rewriter.create<TensorFromElementsOp>(loc, numElementsAsIndex);
auto flatTensorTy = RankedTensorType::get({ShapedType::kDynamicSize},

View File

@ -365,11 +365,20 @@ func @dynamic_broadcast_in_dim_op_not_actually_dynamic(%arg0: tensor<4xf32>, %ar
return %0 : tensor<5x4xf32>
}
// CHECK-LABEL: func @dynamic_broadcast_in_dim_to_same_shape
func @dynamic_broadcast_in_dim_to_same_shape(%arg0: tensor<?xf32>) -> tensor<?xf32> {
// CHECK-SAME: %[[ARG:.*]]: tensor<?xf32>
%0 = shape.shape_of %arg0 : tensor<?xf32>
%1 = shape.to_extent_tensor %0 : tensor<1xindex>
// CHECK-LABEL: func @dynamic_broadcast_in_dim_to_same_shape_1
func @dynamic_broadcast_in_dim_to_same_shape_1(%arg0: tensor<?xf32>) -> tensor<?xf32> {
// CHECK-SAME: %[[ARG:.*]]: tensor<?xf32>
%0 = shape.shape_of %arg0 : tensor<?xf32> -> tensor<1xindex>
%2 = "mhlo.dynamic_broadcast_in_dim"(%arg0, %0) { broadcast_dimensions = dense<0> : tensor<1xi64> } : (tensor<?xf32>, tensor<1xindex>) -> tensor<?xf32>
// CHECK: return %[[ARG]] : tensor<?xf32>
return %2 : tensor<?xf32>
}
// CHECK-LABEL: func @dynamic_broadcast_in_dim_to_same_shape_2
func @dynamic_broadcast_in_dim_to_same_shape_2(%arg0: tensor<?xf32>) -> tensor<?xf32> {
// CHECK-SAME: %[[ARG:.*]]: tensor<?xf32>
%0 = shape.shape_of %arg0 : tensor<?xf32> -> !shape.shape
%1 = shape.to_extent_tensor %0 : !shape.shape -> tensor<1xindex>
%2 = "mhlo.dynamic_broadcast_in_dim"(%arg0, %1) { broadcast_dimensions = dense<0> : tensor<1xi64> } : (tensor<?xf32>, tensor<1xindex>) -> tensor<?xf32>
// CHECK: return %[[ARG]] : tensor<?xf32>
return %2 : tensor<?xf32>

View File

@ -18,7 +18,9 @@ func @dynamicBroadcast(%arg0: tensor<?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?
// CHECK-DAG: %[[ARG1_S:.+]] = shape.shape_of %[[ARG1]]
// CHECK-NEXT: %[[WITNESS:.+]] = shape.cstr_broadcastable %[[ARG0_S]], %[[ARG1_S]]
// CHECK-NEXT: %[[FINAL_RESULT:.+]] = shape.assuming %[[WITNESS]]
// CHECK-DAG: %[[RESULT_S:.+]] = shape.broadcast %[[ARG0_S]], %[[ARG1_S]]
// CHECK-DAG: %[[ARG0_SS:.+]] = shape.shape_of %[[ARG0]]
// CHECK-DAG: %[[ARG1_SS:.+]] = shape.shape_of %[[ARG1]]
// CHECK-DAG: %[[RESULT_S:.+]] = shape.broadcast %[[ARG0_SS]], %[[ARG1_SS]]
// CHECK: %[[RESULT_EXTENTS:.+]] = shape.to_extent_tensor %[[RESULT_S]]
// CHECK-DAG: %[[ARG0_B:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG0]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<1> : tensor<1xi64>}
// CHECK-DAG: %[[ARG1_B:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG1]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>}
@ -39,7 +41,9 @@ func @dynamicBroadcastComplex(%arg0: tensor<?xf32>, %arg1: tensor<?x?xf32>) -> t
// CHECK-DAG: %[[ARG1_S:.+]] = shape.shape_of %[[ARG1]]
// CHECK-NEXT: %[[WITNESS:.+]] = shape.cstr_broadcastable %[[ARG0_S]], %[[ARG1_S]]
// CHECK-NEXT: %[[FINAL_RESULT:.+]] = shape.assuming %[[WITNESS]]
// CHECK-NEXT: %[[RESULT_S:.+]] = shape.broadcast %[[ARG0_S]], %[[ARG1_S]]
// CHECK-DAG: %[[ARG0_SS:.+]] = shape.shape_of %[[ARG0]]
// CHECK-DAG: %[[ARG1_SS:.+]] = shape.shape_of %[[ARG1]]
// CHECK-NEXT: %[[RESULT_S:.+]] = shape.broadcast %[[ARG0_SS]], %[[ARG1_SS]]
// CHECK-NEXT: %[[RESULT_EXTENTS:.+]] = shape.to_extent_tensor %[[RESULT_S]]
// CHECK-DAG: %[[ARG0_B:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG0]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<?xf32>, tensor<2xindex>) -> tensor<?x?xf32>
// CHECK-DAG: %[[ARG1_B:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG1]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<?x?xf32>, tensor<2xindex>) -> tensor<?x?xf32>
@ -60,7 +64,9 @@ func @dynamicBroadcastCompare(%arg0: tensor<?xf32>, %arg1: tensor<?x?xf32>) -> t
// CHECK-DAG: %[[ARG1_S:.+]] = shape.shape_of %[[ARG1]]
// CHECK: %[[WITNESS:.+]] = shape.cstr_broadcastable %[[ARG0_S]], %[[ARG1_S]]
// CHECK: %[[FINAL_RESULT:.+]] = shape.assuming %[[WITNESS]]
// CHECK: %[[RESULT_S:.+]] = shape.broadcast %[[ARG0_S]], %[[ARG1_S]]
// CHECK-DAG: %[[ARG0_SS:.+]] = shape.shape_of %[[ARG0]]
// CHECK-DAG: %[[ARG1_SS:.+]] = shape.shape_of %[[ARG1]]
// CHECK: %[[RESULT_S:.+]] = shape.broadcast %[[ARG0_SS]], %[[ARG1_SS]]
// CHECK: %[[RESULT_EXTENTS:.+]] = shape.to_extent_tensor %[[RESULT_S]]
// CHECK-DAG: %[[ARG0_B:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG0]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<?xf32>, tensor<2xindex>) -> tensor<?x?xf32>
// CHECK-DAG: %[[ARG1_B:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG1]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<?x?xf32>, tensor<2xindex>) -> tensor<?x?xf32>
@ -253,8 +259,8 @@ func @addScalarUnranked(%arg0: tensor<f32>, %arg1: tensor<*xf32>) -> tensor<*xf3
// to a 1D tensor.
// CHECK: %[[SHAPE_1:.*]] = shape.shape_of %[[ARG_1]] : tensor<*xf32>
// CHECK: %[[NUM_ELEMENTS:.*]] = shape.num_elements %[[SHAPE_1]]
// CHECK: %[[SIZE:.*]] = shape.size_to_index %[[NUM_ELEMENTS]]
// CHECK: %[[SIZE_TENSOR:.*]] = tensor_from_elements(%[[SIZE]]) : tensor<1xindex>
// CHECK: %[[NUM_ELEMENTS_INDEX:.*]] = shape.size_to_index %[[NUM_ELEMENTS]]
// CHECK: %[[SIZE_TENSOR:.*]] = tensor_from_elements(%[[NUM_ELEMENTS_INDEX]]) : tensor<1xindex>
// CHECK: %[[RESHAPED:.*]] = "mhlo.dynamic_reshape"(%[[ARG_1]], %[[SIZE_TENSOR]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
// The assuming region is part of the second stage of lowering
// with ranked broadcasting logic.
@ -263,8 +269,9 @@ func @addScalarUnranked(%arg0: tensor<f32>, %arg1: tensor<*xf32>) -> tensor<*xf3
// CHECK: %[[WITNESS:.*]] = shape.cstr_broadcastable %[[SHAPE_0]], %[[SHAPE_RESHAPED]]
// CHECK: %[[ASSUMING_RESULT:.*]] = shape.assuming %[[WITNESS]] -> (tensor<?xf32>) {
// CHECK: %[[SCALAR_SHAPE:.*]] = shape.const_shape []
// CHECK: %[[SHAPE_RESHAPED:.*]] = shape.shape_of %[[RESHAPED]] : tensor<?xf32>
// CHECK: %[[BROADCASTED_SHAPE:.*]] = shape.broadcast %[[SCALAR_SHAPE]], %[[SHAPE_RESHAPED]]
// CHECK: %[[SHAPE_TENSOR:.*]] = shape.to_extent_tensor %[[BROADCASTED_SHAPE]] : tensor<1xindex>
// CHECK: %[[SHAPE_TENSOR:.*]] = shape.to_extent_tensor %[[BROADCASTED_SHAPE]] : !shape.shape -> tensor<1xindex>
// CHECK: %[[BROADCASTED_LHS:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG_0]], %[[SHAPE_TENSOR]]) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f32>, tensor<1xindex>) -> tensor<?xf32>
// CHECK: %[[BROADCASTED_RHS:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[RESHAPED]], %[[SHAPE_TENSOR]]) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor<?xf32>, tensor<1xindex>) -> tensor<?xf32>
// CHECK: %[[BROADCASTED_RESULT:.*]] = mhlo.add %[[BROADCASTED_LHS]], %[[BROADCASTED_RHS]] : tensor<?xf32>
@ -272,8 +279,8 @@ func @addScalarUnranked(%arg0: tensor<f32>, %arg1: tensor<*xf32>) -> tensor<*xf3
// CHECK: }
// As part of the unranked logic, the result is reshaped back
// to an unranked tensor.
// CHECK: %[[PROPER_SHAPE_TENSOR:.*]] = shape.to_extent_tensor %[[SHAPE_1]] : tensor<?xindex>
// CHECK: %[[RESHAPED_RESULT:.*]] = "mhlo.dynamic_reshape"(%[[VAL_19:.*]], %[[PROPER_SHAPE_TENSOR]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
// CHECK: %[[SHAPE_2:.*]] = shape.to_extent_tensor %[[SHAPE_1]] : tensor<?xindex> -> tensor<?xindex>
// CHECK: %[[RESHAPED_RESULT:.*]] = "mhlo.dynamic_reshape"(%[[ASSUMING_RESULT:.*]], %[[SHAPE_2]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
// CHECK: return %[[RESHAPED_RESULT]] : tensor<*xf32>
// CHECK: }
@ -290,8 +297,8 @@ func @addUnrankedScalar(%arg0: tensor<*xf32>, %arg1: tensor<f32>) -> tensor<*xf3
// to a 1D tensor.
// CHECK: %[[SHAPE_0:.*]] = shape.shape_of %[[ARG_0]] : tensor<*xf32>
// CHECK: %[[NUM_ELEMENTS:.*]] = shape.num_elements %[[SHAPE_0]]
// CHECK: %[[SIZE:.*]] = shape.size_to_index %[[NUM_ELEMENTS]]
// CHECK: %[[SIZE_TENSOR:.*]] = tensor_from_elements(%[[SIZE]]) : tensor<1xindex>
// CHECK: %[[NUM_ELEMENTS_INDEX:.*]] = shape.size_to_index %[[NUM_ELEMENTS]]
// CHECK: %[[SIZE_TENSOR:.*]] = tensor_from_elements(%[[NUM_ELEMENTS_INDEX]]) : tensor<1xindex>
// CHECK: %[[RESHAPED:.*]] = "mhlo.dynamic_reshape"(%[[ARG_0]], %[[SIZE_TENSOR]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
// The assuming region is part of the second stage of lowering
// with ranked broadcasting logic.
@ -299,15 +306,16 @@ func @addUnrankedScalar(%arg0: tensor<*xf32>, %arg1: tensor<f32>) -> tensor<*xf3
// CHECK: %[[SHAPE_1:.*]] = shape.shape_of %[[ARG_1]] : tensor<f32>
// CHECK: %[[WITNESS:.*]] = shape.cstr_broadcastable %[[SHAPE_RESHAPED]], %[[SHAPE_1]]
// CHECK: %[[ASSUMING_RESULT:.*]] = shape.assuming %[[WITNESS]] -> (tensor<?xf32>) {
// CHECK: %[[SHAPE_TENSOR:.*]] = shape.to_extent_tensor %[[SHAPE_RESHAPED]] : tensor<1xindex>
// CHECK: %[[BROADCASTED_LHS:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[RESHAPED]], %[[SHAPE_TENSOR]]) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor<?xf32>, tensor<1xindex>) -> tensor<?xf32>
// CHECK: %[[BROADCASTED_RHS:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG_1]], %[[SHAPE_TENSOR]]) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f32>, tensor<1xindex>) -> tensor<?xf32>
// CHECK: %[[SHAPE_OF:.*]] = shape.shape_of %[[RESHAPED]] : tensor<?xf32>
// CHECK: %[[SHAPE_RESHAPED:.*]] = shape.to_extent_tensor %[[SHAPE_OF]]
// CHECK: %[[BROADCASTED_LHS:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[RESHAPED]], %[[SHAPE_RESHAPED]]) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor<?xf32>, tensor<1xindex>) -> tensor<?xf32>
// CHECK: %[[BROADCASTED_RHS:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG_1]], %[[SHAPE_RESHAPED]]) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f32>, tensor<1xindex>) -> tensor<?xf32>
// CHECK: %[[BROADCASTED_RESULT:.*]] = mhlo.add %[[BROADCASTED_LHS]], %[[BROADCASTED_RHS]] : tensor<?xf32>
// CHECK: shape.assuming_yield %[[BROADCASTED_RESULT]] : tensor<?xf32>
// CHECK: }
// As part of the unranked logic, the result is reshaped back
// to an unranked tensor.
// CHECK: %[[PROPER_SHAPE_TENSOR:.*]] = shape.to_extent_tensor %[[SHAPE_0]] : tensor<?xindex>
// CHECK: %[[RESHAPED_RESULT:.*]] = "mhlo.dynamic_reshape"(%[[VAL_19:.*]], %[[PROPER_SHAPE_TENSOR]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
// CHECK: %[[SHAPE_2:.*]] = shape.to_extent_tensor %[[SHAPE_0]]
// CHECK: %[[RESHAPED_RESULT:.*]] = "mhlo.dynamic_reshape"(%[[ASSUMING_RESULT:.*]], %[[SHAPE_2]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
// CHECK: return %[[RESHAPED_RESULT]] : tensor<*xf32>
// CHECK: }

View File

@ -5,9 +5,9 @@
func @sqr_transform_result(%a: tensor<*xf32>) -> tensor<*xf32> {
// Flatten operand shape.
%shape = shape.shape_of %a : tensor<*xf32>
%num_elements = shape.num_elements %shape
%num_elements_as_index = shape.size_to_index %num_elements
%shape = shape.shape_of %a : tensor<*xf32> -> tensor<?xindex>
%num_elements = shape.num_elements %shape : tensor<?xindex> -> index
%num_elements_as_index = shape.size_to_index %num_elements : index
%flat_shape = tensor_from_elements(%num_elements_as_index) : tensor<1xindex>
%flat_a = "mhlo.dynamic_reshape"(%a, %flat_shape)
: (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
@ -16,7 +16,7 @@ func @sqr_transform_result(%a: tensor<*xf32>) -> tensor<*xf32> {
%flat_b = "mhlo.sqrt"(%flat_a) : (tensor<?xf32>) -> tensor<?xf32>
// Restore original shape.
%shape_as_extent_tensor = shape.to_extent_tensor %shape : tensor<?xindex>
%shape_as_extent_tensor = shape.to_extent_tensor %shape : tensor<?xindex> -> tensor<?xindex>
%b = "mhlo.dynamic_reshape"(%flat_b, %shape_as_extent_tensor)
: (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
@ -73,14 +73,14 @@ func @sqrt_static(%a: tensor<2x3xf32>) -> tensor<2x3xf32> {
func @add_unranked(%a : tensor<*xf32>, %b : tensor<*xf32>) -> tensor<*xf32> {
// CHECK: %[[SHAPE_A:.*]] = shape.shape_of %[[A]]
// CHECK: %[[SHAPE_B:.*]] = shape.shape_of %[[B]]
// CHECK: %[[SHAPE:.*]] = shape.any %[[SHAPE_A]], %[[SHAPE_B]]
// CHECK: %[[SHAPE:.*]] = "shape.any"(%[[SHAPE_A]], %[[SHAPE_B]])
// CHECK: %[[NUM_ELEMENTS:.*]] = shape.num_elements %[[SHAPE]]
// CHECK: %[[NUM_ELEMENTS_AS_INDEX:.*]] = shape.size_to_index %[[NUM_ELEMENTS]]
// CHECK: %[[FLAT_SHAPE:.*]] = tensor_from_elements(%[[NUM_ELEMENTS_AS_INDEX]]) : tensor<1xindex>
// CHECK: %[[FLAT_A:.*]] = "mhlo.dynamic_reshape"(%[[A]], %[[FLAT_SHAPE]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
// CHECK: %[[FLAT_B:.*]] = "mhlo.dynamic_reshape"(%[[B]], %[[FLAT_SHAPE]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
// CHECK: %[[FLAT_RESULT:.*]] = mhlo.add %[[FLAT_A]], %[[FLAT_B]] : tensor<?xf32>
// CHECK: %[[SHAPE_AS_EXTENT_TENSOR:.*]] = shape.to_extent_tensor %[[SHAPE]] : tensor<?xindex>
// CHECK: %[[SHAPE_AS_EXTENT_TENSOR:.*]] = shape.to_extent_tensor %[[SHAPE]]
// CHECK: %[[RESULT:.*]] = "mhlo.dynamic_reshape"(%[[FLAT_RESULT]], %[[SHAPE_AS_EXTENT_TENSOR]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
// CHECK: return %[[RESULT]] : tensor<*xf32>
%result = mhlo.add %a, %b : tensor<*xf32>