Lower LSTMOp to Krnl dialect (#73)
* Support dilations and enable e2e tests * Fix allocating memory for dynamic shape * Edit comments * Do dilation by computing an offset from kernel index * Correct dilation formula, add an example of out-of-bound, and add a test for dilation * Import optional outputs as NoneType * Shape inference for ONNXLSTM * Edit ONNXLSTM::inferShape() * Shape inference for ONNXLSTMOp * Create a common function for inferring shape for RNN ops * CheckInsertDeallocation for a specific result * Allocate memory for LSTM * First round of lowering * Allocate memory for hidden and cell states * Test with custom Tanh * Fix an error in Ct's formula * Add E2E tests * Return outputs * Refactor the code * Enable E2E tests * Support reverse and bidirectional directions * Minor revision * Return all intermediate hidden states * Call existing activation functions * Structs for activation functions * Call existing activations in ONNX * Minor revision * Compare strings ignoring case * Use memreftype of rank 0 for calling activation functions * Fix getActivationPack() * Revise the code * Add one MLIR test * Add MLIR tests for reverse and bidirectional modes * Make the order of emiting instructions deterministic * Use OperandAdaptor instead of directly use an operand index * Use literal assignments * Change some variable names * Use literal assignments * Use literal assignments * Format the code Co-authored-by: Gheorghe-Teodor Bercea <gt.bercea@gmail.com>
This commit is contained in:
parent
9a874007ce
commit
24343177b8
|
@ -9,6 +9,9 @@ add_library(OMONNXToKrnl
|
||||||
NN/Conv.cpp
|
NN/Conv.cpp
|
||||||
NN/Normalization.cpp
|
NN/Normalization.cpp
|
||||||
NN/Pooling.cpp
|
NN/Pooling.cpp
|
||||||
|
RNN/RNNBase.cpp
|
||||||
|
RNN/RNNBase.hpp
|
||||||
|
RNN/LSTM.cpp
|
||||||
Tensor/Identity.cpp
|
Tensor/Identity.cpp
|
||||||
Tensor/Reshape.cpp
|
Tensor/Reshape.cpp
|
||||||
Tensor/PadConstantValuePad.cpp
|
Tensor/PadConstantValuePad.cpp
|
||||||
|
|
|
@ -103,6 +103,8 @@ void FrontendToKrnlLoweringPass::runOnOperation() {
|
||||||
populateLoweringONNXConvOpPattern(patterns, &getContext());
|
populateLoweringONNXConvOpPattern(patterns, &getContext());
|
||||||
populateLoweringONNXNormalizationOpPattern(patterns, &getContext());
|
populateLoweringONNXNormalizationOpPattern(patterns, &getContext());
|
||||||
populateLoweringONNXPoolingOpPattern(patterns, &getContext());
|
populateLoweringONNXPoolingOpPattern(patterns, &getContext());
|
||||||
|
// Recurrent neural network
|
||||||
|
populateLoweringONNXLSTMOpPattern(patterns, &getContext());
|
||||||
// Entry point
|
// Entry point
|
||||||
patterns.insert<ONNXEntryPointLowering>(&getContext());
|
patterns.insert<ONNXEntryPointLowering>(&getContext());
|
||||||
|
|
||||||
|
|
|
@ -102,16 +102,14 @@ Value insertAllocAndDealloc(MemRefType type, Location loc,
|
||||||
// Determine if current function returns the result value of the
|
// Determine if current function returns the result value of the
|
||||||
// current op being lowered. If it does then dealloc should not be
|
// current op being lowered. If it does then dealloc should not be
|
||||||
// inserted.
|
// inserted.
|
||||||
bool checkInsertDealloc(Operation *currentOp) {
|
bool checkInsertDealloc(Operation *currentOp, int resultIndex) {
|
||||||
auto parentBlock = currentOp->getBlock();
|
auto parentBlock = currentOp->getBlock();
|
||||||
|
|
||||||
bool insertDealloc = true;
|
bool insertDealloc = true;
|
||||||
parentBlock->walk([&insertDealloc, currentOp](ReturnOp op) {
|
parentBlock->walk([&insertDealloc, currentOp, resultIndex](ReturnOp op) {
|
||||||
assert(currentOp->getNumResults() < 2 &&
|
|
||||||
"No more than one result supported (for now).");
|
|
||||||
// If there is at least one result to investigate.
|
// If there is at least one result to investigate.
|
||||||
if (currentOp->getNumResults() > 0) {
|
if (currentOp->getNumResults() > 0) {
|
||||||
auto result = currentOp->getResult(0);
|
auto result = currentOp->getResult(resultIndex);
|
||||||
for (const auto &operand : op.getOperands())
|
for (const auto &operand : op.getOperands())
|
||||||
if (operand == result)
|
if (operand == result)
|
||||||
insertDealloc = false;
|
insertDealloc = false;
|
||||||
|
|
|
@ -46,7 +46,7 @@ Value insertAllocAndDealloc(MemRefType type, Location loc,
|
||||||
// Determine if current function returns the result value of the
|
// Determine if current function returns the result value of the
|
||||||
// current op being lowered. If it does then dealloc should not be
|
// current op being lowered. If it does then dealloc should not be
|
||||||
// inserted.
|
// inserted.
|
||||||
bool checkInsertDealloc(Operation *currentOp);
|
bool checkInsertDealloc(Operation *currentOp, int resultIndex = 0);
|
||||||
|
|
||||||
// Create a mapping from result type's dimensions to input type's dimensions,
|
// Create a mapping from result type's dimensions to input type's dimensions,
|
||||||
// given that the result type is the result of a reduction op over the input
|
// given that the result type is the result of a reduction op over the input
|
||||||
|
@ -218,6 +218,10 @@ void populateLoweringONNXNormalizationOpPattern(
|
||||||
void populateLoweringONNXPoolingOpPattern(
|
void populateLoweringONNXPoolingOpPattern(
|
||||||
OwningRewritePatternList &patterns, MLIRContext *ctx);
|
OwningRewritePatternList &patterns, MLIRContext *ctx);
|
||||||
|
|
||||||
|
// `RNN` directory methods:
|
||||||
|
void populateLoweringONNXLSTMOpPattern(
|
||||||
|
OwningRewritePatternList &patterns, MLIRContext *ctx);
|
||||||
|
|
||||||
// `Tensor` directory methods:
|
// `Tensor` directory methods:
|
||||||
|
|
||||||
void populateLoweringONNXUnsqueezeOpPattern(
|
void populateLoweringONNXUnsqueezeOpPattern(
|
||||||
|
|
|
@ -0,0 +1,537 @@
|
||||||
|
//===--------------- LSTM.cpp - Lowering LSTM Op --------------------------===//
|
||||||
|
//
|
||||||
|
// Copyright 2019 The IBM Research Authors.
|
||||||
|
//
|
||||||
|
// =============================================================================
|
||||||
|
//
|
||||||
|
// This file lowers the ONNX LSTM Operators to Krnl dialect.
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#include "src/Conversion/ONNXToKrnl/RNN/RNNBase.hpp"
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
|
||||||
|
struct LstmState {
|
||||||
|
Value allH;
|
||||||
|
Value ht;
|
||||||
|
Value ct;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct LstmActivationPack {
|
||||||
|
RNNActivation f;
|
||||||
|
RNNActivation g;
|
||||||
|
RNNActivation h;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <>
|
||||||
|
bool hasAllNoneOutput<ONNXLSTMOp>(ONNXLSTMOp *op) {
|
||||||
|
return (
|
||||||
|
isNoneType(op->Y()) && isNoneType(op->Y_h()) && isNoneType(op->Y_c()));
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
std::tuple<LstmActivationPack, LstmActivationPack>
|
||||||
|
getActivationPack<ONNXLSTMOp, LstmActivationPack>(ONNXLSTMOp *op) {
|
||||||
|
auto direction = op->direction();
|
||||||
|
auto activations = op->activations();
|
||||||
|
auto activationAlpha = op->activation_alpha();
|
||||||
|
auto activationBeta = op->activation_beta();
|
||||||
|
|
||||||
|
LstmActivationPack activationForward, activationReverse;
|
||||||
|
|
||||||
|
// Get activation function name.
|
||||||
|
// Default forward functions
|
||||||
|
activationForward.f.name = "sigmoid";
|
||||||
|
activationForward.g.name = "tanh";
|
||||||
|
activationForward.h.name = "tanh";
|
||||||
|
// Default backward functions
|
||||||
|
activationReverse.f.name = "sigmoid";
|
||||||
|
activationReverse.g.name = "tanh";
|
||||||
|
activationReverse.h.name = "tanh";
|
||||||
|
if (activations) {
|
||||||
|
ArrayAttr activationArrAttr = activations.getValue();
|
||||||
|
if (direction == FORWARD || direction == BIDIRECTIONAL) {
|
||||||
|
// Forward activations.
|
||||||
|
if (activationArrAttr.size() > 0) {
|
||||||
|
activationForward.f.name =
|
||||||
|
activationArrAttr[0].cast<StringAttr>().getValue();
|
||||||
|
}
|
||||||
|
if (activationArrAttr.size() > 1) {
|
||||||
|
activationForward.g.name =
|
||||||
|
activationArrAttr[1].cast<StringAttr>().getValue();
|
||||||
|
}
|
||||||
|
if (activationArrAttr.size() > 2) {
|
||||||
|
activationForward.h.name =
|
||||||
|
activationArrAttr[2].cast<StringAttr>().getValue();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reverse activations.
|
||||||
|
if (direction == REVERSE || direction == BIDIRECTIONAL) {
|
||||||
|
int startIndex = (direction == REVERSE) ? 0 : 3;
|
||||||
|
if (activationArrAttr.size() > startIndex) {
|
||||||
|
activationReverse.f.name =
|
||||||
|
activationArrAttr[startIndex].cast<StringAttr>().getValue();
|
||||||
|
}
|
||||||
|
if (activationArrAttr.size() > startIndex + 1) {
|
||||||
|
activationReverse.g.name =
|
||||||
|
activationArrAttr[startIndex + 1].cast<StringAttr>().getValue();
|
||||||
|
}
|
||||||
|
if (activationArrAttr.size() > startIndex + 2) {
|
||||||
|
activationReverse.h.name =
|
||||||
|
activationArrAttr[startIndex + 2].cast<StringAttr>().getValue();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get alpha attributes.
|
||||||
|
if (activationAlpha) {
|
||||||
|
ArrayAttr activationArrAttr = activationAlpha.getValue();
|
||||||
|
if (direction == FORWARD || direction == BIDIRECTIONAL) {
|
||||||
|
// Forward activations.
|
||||||
|
if (activationArrAttr.size() > 0) {
|
||||||
|
activationForward.f.alpha = activationArrAttr[0].cast<FloatAttr>();
|
||||||
|
}
|
||||||
|
if (activationArrAttr.size() > 1) {
|
||||||
|
activationForward.g.alpha = activationArrAttr[1].cast<FloatAttr>();
|
||||||
|
}
|
||||||
|
if (activationArrAttr.size() > 2) {
|
||||||
|
activationForward.h.alpha = activationArrAttr[2].cast<FloatAttr>();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reverse activations.
|
||||||
|
if (direction == REVERSE || direction == BIDIRECTIONAL) {
|
||||||
|
int startIndex = (direction == REVERSE) ? 0 : 3;
|
||||||
|
if (activationArrAttr.size() > startIndex) {
|
||||||
|
activationReverse.f.alpha =
|
||||||
|
activationArrAttr[startIndex].cast<FloatAttr>();
|
||||||
|
}
|
||||||
|
if (activationArrAttr.size() > startIndex + 1) {
|
||||||
|
activationReverse.g.alpha =
|
||||||
|
activationArrAttr[startIndex + 1].cast<FloatAttr>();
|
||||||
|
}
|
||||||
|
if (activationArrAttr.size() > startIndex + 2) {
|
||||||
|
activationReverse.h.alpha =
|
||||||
|
activationArrAttr[startIndex + 2].cast<FloatAttr>();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get beta attributes.
|
||||||
|
if (activationBeta) {
|
||||||
|
ArrayAttr activationArrAttr = activationBeta.getValue();
|
||||||
|
if (direction == FORWARD || direction == BIDIRECTIONAL) {
|
||||||
|
// Forward activations.
|
||||||
|
if (activationArrAttr.size() > 0) {
|
||||||
|
activationForward.f.beta = activationArrAttr[0].cast<FloatAttr>();
|
||||||
|
}
|
||||||
|
if (activationArrAttr.size() > 1) {
|
||||||
|
activationForward.g.beta = activationArrAttr[1].cast<FloatAttr>();
|
||||||
|
}
|
||||||
|
if (activationArrAttr.size() > 2) {
|
||||||
|
activationForward.h.beta = activationArrAttr[2].cast<FloatAttr>();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reverse activations.
|
||||||
|
if (direction == REVERSE || direction == BIDIRECTIONAL) {
|
||||||
|
int startIndex = (direction == REVERSE) ? 0 : 3;
|
||||||
|
if (activationArrAttr.size() > startIndex) {
|
||||||
|
activationReverse.f.beta =
|
||||||
|
activationArrAttr[startIndex].cast<FloatAttr>();
|
||||||
|
}
|
||||||
|
if (activationArrAttr.size() > startIndex + 1) {
|
||||||
|
activationReverse.g.beta =
|
||||||
|
activationArrAttr[startIndex + 1].cast<FloatAttr>();
|
||||||
|
}
|
||||||
|
if (activationArrAttr.size() > startIndex + 2) {
|
||||||
|
activationReverse.h.beta =
|
||||||
|
activationArrAttr[startIndex + 2].cast<FloatAttr>();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return std::make_tuple(activationForward, activationReverse);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
LstmState allocAndInitializeStates<ONNXLSTMOp, LstmState>(
|
||||||
|
ConversionPatternRewriter &rewriter, Location loc, ONNXLSTMOp *op,
|
||||||
|
OperandAdaptor<ONNXLSTMOp> operandAdaptor) {
|
||||||
|
LstmState state;
|
||||||
|
|
||||||
|
// Insert allocation and deallocation for the results of this operation.
|
||||||
|
if (!isNoneType(op->Y())) {
|
||||||
|
auto yMemRefType = convertToMemRefType(op->Y().getType());
|
||||||
|
if (hasAllConstantDimensions(yMemRefType))
|
||||||
|
state.allH = insertAllocAndDealloc(yMemRefType, loc, rewriter,
|
||||||
|
checkInsertDealloc(op->getOperation(), 0));
|
||||||
|
else
|
||||||
|
emitError(loc, "Unsupported dynamic dimensions.");
|
||||||
|
} else {
|
||||||
|
state.allH = op->Y();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Y_h :: [num_directions, batch_size, hidden_size]
|
||||||
|
if (!isNoneType(op->Y_h())) {
|
||||||
|
auto yhMemRefType = convertToMemRefType(op->Y_h().getType());
|
||||||
|
if (hasAllConstantDimensions(yhMemRefType))
|
||||||
|
state.ht = insertAllocAndDealloc(yhMemRefType, loc, rewriter,
|
||||||
|
checkInsertDealloc(op->getOperation(), 1));
|
||||||
|
else
|
||||||
|
emitError(loc, "Unsupported dynamic dimensions.");
|
||||||
|
} else {
|
||||||
|
auto yhMemRefType = MemRefType::get(
|
||||||
|
{dimAt(operandAdaptor.W(), 0), dimAt(operandAdaptor.X(), 1),
|
||||||
|
dimAt(operandAdaptor.R(), 2)},
|
||||||
|
operandAdaptor.X().getType().cast<ShapedType>().getElementType());
|
||||||
|
state.ht = insertAllocAndDealloc(yhMemRefType, loc, rewriter, true);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Y_c :: [num_directions, batch_size, hidden_size]
|
||||||
|
if (!isNoneType(op->Y_c())) {
|
||||||
|
auto ycMemRefType = convertToMemRefType(op->Y_c().getType());
|
||||||
|
if (hasAllConstantDimensions(ycMemRefType))
|
||||||
|
state.ct = insertAllocAndDealloc(ycMemRefType, loc, rewriter,
|
||||||
|
checkInsertDealloc(op->getOperation(), 2));
|
||||||
|
else
|
||||||
|
emitError(loc, "Unsupported dynamic dimensions.");
|
||||||
|
} else {
|
||||||
|
auto ycMemRefType = MemRefType::get(
|
||||||
|
{dimAt(operandAdaptor.W(), 0), dimAt(operandAdaptor.X(), 1),
|
||||||
|
dimAt(operandAdaptor.R(), 2)},
|
||||||
|
operandAdaptor.X().getType().cast<ShapedType>().getElementType());
|
||||||
|
state.ct = insertAllocAndDealloc(ycMemRefType, loc, rewriter, true);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Initialize ht and ct.
|
||||||
|
Value zero = emitConstantOp(rewriter, loc,
|
||||||
|
operandAdaptor.X().getType().cast<ShapedType>().getElementType(), 0);
|
||||||
|
int nLoops = 3;
|
||||||
|
BuildKrnlLoop initializationLoops(rewriter, loc, nLoops);
|
||||||
|
initializationLoops.createDefineOptimizeAndIterateOp(state.ht);
|
||||||
|
auto ipInitializationLoops = rewriter.saveInsertionPoint();
|
||||||
|
rewriter.setInsertionPointToStart(initializationLoops.getIterateBlock());
|
||||||
|
{
|
||||||
|
SmallVector<Value, 4> IVs;
|
||||||
|
for (int i = 0; i < nLoops; ++i)
|
||||||
|
IVs.emplace_back(initializationLoops.getInductionVar(i));
|
||||||
|
|
||||||
|
Value hiddenVal = zero;
|
||||||
|
if (!isNoneType(operandAdaptor.initial_h()))
|
||||||
|
hiddenVal = rewriter.create<LoadOp>(loc, operandAdaptor.initial_h(), IVs);
|
||||||
|
rewriter.create<StoreOp>(loc, hiddenVal, state.ht, IVs);
|
||||||
|
|
||||||
|
Value cellVal = zero;
|
||||||
|
if (!isNoneType(operandAdaptor.initial_c()))
|
||||||
|
cellVal = rewriter.create<LoadOp>(loc, operandAdaptor.initial_c(), IVs);
|
||||||
|
rewriter.create<StoreOp>(loc, cellVal, state.ct, IVs);
|
||||||
|
}
|
||||||
|
rewriter.restoreInsertionPoint(ipInitializationLoops);
|
||||||
|
return state;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
void calculateState<ONNXLSTMOp, LstmState, LstmActivationPack>(
|
||||||
|
ConversionPatternRewriter &rewriter, Location loc,
|
||||||
|
OperandAdaptor<ONNXLSTMOp> operandAdaptor, LstmState state,
|
||||||
|
LstmActivationPack activationPack, Value directionIV, Value sequenceIV) {
|
||||||
|
|
||||||
|
bool hasBiasForInput = false, hasPeepholes = false;
|
||||||
|
if (!isNoneType(operandAdaptor.B()))
|
||||||
|
hasBiasForInput = true;
|
||||||
|
if (!isNoneType(operandAdaptor.P()))
|
||||||
|
hasPeepholes = true;
|
||||||
|
|
||||||
|
// Prepare dimensions.
|
||||||
|
auto batchDimSize = dimAt(operandAdaptor.X(), 1);
|
||||||
|
auto inputDimSize = dimAt(operandAdaptor.X(), 2);
|
||||||
|
auto hiddenDimSize = dimAt(operandAdaptor.R(), 2);
|
||||||
|
Value hiddenDimVal =
|
||||||
|
emitConstantOp(rewriter, loc, rewriter.getIndexType(), hiddenDimSize);
|
||||||
|
|
||||||
|
auto elementType =
|
||||||
|
operandAdaptor.X().getType().cast<ShapedType>().getElementType();
|
||||||
|
|
||||||
|
// Prepare AffineMap to access bias, peepholes tensors.
|
||||||
|
AffineMap accessByOffsetMap;
|
||||||
|
{
|
||||||
|
AffineExpr iv = rewriter.getAffineDimExpr(0);
|
||||||
|
AffineExpr index = rewriter.getAffineSymbolExpr(0);
|
||||||
|
AffineExpr size = rewriter.getAffineSymbolExpr(1);
|
||||||
|
AffineExpr accessByOffsetExpr = index * size + iv;
|
||||||
|
accessByOffsetMap = AffineMap::get(1, 2, accessByOffsetExpr);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Prepare constant indices.
|
||||||
|
SmallVector<Value, 4> constantIndices;
|
||||||
|
for (int i = 0; i < 8; i++)
|
||||||
|
constantIndices.emplace_back(
|
||||||
|
emitConstantOp(rewriter, loc, rewriter.getIndexType(), i));
|
||||||
|
|
||||||
|
// Equations for LSTM.
|
||||||
|
// it = f(Xt*(Wi^T) + Ht-1*(Ri^T) + Pi (.) Ct-1 + Wbi + Rbi)
|
||||||
|
// ft = f(Xt*(Wf^T) + Ht-1*(Rf^T) + Pf (.) Ct-1 + Wbf + Rbf)
|
||||||
|
// ct = g(Xt*(Wc^T) + Ht-1*(Rc^T) + Wbc + Rbc)
|
||||||
|
// Ct = ft (.) Ct-1 + it (.) ct
|
||||||
|
// ot = f(Xt*(Wo^T) + Ht-1*(Ro^T) + Po (.) Ct + Wbo + Rbo)
|
||||||
|
// Ht = ot (.) h(Ct)
|
||||||
|
//
|
||||||
|
// The following code will emit loops as follows:
|
||||||
|
// for b in 0 .. BatchDimSize
|
||||||
|
// for h in 0 .. HiddenDimSize
|
||||||
|
// for i in 0 .. InputDimSize {
|
||||||
|
// compute Xt*(Wi^T), Xt*(Wo^T), Xt*(Wf^t), Xt*(Wc^T),
|
||||||
|
// Ht-1*(Ri^T), Ht-1*(Ro^T), Ht-1*(Rf^t), Ht-1*(Rc^T)
|
||||||
|
// }
|
||||||
|
// compute it, ft, ct, Ct, ot, Ht
|
||||||
|
|
||||||
|
BuildKrnlLoop stateLoops(rewriter, loc, 2);
|
||||||
|
stateLoops.createDefineAndOptimizeOp();
|
||||||
|
stateLoops.pushBounds(0, batchDimSize);
|
||||||
|
stateLoops.pushBounds(0, hiddenDimSize);
|
||||||
|
stateLoops.createIterateOp();
|
||||||
|
|
||||||
|
rewriter.setInsertionPointToStart(stateLoops.getIterateBlock());
|
||||||
|
{
|
||||||
|
auto batchIV = stateLoops.getInductionVar(0);
|
||||||
|
auto hiddenIV = stateLoops.getInductionVar(1);
|
||||||
|
|
||||||
|
// IVs to access tensors.
|
||||||
|
// IVs for the hidden and cell state tensors.
|
||||||
|
SmallVector<Value, 4> hIVs, cIVs;
|
||||||
|
// IVs for the bias tensors for W and R.
|
||||||
|
SmallVector<SmallVector<Value, 4>, 4> wbIOFCIVs, rbIOFCIVs;
|
||||||
|
// IVs for the peepholes.
|
||||||
|
SmallVector<SmallVector<Value, 4>, 4> pIOFIVs;
|
||||||
|
|
||||||
|
{ // Compute IVs.
|
||||||
|
// H :: [num_directions, batch_size, hidden_size]
|
||||||
|
hIVs = {directionIV, batchIV, hiddenIV};
|
||||||
|
// C :: [num_directions, batch_size, hidden_size]
|
||||||
|
cIVs = {directionIV, batchIV, hiddenIV};
|
||||||
|
|
||||||
|
// Bias [Wb[iofc], Rb[iofc]] :: [num_directions, 8*hidden_size]
|
||||||
|
if (hasBiasForInput) {
|
||||||
|
// Wb[iofc]
|
||||||
|
for (unsigned i = 0; i < 4; ++i) {
|
||||||
|
Value wHiddenIV =
|
||||||
|
rewriter.create<AffineApplyOp>(loc, accessByOffsetMap,
|
||||||
|
ValueRange(std::vector<Value>{/*iv=*/hiddenIV,
|
||||||
|
/*index=*/constantIndices[i], /*size=*/hiddenDimVal}));
|
||||||
|
wbIOFCIVs.emplace_back(SmallVector<Value, 2>{directionIV, wHiddenIV});
|
||||||
|
}
|
||||||
|
// Rb[iofc]
|
||||||
|
for (unsigned i = 4; i < 8; ++i) {
|
||||||
|
SmallVector<Value, 4> rbIVs;
|
||||||
|
Value rHiddenIV =
|
||||||
|
rewriter.create<AffineApplyOp>(loc, accessByOffsetMap,
|
||||||
|
ValueRange(std::vector<Value>{/*iv=*/hiddenIV,
|
||||||
|
/*index=*/constantIndices[i], /*size=*/hiddenDimVal}));
|
||||||
|
rbIOFCIVs.emplace_back(SmallVector<Value, 2>{directionIV, rHiddenIV});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Peepholes P[iof] :: [num_directions, 3*hidden_size]
|
||||||
|
if (hasPeepholes) {
|
||||||
|
for (unsigned i = 0; i < 3; ++i) {
|
||||||
|
SmallVector<Value, 4> pIVs;
|
||||||
|
Value pHiddenIV =
|
||||||
|
rewriter.create<AffineApplyOp>(loc, accessByOffsetMap,
|
||||||
|
ValueRange(std::vector<Value>{
|
||||||
|
hiddenIV, constantIndices[i], hiddenDimVal}));
|
||||||
|
pIOFIVs.emplace_back(SmallVector<Value, 2>{directionIV, pHiddenIV});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Value loadH = rewriter.create<LoadOp>(loc, state.ht, hIVs);
|
||||||
|
Value loadC = rewriter.create<LoadOp>(loc, state.ct, cIVs);
|
||||||
|
|
||||||
|
// Emit instructions for matrix multiplications:
|
||||||
|
// Xt*(Wi^T), Xt*(Wo^T), Xt*(Wf^t), Xt*(Wc^T)
|
||||||
|
// Ht-1*(Ri^T), Ht-1*(Ro^T), Ht-1*(Rf^t), Ht-1*(Rc^T)
|
||||||
|
|
||||||
|
// Allocate memory for storing matrix multiplication results.
|
||||||
|
SmallVector<Value, 4> xwIOFC, hrIOFC;
|
||||||
|
Value zero = emitConstantOp(rewriter, loc, elementType, 0);
|
||||||
|
MemRefType scalarMemRefType = MemRefType::get({}, elementType, {}, 0);
|
||||||
|
for (unsigned i = 0; i < 4; ++i) {
|
||||||
|
Value xwAlloc = rewriter.create<AllocOp>(loc, scalarMemRefType);
|
||||||
|
rewriter.create<StoreOp>(loc, zero, xwAlloc);
|
||||||
|
Value hrAlloc = rewriter.create<AllocOp>(loc, scalarMemRefType);
|
||||||
|
rewriter.create<StoreOp>(loc, zero, hrAlloc);
|
||||||
|
xwIOFC.emplace_back(xwAlloc);
|
||||||
|
hrIOFC.emplace_back(hrAlloc);
|
||||||
|
}
|
||||||
|
|
||||||
|
{ // Emit instructions for matrix multiplications.
|
||||||
|
// input_size is the reduction dimension.
|
||||||
|
BuildKrnlLoop reductionLoops(rewriter, loc, 1);
|
||||||
|
reductionLoops.createDefineAndOptimizeOp();
|
||||||
|
reductionLoops.pushBounds(0, inputDimSize);
|
||||||
|
reductionLoops.createIterateOp();
|
||||||
|
|
||||||
|
auto ipReductionLoops = rewriter.saveInsertionPoint();
|
||||||
|
rewriter.setInsertionPointToStart(reductionLoops.getIterateBlock());
|
||||||
|
{
|
||||||
|
auto reductionIV = reductionLoops.getInductionVar(0);
|
||||||
|
// Prepare IVs for accessing the input tensor and parameters.
|
||||||
|
SmallVector<Value, 4> xIVs;
|
||||||
|
SmallVector<SmallVector<Value, 4>, 4> wIOFCIVs, rIOFCIVs;
|
||||||
|
|
||||||
|
// X :: [seq_length, batch_size, input_size]
|
||||||
|
xIVs = {sequenceIV, batchIV, reductionIV};
|
||||||
|
|
||||||
|
// W[iofc] :: [num_directions, 4*hidden_size, input_size]
|
||||||
|
// R[iofc] :: [num_directions, 4*hidden_size, input_size]
|
||||||
|
for (unsigned i = 0; i < 4; ++i) {
|
||||||
|
SmallVector<Value, 4> wIVs, rIVs;
|
||||||
|
Value wHiddenIV =
|
||||||
|
rewriter.create<AffineApplyOp>(loc, accessByOffsetMap,
|
||||||
|
ValueRange(std::vector<Value>{
|
||||||
|
hiddenIV, constantIndices[i], hiddenDimVal}));
|
||||||
|
|
||||||
|
wIVs = {directionIV, wHiddenIV, reductionIV};
|
||||||
|
wIOFCIVs.emplace_back(wIVs);
|
||||||
|
|
||||||
|
rIVs = {directionIV, wHiddenIV, reductionIV};
|
||||||
|
rIOFCIVs.emplace_back(rIVs);
|
||||||
|
}
|
||||||
|
|
||||||
|
Value loadX = rewriter.create<LoadOp>(loc, operandAdaptor.X(), xIVs);
|
||||||
|
for (unsigned i = 0; i < 4; ++i) {
|
||||||
|
// Xt * Wiofc
|
||||||
|
Value loadW =
|
||||||
|
rewriter.create<LoadOp>(loc, operandAdaptor.W(), wIOFCIVs[i]);
|
||||||
|
Value xwVal = rewriter.create<MulFOp>(loc, loadX, loadW);
|
||||||
|
Value loadXW = rewriter.create<LoadOp>(loc, xwIOFC[i]);
|
||||||
|
Value nextXW = rewriter.create<AddFOp>(loc, loadXW, xwVal);
|
||||||
|
rewriter.create<StoreOp>(loc, nextXW, xwIOFC[i]);
|
||||||
|
// Ht-1 * Riofc
|
||||||
|
Value loadR =
|
||||||
|
rewriter.create<LoadOp>(loc, operandAdaptor.R(), rIOFCIVs[i]);
|
||||||
|
Value hrVal = rewriter.create<MulFOp>(loc, loadH, loadR);
|
||||||
|
Value loadHR = rewriter.create<LoadOp>(loc, hrIOFC[i]);
|
||||||
|
Value nextHR = rewriter.create<AddFOp>(loc, loadHR, hrVal);
|
||||||
|
rewriter.create<StoreOp>(loc, nextHR, hrIOFC[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
rewriter.restoreInsertionPoint(ipReductionLoops);
|
||||||
|
}
|
||||||
|
|
||||||
|
// it = f(Xt*(Wi^T) + Ht-1*(Ri^T) + Pi (.) Ct-1 + Wbi + Rbi)
|
||||||
|
Value loadXWI = rewriter.create<LoadOp>(loc, xwIOFC[0]);
|
||||||
|
Value loadHRI = rewriter.create<LoadOp>(loc, hrIOFC[0]);
|
||||||
|
Value it = rewriter.create<AddFOp>(loc, loadXWI, loadHRI);
|
||||||
|
if (hasPeepholes) {
|
||||||
|
Value loadP =
|
||||||
|
rewriter.create<LoadOp>(loc, operandAdaptor.P(), pIOFIVs[0]);
|
||||||
|
Value PC = rewriter.create<MulFOp>(loc, loadP, loadC);
|
||||||
|
it = rewriter.create<AddFOp>(loc, it, PC);
|
||||||
|
}
|
||||||
|
if (hasBiasForInput) {
|
||||||
|
Value loadWB =
|
||||||
|
rewriter.create<LoadOp>(loc, operandAdaptor.B(), wbIOFCIVs[0]);
|
||||||
|
it = rewriter.create<AddFOp>(loc, it, loadWB);
|
||||||
|
Value loadRB =
|
||||||
|
rewriter.create<LoadOp>(loc, operandAdaptor.B(), rbIOFCIVs[0]);
|
||||||
|
it = rewriter.create<AddFOp>(loc, it, loadRB);
|
||||||
|
}
|
||||||
|
it = applyActivation(rewriter, loc, activationPack.f, it);
|
||||||
|
|
||||||
|
// ft = f(Xt*(Wf^T) + Ht-1*(Rf^T) + Pf (.) Ct-1 + Wbf + Rbf)
|
||||||
|
Value loadXWF = rewriter.create<LoadOp>(loc, xwIOFC[2]);
|
||||||
|
Value loadHRF = rewriter.create<LoadOp>(loc, hrIOFC[2]);
|
||||||
|
Value ft = rewriter.create<AddFOp>(loc, loadXWF, loadHRF);
|
||||||
|
if (hasPeepholes) {
|
||||||
|
Value loadP =
|
||||||
|
rewriter.create<LoadOp>(loc, operandAdaptor.P(), pIOFIVs[2]);
|
||||||
|
Value PC = rewriter.create<MulFOp>(loc, loadP, loadC);
|
||||||
|
ft = rewriter.create<AddFOp>(loc, ft, PC);
|
||||||
|
}
|
||||||
|
if (hasBiasForInput) {
|
||||||
|
Value loadWB =
|
||||||
|
rewriter.create<LoadOp>(loc, operandAdaptor.B(), wbIOFCIVs[2]);
|
||||||
|
ft = rewriter.create<AddFOp>(loc, ft, loadWB);
|
||||||
|
Value loadRB =
|
||||||
|
rewriter.create<LoadOp>(loc, operandAdaptor.B(), rbIOFCIVs[2]);
|
||||||
|
ft = rewriter.create<AddFOp>(loc, ft, loadRB);
|
||||||
|
}
|
||||||
|
ft = applyActivation(rewriter, loc, activationPack.f, ft);
|
||||||
|
|
||||||
|
// ct = g(Xt*(Wc^T) + Ht-1*(Rc^T) + Wbc + Rbc)
|
||||||
|
Value loadXWC = rewriter.create<LoadOp>(loc, xwIOFC[3]);
|
||||||
|
Value loadHRC = rewriter.create<LoadOp>(loc, hrIOFC[3]);
|
||||||
|
Value ct = rewriter.create<AddFOp>(loc, loadXWC, loadHRC);
|
||||||
|
if (hasBiasForInput) {
|
||||||
|
Value loadWB =
|
||||||
|
rewriter.create<LoadOp>(loc, operandAdaptor.B(), wbIOFCIVs[3]);
|
||||||
|
ct = rewriter.create<AddFOp>(loc, ct, loadWB);
|
||||||
|
Value loadRB =
|
||||||
|
rewriter.create<LoadOp>(loc, operandAdaptor.B(), rbIOFCIVs[3]);
|
||||||
|
ct = rewriter.create<AddFOp>(loc, ct, loadRB);
|
||||||
|
}
|
||||||
|
ct = applyActivation(rewriter, loc, activationPack.g, ct);
|
||||||
|
|
||||||
|
// Ct = ft (.) Ct-1 + it (.) ct
|
||||||
|
Value FtCt1 = rewriter.create<MulFOp>(loc, ft, loadC);
|
||||||
|
Value itct = rewriter.create<MulFOp>(loc, it, ct);
|
||||||
|
Value Ct = rewriter.create<AddFOp>(loc, FtCt1, itct);
|
||||||
|
rewriter.create<StoreOp>(loc, Ct, state.ct, cIVs);
|
||||||
|
|
||||||
|
// ot = f(Xt*(Wo^T) + Ht-1*(Ro^T) + Po (.) Ct + Wbo + Rbo)
|
||||||
|
Value loadXWO = rewriter.create<LoadOp>(loc, xwIOFC[1]);
|
||||||
|
Value loadHRO = rewriter.create<LoadOp>(loc, hrIOFC[1]);
|
||||||
|
Value ot = rewriter.create<AddFOp>(loc, loadXWO, loadHRO);
|
||||||
|
if (hasPeepholes) {
|
||||||
|
Value loadP =
|
||||||
|
rewriter.create<LoadOp>(loc, operandAdaptor.P(), pIOFIVs[1]);
|
||||||
|
Value PC = rewriter.create<MulFOp>(loc, loadP, Ct);
|
||||||
|
ot = rewriter.create<AddFOp>(loc, ot, PC);
|
||||||
|
}
|
||||||
|
if (hasBiasForInput) {
|
||||||
|
Value loadWB =
|
||||||
|
rewriter.create<LoadOp>(loc, operandAdaptor.B(), wbIOFCIVs[1]);
|
||||||
|
ot = rewriter.create<AddFOp>(loc, ot, loadWB);
|
||||||
|
Value loadRB =
|
||||||
|
rewriter.create<LoadOp>(loc, operandAdaptor.B(), rbIOFCIVs[1]);
|
||||||
|
ot = rewriter.create<AddFOp>(loc, ot, loadRB);
|
||||||
|
}
|
||||||
|
ot = applyActivation(rewriter, loc, activationPack.f, ot);
|
||||||
|
|
||||||
|
// Ht = ot (.) h(Ct)
|
||||||
|
Value hCt = applyActivation(rewriter, loc, activationPack.h, Ct);
|
||||||
|
Value Ht = rewriter.create<MulFOp>(loc, ot, hCt);
|
||||||
|
rewriter.create<StoreOp>(loc, Ht, state.ht, hIVs);
|
||||||
|
|
||||||
|
// Store the current Ht if required.
|
||||||
|
if (!isNoneType(state.allH)) {
|
||||||
|
SmallVector<Value, 4> allHIVs{sequenceIV, directionIV, batchIV, hiddenIV};
|
||||||
|
rewriter.create<StoreOp>(loc, Ht, state.allH, allHIVs);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Deallocate the temporary results of matrix multiplications.
|
||||||
|
for (Value v : xwIOFC)
|
||||||
|
rewriter.create<DeallocOp>(loc, v);
|
||||||
|
for (Value v : hrIOFC)
|
||||||
|
rewriter.create<DeallocOp>(loc, v);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
void stateToOutput<ONNXLSTMOp, LstmState>(
|
||||||
|
ONNXLSTMOp *op, LstmState state, std::vector<Value> &outputs) {
|
||||||
|
Value noneValue;
|
||||||
|
outputs.emplace_back((isNoneType(op->Y()) ? noneValue : state.allH));
|
||||||
|
outputs.emplace_back((isNoneType(op->Y_h()) ? noneValue : state.ht));
|
||||||
|
outputs.emplace_back((isNoneType(op->Y_c()) ? noneValue : state.ct));
|
||||||
|
}
|
||||||
|
|
||||||
|
void populateLoweringONNXLSTMOpPattern(
|
||||||
|
OwningRewritePatternList &patterns, MLIRContext *ctx) {
|
||||||
|
patterns.insert<ONNXRNNOpLowering<ONNXLSTMOp, LstmState, LstmActivationPack>>(
|
||||||
|
ctx);
|
||||||
|
}
|
|
@ -0,0 +1,73 @@
|
||||||
|
//===--------------- RNNBase.cpp - Lowering RNN Ops -----------------------===//
|
||||||
|
//
|
||||||
|
// Copyright 2019 The IBM Research Authors.
|
||||||
|
//
|
||||||
|
// =============================================================================
|
||||||
|
//
|
||||||
|
// This file defines base functions for lowerng the ONNX RNN Operators.
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#include "src/Conversion/ONNXToKrnl/RNN/RNNBase.hpp"
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
|
||||||
|
// Check a Value's type is none or not.
|
||||||
|
bool isNoneType(Value val) { return val.getType().isa<NoneType>(); }
|
||||||
|
|
||||||
|
// Get a dimension of the tensor's shape.
|
||||||
|
int64_t dimAt(Value val, int index) {
|
||||||
|
return val.getType().cast<ShapedType>().getShape()[index];
|
||||||
|
}
|
||||||
|
|
||||||
|
// Apply an activation function on a given scalar operand.
|
||||||
|
Value applyActivation(ConversionPatternRewriter &rewriter, Location loc,
|
||||||
|
RNNActivation activation, Value scalarOperand) {
|
||||||
|
Value res;
|
||||||
|
|
||||||
|
MemRefType scalarMemRefType =
|
||||||
|
MemRefType::get({}, scalarOperand.getType(), {}, 0);
|
||||||
|
Value alloc = rewriter.create<AllocOp>(loc, scalarMemRefType);
|
||||||
|
rewriter.create<StoreOp>(loc, scalarOperand, alloc);
|
||||||
|
|
||||||
|
std::vector<mlir::NamedAttribute> attributes;
|
||||||
|
if (activation.alpha) {
|
||||||
|
attributes.emplace_back(
|
||||||
|
rewriter.getNamedAttr("alpha", activation.alpha.getValue()));
|
||||||
|
}
|
||||||
|
if (activation.beta) {
|
||||||
|
attributes.emplace_back(
|
||||||
|
rewriter.getNamedAttr("beta", activation.beta.getValue()));
|
||||||
|
}
|
||||||
|
|
||||||
|
if (activation.name.equals_lower("relu"))
|
||||||
|
res = rewriter.create<ONNXReluOp>(loc, scalarMemRefType, alloc);
|
||||||
|
else if (activation.name.equals_lower("tanh"))
|
||||||
|
res = rewriter.create<ONNXTanhOp>(loc, scalarMemRefType, alloc);
|
||||||
|
else if (activation.name.equals_lower("sigmoid"))
|
||||||
|
res = rewriter.create<ONNXSigmoidOp>(loc, scalarMemRefType, alloc);
|
||||||
|
else if (activation.name.equals_lower("affine"))
|
||||||
|
emitError(loc, "Unsupported activation");
|
||||||
|
else if (activation.name.equals_lower("leakyrelu"))
|
||||||
|
res = rewriter.create<ONNXLeakyReluOp>(
|
||||||
|
loc, scalarMemRefType, alloc, attributes);
|
||||||
|
else if (activation.name.equals_lower("thresholdedrelu"))
|
||||||
|
res = rewriter.create<ONNXThresholdedReluOp>(
|
||||||
|
loc, scalarMemRefType, alloc, attributes);
|
||||||
|
else if (activation.name.equals_lower("scaledtanh"))
|
||||||
|
emitError(loc, "Unsupported activation");
|
||||||
|
else if (activation.name.equals_lower("hardsigmoid"))
|
||||||
|
res = rewriter.create<ONNXHardSigmoidOp>(
|
||||||
|
loc, scalarMemRefType, alloc, attributes);
|
||||||
|
else if (activation.name.equals_lower("elu"))
|
||||||
|
res = rewriter.create<ONNXEluOp>(loc, scalarMemRefType, alloc, attributes);
|
||||||
|
else if (activation.name.equals_lower("softsign"))
|
||||||
|
res = rewriter.create<ONNXSoftsignOp>(loc, scalarMemRefType, alloc);
|
||||||
|
else if (activation.name.equals_lower("softplus"))
|
||||||
|
res = rewriter.create<ONNXSoftplusOp>(loc, scalarMemRefType, alloc);
|
||||||
|
else
|
||||||
|
llvm_unreachable("Unsupported activation");
|
||||||
|
|
||||||
|
Value result = rewriter.create<LoadOp>(loc, res);
|
||||||
|
return result;
|
||||||
|
}
|
|
@ -0,0 +1,144 @@
|
||||||
|
//===--------------- RNNBase.hpp - Lowering RNN Ops -----------------------===//
|
||||||
|
//
|
||||||
|
// Copyright 2019 The IBM Research Authors.
|
||||||
|
//
|
||||||
|
// =============================================================================
|
||||||
|
//
|
||||||
|
// This file defines base functions for lowerng the ONNX RNN Operators.
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#include "mlir/IR/AffineExpr.h"
|
||||||
|
#include "src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp"
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
|
||||||
|
static const StringRef FORWARD = "forward";
|
||||||
|
static const StringRef REVERSE = "reverse";
|
||||||
|
static const StringRef BIDIRECTIONAL = "bidirectional";
|
||||||
|
|
||||||
|
struct RNNActivation {
|
||||||
|
StringRef name;
|
||||||
|
Optional<FloatAttr> alpha;
|
||||||
|
Optional<FloatAttr> beta;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Check a Value's type is none or not.
|
||||||
|
bool isNoneType(Value val);
|
||||||
|
|
||||||
|
// Get a dimension of the tensor's shape.
|
||||||
|
int64_t dimAt(Value val, int index);
|
||||||
|
|
||||||
|
// Apply an activation function on a given scalar operand.
|
||||||
|
Value applyActivation(ConversionPatternRewriter &rewriter, Location loc,
|
||||||
|
RNNActivation activation, Value scalarOperand);
|
||||||
|
|
||||||
|
// Override the following methods when lowering an RNN operation:
|
||||||
|
// - hasAllNoneOutput
|
||||||
|
// - getActivationPack
|
||||||
|
// - allocAndInitializeStates
|
||||||
|
// - calculateState
|
||||||
|
// - stateToOutput
|
||||||
|
|
||||||
|
// Check whether all outputs have NoneType or not.
|
||||||
|
template <typename RNNOp>
|
||||||
|
bool hasAllNoneOutput(RNNOp *op);
|
||||||
|
|
||||||
|
// Obtain activations functions for a specific operation.
|
||||||
|
template <typename RNNOp, typename A>
|
||||||
|
std::tuple<A, A> getActivationPack(RNNOp *op);
|
||||||
|
|
||||||
|
// Allocate memory for RNN states and initialize them.
|
||||||
|
template <typename RNNOp, typename S>
|
||||||
|
S allocAndInitializeStates(ConversionPatternRewriter &rewriter, Location loc,
|
||||||
|
RNNOp *op, OperandAdaptor<RNNOp> operandAdaptor);
|
||||||
|
|
||||||
|
// Calculate new states from the current input and states.
|
||||||
|
template <typename RNNOp, typename S, typename A>
|
||||||
|
void calculateState(ConversionPatternRewriter &rewriter, Location loc,
|
||||||
|
OperandAdaptor<RNNOp> operandAdaptor, S state, A activationSet,
|
||||||
|
Value directionIV, Value sequenceIV);
|
||||||
|
|
||||||
|
// Write states to the RNN's outputs.
|
||||||
|
template <typename RNNOp, typename S>
|
||||||
|
void stateToOutput(RNNOp *op, S state, std::vector<Value> &outputs);
|
||||||
|
|
||||||
|
// A common template for lowering an RNN operation.
|
||||||
|
template <typename RNNOp, typename S, typename A>
|
||||||
|
struct ONNXRNNOpLowering : public ConversionPattern {
|
||||||
|
ONNXRNNOpLowering(MLIRContext *ctx)
|
||||||
|
: ConversionPattern(RNNOp::getOperationName(), 1, ctx) {}
|
||||||
|
|
||||||
|
LogicalResult matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||||
|
ConversionPatternRewriter &rewriter) const final {
|
||||||
|
auto loc = op->getLoc();
|
||||||
|
|
||||||
|
RNNOp rnnOp = llvm::dyn_cast<RNNOp>(op);
|
||||||
|
OperandAdaptor<RNNOp> operandAdaptor(operands);
|
||||||
|
|
||||||
|
if (hasAllNoneOutput<RNNOp>(&rnnOp)) {
|
||||||
|
rewriter.eraseOp(op);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
S state = allocAndInitializeStates<RNNOp, S>(
|
||||||
|
rewriter, loc, &rnnOp, operandAdaptor);
|
||||||
|
|
||||||
|
A activationForward, activationReverse;
|
||||||
|
std::tie(activationForward, activationReverse) =
|
||||||
|
getActivationPack<RNNOp, A>(&rnnOp);
|
||||||
|
|
||||||
|
int64_t sequenceDimSize = dimAt(rnnOp.X(), 0);
|
||||||
|
auto direction = rnnOp.direction();
|
||||||
|
|
||||||
|
if (direction == FORWARD || direction == BIDIRECTIONAL) {
|
||||||
|
BuildKrnlLoop sequenceLoops(rewriter, loc, 1);
|
||||||
|
sequenceLoops.createDefineAndOptimizeOp();
|
||||||
|
sequenceLoops.pushBounds(0, sequenceDimSize);
|
||||||
|
sequenceLoops.createIterateOp();
|
||||||
|
|
||||||
|
auto ipSequenceLoops = rewriter.saveInsertionPoint();
|
||||||
|
rewriter.setInsertionPointToStart(sequenceLoops.getIterateBlock());
|
||||||
|
{
|
||||||
|
Value directionIV =
|
||||||
|
emitConstantOp(rewriter, loc, rewriter.getIndexType(), 0);
|
||||||
|
Value sequenceIV = sequenceLoops.getInductionVar(0);
|
||||||
|
// Emit calculation for one RNN step.
|
||||||
|
calculateState<RNNOp, S, A>(rewriter, loc, operandAdaptor, state,
|
||||||
|
activationForward, directionIV, sequenceIV);
|
||||||
|
}
|
||||||
|
rewriter.restoreInsertionPoint(ipSequenceLoops);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (direction == REVERSE || direction == BIDIRECTIONAL) {
|
||||||
|
BuildKrnlLoop sequenceLoops(rewriter, loc, 1);
|
||||||
|
sequenceLoops.createDefineAndOptimizeOp();
|
||||||
|
sequenceLoops.pushBounds(0, sequenceDimSize);
|
||||||
|
sequenceLoops.createIterateOp();
|
||||||
|
|
||||||
|
auto ipSequenceLoops = rewriter.saveInsertionPoint();
|
||||||
|
rewriter.setInsertionPointToStart(sequenceLoops.getIterateBlock());
|
||||||
|
{
|
||||||
|
AffineMap reverseIVMap = AffineMap::get(1, 1,
|
||||||
|
rewriter.getAffineSymbolExpr(0) - rewriter.getAffineDimExpr(0) - 1);
|
||||||
|
|
||||||
|
Value directionIV = emitConstantOp(rewriter, loc,
|
||||||
|
rewriter.getIndexType(), (direction == REVERSE) ? 0 : 1);
|
||||||
|
Value reverseSequenceIV =
|
||||||
|
rewriter.create<AffineApplyOp>(loc, reverseIVMap,
|
||||||
|
ValueRange(std::vector<Value>{sequenceLoops.getInductionVar(0),
|
||||||
|
emitConstantOp(rewriter, loc, rewriter.getIndexType(),
|
||||||
|
sequenceDimSize)}));
|
||||||
|
// Emit calculation for one RNN step.
|
||||||
|
calculateState<RNNOp, S, A>(rewriter, loc, operandAdaptor, state,
|
||||||
|
activationReverse, directionIV, reverseSequenceIV);
|
||||||
|
}
|
||||||
|
rewriter.restoreInsertionPoint(ipSequenceLoops);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<Value> outputs;
|
||||||
|
stateToOutput<RNNOp, S>(&rnnOp, state, outputs);
|
||||||
|
rewriter.replaceOp(op, outputs);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
|
@ -299,6 +299,112 @@ static void insertConvSpatialDim(SmallVector<int64_t, 4> *outputDims,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// Support function that infers shape for RNN operations.
|
||||||
|
template <typename T>
|
||||||
|
static bool RNNShapeInference(T *op) {
|
||||||
|
Value X = op->X();
|
||||||
|
Value W = op->W();
|
||||||
|
Value R = op->R();
|
||||||
|
|
||||||
|
if (!X.getType().isa<RankedTensorType>() ||
|
||||||
|
!W.getType().isa<RankedTensorType>() ||
|
||||||
|
!R.getType().isa<RankedTensorType>())
|
||||||
|
return false;
|
||||||
|
|
||||||
|
auto xTy = X.getType().cast<RankedTensorType>();
|
||||||
|
auto elementType = xTy.getElementType();
|
||||||
|
|
||||||
|
// xShape :: [seq_length, batch_size, input_size]
|
||||||
|
auto xShape = xTy.getShape();
|
||||||
|
// wShape :: [num_directions, 4*hidden_size, input_size]
|
||||||
|
auto wShape = W.getType().cast<RankedTensorType>().getShape();
|
||||||
|
// rShape :: [num_directions, 4*hidden_size, hidden_size]
|
||||||
|
auto rShape = R.getType().cast<RankedTensorType>().getShape();
|
||||||
|
|
||||||
|
if (xShape.size() != 3) {
|
||||||
|
op->emitError("The first input tensor must have rank 3");
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (wShape.size() != 3) {
|
||||||
|
op->emitError("The second input tensor must have rank 3");
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (rShape.size() != 3) {
|
||||||
|
op->emitError("The third input tensor must have rank 3");
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get sequence length, batch size and input size.
|
||||||
|
auto sequenceLength = xShape[0];
|
||||||
|
auto batchSize = xShape[1];
|
||||||
|
auto inputSize = xShape[2];
|
||||||
|
|
||||||
|
// Get hidden size from hidden_size attribute.
|
||||||
|
int64_t hiddenSize = -1;
|
||||||
|
if (op->hidden_size().hasValue()) {
|
||||||
|
hiddenSize = op->hidden_size().getValue().getSExtValue();
|
||||||
|
} else {
|
||||||
|
// Infer hidden_size from wShape and rShape if possible.
|
||||||
|
if (rShape[2] != -1)
|
||||||
|
hiddenSize = rShape[2];
|
||||||
|
else if (rShape[1] != -1)
|
||||||
|
hiddenSize = rShape[1] / 4;
|
||||||
|
else if (wShape[1] != -1)
|
||||||
|
hiddenSize = wShape[1] / 4;
|
||||||
|
// Update hidden_size attribute.
|
||||||
|
if (hiddenSize != -1) {
|
||||||
|
auto builder = mlir::Builder(op->getContext());
|
||||||
|
op->hidden_sizeAttr(builder.getI64IntegerAttr(hiddenSize));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get direction.
|
||||||
|
int numDirection;
|
||||||
|
if ((op->direction() == "forward") || (op->direction() == "reverse"))
|
||||||
|
numDirection = 1;
|
||||||
|
else if (op->direction() == "bidirectional")
|
||||||
|
numDirection = 2;
|
||||||
|
else
|
||||||
|
numDirection = -1;
|
||||||
|
if (numDirection == -1) {
|
||||||
|
op->emitError("direction attribute muse be one of the strings: forward, "
|
||||||
|
"reverse, and bidirectional");
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set result types.
|
||||||
|
unsigned numOfResults = op->getNumResults();
|
||||||
|
if (numOfResults > 0) {
|
||||||
|
// Y :: [seq_length, num_directions, batch_size, hidden_size]
|
||||||
|
Type yTy = op->getResults()[0].getType();
|
||||||
|
if (!yTy.isa<NoneType>()) {
|
||||||
|
yTy = RankedTensorType::get(
|
||||||
|
{sequenceLength, numDirection, batchSize, hiddenSize}, elementType);
|
||||||
|
op->getResults()[0].setType(yTy);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (numOfResults > 1) {
|
||||||
|
// Y_h :: [num_directions, batch_size, hidden_size]
|
||||||
|
Type yhTy = op->getResults()[1].getType();
|
||||||
|
if (!yhTy.isa<NoneType>()) {
|
||||||
|
yhTy = RankedTensorType::get(
|
||||||
|
{numDirection, batchSize, hiddenSize}, elementType);
|
||||||
|
op->getResults()[1].setType(yhTy);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (numOfResults > 2) {
|
||||||
|
// Y_c :: [num_directions, batch_size, hidden_size]
|
||||||
|
Type ycTy = op->getResults()[2].getType();
|
||||||
|
if (!ycTy.isa<NoneType>()) {
|
||||||
|
ycTy = RankedTensorType::get(
|
||||||
|
{numDirection, batchSize, hiddenSize}, elementType);
|
||||||
|
op->getResults()[2].setType(ycTy);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// ONNXOpsDialect
|
// ONNXOpsDialect
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -1472,7 +1578,6 @@ bool ONNXConstantOp::inferShapes() {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
// Concat
|
// Concat
|
||||||
|
|
||||||
bool ONNXConcatOp::inferShapes() {
|
bool ONNXConcatOp::inferShapes() {
|
||||||
|
@ -1537,6 +1642,21 @@ bool ONNXConcatOp::inferShapes() {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// RNN
|
||||||
|
|
||||||
|
bool ONNXRNNOp::inferShapes() { return RNNShapeInference<>(this); }
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// LSTM
|
||||||
|
|
||||||
|
bool ONNXLSTMOp::inferShapes() { return RNNShapeInference<>(this); }
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// GRU
|
||||||
|
|
||||||
|
bool ONNXGRUOp::inferShapes() { return RNNShapeInference<>(this); }
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Split
|
// Split
|
||||||
|
|
||||||
|
|
|
@ -738,7 +738,7 @@ def ONNXFloorOp:ONNX_Op<"Floor",
|
||||||
}
|
}
|
||||||
|
|
||||||
def ONNXGRUOp:ONNX_Op<"GRU",
|
def ONNXGRUOp:ONNX_Op<"GRU",
|
||||||
[NoSideEffect]> {
|
[NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
|
||||||
let summary = "ONNX GRU operation";
|
let summary = "ONNX GRU operation";
|
||||||
let description = [{
|
let description = [{
|
||||||
"Computes an one-layer GRU. This operator is usually supported via some custom"
|
"Computes an one-layer GRU. This operator is usually supported via some custom"
|
||||||
|
@ -1246,7 +1246,7 @@ def ONNXLRNOp:ONNX_Op<"LRN",
|
||||||
}
|
}
|
||||||
|
|
||||||
def ONNXLSTMOp:ONNX_Op<"LSTM",
|
def ONNXLSTMOp:ONNX_Op<"LSTM",
|
||||||
[NoSideEffect]> {
|
[NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
|
||||||
let summary = "ONNX LSTM operation";
|
let summary = "ONNX LSTM operation";
|
||||||
let description = [{
|
let description = [{
|
||||||
"Computes an one-layer LSTM. This operator is usually supported via some"
|
"Computes an one-layer LSTM. This operator is usually supported via some"
|
||||||
|
@ -2086,7 +2086,7 @@ def ONNXQuantizeLinearOp:ONNX_Op<"QuantizeLinear",
|
||||||
}
|
}
|
||||||
|
|
||||||
def ONNXRNNOp:ONNX_Op<"RNN",
|
def ONNXRNNOp:ONNX_Op<"RNN",
|
||||||
[NoSideEffect]> {
|
[NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
|
||||||
let summary = "ONNX RNN operation";
|
let summary = "ONNX RNN operation";
|
||||||
let description = [{
|
let description = [{
|
||||||
"Computes an one-layer simple RNN. This operator is usually supported"
|
"Computes an one-layer simple RNN. This operator is usually supported"
|
||||||
|
|
|
@ -126,6 +126,9 @@ public:
|
||||||
op->getName().getStringRef() != "onnx.Concat" &&
|
op->getName().getStringRef() != "onnx.Concat" &&
|
||||||
op->getName().getStringRef() != "onnx.Split" &&
|
op->getName().getStringRef() != "onnx.Split" &&
|
||||||
op->getName().getStringRef() != "onnx.Neg" &&
|
op->getName().getStringRef() != "onnx.Neg" &&
|
||||||
|
op->getName().getStringRef() != "onnx.RNN" &&
|
||||||
|
op->getName().getStringRef() != "onnx.LSTM" &&
|
||||||
|
op->getName().getStringRef() != "onnx.GRU" &&
|
||||||
op->getName().getStringRef() != "onnx.Unsqueeze")
|
op->getName().getStringRef() != "onnx.Unsqueeze")
|
||||||
return false;
|
return false;
|
||||||
return llvm::any_of(op->getResultTypes(), [](Type result_type) {
|
return llvm::any_of(op->getResultTypes(), [](Type result_type) {
|
||||||
|
|
|
@ -356,6 +356,11 @@ test_to_enable = [
|
||||||
"test_averagepool_2d_strides_cpu",
|
"test_averagepool_2d_strides_cpu",
|
||||||
"test_averagepool_3d_default_cpu",
|
"test_averagepool_3d_default_cpu",
|
||||||
|
|
||||||
|
# LSTM
|
||||||
|
"test_lstm_defaults_cpu",
|
||||||
|
"test_lstm_with_initial_bias_cpu",
|
||||||
|
"test_lstm_with_peepholes_cpu",
|
||||||
|
|
||||||
]
|
]
|
||||||
|
|
||||||
# Extract name of all test cases.
|
# Extract name of all test cases.
|
||||||
|
|
|
@ -1769,3 +1769,314 @@ func @test_maxpool_pooling_operation(%arg0 : tensor<1x3x32x32xf32>) -> tensor<*x
|
||||||
// CHECK: }
|
// CHECK: }
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
func @test_lstm_general_computation(%arg0: tensor<4x3x2xf32>, %arg1: tensor<1x12x2xf32>, %arg2: tensor<1x12x3xf32>) -> tensor<*xf32> {
|
||||||
|
%cst = constant unit
|
||||||
|
%Y, %Y_h, %Y_c = "onnx.LSTM"(%arg0, %arg1, %arg2, %cst, %cst, %cst, %cst, %cst) {hidden_size = 3 : i64} : (tensor<4x3x2xf32>, tensor<1x12x2xf32>, tensor<1x12x3xf32>, none, none, none, none, none) -> (none, tensor<*xf32>, none)
|
||||||
|
return %Y_h : tensor<*xf32>
|
||||||
|
|
||||||
|
// CHECK-DAG: [[ACCESS_BY_OFFSET_MAP:#.+]] = affine_map<(d0)[s0, s1] -> (d0 + s0 * s1)>
|
||||||
|
// CHECK-LABEL: @test_lstm_general_computation
|
||||||
|
|
||||||
|
// CHECK: [[CELL_STATE:%.+]] = alloc() : memref<1x3x3xf32>
|
||||||
|
// CHECK: [[HIDDEN_STATE:%.+]] = alloc() : memref<1x3x3xf32>
|
||||||
|
// CHECK: {{.*}} = constant unit
|
||||||
|
|
||||||
|
// CHECK: [[INITIAL_VALUE:%.+]] = constant 0.000000e+00 : f32
|
||||||
|
// CHECK: [[INITIALIZE_LOOPS:%.+]]:3 = krnl.define_loops 3
|
||||||
|
// CHECK: [[INITIALIZE_OPT_LOOPS:%.+]]:3 = krnl.optimize_loops {
|
||||||
|
// CHECK: krnl.return_loops [[INITIALIZE_LOOPS]]#0, [[INITIALIZE_LOOPS]]#1, [[INITIALIZE_LOOPS]]#2
|
||||||
|
// CHECK: } : () -> (!krnl.loop, !krnl.loop, !krnl.loop)
|
||||||
|
// CHECK: krnl.iterate([[INITIALIZE_OPT_LOOPS]]#0, [[INITIALIZE_OPT_LOOPS]]#1, [[INITIALIZE_OPT_LOOPS]]#2) with ([[INITIALIZE_LOOPS]]#0 -> %arg3 = 0 to 1, [[INITIALIZE_LOOPS]]#1 -> %arg4 = 0 to 3, [[INITIALIZE_LOOPS]]#2 -> %arg5 = 0 to 3) {
|
||||||
|
// CHECK: store [[INITIAL_VALUE]], [[HIDDEN_STATE]][%arg3, %arg4, %arg5] : memref<1x3x3xf32>
|
||||||
|
// CHECK: store [[INITIAL_VALUE]], [[CELL_STATE]][%arg3, %arg4, %arg5] : memref<1x3x3xf32>
|
||||||
|
// CHECK: }
|
||||||
|
|
||||||
|
// CHECK: [[SEQUENCE_LOOPS:%.+]] = krnl.define_loops 1
|
||||||
|
// CHECK: [[SEQUENCE_OPT_LOOPS:%.+]] = krnl.optimize_loops {
|
||||||
|
// CHECK: krnl.return_loops [[SEQUENCE_LOOPS]]
|
||||||
|
// CHECK: } : () -> !krnl.loop
|
||||||
|
// CHECK: krnl.iterate([[SEQUENCE_OPT_LOOPS]]) with ([[SEQUENCE_LOOPS]] -> %arg3 = 0 to 4) {
|
||||||
|
// CHECK: {{.*}} = constant 0 : index
|
||||||
|
// CHECK: {{.*}} = constant 3 : index
|
||||||
|
// CHECK: {{.*}} = constant 0 : index
|
||||||
|
// CHECK: {{.*}} = constant 1 : index
|
||||||
|
// CHECK: {{.*}} = constant 2 : index
|
||||||
|
// CHECK: {{.*}} = constant 3 : index
|
||||||
|
// CHECK: {{.*}} = constant 4 : index
|
||||||
|
// CHECK: {{.*}} = constant 5 : index
|
||||||
|
// CHECK: {{.*}} = constant 6 : index
|
||||||
|
// CHECK: {{.*}} = constant 7 : index
|
||||||
|
// CHECK: [[DATA_LOOPS:%.+]]:2 = krnl.define_loops 2
|
||||||
|
// CHECK: [[DATA_OPT_LOOPS:%.+]]:2 = krnl.optimize_loops {
|
||||||
|
// CHECK: krnl.return_loops [[DATA_LOOPS]]#0, [[DATA_LOOPS]]#1
|
||||||
|
// CHECK: } : () -> (!krnl.loop, !krnl.loop)
|
||||||
|
// CHECK: krnl.iterate([[DATA_OPT_LOOPS]]#0, [[DATA_OPT_LOOPS]]#1) with ([[DATA_LOOPS]]#0 -> %arg4 = 0 to 3, [[DATA_LOOPS]]#1 -> %arg5 = 0 to 3) {
|
||||||
|
// CHECK: [[hCt:%.+]] = alloc() : memref<f32>
|
||||||
|
// CHECK: [[Ot:%.+]] = alloc() : memref<f32>
|
||||||
|
// CHECK: [[ct:%.+]] = alloc() : memref<f32>
|
||||||
|
// CHECK: [[Ft:%.+]] = alloc() : memref<f32>
|
||||||
|
// CHECK: [[It:%.+]] = alloc() : memref<f32>
|
||||||
|
// CHECK: [[Ht1_LOAD:%.+]] = load [[HIDDEN_STATE]][%c0, %arg4, %arg5] : memref<1x3x3xf32>
|
||||||
|
// CHECK: [[Ct1_LOAD:%.+]] = load [[CELL_STATE]][%c0, %arg4, %arg5] : memref<1x3x3xf32>
|
||||||
|
|
||||||
|
// CHECK: [[ZERO_FLOAT:%.+]] = constant 0.000000e+00 : f32
|
||||||
|
// CHECK: [[XtWi_GEMM:%.+]] = alloc() : memref<f32>
|
||||||
|
// CHECK: store [[ZERO_FLOAT]], [[XtWi_GEMM]][] : memref<f32>
|
||||||
|
// CHECK: [[Ht1Ri_GEMM:%.+]] = alloc() : memref<f32>
|
||||||
|
// CHECK: store [[ZERO_FLOAT]], [[Ht1Ri_GEMM]][] : memref<f32>
|
||||||
|
// CHECK: [[XtWo_GEMM:%.+]] = alloc() : memref<f32>
|
||||||
|
// CHECK: store [[ZERO_FLOAT]], [[XtWo_GEMM]][] : memref<f32>
|
||||||
|
// CHECK: [[Ht1Ro_GEMM:%.+]] = alloc() : memref<f32>
|
||||||
|
// CHECK: store [[ZERO_FLOAT]], [[Ht1Ro_GEMM]][] : memref<f32>
|
||||||
|
// CHECK: [[XtWf_GEMM:%.+]] = alloc() : memref<f32>
|
||||||
|
// CHECK: store [[ZERO_FLOAT]], [[XtWf_GEMM]][] : memref<f32>
|
||||||
|
// CHECK: [[Ht1Rf_GEMM:%.+]] = alloc() : memref<f32>
|
||||||
|
// CHECK: store [[ZERO_FLOAT]], [[Ht1Rf_GEMM]][] : memref<f32>
|
||||||
|
// CHECK: [[XtWc_GEMM:%.+]] = alloc() : memref<f32>
|
||||||
|
// CHECK: store [[ZERO_FLOAT]], [[XtWc_GEMM]][] : memref<f32>
|
||||||
|
// CHECK: [[Ht1Rc_GEMM:%.+]] = alloc() : memref<f32>
|
||||||
|
// CHECK: store [[ZERO_FLOAT]], [[Ht1Rc_GEMM]][] : memref<f32>
|
||||||
|
|
||||||
|
// CHECK: [[REDUCTION_LOOPS:%.+]] = krnl.define_loops 1
|
||||||
|
// CHECK: [[REDUCTION_OPT_LOOPS:%.+]] = krnl.optimize_loops {
|
||||||
|
// CHECK: krnl.return_loops [[REDUCTION_LOOPS]]
|
||||||
|
// CHECK: } : () -> !krnl.loop
|
||||||
|
// CHECK: krnl.iterate([[REDUCTION_OPT_LOOPS]]) with ([[REDUCTION_LOOPS]] -> %arg6 = 0 to 2) {
|
||||||
|
// CHECK: [[INPUT_HIDDEN_INDEX:%.+]] = affine.apply #{{.*}}(%arg5)[%c0_1, %c3]
|
||||||
|
// CHECK: [[OUTPUT_HIDDEN_INDEX:%.+]] = affine.apply #{{.*}}(%arg5)[%c1, %c3]
|
||||||
|
// CHECK: [[FORGET_HIDDEN_INDEX:%.+]] = affine.apply #{{.*}}(%arg5)[%c2, %c3]
|
||||||
|
// CHECK: [[CELL_HIDDEN_INDEX:%.+]] = affine.apply #{{.*}}(%arg5)[%c3_2, %c3]
|
||||||
|
// CHECK: [[Xt_LOAD:%.+]] = load %arg0[%arg3, %arg4, %arg6] : memref<4x3x2xf32>
|
||||||
|
|
||||||
|
// CHECK: [[Wi_LOAD:%.+]] = load %arg1[%c0, [[INPUT_HIDDEN_INDEX]], %arg6] : memref<1x12x2xf32>
|
||||||
|
// CHECK: {{.*}} = mulf [[Xt_LOAD]], [[Wi_LOAD]] : f32
|
||||||
|
// CHECK: {{.*}} = load [[XtWi_GEMM]][] : memref<f32>
|
||||||
|
// CHECK: {{.*}} = addf {{.*}}, {{.*}} : f32
|
||||||
|
// CHECK: store %59, [[XtWi_GEMM]][] : memref<f32>
|
||||||
|
|
||||||
|
// CHECK: [[Ri_LOAD:%.+]] = load %arg2[%c0, [[INPUT_HIDDEN_INDEX]], %arg6] : memref<1x12x3xf32>
|
||||||
|
// CHECK: {{.*}} = mulf [[Ht1_LOAD]], [[Ri_LOAD]] : f32
|
||||||
|
// CHECK: {{.*}} = load [[Ht1Ri_GEMM]][] : memref<f32>
|
||||||
|
// CHECK: {{.*}} = addf {{.*}}, {{.*}} : f32
|
||||||
|
// CHECK: store %63, [[Ht1Ri_GEMM]][] : memref<f32>
|
||||||
|
|
||||||
|
// CHECK: [[Wo_LOAD:%.+]] = load %arg1[%c0, [[OUTPUT_HIDDEN_INDEX]], %arg6] : memref<1x12x2xf32>
|
||||||
|
// CHECK: {{.*}} = mulf [[Xt_LOAD]], [[Wo_LOAD]] : f32
|
||||||
|
// CHECK: {{.*}} = load [[XtWo_GEMM]][] : memref<f32>
|
||||||
|
// CHECK: {{.*}} = addf {{.*}}, {{.*}} : f32
|
||||||
|
// CHECK: store %67, [[XtWo_GEMM]][] : memref<f32>
|
||||||
|
|
||||||
|
// CHECK: [[Ro_LOAD:%.+]] = load %arg2[%c0, [[OUTPUT_HIDDEN_INDEX]], %arg6] : memref<1x12x3xf32>
|
||||||
|
// CHECK: {{.*}} = mulf [[Ht1_LOAD]], [[Ro_LOAD]] : f32
|
||||||
|
// CHECK: {{.*}} = load [[Ht1Ro_GEMM]][] : memref<f32>
|
||||||
|
// CHECK: {{.*}} = addf {{.*}}, {{.*}} : f32
|
||||||
|
// CHECK: store %71, [[Ht1Ro_GEMM]][] : memref<f32>
|
||||||
|
|
||||||
|
// CHECK: [[Wf_LOAD:%.+]] = load %arg1[%c0, [[FORGET_HIDDEN_INDEX]], %arg6] : memref<1x12x2xf32>
|
||||||
|
// CHECK: {{.*}} = mulf [[Xt_LOAD]], [[Wf_LOAD]] : f32
|
||||||
|
// CHECK: {{.*}} = load [[XtWf_GEMM]][] : memref<f32>
|
||||||
|
// CHECK: {{.*}} = addf {{.*}}, {{.*}} : f32
|
||||||
|
// CHECK: store %75, [[XtWf_GEMM]][] : memref<f32>
|
||||||
|
|
||||||
|
// CHECK: [[Rf_LOAD:%.+]] = load %arg2[%c0, [[FORGET_HIDDEN_INDEX]], %arg6] : memref<1x12x3xf32>
|
||||||
|
// CHECK: {{.*}} = mulf [[Ht1_LOAD]], [[Rf_LOAD]] : f32
|
||||||
|
// CHECK: {{.*}} = load [[Ht1Rf_GEMM]][] : memref<f32>
|
||||||
|
// CHECK: {{.*}} = addf {{.*}}, {{.*}} : f32
|
||||||
|
// CHECK: store %79, [[Ht1Rf_GEMM]][] : memref<f32>
|
||||||
|
|
||||||
|
// CHECK: [[Wc_LOAD:%.+]] = load %arg1[%c0, [[CELL_HIDDEN_INDEX]], %arg6] : memref<1x12x2xf32>
|
||||||
|
// CHECK: {{.*}} = mulf [[Xt_LOAD]], [[Wc_LOAD]] : f32
|
||||||
|
// CHECK: {{.*}} = load [[XtWc_GEMM]][] : memref<f32>
|
||||||
|
// CHECK: {{.*}} = addf {{.*}}, {{.*}} : f32
|
||||||
|
// CHECK: store %83, [[XtWc_GEMM]][] : memref<f32>
|
||||||
|
|
||||||
|
// CHECK: [[Rc_LOAD:%.+]] = load %arg2[%c0, [[CELL_HIDDEN_INDEX]], %arg6] : memref<1x12x3xf32>
|
||||||
|
// CHECK: {{.*}} = mulf [[Ht1_LOAD]], [[Rc_LOAD]] : f32
|
||||||
|
// CHECK: {{.*}} = load [[Ht1Rc_GEMM]][] : memref<f32>
|
||||||
|
// CHECK: {{.*}} = addf {{.*}}, {{.*}} : f32
|
||||||
|
// CHECK: store %87, [[Ht1Rc_GEMM]][] : memref<f32>
|
||||||
|
// CHECK: }
|
||||||
|
|
||||||
|
// CHECK: [[XtWi_LOAD:%.+]] = load [[XtWi_GEMM]][] : memref<f32>
|
||||||
|
// CHECK: [[Ht1Ri_LOAD:%.+]] = load [[Ht1Ri_GEMM]][] : memref<f32>
|
||||||
|
// CHECK: [[It_OUTPUT:%.+]] = addf [[XtWi_LOAD]], [[Ht1Ri_LOAD]] : f32
|
||||||
|
|
||||||
|
// CHECK: [[SIGMOID_INPUT:%.+]] = alloc() : memref<f32>
|
||||||
|
// CHECK: store [[It_OUTPUT]], [[SIGMOID_INPUT]][] : memref<f32>
|
||||||
|
// CHECK: krnl.define_loops 0
|
||||||
|
// CHECK: krnl.optimize_loops {
|
||||||
|
// CHECK: krnl.return_loops
|
||||||
|
// CHECK: } : () -> ()
|
||||||
|
// CHECK: krnl.iterate() with () {
|
||||||
|
// CHECK: {{.*}} = load [[SIGMOID_INPUT]][] : memref<f32>
|
||||||
|
// CHECK: {{.*}} = constant 0.000000e+00 : f32
|
||||||
|
// CHECK: {{.*}} = constant 1.000000e+00 : f32
|
||||||
|
// CHECK: {{.*}} = subf {{.*}}, {{.*}}: f32
|
||||||
|
// CHECK: {{.*}} = exp {{.*}} : f32
|
||||||
|
// CHECK: {{.*}} = addf {{.*}}, {{.*}} : f32
|
||||||
|
// CHECK: {{.*}} = divf {{.*}}, {{.*}} : f32
|
||||||
|
// CHECK: store {{.*}}, [[It]][] : memref<f32>
|
||||||
|
// CHECK: }
|
||||||
|
// CHECK: [[It_LOAD:%.+]] = load [[It]][] : memref<f32>
|
||||||
|
|
||||||
|
// CHECK: [[XtWf_LOAD:%.+]] = load [[XtWf_GEMM]][] : memref<f32>
|
||||||
|
// CHECK: [[Ht1Rf_LOAD:%.+]] = load [[Ht1Rf_GEMM]][] : memref<f32>
|
||||||
|
// CHECK: [[Ft_OUTPUT:%.+]] = addf [[XtWf_LOAD]], [[Ht1Rf_LOAD]] : f32
|
||||||
|
|
||||||
|
// CHECK: [[SIGMOID_FORGET:%.+]] = alloc() : memref<f32>
|
||||||
|
// CHECK: store [[Ft_OUTPUT]], [[SIGMOID_FORGET]][] : memref<f32>
|
||||||
|
// CHECK: krnl.define_loops 0
|
||||||
|
// CHECK: krnl.optimize_loops {
|
||||||
|
// CHECK: krnl.return_loops
|
||||||
|
// CHECK: } : () -> ()
|
||||||
|
// CHECK: krnl.iterate() with () {
|
||||||
|
// CHECK: {{.*}} = load [[SIGMOID_FORGET]][] : memref<f32>
|
||||||
|
// CHECK: {{.*}} = constant 0.000000e+00 : f32
|
||||||
|
// CHECK: {{.*}} = constant 1.000000e+00 : f32
|
||||||
|
// CHECK: {{.*}} = subf {{.*}}, {{.*}}: f32
|
||||||
|
// CHECK: {{.*}} = exp {{.*}} : f32
|
||||||
|
// CHECK: {{.*}} = addf {{.*}}, {{.*}} : f32
|
||||||
|
// CHECK: {{.*}} = divf {{.*}}, {{.*}} : f32
|
||||||
|
// CHECK: store {{.*}}, [[Ft]][] : memref<f32>
|
||||||
|
// CHECK: }
|
||||||
|
// CHECK: [[Ft_LOAD:%.+]] = load [[Ft]][] : memref<f32>
|
||||||
|
|
||||||
|
// CHECK: [[XtWc_LOAD:%.+]] = load [[XtWc_GEMM]][] : memref<f32>
|
||||||
|
// CHECK: [[Ht1Rc_LOAD:%.+]] = load [[Ht1Rc_GEMM]][] : memref<f32>
|
||||||
|
// CHECK: [[ct_OUTPUT:%.+]] = addf [[XtWc_LOAD]], [[Ht1Rc_LOAD]] : f32
|
||||||
|
|
||||||
|
// CHECK: [[TANH_CELL:%.+]] = alloc() : memref<f32>
|
||||||
|
// CHECK: store [[ct_OUTPUT]], [[TANH_CELL]][] : memref<f32>
|
||||||
|
// CHECK: krnl.define_loops 0
|
||||||
|
// CHECK: krnl.optimize_loops {
|
||||||
|
// CHECK: krnl.return_loops
|
||||||
|
// CHECK: } : () -> ()
|
||||||
|
// CHECK: krnl.iterate() with () {
|
||||||
|
// CHECK: {{.*}} = load [[TANH_CELL]][] : memref<f32>
|
||||||
|
// CHECK: {{.*}} = constant 0.000000e+00 : f32
|
||||||
|
// CHECK: {{.*}} = subf {{.*}}, {{.*}} : f32
|
||||||
|
// CHECK: {{.*}} = exp {{.*}} : f32
|
||||||
|
// CHECK: {{.*}} = exp {{.*}} : f32
|
||||||
|
// CHECK: {{.*}} = subf {{.*}}, {{.*}} : f32
|
||||||
|
// CHECK: {{.*}} = addf {{.*}}, {{.*}} : f32
|
||||||
|
// CHECK: {{.*}} = divf {{.*}}, {{.*}} : f32
|
||||||
|
// CHECK: store {{.*}}, [[ct]][] : memref<f32>
|
||||||
|
// CHECK: }
|
||||||
|
// CHECK: [[ct_LOAD:%.+]] = load [[ct]][] : memref<f32>
|
||||||
|
|
||||||
|
// CHECK: [[FtCt1:%.+]] = mulf [[Ft_LOAD]], [[Ct1_LOAD]] : f32
|
||||||
|
// CHECK: [[Itct:%.+]] = mulf [[It_LOAD]], [[ct_LOAD]] : f32
|
||||||
|
// CHECK: [[Ct:%.+]] = addf [[FtCt1]], [[Itct]] : f32
|
||||||
|
// CHECK: store [[Ct]], [[CELL_STATE]][%c0, %arg4, %arg5] : memref<1x3x3xf32>
|
||||||
|
|
||||||
|
// CHECK: [[XtWo_LOAD:%.+]] = load [[XtWo_GEMM]][] : memref<f32>
|
||||||
|
// CHECK: [[Ht1Ro_LOAD:%.+]] = load [[Ht1Ro_GEMM]][] : memref<f32>
|
||||||
|
// CHECK: [[Ot_OUTPUT:%.+]] = addf [[XtWo_LOAD]], [[Ht1Ro_LOAD]] : f32
|
||||||
|
|
||||||
|
// CHECK: [[SIGMOID_OUTPUT:%.+]] = alloc() : memref<f32>
|
||||||
|
// CHECK: store [[Ot_OUTPUT]], [[SIGMOID_OUTPUT]][] : memref<f32>
|
||||||
|
// CHECK: krnl.define_loops 0
|
||||||
|
// CHECK: krnl.optimize_loops {
|
||||||
|
// CHECK: krnl.return_loops
|
||||||
|
// CHECK: } : () -> ()
|
||||||
|
// CHECK: krnl.iterate() with () {
|
||||||
|
// CHECK: {{.*}} = load [[SIGMOID_OUTPUT]][] : memref<f32>
|
||||||
|
// CHECK: {{.*}} = constant 0.000000e+00 : f32
|
||||||
|
// CHECK: {{.*}} = constant 1.000000e+00 : f32
|
||||||
|
// CHECK: {{.*}} = subf {{.*}}, {{.*}}: f32
|
||||||
|
// CHECK: {{.*}} = exp {{.*}} : f32
|
||||||
|
// CHECK: {{.*}} = addf {{.*}}, {{.*}} : f32
|
||||||
|
// CHECK: {{.*}} = divf {{.*}}, {{.*}} : f32
|
||||||
|
// CHECK: store {{.*}}, [[Ot]][] : memref<f32>
|
||||||
|
// CHECK: }
|
||||||
|
// CHECK: [[Ot_LOAD:%.+]] = load [[Ot]][] : memref<f32>
|
||||||
|
|
||||||
|
// CHECK: [[TANH_HIDDEN:%.+]] = alloc() : memref<f32>
|
||||||
|
// CHECK: store [[Ct]], [[TANH_HIDDEN]][] : memref<f32>
|
||||||
|
// CHECK: krnl.define_loops 0
|
||||||
|
// CHECK: krnl.optimize_loops {
|
||||||
|
// CHECK: krnl.return_loops
|
||||||
|
// CHECK: } : () -> ()
|
||||||
|
// CHECK: krnl.iterate() with () {
|
||||||
|
// CHECK: {{.*}} = load [[TANH_HIDDEN]][] : memref<f32>
|
||||||
|
// CHECK: {{.*}} = constant 0.000000e+00 : f32
|
||||||
|
// CHECK: {{.*}} = subf {{.*}}, {{.*}} : f32
|
||||||
|
// CHECK: {{.*}} = exp {{.*}} : f32
|
||||||
|
// CHECK: {{.*}} = exp {{.*}} : f32
|
||||||
|
// CHECK: {{.*}} = subf {{.*}}, {{.*}} : f32
|
||||||
|
// CHECK: {{.*}} = addf {{.*}}, {{.*}} : f32
|
||||||
|
// CHECK: {{.*}} = divf {{.*}}, {{.*}} : f32
|
||||||
|
// CHECK: store {{.*}}, [[hCt]][] : memref<f32>
|
||||||
|
// CHECK: }
|
||||||
|
// CHECK: [[hCt_LOAD:%.+]] = load [[hCt]][] : memref<f32>
|
||||||
|
|
||||||
|
// CHECK: [[Ht:%.+]] = mulf [[Ot_LOAD]], [[hCt_LOAD]] : f32
|
||||||
|
// CHECK: store [[Ht]], [[HIDDEN_STATE]][%c0, %arg4, %arg5] : memref<1x3x3xf32>
|
||||||
|
|
||||||
|
// CHECK: dealloc [[XtWi_GEMM]] : memref<f32>
|
||||||
|
// CHECK: dealloc [[XtWo_GEMM]] : memref<f32>
|
||||||
|
// CHECK: dealloc [[XtWf_GEMM]] : memref<f32>
|
||||||
|
// CHECK: dealloc [[XtWc_GEMM]] : memref<f32>
|
||||||
|
// CHECK: dealloc [[Ht1Ri_GEMM]] : memref<f32>
|
||||||
|
// CHECK: dealloc [[Ht1Ro_GEMM]] : memref<f32>
|
||||||
|
// CHECK: dealloc [[Ht1Rf_GEMM]] : memref<f32>
|
||||||
|
// CHECK: dealloc [[Ht1Rc_GEMM]] : memref<f32>
|
||||||
|
// CHECK: dealloc [[It]] : memref<f32>
|
||||||
|
// CHECK: dealloc [[Ft]] : memref<f32>
|
||||||
|
// CHECK: dealloc [[ct]] : memref<f32>
|
||||||
|
// CHECK: dealloc [[Ot]] : memref<f32>
|
||||||
|
// CHECK: dealloc [[hCt]] : memref<f32>
|
||||||
|
// CHECK: }
|
||||||
|
// CHECK: }
|
||||||
|
// CHECK: dealloc [[CELL_STATE]] : memref<1x3x3xf32>
|
||||||
|
// CHECK: return [[HIDDEN_STATE]] : memref<1x3x3xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
func @test_lstm_reverse_mode(%arg0: tensor<4x3x2xf32>, %arg1: tensor<1x12x2xf32>, %arg2: tensor<1x12x3xf32>) -> tensor<*xf32> {
|
||||||
|
%cst = constant unit
|
||||||
|
%Y, %Y_h, %Y_c = "onnx.LSTM"(%arg0, %arg1, %arg2, %cst, %cst, %cst, %cst, %cst) {hidden_size = 3 : i64, direction = "reverse"} : (tensor<4x3x2xf32>, tensor<1x12x2xf32>, tensor<1x12x3xf32>, none, none, none, none, none) -> (none, tensor<*xf32>, none)
|
||||||
|
return %Y_h : tensor<*xf32>
|
||||||
|
|
||||||
|
// CHECK: [[REVERSE_IV_MAP:#.+]] = affine_map<(d0)[s0] -> (-d0 + s0 - 1)>
|
||||||
|
// CHECK-LABEL: @test_lstm_reverse_mode
|
||||||
|
|
||||||
|
// CHECK: [[REVERSE_SEQUENCE_LOOPS:%.+]] = krnl.define_loops 1
|
||||||
|
// CHECK: [[REVERSE_SEQUENCE_OPT_LOOPS:%.+]] = krnl.optimize_loops {
|
||||||
|
// CHECK: krnl.return_loops [[REVERSE_SEQUENCE_LOOPS]]
|
||||||
|
// CHECK: } : () -> !krnl.loop
|
||||||
|
// CHECK: krnl.iterate([[REVERSE_SEQUENCE_OPT_LOOPS]]) with ([[REVERSE_SEQUENCE_LOOPS]] -> %arg3 = 0 to 4) {
|
||||||
|
// CHECK: %[[SEQUENCE_LEN:.+]] = constant 4 : index
|
||||||
|
// CHECK: %[[REVERSE_SEQUENCE_IV:.+]] = affine.apply [[REVERSE_IV_MAP]](%arg3)[%[[SEQUENCE_LEN]]{{]}}
|
||||||
|
// CHECK: [[Xt_LOAD:%.+]] = load %arg0[%[[REVERSE_SEQUENCE_IV]], {{.*}}, {{.*}}] : memref<4x3x2xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
func @test_lstm_bidirectional_mode(%arg0: tensor<4x3x2xf32>, %arg1: tensor<1x12x2xf32>, %arg2: tensor<1x12x3xf32>) -> tensor<*xf32> {
|
||||||
|
%cst = constant unit
|
||||||
|
%Y, %Y_h, %Y_c = "onnx.LSTM"(%arg0, %arg1, %arg2, %cst, %cst, %cst, %cst, %cst) {hidden_size = 3 : i64, direction = "bidirectional"} : (tensor<4x3x2xf32>, tensor<1x12x2xf32>, tensor<1x12x3xf32>, none, none, none, none, none) -> (none, tensor<*xf32>, none)
|
||||||
|
return %Y_h : tensor<*xf32>
|
||||||
|
|
||||||
|
// CHECK: [[REVERSE_IV_MAP:#.+]] = affine_map<(d0)[s0] -> (-d0 + s0 - 1)>
|
||||||
|
// CHECK-LABEL: @test_lstm_bidirectional_mode
|
||||||
|
|
||||||
|
// CHECK: [[SEQUENCE_LOOPS:%.+]] = krnl.define_loops 1
|
||||||
|
// CHECK: [[SEQUENCE_OPT_LOOPS:%.+]] = krnl.optimize_loops {
|
||||||
|
// CHECK: krnl.return_loops [[SEQUENCE_LOOPS]]
|
||||||
|
// CHECK: } : () -> !krnl.loop
|
||||||
|
// CHECK: krnl.iterate([[SEQUENCE_OPT_LOOPS]]) with ([[SEQUENCE_LOOPS]] -> %arg3 = 0 to 4) {
|
||||||
|
// CHECK: [[Xt_LOAD:%.+]] = load %arg0[%arg3, {{.*}}, {{.*}}] : memref<4x3x2xf32>
|
||||||
|
|
||||||
|
// CHECK: [[REVERSE_SEQUENCE_LOOPS:%.+]] = krnl.define_loops 1
|
||||||
|
// CHECK: [[REVERSE_SEQUENCE_OPT_LOOPS:%.+]] = krnl.optimize_loops {
|
||||||
|
// CHECK: krnl.return_loops [[REVERSE_SEQUENCE_LOOPS]]
|
||||||
|
// CHECK: } : () -> !krnl.loop
|
||||||
|
// CHECK: krnl.iterate([[REVERSE_SEQUENCE_OPT_LOOPS]]) with ([[REVERSE_SEQUENCE_LOOPS]] -> %arg3 = 0 to 4) {
|
||||||
|
// CHECK: %[[SEQUENCE_LEN:.+]] = constant 4 : index
|
||||||
|
// CHECK: %[[REVERSE_SEQUENCE_IV:.+]] = affine.apply [[REVERSE_IV_MAP]](%arg3)[%[[SEQUENCE_LEN]]{{]}}
|
||||||
|
// CHECK: [[Xt_LOAD:%.+]] = load %arg0[%[[REVERSE_SEQUENCE_IV]], {{.*}}, {{.*}}] : memref<4x3x2xf32>
|
||||||
|
}
|
||||||
|
|
|
@ -613,6 +613,222 @@ func @test_concat_3(%arg0 : tensor<5x1x32xf32>, %arg1 : tensor<5x3x32xf32>, %arg
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
|
func @test_rnn_all_results(%arg0: tensor<4x3x2xf32>, %arg1: tensor<1x12x2xf32>, %arg2: tensor<1x12x3xf32>) -> tensor<*xf32> {
|
||||||
|
%cst = constant unit
|
||||||
|
%Y, %Y_h = "onnx.RNN"(%arg0, %arg1, %arg2, %cst, %cst, %cst) {hidden_size = 3 : i64} : (tensor<4x3x2xf32>, tensor<1x12x2xf32>, tensor<1x12x3xf32>, none, none, none) -> (tensor<*xf32>, tensor<*xf32>)
|
||||||
|
return %Y_h : tensor<*xf32>
|
||||||
|
|
||||||
|
// CHECK-LABEL: test_rnn_all_results
|
||||||
|
// CHECK: %{{.*}}, [[RES:%.+]] = "onnx.RNN"(%arg0, %arg1, %arg2, %cst, %cst, %cst) {hidden_size = 3 : i64} : (tensor<4x3x2xf32>, tensor<1x12x2xf32>, tensor<1x12x3xf32>, none, none, none) -> (tensor<4x1x3x3xf32>, tensor<1x3x3xf32>)
|
||||||
|
// CHECK: return [[RES]] : tensor<1x3x3xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
func @test_rnn_no_results(%arg0: tensor<4x3x2xf32>, %arg1: tensor<1x12x2xf32>, %arg2: tensor<1x12x3xf32>) -> () {
|
||||||
|
%cst = constant unit
|
||||||
|
%Y, %Y_h = "onnx.RNN"(%arg0, %arg1, %arg2, %cst, %cst, %cst) {hidden_size = 3 : i64} : (tensor<4x3x2xf32>, tensor<1x12x2xf32>, tensor<1x12x3xf32>, none, none, none) -> (none, none)
|
||||||
|
return
|
||||||
|
|
||||||
|
// CHECK-LABEL: test_rnn_no_results
|
||||||
|
// CHECK: %{{.*}}, [[RES:%.+]] = "onnx.RNN"(%arg0, %arg1, %arg2, %cst, %cst, %cst) {hidden_size = 3 : i64} : (tensor<4x3x2xf32>, tensor<1x12x2xf32>, tensor<1x12x3xf32>, none, none, none) -> (none, none)
|
||||||
|
// CHECK: return
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
func @test_rnn_missing_first_result(%arg0: tensor<4x3x2xf32>, %arg1: tensor<1x12x2xf32>, %arg2: tensor<1x12x3xf32>) -> tensor<*xf32> {
|
||||||
|
%cst = constant unit
|
||||||
|
%Y, %Y_h = "onnx.RNN"(%arg0, %arg1, %arg2, %cst, %cst, %cst) {hidden_size = 3 : i64} : (tensor<4x3x2xf32>, tensor<1x12x2xf32>, tensor<1x12x3xf32>, none, none, none) -> (none, tensor<*xf32>)
|
||||||
|
return %Y_h : tensor<*xf32>
|
||||||
|
|
||||||
|
// CHECK-LABEL: test_rnn_missing_first_result
|
||||||
|
// CHECK: %{{.*}}, [[RES:%.+]] = "onnx.RNN"(%arg0, %arg1, %arg2, %cst, %cst, %cst) {hidden_size = 3 : i64} : (tensor<4x3x2xf32>, tensor<1x12x2xf32>, tensor<1x12x3xf32>, none, none, none) -> (none, tensor<1x3x3xf32>)
|
||||||
|
// CHECK: return [[RES]] : tensor<1x3x3xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
func @test_rnn_missing_trailing_result(%arg0: tensor<4x3x2xf32>, %arg1: tensor<1x12x2xf32>, %arg2: tensor<1x12x3xf32>) -> () {
|
||||||
|
%cst = constant unit
|
||||||
|
%Y, %Y_h = "onnx.RNN"(%arg0, %arg1, %arg2, %cst, %cst, %cst) {hidden_size = 3 : i64} : (tensor<4x3x2xf32>, tensor<1x12x2xf32>, tensor<1x12x3xf32>, none, none, none) -> (tensor<*xf32>, none)
|
||||||
|
return
|
||||||
|
|
||||||
|
// CHECK-LABEL: test_rnn_missing_trailing_result
|
||||||
|
// CHECK: %{{.*}}, [[RES:%.+]] = "onnx.RNN"(%arg0, %arg1, %arg2, %cst, %cst, %cst) {hidden_size = 3 : i64} : (tensor<4x3x2xf32>, tensor<1x12x2xf32>, tensor<1x12x3xf32>, none, none, none) -> (tensor<4x1x3x3xf32>, none)
|
||||||
|
// CHECK: return
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
func @test_rnn_all_results_no_hidden_size(%arg0: tensor<4x3x2xf32>, %arg1: tensor<1x12x2xf32>, %arg2: tensor<1x12x3xf32>) -> tensor<*xf32> {
|
||||||
|
%cst = constant unit
|
||||||
|
%Y, %Y_h = "onnx.RNN"(%arg0, %arg1, %arg2, %cst, %cst, %cst) : (tensor<4x3x2xf32>, tensor<1x12x2xf32>, tensor<1x12x3xf32>, none, none, none) -> (tensor<*xf32>, tensor<*xf32>)
|
||||||
|
return %Y_h : tensor<*xf32>
|
||||||
|
|
||||||
|
// CHECK-LABEL: test_rnn_all_results_no_hidden_size
|
||||||
|
// CHECK: %{{.*}}, [[RES:%.+]] = "onnx.RNN"(%arg0, %arg1, %arg2, %cst, %cst, %cst) {hidden_size = 3 : i64} : (tensor<4x3x2xf32>, tensor<1x12x2xf32>, tensor<1x12x3xf32>, none, none, none) -> (tensor<4x1x3x3xf32>, tensor<1x3x3xf32>)
|
||||||
|
// CHECK: return [[RES]] : tensor<1x3x3xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
func @test_rnn_all_results_unknown_dims(%arg0: tensor<?x?x?xf32>, %arg1: tensor<?x?x?xf32>, %arg2: tensor<?x?x?xf32>) -> tensor<*xf32> {
|
||||||
|
%cst = constant unit
|
||||||
|
%Y, %Y_h = "onnx.RNN"(%arg0, %arg1, %arg2, %cst, %cst, %cst) : (tensor<?x?x?xf32>, tensor<?x?x?xf32>, tensor<?x?x?xf32>, none, none, none) -> (tensor<*xf32>, tensor<*xf32>)
|
||||||
|
return %Y_h : tensor<*xf32>
|
||||||
|
|
||||||
|
// CHECK-LABEL: test_rnn_all_results_unknown_dims
|
||||||
|
// CHECK: %{{.*}}, [[RES:%.+]] = "onnx.RNN"(%arg0, %arg1, %arg2, %cst, %cst, %cst) : (tensor<?x?x?xf32>, tensor<?x?x?xf32>, tensor<?x?x?xf32>, none, none, none) -> (tensor<?x1x?x?xf32>, tensor<1x?x?xf32>)
|
||||||
|
// CHECK: return [[RES]] : tensor<1x?x?xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
func @test_gru_all_results(%arg0: tensor<4x3x2xf32>, %arg1: tensor<1x12x2xf32>, %arg2: tensor<1x12x3xf32>) -> tensor<*xf32> {
|
||||||
|
%cst = constant unit
|
||||||
|
%Y, %Y_h = "onnx.GRU"(%arg0, %arg1, %arg2, %cst, %cst, %cst) {hidden_size = 3 : i64} : (tensor<4x3x2xf32>, tensor<1x12x2xf32>, tensor<1x12x3xf32>, none, none, none) -> (tensor<*xf32>, tensor<*xf32>)
|
||||||
|
return %Y_h : tensor<*xf32>
|
||||||
|
|
||||||
|
// CHECK-LABEL: test_gru_all_results
|
||||||
|
// CHECK: %{{.*}}, [[RES:%.+]] = "onnx.GRU"(%arg0, %arg1, %arg2, %cst, %cst, %cst) {hidden_size = 3 : i64} : (tensor<4x3x2xf32>, tensor<1x12x2xf32>, tensor<1x12x3xf32>, none, none, none) -> (tensor<4x1x3x3xf32>, tensor<1x3x3xf32>)
|
||||||
|
// CHECK: return [[RES]] : tensor<1x3x3xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
func @test_gru_no_results(%arg0: tensor<4x3x2xf32>, %arg1: tensor<1x12x2xf32>, %arg2: tensor<1x12x3xf32>) -> () {
|
||||||
|
%cst = constant unit
|
||||||
|
%Y, %Y_h = "onnx.GRU"(%arg0, %arg1, %arg2, %cst, %cst, %cst) {hidden_size = 3 : i64} : (tensor<4x3x2xf32>, tensor<1x12x2xf32>, tensor<1x12x3xf32>, none, none, none) -> (none, none)
|
||||||
|
return
|
||||||
|
|
||||||
|
// CHECK-LABEL: test_gru_no_results
|
||||||
|
// CHECK: %{{.*}}, [[RES:%.+]] = "onnx.GRU"(%arg0, %arg1, %arg2, %cst, %cst, %cst) {hidden_size = 3 : i64} : (tensor<4x3x2xf32>, tensor<1x12x2xf32>, tensor<1x12x3xf32>, none, none, none) -> (none, none)
|
||||||
|
// CHECK: return
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
func @test_gru_missing_first_result(%arg0: tensor<4x3x2xf32>, %arg1: tensor<1x12x2xf32>, %arg2: tensor<1x12x3xf32>) -> tensor<*xf32> {
|
||||||
|
%cst = constant unit
|
||||||
|
%Y, %Y_h = "onnx.GRU"(%arg0, %arg1, %arg2, %cst, %cst, %cst) {hidden_size = 3 : i64} : (tensor<4x3x2xf32>, tensor<1x12x2xf32>, tensor<1x12x3xf32>, none, none, none) -> (none, tensor<*xf32>)
|
||||||
|
return %Y_h : tensor<*xf32>
|
||||||
|
|
||||||
|
// CHECK-LABEL: test_gru_missing_first_result
|
||||||
|
// CHECK: %{{.*}}, [[RES:%.+]] = "onnx.GRU"(%arg0, %arg1, %arg2, %cst, %cst, %cst) {hidden_size = 3 : i64} : (tensor<4x3x2xf32>, tensor<1x12x2xf32>, tensor<1x12x3xf32>, none, none, none) -> (none, tensor<1x3x3xf32>)
|
||||||
|
// CHECK: return [[RES]] : tensor<1x3x3xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
func @test_gru_missing_trailing_result(%arg0: tensor<4x3x2xf32>, %arg1: tensor<1x12x2xf32>, %arg2: tensor<1x12x3xf32>) -> () {
|
||||||
|
%cst = constant unit
|
||||||
|
%Y, %Y_h = "onnx.GRU"(%arg0, %arg1, %arg2, %cst, %cst, %cst) {hidden_size = 3 : i64} : (tensor<4x3x2xf32>, tensor<1x12x2xf32>, tensor<1x12x3xf32>, none, none, none) -> (tensor<*xf32>, none)
|
||||||
|
return
|
||||||
|
|
||||||
|
// CHECK-LABEL: test_gru_missing_trailing_result
|
||||||
|
// CHECK: %{{.*}}, [[RES:%.+]] = "onnx.GRU"(%arg0, %arg1, %arg2, %cst, %cst, %cst) {hidden_size = 3 : i64} : (tensor<4x3x2xf32>, tensor<1x12x2xf32>, tensor<1x12x3xf32>, none, none, none) -> (tensor<4x1x3x3xf32>, none)
|
||||||
|
// CHECK: return
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
func @test_gru_all_results_no_hidden_size(%arg0: tensor<4x3x2xf32>, %arg1: tensor<1x12x2xf32>, %arg2: tensor<1x12x3xf32>) -> tensor<*xf32> {
|
||||||
|
%cst = constant unit
|
||||||
|
%Y, %Y_h = "onnx.GRU"(%arg0, %arg1, %arg2, %cst, %cst, %cst) : (tensor<4x3x2xf32>, tensor<1x12x2xf32>, tensor<1x12x3xf32>, none, none, none) -> (tensor<*xf32>, tensor<*xf32>)
|
||||||
|
return %Y_h : tensor<*xf32>
|
||||||
|
|
||||||
|
// CHECK-LABEL: test_gru_all_results_no_hidden_size
|
||||||
|
// CHECK: %{{.*}}, [[RES:%.+]] = "onnx.GRU"(%arg0, %arg1, %arg2, %cst, %cst, %cst) {hidden_size = 3 : i64} : (tensor<4x3x2xf32>, tensor<1x12x2xf32>, tensor<1x12x3xf32>, none, none, none) -> (tensor<4x1x3x3xf32>, tensor<1x3x3xf32>)
|
||||||
|
// CHECK: return [[RES]] : tensor<1x3x3xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
func @test_gru_all_results_unknown_dims(%arg0: tensor<?x?x?xf32>, %arg1: tensor<?x?x?xf32>, %arg2: tensor<?x?x?xf32>) -> tensor<*xf32> {
|
||||||
|
%cst = constant unit
|
||||||
|
%Y, %Y_h = "onnx.GRU"(%arg0, %arg1, %arg2, %cst, %cst, %cst) : (tensor<?x?x?xf32>, tensor<?x?x?xf32>, tensor<?x?x?xf32>, none, none, none) -> (tensor<*xf32>, tensor<*xf32>)
|
||||||
|
return %Y_h : tensor<*xf32>
|
||||||
|
|
||||||
|
// CHECK-LABEL: test_gru_all_results_unknown_dims
|
||||||
|
// CHECK: %{{.*}}, [[RES:%.+]] = "onnx.GRU"(%arg0, %arg1, %arg2, %cst, %cst, %cst) : (tensor<?x?x?xf32>, tensor<?x?x?xf32>, tensor<?x?x?xf32>, none, none, none) -> (tensor<?x1x?x?xf32>, tensor<1x?x?xf32>)
|
||||||
|
// CHECK: return [[RES]] : tensor<1x?x?xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
func @test_lstm_all_results(%arg0: tensor<4x3x2xf32>, %arg1: tensor<1x12x2xf32>, %arg2: tensor<1x12x3xf32>) -> tensor<*xf32> {
|
||||||
|
%cst = constant unit
|
||||||
|
%Y, %Y_h, %Y_c = "onnx.LSTM"(%arg0, %arg1, %arg2, %cst, %cst, %cst, %cst, %cst) {hidden_size = 3 : i64} : (tensor<4x3x2xf32>, tensor<1x12x2xf32>, tensor<1x12x3xf32>, none, none, none, none, none) -> (tensor<*xf32>, tensor<*xf32>, tensor<*xf32>)
|
||||||
|
return %Y_h : tensor<*xf32>
|
||||||
|
|
||||||
|
// CHECK-LABEL: test_lstm_all_results
|
||||||
|
// CHECK: %{{.*}}, [[RES:%.+]], %{{.*}} = "onnx.LSTM"(%arg0, %arg1, %arg2, %cst, %cst, %cst, %cst, %cst) {hidden_size = 3 : i64} : (tensor<4x3x2xf32>, tensor<1x12x2xf32>, tensor<1x12x3xf32>, none, none, none, none, none) -> (tensor<4x1x3x3xf32>, tensor<1x3x3xf32>, tensor<1x3x3xf32>)
|
||||||
|
// CHECK: return [[RES]] : tensor<1x3x3xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
func @test_lstm_no_results(%arg0: tensor<4x3x2xf32>, %arg1: tensor<1x12x2xf32>, %arg2: tensor<1x12x3xf32>) -> () {
|
||||||
|
%cst = constant unit
|
||||||
|
%Y, %Y_h, %Y_c = "onnx.LSTM"(%arg0, %arg1, %arg2, %cst, %cst, %cst, %cst, %cst) {hidden_size = 3 : i64} : (tensor<4x3x2xf32>, tensor<1x12x2xf32>, tensor<1x12x3xf32>, none, none, none, none, none) -> (none, none, none)
|
||||||
|
return
|
||||||
|
|
||||||
|
// CHECK-LABEL: test_lstm_no_results
|
||||||
|
// CHECK: %{{.*}}, [[RES:%.+]], %{{.*}} = "onnx.LSTM"(%arg0, %arg1, %arg2, %cst, %cst, %cst, %cst, %cst) {hidden_size = 3 : i64} : (tensor<4x3x2xf32>, tensor<1x12x2xf32>, tensor<1x12x3xf32>, none, none, none, none, none) -> (none, none, none)
|
||||||
|
// CHECK: return
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
func @test_lstm_missing_first_result(%arg0: tensor<4x3x2xf32>, %arg1: tensor<1x12x2xf32>, %arg2: tensor<1x12x3xf32>) -> tensor<*xf32> {
|
||||||
|
%cst = constant unit
|
||||||
|
%Y, %Y_h, %Y_c = "onnx.LSTM"(%arg0, %arg1, %arg2, %cst, %cst, %cst, %cst, %cst) {hidden_size = 3 : i64} : (tensor<4x3x2xf32>, tensor<1x12x2xf32>, tensor<1x12x3xf32>, none, none, none, none, none) -> (none, tensor<*xf32>, tensor<*xf32>)
|
||||||
|
return %Y_h : tensor<*xf32>
|
||||||
|
|
||||||
|
// CHECK-LABEL: test_lstm_missing_first_result
|
||||||
|
// CHECK: %{{.*}}, [[RES:%.+]], %{{.*}} = "onnx.LSTM"(%arg0, %arg1, %arg2, %cst, %cst, %cst, %cst, %cst) {hidden_size = 3 : i64} : (tensor<4x3x2xf32>, tensor<1x12x2xf32>, tensor<1x12x3xf32>, none, none, none, none, none) -> (none, tensor<1x3x3xf32>, tensor<1x3x3xf32>)
|
||||||
|
// CHECK: return [[RES]] : tensor<1x3x3xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
func @test_lstm_missing_trailing_result(%arg0: tensor<4x3x2xf32>, %arg1: tensor<1x12x2xf32>, %arg2: tensor<1x12x3xf32>) -> tensor<*xf32> {
|
||||||
|
%cst = constant unit
|
||||||
|
%Y, %Y_h, %Y_c = "onnx.LSTM"(%arg0, %arg1, %arg2, %cst, %cst, %cst, %cst, %cst) {hidden_size = 3 : i64} : (tensor<4x3x2xf32>, tensor<1x12x2xf32>, tensor<1x12x3xf32>, none, none, none, none, none) -> (tensor<*xf32>, tensor<*xf32>, none)
|
||||||
|
return %Y_h : tensor<*xf32>
|
||||||
|
|
||||||
|
// CHECK-LABEL: test_lstm_missing_trailing_result
|
||||||
|
// CHECK: %{{.*}}, [[RES:%.+]], %{{.*}} = "onnx.LSTM"(%arg0, %arg1, %arg2, %cst, %cst, %cst, %cst, %cst) {hidden_size = 3 : i64} : (tensor<4x3x2xf32>, tensor<1x12x2xf32>, tensor<1x12x3xf32>, none, none, none, none, none) -> (tensor<4x1x3x3xf32>, tensor<1x3x3xf32>, none)
|
||||||
|
// CHECK: return [[RES]] : tensor<1x3x3xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
func @test_lstm_all_results_no_hidden_size(%arg0: tensor<4x3x2xf32>, %arg1: tensor<1x12x2xf32>, %arg2: tensor<1x12x3xf32>) -> tensor<*xf32> {
|
||||||
|
%cst = constant unit
|
||||||
|
%Y, %Y_h, %Y_c = "onnx.LSTM"(%arg0, %arg1, %arg2, %cst, %cst, %cst, %cst, %cst) : (tensor<4x3x2xf32>, tensor<1x12x2xf32>, tensor<1x12x3xf32>, none, none, none, none, none) -> (tensor<*xf32>, tensor<*xf32>, tensor<*xf32>)
|
||||||
|
return %Y_h : tensor<*xf32>
|
||||||
|
|
||||||
|
// CHECK-LABEL: test_lstm_all_results_no_hidden_size
|
||||||
|
// CHECK: %{{.*}}, [[RES:%.+]], %{{.*}} = "onnx.LSTM"(%arg0, %arg1, %arg2, %cst, %cst, %cst, %cst, %cst) {hidden_size = 3 : i64} : (tensor<4x3x2xf32>, tensor<1x12x2xf32>, tensor<1x12x3xf32>, none, none, none, none, none) -> (tensor<4x1x3x3xf32>, tensor<1x3x3xf32>, tensor<1x3x3xf32>)
|
||||||
|
// CHECK: return [[RES]] : tensor<1x3x3xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
func @test_lstm_all_results_unknown_dims(%arg0: tensor<?x?x?xf32>, %arg1: tensor<?x?x?xf32>, %arg2: tensor<?x?x?xf32>) -> tensor<*xf32> {
|
||||||
|
%cst = constant unit
|
||||||
|
%Y, %Y_h, %Y_c = "onnx.LSTM"(%arg0, %arg1, %arg2, %cst, %cst, %cst, %cst, %cst) : (tensor<?x?x?xf32>, tensor<?x?x?xf32>, tensor<?x?x?xf32>, none, none, none, none, none) -> (tensor<*xf32>, tensor<*xf32>, tensor<*xf32>)
|
||||||
|
return %Y_h : tensor<*xf32>
|
||||||
|
|
||||||
|
// CHECK-LABEL: test_lstm_all_results_unknown_dims
|
||||||
|
// CHECK: %{{.*}}, [[RES:%.+]], %{{.*}} = "onnx.LSTM"(%arg0, %arg1, %arg2, %cst, %cst, %cst, %cst, %cst) : (tensor<?x?x?xf32>, tensor<?x?x?xf32>, tensor<?x?x?xf32>, none, none, none, none, none) -> (tensor<?x1x?x?xf32>, tensor<1x?x?xf32>, tensor<1x?x?xf32>)
|
||||||
|
// CHECK: return [[RES]] : tensor<1x?x?xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
func @test_split_1(%arg0 : tensor<16x32x64xf32>) -> tensor<*xf32> {
|
func @test_split_1(%arg0 : tensor<16x32x64xf32>) -> tensor<*xf32> {
|
||||||
%0, %1 = "onnx.Split"(%arg0) { axis = 1 } : (tensor<16x32x64xf32>) -> (tensor<*xf32>, tensor<*xf32>)
|
%0, %1 = "onnx.Split"(%arg0) { axis = 1 } : (tensor<16x32x64xf32>) -> (tensor<*xf32>, tensor<*xf32>)
|
||||||
"std.return"(%0) : (tensor<*xf32>) -> ()
|
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||||
|
|
|
@ -63,7 +63,8 @@ OpsWithShapeInference = [
|
||||||
'LeakyRelu', 'Elu', 'Selu', 'HardSigmoid', 'Reshape', 'Reciprocal',
|
'LeakyRelu', 'Elu', 'Selu', 'HardSigmoid', 'Reshape', 'Reciprocal',
|
||||||
'Identity', 'Cos', 'Log', 'Transpose', 'Softmax', 'ReduceMax', 'ReduceMin',
|
'Identity', 'Cos', 'Log', 'Transpose', 'Softmax', 'ReduceMax', 'ReduceMin',
|
||||||
'ReduceProd', 'ReduceSum', 'Softplus', 'Softsign', 'Sqrt', 'Unsqueeze',
|
'ReduceProd', 'ReduceSum', 'Softplus', 'Softsign', 'Sqrt', 'Unsqueeze',
|
||||||
'Sign', 'Constant', 'AveragePool', 'Abs', 'Conv', 'Concat', 'Neg', 'Split'
|
'Sign', 'Constant', 'AveragePool', 'Abs', 'Conv', 'Concat', 'Neg', 'RNN',
|
||||||
|
'LSTM', 'GRU', 'Split'
|
||||||
]
|
]
|
||||||
|
|
||||||
# Operations supporting canonicalization.
|
# Operations supporting canonicalization.
|
||||||
|
|
Loading…
Reference in New Issue