[KERNEL_GEN] Restrict broadcast -> reshape canonicalization to identity dims.
This is needed to avoid the case, when the broadcast_in_dims also performs permutation. PiperOrigin-RevId: 350650342
This commit is contained in:
parent
a8c0f2b944
commit
6c42f54298
|
@ -43,4 +43,8 @@ def BinBroadcastDimensionsNonEmpty : NativeCodeCall<
|
||||||
class GetScalarOfType<int value> : NativeCodeCall<
|
class GetScalarOfType<int value> : NativeCodeCall<
|
||||||
"hlo::GetScalarOfType(getElementTypeOrSelf($0)," # value # ")">;
|
"hlo::GetScalarOfType(getElementTypeOrSelf($0)," # value # ")">;
|
||||||
|
|
||||||
|
// Constraint that Attr has values [0, 1, ...].
|
||||||
|
def IdentityBroadcastDims : AttrConstraint<
|
||||||
|
CPred<"hlo::IsSequenceStartingWith0($_self)">>;
|
||||||
|
|
||||||
#endif // HLO_UTILS
|
#endif // HLO_UTILS
|
||||||
|
|
|
@ -88,6 +88,9 @@ DenseElementsAttr GetScalarLimitOfType(Type ty, ScalarLimit limit);
|
||||||
std::string LmhloToMhloOpName(llvm::StringRef op_name,
|
std::string LmhloToMhloOpName(llvm::StringRef op_name,
|
||||||
mlir::MLIRContext* context);
|
mlir::MLIRContext* context);
|
||||||
|
|
||||||
|
// Return true if Attr has values [0, 1, ...].
|
||||||
|
bool IsSequenceStartingWith0(DenseIntElementsAttr attr);
|
||||||
|
|
||||||
} // namespace hlo
|
} // namespace hlo
|
||||||
} // namespace mlir
|
} // namespace mlir
|
||||||
|
|
||||||
|
|
|
@ -38,5 +38,6 @@ def RemoveRedundantDynamicReshape : Pat<
|
||||||
// is a dynamic reshape.
|
// is a dynamic reshape.
|
||||||
def RemoveRedundantDynamicBroadcast : Pat<
|
def RemoveRedundantDynamicBroadcast : Pat<
|
||||||
(HLO_DynamicBroadcastInDimOp
|
(HLO_DynamicBroadcastInDimOp
|
||||||
(HLO_DynamicReshapeOp $operand, $shape), $shape, $dims),
|
(HLO_DynamicReshapeOp $operand, $shape),
|
||||||
|
$shape, IdentityBroadcastDims:$dims),
|
||||||
(HLO_DynamicReshapeOp $operand, $shape)>;
|
(HLO_DynamicReshapeOp $operand, $shape)>;
|
||||||
|
|
|
@ -140,5 +140,11 @@ std::string LmhloToMhloOpName(llvm::StringRef op_name,
|
||||||
return "";
|
return "";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool IsSequenceStartingWith0(DenseIntElementsAttr attr) {
|
||||||
|
for (int64_t i = 0, e = attr.getNumElements(); i < e; ++i)
|
||||||
|
if (attr.getValue<IntegerAttr>(i).getInt() != i) return false;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace hlo
|
} // namespace hlo
|
||||||
} // namespace mlir
|
} // namespace mlir
|
||||||
|
|
|
@ -1542,8 +1542,10 @@ func @identity_broadcast_in_dim_reshape(%arg0: tensor<128xf32>) -> tensor<128xf3
|
||||||
}
|
}
|
||||||
|
|
||||||
// CHECK-LABEL: @broadcast_of_reshape
|
// CHECK-LABEL: @broadcast_of_reshape
|
||||||
func @broadcast_of_reshape(%arg: tensor<?xf32>, %shape: tensor<2xindex>) -> tensor<?x?xf32> {
|
func @broadcast_of_reshape(%arg: tensor<?xf32>,
|
||||||
%0 = "mhlo.dynamic_reshape"(%arg, %shape) : (tensor<?xf32>, tensor<2xindex>) -> tensor<?x?xf32>
|
%shape: tensor<2xindex>) -> tensor<?x?xf32> {
|
||||||
|
%0 = "mhlo.dynamic_reshape"(%arg, %shape)
|
||||||
|
: (tensor<?xf32>, tensor<2xindex>) -> tensor<?x?xf32>
|
||||||
%1 = "mhlo.dynamic_broadcast_in_dim"(%0, %shape) {
|
%1 = "mhlo.dynamic_broadcast_in_dim"(%0, %shape) {
|
||||||
broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>
|
broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>
|
||||||
} : (tensor<?x?xf32>, tensor<2xindex>) -> tensor<?x?xf32>
|
} : (tensor<?x?xf32>, tensor<2xindex>) -> tensor<?x?xf32>
|
||||||
|
@ -1551,3 +1553,16 @@ func @broadcast_of_reshape(%arg: tensor<?xf32>, %shape: tensor<2xindex>) -> tens
|
||||||
}
|
}
|
||||||
// CHECK: [[RESHAPE:%.*]] = "mhlo.dynamic_reshape"
|
// CHECK: [[RESHAPE:%.*]] = "mhlo.dynamic_reshape"
|
||||||
// CHECK: return [[RESHAPE]]
|
// CHECK: return [[RESHAPE]]
|
||||||
|
|
||||||
|
// CHECK-LABEL: @permutation_broadcast_of_reshape
|
||||||
|
func @permutation_broadcast_of_reshape(%arg: tensor<?xf32>,
|
||||||
|
%shape: tensor<2xindex>) -> tensor<?x?xf32> {
|
||||||
|
%0 = "mhlo.dynamic_reshape"(%arg, %shape)
|
||||||
|
: (tensor<?xf32>, tensor<2xindex>) -> tensor<?x?xf32>
|
||||||
|
%1 = "mhlo.dynamic_broadcast_in_dim"(%0, %shape) {
|
||||||
|
broadcast_dimensions = dense<[1, 0]> : tensor<2xi64>
|
||||||
|
} : (tensor<?x?xf32>, tensor<2xindex>) -> tensor<?x?xf32>
|
||||||
|
return %1 : tensor<?x?xf32>
|
||||||
|
}
|
||||||
|
// CHECK: mhlo.dynamic_reshape
|
||||||
|
// CHECK: mhlo.dynamic_broadcast_in_dim
|
||||||
|
|
Loading…
Reference in New Issue