diff --git a/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td b/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td index 4ce8ba5..653d742 100644 --- a/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td +++ b/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td @@ -21,9 +21,9 @@ limitations under the License. include "third_party/llvm/llvm-project/mlir/include/mlir/IR/OpBase.td" include "third_party/llvm/llvm-project/mlir/include/mlir/Interfaces/InferTypeOpInterface.td" include "third_party/llvm/llvm-project/mlir/include/mlir/Interfaces/SideEffectInterfaces.td" -include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td" -include "mlir-hlo/Dialect/mhlo/IR/hlo_utils.td" -include "mlir-hlo/Dialect/mhlo/IR/infer_fusibility_op_interface.td" +include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td" +include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_utils.td" +include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/infer_fusibility_op_interface.td" def HLO_Dialect : Dialect { let name = "mhlo"; diff --git a/lib/Dialect/mhlo/IR/hlo_ops.cc b/lib/Dialect/mhlo/IR/hlo_ops.cc index 16c19cb..3715c41 100644 --- a/lib/Dialect/mhlo/IR/hlo_ops.cc +++ b/lib/Dialect/mhlo/IR/hlo_ops.cc @@ -35,6 +35,7 @@ limitations under the License. #include "third_party/llvm/llvm-project/llvm/include/llvm/Support/Casting.h" #include "third_party/llvm/llvm-project/llvm/include/llvm/Support/FormatVariadic.h" #include "third_party/llvm/llvm-project/llvm/include/llvm/Support/MathExtras.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/Shape/IR/Shape.h" #include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h" #include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Attributes.h" #include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Builders.h" @@ -59,6 +60,7 @@ limitations under the License. #include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/utils/hlo_utils.h" namespace mlir { +#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_patterns.cc.inc" #include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_structs.cc.inc" namespace mhlo { @@ -744,7 +746,8 @@ class DynamicBroadcastInDimOpNotActuallyDynamic void DynamicBroadcastInDimOp::getCanonicalizationPatterns( OwningRewritePatternList& results, MLIRContext* context) { - results.insert(context); + results.insert(context); } //===----------------------------------------------------------------------===// diff --git a/lib/Dialect/mhlo/IR/hlo_patterns.td b/lib/Dialect/mhlo/IR/hlo_patterns.td new file mode 100644 index 0000000..f3e0181 --- /dev/null +++ b/lib/Dialect/mhlo/IR/hlo_patterns.td @@ -0,0 +1,29 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Canonicalization patterns for the MHLO dialect. + +include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td" +include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td" + +def EqualBinaryOperands : Constraint>; + +// Canonicalization patterns. + +def DynamicBroadcastToOwnShape : Pat< + (HLO_DynamicBroadcastInDimOp:$op $arg0, + (Shape_ToExtentTensorOp (Shape_ShapeOfOp $arg1)), $attr), + (replaceWithValue $arg0), [(EqualBinaryOperands $arg0, $arg1)]>; + diff --git a/tests/canonicalize.mlir b/tests/canonicalize.mlir index 8777412..f773c95 100644 --- a/tests/canonicalize.mlir +++ b/tests/canonicalize.mlir @@ -365,6 +365,16 @@ func @dynamic_broadcast_in_dim_op_not_actually_dynamic(%arg0: tensor<4xf32>, %ar return %0 : tensor<5x4xf32> } +// CHECK-LABEL: func @dynamic_broadcast_in_dim_to_same_shape +func @dynamic_broadcast_in_dim_to_same_shape(%arg0: tensor) -> tensor { +// CHECK-SAME: %[[ARG:.*]]: tensor + %0 = shape.shape_of %arg0 : tensor + %1 = shape.to_extent_tensor %0 : tensor<1xindex> + %2 = "mhlo.dynamic_broadcast_in_dim"(%arg0, %1) { broadcast_dimensions = dense<0> : tensor<1xi64> } : (tensor, tensor<1xindex>) -> tensor + // CHECK: return %[[ARG]] : tensor + return %2 : tensor +} + // CHECK-LABEL: func @broadcast_in_dim_constant_fold_0d func @broadcast_in_dim_constant_fold_0d() -> tensor<1x64x224x224xf32> { %cst = mhlo.constant dense<0.000000e+00> : tensor