diff --git a/include/mlir-hlo/Dialect/mhlo/transforms/mhlo_passes.td b/include/mlir-hlo/Dialect/mhlo/transforms/mhlo_passes.td index fa3bde2..aa0f4c3 100644 --- a/include/mlir-hlo/Dialect/mhlo/transforms/mhlo_passes.td +++ b/include/mlir-hlo/Dialect/mhlo/transforms/mhlo_passes.td @@ -30,6 +30,11 @@ def LegalizeControlFlowPass : Pass<"mhlo-legalize-control-flow", "FuncOp"> { let constructor = "createLegalizeControlFlowPass()"; } +def LegalizeControlFlowToScfPass : Pass<"mhlo-control-flow-to-scf", "FuncOp"> { + let summary = "Legalize from MHLO control flow to SCF control flow."; + let constructor = "createControlFlowToScfPass()"; +} + def LegalizeGatherToTorchIndexSelectPass : Pass<"mhlo-legalize-gather-to-torch-index-select", "FuncOp"> { let summary = "Legalizes gathers to a torch index select."; let constructor = "createLegalizeGatherToTorchIndexSelectPass()"; diff --git a/include/mlir-hlo/Dialect/mhlo/transforms/passes.h b/include/mlir-hlo/Dialect/mhlo/transforms/passes.h index efa116f..541d8e4 100644 --- a/include/mlir-hlo/Dialect/mhlo/transforms/passes.h +++ b/include/mlir-hlo/Dialect/mhlo/transforms/passes.h @@ -35,6 +35,9 @@ namespace mhlo { /// Lowers HLO control flow ops to the Standard dialect. std::unique_ptr> createLegalizeControlFlowPass(); +/// Lowers MHLO control flow ops to the SCF dialect. +std::unique_ptr> createControlFlowToScfPass(); + /// Lowers from HLO dialect to Standard dialect. std::unique_ptr> createLegalizeToStdPass(); diff --git a/lib/Dialect/mhlo/transforms/CMakeLists.txt b/lib/Dialect/mhlo/transforms/CMakeLists.txt index bb9f98d..945fa0e 100644 --- a/lib/Dialect/mhlo/transforms/CMakeLists.txt +++ b/lib/Dialect/mhlo/transforms/CMakeLists.txt @@ -93,6 +93,7 @@ add_mlir_library(MhloToLhloConversion add_mlir_library(MhloToStandard legalize_control_flow.cc legalize_to_standard.cc + mhlo_control_flow_to_scf.cc DEPENDS MLIRhlo_opsIncGen diff --git a/lib/Dialect/mhlo/transforms/mhlo_control_flow_to_scf.cc b/lib/Dialect/mhlo/transforms/mhlo_control_flow_to_scf.cc new file mode 100644 index 0000000..aba7b07 --- /dev/null +++ b/lib/Dialect/mhlo/transforms/mhlo_control_flow_to_scf.cc @@ -0,0 +1,199 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "llvm/Support/Casting.h" +#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" +#include "mlir-hlo/Dialect/mhlo/transforms/passes.h" +#include "mlir/Dialect/SCF/SCF.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/IR/Matchers.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/StandardTypes.h" +#include "mlir/IR/Value.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Support/LLVM.h" + +#define DEBUG_TYPE "mhlo-control-flow-to-scf" + +namespace mlir { +namespace mhlo { + +namespace { + +/// Convert MHLO While to SCF. +void MatchAndRewrite(WhileOp whileOp); + +/// Pass that converts MHLO control flow to SCF. +class ControlFlowToScfPass + : public mlir::PassWrapper { + void getDependentDialects(DialectRegistry& registry) const override { + registry.insert(); + } + void runOnFunction() override { + getFunction().walk([&](WhileOp whileOp) { MatchAndRewrite(whileOp); }); + } +}; + +// TODO(jpienaar): Look into reformulating as a pattern. +void MatchAndRewrite(WhileOp whileOp) { + // Handle pattern: + // x = start + // step = ... + // limit = ... + // while (x < limit) { ... x += step; } + + // Only handling multi value while loops at the moment. + auto tupleOp = whileOp.getOperand().getDefiningOp(); + if (!tupleOp) return; + auto bodyReturn = whileOp.body() + .front() + .getTerminator() + ->getOperand(0) + .getDefiningOp(); + // Note: due to the shape restrictions on While, if the operand to While is a + // tuple, then so is the return type of the body. But the verifier isn't + // checking that at the moment, so just bail out here if this doesn't hold. + if (!bodyReturn) return; + + Value result = whileOp.cond().front().getTerminator()->getOperand(0); + // TODO(jpienaar): Expand to handle more than simple case with LT compare and + // constant step. + auto cmp = result.getDefiningOp(); + if (!cmp || cmp.comparison_direction() != "LT") return; + + const int kConstant = -1; + auto getValueAndIndex = [&](Value val) -> std::pair { + if (matchPattern(val, m_Constant())) return {val, kConstant}; + // If it is defined by a tuple, then the tuple has to have been fed in and + // the external value is captured. + if (auto gte = val.getDefiningOp()) { + if (!gte.getOperand().isa()) return {nullptr, 0}; + int index = gte.index().getSExtValue(); + return {tupleOp.getOperand(index), index}; + } + return {nullptr, 0}; + }; + + using ValueIndex = std::pair; + ValueIndex loopIndVar = getValueAndIndex(cmp.lhs()); + ValueIndex max = getValueAndIndex(cmp.rhs()); + if (!loopIndVar.first || !max.first) return; + auto add = + bodyReturn.getOperand(loopIndVar.second).getDefiningOp(); + if (!add) return; + ValueIndex step = getValueAndIndex(add.rhs()); + if (step.second != kConstant || !step.first) return; + + // Only handle case where tuple isn't propagated as is for now. + // TODO(jpienaar): Remove this when a tuple is also created inside the loop + // to propagate. + for (auto* use : whileOp.body().front().getArgument(0).getUsers()) + if (!isa(use)) return; + + LLVM_DEBUG(llvm::dbgs() << "Found for (" << whileOp.getLoc() << "):\n"; + llvm::dbgs() << " loopIndVar = " << loopIndVar.second << " max = " + << max.second << " step = " << step.second << "\n"; + llvm::dbgs() << " loopIndVar = " << loopIndVar.first << " max = " + << max.first << " step = " << step.first << "\n";); + OpBuilder b(whileOp); + // Inputs to new for loop. + llvm::SmallVector input; + input.reserve(tupleOp.getNumOperands()); + for (auto r : tupleOp.getOperands().take_front(loopIndVar.second)) + input.push_back(r); + for (auto r : tupleOp.getOperands().drop_front(loopIndVar.second + 1)) + input.push_back(r); + + auto tensorIndexType = RankedTensorType::get({}, b.getIndexType()); + auto getAsIndex = [&](Value val) { + auto loc = whileOp.getLoc(); + return b.create( + loc, b.create(loc, tensorIndexType, val), ValueRange()); + }; + + // SCF for uses index type, so converted these. + auto forloopIndVar = getAsIndex(loopIndVar.first); + auto forMax = getAsIndex(max.first); + auto forStep = getAsIndex(step.first); + auto forOp = b.create(whileOp.getLoc(), forloopIndVar, + forMax, forStep, input); + // Transfer the body without the block arguments. + forOp.getLoopBody().front().getOperations().splice( + forOp.getLoopBody().front().getOperations().end(), + whileOp.body().front().getOperations()); + + b.setInsertionPointToStart(&forOp.getLoopBody().front()); + auto loopIndVarElType = + loopIndVar.first.getType().cast().getElementType(); + Value indVar = b.create( + whileOp.getLoc(), RankedTensorType::get({}, loopIndVarElType), + b.create(whileOp.getLoc(), loopIndVarElType, + forOp.getInductionVar())); + // Update all block argument users to the SCF For args. + for (auto* use : + llvm::make_early_inc_range(whileOp.body().getArgument(0).getUsers())) { + // TODO(jpienaar): Expand here too when we allow using the tuple in the + // loop. + auto gte = cast(use); + // If the loop induction var, then refer to the loop induction variable as + // this operand is not updated. + if (gte.index() == loopIndVar.second) { + use->getResult(0).replaceAllUsesWith(indVar); + use->erase(); + continue; + } + int index = gte.index().getSExtValue(); + // If after the loop induction variable, then decrement as we don't include + // the loop induction variable in the for iter operands. + if (index > loopIndVar.second) --index; + use->getResult(0).replaceAllUsesWith(forOp.getIterOperands()[index]); + use->erase(); + } + + // Create new yield op without induction var update. + SmallVector newYieldOps; + newYieldOps.reserve(bodyReturn.getNumOperands() - 1); + for (auto r : bodyReturn.getOperands().take_front(loopIndVar.second)) + newYieldOps.push_back(r); + for (auto r : bodyReturn.getOperands().drop_front(loopIndVar.second + 1)) + newYieldOps.push_back(r); + // Delete return & tuple op. + forOp.getLoopBody().front().back().erase(); + forOp.getLoopBody().front().back().erase(); + b.setInsertionPointToEnd(&forOp.getLoopBody().front()); + b.create(whileOp.getLoc(), newYieldOps); + + // Recombine output tuple with max value of induction variable. + llvm::SmallVector loopOut; + loopOut.reserve(forOp.getNumResults() + 1); + for (auto r : forOp.getResults().take_front(loopIndVar.second)) + loopOut.push_back(r); + loopOut.push_back(max.first); + for (auto r : forOp.getResults().drop_front(loopIndVar.second)) + loopOut.push_back(r); + b.setInsertionPoint(whileOp); + auto newRes = b.create(whileOp.getLoc(), loopOut); + whileOp.replaceAllUsesWith(newRes.getOperation()); + whileOp.erase(); +} + +} // anonymous namespace + +std::unique_ptr> createControlFlowToScfPass() { + return std::make_unique(); +} + +} // namespace mhlo +} // namespace mlir diff --git a/tests/legalize_to_scf.mlir b/tests/legalize_to_scf.mlir new file mode 100644 index 0000000..9c887a7 --- /dev/null +++ b/tests/legalize_to_scf.mlir @@ -0,0 +1,38 @@ +// RUN: mlir-hlo-opt --mhlo-control-flow-to-scf %s | FileCheck %s + +func @lt_loop(%arg0: tensor<4xf32>, %arg1: tensor, %arg2: tensor, %arg3: tensor<4xf32>, %arg4: tensor, %arg5: tensor, %arg6: tensor, %arg7: tensor, %arg8: tensor) -> (tuple, tensor, tensor>) { + %cst = constant dense<-1> : tensor + %cst_0 = constant dense<1> : tensor + %cst_1 = constant dense<0> : tensor + %cst_2 = constant dense<1000> : tensor + %0 = "mhlo.tuple"(%cst_1, %cst, %cst_2) : (tensor, tensor, tensor) -> tuple, tensor, tensor> + %1 = "mhlo.while"(%0) ( { + ^bb0(%arg9: tuple, tensor, tensor>): // no predecessors + %2 = "mhlo.get_tuple_element"(%arg9) {index = 0 : i32} : (tuple, tensor, tensor>) -> tensor + %3 = "mhlo.get_tuple_element"(%arg9) {index = 2 : i32} : (tuple, tensor, tensor>) -> tensor + %4 = "mhlo.compare"(%2, %3) {comparison_direction = "LT"} : (tensor, tensor) -> tensor + "mhlo.return"(%4) : (tensor) -> () + }, { + ^bb0(%arg9: tuple, tensor, tensor>): // no predecessors + %2 = "mhlo.get_tuple_element"(%arg9) {index = 0 : i32} : (tuple, tensor, tensor>) -> tensor + %3 = mhlo.add %2, %cst_0 : tensor + %4 = "mhlo.get_tuple_element"(%arg9) {index = 1 : i32} : (tuple, tensor, tensor>) -> tensor + %5 = "mhlo.get_tuple_element"(%arg9) {index = 2 : i32} : (tuple, tensor, tensor>) -> tensor + %6 = "mhlo.tuple"(%3, %4, %5) : (tensor, tensor, tensor) -> tuple, tensor, tensor> + "mhlo.return"(%6) : (tuple, tensor, tensor>) -> () + }) : (tuple, tensor, tensor>) -> tuple, tensor, tensor> + return %1 : tuple, tensor, tensor> +} + +// CHECK-LABEL: func @lt_loop( +// CHECK: %[[VAL_9:.*]] = constant dense<-1> : tensor +// CHECK: %[[VAL_10:.*]] = constant dense<1> : tensor +// CHECK: %[[VAL_11:.*]] = constant dense<0> : tensor +// CHECK: %[[VAL_12:.*]] = constant dense<1000> : tensor +// CHECK: %[[VAL_14:.*]] = index_cast %[[VAL_11]] : tensor to tensor +// CHECK: %[[VAL_15:.*]] = extract_element %[[VAL_14]][] : tensor +// CHECK: %[[VAL_16:.*]] = index_cast %[[VAL_12]] : tensor to tensor +// CHECK: %[[VAL_17:.*]] = extract_element %[[VAL_16]][] : tensor +// CHECK: %[[VAL_18:.*]] = index_cast %[[VAL_10]] : tensor to tensor +// CHECK: %[[VAL_19:.*]] = extract_element %[[VAL_18]][] : tensor +// CHECK: scf.for %[[VAL_21:.*]] = %[[VAL_15]] to %[[VAL_17]] step %[[VAL_19]] iter_args(%[[VAL_22:.*]] = %[[VAL_9]], %[[VAL_23:.*]] = %[[VAL_12]])