[mhlo] Add legalize to SCF pass
Start of pass to legalize MHLO control flow to SCF for further optimization in common form. The current version just matches a very simple instance (which also happens to occur a few times). Exposes some further canonicalization opportunities that aren't yet addressed. PiperOrigin-RevId: 329017723
This commit is contained in:
parent
7176fb1839
commit
344c500fca
|
@ -30,6 +30,11 @@ def LegalizeControlFlowPass : Pass<"mhlo-legalize-control-flow", "FuncOp"> {
|
|||
let constructor = "createLegalizeControlFlowPass()";
|
||||
}
|
||||
|
||||
def LegalizeControlFlowToScfPass : Pass<"mhlo-control-flow-to-scf", "FuncOp"> {
|
||||
let summary = "Legalize from MHLO control flow to SCF control flow.";
|
||||
let constructor = "createControlFlowToScfPass()";
|
||||
}
|
||||
|
||||
def LegalizeGatherToTorchIndexSelectPass : Pass<"mhlo-legalize-gather-to-torch-index-select", "FuncOp"> {
|
||||
let summary = "Legalizes gathers to a torch index select.";
|
||||
let constructor = "createLegalizeGatherToTorchIndexSelectPass()";
|
||||
|
|
|
@ -35,6 +35,9 @@ namespace mhlo {
|
|||
/// Lowers HLO control flow ops to the Standard dialect.
|
||||
std::unique_ptr<OperationPass<FuncOp>> createLegalizeControlFlowPass();
|
||||
|
||||
/// Lowers MHLO control flow ops to the SCF dialect.
|
||||
std::unique_ptr<OperationPass<FuncOp>> createControlFlowToScfPass();
|
||||
|
||||
/// Lowers from HLO dialect to Standard dialect.
|
||||
std::unique_ptr<OperationPass<FuncOp>> createLegalizeToStdPass();
|
||||
|
||||
|
|
|
@ -93,6 +93,7 @@ add_mlir_library(MhloToLhloConversion
|
|||
add_mlir_library(MhloToStandard
|
||||
legalize_control_flow.cc
|
||||
legalize_to_standard.cc
|
||||
mhlo_control_flow_to_scf.cc
|
||||
|
||||
DEPENDS
|
||||
MLIRhlo_opsIncGen
|
||||
|
|
|
@ -0,0 +1,199 @@
|
|||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "llvm/Support/Casting.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
|
||||
#include "mlir/Dialect/SCF/SCF.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/IR/Matchers.h"
|
||||
#include "mlir/IR/Operation.h"
|
||||
#include "mlir/IR/StandardTypes.h"
|
||||
#include "mlir/IR/Value.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Support/LLVM.h"
|
||||
|
||||
#define DEBUG_TYPE "mhlo-control-flow-to-scf"
|
||||
|
||||
namespace mlir {
|
||||
namespace mhlo {
|
||||
|
||||
namespace {
|
||||
|
||||
/// Convert MHLO While to SCF.
|
||||
void MatchAndRewrite(WhileOp whileOp);
|
||||
|
||||
/// Pass that converts MHLO control flow to SCF.
|
||||
class ControlFlowToScfPass
|
||||
: public mlir::PassWrapper<ControlFlowToScfPass, FunctionPass> {
|
||||
void getDependentDialects(DialectRegistry& registry) const override {
|
||||
registry.insert<scf::SCFDialect>();
|
||||
}
|
||||
void runOnFunction() override {
|
||||
getFunction().walk([&](WhileOp whileOp) { MatchAndRewrite(whileOp); });
|
||||
}
|
||||
};
|
||||
|
||||
// TODO(jpienaar): Look into reformulating as a pattern.
|
||||
void MatchAndRewrite(WhileOp whileOp) {
|
||||
// Handle pattern:
|
||||
// x = start
|
||||
// step = ...
|
||||
// limit = ...
|
||||
// while (x < limit) { ... x += step; }
|
||||
|
||||
// Only handling multi value while loops at the moment.
|
||||
auto tupleOp = whileOp.getOperand().getDefiningOp<TupleOp>();
|
||||
if (!tupleOp) return;
|
||||
auto bodyReturn = whileOp.body()
|
||||
.front()
|
||||
.getTerminator()
|
||||
->getOperand(0)
|
||||
.getDefiningOp<mhlo::TupleOp>();
|
||||
// Note: due to the shape restrictions on While, if the operand to While is a
|
||||
// tuple, then so is the return type of the body. But the verifier isn't
|
||||
// checking that at the moment, so just bail out here if this doesn't hold.
|
||||
if (!bodyReturn) return;
|
||||
|
||||
Value result = whileOp.cond().front().getTerminator()->getOperand(0);
|
||||
// TODO(jpienaar): Expand to handle more than simple case with LT compare and
|
||||
// constant step.
|
||||
auto cmp = result.getDefiningOp<mhlo::CompareOp>();
|
||||
if (!cmp || cmp.comparison_direction() != "LT") return;
|
||||
|
||||
const int kConstant = -1;
|
||||
auto getValueAndIndex = [&](Value val) -> std::pair<Value, int> {
|
||||
if (matchPattern(val, m_Constant())) return {val, kConstant};
|
||||
// If it is defined by a tuple, then the tuple has to have been fed in and
|
||||
// the external value is captured.
|
||||
if (auto gte = val.getDefiningOp<GetTupleElementOp>()) {
|
||||
if (!gte.getOperand().isa<mlir::BlockArgument>()) return {nullptr, 0};
|
||||
int index = gte.index().getSExtValue();
|
||||
return {tupleOp.getOperand(index), index};
|
||||
}
|
||||
return {nullptr, 0};
|
||||
};
|
||||
|
||||
using ValueIndex = std::pair<Value, int>;
|
||||
ValueIndex loopIndVar = getValueAndIndex(cmp.lhs());
|
||||
ValueIndex max = getValueAndIndex(cmp.rhs());
|
||||
if (!loopIndVar.first || !max.first) return;
|
||||
auto add =
|
||||
bodyReturn.getOperand(loopIndVar.second).getDefiningOp<mhlo::AddOp>();
|
||||
if (!add) return;
|
||||
ValueIndex step = getValueAndIndex(add.rhs());
|
||||
if (step.second != kConstant || !step.first) return;
|
||||
|
||||
// Only handle case where tuple isn't propagated as is for now.
|
||||
// TODO(jpienaar): Remove this when a tuple is also created inside the loop
|
||||
// to propagate.
|
||||
for (auto* use : whileOp.body().front().getArgument(0).getUsers())
|
||||
if (!isa<GetTupleElementOp>(use)) return;
|
||||
|
||||
LLVM_DEBUG(llvm::dbgs() << "Found for (" << whileOp.getLoc() << "):\n";
|
||||
llvm::dbgs() << " loopIndVar = " << loopIndVar.second << " max = "
|
||||
<< max.second << " step = " << step.second << "\n";
|
||||
llvm::dbgs() << " loopIndVar = " << loopIndVar.first << " max = "
|
||||
<< max.first << " step = " << step.first << "\n";);
|
||||
OpBuilder b(whileOp);
|
||||
// Inputs to new for loop.
|
||||
llvm::SmallVector<Value, 4> input;
|
||||
input.reserve(tupleOp.getNumOperands());
|
||||
for (auto r : tupleOp.getOperands().take_front(loopIndVar.second))
|
||||
input.push_back(r);
|
||||
for (auto r : tupleOp.getOperands().drop_front(loopIndVar.second + 1))
|
||||
input.push_back(r);
|
||||
|
||||
auto tensorIndexType = RankedTensorType::get({}, b.getIndexType());
|
||||
auto getAsIndex = [&](Value val) {
|
||||
auto loc = whileOp.getLoc();
|
||||
return b.create<ExtractElementOp>(
|
||||
loc, b.create<IndexCastOp>(loc, tensorIndexType, val), ValueRange());
|
||||
};
|
||||
|
||||
// SCF for uses index type, so converted these.
|
||||
auto forloopIndVar = getAsIndex(loopIndVar.first);
|
||||
auto forMax = getAsIndex(max.first);
|
||||
auto forStep = getAsIndex(step.first);
|
||||
auto forOp = b.create<mlir::scf::ForOp>(whileOp.getLoc(), forloopIndVar,
|
||||
forMax, forStep, input);
|
||||
// Transfer the body without the block arguments.
|
||||
forOp.getLoopBody().front().getOperations().splice(
|
||||
forOp.getLoopBody().front().getOperations().end(),
|
||||
whileOp.body().front().getOperations());
|
||||
|
||||
b.setInsertionPointToStart(&forOp.getLoopBody().front());
|
||||
auto loopIndVarElType =
|
||||
loopIndVar.first.getType().cast<ShapedType>().getElementType();
|
||||
Value indVar = b.create<SplatOp>(
|
||||
whileOp.getLoc(), RankedTensorType::get({}, loopIndVarElType),
|
||||
b.create<IndexCastOp>(whileOp.getLoc(), loopIndVarElType,
|
||||
forOp.getInductionVar()));
|
||||
// Update all block argument users to the SCF For args.
|
||||
for (auto* use :
|
||||
llvm::make_early_inc_range(whileOp.body().getArgument(0).getUsers())) {
|
||||
// TODO(jpienaar): Expand here too when we allow using the tuple in the
|
||||
// loop.
|
||||
auto gte = cast<GetTupleElementOp>(use);
|
||||
// If the loop induction var, then refer to the loop induction variable as
|
||||
// this operand is not updated.
|
||||
if (gte.index() == loopIndVar.second) {
|
||||
use->getResult(0).replaceAllUsesWith(indVar);
|
||||
use->erase();
|
||||
continue;
|
||||
}
|
||||
int index = gte.index().getSExtValue();
|
||||
// If after the loop induction variable, then decrement as we don't include
|
||||
// the loop induction variable in the for iter operands.
|
||||
if (index > loopIndVar.second) --index;
|
||||
use->getResult(0).replaceAllUsesWith(forOp.getIterOperands()[index]);
|
||||
use->erase();
|
||||
}
|
||||
|
||||
// Create new yield op without induction var update.
|
||||
SmallVector<Value, 4> newYieldOps;
|
||||
newYieldOps.reserve(bodyReturn.getNumOperands() - 1);
|
||||
for (auto r : bodyReturn.getOperands().take_front(loopIndVar.second))
|
||||
newYieldOps.push_back(r);
|
||||
for (auto r : bodyReturn.getOperands().drop_front(loopIndVar.second + 1))
|
||||
newYieldOps.push_back(r);
|
||||
// Delete return & tuple op.
|
||||
forOp.getLoopBody().front().back().erase();
|
||||
forOp.getLoopBody().front().back().erase();
|
||||
b.setInsertionPointToEnd(&forOp.getLoopBody().front());
|
||||
b.create<scf::YieldOp>(whileOp.getLoc(), newYieldOps);
|
||||
|
||||
// Recombine output tuple with max value of induction variable.
|
||||
llvm::SmallVector<Value, 4> loopOut;
|
||||
loopOut.reserve(forOp.getNumResults() + 1);
|
||||
for (auto r : forOp.getResults().take_front(loopIndVar.second))
|
||||
loopOut.push_back(r);
|
||||
loopOut.push_back(max.first);
|
||||
for (auto r : forOp.getResults().drop_front(loopIndVar.second))
|
||||
loopOut.push_back(r);
|
||||
b.setInsertionPoint(whileOp);
|
||||
auto newRes = b.create<mhlo::TupleOp>(whileOp.getLoc(), loopOut);
|
||||
whileOp.replaceAllUsesWith(newRes.getOperation());
|
||||
whileOp.erase();
|
||||
}
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
std::unique_ptr<OperationPass<FuncOp>> createControlFlowToScfPass() {
|
||||
return std::make_unique<ControlFlowToScfPass>();
|
||||
}
|
||||
|
||||
} // namespace mhlo
|
||||
} // namespace mlir
|
|
@ -0,0 +1,38 @@
|
|||
// RUN: mlir-hlo-opt --mhlo-control-flow-to-scf %s | FileCheck %s
|
||||
|
||||
func @lt_loop(%arg0: tensor<4xf32>, %arg1: tensor<f32>, %arg2: tensor<f32>, %arg3: tensor<4xf32>, %arg4: tensor<f32>, %arg5: tensor<f32>, %arg6: tensor<f32>, %arg7: tensor<f32>, %arg8: tensor<i32>) -> (tuple<tensor<i32>, tensor<i32>, tensor<i32>>) {
|
||||
%cst = constant dense<-1> : tensor<i32>
|
||||
%cst_0 = constant dense<1> : tensor<i32>
|
||||
%cst_1 = constant dense<0> : tensor<i32>
|
||||
%cst_2 = constant dense<1000> : tensor<i32>
|
||||
%0 = "mhlo.tuple"(%cst_1, %cst, %cst_2) : (tensor<i32>, tensor<i32>, tensor<i32>) -> tuple<tensor<i32>, tensor<i32>, tensor<i32>>
|
||||
%1 = "mhlo.while"(%0) ( {
|
||||
^bb0(%arg9: tuple<tensor<i32>, tensor<i32>, tensor<i32>>): // no predecessors
|
||||
%2 = "mhlo.get_tuple_element"(%arg9) {index = 0 : i32} : (tuple<tensor<i32>, tensor<i32>, tensor<i32>>) -> tensor<i32>
|
||||
%3 = "mhlo.get_tuple_element"(%arg9) {index = 2 : i32} : (tuple<tensor<i32>, tensor<i32>, tensor<i32>>) -> tensor<i32>
|
||||
%4 = "mhlo.compare"(%2, %3) {comparison_direction = "LT"} : (tensor<i32>, tensor<i32>) -> tensor<i1>
|
||||
"mhlo.return"(%4) : (tensor<i1>) -> ()
|
||||
}, {
|
||||
^bb0(%arg9: tuple<tensor<i32>, tensor<i32>, tensor<i32>>): // no predecessors
|
||||
%2 = "mhlo.get_tuple_element"(%arg9) {index = 0 : i32} : (tuple<tensor<i32>, tensor<i32>, tensor<i32>>) -> tensor<i32>
|
||||
%3 = mhlo.add %2, %cst_0 : tensor<i32>
|
||||
%4 = "mhlo.get_tuple_element"(%arg9) {index = 1 : i32} : (tuple<tensor<i32>, tensor<i32>, tensor<i32>>) -> tensor<i32>
|
||||
%5 = "mhlo.get_tuple_element"(%arg9) {index = 2 : i32} : (tuple<tensor<i32>, tensor<i32>, tensor<i32>>) -> tensor<i32>
|
||||
%6 = "mhlo.tuple"(%3, %4, %5) : (tensor<i32>, tensor<i32>, tensor<i32>) -> tuple<tensor<i32>, tensor<i32>, tensor<i32>>
|
||||
"mhlo.return"(%6) : (tuple<tensor<i32>, tensor<i32>, tensor<i32>>) -> ()
|
||||
}) : (tuple<tensor<i32>, tensor<i32>, tensor<i32>>) -> tuple<tensor<i32>, tensor<i32>, tensor<i32>>
|
||||
return %1 : tuple<tensor<i32>, tensor<i32>, tensor<i32>>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @lt_loop(
|
||||
// CHECK: %[[VAL_9:.*]] = constant dense<-1> : tensor<i32>
|
||||
// CHECK: %[[VAL_10:.*]] = constant dense<1> : tensor<i32>
|
||||
// CHECK: %[[VAL_11:.*]] = constant dense<0> : tensor<i32>
|
||||
// CHECK: %[[VAL_12:.*]] = constant dense<1000> : tensor<i32>
|
||||
// CHECK: %[[VAL_14:.*]] = index_cast %[[VAL_11]] : tensor<i32> to tensor<index>
|
||||
// CHECK: %[[VAL_15:.*]] = extract_element %[[VAL_14]][] : tensor<index>
|
||||
// CHECK: %[[VAL_16:.*]] = index_cast %[[VAL_12]] : tensor<i32> to tensor<index>
|
||||
// CHECK: %[[VAL_17:.*]] = extract_element %[[VAL_16]][] : tensor<index>
|
||||
// CHECK: %[[VAL_18:.*]] = index_cast %[[VAL_10]] : tensor<i32> to tensor<index>
|
||||
// CHECK: %[[VAL_19:.*]] = extract_element %[[VAL_18]][] : tensor<index>
|
||||
// CHECK: scf.for %[[VAL_21:.*]] = %[[VAL_15]] to %[[VAL_17]] step %[[VAL_19]] iter_args(%[[VAL_22:.*]] = %[[VAL_9]], %[[VAL_23:.*]] = %[[VAL_12]])
|
Loading…
Reference in New Issue