[KERNEL_GEN] Restrict broadcast -> reshape canonicalization to identity dims.
This is needed to avoid the case, when the broadcast_in_dims also performs permutation. PiperOrigin-RevId: 350650342
This commit is contained in:
parent
a8c0f2b944
commit
6c42f54298
|
@ -43,4 +43,8 @@ def BinBroadcastDimensionsNonEmpty : NativeCodeCall<
|
|||
class GetScalarOfType<int value> : NativeCodeCall<
|
||||
"hlo::GetScalarOfType(getElementTypeOrSelf($0)," # value # ")">;
|
||||
|
||||
// Constraint that Attr has values [0, 1, ...].
|
||||
def IdentityBroadcastDims : AttrConstraint<
|
||||
CPred<"hlo::IsSequenceStartingWith0($_self)">>;
|
||||
|
||||
#endif // HLO_UTILS
|
||||
|
|
|
@ -88,6 +88,9 @@ DenseElementsAttr GetScalarLimitOfType(Type ty, ScalarLimit limit);
|
|||
std::string LmhloToMhloOpName(llvm::StringRef op_name,
|
||||
mlir::MLIRContext* context);
|
||||
|
||||
// Return true if Attr has values [0, 1, ...].
|
||||
bool IsSequenceStartingWith0(DenseIntElementsAttr attr);
|
||||
|
||||
} // namespace hlo
|
||||
} // namespace mlir
|
||||
|
||||
|
|
|
@ -38,5 +38,6 @@ def RemoveRedundantDynamicReshape : Pat<
|
|||
// is a dynamic reshape.
|
||||
def RemoveRedundantDynamicBroadcast : Pat<
|
||||
(HLO_DynamicBroadcastInDimOp
|
||||
(HLO_DynamicReshapeOp $operand, $shape), $shape, $dims),
|
||||
(HLO_DynamicReshapeOp $operand, $shape),
|
||||
$shape, IdentityBroadcastDims:$dims),
|
||||
(HLO_DynamicReshapeOp $operand, $shape)>;
|
||||
|
|
|
@ -140,5 +140,11 @@ std::string LmhloToMhloOpName(llvm::StringRef op_name,
|
|||
return "";
|
||||
}
|
||||
|
||||
bool IsSequenceStartingWith0(DenseIntElementsAttr attr) {
|
||||
for (int64_t i = 0, e = attr.getNumElements(); i < e; ++i)
|
||||
if (attr.getValue<IntegerAttr>(i).getInt() != i) return false;
|
||||
return true;
|
||||
}
|
||||
|
||||
} // namespace hlo
|
||||
} // namespace mlir
|
||||
|
|
|
@ -1542,8 +1542,10 @@ func @identity_broadcast_in_dim_reshape(%arg0: tensor<128xf32>) -> tensor<128xf3
|
|||
}
|
||||
|
||||
// CHECK-LABEL: @broadcast_of_reshape
|
||||
func @broadcast_of_reshape(%arg: tensor<?xf32>, %shape: tensor<2xindex>) -> tensor<?x?xf32> {
|
||||
%0 = "mhlo.dynamic_reshape"(%arg, %shape) : (tensor<?xf32>, tensor<2xindex>) -> tensor<?x?xf32>
|
||||
func @broadcast_of_reshape(%arg: tensor<?xf32>,
|
||||
%shape: tensor<2xindex>) -> tensor<?x?xf32> {
|
||||
%0 = "mhlo.dynamic_reshape"(%arg, %shape)
|
||||
: (tensor<?xf32>, tensor<2xindex>) -> tensor<?x?xf32>
|
||||
%1 = "mhlo.dynamic_broadcast_in_dim"(%0, %shape) {
|
||||
broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>
|
||||
} : (tensor<?x?xf32>, tensor<2xindex>) -> tensor<?x?xf32>
|
||||
|
@ -1551,3 +1553,16 @@ func @broadcast_of_reshape(%arg: tensor<?xf32>, %shape: tensor<2xindex>) -> tens
|
|||
}
|
||||
// CHECK: [[RESHAPE:%.*]] = "mhlo.dynamic_reshape"
|
||||
// CHECK: return [[RESHAPE]]
|
||||
|
||||
// CHECK-LABEL: @permutation_broadcast_of_reshape
|
||||
func @permutation_broadcast_of_reshape(%arg: tensor<?xf32>,
|
||||
%shape: tensor<2xindex>) -> tensor<?x?xf32> {
|
||||
%0 = "mhlo.dynamic_reshape"(%arg, %shape)
|
||||
: (tensor<?xf32>, tensor<2xindex>) -> tensor<?x?xf32>
|
||||
%1 = "mhlo.dynamic_broadcast_in_dim"(%0, %shape) {
|
||||
broadcast_dimensions = dense<[1, 0]> : tensor<2xi64>
|
||||
} : (tensor<?x?xf32>, tensor<2xindex>) -> tensor<?x?xf32>
|
||||
return %1 : tensor<?x?xf32>
|
||||
}
|
||||
// CHECK: mhlo.dynamic_reshape
|
||||
// CHECK: mhlo.dynamic_broadcast_in_dim
|
||||
|
|
Loading…
Reference in New Issue