[KERNEL_GEN] Restrict broadcast -> reshape canonicalization to identity dims.

This is needed to avoid the case, when the broadcast_in_dims also performs permutation.

PiperOrigin-RevId: 350650342
This commit is contained in:
Alexander Belyaev 2021-01-07 15:29:30 -08:00 committed by TensorFlow MLIR Team
parent a8c0f2b944
commit 6c42f54298
5 changed files with 32 additions and 3 deletions

View File

@ -43,4 +43,8 @@ def BinBroadcastDimensionsNonEmpty : NativeCodeCall<
class GetScalarOfType<int value> : NativeCodeCall<
"hlo::GetScalarOfType(getElementTypeOrSelf($0)," # value # ")">;
// Constraint that Attr has values [0, 1, ...].
def IdentityBroadcastDims : AttrConstraint<
CPred<"hlo::IsSequenceStartingWith0($_self)">>;
#endif // HLO_UTILS

View File

@ -88,6 +88,9 @@ DenseElementsAttr GetScalarLimitOfType(Type ty, ScalarLimit limit);
std::string LmhloToMhloOpName(llvm::StringRef op_name,
mlir::MLIRContext* context);
// Return true if Attr has values [0, 1, ...].
bool IsSequenceStartingWith0(DenseIntElementsAttr attr);
} // namespace hlo
} // namespace mlir

View File

@ -38,5 +38,6 @@ def RemoveRedundantDynamicReshape : Pat<
// is a dynamic reshape.
def RemoveRedundantDynamicBroadcast : Pat<
(HLO_DynamicBroadcastInDimOp
(HLO_DynamicReshapeOp $operand, $shape), $shape, $dims),
(HLO_DynamicReshapeOp $operand, $shape),
$shape, IdentityBroadcastDims:$dims),
(HLO_DynamicReshapeOp $operand, $shape)>;

View File

@ -140,5 +140,11 @@ std::string LmhloToMhloOpName(llvm::StringRef op_name,
return "";
}
bool IsSequenceStartingWith0(DenseIntElementsAttr attr) {
for (int64_t i = 0, e = attr.getNumElements(); i < e; ++i)
if (attr.getValue<IntegerAttr>(i).getInt() != i) return false;
return true;
}
} // namespace hlo
} // namespace mlir

View File

@ -1542,8 +1542,10 @@ func @identity_broadcast_in_dim_reshape(%arg0: tensor<128xf32>) -> tensor<128xf3
}
// CHECK-LABEL: @broadcast_of_reshape
func @broadcast_of_reshape(%arg: tensor<?xf32>, %shape: tensor<2xindex>) -> tensor<?x?xf32> {
%0 = "mhlo.dynamic_reshape"(%arg, %shape) : (tensor<?xf32>, tensor<2xindex>) -> tensor<?x?xf32>
func @broadcast_of_reshape(%arg: tensor<?xf32>,
%shape: tensor<2xindex>) -> tensor<?x?xf32> {
%0 = "mhlo.dynamic_reshape"(%arg, %shape)
: (tensor<?xf32>, tensor<2xindex>) -> tensor<?x?xf32>
%1 = "mhlo.dynamic_broadcast_in_dim"(%0, %shape) {
broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>
} : (tensor<?x?xf32>, tensor<2xindex>) -> tensor<?x?xf32>
@ -1551,3 +1553,16 @@ func @broadcast_of_reshape(%arg: tensor<?xf32>, %shape: tensor<2xindex>) -> tens
}
// CHECK: [[RESHAPE:%.*]] = "mhlo.dynamic_reshape"
// CHECK: return [[RESHAPE]]
// CHECK-LABEL: @permutation_broadcast_of_reshape
func @permutation_broadcast_of_reshape(%arg: tensor<?xf32>,
%shape: tensor<2xindex>) -> tensor<?x?xf32> {
%0 = "mhlo.dynamic_reshape"(%arg, %shape)
: (tensor<?xf32>, tensor<2xindex>) -> tensor<?x?xf32>
%1 = "mhlo.dynamic_broadcast_in_dim"(%0, %shape) {
broadcast_dimensions = dense<[1, 0]> : tensor<2xi64>
} : (tensor<?x?xf32>, tensor<2xindex>) -> tensor<?x?xf32>
return %1 : tensor<?x?xf32>
}
// CHECK: mhlo.dynamic_reshape
// CHECK: mhlo.dynamic_broadcast_in_dim