Canonicalize MHLO Case and If Ops with constant conditions
ReplaceOpWithRegion was taken directly from ScfOps. We should maybe put that somewhere common in core. PiperOrigin-RevId: 365936724
This commit is contained in:
		
							parent
							
								
									2fb2a92c6e
								
							
						
					
					
						commit
						5d65758e8c
					
				| 
						 | 
					@ -549,6 +549,8 @@ def HLO_IfOp: HLO_Op<"if", [
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  // TODO(b/129422361): ConditionalOp has special conversion logic to HLO.
 | 
					  // TODO(b/129422361): ConditionalOp has special conversion logic to HLO.
 | 
				
			||||||
  let hasCustomHLOConverter = 1;
 | 
					  let hasCustomHLOConverter = 1;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  let hasCanonicalizer = 1;
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// Xla Client API has two separate calls for indexed and predicated conditional,
 | 
					// Xla Client API has two separate calls for indexed and predicated conditional,
 | 
				
			||||||
| 
						 | 
					@ -569,6 +571,8 @@ def HLO_CaseOp: HLO_Op<"case", [
 | 
				
			||||||
  let results = (outs Variadic<HLO_TensorOrTuple>);
 | 
					  let results = (outs Variadic<HLO_TensorOrTuple>);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  let hasCustomHLOConverter = 1;
 | 
					  let hasCustomHLOConverter = 1;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  let hasCanonicalizer = 1;
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -45,6 +45,7 @@ limitations under the License.
 | 
				
			||||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
 | 
					#include "mlir/Dialect/Tensor/IR/Tensor.h"
 | 
				
			||||||
#include "mlir/IR/Attributes.h"
 | 
					#include "mlir/IR/Attributes.h"
 | 
				
			||||||
#include "mlir/IR/Builders.h"
 | 
					#include "mlir/IR/Builders.h"
 | 
				
			||||||
 | 
					#include "mlir/IR/BuiltinAttributes.h"
 | 
				
			||||||
#include "mlir/IR/BuiltinTypes.h"
 | 
					#include "mlir/IR/BuiltinTypes.h"
 | 
				
			||||||
#include "mlir/IR/Dialect.h"
 | 
					#include "mlir/IR/Dialect.h"
 | 
				
			||||||
#include "mlir/IR/Location.h"
 | 
					#include "mlir/IR/Location.h"
 | 
				
			||||||
| 
						 | 
					@ -131,6 +132,19 @@ DenseIntElementsAttr BuildSliceLimits(DenseIntElementsAttr start_indices,
 | 
				
			||||||
  return GetI64ElementsAttr(slice_limits, builder);
 | 
					  return GetI64ElementsAttr(slice_limits, builder);
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					/// Replaces the given op with the contents of the given single-block region,
 | 
				
			||||||
 | 
					/// using the operands of the block terminator to replace operation results.
 | 
				
			||||||
 | 
					static void ReplaceOpWithRegion(PatternRewriter& rewriter, Operation* op,
 | 
				
			||||||
 | 
					                                Region& region, ValueRange blockArgs = {}) {
 | 
				
			||||||
 | 
					  assert(llvm::hasSingleElement(region) && "expected single-block region");
 | 
				
			||||||
 | 
					  Block* block = ®ion.front();
 | 
				
			||||||
 | 
					  Operation* terminator = block->getTerminator();
 | 
				
			||||||
 | 
					  ValueRange results = terminator->getOperands();
 | 
				
			||||||
 | 
					  rewriter.mergeBlockBefore(block, op, blockArgs);
 | 
				
			||||||
 | 
					  rewriter.replaceOp(op, results);
 | 
				
			||||||
 | 
					  rewriter.eraseOp(terminator);
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
#include "mhlo_canonicalize.inc"
 | 
					#include "mhlo_canonicalize.inc"
 | 
				
			||||||
}  // namespace
 | 
					}  // namespace
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -2129,6 +2143,24 @@ static LogicalResult Verify(IfOp op) {
 | 
				
			||||||
  return success();
 | 
					  return success();
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					static LogicalResult InlineIfConstantCondition(IfOp ifOp,
 | 
				
			||||||
 | 
					                                               PatternRewriter& rewriter) {
 | 
				
			||||||
 | 
					  DenseIntElementsAttr pred_attr;
 | 
				
			||||||
 | 
					  if (!matchPattern(ifOp.pred(), m_Constant(&pred_attr))) return failure();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  if (pred_attr.getSplatValue<BoolAttr>().getValue()) {
 | 
				
			||||||
 | 
					    ReplaceOpWithRegion(rewriter, ifOp, ifOp.true_branch(), ifOp.true_arg());
 | 
				
			||||||
 | 
					  } else {
 | 
				
			||||||
 | 
					    ReplaceOpWithRegion(rewriter, ifOp, ifOp.false_branch(), ifOp.false_arg());
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					  return success();
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					void IfOp::getCanonicalizationPatterns(OwningRewritePatternList& results,
 | 
				
			||||||
 | 
					                                       MLIRContext* context) {
 | 
				
			||||||
 | 
					  results.add(&InlineIfConstantCondition);
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
//===----------------------------------------------------------------------===//
 | 
					//===----------------------------------------------------------------------===//
 | 
				
			||||||
// Case Op
 | 
					// Case Op
 | 
				
			||||||
//===----------------------------------------------------------------------===//
 | 
					//===----------------------------------------------------------------------===//
 | 
				
			||||||
| 
						 | 
					@ -2150,6 +2182,31 @@ static LogicalResult Verify(CaseOp op) {
 | 
				
			||||||
  return success();
 | 
					  return success();
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					static LogicalResult InlineCaseConstantCondition(CaseOp caseOp,
 | 
				
			||||||
 | 
					                                                 PatternRewriter& rewriter) {
 | 
				
			||||||
 | 
					  DenseIntElementsAttr index_attr;
 | 
				
			||||||
 | 
					  if (!matchPattern(caseOp.index(), m_Constant(&index_attr))) {
 | 
				
			||||||
 | 
					    return failure();
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					  int64_t index =
 | 
				
			||||||
 | 
					      index_attr.getSplatValue<IntegerAttr>().getValue().getSExtValue();
 | 
				
			||||||
 | 
					  // For an OOB index, the last branch is executed as the default branch:
 | 
				
			||||||
 | 
					  // https://www.tensorflow.org/xla/operation_semantics#conditional
 | 
				
			||||||
 | 
					  if (index < 0 || index >= caseOp.getNumRegions())
 | 
				
			||||||
 | 
					    index = caseOp.getNumRegions() - 1;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  Region& region = caseOp.getRegion(index);
 | 
				
			||||||
 | 
					  if (!llvm::hasSingleElement(region)) return failure();
 | 
				
			||||||
 | 
					  ReplaceOpWithRegion(rewriter, caseOp, region,
 | 
				
			||||||
 | 
					                      caseOp.branch_operands()[index]);
 | 
				
			||||||
 | 
					  return success();
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					void CaseOp::getCanonicalizationPatterns(OwningRewritePatternList& results,
 | 
				
			||||||
 | 
					                                         MLIRContext* context) {
 | 
				
			||||||
 | 
					  results.add(&InlineCaseConstantCondition);
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
//===----------------------------------------------------------------------===//
 | 
					//===----------------------------------------------------------------------===//
 | 
				
			||||||
// SqrtOp
 | 
					// SqrtOp
 | 
				
			||||||
//===----------------------------------------------------------------------===//
 | 
					//===----------------------------------------------------------------------===//
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -1280,6 +1280,108 @@ func @not_fold_sqrt_neg_constants() -> tensor<4xf32> {
 | 
				
			||||||
  return %1 : tensor<4xf32>
 | 
					  return %1 : tensor<4xf32>
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// CHECK-LABEL: func @fold_if_true(
 | 
				
			||||||
 | 
					//  CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]
 | 
				
			||||||
 | 
					//  CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]
 | 
				
			||||||
 | 
					//  CHECK-SAME: )
 | 
				
			||||||
 | 
					func @fold_if_true(%arg0 : tensor<f32>, %arg1 : tensor<f32>) -> tensor<f32> {
 | 
				
			||||||
 | 
					  // CHECK-NOT: mhlo.if
 | 
				
			||||||
 | 
					  // CHECK: return %[[ARG0]]
 | 
				
			||||||
 | 
					  %true = mhlo.constant dense<true> : tensor<i1>
 | 
				
			||||||
 | 
					  %0 = "mhlo.if"(%true, %arg0, %arg1) ( {
 | 
				
			||||||
 | 
					    ^bb0(%bbarg0: tensor<f32>):
 | 
				
			||||||
 | 
					      "mhlo.return"(%bbarg0) : (tensor<f32>) -> ()
 | 
				
			||||||
 | 
					  },  {
 | 
				
			||||||
 | 
					    ^bb0(%bbarg1: tensor<f32>):
 | 
				
			||||||
 | 
					      "mhlo.return"(%bbarg1) : (tensor<f32>) -> ()
 | 
				
			||||||
 | 
					  }) : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
 | 
				
			||||||
 | 
					  return %0 : tensor<f32>
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// CHECK-LABEL: func @fold_if_false(
 | 
				
			||||||
 | 
					//  CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]
 | 
				
			||||||
 | 
					//  CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]
 | 
				
			||||||
 | 
					//  CHECK-SAME: )
 | 
				
			||||||
 | 
					func @fold_if_false(%arg0 : tensor<f32>, %arg1 : tensor<f32>) -> tensor<f32> {
 | 
				
			||||||
 | 
					  // CHECK-NOT: mhlo.if
 | 
				
			||||||
 | 
					  // CHECK: return %[[ARG1]]
 | 
				
			||||||
 | 
					  %false = mhlo.constant dense<false> : tensor<i1>
 | 
				
			||||||
 | 
					  %0 = "mhlo.if"(%false, %arg0, %arg1) ( {
 | 
				
			||||||
 | 
					    ^bb0(%bbarg0: tensor<f32>):
 | 
				
			||||||
 | 
					      "mhlo.return"(%bbarg0) : (tensor<f32>) -> ()
 | 
				
			||||||
 | 
					  },  {
 | 
				
			||||||
 | 
					    ^bb0(%bbarg1: tensor<f32>):
 | 
				
			||||||
 | 
					      "mhlo.return"(%bbarg1) : (tensor<f32>) -> ()
 | 
				
			||||||
 | 
					  }) : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
 | 
				
			||||||
 | 
					  return %0 : tensor<f32>
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// CHECK-LABEL: func @fold_case(
 | 
				
			||||||
 | 
					//  CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]
 | 
				
			||||||
 | 
					//  CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]
 | 
				
			||||||
 | 
					//  CHECK-SAME:   %[[ARG2:[a-zA-Z0-9_]+]]
 | 
				
			||||||
 | 
					//  CHECK-SAME: )
 | 
				
			||||||
 | 
					func @fold_case(%arg0 : tensor<f32>, %arg1 : tensor<f32>, %arg2 : tensor<f32>) -> tensor<f32> {
 | 
				
			||||||
 | 
					  // CHECK-NOT: mhlo.case
 | 
				
			||||||
 | 
					  // CHECK: return %[[ARG1]]
 | 
				
			||||||
 | 
					  %c1 = mhlo.constant dense<1> : tensor<i32>
 | 
				
			||||||
 | 
					  %0 = "mhlo.case"(%c1, %arg0, %arg1, %arg2) ( {
 | 
				
			||||||
 | 
					    ^bb0(%bbarg0: tensor<f32>):
 | 
				
			||||||
 | 
					      "mhlo.return"(%bbarg0) : (tensor<f32>) -> ()
 | 
				
			||||||
 | 
					    },  {
 | 
				
			||||||
 | 
					    ^bb0(%bbarg1: tensor<f32>):
 | 
				
			||||||
 | 
					      "mhlo.return"(%bbarg1) : (tensor<f32>) -> ()
 | 
				
			||||||
 | 
					  },  {
 | 
				
			||||||
 | 
					    ^bb0(%bbarg2: tensor<f32>):
 | 
				
			||||||
 | 
					      "mhlo.return"(%bbarg2) : (tensor<f32>) -> ()
 | 
				
			||||||
 | 
					  }) : (tensor<i32>, tensor<f32>, tensor<f32>, tensor<f32>) -> tensor<f32>
 | 
				
			||||||
 | 
					  return %0 : tensor<f32>
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// CHECK-LABEL: func @fold_case_negative_index(
 | 
				
			||||||
 | 
					//  CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]
 | 
				
			||||||
 | 
					//  CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]
 | 
				
			||||||
 | 
					//  CHECK-SAME:   %[[ARG2:[a-zA-Z0-9_]+]]
 | 
				
			||||||
 | 
					//  CHECK-SAME: )
 | 
				
			||||||
 | 
					func @fold_case_negative_index(%arg0 : tensor<f32>, %arg1 : tensor<f32>, %arg2 : tensor<f32>) -> tensor<f32> {
 | 
				
			||||||
 | 
					  // CHECK-NOT: mhlo.case
 | 
				
			||||||
 | 
					  // CHECK: return %[[ARG2]]
 | 
				
			||||||
 | 
					  %m1000 = mhlo.constant dense<-1000> : tensor<i32>
 | 
				
			||||||
 | 
					  %0 = "mhlo.case"(%m1000, %arg0, %arg1, %arg2) ( {
 | 
				
			||||||
 | 
					    ^bb0(%bbarg0: tensor<f32>):
 | 
				
			||||||
 | 
					      "mhlo.return"(%bbarg0) : (tensor<f32>) -> ()
 | 
				
			||||||
 | 
					    },  {
 | 
				
			||||||
 | 
					    ^bb0(%bbarg1: tensor<f32>):
 | 
				
			||||||
 | 
					      "mhlo.return"(%bbarg1) : (tensor<f32>) -> ()
 | 
				
			||||||
 | 
					  },  {
 | 
				
			||||||
 | 
					    ^bb0(%bbarg2: tensor<f32>):
 | 
				
			||||||
 | 
					      "mhlo.return"(%bbarg2) : (tensor<f32>) -> ()
 | 
				
			||||||
 | 
					  }) : (tensor<i32>, tensor<f32>, tensor<f32>, tensor<f32>) -> tensor<f32>
 | 
				
			||||||
 | 
					  return %0 : tensor<f32>
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// CHECK-LABEL: func @fold_case_oob_index(
 | 
				
			||||||
 | 
					//  CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]
 | 
				
			||||||
 | 
					//  CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]
 | 
				
			||||||
 | 
					//  CHECK-SAME:   %[[ARG2:[a-zA-Z0-9_]+]]
 | 
				
			||||||
 | 
					//  CHECK-SAME: )
 | 
				
			||||||
 | 
					func @fold_case_oob_index(%arg0 : tensor<f32>, %arg1 : tensor<f32>, %arg2 : tensor<f32>) -> tensor<f32> {
 | 
				
			||||||
 | 
					  // CHECK-NOT: mhlo.case
 | 
				
			||||||
 | 
					  // CHECK: return %[[ARG2]]
 | 
				
			||||||
 | 
					  %c1000 = mhlo.constant dense<1000> : tensor<i32>
 | 
				
			||||||
 | 
					  %0 = "mhlo.case"(%c1000, %arg0, %arg1, %arg2) ( {
 | 
				
			||||||
 | 
					    ^bb0(%bbarg0: tensor<f32>):
 | 
				
			||||||
 | 
					      "mhlo.return"(%bbarg0) : (tensor<f32>) -> ()
 | 
				
			||||||
 | 
					    },  {
 | 
				
			||||||
 | 
					    ^bb0(%bbarg1: tensor<f32>):
 | 
				
			||||||
 | 
					      "mhlo.return"(%bbarg1) : (tensor<f32>) -> ()
 | 
				
			||||||
 | 
					  },  {
 | 
				
			||||||
 | 
					    ^bb0(%bbarg2: tensor<f32>):
 | 
				
			||||||
 | 
					      "mhlo.return"(%bbarg2) : (tensor<f32>) -> ()
 | 
				
			||||||
 | 
					  }) : (tensor<i32>, tensor<f32>, tensor<f32>, tensor<f32>) -> tensor<f32>
 | 
				
			||||||
 | 
					  return %0 : tensor<f32>
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// CHECK-LABEL: @tensor_flow_scatter_v1_update
 | 
					// CHECK-LABEL: @tensor_flow_scatter_v1_update
 | 
				
			||||||
func @tensor_flow_scatter_v1_update() -> tensor<3x3xi32> {
 | 
					func @tensor_flow_scatter_v1_update() -> tensor<3x3xi32> {
 | 
				
			||||||
  %0 = constant dense<[[1, 2, 3], [4, 5, 6], [7, 8, 9]]> : tensor<3x3xi32>
 | 
					  %0 = constant dense<[[1, 2, 3], [4, 5, 6], [7, 8, 9]]> : tensor<3x3xi32>
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue