diff --git a/include/mlir-hlo/Dialect/mhlo/IR/hlo_utils.td b/include/mlir-hlo/Dialect/mhlo/IR/hlo_utils.td index 32940cb..461527f 100644 --- a/include/mlir-hlo/Dialect/mhlo/IR/hlo_utils.td +++ b/include/mlir-hlo/Dialect/mhlo/IR/hlo_utils.td @@ -43,4 +43,8 @@ def BinBroadcastDimensionsNonEmpty : NativeCodeCall< class GetScalarOfType : NativeCodeCall< "hlo::GetScalarOfType(getElementTypeOrSelf($0)," # value # ")">; +// Constraint that Attr has values [0, 1, ...]. +def IdentityBroadcastDims : AttrConstraint< + CPred<"hlo::IsSequenceStartingWith0($_self)">>; + #endif // HLO_UTILS diff --git a/include/mlir-hlo/utils/hlo_utils.h b/include/mlir-hlo/utils/hlo_utils.h index 602ca96..ca00bb6 100644 --- a/include/mlir-hlo/utils/hlo_utils.h +++ b/include/mlir-hlo/utils/hlo_utils.h @@ -88,6 +88,9 @@ DenseElementsAttr GetScalarLimitOfType(Type ty, ScalarLimit limit); std::string LmhloToMhloOpName(llvm::StringRef op_name, mlir::MLIRContext* context); +// Return true if Attr has values [0, 1, ...]. +bool IsSequenceStartingWith0(DenseIntElementsAttr attr); + } // namespace hlo } // namespace mlir diff --git a/lib/Dialect/mhlo/IR/mhlo_canonicalize.td b/lib/Dialect/mhlo/IR/mhlo_canonicalize.td index a4c480b..5974289 100644 --- a/lib/Dialect/mhlo/IR/mhlo_canonicalize.td +++ b/lib/Dialect/mhlo/IR/mhlo_canonicalize.td @@ -38,5 +38,6 @@ def RemoveRedundantDynamicReshape : Pat< // is a dynamic reshape. def RemoveRedundantDynamicBroadcast : Pat< (HLO_DynamicBroadcastInDimOp - (HLO_DynamicReshapeOp $operand, $shape), $shape, $dims), + (HLO_DynamicReshapeOp $operand, $shape), + $shape, IdentityBroadcastDims:$dims), (HLO_DynamicReshapeOp $operand, $shape)>; diff --git a/lib/utils/hlo_utils.cc b/lib/utils/hlo_utils.cc index 8ff1ce3..7f3325a 100644 --- a/lib/utils/hlo_utils.cc +++ b/lib/utils/hlo_utils.cc @@ -140,5 +140,11 @@ std::string LmhloToMhloOpName(llvm::StringRef op_name, return ""; } +bool IsSequenceStartingWith0(DenseIntElementsAttr attr) { + for (int64_t i = 0, e = attr.getNumElements(); i < e; ++i) + if (attr.getValue(i).getInt() != i) return false; + return true; +} + } // namespace hlo } // namespace mlir diff --git a/tests/canonicalize.mlir b/tests/canonicalize.mlir index 1b6ff8b..c3e0143 100644 --- a/tests/canonicalize.mlir +++ b/tests/canonicalize.mlir @@ -1542,8 +1542,10 @@ func @identity_broadcast_in_dim_reshape(%arg0: tensor<128xf32>) -> tensor<128xf3 } // CHECK-LABEL: @broadcast_of_reshape -func @broadcast_of_reshape(%arg: tensor, %shape: tensor<2xindex>) -> tensor { - %0 = "mhlo.dynamic_reshape"(%arg, %shape) : (tensor, tensor<2xindex>) -> tensor +func @broadcast_of_reshape(%arg: tensor, + %shape: tensor<2xindex>) -> tensor { + %0 = "mhlo.dynamic_reshape"(%arg, %shape) + : (tensor, tensor<2xindex>) -> tensor %1 = "mhlo.dynamic_broadcast_in_dim"(%0, %shape) { broadcast_dimensions = dense<[0, 1]> : tensor<2xi64> } : (tensor, tensor<2xindex>) -> tensor @@ -1551,3 +1553,16 @@ func @broadcast_of_reshape(%arg: tensor, %shape: tensor<2xindex>) -> tens } // CHECK: [[RESHAPE:%.*]] = "mhlo.dynamic_reshape" // CHECK: return [[RESHAPE]] + +// CHECK-LABEL: @permutation_broadcast_of_reshape +func @permutation_broadcast_of_reshape(%arg: tensor, + %shape: tensor<2xindex>) -> tensor { + %0 = "mhlo.dynamic_reshape"(%arg, %shape) + : (tensor, tensor<2xindex>) -> tensor + %1 = "mhlo.dynamic_broadcast_in_dim"(%0, %shape) { + broadcast_dimensions = dense<[1, 0]> : tensor<2xi64> + } : (tensor, tensor<2xindex>) -> tensor + return %1 : tensor +} +// CHECK: mhlo.dynamic_reshape +// CHECK: mhlo.dynamic_broadcast_in_dim