Add canonicalization patterns for dynamic_broadcast_in_dim where the target shape is the shape of the operand.

PiperOrigin-RevId: 321312182
This commit is contained in:
Stephan Herhut 2020-07-15 07:37:58 +00:00 committed by Mehdi Amini
parent 86f290896d
commit 7a6adc6a84
4 changed files with 46 additions and 4 deletions

View File

@ -21,9 +21,9 @@ limitations under the License.
include "third_party/llvm/llvm-project/mlir/include/mlir/IR/OpBase.td"
include "third_party/llvm/llvm-project/mlir/include/mlir/Interfaces/InferTypeOpInterface.td"
include "third_party/llvm/llvm-project/mlir/include/mlir/Interfaces/SideEffectInterfaces.td"
include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td"
include "mlir-hlo/Dialect/mhlo/IR/hlo_utils.td"
include "mlir-hlo/Dialect/mhlo/IR/infer_fusibility_op_interface.td"
include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td"
include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_utils.td"
include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/infer_fusibility_op_interface.td"
def HLO_Dialect : Dialect {
let name = "mhlo";

View File

@ -35,6 +35,7 @@ limitations under the License.
#include "third_party/llvm/llvm-project/llvm/include/llvm/Support/Casting.h"
#include "third_party/llvm/llvm-project/llvm/include/llvm/Support/FormatVariadic.h"
#include "third_party/llvm/llvm-project/llvm/include/llvm/Support/MathExtras.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/Shape/IR/Shape.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Attributes.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Builders.h"
@ -59,6 +60,7 @@ limitations under the License.
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/utils/hlo_utils.h"
namespace mlir {
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_patterns.cc.inc"
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_structs.cc.inc"
namespace mhlo {
@ -744,7 +746,8 @@ class DynamicBroadcastInDimOpNotActuallyDynamic
void DynamicBroadcastInDimOp::getCanonicalizationPatterns(
OwningRewritePatternList& results, MLIRContext* context) {
results.insert<DynamicBroadcastInDimOpNotActuallyDynamic>(context);
results.insert<DynamicBroadcastInDimOpNotActuallyDynamic,
DynamicBroadcastToOwnShape>(context);
}
//===----------------------------------------------------------------------===//

View File

@ -0,0 +1,29 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// Canonicalization patterns for the MHLO dialect.
include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td"
include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td"
def EqualBinaryOperands : Constraint<CPred<"$0 == $1">>;
// Canonicalization patterns.
def DynamicBroadcastToOwnShape : Pat<
(HLO_DynamicBroadcastInDimOp:$op $arg0,
(Shape_ToExtentTensorOp (Shape_ShapeOfOp $arg1)), $attr),
(replaceWithValue $arg0), [(EqualBinaryOperands $arg0, $arg1)]>;

View File

@ -365,6 +365,16 @@ func @dynamic_broadcast_in_dim_op_not_actually_dynamic(%arg0: tensor<4xf32>, %ar
return %0 : tensor<5x4xf32>
}
// CHECK-LABEL: func @dynamic_broadcast_in_dim_to_same_shape
func @dynamic_broadcast_in_dim_to_same_shape(%arg0: tensor<?xf32>) -> tensor<?xf32> {
// CHECK-SAME: %[[ARG:.*]]: tensor<?xf32>
%0 = shape.shape_of %arg0 : tensor<?xf32>
%1 = shape.to_extent_tensor %0 : tensor<1xindex>
%2 = "mhlo.dynamic_broadcast_in_dim"(%arg0, %1) { broadcast_dimensions = dense<0> : tensor<1xi64> } : (tensor<?xf32>, tensor<1xindex>) -> tensor<?xf32>
// CHECK: return %[[ARG]] : tensor<?xf32>
return %2 : tensor<?xf32>
}
// CHECK-LABEL: func @broadcast_in_dim_constant_fold_0d
func @broadcast_in_dim_constant_fold_0d() -> tensor<1x64x224x224xf32> {
%cst = mhlo.constant dense<0.000000e+00> : tensor<f32>