Sink standard dialect constants in sink_constants_to_control_flow pass
This is required before exporting HLO dialect ops with standard dialect constant to XLA. Also, sink constants for sort op as well. Added a TODO to generalize this pass to handle more ops and non-const values defined outside. PiperOrigin-RevId: 324301911
This commit is contained in:
		
							parent
							
								
									1c535f1718
								
							
						
					
					
						commit
						577a81a66d
					
				|  | @ -21,6 +21,7 @@ limitations under the License. | ||||||
| #include "mlir/Pass/PassManager.h" | #include "mlir/Pass/PassManager.h" | ||||||
| #include "mlir/Support/LLVM.h" | #include "mlir/Support/LLVM.h" | ||||||
| #include "mlir/Transforms/RegionUtils.h" | #include "mlir/Transforms/RegionUtils.h" | ||||||
|  | #include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h" | ||||||
| 
 | 
 | ||||||
| namespace mlir { | namespace mlir { | ||||||
| namespace mhlo { | namespace mhlo { | ||||||
|  | @ -29,6 +30,13 @@ namespace { | ||||||
| 
 | 
 | ||||||
| // A pass that sinks constants implicitly captured in control flow regions. This
 | // A pass that sinks constants implicitly captured in control flow regions. This
 | ||||||
| // is necessary to export to XLA.
 | // is necessary to export to XLA.
 | ||||||
|  | // TODO(hinsu): Generalize this pass to handle all the ops with regions. Any
 | ||||||
|  | // value used within the region that is defined outside of op's region should be
 | ||||||
|  | // sank to the regions and not just the constants. Ops such as If and While
 | ||||||
|  | // whose computations doesn't require fixed signature like Sort or Reduce have
 | ||||||
|  | // an option to pass outside values as operands of the op to avoid recomputing
 | ||||||
|  | // those within internally. Note that doing so is the only option in case of
 | ||||||
|  | // BlockArguments.
 | ||||||
| class SinkConstantsToControlFlowPass | class SinkConstantsToControlFlowPass | ||||||
|     : public mlir::PassWrapper<SinkConstantsToControlFlowPass, FunctionPass> { |     : public mlir::PassWrapper<SinkConstantsToControlFlowPass, FunctionPass> { | ||||||
|   void runOnFunction() override { |   void runOnFunction() override { | ||||||
|  | @ -39,6 +47,8 @@ class SinkConstantsToControlFlowPass | ||||||
|       } else if (auto if_op = llvm::dyn_cast<IfOp>(op)) { |       } else if (auto if_op = llvm::dyn_cast<IfOp>(op)) { | ||||||
|         SinkToRegion(&if_op.true_branch()); |         SinkToRegion(&if_op.true_branch()); | ||||||
|         SinkToRegion(&if_op.false_branch()); |         SinkToRegion(&if_op.false_branch()); | ||||||
|  |       } else if (auto sort_op = llvm::dyn_cast<SortOp>(op)) { | ||||||
|  |         SinkToRegion(&sort_op.comparator()); | ||||||
|       } |       } | ||||||
|     }); |     }); | ||||||
|   } |   } | ||||||
|  | @ -46,26 +56,26 @@ class SinkConstantsToControlFlowPass | ||||||
|  private: |  private: | ||||||
|   // Performs constant sinking into a region.
 |   // Performs constant sinking into a region.
 | ||||||
|   static void SinkToRegion(Region* region) { |   static void SinkToRegion(Region* region) { | ||||||
|     llvm::DenseMap<Value, ConstOp> sunk_constant; |     llvm::DenseMap<Value, Operation*> sunk_constant; | ||||||
|     visitUsedValuesDefinedAbove({*region}, [&](OpOperand* use) { |     visitUsedValuesDefinedAbove({*region}, [&](OpOperand* use) { | ||||||
|       Value constant = use->get(); |       Value constant = use->get(); | ||||||
|       auto const_op = dyn_cast_or_null<ConstOp>(constant.getDefiningOp()); |       auto op = constant.getDefiningOp(); | ||||||
|       if (!const_op) return; |       if (!op || !isa<ConstOp, ConstantOp>(op)) return; | ||||||
|       auto map_entry = sunk_constant.try_emplace(constant, nullptr); |       auto map_entry = sunk_constant.try_emplace(constant, nullptr); | ||||||
|       if (!map_entry.second) { |       if (!map_entry.second) { | ||||||
|         // This constant has already been cloned into the region, reuse it.
 |         // This constant has already been cloned into the region, reuse it.
 | ||||||
|         use->set(map_entry.first->getSecond().getResult()); |         use->set(map_entry.first->getSecond()->getResult(0)); | ||||||
|         if (constant.use_empty()) const_op.erase(); |         if (op->use_empty()) op->erase(); | ||||||
|         return; |         return; | ||||||
|       } |       } | ||||||
|       if (constant.hasOneUse()) { |       if (constant.hasOneUse()) { | ||||||
|         const_op.getOperation()->moveBefore(®ion->front().front()); |         op->moveBefore(®ion->front().front()); | ||||||
|         return; |         return; | ||||||
|       } |       } | ||||||
|       map_entry.first->getSecond() = const_op.clone(); |       map_entry.first->getSecond() = op->clone(); | ||||||
|       region->front().getOperations().insert(region->front().begin(), |       region->front().getOperations().insert(region->front().begin(), | ||||||
|                                              map_entry.first->getSecond()); |                                              map_entry.first->getSecond()); | ||||||
|       use->set(map_entry.first->getSecond().getResult()); |       use->set(map_entry.first->getSecond()->getResult(0)); | ||||||
|     }); |     }); | ||||||
|   } |   } | ||||||
| }; | }; | ||||||
|  |  | ||||||
|  | @ -58,3 +58,17 @@ func @sink_const_to_conditional(%arg0: tensor<i64>) -> tensor<i64> { | ||||||
|   %9 = "mhlo.get_tuple_element"(%2) {index = 0 : i32} : (tuple<tensor<i64>>) -> tensor<i64> |   %9 = "mhlo.get_tuple_element"(%2) {index = 0 : i32} : (tuple<tensor<i64>>) -> tensor<i64> | ||||||
|   return %9 : tensor<i64> |   return %9 : tensor<i64> | ||||||
| } | } | ||||||
|  | 
 | ||||||
|  | func @sink_const_to_sort(%arg0: tensor<16xf32>) { | ||||||
|  |   %c0 = constant dense<1.0> : tensor<f32> | ||||||
|  |   // CHECK: "mhlo.sort" | ||||||
|  |   %0 = "mhlo.sort"(%arg0) ( { | ||||||
|  |   ^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>): | ||||||
|  |     // CHECK: constant dense<1.000000e+00> | ||||||
|  |     %1 = "mhlo.divide"(%arg1, %c0) : (tensor<f32>, tensor<f32>) -> tensor<f32> | ||||||
|  |     %2 = "mhlo.divide"(%arg2, %c0) : (tensor<f32>, tensor<f32>) -> tensor<f32> | ||||||
|  |     %3 = "mhlo.compare"(%1, %2) {comparison_direction = "GT"} : (tensor<f32>, tensor<f32>) -> tensor<i1> | ||||||
|  |     "mhlo.return"(%3) : (tensor<i1>) -> () | ||||||
|  |   }) {is_stable = true} : (tensor<16xf32>) -> tensor<16xi32> | ||||||
|  |   return | ||||||
|  | } | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue