Add canonicalization patterns for dynamic_broadcast_in_dim where the target shape is the shape of the operand.
PiperOrigin-RevId: 321312182
This commit is contained in:
		
							parent
							
								
									86f290896d
								
							
						
					
					
						commit
						7a6adc6a84
					
				|  | @ -21,9 +21,9 @@ limitations under the License. | |||
| include "third_party/llvm/llvm-project/mlir/include/mlir/IR/OpBase.td" | ||||
| include "third_party/llvm/llvm-project/mlir/include/mlir/Interfaces/InferTypeOpInterface.td" | ||||
| include "third_party/llvm/llvm-project/mlir/include/mlir/Interfaces/SideEffectInterfaces.td" | ||||
| include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td" | ||||
| include "mlir-hlo/Dialect/mhlo/IR/hlo_utils.td" | ||||
| include "mlir-hlo/Dialect/mhlo/IR/infer_fusibility_op_interface.td" | ||||
| include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td" | ||||
| include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_utils.td" | ||||
| include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/infer_fusibility_op_interface.td" | ||||
| 
 | ||||
| def HLO_Dialect : Dialect { | ||||
|   let name = "mhlo"; | ||||
|  |  | |||
|  | @ -35,6 +35,7 @@ limitations under the License. | |||
| #include "third_party/llvm/llvm-project/llvm/include/llvm/Support/Casting.h" | ||||
| #include "third_party/llvm/llvm-project/llvm/include/llvm/Support/FormatVariadic.h" | ||||
| #include "third_party/llvm/llvm-project/llvm/include/llvm/Support/MathExtras.h" | ||||
| #include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/Shape/IR/Shape.h" | ||||
| #include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h" | ||||
| #include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Attributes.h" | ||||
| #include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Builders.h" | ||||
|  | @ -59,6 +60,7 @@ limitations under the License. | |||
| #include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/utils/hlo_utils.h" | ||||
| 
 | ||||
| namespace mlir { | ||||
| #include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_patterns.cc.inc" | ||||
| #include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_structs.cc.inc" | ||||
| namespace mhlo { | ||||
| 
 | ||||
|  | @ -744,7 +746,8 @@ class DynamicBroadcastInDimOpNotActuallyDynamic | |||
| 
 | ||||
| void DynamicBroadcastInDimOp::getCanonicalizationPatterns( | ||||
|     OwningRewritePatternList& results, MLIRContext* context) { | ||||
|   results.insert<DynamicBroadcastInDimOpNotActuallyDynamic>(context); | ||||
|   results.insert<DynamicBroadcastInDimOpNotActuallyDynamic, | ||||
|                  DynamicBroadcastToOwnShape>(context); | ||||
| } | ||||
| 
 | ||||
| //===----------------------------------------------------------------------===//
 | ||||
|  |  | |||
|  | @ -0,0 +1,29 @@ | |||
| /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. | ||||
| 
 | ||||
| Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| you may not use this file except in compliance with the License. | ||||
| You may obtain a copy of the License at | ||||
| 
 | ||||
|     http://www.apache.org/licenses/LICENSE-2.0 | ||||
| 
 | ||||
| Unless required by applicable law or agreed to in writing, software | ||||
| distributed under the License is distributed on an "AS IS" BASIS, | ||||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| See the License for the specific language governing permissions and | ||||
| limitations under the License. | ||||
| ==============================================================================*/ | ||||
| 
 | ||||
| // Canonicalization patterns for the MHLO dialect. | ||||
| 
 | ||||
| include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td" | ||||
| include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td" | ||||
| 
 | ||||
| def EqualBinaryOperands : Constraint<CPred<"$0 == $1">>; | ||||
| 
 | ||||
| // Canonicalization patterns. | ||||
| 
 | ||||
| def DynamicBroadcastToOwnShape : Pat< | ||||
|   (HLO_DynamicBroadcastInDimOp:$op $arg0, | ||||
|       (Shape_ToExtentTensorOp (Shape_ShapeOfOp $arg1)), $attr), | ||||
|   (replaceWithValue $arg0), [(EqualBinaryOperands $arg0, $arg1)]>; | ||||
| 
 | ||||
|  | @ -365,6 +365,16 @@ func @dynamic_broadcast_in_dim_op_not_actually_dynamic(%arg0: tensor<4xf32>, %ar | |||
|   return %0 : tensor<5x4xf32> | ||||
| } | ||||
| 
 | ||||
| // CHECK-LABEL: func @dynamic_broadcast_in_dim_to_same_shape | ||||
| func @dynamic_broadcast_in_dim_to_same_shape(%arg0: tensor<?xf32>) -> tensor<?xf32> { | ||||
| // CHECK-SAME: %[[ARG:.*]]: tensor<?xf32> | ||||
|   %0 = shape.shape_of %arg0 : tensor<?xf32> | ||||
|   %1 = shape.to_extent_tensor %0 : tensor<1xindex> | ||||
|   %2 = "mhlo.dynamic_broadcast_in_dim"(%arg0, %1) { broadcast_dimensions = dense<0> : tensor<1xi64> } : (tensor<?xf32>, tensor<1xindex>) -> tensor<?xf32> | ||||
|   // CHECK: return %[[ARG]] : tensor<?xf32> | ||||
|   return %2 : tensor<?xf32> | ||||
| } | ||||
| 
 | ||||
| // CHECK-LABEL: func @broadcast_in_dim_constant_fold_0d | ||||
| func @broadcast_in_dim_constant_fold_0d() -> tensor<1x64x224x224xf32> { | ||||
|   %cst = mhlo.constant dense<0.000000e+00> : tensor<f32> | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue