[mhlo] Add legalize to SCF pass
Start of pass to legalize MHLO control flow to SCF for further optimization in common form. The current version just matches a very simple instance (which also happens to occur a few times). Exposes some further canonicalization opportunities that aren't yet addressed. PiperOrigin-RevId: 329017723
This commit is contained in:
parent
7176fb1839
commit
344c500fca
|
@ -30,6 +30,11 @@ def LegalizeControlFlowPass : Pass<"mhlo-legalize-control-flow", "FuncOp"> {
|
||||||
let constructor = "createLegalizeControlFlowPass()";
|
let constructor = "createLegalizeControlFlowPass()";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def LegalizeControlFlowToScfPass : Pass<"mhlo-control-flow-to-scf", "FuncOp"> {
|
||||||
|
let summary = "Legalize from MHLO control flow to SCF control flow.";
|
||||||
|
let constructor = "createControlFlowToScfPass()";
|
||||||
|
}
|
||||||
|
|
||||||
def LegalizeGatherToTorchIndexSelectPass : Pass<"mhlo-legalize-gather-to-torch-index-select", "FuncOp"> {
|
def LegalizeGatherToTorchIndexSelectPass : Pass<"mhlo-legalize-gather-to-torch-index-select", "FuncOp"> {
|
||||||
let summary = "Legalizes gathers to a torch index select.";
|
let summary = "Legalizes gathers to a torch index select.";
|
||||||
let constructor = "createLegalizeGatherToTorchIndexSelectPass()";
|
let constructor = "createLegalizeGatherToTorchIndexSelectPass()";
|
||||||
|
|
|
@ -35,6 +35,9 @@ namespace mhlo {
|
||||||
/// Lowers HLO control flow ops to the Standard dialect.
|
/// Lowers HLO control flow ops to the Standard dialect.
|
||||||
std::unique_ptr<OperationPass<FuncOp>> createLegalizeControlFlowPass();
|
std::unique_ptr<OperationPass<FuncOp>> createLegalizeControlFlowPass();
|
||||||
|
|
||||||
|
/// Lowers MHLO control flow ops to the SCF dialect.
|
||||||
|
std::unique_ptr<OperationPass<FuncOp>> createControlFlowToScfPass();
|
||||||
|
|
||||||
/// Lowers from HLO dialect to Standard dialect.
|
/// Lowers from HLO dialect to Standard dialect.
|
||||||
std::unique_ptr<OperationPass<FuncOp>> createLegalizeToStdPass();
|
std::unique_ptr<OperationPass<FuncOp>> createLegalizeToStdPass();
|
||||||
|
|
||||||
|
|
|
@ -93,6 +93,7 @@ add_mlir_library(MhloToLhloConversion
|
||||||
add_mlir_library(MhloToStandard
|
add_mlir_library(MhloToStandard
|
||||||
legalize_control_flow.cc
|
legalize_control_flow.cc
|
||||||
legalize_to_standard.cc
|
legalize_to_standard.cc
|
||||||
|
mhlo_control_flow_to_scf.cc
|
||||||
|
|
||||||
DEPENDS
|
DEPENDS
|
||||||
MLIRhlo_opsIncGen
|
MLIRhlo_opsIncGen
|
||||||
|
|
|
@ -0,0 +1,199 @@
|
||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "llvm/Support/Casting.h"
|
||||||
|
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
||||||
|
#include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
|
||||||
|
#include "mlir/Dialect/SCF/SCF.h"
|
||||||
|
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||||
|
#include "mlir/IR/Matchers.h"
|
||||||
|
#include "mlir/IR/Operation.h"
|
||||||
|
#include "mlir/IR/StandardTypes.h"
|
||||||
|
#include "mlir/IR/Value.h"
|
||||||
|
#include "mlir/Pass/Pass.h"
|
||||||
|
#include "mlir/Support/LLVM.h"
|
||||||
|
|
||||||
|
#define DEBUG_TYPE "mhlo-control-flow-to-scf"
|
||||||
|
|
||||||
|
namespace mlir {
|
||||||
|
namespace mhlo {
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
/// Convert MHLO While to SCF.
|
||||||
|
void MatchAndRewrite(WhileOp whileOp);
|
||||||
|
|
||||||
|
/// Pass that converts MHLO control flow to SCF.
|
||||||
|
class ControlFlowToScfPass
|
||||||
|
: public mlir::PassWrapper<ControlFlowToScfPass, FunctionPass> {
|
||||||
|
void getDependentDialects(DialectRegistry& registry) const override {
|
||||||
|
registry.insert<scf::SCFDialect>();
|
||||||
|
}
|
||||||
|
void runOnFunction() override {
|
||||||
|
getFunction().walk([&](WhileOp whileOp) { MatchAndRewrite(whileOp); });
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// TODO(jpienaar): Look into reformulating as a pattern.
|
||||||
|
void MatchAndRewrite(WhileOp whileOp) {
|
||||||
|
// Handle pattern:
|
||||||
|
// x = start
|
||||||
|
// step = ...
|
||||||
|
// limit = ...
|
||||||
|
// while (x < limit) { ... x += step; }
|
||||||
|
|
||||||
|
// Only handling multi value while loops at the moment.
|
||||||
|
auto tupleOp = whileOp.getOperand().getDefiningOp<TupleOp>();
|
||||||
|
if (!tupleOp) return;
|
||||||
|
auto bodyReturn = whileOp.body()
|
||||||
|
.front()
|
||||||
|
.getTerminator()
|
||||||
|
->getOperand(0)
|
||||||
|
.getDefiningOp<mhlo::TupleOp>();
|
||||||
|
// Note: due to the shape restrictions on While, if the operand to While is a
|
||||||
|
// tuple, then so is the return type of the body. But the verifier isn't
|
||||||
|
// checking that at the moment, so just bail out here if this doesn't hold.
|
||||||
|
if (!bodyReturn) return;
|
||||||
|
|
||||||
|
Value result = whileOp.cond().front().getTerminator()->getOperand(0);
|
||||||
|
// TODO(jpienaar): Expand to handle more than simple case with LT compare and
|
||||||
|
// constant step.
|
||||||
|
auto cmp = result.getDefiningOp<mhlo::CompareOp>();
|
||||||
|
if (!cmp || cmp.comparison_direction() != "LT") return;
|
||||||
|
|
||||||
|
const int kConstant = -1;
|
||||||
|
auto getValueAndIndex = [&](Value val) -> std::pair<Value, int> {
|
||||||
|
if (matchPattern(val, m_Constant())) return {val, kConstant};
|
||||||
|
// If it is defined by a tuple, then the tuple has to have been fed in and
|
||||||
|
// the external value is captured.
|
||||||
|
if (auto gte = val.getDefiningOp<GetTupleElementOp>()) {
|
||||||
|
if (!gte.getOperand().isa<mlir::BlockArgument>()) return {nullptr, 0};
|
||||||
|
int index = gte.index().getSExtValue();
|
||||||
|
return {tupleOp.getOperand(index), index};
|
||||||
|
}
|
||||||
|
return {nullptr, 0};
|
||||||
|
};
|
||||||
|
|
||||||
|
using ValueIndex = std::pair<Value, int>;
|
||||||
|
ValueIndex loopIndVar = getValueAndIndex(cmp.lhs());
|
||||||
|
ValueIndex max = getValueAndIndex(cmp.rhs());
|
||||||
|
if (!loopIndVar.first || !max.first) return;
|
||||||
|
auto add =
|
||||||
|
bodyReturn.getOperand(loopIndVar.second).getDefiningOp<mhlo::AddOp>();
|
||||||
|
if (!add) return;
|
||||||
|
ValueIndex step = getValueAndIndex(add.rhs());
|
||||||
|
if (step.second != kConstant || !step.first) return;
|
||||||
|
|
||||||
|
// Only handle case where tuple isn't propagated as is for now.
|
||||||
|
// TODO(jpienaar): Remove this when a tuple is also created inside the loop
|
||||||
|
// to propagate.
|
||||||
|
for (auto* use : whileOp.body().front().getArgument(0).getUsers())
|
||||||
|
if (!isa<GetTupleElementOp>(use)) return;
|
||||||
|
|
||||||
|
LLVM_DEBUG(llvm::dbgs() << "Found for (" << whileOp.getLoc() << "):\n";
|
||||||
|
llvm::dbgs() << " loopIndVar = " << loopIndVar.second << " max = "
|
||||||
|
<< max.second << " step = " << step.second << "\n";
|
||||||
|
llvm::dbgs() << " loopIndVar = " << loopIndVar.first << " max = "
|
||||||
|
<< max.first << " step = " << step.first << "\n";);
|
||||||
|
OpBuilder b(whileOp);
|
||||||
|
// Inputs to new for loop.
|
||||||
|
llvm::SmallVector<Value, 4> input;
|
||||||
|
input.reserve(tupleOp.getNumOperands());
|
||||||
|
for (auto r : tupleOp.getOperands().take_front(loopIndVar.second))
|
||||||
|
input.push_back(r);
|
||||||
|
for (auto r : tupleOp.getOperands().drop_front(loopIndVar.second + 1))
|
||||||
|
input.push_back(r);
|
||||||
|
|
||||||
|
auto tensorIndexType = RankedTensorType::get({}, b.getIndexType());
|
||||||
|
auto getAsIndex = [&](Value val) {
|
||||||
|
auto loc = whileOp.getLoc();
|
||||||
|
return b.create<ExtractElementOp>(
|
||||||
|
loc, b.create<IndexCastOp>(loc, tensorIndexType, val), ValueRange());
|
||||||
|
};
|
||||||
|
|
||||||
|
// SCF for uses index type, so converted these.
|
||||||
|
auto forloopIndVar = getAsIndex(loopIndVar.first);
|
||||||
|
auto forMax = getAsIndex(max.first);
|
||||||
|
auto forStep = getAsIndex(step.first);
|
||||||
|
auto forOp = b.create<mlir::scf::ForOp>(whileOp.getLoc(), forloopIndVar,
|
||||||
|
forMax, forStep, input);
|
||||||
|
// Transfer the body without the block arguments.
|
||||||
|
forOp.getLoopBody().front().getOperations().splice(
|
||||||
|
forOp.getLoopBody().front().getOperations().end(),
|
||||||
|
whileOp.body().front().getOperations());
|
||||||
|
|
||||||
|
b.setInsertionPointToStart(&forOp.getLoopBody().front());
|
||||||
|
auto loopIndVarElType =
|
||||||
|
loopIndVar.first.getType().cast<ShapedType>().getElementType();
|
||||||
|
Value indVar = b.create<SplatOp>(
|
||||||
|
whileOp.getLoc(), RankedTensorType::get({}, loopIndVarElType),
|
||||||
|
b.create<IndexCastOp>(whileOp.getLoc(), loopIndVarElType,
|
||||||
|
forOp.getInductionVar()));
|
||||||
|
// Update all block argument users to the SCF For args.
|
||||||
|
for (auto* use :
|
||||||
|
llvm::make_early_inc_range(whileOp.body().getArgument(0).getUsers())) {
|
||||||
|
// TODO(jpienaar): Expand here too when we allow using the tuple in the
|
||||||
|
// loop.
|
||||||
|
auto gte = cast<GetTupleElementOp>(use);
|
||||||
|
// If the loop induction var, then refer to the loop induction variable as
|
||||||
|
// this operand is not updated.
|
||||||
|
if (gte.index() == loopIndVar.second) {
|
||||||
|
use->getResult(0).replaceAllUsesWith(indVar);
|
||||||
|
use->erase();
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
int index = gte.index().getSExtValue();
|
||||||
|
// If after the loop induction variable, then decrement as we don't include
|
||||||
|
// the loop induction variable in the for iter operands.
|
||||||
|
if (index > loopIndVar.second) --index;
|
||||||
|
use->getResult(0).replaceAllUsesWith(forOp.getIterOperands()[index]);
|
||||||
|
use->erase();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create new yield op without induction var update.
|
||||||
|
SmallVector<Value, 4> newYieldOps;
|
||||||
|
newYieldOps.reserve(bodyReturn.getNumOperands() - 1);
|
||||||
|
for (auto r : bodyReturn.getOperands().take_front(loopIndVar.second))
|
||||||
|
newYieldOps.push_back(r);
|
||||||
|
for (auto r : bodyReturn.getOperands().drop_front(loopIndVar.second + 1))
|
||||||
|
newYieldOps.push_back(r);
|
||||||
|
// Delete return & tuple op.
|
||||||
|
forOp.getLoopBody().front().back().erase();
|
||||||
|
forOp.getLoopBody().front().back().erase();
|
||||||
|
b.setInsertionPointToEnd(&forOp.getLoopBody().front());
|
||||||
|
b.create<scf::YieldOp>(whileOp.getLoc(), newYieldOps);
|
||||||
|
|
||||||
|
// Recombine output tuple with max value of induction variable.
|
||||||
|
llvm::SmallVector<Value, 4> loopOut;
|
||||||
|
loopOut.reserve(forOp.getNumResults() + 1);
|
||||||
|
for (auto r : forOp.getResults().take_front(loopIndVar.second))
|
||||||
|
loopOut.push_back(r);
|
||||||
|
loopOut.push_back(max.first);
|
||||||
|
for (auto r : forOp.getResults().drop_front(loopIndVar.second))
|
||||||
|
loopOut.push_back(r);
|
||||||
|
b.setInsertionPoint(whileOp);
|
||||||
|
auto newRes = b.create<mhlo::TupleOp>(whileOp.getLoc(), loopOut);
|
||||||
|
whileOp.replaceAllUsesWith(newRes.getOperation());
|
||||||
|
whileOp.erase();
|
||||||
|
}
|
||||||
|
|
||||||
|
} // anonymous namespace
|
||||||
|
|
||||||
|
std::unique_ptr<OperationPass<FuncOp>> createControlFlowToScfPass() {
|
||||||
|
return std::make_unique<ControlFlowToScfPass>();
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mhlo
|
||||||
|
} // namespace mlir
|
|
@ -0,0 +1,38 @@
|
||||||
|
// RUN: mlir-hlo-opt --mhlo-control-flow-to-scf %s | FileCheck %s
|
||||||
|
|
||||||
|
func @lt_loop(%arg0: tensor<4xf32>, %arg1: tensor<f32>, %arg2: tensor<f32>, %arg3: tensor<4xf32>, %arg4: tensor<f32>, %arg5: tensor<f32>, %arg6: tensor<f32>, %arg7: tensor<f32>, %arg8: tensor<i32>) -> (tuple<tensor<i32>, tensor<i32>, tensor<i32>>) {
|
||||||
|
%cst = constant dense<-1> : tensor<i32>
|
||||||
|
%cst_0 = constant dense<1> : tensor<i32>
|
||||||
|
%cst_1 = constant dense<0> : tensor<i32>
|
||||||
|
%cst_2 = constant dense<1000> : tensor<i32>
|
||||||
|
%0 = "mhlo.tuple"(%cst_1, %cst, %cst_2) : (tensor<i32>, tensor<i32>, tensor<i32>) -> tuple<tensor<i32>, tensor<i32>, tensor<i32>>
|
||||||
|
%1 = "mhlo.while"(%0) ( {
|
||||||
|
^bb0(%arg9: tuple<tensor<i32>, tensor<i32>, tensor<i32>>): // no predecessors
|
||||||
|
%2 = "mhlo.get_tuple_element"(%arg9) {index = 0 : i32} : (tuple<tensor<i32>, tensor<i32>, tensor<i32>>) -> tensor<i32>
|
||||||
|
%3 = "mhlo.get_tuple_element"(%arg9) {index = 2 : i32} : (tuple<tensor<i32>, tensor<i32>, tensor<i32>>) -> tensor<i32>
|
||||||
|
%4 = "mhlo.compare"(%2, %3) {comparison_direction = "LT"} : (tensor<i32>, tensor<i32>) -> tensor<i1>
|
||||||
|
"mhlo.return"(%4) : (tensor<i1>) -> ()
|
||||||
|
}, {
|
||||||
|
^bb0(%arg9: tuple<tensor<i32>, tensor<i32>, tensor<i32>>): // no predecessors
|
||||||
|
%2 = "mhlo.get_tuple_element"(%arg9) {index = 0 : i32} : (tuple<tensor<i32>, tensor<i32>, tensor<i32>>) -> tensor<i32>
|
||||||
|
%3 = mhlo.add %2, %cst_0 : tensor<i32>
|
||||||
|
%4 = "mhlo.get_tuple_element"(%arg9) {index = 1 : i32} : (tuple<tensor<i32>, tensor<i32>, tensor<i32>>) -> tensor<i32>
|
||||||
|
%5 = "mhlo.get_tuple_element"(%arg9) {index = 2 : i32} : (tuple<tensor<i32>, tensor<i32>, tensor<i32>>) -> tensor<i32>
|
||||||
|
%6 = "mhlo.tuple"(%3, %4, %5) : (tensor<i32>, tensor<i32>, tensor<i32>) -> tuple<tensor<i32>, tensor<i32>, tensor<i32>>
|
||||||
|
"mhlo.return"(%6) : (tuple<tensor<i32>, tensor<i32>, tensor<i32>>) -> ()
|
||||||
|
}) : (tuple<tensor<i32>, tensor<i32>, tensor<i32>>) -> tuple<tensor<i32>, tensor<i32>, tensor<i32>>
|
||||||
|
return %1 : tuple<tensor<i32>, tensor<i32>, tensor<i32>>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @lt_loop(
|
||||||
|
// CHECK: %[[VAL_9:.*]] = constant dense<-1> : tensor<i32>
|
||||||
|
// CHECK: %[[VAL_10:.*]] = constant dense<1> : tensor<i32>
|
||||||
|
// CHECK: %[[VAL_11:.*]] = constant dense<0> : tensor<i32>
|
||||||
|
// CHECK: %[[VAL_12:.*]] = constant dense<1000> : tensor<i32>
|
||||||
|
// CHECK: %[[VAL_14:.*]] = index_cast %[[VAL_11]] : tensor<i32> to tensor<index>
|
||||||
|
// CHECK: %[[VAL_15:.*]] = extract_element %[[VAL_14]][] : tensor<index>
|
||||||
|
// CHECK: %[[VAL_16:.*]] = index_cast %[[VAL_12]] : tensor<i32> to tensor<index>
|
||||||
|
// CHECK: %[[VAL_17:.*]] = extract_element %[[VAL_16]][] : tensor<index>
|
||||||
|
// CHECK: %[[VAL_18:.*]] = index_cast %[[VAL_10]] : tensor<i32> to tensor<index>
|
||||||
|
// CHECK: %[[VAL_19:.*]] = extract_element %[[VAL_18]][] : tensor<index>
|
||||||
|
// CHECK: scf.for %[[VAL_21:.*]] = %[[VAL_15]] to %[[VAL_17]] step %[[VAL_19]] iter_args(%[[VAL_22:.*]] = %[[VAL_9]], %[[VAL_23:.*]] = %[[VAL_12]])
|
Loading…
Reference in New Issue