[mhlo] Add legalize to SCF pass

Start of pass to legalize MHLO control flow to SCF for further optimization in common form. The current version just matches a very simple instance (which also happens to occur a few times). Exposes some further canonicalization opportunities that aren't yet addressed.

PiperOrigin-RevId: 329017723
This commit is contained in:
Jacques Pienaar 2020-08-28 15:10:56 -07:00 committed by TensorFlow MLIR Team
parent 7176fb1839
commit 344c500fca
5 changed files with 246 additions and 0 deletions

View File

@ -30,6 +30,11 @@ def LegalizeControlFlowPass : Pass<"mhlo-legalize-control-flow", "FuncOp"> {
let constructor = "createLegalizeControlFlowPass()"; let constructor = "createLegalizeControlFlowPass()";
} }
def LegalizeControlFlowToScfPass : Pass<"mhlo-control-flow-to-scf", "FuncOp"> {
let summary = "Legalize from MHLO control flow to SCF control flow.";
let constructor = "createControlFlowToScfPass()";
}
def LegalizeGatherToTorchIndexSelectPass : Pass<"mhlo-legalize-gather-to-torch-index-select", "FuncOp"> { def LegalizeGatherToTorchIndexSelectPass : Pass<"mhlo-legalize-gather-to-torch-index-select", "FuncOp"> {
let summary = "Legalizes gathers to a torch index select."; let summary = "Legalizes gathers to a torch index select.";
let constructor = "createLegalizeGatherToTorchIndexSelectPass()"; let constructor = "createLegalizeGatherToTorchIndexSelectPass()";

View File

@ -35,6 +35,9 @@ namespace mhlo {
/// Lowers HLO control flow ops to the Standard dialect. /// Lowers HLO control flow ops to the Standard dialect.
std::unique_ptr<OperationPass<FuncOp>> createLegalizeControlFlowPass(); std::unique_ptr<OperationPass<FuncOp>> createLegalizeControlFlowPass();
/// Lowers MHLO control flow ops to the SCF dialect.
std::unique_ptr<OperationPass<FuncOp>> createControlFlowToScfPass();
/// Lowers from HLO dialect to Standard dialect. /// Lowers from HLO dialect to Standard dialect.
std::unique_ptr<OperationPass<FuncOp>> createLegalizeToStdPass(); std::unique_ptr<OperationPass<FuncOp>> createLegalizeToStdPass();

View File

@ -93,6 +93,7 @@ add_mlir_library(MhloToLhloConversion
add_mlir_library(MhloToStandard add_mlir_library(MhloToStandard
legalize_control_flow.cc legalize_control_flow.cc
legalize_to_standard.cc legalize_to_standard.cc
mhlo_control_flow_to_scf.cc
DEPENDS DEPENDS
MLIRhlo_opsIncGen MLIRhlo_opsIncGen

View File

@ -0,0 +1,199 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "llvm/Support/Casting.h"
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
#include "mlir/Dialect/SCF/SCF.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/IR/Value.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Support/LLVM.h"
#define DEBUG_TYPE "mhlo-control-flow-to-scf"
namespace mlir {
namespace mhlo {
namespace {
/// Convert MHLO While to SCF.
void MatchAndRewrite(WhileOp whileOp);
/// Pass that converts MHLO control flow to SCF.
class ControlFlowToScfPass
: public mlir::PassWrapper<ControlFlowToScfPass, FunctionPass> {
void getDependentDialects(DialectRegistry& registry) const override {
registry.insert<scf::SCFDialect>();
}
void runOnFunction() override {
getFunction().walk([&](WhileOp whileOp) { MatchAndRewrite(whileOp); });
}
};
// TODO(jpienaar): Look into reformulating as a pattern.
void MatchAndRewrite(WhileOp whileOp) {
// Handle pattern:
// x = start
// step = ...
// limit = ...
// while (x < limit) { ... x += step; }
// Only handling multi value while loops at the moment.
auto tupleOp = whileOp.getOperand().getDefiningOp<TupleOp>();
if (!tupleOp) return;
auto bodyReturn = whileOp.body()
.front()
.getTerminator()
->getOperand(0)
.getDefiningOp<mhlo::TupleOp>();
// Note: due to the shape restrictions on While, if the operand to While is a
// tuple, then so is the return type of the body. But the verifier isn't
// checking that at the moment, so just bail out here if this doesn't hold.
if (!bodyReturn) return;
Value result = whileOp.cond().front().getTerminator()->getOperand(0);
// TODO(jpienaar): Expand to handle more than simple case with LT compare and
// constant step.
auto cmp = result.getDefiningOp<mhlo::CompareOp>();
if (!cmp || cmp.comparison_direction() != "LT") return;
const int kConstant = -1;
auto getValueAndIndex = [&](Value val) -> std::pair<Value, int> {
if (matchPattern(val, m_Constant())) return {val, kConstant};
// If it is defined by a tuple, then the tuple has to have been fed in and
// the external value is captured.
if (auto gte = val.getDefiningOp<GetTupleElementOp>()) {
if (!gte.getOperand().isa<mlir::BlockArgument>()) return {nullptr, 0};
int index = gte.index().getSExtValue();
return {tupleOp.getOperand(index), index};
}
return {nullptr, 0};
};
using ValueIndex = std::pair<Value, int>;
ValueIndex loopIndVar = getValueAndIndex(cmp.lhs());
ValueIndex max = getValueAndIndex(cmp.rhs());
if (!loopIndVar.first || !max.first) return;
auto add =
bodyReturn.getOperand(loopIndVar.second).getDefiningOp<mhlo::AddOp>();
if (!add) return;
ValueIndex step = getValueAndIndex(add.rhs());
if (step.second != kConstant || !step.first) return;
// Only handle case where tuple isn't propagated as is for now.
// TODO(jpienaar): Remove this when a tuple is also created inside the loop
// to propagate.
for (auto* use : whileOp.body().front().getArgument(0).getUsers())
if (!isa<GetTupleElementOp>(use)) return;
LLVM_DEBUG(llvm::dbgs() << "Found for (" << whileOp.getLoc() << "):\n";
llvm::dbgs() << " loopIndVar = " << loopIndVar.second << " max = "
<< max.second << " step = " << step.second << "\n";
llvm::dbgs() << " loopIndVar = " << loopIndVar.first << " max = "
<< max.first << " step = " << step.first << "\n";);
OpBuilder b(whileOp);
// Inputs to new for loop.
llvm::SmallVector<Value, 4> input;
input.reserve(tupleOp.getNumOperands());
for (auto r : tupleOp.getOperands().take_front(loopIndVar.second))
input.push_back(r);
for (auto r : tupleOp.getOperands().drop_front(loopIndVar.second + 1))
input.push_back(r);
auto tensorIndexType = RankedTensorType::get({}, b.getIndexType());
auto getAsIndex = [&](Value val) {
auto loc = whileOp.getLoc();
return b.create<ExtractElementOp>(
loc, b.create<IndexCastOp>(loc, tensorIndexType, val), ValueRange());
};
// SCF for uses index type, so converted these.
auto forloopIndVar = getAsIndex(loopIndVar.first);
auto forMax = getAsIndex(max.first);
auto forStep = getAsIndex(step.first);
auto forOp = b.create<mlir::scf::ForOp>(whileOp.getLoc(), forloopIndVar,
forMax, forStep, input);
// Transfer the body without the block arguments.
forOp.getLoopBody().front().getOperations().splice(
forOp.getLoopBody().front().getOperations().end(),
whileOp.body().front().getOperations());
b.setInsertionPointToStart(&forOp.getLoopBody().front());
auto loopIndVarElType =
loopIndVar.first.getType().cast<ShapedType>().getElementType();
Value indVar = b.create<SplatOp>(
whileOp.getLoc(), RankedTensorType::get({}, loopIndVarElType),
b.create<IndexCastOp>(whileOp.getLoc(), loopIndVarElType,
forOp.getInductionVar()));
// Update all block argument users to the SCF For args.
for (auto* use :
llvm::make_early_inc_range(whileOp.body().getArgument(0).getUsers())) {
// TODO(jpienaar): Expand here too when we allow using the tuple in the
// loop.
auto gte = cast<GetTupleElementOp>(use);
// If the loop induction var, then refer to the loop induction variable as
// this operand is not updated.
if (gte.index() == loopIndVar.second) {
use->getResult(0).replaceAllUsesWith(indVar);
use->erase();
continue;
}
int index = gte.index().getSExtValue();
// If after the loop induction variable, then decrement as we don't include
// the loop induction variable in the for iter operands.
if (index > loopIndVar.second) --index;
use->getResult(0).replaceAllUsesWith(forOp.getIterOperands()[index]);
use->erase();
}
// Create new yield op without induction var update.
SmallVector<Value, 4> newYieldOps;
newYieldOps.reserve(bodyReturn.getNumOperands() - 1);
for (auto r : bodyReturn.getOperands().take_front(loopIndVar.second))
newYieldOps.push_back(r);
for (auto r : bodyReturn.getOperands().drop_front(loopIndVar.second + 1))
newYieldOps.push_back(r);
// Delete return & tuple op.
forOp.getLoopBody().front().back().erase();
forOp.getLoopBody().front().back().erase();
b.setInsertionPointToEnd(&forOp.getLoopBody().front());
b.create<scf::YieldOp>(whileOp.getLoc(), newYieldOps);
// Recombine output tuple with max value of induction variable.
llvm::SmallVector<Value, 4> loopOut;
loopOut.reserve(forOp.getNumResults() + 1);
for (auto r : forOp.getResults().take_front(loopIndVar.second))
loopOut.push_back(r);
loopOut.push_back(max.first);
for (auto r : forOp.getResults().drop_front(loopIndVar.second))
loopOut.push_back(r);
b.setInsertionPoint(whileOp);
auto newRes = b.create<mhlo::TupleOp>(whileOp.getLoc(), loopOut);
whileOp.replaceAllUsesWith(newRes.getOperation());
whileOp.erase();
}
} // anonymous namespace
std::unique_ptr<OperationPass<FuncOp>> createControlFlowToScfPass() {
return std::make_unique<ControlFlowToScfPass>();
}
} // namespace mhlo
} // namespace mlir

View File

@ -0,0 +1,38 @@
// RUN: mlir-hlo-opt --mhlo-control-flow-to-scf %s | FileCheck %s
func @lt_loop(%arg0: tensor<4xf32>, %arg1: tensor<f32>, %arg2: tensor<f32>, %arg3: tensor<4xf32>, %arg4: tensor<f32>, %arg5: tensor<f32>, %arg6: tensor<f32>, %arg7: tensor<f32>, %arg8: tensor<i32>) -> (tuple<tensor<i32>, tensor<i32>, tensor<i32>>) {
%cst = constant dense<-1> : tensor<i32>
%cst_0 = constant dense<1> : tensor<i32>
%cst_1 = constant dense<0> : tensor<i32>
%cst_2 = constant dense<1000> : tensor<i32>
%0 = "mhlo.tuple"(%cst_1, %cst, %cst_2) : (tensor<i32>, tensor<i32>, tensor<i32>) -> tuple<tensor<i32>, tensor<i32>, tensor<i32>>
%1 = "mhlo.while"(%0) ( {
^bb0(%arg9: tuple<tensor<i32>, tensor<i32>, tensor<i32>>): // no predecessors
%2 = "mhlo.get_tuple_element"(%arg9) {index = 0 : i32} : (tuple<tensor<i32>, tensor<i32>, tensor<i32>>) -> tensor<i32>
%3 = "mhlo.get_tuple_element"(%arg9) {index = 2 : i32} : (tuple<tensor<i32>, tensor<i32>, tensor<i32>>) -> tensor<i32>
%4 = "mhlo.compare"(%2, %3) {comparison_direction = "LT"} : (tensor<i32>, tensor<i32>) -> tensor<i1>
"mhlo.return"(%4) : (tensor<i1>) -> ()
}, {
^bb0(%arg9: tuple<tensor<i32>, tensor<i32>, tensor<i32>>): // no predecessors
%2 = "mhlo.get_tuple_element"(%arg9) {index = 0 : i32} : (tuple<tensor<i32>, tensor<i32>, tensor<i32>>) -> tensor<i32>
%3 = mhlo.add %2, %cst_0 : tensor<i32>
%4 = "mhlo.get_tuple_element"(%arg9) {index = 1 : i32} : (tuple<tensor<i32>, tensor<i32>, tensor<i32>>) -> tensor<i32>
%5 = "mhlo.get_tuple_element"(%arg9) {index = 2 : i32} : (tuple<tensor<i32>, tensor<i32>, tensor<i32>>) -> tensor<i32>
%6 = "mhlo.tuple"(%3, %4, %5) : (tensor<i32>, tensor<i32>, tensor<i32>) -> tuple<tensor<i32>, tensor<i32>, tensor<i32>>
"mhlo.return"(%6) : (tuple<tensor<i32>, tensor<i32>, tensor<i32>>) -> ()
}) : (tuple<tensor<i32>, tensor<i32>, tensor<i32>>) -> tuple<tensor<i32>, tensor<i32>, tensor<i32>>
return %1 : tuple<tensor<i32>, tensor<i32>, tensor<i32>>
}
// CHECK-LABEL: func @lt_loop(
// CHECK: %[[VAL_9:.*]] = constant dense<-1> : tensor<i32>
// CHECK: %[[VAL_10:.*]] = constant dense<1> : tensor<i32>
// CHECK: %[[VAL_11:.*]] = constant dense<0> : tensor<i32>
// CHECK: %[[VAL_12:.*]] = constant dense<1000> : tensor<i32>
// CHECK: %[[VAL_14:.*]] = index_cast %[[VAL_11]] : tensor<i32> to tensor<index>
// CHECK: %[[VAL_15:.*]] = extract_element %[[VAL_14]][] : tensor<index>
// CHECK: %[[VAL_16:.*]] = index_cast %[[VAL_12]] : tensor<i32> to tensor<index>
// CHECK: %[[VAL_17:.*]] = extract_element %[[VAL_16]][] : tensor<index>
// CHECK: %[[VAL_18:.*]] = index_cast %[[VAL_10]] : tensor<i32> to tensor<index>
// CHECK: %[[VAL_19:.*]] = extract_element %[[VAL_18]][] : tensor<index>
// CHECK: scf.for %[[VAL_21:.*]] = %[[VAL_15]] to %[[VAL_17]] step %[[VAL_19]] iter_args(%[[VAL_22:.*]] = %[[VAL_9]], %[[VAL_23:.*]] = %[[VAL_12]])