Lower LSTMOp to Krnl dialect (#73)

* Support dilations and enable e2e tests

* Fix allocating memory for dynamic shape

* Edit comments

* Do dilation by computing an offset from kernel index

* Correct dilation formula, add an example of out-of-bound, and add a test for dilation

* Import optional outputs as NoneType

* Shape inference for ONNXLSTM

* Edit ONNXLSTM::inferShape()

* Shape inference for ONNXLSTMOp

* Create a common function for inferring shape for RNN ops

* CheckInsertDeallocation for a specific result

* Allocate memory for LSTM

* First round of lowering

* Allocate memory for hidden and cell states

* Test with custom Tanh

* Fix an error in Ct's formula

* Add E2E tests

* Return outputs

* Refactor the code

* Enable E2E tests

* Support reverse and bidirectional directions

* Minor revision

* Return all intermediate hidden states

* Call existing activation functions

* Structs for activation functions

* Call existing activations in ONNX

* Minor revision

* Compare strings ignoring case

* Use memreftype of rank 0 for calling activation functions

* Fix getActivationPack()

* Revise the code

* Add one MLIR test

* Add MLIR tests for reverse and bidirectional modes

* Make the order of emiting instructions deterministic

* Use OperandAdaptor instead of directly use an operand index

* Use literal assignments

* Change some variable names

* Use literal assignments

* Use literal assignments

* Format the code

Co-authored-by: Gheorghe-Teodor Bercea <gt.bercea@gmail.com>
This commit is contained in:
Tung D. Le 2020-05-13 22:08:06 +09:00 committed by GitHub
parent 9a874007ce
commit 24343177b8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 1429 additions and 12 deletions

View File

@ -9,6 +9,9 @@ add_library(OMONNXToKrnl
NN/Conv.cpp
NN/Normalization.cpp
NN/Pooling.cpp
RNN/RNNBase.cpp
RNN/RNNBase.hpp
RNN/LSTM.cpp
Tensor/Identity.cpp
Tensor/Reshape.cpp
Tensor/PadConstantValuePad.cpp

View File

@ -103,6 +103,8 @@ void FrontendToKrnlLoweringPass::runOnOperation() {
populateLoweringONNXConvOpPattern(patterns, &getContext());
populateLoweringONNXNormalizationOpPattern(patterns, &getContext());
populateLoweringONNXPoolingOpPattern(patterns, &getContext());
// Recurrent neural network
populateLoweringONNXLSTMOpPattern(patterns, &getContext());
// Entry point
patterns.insert<ONNXEntryPointLowering>(&getContext());

View File

@ -102,16 +102,14 @@ Value insertAllocAndDealloc(MemRefType type, Location loc,
// Determine if current function returns the result value of the
// current op being lowered. If it does then dealloc should not be
// inserted.
bool checkInsertDealloc(Operation *currentOp) {
bool checkInsertDealloc(Operation *currentOp, int resultIndex) {
auto parentBlock = currentOp->getBlock();
bool insertDealloc = true;
parentBlock->walk([&insertDealloc, currentOp](ReturnOp op) {
assert(currentOp->getNumResults() < 2 &&
"No more than one result supported (for now).");
parentBlock->walk([&insertDealloc, currentOp, resultIndex](ReturnOp op) {
// If there is at least one result to investigate.
if (currentOp->getNumResults() > 0) {
auto result = currentOp->getResult(0);
auto result = currentOp->getResult(resultIndex);
for (const auto &operand : op.getOperands())
if (operand == result)
insertDealloc = false;

View File

@ -46,7 +46,7 @@ Value insertAllocAndDealloc(MemRefType type, Location loc,
// Determine if current function returns the result value of the
// current op being lowered. If it does then dealloc should not be
// inserted.
bool checkInsertDealloc(Operation *currentOp);
bool checkInsertDealloc(Operation *currentOp, int resultIndex = 0);
// Create a mapping from result type's dimensions to input type's dimensions,
// given that the result type is the result of a reduction op over the input
@ -218,6 +218,10 @@ void populateLoweringONNXNormalizationOpPattern(
void populateLoweringONNXPoolingOpPattern(
OwningRewritePatternList &patterns, MLIRContext *ctx);
// `RNN` directory methods:
void populateLoweringONNXLSTMOpPattern(
OwningRewritePatternList &patterns, MLIRContext *ctx);
// `Tensor` directory methods:
void populateLoweringONNXUnsqueezeOpPattern(

View File

@ -0,0 +1,537 @@
//===--------------- LSTM.cpp - Lowering LSTM Op --------------------------===//
//
// Copyright 2019 The IBM Research Authors.
//
// =============================================================================
//
// This file lowers the ONNX LSTM Operators to Krnl dialect.
//
//===----------------------------------------------------------------------===//
#include "src/Conversion/ONNXToKrnl/RNN/RNNBase.hpp"
using namespace mlir;
struct LstmState {
Value allH;
Value ht;
Value ct;
};
struct LstmActivationPack {
RNNActivation f;
RNNActivation g;
RNNActivation h;
};
template <>
bool hasAllNoneOutput<ONNXLSTMOp>(ONNXLSTMOp *op) {
return (
isNoneType(op->Y()) && isNoneType(op->Y_h()) && isNoneType(op->Y_c()));
}
template <>
std::tuple<LstmActivationPack, LstmActivationPack>
getActivationPack<ONNXLSTMOp, LstmActivationPack>(ONNXLSTMOp *op) {
auto direction = op->direction();
auto activations = op->activations();
auto activationAlpha = op->activation_alpha();
auto activationBeta = op->activation_beta();
LstmActivationPack activationForward, activationReverse;
// Get activation function name.
// Default forward functions
activationForward.f.name = "sigmoid";
activationForward.g.name = "tanh";
activationForward.h.name = "tanh";
// Default backward functions
activationReverse.f.name = "sigmoid";
activationReverse.g.name = "tanh";
activationReverse.h.name = "tanh";
if (activations) {
ArrayAttr activationArrAttr = activations.getValue();
if (direction == FORWARD || direction == BIDIRECTIONAL) {
// Forward activations.
if (activationArrAttr.size() > 0) {
activationForward.f.name =
activationArrAttr[0].cast<StringAttr>().getValue();
}
if (activationArrAttr.size() > 1) {
activationForward.g.name =
activationArrAttr[1].cast<StringAttr>().getValue();
}
if (activationArrAttr.size() > 2) {
activationForward.h.name =
activationArrAttr[2].cast<StringAttr>().getValue();
}
}
// Reverse activations.
if (direction == REVERSE || direction == BIDIRECTIONAL) {
int startIndex = (direction == REVERSE) ? 0 : 3;
if (activationArrAttr.size() > startIndex) {
activationReverse.f.name =
activationArrAttr[startIndex].cast<StringAttr>().getValue();
}
if (activationArrAttr.size() > startIndex + 1) {
activationReverse.g.name =
activationArrAttr[startIndex + 1].cast<StringAttr>().getValue();
}
if (activationArrAttr.size() > startIndex + 2) {
activationReverse.h.name =
activationArrAttr[startIndex + 2].cast<StringAttr>().getValue();
}
}
}
// Get alpha attributes.
if (activationAlpha) {
ArrayAttr activationArrAttr = activationAlpha.getValue();
if (direction == FORWARD || direction == BIDIRECTIONAL) {
// Forward activations.
if (activationArrAttr.size() > 0) {
activationForward.f.alpha = activationArrAttr[0].cast<FloatAttr>();
}
if (activationArrAttr.size() > 1) {
activationForward.g.alpha = activationArrAttr[1].cast<FloatAttr>();
}
if (activationArrAttr.size() > 2) {
activationForward.h.alpha = activationArrAttr[2].cast<FloatAttr>();
}
}
// Reverse activations.
if (direction == REVERSE || direction == BIDIRECTIONAL) {
int startIndex = (direction == REVERSE) ? 0 : 3;
if (activationArrAttr.size() > startIndex) {
activationReverse.f.alpha =
activationArrAttr[startIndex].cast<FloatAttr>();
}
if (activationArrAttr.size() > startIndex + 1) {
activationReverse.g.alpha =
activationArrAttr[startIndex + 1].cast<FloatAttr>();
}
if (activationArrAttr.size() > startIndex + 2) {
activationReverse.h.alpha =
activationArrAttr[startIndex + 2].cast<FloatAttr>();
}
}
}
// Get beta attributes.
if (activationBeta) {
ArrayAttr activationArrAttr = activationBeta.getValue();
if (direction == FORWARD || direction == BIDIRECTIONAL) {
// Forward activations.
if (activationArrAttr.size() > 0) {
activationForward.f.beta = activationArrAttr[0].cast<FloatAttr>();
}
if (activationArrAttr.size() > 1) {
activationForward.g.beta = activationArrAttr[1].cast<FloatAttr>();
}
if (activationArrAttr.size() > 2) {
activationForward.h.beta = activationArrAttr[2].cast<FloatAttr>();
}
}
// Reverse activations.
if (direction == REVERSE || direction == BIDIRECTIONAL) {
int startIndex = (direction == REVERSE) ? 0 : 3;
if (activationArrAttr.size() > startIndex) {
activationReverse.f.beta =
activationArrAttr[startIndex].cast<FloatAttr>();
}
if (activationArrAttr.size() > startIndex + 1) {
activationReverse.g.beta =
activationArrAttr[startIndex + 1].cast<FloatAttr>();
}
if (activationArrAttr.size() > startIndex + 2) {
activationReverse.h.beta =
activationArrAttr[startIndex + 2].cast<FloatAttr>();
}
}
}
return std::make_tuple(activationForward, activationReverse);
}
template <>
LstmState allocAndInitializeStates<ONNXLSTMOp, LstmState>(
ConversionPatternRewriter &rewriter, Location loc, ONNXLSTMOp *op,
OperandAdaptor<ONNXLSTMOp> operandAdaptor) {
LstmState state;
// Insert allocation and deallocation for the results of this operation.
if (!isNoneType(op->Y())) {
auto yMemRefType = convertToMemRefType(op->Y().getType());
if (hasAllConstantDimensions(yMemRefType))
state.allH = insertAllocAndDealloc(yMemRefType, loc, rewriter,
checkInsertDealloc(op->getOperation(), 0));
else
emitError(loc, "Unsupported dynamic dimensions.");
} else {
state.allH = op->Y();
}
// Y_h :: [num_directions, batch_size, hidden_size]
if (!isNoneType(op->Y_h())) {
auto yhMemRefType = convertToMemRefType(op->Y_h().getType());
if (hasAllConstantDimensions(yhMemRefType))
state.ht = insertAllocAndDealloc(yhMemRefType, loc, rewriter,
checkInsertDealloc(op->getOperation(), 1));
else
emitError(loc, "Unsupported dynamic dimensions.");
} else {
auto yhMemRefType = MemRefType::get(
{dimAt(operandAdaptor.W(), 0), dimAt(operandAdaptor.X(), 1),
dimAt(operandAdaptor.R(), 2)},
operandAdaptor.X().getType().cast<ShapedType>().getElementType());
state.ht = insertAllocAndDealloc(yhMemRefType, loc, rewriter, true);
}
// Y_c :: [num_directions, batch_size, hidden_size]
if (!isNoneType(op->Y_c())) {
auto ycMemRefType = convertToMemRefType(op->Y_c().getType());
if (hasAllConstantDimensions(ycMemRefType))
state.ct = insertAllocAndDealloc(ycMemRefType, loc, rewriter,
checkInsertDealloc(op->getOperation(), 2));
else
emitError(loc, "Unsupported dynamic dimensions.");
} else {
auto ycMemRefType = MemRefType::get(
{dimAt(operandAdaptor.W(), 0), dimAt(operandAdaptor.X(), 1),
dimAt(operandAdaptor.R(), 2)},
operandAdaptor.X().getType().cast<ShapedType>().getElementType());
state.ct = insertAllocAndDealloc(ycMemRefType, loc, rewriter, true);
}
// Initialize ht and ct.
Value zero = emitConstantOp(rewriter, loc,
operandAdaptor.X().getType().cast<ShapedType>().getElementType(), 0);
int nLoops = 3;
BuildKrnlLoop initializationLoops(rewriter, loc, nLoops);
initializationLoops.createDefineOptimizeAndIterateOp(state.ht);
auto ipInitializationLoops = rewriter.saveInsertionPoint();
rewriter.setInsertionPointToStart(initializationLoops.getIterateBlock());
{
SmallVector<Value, 4> IVs;
for (int i = 0; i < nLoops; ++i)
IVs.emplace_back(initializationLoops.getInductionVar(i));
Value hiddenVal = zero;
if (!isNoneType(operandAdaptor.initial_h()))
hiddenVal = rewriter.create<LoadOp>(loc, operandAdaptor.initial_h(), IVs);
rewriter.create<StoreOp>(loc, hiddenVal, state.ht, IVs);
Value cellVal = zero;
if (!isNoneType(operandAdaptor.initial_c()))
cellVal = rewriter.create<LoadOp>(loc, operandAdaptor.initial_c(), IVs);
rewriter.create<StoreOp>(loc, cellVal, state.ct, IVs);
}
rewriter.restoreInsertionPoint(ipInitializationLoops);
return state;
}
template <>
void calculateState<ONNXLSTMOp, LstmState, LstmActivationPack>(
ConversionPatternRewriter &rewriter, Location loc,
OperandAdaptor<ONNXLSTMOp> operandAdaptor, LstmState state,
LstmActivationPack activationPack, Value directionIV, Value sequenceIV) {
bool hasBiasForInput = false, hasPeepholes = false;
if (!isNoneType(operandAdaptor.B()))
hasBiasForInput = true;
if (!isNoneType(operandAdaptor.P()))
hasPeepholes = true;
// Prepare dimensions.
auto batchDimSize = dimAt(operandAdaptor.X(), 1);
auto inputDimSize = dimAt(operandAdaptor.X(), 2);
auto hiddenDimSize = dimAt(operandAdaptor.R(), 2);
Value hiddenDimVal =
emitConstantOp(rewriter, loc, rewriter.getIndexType(), hiddenDimSize);
auto elementType =
operandAdaptor.X().getType().cast<ShapedType>().getElementType();
// Prepare AffineMap to access bias, peepholes tensors.
AffineMap accessByOffsetMap;
{
AffineExpr iv = rewriter.getAffineDimExpr(0);
AffineExpr index = rewriter.getAffineSymbolExpr(0);
AffineExpr size = rewriter.getAffineSymbolExpr(1);
AffineExpr accessByOffsetExpr = index * size + iv;
accessByOffsetMap = AffineMap::get(1, 2, accessByOffsetExpr);
}
// Prepare constant indices.
SmallVector<Value, 4> constantIndices;
for (int i = 0; i < 8; i++)
constantIndices.emplace_back(
emitConstantOp(rewriter, loc, rewriter.getIndexType(), i));
// Equations for LSTM.
// it = f(Xt*(Wi^T) + Ht-1*(Ri^T) + Pi (.) Ct-1 + Wbi + Rbi)
// ft = f(Xt*(Wf^T) + Ht-1*(Rf^T) + Pf (.) Ct-1 + Wbf + Rbf)
// ct = g(Xt*(Wc^T) + Ht-1*(Rc^T) + Wbc + Rbc)
// Ct = ft (.) Ct-1 + it (.) ct
// ot = f(Xt*(Wo^T) + Ht-1*(Ro^T) + Po (.) Ct + Wbo + Rbo)
// Ht = ot (.) h(Ct)
//
// The following code will emit loops as follows:
// for b in 0 .. BatchDimSize
// for h in 0 .. HiddenDimSize
// for i in 0 .. InputDimSize {
// compute Xt*(Wi^T), Xt*(Wo^T), Xt*(Wf^t), Xt*(Wc^T),
// Ht-1*(Ri^T), Ht-1*(Ro^T), Ht-1*(Rf^t), Ht-1*(Rc^T)
// }
// compute it, ft, ct, Ct, ot, Ht
BuildKrnlLoop stateLoops(rewriter, loc, 2);
stateLoops.createDefineAndOptimizeOp();
stateLoops.pushBounds(0, batchDimSize);
stateLoops.pushBounds(0, hiddenDimSize);
stateLoops.createIterateOp();
rewriter.setInsertionPointToStart(stateLoops.getIterateBlock());
{
auto batchIV = stateLoops.getInductionVar(0);
auto hiddenIV = stateLoops.getInductionVar(1);
// IVs to access tensors.
// IVs for the hidden and cell state tensors.
SmallVector<Value, 4> hIVs, cIVs;
// IVs for the bias tensors for W and R.
SmallVector<SmallVector<Value, 4>, 4> wbIOFCIVs, rbIOFCIVs;
// IVs for the peepholes.
SmallVector<SmallVector<Value, 4>, 4> pIOFIVs;
{ // Compute IVs.
// H :: [num_directions, batch_size, hidden_size]
hIVs = {directionIV, batchIV, hiddenIV};
// C :: [num_directions, batch_size, hidden_size]
cIVs = {directionIV, batchIV, hiddenIV};
// Bias [Wb[iofc], Rb[iofc]] :: [num_directions, 8*hidden_size]
if (hasBiasForInput) {
// Wb[iofc]
for (unsigned i = 0; i < 4; ++i) {
Value wHiddenIV =
rewriter.create<AffineApplyOp>(loc, accessByOffsetMap,
ValueRange(std::vector<Value>{/*iv=*/hiddenIV,
/*index=*/constantIndices[i], /*size=*/hiddenDimVal}));
wbIOFCIVs.emplace_back(SmallVector<Value, 2>{directionIV, wHiddenIV});
}
// Rb[iofc]
for (unsigned i = 4; i < 8; ++i) {
SmallVector<Value, 4> rbIVs;
Value rHiddenIV =
rewriter.create<AffineApplyOp>(loc, accessByOffsetMap,
ValueRange(std::vector<Value>{/*iv=*/hiddenIV,
/*index=*/constantIndices[i], /*size=*/hiddenDimVal}));
rbIOFCIVs.emplace_back(SmallVector<Value, 2>{directionIV, rHiddenIV});
}
}
// Peepholes P[iof] :: [num_directions, 3*hidden_size]
if (hasPeepholes) {
for (unsigned i = 0; i < 3; ++i) {
SmallVector<Value, 4> pIVs;
Value pHiddenIV =
rewriter.create<AffineApplyOp>(loc, accessByOffsetMap,
ValueRange(std::vector<Value>{
hiddenIV, constantIndices[i], hiddenDimVal}));
pIOFIVs.emplace_back(SmallVector<Value, 2>{directionIV, pHiddenIV});
}
}
}
Value loadH = rewriter.create<LoadOp>(loc, state.ht, hIVs);
Value loadC = rewriter.create<LoadOp>(loc, state.ct, cIVs);
// Emit instructions for matrix multiplications:
// Xt*(Wi^T), Xt*(Wo^T), Xt*(Wf^t), Xt*(Wc^T)
// Ht-1*(Ri^T), Ht-1*(Ro^T), Ht-1*(Rf^t), Ht-1*(Rc^T)
// Allocate memory for storing matrix multiplication results.
SmallVector<Value, 4> xwIOFC, hrIOFC;
Value zero = emitConstantOp(rewriter, loc, elementType, 0);
MemRefType scalarMemRefType = MemRefType::get({}, elementType, {}, 0);
for (unsigned i = 0; i < 4; ++i) {
Value xwAlloc = rewriter.create<AllocOp>(loc, scalarMemRefType);
rewriter.create<StoreOp>(loc, zero, xwAlloc);
Value hrAlloc = rewriter.create<AllocOp>(loc, scalarMemRefType);
rewriter.create<StoreOp>(loc, zero, hrAlloc);
xwIOFC.emplace_back(xwAlloc);
hrIOFC.emplace_back(hrAlloc);
}
{ // Emit instructions for matrix multiplications.
// input_size is the reduction dimension.
BuildKrnlLoop reductionLoops(rewriter, loc, 1);
reductionLoops.createDefineAndOptimizeOp();
reductionLoops.pushBounds(0, inputDimSize);
reductionLoops.createIterateOp();
auto ipReductionLoops = rewriter.saveInsertionPoint();
rewriter.setInsertionPointToStart(reductionLoops.getIterateBlock());
{
auto reductionIV = reductionLoops.getInductionVar(0);
// Prepare IVs for accessing the input tensor and parameters.
SmallVector<Value, 4> xIVs;
SmallVector<SmallVector<Value, 4>, 4> wIOFCIVs, rIOFCIVs;
// X :: [seq_length, batch_size, input_size]
xIVs = {sequenceIV, batchIV, reductionIV};
// W[iofc] :: [num_directions, 4*hidden_size, input_size]
// R[iofc] :: [num_directions, 4*hidden_size, input_size]
for (unsigned i = 0; i < 4; ++i) {
SmallVector<Value, 4> wIVs, rIVs;
Value wHiddenIV =
rewriter.create<AffineApplyOp>(loc, accessByOffsetMap,
ValueRange(std::vector<Value>{
hiddenIV, constantIndices[i], hiddenDimVal}));
wIVs = {directionIV, wHiddenIV, reductionIV};
wIOFCIVs.emplace_back(wIVs);
rIVs = {directionIV, wHiddenIV, reductionIV};
rIOFCIVs.emplace_back(rIVs);
}
Value loadX = rewriter.create<LoadOp>(loc, operandAdaptor.X(), xIVs);
for (unsigned i = 0; i < 4; ++i) {
// Xt * Wiofc
Value loadW =
rewriter.create<LoadOp>(loc, operandAdaptor.W(), wIOFCIVs[i]);
Value xwVal = rewriter.create<MulFOp>(loc, loadX, loadW);
Value loadXW = rewriter.create<LoadOp>(loc, xwIOFC[i]);
Value nextXW = rewriter.create<AddFOp>(loc, loadXW, xwVal);
rewriter.create<StoreOp>(loc, nextXW, xwIOFC[i]);
// Ht-1 * Riofc
Value loadR =
rewriter.create<LoadOp>(loc, operandAdaptor.R(), rIOFCIVs[i]);
Value hrVal = rewriter.create<MulFOp>(loc, loadH, loadR);
Value loadHR = rewriter.create<LoadOp>(loc, hrIOFC[i]);
Value nextHR = rewriter.create<AddFOp>(loc, loadHR, hrVal);
rewriter.create<StoreOp>(loc, nextHR, hrIOFC[i]);
}
}
rewriter.restoreInsertionPoint(ipReductionLoops);
}
// it = f(Xt*(Wi^T) + Ht-1*(Ri^T) + Pi (.) Ct-1 + Wbi + Rbi)
Value loadXWI = rewriter.create<LoadOp>(loc, xwIOFC[0]);
Value loadHRI = rewriter.create<LoadOp>(loc, hrIOFC[0]);
Value it = rewriter.create<AddFOp>(loc, loadXWI, loadHRI);
if (hasPeepholes) {
Value loadP =
rewriter.create<LoadOp>(loc, operandAdaptor.P(), pIOFIVs[0]);
Value PC = rewriter.create<MulFOp>(loc, loadP, loadC);
it = rewriter.create<AddFOp>(loc, it, PC);
}
if (hasBiasForInput) {
Value loadWB =
rewriter.create<LoadOp>(loc, operandAdaptor.B(), wbIOFCIVs[0]);
it = rewriter.create<AddFOp>(loc, it, loadWB);
Value loadRB =
rewriter.create<LoadOp>(loc, operandAdaptor.B(), rbIOFCIVs[0]);
it = rewriter.create<AddFOp>(loc, it, loadRB);
}
it = applyActivation(rewriter, loc, activationPack.f, it);
// ft = f(Xt*(Wf^T) + Ht-1*(Rf^T) + Pf (.) Ct-1 + Wbf + Rbf)
Value loadXWF = rewriter.create<LoadOp>(loc, xwIOFC[2]);
Value loadHRF = rewriter.create<LoadOp>(loc, hrIOFC[2]);
Value ft = rewriter.create<AddFOp>(loc, loadXWF, loadHRF);
if (hasPeepholes) {
Value loadP =
rewriter.create<LoadOp>(loc, operandAdaptor.P(), pIOFIVs[2]);
Value PC = rewriter.create<MulFOp>(loc, loadP, loadC);
ft = rewriter.create<AddFOp>(loc, ft, PC);
}
if (hasBiasForInput) {
Value loadWB =
rewriter.create<LoadOp>(loc, operandAdaptor.B(), wbIOFCIVs[2]);
ft = rewriter.create<AddFOp>(loc, ft, loadWB);
Value loadRB =
rewriter.create<LoadOp>(loc, operandAdaptor.B(), rbIOFCIVs[2]);
ft = rewriter.create<AddFOp>(loc, ft, loadRB);
}
ft = applyActivation(rewriter, loc, activationPack.f, ft);
// ct = g(Xt*(Wc^T) + Ht-1*(Rc^T) + Wbc + Rbc)
Value loadXWC = rewriter.create<LoadOp>(loc, xwIOFC[3]);
Value loadHRC = rewriter.create<LoadOp>(loc, hrIOFC[3]);
Value ct = rewriter.create<AddFOp>(loc, loadXWC, loadHRC);
if (hasBiasForInput) {
Value loadWB =
rewriter.create<LoadOp>(loc, operandAdaptor.B(), wbIOFCIVs[3]);
ct = rewriter.create<AddFOp>(loc, ct, loadWB);
Value loadRB =
rewriter.create<LoadOp>(loc, operandAdaptor.B(), rbIOFCIVs[3]);
ct = rewriter.create<AddFOp>(loc, ct, loadRB);
}
ct = applyActivation(rewriter, loc, activationPack.g, ct);
// Ct = ft (.) Ct-1 + it (.) ct
Value FtCt1 = rewriter.create<MulFOp>(loc, ft, loadC);
Value itct = rewriter.create<MulFOp>(loc, it, ct);
Value Ct = rewriter.create<AddFOp>(loc, FtCt1, itct);
rewriter.create<StoreOp>(loc, Ct, state.ct, cIVs);
// ot = f(Xt*(Wo^T) + Ht-1*(Ro^T) + Po (.) Ct + Wbo + Rbo)
Value loadXWO = rewriter.create<LoadOp>(loc, xwIOFC[1]);
Value loadHRO = rewriter.create<LoadOp>(loc, hrIOFC[1]);
Value ot = rewriter.create<AddFOp>(loc, loadXWO, loadHRO);
if (hasPeepholes) {
Value loadP =
rewriter.create<LoadOp>(loc, operandAdaptor.P(), pIOFIVs[1]);
Value PC = rewriter.create<MulFOp>(loc, loadP, Ct);
ot = rewriter.create<AddFOp>(loc, ot, PC);
}
if (hasBiasForInput) {
Value loadWB =
rewriter.create<LoadOp>(loc, operandAdaptor.B(), wbIOFCIVs[1]);
ot = rewriter.create<AddFOp>(loc, ot, loadWB);
Value loadRB =
rewriter.create<LoadOp>(loc, operandAdaptor.B(), rbIOFCIVs[1]);
ot = rewriter.create<AddFOp>(loc, ot, loadRB);
}
ot = applyActivation(rewriter, loc, activationPack.f, ot);
// Ht = ot (.) h(Ct)
Value hCt = applyActivation(rewriter, loc, activationPack.h, Ct);
Value Ht = rewriter.create<MulFOp>(loc, ot, hCt);
rewriter.create<StoreOp>(loc, Ht, state.ht, hIVs);
// Store the current Ht if required.
if (!isNoneType(state.allH)) {
SmallVector<Value, 4> allHIVs{sequenceIV, directionIV, batchIV, hiddenIV};
rewriter.create<StoreOp>(loc, Ht, state.allH, allHIVs);
}
// Deallocate the temporary results of matrix multiplications.
for (Value v : xwIOFC)
rewriter.create<DeallocOp>(loc, v);
for (Value v : hrIOFC)
rewriter.create<DeallocOp>(loc, v);
}
}
template <>
void stateToOutput<ONNXLSTMOp, LstmState>(
ONNXLSTMOp *op, LstmState state, std::vector<Value> &outputs) {
Value noneValue;
outputs.emplace_back((isNoneType(op->Y()) ? noneValue : state.allH));
outputs.emplace_back((isNoneType(op->Y_h()) ? noneValue : state.ht));
outputs.emplace_back((isNoneType(op->Y_c()) ? noneValue : state.ct));
}
void populateLoweringONNXLSTMOpPattern(
OwningRewritePatternList &patterns, MLIRContext *ctx) {
patterns.insert<ONNXRNNOpLowering<ONNXLSTMOp, LstmState, LstmActivationPack>>(
ctx);
}

View File

@ -0,0 +1,73 @@
//===--------------- RNNBase.cpp - Lowering RNN Ops -----------------------===//
//
// Copyright 2019 The IBM Research Authors.
//
// =============================================================================
//
// This file defines base functions for lowerng the ONNX RNN Operators.
//
//===----------------------------------------------------------------------===//
#include "src/Conversion/ONNXToKrnl/RNN/RNNBase.hpp"
using namespace mlir;
// Check a Value's type is none or not.
bool isNoneType(Value val) { return val.getType().isa<NoneType>(); }
// Get a dimension of the tensor's shape.
int64_t dimAt(Value val, int index) {
return val.getType().cast<ShapedType>().getShape()[index];
}
// Apply an activation function on a given scalar operand.
Value applyActivation(ConversionPatternRewriter &rewriter, Location loc,
RNNActivation activation, Value scalarOperand) {
Value res;
MemRefType scalarMemRefType =
MemRefType::get({}, scalarOperand.getType(), {}, 0);
Value alloc = rewriter.create<AllocOp>(loc, scalarMemRefType);
rewriter.create<StoreOp>(loc, scalarOperand, alloc);
std::vector<mlir::NamedAttribute> attributes;
if (activation.alpha) {
attributes.emplace_back(
rewriter.getNamedAttr("alpha", activation.alpha.getValue()));
}
if (activation.beta) {
attributes.emplace_back(
rewriter.getNamedAttr("beta", activation.beta.getValue()));
}
if (activation.name.equals_lower("relu"))
res = rewriter.create<ONNXReluOp>(loc, scalarMemRefType, alloc);
else if (activation.name.equals_lower("tanh"))
res = rewriter.create<ONNXTanhOp>(loc, scalarMemRefType, alloc);
else if (activation.name.equals_lower("sigmoid"))
res = rewriter.create<ONNXSigmoidOp>(loc, scalarMemRefType, alloc);
else if (activation.name.equals_lower("affine"))
emitError(loc, "Unsupported activation");
else if (activation.name.equals_lower("leakyrelu"))
res = rewriter.create<ONNXLeakyReluOp>(
loc, scalarMemRefType, alloc, attributes);
else if (activation.name.equals_lower("thresholdedrelu"))
res = rewriter.create<ONNXThresholdedReluOp>(
loc, scalarMemRefType, alloc, attributes);
else if (activation.name.equals_lower("scaledtanh"))
emitError(loc, "Unsupported activation");
else if (activation.name.equals_lower("hardsigmoid"))
res = rewriter.create<ONNXHardSigmoidOp>(
loc, scalarMemRefType, alloc, attributes);
else if (activation.name.equals_lower("elu"))
res = rewriter.create<ONNXEluOp>(loc, scalarMemRefType, alloc, attributes);
else if (activation.name.equals_lower("softsign"))
res = rewriter.create<ONNXSoftsignOp>(loc, scalarMemRefType, alloc);
else if (activation.name.equals_lower("softplus"))
res = rewriter.create<ONNXSoftplusOp>(loc, scalarMemRefType, alloc);
else
llvm_unreachable("Unsupported activation");
Value result = rewriter.create<LoadOp>(loc, res);
return result;
}

View File

@ -0,0 +1,144 @@
//===--------------- RNNBase.hpp - Lowering RNN Ops -----------------------===//
//
// Copyright 2019 The IBM Research Authors.
//
// =============================================================================
//
// This file defines base functions for lowerng the ONNX RNN Operators.
//
//===----------------------------------------------------------------------===//
#include "mlir/IR/AffineExpr.h"
#include "src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp"
using namespace mlir;
static const StringRef FORWARD = "forward";
static const StringRef REVERSE = "reverse";
static const StringRef BIDIRECTIONAL = "bidirectional";
struct RNNActivation {
StringRef name;
Optional<FloatAttr> alpha;
Optional<FloatAttr> beta;
};
// Check a Value's type is none or not.
bool isNoneType(Value val);
// Get a dimension of the tensor's shape.
int64_t dimAt(Value val, int index);
// Apply an activation function on a given scalar operand.
Value applyActivation(ConversionPatternRewriter &rewriter, Location loc,
RNNActivation activation, Value scalarOperand);
// Override the following methods when lowering an RNN operation:
// - hasAllNoneOutput
// - getActivationPack
// - allocAndInitializeStates
// - calculateState
// - stateToOutput
// Check whether all outputs have NoneType or not.
template <typename RNNOp>
bool hasAllNoneOutput(RNNOp *op);
// Obtain activations functions for a specific operation.
template <typename RNNOp, typename A>
std::tuple<A, A> getActivationPack(RNNOp *op);
// Allocate memory for RNN states and initialize them.
template <typename RNNOp, typename S>
S allocAndInitializeStates(ConversionPatternRewriter &rewriter, Location loc,
RNNOp *op, OperandAdaptor<RNNOp> operandAdaptor);
// Calculate new states from the current input and states.
template <typename RNNOp, typename S, typename A>
void calculateState(ConversionPatternRewriter &rewriter, Location loc,
OperandAdaptor<RNNOp> operandAdaptor, S state, A activationSet,
Value directionIV, Value sequenceIV);
// Write states to the RNN's outputs.
template <typename RNNOp, typename S>
void stateToOutput(RNNOp *op, S state, std::vector<Value> &outputs);
// A common template for lowering an RNN operation.
template <typename RNNOp, typename S, typename A>
struct ONNXRNNOpLowering : public ConversionPattern {
ONNXRNNOpLowering(MLIRContext *ctx)
: ConversionPattern(RNNOp::getOperationName(), 1, ctx) {}
LogicalResult matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
auto loc = op->getLoc();
RNNOp rnnOp = llvm::dyn_cast<RNNOp>(op);
OperandAdaptor<RNNOp> operandAdaptor(operands);
if (hasAllNoneOutput<RNNOp>(&rnnOp)) {
rewriter.eraseOp(op);
return success();
}
S state = allocAndInitializeStates<RNNOp, S>(
rewriter, loc, &rnnOp, operandAdaptor);
A activationForward, activationReverse;
std::tie(activationForward, activationReverse) =
getActivationPack<RNNOp, A>(&rnnOp);
int64_t sequenceDimSize = dimAt(rnnOp.X(), 0);
auto direction = rnnOp.direction();
if (direction == FORWARD || direction == BIDIRECTIONAL) {
BuildKrnlLoop sequenceLoops(rewriter, loc, 1);
sequenceLoops.createDefineAndOptimizeOp();
sequenceLoops.pushBounds(0, sequenceDimSize);
sequenceLoops.createIterateOp();
auto ipSequenceLoops = rewriter.saveInsertionPoint();
rewriter.setInsertionPointToStart(sequenceLoops.getIterateBlock());
{
Value directionIV =
emitConstantOp(rewriter, loc, rewriter.getIndexType(), 0);
Value sequenceIV = sequenceLoops.getInductionVar(0);
// Emit calculation for one RNN step.
calculateState<RNNOp, S, A>(rewriter, loc, operandAdaptor, state,
activationForward, directionIV, sequenceIV);
}
rewriter.restoreInsertionPoint(ipSequenceLoops);
}
if (direction == REVERSE || direction == BIDIRECTIONAL) {
BuildKrnlLoop sequenceLoops(rewriter, loc, 1);
sequenceLoops.createDefineAndOptimizeOp();
sequenceLoops.pushBounds(0, sequenceDimSize);
sequenceLoops.createIterateOp();
auto ipSequenceLoops = rewriter.saveInsertionPoint();
rewriter.setInsertionPointToStart(sequenceLoops.getIterateBlock());
{
AffineMap reverseIVMap = AffineMap::get(1, 1,
rewriter.getAffineSymbolExpr(0) - rewriter.getAffineDimExpr(0) - 1);
Value directionIV = emitConstantOp(rewriter, loc,
rewriter.getIndexType(), (direction == REVERSE) ? 0 : 1);
Value reverseSequenceIV =
rewriter.create<AffineApplyOp>(loc, reverseIVMap,
ValueRange(std::vector<Value>{sequenceLoops.getInductionVar(0),
emitConstantOp(rewriter, loc, rewriter.getIndexType(),
sequenceDimSize)}));
// Emit calculation for one RNN step.
calculateState<RNNOp, S, A>(rewriter, loc, operandAdaptor, state,
activationReverse, directionIV, reverseSequenceIV);
}
rewriter.restoreInsertionPoint(ipSequenceLoops);
}
std::vector<Value> outputs;
stateToOutput<RNNOp, S>(&rnnOp, state, outputs);
rewriter.replaceOp(op, outputs);
return success();
}
};

View File

@ -299,6 +299,112 @@ static void insertConvSpatialDim(SmallVector<int64_t, 4> *outputDims,
}
}
//===----------------------------------------------------------------------===//
// Support function that infers shape for RNN operations.
template <typename T>
static bool RNNShapeInference(T *op) {
Value X = op->X();
Value W = op->W();
Value R = op->R();
if (!X.getType().isa<RankedTensorType>() ||
!W.getType().isa<RankedTensorType>() ||
!R.getType().isa<RankedTensorType>())
return false;
auto xTy = X.getType().cast<RankedTensorType>();
auto elementType = xTy.getElementType();
// xShape :: [seq_length, batch_size, input_size]
auto xShape = xTy.getShape();
// wShape :: [num_directions, 4*hidden_size, input_size]
auto wShape = W.getType().cast<RankedTensorType>().getShape();
// rShape :: [num_directions, 4*hidden_size, hidden_size]
auto rShape = R.getType().cast<RankedTensorType>().getShape();
if (xShape.size() != 3) {
op->emitError("The first input tensor must have rank 3");
return false;
}
if (wShape.size() != 3) {
op->emitError("The second input tensor must have rank 3");
return false;
}
if (rShape.size() != 3) {
op->emitError("The third input tensor must have rank 3");
return false;
}
// Get sequence length, batch size and input size.
auto sequenceLength = xShape[0];
auto batchSize = xShape[1];
auto inputSize = xShape[2];
// Get hidden size from hidden_size attribute.
int64_t hiddenSize = -1;
if (op->hidden_size().hasValue()) {
hiddenSize = op->hidden_size().getValue().getSExtValue();
} else {
// Infer hidden_size from wShape and rShape if possible.
if (rShape[2] != -1)
hiddenSize = rShape[2];
else if (rShape[1] != -1)
hiddenSize = rShape[1] / 4;
else if (wShape[1] != -1)
hiddenSize = wShape[1] / 4;
// Update hidden_size attribute.
if (hiddenSize != -1) {
auto builder = mlir::Builder(op->getContext());
op->hidden_sizeAttr(builder.getI64IntegerAttr(hiddenSize));
}
}
// Get direction.
int numDirection;
if ((op->direction() == "forward") || (op->direction() == "reverse"))
numDirection = 1;
else if (op->direction() == "bidirectional")
numDirection = 2;
else
numDirection = -1;
if (numDirection == -1) {
op->emitError("direction attribute muse be one of the strings: forward, "
"reverse, and bidirectional");
return false;
}
// Set result types.
unsigned numOfResults = op->getNumResults();
if (numOfResults > 0) {
// Y :: [seq_length, num_directions, batch_size, hidden_size]
Type yTy = op->getResults()[0].getType();
if (!yTy.isa<NoneType>()) {
yTy = RankedTensorType::get(
{sequenceLength, numDirection, batchSize, hiddenSize}, elementType);
op->getResults()[0].setType(yTy);
}
}
if (numOfResults > 1) {
// Y_h :: [num_directions, batch_size, hidden_size]
Type yhTy = op->getResults()[1].getType();
if (!yhTy.isa<NoneType>()) {
yhTy = RankedTensorType::get(
{numDirection, batchSize, hiddenSize}, elementType);
op->getResults()[1].setType(yhTy);
}
}
if (numOfResults > 2) {
// Y_c :: [num_directions, batch_size, hidden_size]
Type ycTy = op->getResults()[2].getType();
if (!ycTy.isa<NoneType>()) {
ycTy = RankedTensorType::get(
{numDirection, batchSize, hiddenSize}, elementType);
op->getResults()[2].setType(ycTy);
}
}
return true;
}
//===----------------------------------------------------------------------===//
// ONNXOpsDialect
//===----------------------------------------------------------------------===//
@ -1472,7 +1578,6 @@ bool ONNXConstantOp::inferShapes() {
return true;
}
//===----------------------------------------------------------------------===//
// Concat
bool ONNXConcatOp::inferShapes() {
@ -1537,6 +1642,21 @@ bool ONNXConcatOp::inferShapes() {
return true;
}
//===----------------------------------------------------------------------===//
// RNN
bool ONNXRNNOp::inferShapes() { return RNNShapeInference<>(this); }
//===----------------------------------------------------------------------===//
// LSTM
bool ONNXLSTMOp::inferShapes() { return RNNShapeInference<>(this); }
//===----------------------------------------------------------------------===//
// GRU
bool ONNXGRUOp::inferShapes() { return RNNShapeInference<>(this); }
//===----------------------------------------------------------------------===//
// Split

View File

@ -738,7 +738,7 @@ def ONNXFloorOp:ONNX_Op<"Floor",
}
def ONNXGRUOp:ONNX_Op<"GRU",
[NoSideEffect]> {
[NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
let summary = "ONNX GRU operation";
let description = [{
"Computes an one-layer GRU. This operator is usually supported via some custom"
@ -1246,7 +1246,7 @@ def ONNXLRNOp:ONNX_Op<"LRN",
}
def ONNXLSTMOp:ONNX_Op<"LSTM",
[NoSideEffect]> {
[NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
let summary = "ONNX LSTM operation";
let description = [{
"Computes an one-layer LSTM. This operator is usually supported via some"
@ -2086,7 +2086,7 @@ def ONNXQuantizeLinearOp:ONNX_Op<"QuantizeLinear",
}
def ONNXRNNOp:ONNX_Op<"RNN",
[NoSideEffect]> {
[NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
let summary = "ONNX RNN operation";
let description = [{
"Computes an one-layer simple RNN. This operator is usually supported"

View File

@ -126,6 +126,9 @@ public:
op->getName().getStringRef() != "onnx.Concat" &&
op->getName().getStringRef() != "onnx.Split" &&
op->getName().getStringRef() != "onnx.Neg" &&
op->getName().getStringRef() != "onnx.RNN" &&
op->getName().getStringRef() != "onnx.LSTM" &&
op->getName().getStringRef() != "onnx.GRU" &&
op->getName().getStringRef() != "onnx.Unsqueeze")
return false;
return llvm::any_of(op->getResultTypes(), [](Type result_type) {

View File

@ -356,6 +356,11 @@ test_to_enable = [
"test_averagepool_2d_strides_cpu",
"test_averagepool_3d_default_cpu",
# LSTM
"test_lstm_defaults_cpu",
"test_lstm_with_initial_bias_cpu",
"test_lstm_with_peepholes_cpu",
]
# Extract name of all test cases.

View File

@ -1769,3 +1769,314 @@ func @test_maxpool_pooling_operation(%arg0 : tensor<1x3x32x32xf32>) -> tensor<*x
// CHECK: }
}
// -----
func @test_lstm_general_computation(%arg0: tensor<4x3x2xf32>, %arg1: tensor<1x12x2xf32>, %arg2: tensor<1x12x3xf32>) -> tensor<*xf32> {
%cst = constant unit
%Y, %Y_h, %Y_c = "onnx.LSTM"(%arg0, %arg1, %arg2, %cst, %cst, %cst, %cst, %cst) {hidden_size = 3 : i64} : (tensor<4x3x2xf32>, tensor<1x12x2xf32>, tensor<1x12x3xf32>, none, none, none, none, none) -> (none, tensor<*xf32>, none)
return %Y_h : tensor<*xf32>
// CHECK-DAG: [[ACCESS_BY_OFFSET_MAP:#.+]] = affine_map<(d0)[s0, s1] -> (d0 + s0 * s1)>
// CHECK-LABEL: @test_lstm_general_computation
// CHECK: [[CELL_STATE:%.+]] = alloc() : memref<1x3x3xf32>
// CHECK: [[HIDDEN_STATE:%.+]] = alloc() : memref<1x3x3xf32>
// CHECK: {{.*}} = constant unit
// CHECK: [[INITIAL_VALUE:%.+]] = constant 0.000000e+00 : f32
// CHECK: [[INITIALIZE_LOOPS:%.+]]:3 = krnl.define_loops 3
// CHECK: [[INITIALIZE_OPT_LOOPS:%.+]]:3 = krnl.optimize_loops {
// CHECK: krnl.return_loops [[INITIALIZE_LOOPS]]#0, [[INITIALIZE_LOOPS]]#1, [[INITIALIZE_LOOPS]]#2
// CHECK: } : () -> (!krnl.loop, !krnl.loop, !krnl.loop)
// CHECK: krnl.iterate([[INITIALIZE_OPT_LOOPS]]#0, [[INITIALIZE_OPT_LOOPS]]#1, [[INITIALIZE_OPT_LOOPS]]#2) with ([[INITIALIZE_LOOPS]]#0 -> %arg3 = 0 to 1, [[INITIALIZE_LOOPS]]#1 -> %arg4 = 0 to 3, [[INITIALIZE_LOOPS]]#2 -> %arg5 = 0 to 3) {
// CHECK: store [[INITIAL_VALUE]], [[HIDDEN_STATE]][%arg3, %arg4, %arg5] : memref<1x3x3xf32>
// CHECK: store [[INITIAL_VALUE]], [[CELL_STATE]][%arg3, %arg4, %arg5] : memref<1x3x3xf32>
// CHECK: }
// CHECK: [[SEQUENCE_LOOPS:%.+]] = krnl.define_loops 1
// CHECK: [[SEQUENCE_OPT_LOOPS:%.+]] = krnl.optimize_loops {
// CHECK: krnl.return_loops [[SEQUENCE_LOOPS]]
// CHECK: } : () -> !krnl.loop
// CHECK: krnl.iterate([[SEQUENCE_OPT_LOOPS]]) with ([[SEQUENCE_LOOPS]] -> %arg3 = 0 to 4) {
// CHECK: {{.*}} = constant 0 : index
// CHECK: {{.*}} = constant 3 : index
// CHECK: {{.*}} = constant 0 : index
// CHECK: {{.*}} = constant 1 : index
// CHECK: {{.*}} = constant 2 : index
// CHECK: {{.*}} = constant 3 : index
// CHECK: {{.*}} = constant 4 : index
// CHECK: {{.*}} = constant 5 : index
// CHECK: {{.*}} = constant 6 : index
// CHECK: {{.*}} = constant 7 : index
// CHECK: [[DATA_LOOPS:%.+]]:2 = krnl.define_loops 2
// CHECK: [[DATA_OPT_LOOPS:%.+]]:2 = krnl.optimize_loops {
// CHECK: krnl.return_loops [[DATA_LOOPS]]#0, [[DATA_LOOPS]]#1
// CHECK: } : () -> (!krnl.loop, !krnl.loop)
// CHECK: krnl.iterate([[DATA_OPT_LOOPS]]#0, [[DATA_OPT_LOOPS]]#1) with ([[DATA_LOOPS]]#0 -> %arg4 = 0 to 3, [[DATA_LOOPS]]#1 -> %arg5 = 0 to 3) {
// CHECK: [[hCt:%.+]] = alloc() : memref<f32>
// CHECK: [[Ot:%.+]] = alloc() : memref<f32>
// CHECK: [[ct:%.+]] = alloc() : memref<f32>
// CHECK: [[Ft:%.+]] = alloc() : memref<f32>
// CHECK: [[It:%.+]] = alloc() : memref<f32>
// CHECK: [[Ht1_LOAD:%.+]] = load [[HIDDEN_STATE]][%c0, %arg4, %arg5] : memref<1x3x3xf32>
// CHECK: [[Ct1_LOAD:%.+]] = load [[CELL_STATE]][%c0, %arg4, %arg5] : memref<1x3x3xf32>
// CHECK: [[ZERO_FLOAT:%.+]] = constant 0.000000e+00 : f32
// CHECK: [[XtWi_GEMM:%.+]] = alloc() : memref<f32>
// CHECK: store [[ZERO_FLOAT]], [[XtWi_GEMM]][] : memref<f32>
// CHECK: [[Ht1Ri_GEMM:%.+]] = alloc() : memref<f32>
// CHECK: store [[ZERO_FLOAT]], [[Ht1Ri_GEMM]][] : memref<f32>
// CHECK: [[XtWo_GEMM:%.+]] = alloc() : memref<f32>
// CHECK: store [[ZERO_FLOAT]], [[XtWo_GEMM]][] : memref<f32>
// CHECK: [[Ht1Ro_GEMM:%.+]] = alloc() : memref<f32>
// CHECK: store [[ZERO_FLOAT]], [[Ht1Ro_GEMM]][] : memref<f32>
// CHECK: [[XtWf_GEMM:%.+]] = alloc() : memref<f32>
// CHECK: store [[ZERO_FLOAT]], [[XtWf_GEMM]][] : memref<f32>
// CHECK: [[Ht1Rf_GEMM:%.+]] = alloc() : memref<f32>
// CHECK: store [[ZERO_FLOAT]], [[Ht1Rf_GEMM]][] : memref<f32>
// CHECK: [[XtWc_GEMM:%.+]] = alloc() : memref<f32>
// CHECK: store [[ZERO_FLOAT]], [[XtWc_GEMM]][] : memref<f32>
// CHECK: [[Ht1Rc_GEMM:%.+]] = alloc() : memref<f32>
// CHECK: store [[ZERO_FLOAT]], [[Ht1Rc_GEMM]][] : memref<f32>
// CHECK: [[REDUCTION_LOOPS:%.+]] = krnl.define_loops 1
// CHECK: [[REDUCTION_OPT_LOOPS:%.+]] = krnl.optimize_loops {
// CHECK: krnl.return_loops [[REDUCTION_LOOPS]]
// CHECK: } : () -> !krnl.loop
// CHECK: krnl.iterate([[REDUCTION_OPT_LOOPS]]) with ([[REDUCTION_LOOPS]] -> %arg6 = 0 to 2) {
// CHECK: [[INPUT_HIDDEN_INDEX:%.+]] = affine.apply #{{.*}}(%arg5)[%c0_1, %c3]
// CHECK: [[OUTPUT_HIDDEN_INDEX:%.+]] = affine.apply #{{.*}}(%arg5)[%c1, %c3]
// CHECK: [[FORGET_HIDDEN_INDEX:%.+]] = affine.apply #{{.*}}(%arg5)[%c2, %c3]
// CHECK: [[CELL_HIDDEN_INDEX:%.+]] = affine.apply #{{.*}}(%arg5)[%c3_2, %c3]
// CHECK: [[Xt_LOAD:%.+]] = load %arg0[%arg3, %arg4, %arg6] : memref<4x3x2xf32>
// CHECK: [[Wi_LOAD:%.+]] = load %arg1[%c0, [[INPUT_HIDDEN_INDEX]], %arg6] : memref<1x12x2xf32>
// CHECK: {{.*}} = mulf [[Xt_LOAD]], [[Wi_LOAD]] : f32
// CHECK: {{.*}} = load [[XtWi_GEMM]][] : memref<f32>
// CHECK: {{.*}} = addf {{.*}}, {{.*}} : f32
// CHECK: store %59, [[XtWi_GEMM]][] : memref<f32>
// CHECK: [[Ri_LOAD:%.+]] = load %arg2[%c0, [[INPUT_HIDDEN_INDEX]], %arg6] : memref<1x12x3xf32>
// CHECK: {{.*}} = mulf [[Ht1_LOAD]], [[Ri_LOAD]] : f32
// CHECK: {{.*}} = load [[Ht1Ri_GEMM]][] : memref<f32>
// CHECK: {{.*}} = addf {{.*}}, {{.*}} : f32
// CHECK: store %63, [[Ht1Ri_GEMM]][] : memref<f32>
// CHECK: [[Wo_LOAD:%.+]] = load %arg1[%c0, [[OUTPUT_HIDDEN_INDEX]], %arg6] : memref<1x12x2xf32>
// CHECK: {{.*}} = mulf [[Xt_LOAD]], [[Wo_LOAD]] : f32
// CHECK: {{.*}} = load [[XtWo_GEMM]][] : memref<f32>
// CHECK: {{.*}} = addf {{.*}}, {{.*}} : f32
// CHECK: store %67, [[XtWo_GEMM]][] : memref<f32>
// CHECK: [[Ro_LOAD:%.+]] = load %arg2[%c0, [[OUTPUT_HIDDEN_INDEX]], %arg6] : memref<1x12x3xf32>
// CHECK: {{.*}} = mulf [[Ht1_LOAD]], [[Ro_LOAD]] : f32
// CHECK: {{.*}} = load [[Ht1Ro_GEMM]][] : memref<f32>
// CHECK: {{.*}} = addf {{.*}}, {{.*}} : f32
// CHECK: store %71, [[Ht1Ro_GEMM]][] : memref<f32>
// CHECK: [[Wf_LOAD:%.+]] = load %arg1[%c0, [[FORGET_HIDDEN_INDEX]], %arg6] : memref<1x12x2xf32>
// CHECK: {{.*}} = mulf [[Xt_LOAD]], [[Wf_LOAD]] : f32
// CHECK: {{.*}} = load [[XtWf_GEMM]][] : memref<f32>
// CHECK: {{.*}} = addf {{.*}}, {{.*}} : f32
// CHECK: store %75, [[XtWf_GEMM]][] : memref<f32>
// CHECK: [[Rf_LOAD:%.+]] = load %arg2[%c0, [[FORGET_HIDDEN_INDEX]], %arg6] : memref<1x12x3xf32>
// CHECK: {{.*}} = mulf [[Ht1_LOAD]], [[Rf_LOAD]] : f32
// CHECK: {{.*}} = load [[Ht1Rf_GEMM]][] : memref<f32>
// CHECK: {{.*}} = addf {{.*}}, {{.*}} : f32
// CHECK: store %79, [[Ht1Rf_GEMM]][] : memref<f32>
// CHECK: [[Wc_LOAD:%.+]] = load %arg1[%c0, [[CELL_HIDDEN_INDEX]], %arg6] : memref<1x12x2xf32>
// CHECK: {{.*}} = mulf [[Xt_LOAD]], [[Wc_LOAD]] : f32
// CHECK: {{.*}} = load [[XtWc_GEMM]][] : memref<f32>
// CHECK: {{.*}} = addf {{.*}}, {{.*}} : f32
// CHECK: store %83, [[XtWc_GEMM]][] : memref<f32>
// CHECK: [[Rc_LOAD:%.+]] = load %arg2[%c0, [[CELL_HIDDEN_INDEX]], %arg6] : memref<1x12x3xf32>
// CHECK: {{.*}} = mulf [[Ht1_LOAD]], [[Rc_LOAD]] : f32
// CHECK: {{.*}} = load [[Ht1Rc_GEMM]][] : memref<f32>
// CHECK: {{.*}} = addf {{.*}}, {{.*}} : f32
// CHECK: store %87, [[Ht1Rc_GEMM]][] : memref<f32>
// CHECK: }
// CHECK: [[XtWi_LOAD:%.+]] = load [[XtWi_GEMM]][] : memref<f32>
// CHECK: [[Ht1Ri_LOAD:%.+]] = load [[Ht1Ri_GEMM]][] : memref<f32>
// CHECK: [[It_OUTPUT:%.+]] = addf [[XtWi_LOAD]], [[Ht1Ri_LOAD]] : f32
// CHECK: [[SIGMOID_INPUT:%.+]] = alloc() : memref<f32>
// CHECK: store [[It_OUTPUT]], [[SIGMOID_INPUT]][] : memref<f32>
// CHECK: krnl.define_loops 0
// CHECK: krnl.optimize_loops {
// CHECK: krnl.return_loops
// CHECK: } : () -> ()
// CHECK: krnl.iterate() with () {
// CHECK: {{.*}} = load [[SIGMOID_INPUT]][] : memref<f32>
// CHECK: {{.*}} = constant 0.000000e+00 : f32
// CHECK: {{.*}} = constant 1.000000e+00 : f32
// CHECK: {{.*}} = subf {{.*}}, {{.*}}: f32
// CHECK: {{.*}} = exp {{.*}} : f32
// CHECK: {{.*}} = addf {{.*}}, {{.*}} : f32
// CHECK: {{.*}} = divf {{.*}}, {{.*}} : f32
// CHECK: store {{.*}}, [[It]][] : memref<f32>
// CHECK: }
// CHECK: [[It_LOAD:%.+]] = load [[It]][] : memref<f32>
// CHECK: [[XtWf_LOAD:%.+]] = load [[XtWf_GEMM]][] : memref<f32>
// CHECK: [[Ht1Rf_LOAD:%.+]] = load [[Ht1Rf_GEMM]][] : memref<f32>
// CHECK: [[Ft_OUTPUT:%.+]] = addf [[XtWf_LOAD]], [[Ht1Rf_LOAD]] : f32
// CHECK: [[SIGMOID_FORGET:%.+]] = alloc() : memref<f32>
// CHECK: store [[Ft_OUTPUT]], [[SIGMOID_FORGET]][] : memref<f32>
// CHECK: krnl.define_loops 0
// CHECK: krnl.optimize_loops {
// CHECK: krnl.return_loops
// CHECK: } : () -> ()
// CHECK: krnl.iterate() with () {
// CHECK: {{.*}} = load [[SIGMOID_FORGET]][] : memref<f32>
// CHECK: {{.*}} = constant 0.000000e+00 : f32
// CHECK: {{.*}} = constant 1.000000e+00 : f32
// CHECK: {{.*}} = subf {{.*}}, {{.*}}: f32
// CHECK: {{.*}} = exp {{.*}} : f32
// CHECK: {{.*}} = addf {{.*}}, {{.*}} : f32
// CHECK: {{.*}} = divf {{.*}}, {{.*}} : f32
// CHECK: store {{.*}}, [[Ft]][] : memref<f32>
// CHECK: }
// CHECK: [[Ft_LOAD:%.+]] = load [[Ft]][] : memref<f32>
// CHECK: [[XtWc_LOAD:%.+]] = load [[XtWc_GEMM]][] : memref<f32>
// CHECK: [[Ht1Rc_LOAD:%.+]] = load [[Ht1Rc_GEMM]][] : memref<f32>
// CHECK: [[ct_OUTPUT:%.+]] = addf [[XtWc_LOAD]], [[Ht1Rc_LOAD]] : f32
// CHECK: [[TANH_CELL:%.+]] = alloc() : memref<f32>
// CHECK: store [[ct_OUTPUT]], [[TANH_CELL]][] : memref<f32>
// CHECK: krnl.define_loops 0
// CHECK: krnl.optimize_loops {
// CHECK: krnl.return_loops
// CHECK: } : () -> ()
// CHECK: krnl.iterate() with () {
// CHECK: {{.*}} = load [[TANH_CELL]][] : memref<f32>
// CHECK: {{.*}} = constant 0.000000e+00 : f32
// CHECK: {{.*}} = subf {{.*}}, {{.*}} : f32
// CHECK: {{.*}} = exp {{.*}} : f32
// CHECK: {{.*}} = exp {{.*}} : f32
// CHECK: {{.*}} = subf {{.*}}, {{.*}} : f32
// CHECK: {{.*}} = addf {{.*}}, {{.*}} : f32
// CHECK: {{.*}} = divf {{.*}}, {{.*}} : f32
// CHECK: store {{.*}}, [[ct]][] : memref<f32>
// CHECK: }
// CHECK: [[ct_LOAD:%.+]] = load [[ct]][] : memref<f32>
// CHECK: [[FtCt1:%.+]] = mulf [[Ft_LOAD]], [[Ct1_LOAD]] : f32
// CHECK: [[Itct:%.+]] = mulf [[It_LOAD]], [[ct_LOAD]] : f32
// CHECK: [[Ct:%.+]] = addf [[FtCt1]], [[Itct]] : f32
// CHECK: store [[Ct]], [[CELL_STATE]][%c0, %arg4, %arg5] : memref<1x3x3xf32>
// CHECK: [[XtWo_LOAD:%.+]] = load [[XtWo_GEMM]][] : memref<f32>
// CHECK: [[Ht1Ro_LOAD:%.+]] = load [[Ht1Ro_GEMM]][] : memref<f32>
// CHECK: [[Ot_OUTPUT:%.+]] = addf [[XtWo_LOAD]], [[Ht1Ro_LOAD]] : f32
// CHECK: [[SIGMOID_OUTPUT:%.+]] = alloc() : memref<f32>
// CHECK: store [[Ot_OUTPUT]], [[SIGMOID_OUTPUT]][] : memref<f32>
// CHECK: krnl.define_loops 0
// CHECK: krnl.optimize_loops {
// CHECK: krnl.return_loops
// CHECK: } : () -> ()
// CHECK: krnl.iterate() with () {
// CHECK: {{.*}} = load [[SIGMOID_OUTPUT]][] : memref<f32>
// CHECK: {{.*}} = constant 0.000000e+00 : f32
// CHECK: {{.*}} = constant 1.000000e+00 : f32
// CHECK: {{.*}} = subf {{.*}}, {{.*}}: f32
// CHECK: {{.*}} = exp {{.*}} : f32
// CHECK: {{.*}} = addf {{.*}}, {{.*}} : f32
// CHECK: {{.*}} = divf {{.*}}, {{.*}} : f32
// CHECK: store {{.*}}, [[Ot]][] : memref<f32>
// CHECK: }
// CHECK: [[Ot_LOAD:%.+]] = load [[Ot]][] : memref<f32>
// CHECK: [[TANH_HIDDEN:%.+]] = alloc() : memref<f32>
// CHECK: store [[Ct]], [[TANH_HIDDEN]][] : memref<f32>
// CHECK: krnl.define_loops 0
// CHECK: krnl.optimize_loops {
// CHECK: krnl.return_loops
// CHECK: } : () -> ()
// CHECK: krnl.iterate() with () {
// CHECK: {{.*}} = load [[TANH_HIDDEN]][] : memref<f32>
// CHECK: {{.*}} = constant 0.000000e+00 : f32
// CHECK: {{.*}} = subf {{.*}}, {{.*}} : f32
// CHECK: {{.*}} = exp {{.*}} : f32
// CHECK: {{.*}} = exp {{.*}} : f32
// CHECK: {{.*}} = subf {{.*}}, {{.*}} : f32
// CHECK: {{.*}} = addf {{.*}}, {{.*}} : f32
// CHECK: {{.*}} = divf {{.*}}, {{.*}} : f32
// CHECK: store {{.*}}, [[hCt]][] : memref<f32>
// CHECK: }
// CHECK: [[hCt_LOAD:%.+]] = load [[hCt]][] : memref<f32>
// CHECK: [[Ht:%.+]] = mulf [[Ot_LOAD]], [[hCt_LOAD]] : f32
// CHECK: store [[Ht]], [[HIDDEN_STATE]][%c0, %arg4, %arg5] : memref<1x3x3xf32>
// CHECK: dealloc [[XtWi_GEMM]] : memref<f32>
// CHECK: dealloc [[XtWo_GEMM]] : memref<f32>
// CHECK: dealloc [[XtWf_GEMM]] : memref<f32>
// CHECK: dealloc [[XtWc_GEMM]] : memref<f32>
// CHECK: dealloc [[Ht1Ri_GEMM]] : memref<f32>
// CHECK: dealloc [[Ht1Ro_GEMM]] : memref<f32>
// CHECK: dealloc [[Ht1Rf_GEMM]] : memref<f32>
// CHECK: dealloc [[Ht1Rc_GEMM]] : memref<f32>
// CHECK: dealloc [[It]] : memref<f32>
// CHECK: dealloc [[Ft]] : memref<f32>
// CHECK: dealloc [[ct]] : memref<f32>
// CHECK: dealloc [[Ot]] : memref<f32>
// CHECK: dealloc [[hCt]] : memref<f32>
// CHECK: }
// CHECK: }
// CHECK: dealloc [[CELL_STATE]] : memref<1x3x3xf32>
// CHECK: return [[HIDDEN_STATE]] : memref<1x3x3xf32>
}
// -----
func @test_lstm_reverse_mode(%arg0: tensor<4x3x2xf32>, %arg1: tensor<1x12x2xf32>, %arg2: tensor<1x12x3xf32>) -> tensor<*xf32> {
%cst = constant unit
%Y, %Y_h, %Y_c = "onnx.LSTM"(%arg0, %arg1, %arg2, %cst, %cst, %cst, %cst, %cst) {hidden_size = 3 : i64, direction = "reverse"} : (tensor<4x3x2xf32>, tensor<1x12x2xf32>, tensor<1x12x3xf32>, none, none, none, none, none) -> (none, tensor<*xf32>, none)
return %Y_h : tensor<*xf32>
// CHECK: [[REVERSE_IV_MAP:#.+]] = affine_map<(d0)[s0] -> (-d0 + s0 - 1)>
// CHECK-LABEL: @test_lstm_reverse_mode
// CHECK: [[REVERSE_SEQUENCE_LOOPS:%.+]] = krnl.define_loops 1
// CHECK: [[REVERSE_SEQUENCE_OPT_LOOPS:%.+]] = krnl.optimize_loops {
// CHECK: krnl.return_loops [[REVERSE_SEQUENCE_LOOPS]]
// CHECK: } : () -> !krnl.loop
// CHECK: krnl.iterate([[REVERSE_SEQUENCE_OPT_LOOPS]]) with ([[REVERSE_SEQUENCE_LOOPS]] -> %arg3 = 0 to 4) {
// CHECK: %[[SEQUENCE_LEN:.+]] = constant 4 : index
// CHECK: %[[REVERSE_SEQUENCE_IV:.+]] = affine.apply [[REVERSE_IV_MAP]](%arg3)[%[[SEQUENCE_LEN]]{{]}}
// CHECK: [[Xt_LOAD:%.+]] = load %arg0[%[[REVERSE_SEQUENCE_IV]], {{.*}}, {{.*}}] : memref<4x3x2xf32>
}
// -----
func @test_lstm_bidirectional_mode(%arg0: tensor<4x3x2xf32>, %arg1: tensor<1x12x2xf32>, %arg2: tensor<1x12x3xf32>) -> tensor<*xf32> {
%cst = constant unit
%Y, %Y_h, %Y_c = "onnx.LSTM"(%arg0, %arg1, %arg2, %cst, %cst, %cst, %cst, %cst) {hidden_size = 3 : i64, direction = "bidirectional"} : (tensor<4x3x2xf32>, tensor<1x12x2xf32>, tensor<1x12x3xf32>, none, none, none, none, none) -> (none, tensor<*xf32>, none)
return %Y_h : tensor<*xf32>
// CHECK: [[REVERSE_IV_MAP:#.+]] = affine_map<(d0)[s0] -> (-d0 + s0 - 1)>
// CHECK-LABEL: @test_lstm_bidirectional_mode
// CHECK: [[SEQUENCE_LOOPS:%.+]] = krnl.define_loops 1
// CHECK: [[SEQUENCE_OPT_LOOPS:%.+]] = krnl.optimize_loops {
// CHECK: krnl.return_loops [[SEQUENCE_LOOPS]]
// CHECK: } : () -> !krnl.loop
// CHECK: krnl.iterate([[SEQUENCE_OPT_LOOPS]]) with ([[SEQUENCE_LOOPS]] -> %arg3 = 0 to 4) {
// CHECK: [[Xt_LOAD:%.+]] = load %arg0[%arg3, {{.*}}, {{.*}}] : memref<4x3x2xf32>
// CHECK: [[REVERSE_SEQUENCE_LOOPS:%.+]] = krnl.define_loops 1
// CHECK: [[REVERSE_SEQUENCE_OPT_LOOPS:%.+]] = krnl.optimize_loops {
// CHECK: krnl.return_loops [[REVERSE_SEQUENCE_LOOPS]]
// CHECK: } : () -> !krnl.loop
// CHECK: krnl.iterate([[REVERSE_SEQUENCE_OPT_LOOPS]]) with ([[REVERSE_SEQUENCE_LOOPS]] -> %arg3 = 0 to 4) {
// CHECK: %[[SEQUENCE_LEN:.+]] = constant 4 : index
// CHECK: %[[REVERSE_SEQUENCE_IV:.+]] = affine.apply [[REVERSE_IV_MAP]](%arg3)[%[[SEQUENCE_LEN]]{{]}}
// CHECK: [[Xt_LOAD:%.+]] = load %arg0[%[[REVERSE_SEQUENCE_IV]], {{.*}}, {{.*}}] : memref<4x3x2xf32>
}

View File

@ -613,6 +613,222 @@ func @test_concat_3(%arg0 : tensor<5x1x32xf32>, %arg1 : tensor<5x3x32xf32>, %arg
// -----
func @test_rnn_all_results(%arg0: tensor<4x3x2xf32>, %arg1: tensor<1x12x2xf32>, %arg2: tensor<1x12x3xf32>) -> tensor<*xf32> {
%cst = constant unit
%Y, %Y_h = "onnx.RNN"(%arg0, %arg1, %arg2, %cst, %cst, %cst) {hidden_size = 3 : i64} : (tensor<4x3x2xf32>, tensor<1x12x2xf32>, tensor<1x12x3xf32>, none, none, none) -> (tensor<*xf32>, tensor<*xf32>)
return %Y_h : tensor<*xf32>
// CHECK-LABEL: test_rnn_all_results
// CHECK: %{{.*}}, [[RES:%.+]] = "onnx.RNN"(%arg0, %arg1, %arg2, %cst, %cst, %cst) {hidden_size = 3 : i64} : (tensor<4x3x2xf32>, tensor<1x12x2xf32>, tensor<1x12x3xf32>, none, none, none) -> (tensor<4x1x3x3xf32>, tensor<1x3x3xf32>)
// CHECK: return [[RES]] : tensor<1x3x3xf32>
}
// -----
func @test_rnn_no_results(%arg0: tensor<4x3x2xf32>, %arg1: tensor<1x12x2xf32>, %arg2: tensor<1x12x3xf32>) -> () {
%cst = constant unit
%Y, %Y_h = "onnx.RNN"(%arg0, %arg1, %arg2, %cst, %cst, %cst) {hidden_size = 3 : i64} : (tensor<4x3x2xf32>, tensor<1x12x2xf32>, tensor<1x12x3xf32>, none, none, none) -> (none, none)
return
// CHECK-LABEL: test_rnn_no_results
// CHECK: %{{.*}}, [[RES:%.+]] = "onnx.RNN"(%arg0, %arg1, %arg2, %cst, %cst, %cst) {hidden_size = 3 : i64} : (tensor<4x3x2xf32>, tensor<1x12x2xf32>, tensor<1x12x3xf32>, none, none, none) -> (none, none)
// CHECK: return
}
// -----
func @test_rnn_missing_first_result(%arg0: tensor<4x3x2xf32>, %arg1: tensor<1x12x2xf32>, %arg2: tensor<1x12x3xf32>) -> tensor<*xf32> {
%cst = constant unit
%Y, %Y_h = "onnx.RNN"(%arg0, %arg1, %arg2, %cst, %cst, %cst) {hidden_size = 3 : i64} : (tensor<4x3x2xf32>, tensor<1x12x2xf32>, tensor<1x12x3xf32>, none, none, none) -> (none, tensor<*xf32>)
return %Y_h : tensor<*xf32>
// CHECK-LABEL: test_rnn_missing_first_result
// CHECK: %{{.*}}, [[RES:%.+]] = "onnx.RNN"(%arg0, %arg1, %arg2, %cst, %cst, %cst) {hidden_size = 3 : i64} : (tensor<4x3x2xf32>, tensor<1x12x2xf32>, tensor<1x12x3xf32>, none, none, none) -> (none, tensor<1x3x3xf32>)
// CHECK: return [[RES]] : tensor<1x3x3xf32>
}
// -----
func @test_rnn_missing_trailing_result(%arg0: tensor<4x3x2xf32>, %arg1: tensor<1x12x2xf32>, %arg2: tensor<1x12x3xf32>) -> () {
%cst = constant unit
%Y, %Y_h = "onnx.RNN"(%arg0, %arg1, %arg2, %cst, %cst, %cst) {hidden_size = 3 : i64} : (tensor<4x3x2xf32>, tensor<1x12x2xf32>, tensor<1x12x3xf32>, none, none, none) -> (tensor<*xf32>, none)
return
// CHECK-LABEL: test_rnn_missing_trailing_result
// CHECK: %{{.*}}, [[RES:%.+]] = "onnx.RNN"(%arg0, %arg1, %arg2, %cst, %cst, %cst) {hidden_size = 3 : i64} : (tensor<4x3x2xf32>, tensor<1x12x2xf32>, tensor<1x12x3xf32>, none, none, none) -> (tensor<4x1x3x3xf32>, none)
// CHECK: return
}
// -----
func @test_rnn_all_results_no_hidden_size(%arg0: tensor<4x3x2xf32>, %arg1: tensor<1x12x2xf32>, %arg2: tensor<1x12x3xf32>) -> tensor<*xf32> {
%cst = constant unit
%Y, %Y_h = "onnx.RNN"(%arg0, %arg1, %arg2, %cst, %cst, %cst) : (tensor<4x3x2xf32>, tensor<1x12x2xf32>, tensor<1x12x3xf32>, none, none, none) -> (tensor<*xf32>, tensor<*xf32>)
return %Y_h : tensor<*xf32>
// CHECK-LABEL: test_rnn_all_results_no_hidden_size
// CHECK: %{{.*}}, [[RES:%.+]] = "onnx.RNN"(%arg0, %arg1, %arg2, %cst, %cst, %cst) {hidden_size = 3 : i64} : (tensor<4x3x2xf32>, tensor<1x12x2xf32>, tensor<1x12x3xf32>, none, none, none) -> (tensor<4x1x3x3xf32>, tensor<1x3x3xf32>)
// CHECK: return [[RES]] : tensor<1x3x3xf32>
}
// -----
func @test_rnn_all_results_unknown_dims(%arg0: tensor<?x?x?xf32>, %arg1: tensor<?x?x?xf32>, %arg2: tensor<?x?x?xf32>) -> tensor<*xf32> {
%cst = constant unit
%Y, %Y_h = "onnx.RNN"(%arg0, %arg1, %arg2, %cst, %cst, %cst) : (tensor<?x?x?xf32>, tensor<?x?x?xf32>, tensor<?x?x?xf32>, none, none, none) -> (tensor<*xf32>, tensor<*xf32>)
return %Y_h : tensor<*xf32>
// CHECK-LABEL: test_rnn_all_results_unknown_dims
// CHECK: %{{.*}}, [[RES:%.+]] = "onnx.RNN"(%arg0, %arg1, %arg2, %cst, %cst, %cst) : (tensor<?x?x?xf32>, tensor<?x?x?xf32>, tensor<?x?x?xf32>, none, none, none) -> (tensor<?x1x?x?xf32>, tensor<1x?x?xf32>)
// CHECK: return [[RES]] : tensor<1x?x?xf32>
}
// -----
func @test_gru_all_results(%arg0: tensor<4x3x2xf32>, %arg1: tensor<1x12x2xf32>, %arg2: tensor<1x12x3xf32>) -> tensor<*xf32> {
%cst = constant unit
%Y, %Y_h = "onnx.GRU"(%arg0, %arg1, %arg2, %cst, %cst, %cst) {hidden_size = 3 : i64} : (tensor<4x3x2xf32>, tensor<1x12x2xf32>, tensor<1x12x3xf32>, none, none, none) -> (tensor<*xf32>, tensor<*xf32>)
return %Y_h : tensor<*xf32>
// CHECK-LABEL: test_gru_all_results
// CHECK: %{{.*}}, [[RES:%.+]] = "onnx.GRU"(%arg0, %arg1, %arg2, %cst, %cst, %cst) {hidden_size = 3 : i64} : (tensor<4x3x2xf32>, tensor<1x12x2xf32>, tensor<1x12x3xf32>, none, none, none) -> (tensor<4x1x3x3xf32>, tensor<1x3x3xf32>)
// CHECK: return [[RES]] : tensor<1x3x3xf32>
}
// -----
func @test_gru_no_results(%arg0: tensor<4x3x2xf32>, %arg1: tensor<1x12x2xf32>, %arg2: tensor<1x12x3xf32>) -> () {
%cst = constant unit
%Y, %Y_h = "onnx.GRU"(%arg0, %arg1, %arg2, %cst, %cst, %cst) {hidden_size = 3 : i64} : (tensor<4x3x2xf32>, tensor<1x12x2xf32>, tensor<1x12x3xf32>, none, none, none) -> (none, none)
return
// CHECK-LABEL: test_gru_no_results
// CHECK: %{{.*}}, [[RES:%.+]] = "onnx.GRU"(%arg0, %arg1, %arg2, %cst, %cst, %cst) {hidden_size = 3 : i64} : (tensor<4x3x2xf32>, tensor<1x12x2xf32>, tensor<1x12x3xf32>, none, none, none) -> (none, none)
// CHECK: return
}
// -----
func @test_gru_missing_first_result(%arg0: tensor<4x3x2xf32>, %arg1: tensor<1x12x2xf32>, %arg2: tensor<1x12x3xf32>) -> tensor<*xf32> {
%cst = constant unit
%Y, %Y_h = "onnx.GRU"(%arg0, %arg1, %arg2, %cst, %cst, %cst) {hidden_size = 3 : i64} : (tensor<4x3x2xf32>, tensor<1x12x2xf32>, tensor<1x12x3xf32>, none, none, none) -> (none, tensor<*xf32>)
return %Y_h : tensor<*xf32>
// CHECK-LABEL: test_gru_missing_first_result
// CHECK: %{{.*}}, [[RES:%.+]] = "onnx.GRU"(%arg0, %arg1, %arg2, %cst, %cst, %cst) {hidden_size = 3 : i64} : (tensor<4x3x2xf32>, tensor<1x12x2xf32>, tensor<1x12x3xf32>, none, none, none) -> (none, tensor<1x3x3xf32>)
// CHECK: return [[RES]] : tensor<1x3x3xf32>
}
// -----
func @test_gru_missing_trailing_result(%arg0: tensor<4x3x2xf32>, %arg1: tensor<1x12x2xf32>, %arg2: tensor<1x12x3xf32>) -> () {
%cst = constant unit
%Y, %Y_h = "onnx.GRU"(%arg0, %arg1, %arg2, %cst, %cst, %cst) {hidden_size = 3 : i64} : (tensor<4x3x2xf32>, tensor<1x12x2xf32>, tensor<1x12x3xf32>, none, none, none) -> (tensor<*xf32>, none)
return
// CHECK-LABEL: test_gru_missing_trailing_result
// CHECK: %{{.*}}, [[RES:%.+]] = "onnx.GRU"(%arg0, %arg1, %arg2, %cst, %cst, %cst) {hidden_size = 3 : i64} : (tensor<4x3x2xf32>, tensor<1x12x2xf32>, tensor<1x12x3xf32>, none, none, none) -> (tensor<4x1x3x3xf32>, none)
// CHECK: return
}
// -----
func @test_gru_all_results_no_hidden_size(%arg0: tensor<4x3x2xf32>, %arg1: tensor<1x12x2xf32>, %arg2: tensor<1x12x3xf32>) -> tensor<*xf32> {
%cst = constant unit
%Y, %Y_h = "onnx.GRU"(%arg0, %arg1, %arg2, %cst, %cst, %cst) : (tensor<4x3x2xf32>, tensor<1x12x2xf32>, tensor<1x12x3xf32>, none, none, none) -> (tensor<*xf32>, tensor<*xf32>)
return %Y_h : tensor<*xf32>
// CHECK-LABEL: test_gru_all_results_no_hidden_size
// CHECK: %{{.*}}, [[RES:%.+]] = "onnx.GRU"(%arg0, %arg1, %arg2, %cst, %cst, %cst) {hidden_size = 3 : i64} : (tensor<4x3x2xf32>, tensor<1x12x2xf32>, tensor<1x12x3xf32>, none, none, none) -> (tensor<4x1x3x3xf32>, tensor<1x3x3xf32>)
// CHECK: return [[RES]] : tensor<1x3x3xf32>
}
// -----
func @test_gru_all_results_unknown_dims(%arg0: tensor<?x?x?xf32>, %arg1: tensor<?x?x?xf32>, %arg2: tensor<?x?x?xf32>) -> tensor<*xf32> {
%cst = constant unit
%Y, %Y_h = "onnx.GRU"(%arg0, %arg1, %arg2, %cst, %cst, %cst) : (tensor<?x?x?xf32>, tensor<?x?x?xf32>, tensor<?x?x?xf32>, none, none, none) -> (tensor<*xf32>, tensor<*xf32>)
return %Y_h : tensor<*xf32>
// CHECK-LABEL: test_gru_all_results_unknown_dims
// CHECK: %{{.*}}, [[RES:%.+]] = "onnx.GRU"(%arg0, %arg1, %arg2, %cst, %cst, %cst) : (tensor<?x?x?xf32>, tensor<?x?x?xf32>, tensor<?x?x?xf32>, none, none, none) -> (tensor<?x1x?x?xf32>, tensor<1x?x?xf32>)
// CHECK: return [[RES]] : tensor<1x?x?xf32>
}
// -----
func @test_lstm_all_results(%arg0: tensor<4x3x2xf32>, %arg1: tensor<1x12x2xf32>, %arg2: tensor<1x12x3xf32>) -> tensor<*xf32> {
%cst = constant unit
%Y, %Y_h, %Y_c = "onnx.LSTM"(%arg0, %arg1, %arg2, %cst, %cst, %cst, %cst, %cst) {hidden_size = 3 : i64} : (tensor<4x3x2xf32>, tensor<1x12x2xf32>, tensor<1x12x3xf32>, none, none, none, none, none) -> (tensor<*xf32>, tensor<*xf32>, tensor<*xf32>)
return %Y_h : tensor<*xf32>
// CHECK-LABEL: test_lstm_all_results
// CHECK: %{{.*}}, [[RES:%.+]], %{{.*}} = "onnx.LSTM"(%arg0, %arg1, %arg2, %cst, %cst, %cst, %cst, %cst) {hidden_size = 3 : i64} : (tensor<4x3x2xf32>, tensor<1x12x2xf32>, tensor<1x12x3xf32>, none, none, none, none, none) -> (tensor<4x1x3x3xf32>, tensor<1x3x3xf32>, tensor<1x3x3xf32>)
// CHECK: return [[RES]] : tensor<1x3x3xf32>
}
// -----
func @test_lstm_no_results(%arg0: tensor<4x3x2xf32>, %arg1: tensor<1x12x2xf32>, %arg2: tensor<1x12x3xf32>) -> () {
%cst = constant unit
%Y, %Y_h, %Y_c = "onnx.LSTM"(%arg0, %arg1, %arg2, %cst, %cst, %cst, %cst, %cst) {hidden_size = 3 : i64} : (tensor<4x3x2xf32>, tensor<1x12x2xf32>, tensor<1x12x3xf32>, none, none, none, none, none) -> (none, none, none)
return
// CHECK-LABEL: test_lstm_no_results
// CHECK: %{{.*}}, [[RES:%.+]], %{{.*}} = "onnx.LSTM"(%arg0, %arg1, %arg2, %cst, %cst, %cst, %cst, %cst) {hidden_size = 3 : i64} : (tensor<4x3x2xf32>, tensor<1x12x2xf32>, tensor<1x12x3xf32>, none, none, none, none, none) -> (none, none, none)
// CHECK: return
}
// -----
func @test_lstm_missing_first_result(%arg0: tensor<4x3x2xf32>, %arg1: tensor<1x12x2xf32>, %arg2: tensor<1x12x3xf32>) -> tensor<*xf32> {
%cst = constant unit
%Y, %Y_h, %Y_c = "onnx.LSTM"(%arg0, %arg1, %arg2, %cst, %cst, %cst, %cst, %cst) {hidden_size = 3 : i64} : (tensor<4x3x2xf32>, tensor<1x12x2xf32>, tensor<1x12x3xf32>, none, none, none, none, none) -> (none, tensor<*xf32>, tensor<*xf32>)
return %Y_h : tensor<*xf32>
// CHECK-LABEL: test_lstm_missing_first_result
// CHECK: %{{.*}}, [[RES:%.+]], %{{.*}} = "onnx.LSTM"(%arg0, %arg1, %arg2, %cst, %cst, %cst, %cst, %cst) {hidden_size = 3 : i64} : (tensor<4x3x2xf32>, tensor<1x12x2xf32>, tensor<1x12x3xf32>, none, none, none, none, none) -> (none, tensor<1x3x3xf32>, tensor<1x3x3xf32>)
// CHECK: return [[RES]] : tensor<1x3x3xf32>
}
// -----
func @test_lstm_missing_trailing_result(%arg0: tensor<4x3x2xf32>, %arg1: tensor<1x12x2xf32>, %arg2: tensor<1x12x3xf32>) -> tensor<*xf32> {
%cst = constant unit
%Y, %Y_h, %Y_c = "onnx.LSTM"(%arg0, %arg1, %arg2, %cst, %cst, %cst, %cst, %cst) {hidden_size = 3 : i64} : (tensor<4x3x2xf32>, tensor<1x12x2xf32>, tensor<1x12x3xf32>, none, none, none, none, none) -> (tensor<*xf32>, tensor<*xf32>, none)
return %Y_h : tensor<*xf32>
// CHECK-LABEL: test_lstm_missing_trailing_result
// CHECK: %{{.*}}, [[RES:%.+]], %{{.*}} = "onnx.LSTM"(%arg0, %arg1, %arg2, %cst, %cst, %cst, %cst, %cst) {hidden_size = 3 : i64} : (tensor<4x3x2xf32>, tensor<1x12x2xf32>, tensor<1x12x3xf32>, none, none, none, none, none) -> (tensor<4x1x3x3xf32>, tensor<1x3x3xf32>, none)
// CHECK: return [[RES]] : tensor<1x3x3xf32>
}
// -----
func @test_lstm_all_results_no_hidden_size(%arg0: tensor<4x3x2xf32>, %arg1: tensor<1x12x2xf32>, %arg2: tensor<1x12x3xf32>) -> tensor<*xf32> {
%cst = constant unit
%Y, %Y_h, %Y_c = "onnx.LSTM"(%arg0, %arg1, %arg2, %cst, %cst, %cst, %cst, %cst) : (tensor<4x3x2xf32>, tensor<1x12x2xf32>, tensor<1x12x3xf32>, none, none, none, none, none) -> (tensor<*xf32>, tensor<*xf32>, tensor<*xf32>)
return %Y_h : tensor<*xf32>
// CHECK-LABEL: test_lstm_all_results_no_hidden_size
// CHECK: %{{.*}}, [[RES:%.+]], %{{.*}} = "onnx.LSTM"(%arg0, %arg1, %arg2, %cst, %cst, %cst, %cst, %cst) {hidden_size = 3 : i64} : (tensor<4x3x2xf32>, tensor<1x12x2xf32>, tensor<1x12x3xf32>, none, none, none, none, none) -> (tensor<4x1x3x3xf32>, tensor<1x3x3xf32>, tensor<1x3x3xf32>)
// CHECK: return [[RES]] : tensor<1x3x3xf32>
}
// -----
func @test_lstm_all_results_unknown_dims(%arg0: tensor<?x?x?xf32>, %arg1: tensor<?x?x?xf32>, %arg2: tensor<?x?x?xf32>) -> tensor<*xf32> {
%cst = constant unit
%Y, %Y_h, %Y_c = "onnx.LSTM"(%arg0, %arg1, %arg2, %cst, %cst, %cst, %cst, %cst) : (tensor<?x?x?xf32>, tensor<?x?x?xf32>, tensor<?x?x?xf32>, none, none, none, none, none) -> (tensor<*xf32>, tensor<*xf32>, tensor<*xf32>)
return %Y_h : tensor<*xf32>
// CHECK-LABEL: test_lstm_all_results_unknown_dims
// CHECK: %{{.*}}, [[RES:%.+]], %{{.*}} = "onnx.LSTM"(%arg0, %arg1, %arg2, %cst, %cst, %cst, %cst, %cst) : (tensor<?x?x?xf32>, tensor<?x?x?xf32>, tensor<?x?x?xf32>, none, none, none, none, none) -> (tensor<?x1x?x?xf32>, tensor<1x?x?xf32>, tensor<1x?x?xf32>)
// CHECK: return [[RES]] : tensor<1x?x?xf32>
}
// -----
func @test_split_1(%arg0 : tensor<16x32x64xf32>) -> tensor<*xf32> {
%0, %1 = "onnx.Split"(%arg0) { axis = 1 } : (tensor<16x32x64xf32>) -> (tensor<*xf32>, tensor<*xf32>)
"std.return"(%0) : (tensor<*xf32>) -> ()

View File

@ -63,7 +63,8 @@ OpsWithShapeInference = [
'LeakyRelu', 'Elu', 'Selu', 'HardSigmoid', 'Reshape', 'Reciprocal',
'Identity', 'Cos', 'Log', 'Transpose', 'Softmax', 'ReduceMax', 'ReduceMin',
'ReduceProd', 'ReduceSum', 'Softplus', 'Softsign', 'Sqrt', 'Unsqueeze',
'Sign', 'Constant', 'AveragePool', 'Abs', 'Conv', 'Concat', 'Neg', 'Split'
'Sign', 'Constant', 'AveragePool', 'Abs', 'Conv', 'Concat', 'Neg', 'RNN',
'LSTM', 'GRU', 'Split'
]
# Operations supporting canonicalization.