Lower LSTMOp to Krnl dialect (#73)
* Support dilations and enable e2e tests * Fix allocating memory for dynamic shape * Edit comments * Do dilation by computing an offset from kernel index * Correct dilation formula, add an example of out-of-bound, and add a test for dilation * Import optional outputs as NoneType * Shape inference for ONNXLSTM * Edit ONNXLSTM::inferShape() * Shape inference for ONNXLSTMOp * Create a common function for inferring shape for RNN ops * CheckInsertDeallocation for a specific result * Allocate memory for LSTM * First round of lowering * Allocate memory for hidden and cell states * Test with custom Tanh * Fix an error in Ct's formula * Add E2E tests * Return outputs * Refactor the code * Enable E2E tests * Support reverse and bidirectional directions * Minor revision * Return all intermediate hidden states * Call existing activation functions * Structs for activation functions * Call existing activations in ONNX * Minor revision * Compare strings ignoring case * Use memreftype of rank 0 for calling activation functions * Fix getActivationPack() * Revise the code * Add one MLIR test * Add MLIR tests for reverse and bidirectional modes * Make the order of emiting instructions deterministic * Use OperandAdaptor instead of directly use an operand index * Use literal assignments * Change some variable names * Use literal assignments * Use literal assignments * Format the code Co-authored-by: Gheorghe-Teodor Bercea <gt.bercea@gmail.com>
This commit is contained in:
parent
9a874007ce
commit
24343177b8
|
@ -9,6 +9,9 @@ add_library(OMONNXToKrnl
|
|||
NN/Conv.cpp
|
||||
NN/Normalization.cpp
|
||||
NN/Pooling.cpp
|
||||
RNN/RNNBase.cpp
|
||||
RNN/RNNBase.hpp
|
||||
RNN/LSTM.cpp
|
||||
Tensor/Identity.cpp
|
||||
Tensor/Reshape.cpp
|
||||
Tensor/PadConstantValuePad.cpp
|
||||
|
|
|
@ -103,6 +103,8 @@ void FrontendToKrnlLoweringPass::runOnOperation() {
|
|||
populateLoweringONNXConvOpPattern(patterns, &getContext());
|
||||
populateLoweringONNXNormalizationOpPattern(patterns, &getContext());
|
||||
populateLoweringONNXPoolingOpPattern(patterns, &getContext());
|
||||
// Recurrent neural network
|
||||
populateLoweringONNXLSTMOpPattern(patterns, &getContext());
|
||||
// Entry point
|
||||
patterns.insert<ONNXEntryPointLowering>(&getContext());
|
||||
|
||||
|
|
|
@ -102,16 +102,14 @@ Value insertAllocAndDealloc(MemRefType type, Location loc,
|
|||
// Determine if current function returns the result value of the
|
||||
// current op being lowered. If it does then dealloc should not be
|
||||
// inserted.
|
||||
bool checkInsertDealloc(Operation *currentOp) {
|
||||
bool checkInsertDealloc(Operation *currentOp, int resultIndex) {
|
||||
auto parentBlock = currentOp->getBlock();
|
||||
|
||||
bool insertDealloc = true;
|
||||
parentBlock->walk([&insertDealloc, currentOp](ReturnOp op) {
|
||||
assert(currentOp->getNumResults() < 2 &&
|
||||
"No more than one result supported (for now).");
|
||||
parentBlock->walk([&insertDealloc, currentOp, resultIndex](ReturnOp op) {
|
||||
// If there is at least one result to investigate.
|
||||
if (currentOp->getNumResults() > 0) {
|
||||
auto result = currentOp->getResult(0);
|
||||
auto result = currentOp->getResult(resultIndex);
|
||||
for (const auto &operand : op.getOperands())
|
||||
if (operand == result)
|
||||
insertDealloc = false;
|
||||
|
@ -488,4 +486,4 @@ Value emitNegativeInfinityConstantOp(
|
|||
|
||||
int64_t ArrayAttrIntVal(ArrayAttr a, int i) {
|
||||
return (a.getValue()[i]).cast<IntegerAttr>().getInt();
|
||||
}
|
||||
}
|
||||
|
|
|
@ -46,7 +46,7 @@ Value insertAllocAndDealloc(MemRefType type, Location loc,
|
|||
// Determine if current function returns the result value of the
|
||||
// current op being lowered. If it does then dealloc should not be
|
||||
// inserted.
|
||||
bool checkInsertDealloc(Operation *currentOp);
|
||||
bool checkInsertDealloc(Operation *currentOp, int resultIndex = 0);
|
||||
|
||||
// Create a mapping from result type's dimensions to input type's dimensions,
|
||||
// given that the result type is the result of a reduction op over the input
|
||||
|
@ -218,6 +218,10 @@ void populateLoweringONNXNormalizationOpPattern(
|
|||
void populateLoweringONNXPoolingOpPattern(
|
||||
OwningRewritePatternList &patterns, MLIRContext *ctx);
|
||||
|
||||
// `RNN` directory methods:
|
||||
void populateLoweringONNXLSTMOpPattern(
|
||||
OwningRewritePatternList &patterns, MLIRContext *ctx);
|
||||
|
||||
// `Tensor` directory methods:
|
||||
|
||||
void populateLoweringONNXUnsqueezeOpPattern(
|
||||
|
|
|
@ -0,0 +1,537 @@
|
|||
//===--------------- LSTM.cpp - Lowering LSTM Op --------------------------===//
|
||||
//
|
||||
// Copyright 2019 The IBM Research Authors.
|
||||
//
|
||||
// =============================================================================
|
||||
//
|
||||
// This file lowers the ONNX LSTM Operators to Krnl dialect.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "src/Conversion/ONNXToKrnl/RNN/RNNBase.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
struct LstmState {
|
||||
Value allH;
|
||||
Value ht;
|
||||
Value ct;
|
||||
};
|
||||
|
||||
struct LstmActivationPack {
|
||||
RNNActivation f;
|
||||
RNNActivation g;
|
||||
RNNActivation h;
|
||||
};
|
||||
|
||||
template <>
|
||||
bool hasAllNoneOutput<ONNXLSTMOp>(ONNXLSTMOp *op) {
|
||||
return (
|
||||
isNoneType(op->Y()) && isNoneType(op->Y_h()) && isNoneType(op->Y_c()));
|
||||
}
|
||||
|
||||
template <>
|
||||
std::tuple<LstmActivationPack, LstmActivationPack>
|
||||
getActivationPack<ONNXLSTMOp, LstmActivationPack>(ONNXLSTMOp *op) {
|
||||
auto direction = op->direction();
|
||||
auto activations = op->activations();
|
||||
auto activationAlpha = op->activation_alpha();
|
||||
auto activationBeta = op->activation_beta();
|
||||
|
||||
LstmActivationPack activationForward, activationReverse;
|
||||
|
||||
// Get activation function name.
|
||||
// Default forward functions
|
||||
activationForward.f.name = "sigmoid";
|
||||
activationForward.g.name = "tanh";
|
||||
activationForward.h.name = "tanh";
|
||||
// Default backward functions
|
||||
activationReverse.f.name = "sigmoid";
|
||||
activationReverse.g.name = "tanh";
|
||||
activationReverse.h.name = "tanh";
|
||||
if (activations) {
|
||||
ArrayAttr activationArrAttr = activations.getValue();
|
||||
if (direction == FORWARD || direction == BIDIRECTIONAL) {
|
||||
// Forward activations.
|
||||
if (activationArrAttr.size() > 0) {
|
||||
activationForward.f.name =
|
||||
activationArrAttr[0].cast<StringAttr>().getValue();
|
||||
}
|
||||
if (activationArrAttr.size() > 1) {
|
||||
activationForward.g.name =
|
||||
activationArrAttr[1].cast<StringAttr>().getValue();
|
||||
}
|
||||
if (activationArrAttr.size() > 2) {
|
||||
activationForward.h.name =
|
||||
activationArrAttr[2].cast<StringAttr>().getValue();
|
||||
}
|
||||
}
|
||||
|
||||
// Reverse activations.
|
||||
if (direction == REVERSE || direction == BIDIRECTIONAL) {
|
||||
int startIndex = (direction == REVERSE) ? 0 : 3;
|
||||
if (activationArrAttr.size() > startIndex) {
|
||||
activationReverse.f.name =
|
||||
activationArrAttr[startIndex].cast<StringAttr>().getValue();
|
||||
}
|
||||
if (activationArrAttr.size() > startIndex + 1) {
|
||||
activationReverse.g.name =
|
||||
activationArrAttr[startIndex + 1].cast<StringAttr>().getValue();
|
||||
}
|
||||
if (activationArrAttr.size() > startIndex + 2) {
|
||||
activationReverse.h.name =
|
||||
activationArrAttr[startIndex + 2].cast<StringAttr>().getValue();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Get alpha attributes.
|
||||
if (activationAlpha) {
|
||||
ArrayAttr activationArrAttr = activationAlpha.getValue();
|
||||
if (direction == FORWARD || direction == BIDIRECTIONAL) {
|
||||
// Forward activations.
|
||||
if (activationArrAttr.size() > 0) {
|
||||
activationForward.f.alpha = activationArrAttr[0].cast<FloatAttr>();
|
||||
}
|
||||
if (activationArrAttr.size() > 1) {
|
||||
activationForward.g.alpha = activationArrAttr[1].cast<FloatAttr>();
|
||||
}
|
||||
if (activationArrAttr.size() > 2) {
|
||||
activationForward.h.alpha = activationArrAttr[2].cast<FloatAttr>();
|
||||
}
|
||||
}
|
||||
|
||||
// Reverse activations.
|
||||
if (direction == REVERSE || direction == BIDIRECTIONAL) {
|
||||
int startIndex = (direction == REVERSE) ? 0 : 3;
|
||||
if (activationArrAttr.size() > startIndex) {
|
||||
activationReverse.f.alpha =
|
||||
activationArrAttr[startIndex].cast<FloatAttr>();
|
||||
}
|
||||
if (activationArrAttr.size() > startIndex + 1) {
|
||||
activationReverse.g.alpha =
|
||||
activationArrAttr[startIndex + 1].cast<FloatAttr>();
|
||||
}
|
||||
if (activationArrAttr.size() > startIndex + 2) {
|
||||
activationReverse.h.alpha =
|
||||
activationArrAttr[startIndex + 2].cast<FloatAttr>();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Get beta attributes.
|
||||
if (activationBeta) {
|
||||
ArrayAttr activationArrAttr = activationBeta.getValue();
|
||||
if (direction == FORWARD || direction == BIDIRECTIONAL) {
|
||||
// Forward activations.
|
||||
if (activationArrAttr.size() > 0) {
|
||||
activationForward.f.beta = activationArrAttr[0].cast<FloatAttr>();
|
||||
}
|
||||
if (activationArrAttr.size() > 1) {
|
||||
activationForward.g.beta = activationArrAttr[1].cast<FloatAttr>();
|
||||
}
|
||||
if (activationArrAttr.size() > 2) {
|
||||
activationForward.h.beta = activationArrAttr[2].cast<FloatAttr>();
|
||||
}
|
||||
}
|
||||
|
||||
// Reverse activations.
|
||||
if (direction == REVERSE || direction == BIDIRECTIONAL) {
|
||||
int startIndex = (direction == REVERSE) ? 0 : 3;
|
||||
if (activationArrAttr.size() > startIndex) {
|
||||
activationReverse.f.beta =
|
||||
activationArrAttr[startIndex].cast<FloatAttr>();
|
||||
}
|
||||
if (activationArrAttr.size() > startIndex + 1) {
|
||||
activationReverse.g.beta =
|
||||
activationArrAttr[startIndex + 1].cast<FloatAttr>();
|
||||
}
|
||||
if (activationArrAttr.size() > startIndex + 2) {
|
||||
activationReverse.h.beta =
|
||||
activationArrAttr[startIndex + 2].cast<FloatAttr>();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return std::make_tuple(activationForward, activationReverse);
|
||||
}
|
||||
|
||||
template <>
|
||||
LstmState allocAndInitializeStates<ONNXLSTMOp, LstmState>(
|
||||
ConversionPatternRewriter &rewriter, Location loc, ONNXLSTMOp *op,
|
||||
OperandAdaptor<ONNXLSTMOp> operandAdaptor) {
|
||||
LstmState state;
|
||||
|
||||
// Insert allocation and deallocation for the results of this operation.
|
||||
if (!isNoneType(op->Y())) {
|
||||
auto yMemRefType = convertToMemRefType(op->Y().getType());
|
||||
if (hasAllConstantDimensions(yMemRefType))
|
||||
state.allH = insertAllocAndDealloc(yMemRefType, loc, rewriter,
|
||||
checkInsertDealloc(op->getOperation(), 0));
|
||||
else
|
||||
emitError(loc, "Unsupported dynamic dimensions.");
|
||||
} else {
|
||||
state.allH = op->Y();
|
||||
}
|
||||
|
||||
// Y_h :: [num_directions, batch_size, hidden_size]
|
||||
if (!isNoneType(op->Y_h())) {
|
||||
auto yhMemRefType = convertToMemRefType(op->Y_h().getType());
|
||||
if (hasAllConstantDimensions(yhMemRefType))
|
||||
state.ht = insertAllocAndDealloc(yhMemRefType, loc, rewriter,
|
||||
checkInsertDealloc(op->getOperation(), 1));
|
||||
else
|
||||
emitError(loc, "Unsupported dynamic dimensions.");
|
||||
} else {
|
||||
auto yhMemRefType = MemRefType::get(
|
||||
{dimAt(operandAdaptor.W(), 0), dimAt(operandAdaptor.X(), 1),
|
||||
dimAt(operandAdaptor.R(), 2)},
|
||||
operandAdaptor.X().getType().cast<ShapedType>().getElementType());
|
||||
state.ht = insertAllocAndDealloc(yhMemRefType, loc, rewriter, true);
|
||||
}
|
||||
|
||||
// Y_c :: [num_directions, batch_size, hidden_size]
|
||||
if (!isNoneType(op->Y_c())) {
|
||||
auto ycMemRefType = convertToMemRefType(op->Y_c().getType());
|
||||
if (hasAllConstantDimensions(ycMemRefType))
|
||||
state.ct = insertAllocAndDealloc(ycMemRefType, loc, rewriter,
|
||||
checkInsertDealloc(op->getOperation(), 2));
|
||||
else
|
||||
emitError(loc, "Unsupported dynamic dimensions.");
|
||||
} else {
|
||||
auto ycMemRefType = MemRefType::get(
|
||||
{dimAt(operandAdaptor.W(), 0), dimAt(operandAdaptor.X(), 1),
|
||||
dimAt(operandAdaptor.R(), 2)},
|
||||
operandAdaptor.X().getType().cast<ShapedType>().getElementType());
|
||||
state.ct = insertAllocAndDealloc(ycMemRefType, loc, rewriter, true);
|
||||
}
|
||||
|
||||
// Initialize ht and ct.
|
||||
Value zero = emitConstantOp(rewriter, loc,
|
||||
operandAdaptor.X().getType().cast<ShapedType>().getElementType(), 0);
|
||||
int nLoops = 3;
|
||||
BuildKrnlLoop initializationLoops(rewriter, loc, nLoops);
|
||||
initializationLoops.createDefineOptimizeAndIterateOp(state.ht);
|
||||
auto ipInitializationLoops = rewriter.saveInsertionPoint();
|
||||
rewriter.setInsertionPointToStart(initializationLoops.getIterateBlock());
|
||||
{
|
||||
SmallVector<Value, 4> IVs;
|
||||
for (int i = 0; i < nLoops; ++i)
|
||||
IVs.emplace_back(initializationLoops.getInductionVar(i));
|
||||
|
||||
Value hiddenVal = zero;
|
||||
if (!isNoneType(operandAdaptor.initial_h()))
|
||||
hiddenVal = rewriter.create<LoadOp>(loc, operandAdaptor.initial_h(), IVs);
|
||||
rewriter.create<StoreOp>(loc, hiddenVal, state.ht, IVs);
|
||||
|
||||
Value cellVal = zero;
|
||||
if (!isNoneType(operandAdaptor.initial_c()))
|
||||
cellVal = rewriter.create<LoadOp>(loc, operandAdaptor.initial_c(), IVs);
|
||||
rewriter.create<StoreOp>(loc, cellVal, state.ct, IVs);
|
||||
}
|
||||
rewriter.restoreInsertionPoint(ipInitializationLoops);
|
||||
return state;
|
||||
}
|
||||
|
||||
template <>
|
||||
void calculateState<ONNXLSTMOp, LstmState, LstmActivationPack>(
|
||||
ConversionPatternRewriter &rewriter, Location loc,
|
||||
OperandAdaptor<ONNXLSTMOp> operandAdaptor, LstmState state,
|
||||
LstmActivationPack activationPack, Value directionIV, Value sequenceIV) {
|
||||
|
||||
bool hasBiasForInput = false, hasPeepholes = false;
|
||||
if (!isNoneType(operandAdaptor.B()))
|
||||
hasBiasForInput = true;
|
||||
if (!isNoneType(operandAdaptor.P()))
|
||||
hasPeepholes = true;
|
||||
|
||||
// Prepare dimensions.
|
||||
auto batchDimSize = dimAt(operandAdaptor.X(), 1);
|
||||
auto inputDimSize = dimAt(operandAdaptor.X(), 2);
|
||||
auto hiddenDimSize = dimAt(operandAdaptor.R(), 2);
|
||||
Value hiddenDimVal =
|
||||
emitConstantOp(rewriter, loc, rewriter.getIndexType(), hiddenDimSize);
|
||||
|
||||
auto elementType =
|
||||
operandAdaptor.X().getType().cast<ShapedType>().getElementType();
|
||||
|
||||
// Prepare AffineMap to access bias, peepholes tensors.
|
||||
AffineMap accessByOffsetMap;
|
||||
{
|
||||
AffineExpr iv = rewriter.getAffineDimExpr(0);
|
||||
AffineExpr index = rewriter.getAffineSymbolExpr(0);
|
||||
AffineExpr size = rewriter.getAffineSymbolExpr(1);
|
||||
AffineExpr accessByOffsetExpr = index * size + iv;
|
||||
accessByOffsetMap = AffineMap::get(1, 2, accessByOffsetExpr);
|
||||
}
|
||||
|
||||
// Prepare constant indices.
|
||||
SmallVector<Value, 4> constantIndices;
|
||||
for (int i = 0; i < 8; i++)
|
||||
constantIndices.emplace_back(
|
||||
emitConstantOp(rewriter, loc, rewriter.getIndexType(), i));
|
||||
|
||||
// Equations for LSTM.
|
||||
// it = f(Xt*(Wi^T) + Ht-1*(Ri^T) + Pi (.) Ct-1 + Wbi + Rbi)
|
||||
// ft = f(Xt*(Wf^T) + Ht-1*(Rf^T) + Pf (.) Ct-1 + Wbf + Rbf)
|
||||
// ct = g(Xt*(Wc^T) + Ht-1*(Rc^T) + Wbc + Rbc)
|
||||
// Ct = ft (.) Ct-1 + it (.) ct
|
||||
// ot = f(Xt*(Wo^T) + Ht-1*(Ro^T) + Po (.) Ct + Wbo + Rbo)
|
||||
// Ht = ot (.) h(Ct)
|
||||
//
|
||||
// The following code will emit loops as follows:
|
||||
// for b in 0 .. BatchDimSize
|
||||
// for h in 0 .. HiddenDimSize
|
||||
// for i in 0 .. InputDimSize {
|
||||
// compute Xt*(Wi^T), Xt*(Wo^T), Xt*(Wf^t), Xt*(Wc^T),
|
||||
// Ht-1*(Ri^T), Ht-1*(Ro^T), Ht-1*(Rf^t), Ht-1*(Rc^T)
|
||||
// }
|
||||
// compute it, ft, ct, Ct, ot, Ht
|
||||
|
||||
BuildKrnlLoop stateLoops(rewriter, loc, 2);
|
||||
stateLoops.createDefineAndOptimizeOp();
|
||||
stateLoops.pushBounds(0, batchDimSize);
|
||||
stateLoops.pushBounds(0, hiddenDimSize);
|
||||
stateLoops.createIterateOp();
|
||||
|
||||
rewriter.setInsertionPointToStart(stateLoops.getIterateBlock());
|
||||
{
|
||||
auto batchIV = stateLoops.getInductionVar(0);
|
||||
auto hiddenIV = stateLoops.getInductionVar(1);
|
||||
|
||||
// IVs to access tensors.
|
||||
// IVs for the hidden and cell state tensors.
|
||||
SmallVector<Value, 4> hIVs, cIVs;
|
||||
// IVs for the bias tensors for W and R.
|
||||
SmallVector<SmallVector<Value, 4>, 4> wbIOFCIVs, rbIOFCIVs;
|
||||
// IVs for the peepholes.
|
||||
SmallVector<SmallVector<Value, 4>, 4> pIOFIVs;
|
||||
|
||||
{ // Compute IVs.
|
||||
// H :: [num_directions, batch_size, hidden_size]
|
||||
hIVs = {directionIV, batchIV, hiddenIV};
|
||||
// C :: [num_directions, batch_size, hidden_size]
|
||||
cIVs = {directionIV, batchIV, hiddenIV};
|
||||
|
||||
// Bias [Wb[iofc], Rb[iofc]] :: [num_directions, 8*hidden_size]
|
||||
if (hasBiasForInput) {
|
||||
// Wb[iofc]
|
||||
for (unsigned i = 0; i < 4; ++i) {
|
||||
Value wHiddenIV =
|
||||
rewriter.create<AffineApplyOp>(loc, accessByOffsetMap,
|
||||
ValueRange(std::vector<Value>{/*iv=*/hiddenIV,
|
||||
/*index=*/constantIndices[i], /*size=*/hiddenDimVal}));
|
||||
wbIOFCIVs.emplace_back(SmallVector<Value, 2>{directionIV, wHiddenIV});
|
||||
}
|
||||
// Rb[iofc]
|
||||
for (unsigned i = 4; i < 8; ++i) {
|
||||
SmallVector<Value, 4> rbIVs;
|
||||
Value rHiddenIV =
|
||||
rewriter.create<AffineApplyOp>(loc, accessByOffsetMap,
|
||||
ValueRange(std::vector<Value>{/*iv=*/hiddenIV,
|
||||
/*index=*/constantIndices[i], /*size=*/hiddenDimVal}));
|
||||
rbIOFCIVs.emplace_back(SmallVector<Value, 2>{directionIV, rHiddenIV});
|
||||
}
|
||||
}
|
||||
|
||||
// Peepholes P[iof] :: [num_directions, 3*hidden_size]
|
||||
if (hasPeepholes) {
|
||||
for (unsigned i = 0; i < 3; ++i) {
|
||||
SmallVector<Value, 4> pIVs;
|
||||
Value pHiddenIV =
|
||||
rewriter.create<AffineApplyOp>(loc, accessByOffsetMap,
|
||||
ValueRange(std::vector<Value>{
|
||||
hiddenIV, constantIndices[i], hiddenDimVal}));
|
||||
pIOFIVs.emplace_back(SmallVector<Value, 2>{directionIV, pHiddenIV});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Value loadH = rewriter.create<LoadOp>(loc, state.ht, hIVs);
|
||||
Value loadC = rewriter.create<LoadOp>(loc, state.ct, cIVs);
|
||||
|
||||
// Emit instructions for matrix multiplications:
|
||||
// Xt*(Wi^T), Xt*(Wo^T), Xt*(Wf^t), Xt*(Wc^T)
|
||||
// Ht-1*(Ri^T), Ht-1*(Ro^T), Ht-1*(Rf^t), Ht-1*(Rc^T)
|
||||
|
||||
// Allocate memory for storing matrix multiplication results.
|
||||
SmallVector<Value, 4> xwIOFC, hrIOFC;
|
||||
Value zero = emitConstantOp(rewriter, loc, elementType, 0);
|
||||
MemRefType scalarMemRefType = MemRefType::get({}, elementType, {}, 0);
|
||||
for (unsigned i = 0; i < 4; ++i) {
|
||||
Value xwAlloc = rewriter.create<AllocOp>(loc, scalarMemRefType);
|
||||
rewriter.create<StoreOp>(loc, zero, xwAlloc);
|
||||
Value hrAlloc = rewriter.create<AllocOp>(loc, scalarMemRefType);
|
||||
rewriter.create<StoreOp>(loc, zero, hrAlloc);
|
||||
xwIOFC.emplace_back(xwAlloc);
|
||||
hrIOFC.emplace_back(hrAlloc);
|
||||
}
|
||||
|
||||
{ // Emit instructions for matrix multiplications.
|
||||
// input_size is the reduction dimension.
|
||||
BuildKrnlLoop reductionLoops(rewriter, loc, 1);
|
||||
reductionLoops.createDefineAndOptimizeOp();
|
||||
reductionLoops.pushBounds(0, inputDimSize);
|
||||
reductionLoops.createIterateOp();
|
||||
|
||||
auto ipReductionLoops = rewriter.saveInsertionPoint();
|
||||
rewriter.setInsertionPointToStart(reductionLoops.getIterateBlock());
|
||||
{
|
||||
auto reductionIV = reductionLoops.getInductionVar(0);
|
||||
// Prepare IVs for accessing the input tensor and parameters.
|
||||
SmallVector<Value, 4> xIVs;
|
||||
SmallVector<SmallVector<Value, 4>, 4> wIOFCIVs, rIOFCIVs;
|
||||
|
||||
// X :: [seq_length, batch_size, input_size]
|
||||
xIVs = {sequenceIV, batchIV, reductionIV};
|
||||
|
||||
// W[iofc] :: [num_directions, 4*hidden_size, input_size]
|
||||
// R[iofc] :: [num_directions, 4*hidden_size, input_size]
|
||||
for (unsigned i = 0; i < 4; ++i) {
|
||||
SmallVector<Value, 4> wIVs, rIVs;
|
||||
Value wHiddenIV =
|
||||
rewriter.create<AffineApplyOp>(loc, accessByOffsetMap,
|
||||
ValueRange(std::vector<Value>{
|
||||
hiddenIV, constantIndices[i], hiddenDimVal}));
|
||||
|
||||
wIVs = {directionIV, wHiddenIV, reductionIV};
|
||||
wIOFCIVs.emplace_back(wIVs);
|
||||
|
||||
rIVs = {directionIV, wHiddenIV, reductionIV};
|
||||
rIOFCIVs.emplace_back(rIVs);
|
||||
}
|
||||
|
||||
Value loadX = rewriter.create<LoadOp>(loc, operandAdaptor.X(), xIVs);
|
||||
for (unsigned i = 0; i < 4; ++i) {
|
||||
// Xt * Wiofc
|
||||
Value loadW =
|
||||
rewriter.create<LoadOp>(loc, operandAdaptor.W(), wIOFCIVs[i]);
|
||||
Value xwVal = rewriter.create<MulFOp>(loc, loadX, loadW);
|
||||
Value loadXW = rewriter.create<LoadOp>(loc, xwIOFC[i]);
|
||||
Value nextXW = rewriter.create<AddFOp>(loc, loadXW, xwVal);
|
||||
rewriter.create<StoreOp>(loc, nextXW, xwIOFC[i]);
|
||||
// Ht-1 * Riofc
|
||||
Value loadR =
|
||||
rewriter.create<LoadOp>(loc, operandAdaptor.R(), rIOFCIVs[i]);
|
||||
Value hrVal = rewriter.create<MulFOp>(loc, loadH, loadR);
|
||||
Value loadHR = rewriter.create<LoadOp>(loc, hrIOFC[i]);
|
||||
Value nextHR = rewriter.create<AddFOp>(loc, loadHR, hrVal);
|
||||
rewriter.create<StoreOp>(loc, nextHR, hrIOFC[i]);
|
||||
}
|
||||
}
|
||||
rewriter.restoreInsertionPoint(ipReductionLoops);
|
||||
}
|
||||
|
||||
// it = f(Xt*(Wi^T) + Ht-1*(Ri^T) + Pi (.) Ct-1 + Wbi + Rbi)
|
||||
Value loadXWI = rewriter.create<LoadOp>(loc, xwIOFC[0]);
|
||||
Value loadHRI = rewriter.create<LoadOp>(loc, hrIOFC[0]);
|
||||
Value it = rewriter.create<AddFOp>(loc, loadXWI, loadHRI);
|
||||
if (hasPeepholes) {
|
||||
Value loadP =
|
||||
rewriter.create<LoadOp>(loc, operandAdaptor.P(), pIOFIVs[0]);
|
||||
Value PC = rewriter.create<MulFOp>(loc, loadP, loadC);
|
||||
it = rewriter.create<AddFOp>(loc, it, PC);
|
||||
}
|
||||
if (hasBiasForInput) {
|
||||
Value loadWB =
|
||||
rewriter.create<LoadOp>(loc, operandAdaptor.B(), wbIOFCIVs[0]);
|
||||
it = rewriter.create<AddFOp>(loc, it, loadWB);
|
||||
Value loadRB =
|
||||
rewriter.create<LoadOp>(loc, operandAdaptor.B(), rbIOFCIVs[0]);
|
||||
it = rewriter.create<AddFOp>(loc, it, loadRB);
|
||||
}
|
||||
it = applyActivation(rewriter, loc, activationPack.f, it);
|
||||
|
||||
// ft = f(Xt*(Wf^T) + Ht-1*(Rf^T) + Pf (.) Ct-1 + Wbf + Rbf)
|
||||
Value loadXWF = rewriter.create<LoadOp>(loc, xwIOFC[2]);
|
||||
Value loadHRF = rewriter.create<LoadOp>(loc, hrIOFC[2]);
|
||||
Value ft = rewriter.create<AddFOp>(loc, loadXWF, loadHRF);
|
||||
if (hasPeepholes) {
|
||||
Value loadP =
|
||||
rewriter.create<LoadOp>(loc, operandAdaptor.P(), pIOFIVs[2]);
|
||||
Value PC = rewriter.create<MulFOp>(loc, loadP, loadC);
|
||||
ft = rewriter.create<AddFOp>(loc, ft, PC);
|
||||
}
|
||||
if (hasBiasForInput) {
|
||||
Value loadWB =
|
||||
rewriter.create<LoadOp>(loc, operandAdaptor.B(), wbIOFCIVs[2]);
|
||||
ft = rewriter.create<AddFOp>(loc, ft, loadWB);
|
||||
Value loadRB =
|
||||
rewriter.create<LoadOp>(loc, operandAdaptor.B(), rbIOFCIVs[2]);
|
||||
ft = rewriter.create<AddFOp>(loc, ft, loadRB);
|
||||
}
|
||||
ft = applyActivation(rewriter, loc, activationPack.f, ft);
|
||||
|
||||
// ct = g(Xt*(Wc^T) + Ht-1*(Rc^T) + Wbc + Rbc)
|
||||
Value loadXWC = rewriter.create<LoadOp>(loc, xwIOFC[3]);
|
||||
Value loadHRC = rewriter.create<LoadOp>(loc, hrIOFC[3]);
|
||||
Value ct = rewriter.create<AddFOp>(loc, loadXWC, loadHRC);
|
||||
if (hasBiasForInput) {
|
||||
Value loadWB =
|
||||
rewriter.create<LoadOp>(loc, operandAdaptor.B(), wbIOFCIVs[3]);
|
||||
ct = rewriter.create<AddFOp>(loc, ct, loadWB);
|
||||
Value loadRB =
|
||||
rewriter.create<LoadOp>(loc, operandAdaptor.B(), rbIOFCIVs[3]);
|
||||
ct = rewriter.create<AddFOp>(loc, ct, loadRB);
|
||||
}
|
||||
ct = applyActivation(rewriter, loc, activationPack.g, ct);
|
||||
|
||||
// Ct = ft (.) Ct-1 + it (.) ct
|
||||
Value FtCt1 = rewriter.create<MulFOp>(loc, ft, loadC);
|
||||
Value itct = rewriter.create<MulFOp>(loc, it, ct);
|
||||
Value Ct = rewriter.create<AddFOp>(loc, FtCt1, itct);
|
||||
rewriter.create<StoreOp>(loc, Ct, state.ct, cIVs);
|
||||
|
||||
// ot = f(Xt*(Wo^T) + Ht-1*(Ro^T) + Po (.) Ct + Wbo + Rbo)
|
||||
Value loadXWO = rewriter.create<LoadOp>(loc, xwIOFC[1]);
|
||||
Value loadHRO = rewriter.create<LoadOp>(loc, hrIOFC[1]);
|
||||
Value ot = rewriter.create<AddFOp>(loc, loadXWO, loadHRO);
|
||||
if (hasPeepholes) {
|
||||
Value loadP =
|
||||
rewriter.create<LoadOp>(loc, operandAdaptor.P(), pIOFIVs[1]);
|
||||
Value PC = rewriter.create<MulFOp>(loc, loadP, Ct);
|
||||
ot = rewriter.create<AddFOp>(loc, ot, PC);
|
||||
}
|
||||
if (hasBiasForInput) {
|
||||
Value loadWB =
|
||||
rewriter.create<LoadOp>(loc, operandAdaptor.B(), wbIOFCIVs[1]);
|
||||
ot = rewriter.create<AddFOp>(loc, ot, loadWB);
|
||||
Value loadRB =
|
||||
rewriter.create<LoadOp>(loc, operandAdaptor.B(), rbIOFCIVs[1]);
|
||||
ot = rewriter.create<AddFOp>(loc, ot, loadRB);
|
||||
}
|
||||
ot = applyActivation(rewriter, loc, activationPack.f, ot);
|
||||
|
||||
// Ht = ot (.) h(Ct)
|
||||
Value hCt = applyActivation(rewriter, loc, activationPack.h, Ct);
|
||||
Value Ht = rewriter.create<MulFOp>(loc, ot, hCt);
|
||||
rewriter.create<StoreOp>(loc, Ht, state.ht, hIVs);
|
||||
|
||||
// Store the current Ht if required.
|
||||
if (!isNoneType(state.allH)) {
|
||||
SmallVector<Value, 4> allHIVs{sequenceIV, directionIV, batchIV, hiddenIV};
|
||||
rewriter.create<StoreOp>(loc, Ht, state.allH, allHIVs);
|
||||
}
|
||||
|
||||
// Deallocate the temporary results of matrix multiplications.
|
||||
for (Value v : xwIOFC)
|
||||
rewriter.create<DeallocOp>(loc, v);
|
||||
for (Value v : hrIOFC)
|
||||
rewriter.create<DeallocOp>(loc, v);
|
||||
}
|
||||
}
|
||||
|
||||
template <>
|
||||
void stateToOutput<ONNXLSTMOp, LstmState>(
|
||||
ONNXLSTMOp *op, LstmState state, std::vector<Value> &outputs) {
|
||||
Value noneValue;
|
||||
outputs.emplace_back((isNoneType(op->Y()) ? noneValue : state.allH));
|
||||
outputs.emplace_back((isNoneType(op->Y_h()) ? noneValue : state.ht));
|
||||
outputs.emplace_back((isNoneType(op->Y_c()) ? noneValue : state.ct));
|
||||
}
|
||||
|
||||
void populateLoweringONNXLSTMOpPattern(
|
||||
OwningRewritePatternList &patterns, MLIRContext *ctx) {
|
||||
patterns.insert<ONNXRNNOpLowering<ONNXLSTMOp, LstmState, LstmActivationPack>>(
|
||||
ctx);
|
||||
}
|
|
@ -0,0 +1,73 @@
|
|||
//===--------------- RNNBase.cpp - Lowering RNN Ops -----------------------===//
|
||||
//
|
||||
// Copyright 2019 The IBM Research Authors.
|
||||
//
|
||||
// =============================================================================
|
||||
//
|
||||
// This file defines base functions for lowerng the ONNX RNN Operators.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "src/Conversion/ONNXToKrnl/RNN/RNNBase.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
// Check a Value's type is none or not.
|
||||
bool isNoneType(Value val) { return val.getType().isa<NoneType>(); }
|
||||
|
||||
// Get a dimension of the tensor's shape.
|
||||
int64_t dimAt(Value val, int index) {
|
||||
return val.getType().cast<ShapedType>().getShape()[index];
|
||||
}
|
||||
|
||||
// Apply an activation function on a given scalar operand.
|
||||
Value applyActivation(ConversionPatternRewriter &rewriter, Location loc,
|
||||
RNNActivation activation, Value scalarOperand) {
|
||||
Value res;
|
||||
|
||||
MemRefType scalarMemRefType =
|
||||
MemRefType::get({}, scalarOperand.getType(), {}, 0);
|
||||
Value alloc = rewriter.create<AllocOp>(loc, scalarMemRefType);
|
||||
rewriter.create<StoreOp>(loc, scalarOperand, alloc);
|
||||
|
||||
std::vector<mlir::NamedAttribute> attributes;
|
||||
if (activation.alpha) {
|
||||
attributes.emplace_back(
|
||||
rewriter.getNamedAttr("alpha", activation.alpha.getValue()));
|
||||
}
|
||||
if (activation.beta) {
|
||||
attributes.emplace_back(
|
||||
rewriter.getNamedAttr("beta", activation.beta.getValue()));
|
||||
}
|
||||
|
||||
if (activation.name.equals_lower("relu"))
|
||||
res = rewriter.create<ONNXReluOp>(loc, scalarMemRefType, alloc);
|
||||
else if (activation.name.equals_lower("tanh"))
|
||||
res = rewriter.create<ONNXTanhOp>(loc, scalarMemRefType, alloc);
|
||||
else if (activation.name.equals_lower("sigmoid"))
|
||||
res = rewriter.create<ONNXSigmoidOp>(loc, scalarMemRefType, alloc);
|
||||
else if (activation.name.equals_lower("affine"))
|
||||
emitError(loc, "Unsupported activation");
|
||||
else if (activation.name.equals_lower("leakyrelu"))
|
||||
res = rewriter.create<ONNXLeakyReluOp>(
|
||||
loc, scalarMemRefType, alloc, attributes);
|
||||
else if (activation.name.equals_lower("thresholdedrelu"))
|
||||
res = rewriter.create<ONNXThresholdedReluOp>(
|
||||
loc, scalarMemRefType, alloc, attributes);
|
||||
else if (activation.name.equals_lower("scaledtanh"))
|
||||
emitError(loc, "Unsupported activation");
|
||||
else if (activation.name.equals_lower("hardsigmoid"))
|
||||
res = rewriter.create<ONNXHardSigmoidOp>(
|
||||
loc, scalarMemRefType, alloc, attributes);
|
||||
else if (activation.name.equals_lower("elu"))
|
||||
res = rewriter.create<ONNXEluOp>(loc, scalarMemRefType, alloc, attributes);
|
||||
else if (activation.name.equals_lower("softsign"))
|
||||
res = rewriter.create<ONNXSoftsignOp>(loc, scalarMemRefType, alloc);
|
||||
else if (activation.name.equals_lower("softplus"))
|
||||
res = rewriter.create<ONNXSoftplusOp>(loc, scalarMemRefType, alloc);
|
||||
else
|
||||
llvm_unreachable("Unsupported activation");
|
||||
|
||||
Value result = rewriter.create<LoadOp>(loc, res);
|
||||
return result;
|
||||
}
|
|
@ -0,0 +1,144 @@
|
|||
//===--------------- RNNBase.hpp - Lowering RNN Ops -----------------------===//
|
||||
//
|
||||
// Copyright 2019 The IBM Research Authors.
|
||||
//
|
||||
// =============================================================================
|
||||
//
|
||||
// This file defines base functions for lowerng the ONNX RNN Operators.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/IR/AffineExpr.h"
|
||||
#include "src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
static const StringRef FORWARD = "forward";
|
||||
static const StringRef REVERSE = "reverse";
|
||||
static const StringRef BIDIRECTIONAL = "bidirectional";
|
||||
|
||||
struct RNNActivation {
|
||||
StringRef name;
|
||||
Optional<FloatAttr> alpha;
|
||||
Optional<FloatAttr> beta;
|
||||
};
|
||||
|
||||
// Check a Value's type is none or not.
|
||||
bool isNoneType(Value val);
|
||||
|
||||
// Get a dimension of the tensor's shape.
|
||||
int64_t dimAt(Value val, int index);
|
||||
|
||||
// Apply an activation function on a given scalar operand.
|
||||
Value applyActivation(ConversionPatternRewriter &rewriter, Location loc,
|
||||
RNNActivation activation, Value scalarOperand);
|
||||
|
||||
// Override the following methods when lowering an RNN operation:
|
||||
// - hasAllNoneOutput
|
||||
// - getActivationPack
|
||||
// - allocAndInitializeStates
|
||||
// - calculateState
|
||||
// - stateToOutput
|
||||
|
||||
// Check whether all outputs have NoneType or not.
|
||||
template <typename RNNOp>
|
||||
bool hasAllNoneOutput(RNNOp *op);
|
||||
|
||||
// Obtain activations functions for a specific operation.
|
||||
template <typename RNNOp, typename A>
|
||||
std::tuple<A, A> getActivationPack(RNNOp *op);
|
||||
|
||||
// Allocate memory for RNN states and initialize them.
|
||||
template <typename RNNOp, typename S>
|
||||
S allocAndInitializeStates(ConversionPatternRewriter &rewriter, Location loc,
|
||||
RNNOp *op, OperandAdaptor<RNNOp> operandAdaptor);
|
||||
|
||||
// Calculate new states from the current input and states.
|
||||
template <typename RNNOp, typename S, typename A>
|
||||
void calculateState(ConversionPatternRewriter &rewriter, Location loc,
|
||||
OperandAdaptor<RNNOp> operandAdaptor, S state, A activationSet,
|
||||
Value directionIV, Value sequenceIV);
|
||||
|
||||
// Write states to the RNN's outputs.
|
||||
template <typename RNNOp, typename S>
|
||||
void stateToOutput(RNNOp *op, S state, std::vector<Value> &outputs);
|
||||
|
||||
// A common template for lowering an RNN operation.
|
||||
template <typename RNNOp, typename S, typename A>
|
||||
struct ONNXRNNOpLowering : public ConversionPattern {
|
||||
ONNXRNNOpLowering(MLIRContext *ctx)
|
||||
: ConversionPattern(RNNOp::getOperationName(), 1, ctx) {}
|
||||
|
||||
LogicalResult matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const final {
|
||||
auto loc = op->getLoc();
|
||||
|
||||
RNNOp rnnOp = llvm::dyn_cast<RNNOp>(op);
|
||||
OperandAdaptor<RNNOp> operandAdaptor(operands);
|
||||
|
||||
if (hasAllNoneOutput<RNNOp>(&rnnOp)) {
|
||||
rewriter.eraseOp(op);
|
||||
return success();
|
||||
}
|
||||
|
||||
S state = allocAndInitializeStates<RNNOp, S>(
|
||||
rewriter, loc, &rnnOp, operandAdaptor);
|
||||
|
||||
A activationForward, activationReverse;
|
||||
std::tie(activationForward, activationReverse) =
|
||||
getActivationPack<RNNOp, A>(&rnnOp);
|
||||
|
||||
int64_t sequenceDimSize = dimAt(rnnOp.X(), 0);
|
||||
auto direction = rnnOp.direction();
|
||||
|
||||
if (direction == FORWARD || direction == BIDIRECTIONAL) {
|
||||
BuildKrnlLoop sequenceLoops(rewriter, loc, 1);
|
||||
sequenceLoops.createDefineAndOptimizeOp();
|
||||
sequenceLoops.pushBounds(0, sequenceDimSize);
|
||||
sequenceLoops.createIterateOp();
|
||||
|
||||
auto ipSequenceLoops = rewriter.saveInsertionPoint();
|
||||
rewriter.setInsertionPointToStart(sequenceLoops.getIterateBlock());
|
||||
{
|
||||
Value directionIV =
|
||||
emitConstantOp(rewriter, loc, rewriter.getIndexType(), 0);
|
||||
Value sequenceIV = sequenceLoops.getInductionVar(0);
|
||||
// Emit calculation for one RNN step.
|
||||
calculateState<RNNOp, S, A>(rewriter, loc, operandAdaptor, state,
|
||||
activationForward, directionIV, sequenceIV);
|
||||
}
|
||||
rewriter.restoreInsertionPoint(ipSequenceLoops);
|
||||
}
|
||||
|
||||
if (direction == REVERSE || direction == BIDIRECTIONAL) {
|
||||
BuildKrnlLoop sequenceLoops(rewriter, loc, 1);
|
||||
sequenceLoops.createDefineAndOptimizeOp();
|
||||
sequenceLoops.pushBounds(0, sequenceDimSize);
|
||||
sequenceLoops.createIterateOp();
|
||||
|
||||
auto ipSequenceLoops = rewriter.saveInsertionPoint();
|
||||
rewriter.setInsertionPointToStart(sequenceLoops.getIterateBlock());
|
||||
{
|
||||
AffineMap reverseIVMap = AffineMap::get(1, 1,
|
||||
rewriter.getAffineSymbolExpr(0) - rewriter.getAffineDimExpr(0) - 1);
|
||||
|
||||
Value directionIV = emitConstantOp(rewriter, loc,
|
||||
rewriter.getIndexType(), (direction == REVERSE) ? 0 : 1);
|
||||
Value reverseSequenceIV =
|
||||
rewriter.create<AffineApplyOp>(loc, reverseIVMap,
|
||||
ValueRange(std::vector<Value>{sequenceLoops.getInductionVar(0),
|
||||
emitConstantOp(rewriter, loc, rewriter.getIndexType(),
|
||||
sequenceDimSize)}));
|
||||
// Emit calculation for one RNN step.
|
||||
calculateState<RNNOp, S, A>(rewriter, loc, operandAdaptor, state,
|
||||
activationReverse, directionIV, reverseSequenceIV);
|
||||
}
|
||||
rewriter.restoreInsertionPoint(ipSequenceLoops);
|
||||
}
|
||||
|
||||
std::vector<Value> outputs;
|
||||
stateToOutput<RNNOp, S>(&rnnOp, state, outputs);
|
||||
rewriter.replaceOp(op, outputs);
|
||||
return success();
|
||||
}
|
||||
};
|
|
@ -299,6 +299,112 @@ static void insertConvSpatialDim(SmallVector<int64_t, 4> *outputDims,
|
|||
}
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Support function that infers shape for RNN operations.
|
||||
template <typename T>
|
||||
static bool RNNShapeInference(T *op) {
|
||||
Value X = op->X();
|
||||
Value W = op->W();
|
||||
Value R = op->R();
|
||||
|
||||
if (!X.getType().isa<RankedTensorType>() ||
|
||||
!W.getType().isa<RankedTensorType>() ||
|
||||
!R.getType().isa<RankedTensorType>())
|
||||
return false;
|
||||
|
||||
auto xTy = X.getType().cast<RankedTensorType>();
|
||||
auto elementType = xTy.getElementType();
|
||||
|
||||
// xShape :: [seq_length, batch_size, input_size]
|
||||
auto xShape = xTy.getShape();
|
||||
// wShape :: [num_directions, 4*hidden_size, input_size]
|
||||
auto wShape = W.getType().cast<RankedTensorType>().getShape();
|
||||
// rShape :: [num_directions, 4*hidden_size, hidden_size]
|
||||
auto rShape = R.getType().cast<RankedTensorType>().getShape();
|
||||
|
||||
if (xShape.size() != 3) {
|
||||
op->emitError("The first input tensor must have rank 3");
|
||||
return false;
|
||||
}
|
||||
if (wShape.size() != 3) {
|
||||
op->emitError("The second input tensor must have rank 3");
|
||||
return false;
|
||||
}
|
||||
if (rShape.size() != 3) {
|
||||
op->emitError("The third input tensor must have rank 3");
|
||||
return false;
|
||||
}
|
||||
|
||||
// Get sequence length, batch size and input size.
|
||||
auto sequenceLength = xShape[0];
|
||||
auto batchSize = xShape[1];
|
||||
auto inputSize = xShape[2];
|
||||
|
||||
// Get hidden size from hidden_size attribute.
|
||||
int64_t hiddenSize = -1;
|
||||
if (op->hidden_size().hasValue()) {
|
||||
hiddenSize = op->hidden_size().getValue().getSExtValue();
|
||||
} else {
|
||||
// Infer hidden_size from wShape and rShape if possible.
|
||||
if (rShape[2] != -1)
|
||||
hiddenSize = rShape[2];
|
||||
else if (rShape[1] != -1)
|
||||
hiddenSize = rShape[1] / 4;
|
||||
else if (wShape[1] != -1)
|
||||
hiddenSize = wShape[1] / 4;
|
||||
// Update hidden_size attribute.
|
||||
if (hiddenSize != -1) {
|
||||
auto builder = mlir::Builder(op->getContext());
|
||||
op->hidden_sizeAttr(builder.getI64IntegerAttr(hiddenSize));
|
||||
}
|
||||
}
|
||||
|
||||
// Get direction.
|
||||
int numDirection;
|
||||
if ((op->direction() == "forward") || (op->direction() == "reverse"))
|
||||
numDirection = 1;
|
||||
else if (op->direction() == "bidirectional")
|
||||
numDirection = 2;
|
||||
else
|
||||
numDirection = -1;
|
||||
if (numDirection == -1) {
|
||||
op->emitError("direction attribute muse be one of the strings: forward, "
|
||||
"reverse, and bidirectional");
|
||||
return false;
|
||||
}
|
||||
|
||||
// Set result types.
|
||||
unsigned numOfResults = op->getNumResults();
|
||||
if (numOfResults > 0) {
|
||||
// Y :: [seq_length, num_directions, batch_size, hidden_size]
|
||||
Type yTy = op->getResults()[0].getType();
|
||||
if (!yTy.isa<NoneType>()) {
|
||||
yTy = RankedTensorType::get(
|
||||
{sequenceLength, numDirection, batchSize, hiddenSize}, elementType);
|
||||
op->getResults()[0].setType(yTy);
|
||||
}
|
||||
}
|
||||
if (numOfResults > 1) {
|
||||
// Y_h :: [num_directions, batch_size, hidden_size]
|
||||
Type yhTy = op->getResults()[1].getType();
|
||||
if (!yhTy.isa<NoneType>()) {
|
||||
yhTy = RankedTensorType::get(
|
||||
{numDirection, batchSize, hiddenSize}, elementType);
|
||||
op->getResults()[1].setType(yhTy);
|
||||
}
|
||||
}
|
||||
if (numOfResults > 2) {
|
||||
// Y_c :: [num_directions, batch_size, hidden_size]
|
||||
Type ycTy = op->getResults()[2].getType();
|
||||
if (!ycTy.isa<NoneType>()) {
|
||||
ycTy = RankedTensorType::get(
|
||||
{numDirection, batchSize, hiddenSize}, elementType);
|
||||
op->getResults()[2].setType(ycTy);
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ONNXOpsDialect
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -1472,7 +1578,6 @@ bool ONNXConstantOp::inferShapes() {
|
|||
return true;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Concat
|
||||
|
||||
bool ONNXConcatOp::inferShapes() {
|
||||
|
@ -1537,6 +1642,21 @@ bool ONNXConcatOp::inferShapes() {
|
|||
return true;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// RNN
|
||||
|
||||
bool ONNXRNNOp::inferShapes() { return RNNShapeInference<>(this); }
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// LSTM
|
||||
|
||||
bool ONNXLSTMOp::inferShapes() { return RNNShapeInference<>(this); }
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// GRU
|
||||
|
||||
bool ONNXGRUOp::inferShapes() { return RNNShapeInference<>(this); }
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Split
|
||||
|
||||
|
|
|
@ -738,7 +738,7 @@ def ONNXFloorOp:ONNX_Op<"Floor",
|
|||
}
|
||||
|
||||
def ONNXGRUOp:ONNX_Op<"GRU",
|
||||
[NoSideEffect]> {
|
||||
[NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
|
||||
let summary = "ONNX GRU operation";
|
||||
let description = [{
|
||||
"Computes an one-layer GRU. This operator is usually supported via some custom"
|
||||
|
@ -1246,7 +1246,7 @@ def ONNXLRNOp:ONNX_Op<"LRN",
|
|||
}
|
||||
|
||||
def ONNXLSTMOp:ONNX_Op<"LSTM",
|
||||
[NoSideEffect]> {
|
||||
[NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
|
||||
let summary = "ONNX LSTM operation";
|
||||
let description = [{
|
||||
"Computes an one-layer LSTM. This operator is usually supported via some"
|
||||
|
@ -2086,7 +2086,7 @@ def ONNXQuantizeLinearOp:ONNX_Op<"QuantizeLinear",
|
|||
}
|
||||
|
||||
def ONNXRNNOp:ONNX_Op<"RNN",
|
||||
[NoSideEffect]> {
|
||||
[NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
|
||||
let summary = "ONNX RNN operation";
|
||||
let description = [{
|
||||
"Computes an one-layer simple RNN. This operator is usually supported"
|
||||
|
|
|
@ -126,6 +126,9 @@ public:
|
|||
op->getName().getStringRef() != "onnx.Concat" &&
|
||||
op->getName().getStringRef() != "onnx.Split" &&
|
||||
op->getName().getStringRef() != "onnx.Neg" &&
|
||||
op->getName().getStringRef() != "onnx.RNN" &&
|
||||
op->getName().getStringRef() != "onnx.LSTM" &&
|
||||
op->getName().getStringRef() != "onnx.GRU" &&
|
||||
op->getName().getStringRef() != "onnx.Unsqueeze")
|
||||
return false;
|
||||
return llvm::any_of(op->getResultTypes(), [](Type result_type) {
|
||||
|
|
|
@ -356,6 +356,11 @@ test_to_enable = [
|
|||
"test_averagepool_2d_strides_cpu",
|
||||
"test_averagepool_3d_default_cpu",
|
||||
|
||||
# LSTM
|
||||
"test_lstm_defaults_cpu",
|
||||
"test_lstm_with_initial_bias_cpu",
|
||||
"test_lstm_with_peepholes_cpu",
|
||||
|
||||
]
|
||||
|
||||
# Extract name of all test cases.
|
||||
|
|
|
@ -1769,3 +1769,314 @@ func @test_maxpool_pooling_operation(%arg0 : tensor<1x3x32x32xf32>) -> tensor<*x
|
|||
// CHECK: }
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @test_lstm_general_computation(%arg0: tensor<4x3x2xf32>, %arg1: tensor<1x12x2xf32>, %arg2: tensor<1x12x3xf32>) -> tensor<*xf32> {
|
||||
%cst = constant unit
|
||||
%Y, %Y_h, %Y_c = "onnx.LSTM"(%arg0, %arg1, %arg2, %cst, %cst, %cst, %cst, %cst) {hidden_size = 3 : i64} : (tensor<4x3x2xf32>, tensor<1x12x2xf32>, tensor<1x12x3xf32>, none, none, none, none, none) -> (none, tensor<*xf32>, none)
|
||||
return %Y_h : tensor<*xf32>
|
||||
|
||||
// CHECK-DAG: [[ACCESS_BY_OFFSET_MAP:#.+]] = affine_map<(d0)[s0, s1] -> (d0 + s0 * s1)>
|
||||
// CHECK-LABEL: @test_lstm_general_computation
|
||||
|
||||
// CHECK: [[CELL_STATE:%.+]] = alloc() : memref<1x3x3xf32>
|
||||
// CHECK: [[HIDDEN_STATE:%.+]] = alloc() : memref<1x3x3xf32>
|
||||
// CHECK: {{.*}} = constant unit
|
||||
|
||||
// CHECK: [[INITIAL_VALUE:%.+]] = constant 0.000000e+00 : f32
|
||||
// CHECK: [[INITIALIZE_LOOPS:%.+]]:3 = krnl.define_loops 3
|
||||
// CHECK: [[INITIALIZE_OPT_LOOPS:%.+]]:3 = krnl.optimize_loops {
|
||||
// CHECK: krnl.return_loops [[INITIALIZE_LOOPS]]#0, [[INITIALIZE_LOOPS]]#1, [[INITIALIZE_LOOPS]]#2
|
||||
// CHECK: } : () -> (!krnl.loop, !krnl.loop, !krnl.loop)
|
||||
// CHECK: krnl.iterate([[INITIALIZE_OPT_LOOPS]]#0, [[INITIALIZE_OPT_LOOPS]]#1, [[INITIALIZE_OPT_LOOPS]]#2) with ([[INITIALIZE_LOOPS]]#0 -> %arg3 = 0 to 1, [[INITIALIZE_LOOPS]]#1 -> %arg4 = 0 to 3, [[INITIALIZE_LOOPS]]#2 -> %arg5 = 0 to 3) {
|
||||
// CHECK: store [[INITIAL_VALUE]], [[HIDDEN_STATE]][%arg3, %arg4, %arg5] : memref<1x3x3xf32>
|
||||
// CHECK: store [[INITIAL_VALUE]], [[CELL_STATE]][%arg3, %arg4, %arg5] : memref<1x3x3xf32>
|
||||
// CHECK: }
|
||||
|
||||
// CHECK: [[SEQUENCE_LOOPS:%.+]] = krnl.define_loops 1
|
||||
// CHECK: [[SEQUENCE_OPT_LOOPS:%.+]] = krnl.optimize_loops {
|
||||
// CHECK: krnl.return_loops [[SEQUENCE_LOOPS]]
|
||||
// CHECK: } : () -> !krnl.loop
|
||||
// CHECK: krnl.iterate([[SEQUENCE_OPT_LOOPS]]) with ([[SEQUENCE_LOOPS]] -> %arg3 = 0 to 4) {
|
||||
// CHECK: {{.*}} = constant 0 : index
|
||||
// CHECK: {{.*}} = constant 3 : index
|
||||
// CHECK: {{.*}} = constant 0 : index
|
||||
// CHECK: {{.*}} = constant 1 : index
|
||||
// CHECK: {{.*}} = constant 2 : index
|
||||
// CHECK: {{.*}} = constant 3 : index
|
||||
// CHECK: {{.*}} = constant 4 : index
|
||||
// CHECK: {{.*}} = constant 5 : index
|
||||
// CHECK: {{.*}} = constant 6 : index
|
||||
// CHECK: {{.*}} = constant 7 : index
|
||||
// CHECK: [[DATA_LOOPS:%.+]]:2 = krnl.define_loops 2
|
||||
// CHECK: [[DATA_OPT_LOOPS:%.+]]:2 = krnl.optimize_loops {
|
||||
// CHECK: krnl.return_loops [[DATA_LOOPS]]#0, [[DATA_LOOPS]]#1
|
||||
// CHECK: } : () -> (!krnl.loop, !krnl.loop)
|
||||
// CHECK: krnl.iterate([[DATA_OPT_LOOPS]]#0, [[DATA_OPT_LOOPS]]#1) with ([[DATA_LOOPS]]#0 -> %arg4 = 0 to 3, [[DATA_LOOPS]]#1 -> %arg5 = 0 to 3) {
|
||||
// CHECK: [[hCt:%.+]] = alloc() : memref<f32>
|
||||
// CHECK: [[Ot:%.+]] = alloc() : memref<f32>
|
||||
// CHECK: [[ct:%.+]] = alloc() : memref<f32>
|
||||
// CHECK: [[Ft:%.+]] = alloc() : memref<f32>
|
||||
// CHECK: [[It:%.+]] = alloc() : memref<f32>
|
||||
// CHECK: [[Ht1_LOAD:%.+]] = load [[HIDDEN_STATE]][%c0, %arg4, %arg5] : memref<1x3x3xf32>
|
||||
// CHECK: [[Ct1_LOAD:%.+]] = load [[CELL_STATE]][%c0, %arg4, %arg5] : memref<1x3x3xf32>
|
||||
|
||||
// CHECK: [[ZERO_FLOAT:%.+]] = constant 0.000000e+00 : f32
|
||||
// CHECK: [[XtWi_GEMM:%.+]] = alloc() : memref<f32>
|
||||
// CHECK: store [[ZERO_FLOAT]], [[XtWi_GEMM]][] : memref<f32>
|
||||
// CHECK: [[Ht1Ri_GEMM:%.+]] = alloc() : memref<f32>
|
||||
// CHECK: store [[ZERO_FLOAT]], [[Ht1Ri_GEMM]][] : memref<f32>
|
||||
// CHECK: [[XtWo_GEMM:%.+]] = alloc() : memref<f32>
|
||||
// CHECK: store [[ZERO_FLOAT]], [[XtWo_GEMM]][] : memref<f32>
|
||||
// CHECK: [[Ht1Ro_GEMM:%.+]] = alloc() : memref<f32>
|
||||
// CHECK: store [[ZERO_FLOAT]], [[Ht1Ro_GEMM]][] : memref<f32>
|
||||
// CHECK: [[XtWf_GEMM:%.+]] = alloc() : memref<f32>
|
||||
// CHECK: store [[ZERO_FLOAT]], [[XtWf_GEMM]][] : memref<f32>
|
||||
// CHECK: [[Ht1Rf_GEMM:%.+]] = alloc() : memref<f32>
|
||||
// CHECK: store [[ZERO_FLOAT]], [[Ht1Rf_GEMM]][] : memref<f32>
|
||||
// CHECK: [[XtWc_GEMM:%.+]] = alloc() : memref<f32>
|
||||
// CHECK: store [[ZERO_FLOAT]], [[XtWc_GEMM]][] : memref<f32>
|
||||
// CHECK: [[Ht1Rc_GEMM:%.+]] = alloc() : memref<f32>
|
||||
// CHECK: store [[ZERO_FLOAT]], [[Ht1Rc_GEMM]][] : memref<f32>
|
||||
|
||||
// CHECK: [[REDUCTION_LOOPS:%.+]] = krnl.define_loops 1
|
||||
// CHECK: [[REDUCTION_OPT_LOOPS:%.+]] = krnl.optimize_loops {
|
||||
// CHECK: krnl.return_loops [[REDUCTION_LOOPS]]
|
||||
// CHECK: } : () -> !krnl.loop
|
||||
// CHECK: krnl.iterate([[REDUCTION_OPT_LOOPS]]) with ([[REDUCTION_LOOPS]] -> %arg6 = 0 to 2) {
|
||||
// CHECK: [[INPUT_HIDDEN_INDEX:%.+]] = affine.apply #{{.*}}(%arg5)[%c0_1, %c3]
|
||||
// CHECK: [[OUTPUT_HIDDEN_INDEX:%.+]] = affine.apply #{{.*}}(%arg5)[%c1, %c3]
|
||||
// CHECK: [[FORGET_HIDDEN_INDEX:%.+]] = affine.apply #{{.*}}(%arg5)[%c2, %c3]
|
||||
// CHECK: [[CELL_HIDDEN_INDEX:%.+]] = affine.apply #{{.*}}(%arg5)[%c3_2, %c3]
|
||||
// CHECK: [[Xt_LOAD:%.+]] = load %arg0[%arg3, %arg4, %arg6] : memref<4x3x2xf32>
|
||||
|
||||
// CHECK: [[Wi_LOAD:%.+]] = load %arg1[%c0, [[INPUT_HIDDEN_INDEX]], %arg6] : memref<1x12x2xf32>
|
||||
// CHECK: {{.*}} = mulf [[Xt_LOAD]], [[Wi_LOAD]] : f32
|
||||
// CHECK: {{.*}} = load [[XtWi_GEMM]][] : memref<f32>
|
||||
// CHECK: {{.*}} = addf {{.*}}, {{.*}} : f32
|
||||
// CHECK: store %59, [[XtWi_GEMM]][] : memref<f32>
|
||||
|
||||
// CHECK: [[Ri_LOAD:%.+]] = load %arg2[%c0, [[INPUT_HIDDEN_INDEX]], %arg6] : memref<1x12x3xf32>
|
||||
// CHECK: {{.*}} = mulf [[Ht1_LOAD]], [[Ri_LOAD]] : f32
|
||||
// CHECK: {{.*}} = load [[Ht1Ri_GEMM]][] : memref<f32>
|
||||
// CHECK: {{.*}} = addf {{.*}}, {{.*}} : f32
|
||||
// CHECK: store %63, [[Ht1Ri_GEMM]][] : memref<f32>
|
||||
|
||||
// CHECK: [[Wo_LOAD:%.+]] = load %arg1[%c0, [[OUTPUT_HIDDEN_INDEX]], %arg6] : memref<1x12x2xf32>
|
||||
// CHECK: {{.*}} = mulf [[Xt_LOAD]], [[Wo_LOAD]] : f32
|
||||
// CHECK: {{.*}} = load [[XtWo_GEMM]][] : memref<f32>
|
||||
// CHECK: {{.*}} = addf {{.*}}, {{.*}} : f32
|
||||
// CHECK: store %67, [[XtWo_GEMM]][] : memref<f32>
|
||||
|
||||
// CHECK: [[Ro_LOAD:%.+]] = load %arg2[%c0, [[OUTPUT_HIDDEN_INDEX]], %arg6] : memref<1x12x3xf32>
|
||||
// CHECK: {{.*}} = mulf [[Ht1_LOAD]], [[Ro_LOAD]] : f32
|
||||
// CHECK: {{.*}} = load [[Ht1Ro_GEMM]][] : memref<f32>
|
||||
// CHECK: {{.*}} = addf {{.*}}, {{.*}} : f32
|
||||
// CHECK: store %71, [[Ht1Ro_GEMM]][] : memref<f32>
|
||||
|
||||
// CHECK: [[Wf_LOAD:%.+]] = load %arg1[%c0, [[FORGET_HIDDEN_INDEX]], %arg6] : memref<1x12x2xf32>
|
||||
// CHECK: {{.*}} = mulf [[Xt_LOAD]], [[Wf_LOAD]] : f32
|
||||
// CHECK: {{.*}} = load [[XtWf_GEMM]][] : memref<f32>
|
||||
// CHECK: {{.*}} = addf {{.*}}, {{.*}} : f32
|
||||
// CHECK: store %75, [[XtWf_GEMM]][] : memref<f32>
|
||||
|
||||
// CHECK: [[Rf_LOAD:%.+]] = load %arg2[%c0, [[FORGET_HIDDEN_INDEX]], %arg6] : memref<1x12x3xf32>
|
||||
// CHECK: {{.*}} = mulf [[Ht1_LOAD]], [[Rf_LOAD]] : f32
|
||||
// CHECK: {{.*}} = load [[Ht1Rf_GEMM]][] : memref<f32>
|
||||
// CHECK: {{.*}} = addf {{.*}}, {{.*}} : f32
|
||||
// CHECK: store %79, [[Ht1Rf_GEMM]][] : memref<f32>
|
||||
|
||||
// CHECK: [[Wc_LOAD:%.+]] = load %arg1[%c0, [[CELL_HIDDEN_INDEX]], %arg6] : memref<1x12x2xf32>
|
||||
// CHECK: {{.*}} = mulf [[Xt_LOAD]], [[Wc_LOAD]] : f32
|
||||
// CHECK: {{.*}} = load [[XtWc_GEMM]][] : memref<f32>
|
||||
// CHECK: {{.*}} = addf {{.*}}, {{.*}} : f32
|
||||
// CHECK: store %83, [[XtWc_GEMM]][] : memref<f32>
|
||||
|
||||
// CHECK: [[Rc_LOAD:%.+]] = load %arg2[%c0, [[CELL_HIDDEN_INDEX]], %arg6] : memref<1x12x3xf32>
|
||||
// CHECK: {{.*}} = mulf [[Ht1_LOAD]], [[Rc_LOAD]] : f32
|
||||
// CHECK: {{.*}} = load [[Ht1Rc_GEMM]][] : memref<f32>
|
||||
// CHECK: {{.*}} = addf {{.*}}, {{.*}} : f32
|
||||
// CHECK: store %87, [[Ht1Rc_GEMM]][] : memref<f32>
|
||||
// CHECK: }
|
||||
|
||||
// CHECK: [[XtWi_LOAD:%.+]] = load [[XtWi_GEMM]][] : memref<f32>
|
||||
// CHECK: [[Ht1Ri_LOAD:%.+]] = load [[Ht1Ri_GEMM]][] : memref<f32>
|
||||
// CHECK: [[It_OUTPUT:%.+]] = addf [[XtWi_LOAD]], [[Ht1Ri_LOAD]] : f32
|
||||
|
||||
// CHECK: [[SIGMOID_INPUT:%.+]] = alloc() : memref<f32>
|
||||
// CHECK: store [[It_OUTPUT]], [[SIGMOID_INPUT]][] : memref<f32>
|
||||
// CHECK: krnl.define_loops 0
|
||||
// CHECK: krnl.optimize_loops {
|
||||
// CHECK: krnl.return_loops
|
||||
// CHECK: } : () -> ()
|
||||
// CHECK: krnl.iterate() with () {
|
||||
// CHECK: {{.*}} = load [[SIGMOID_INPUT]][] : memref<f32>
|
||||
// CHECK: {{.*}} = constant 0.000000e+00 : f32
|
||||
// CHECK: {{.*}} = constant 1.000000e+00 : f32
|
||||
// CHECK: {{.*}} = subf {{.*}}, {{.*}}: f32
|
||||
// CHECK: {{.*}} = exp {{.*}} : f32
|
||||
// CHECK: {{.*}} = addf {{.*}}, {{.*}} : f32
|
||||
// CHECK: {{.*}} = divf {{.*}}, {{.*}} : f32
|
||||
// CHECK: store {{.*}}, [[It]][] : memref<f32>
|
||||
// CHECK: }
|
||||
// CHECK: [[It_LOAD:%.+]] = load [[It]][] : memref<f32>
|
||||
|
||||
// CHECK: [[XtWf_LOAD:%.+]] = load [[XtWf_GEMM]][] : memref<f32>
|
||||
// CHECK: [[Ht1Rf_LOAD:%.+]] = load [[Ht1Rf_GEMM]][] : memref<f32>
|
||||
// CHECK: [[Ft_OUTPUT:%.+]] = addf [[XtWf_LOAD]], [[Ht1Rf_LOAD]] : f32
|
||||
|
||||
// CHECK: [[SIGMOID_FORGET:%.+]] = alloc() : memref<f32>
|
||||
// CHECK: store [[Ft_OUTPUT]], [[SIGMOID_FORGET]][] : memref<f32>
|
||||
// CHECK: krnl.define_loops 0
|
||||
// CHECK: krnl.optimize_loops {
|
||||
// CHECK: krnl.return_loops
|
||||
// CHECK: } : () -> ()
|
||||
// CHECK: krnl.iterate() with () {
|
||||
// CHECK: {{.*}} = load [[SIGMOID_FORGET]][] : memref<f32>
|
||||
// CHECK: {{.*}} = constant 0.000000e+00 : f32
|
||||
// CHECK: {{.*}} = constant 1.000000e+00 : f32
|
||||
// CHECK: {{.*}} = subf {{.*}}, {{.*}}: f32
|
||||
// CHECK: {{.*}} = exp {{.*}} : f32
|
||||
// CHECK: {{.*}} = addf {{.*}}, {{.*}} : f32
|
||||
// CHECK: {{.*}} = divf {{.*}}, {{.*}} : f32
|
||||
// CHECK: store {{.*}}, [[Ft]][] : memref<f32>
|
||||
// CHECK: }
|
||||
// CHECK: [[Ft_LOAD:%.+]] = load [[Ft]][] : memref<f32>
|
||||
|
||||
// CHECK: [[XtWc_LOAD:%.+]] = load [[XtWc_GEMM]][] : memref<f32>
|
||||
// CHECK: [[Ht1Rc_LOAD:%.+]] = load [[Ht1Rc_GEMM]][] : memref<f32>
|
||||
// CHECK: [[ct_OUTPUT:%.+]] = addf [[XtWc_LOAD]], [[Ht1Rc_LOAD]] : f32
|
||||
|
||||
// CHECK: [[TANH_CELL:%.+]] = alloc() : memref<f32>
|
||||
// CHECK: store [[ct_OUTPUT]], [[TANH_CELL]][] : memref<f32>
|
||||
// CHECK: krnl.define_loops 0
|
||||
// CHECK: krnl.optimize_loops {
|
||||
// CHECK: krnl.return_loops
|
||||
// CHECK: } : () -> ()
|
||||
// CHECK: krnl.iterate() with () {
|
||||
// CHECK: {{.*}} = load [[TANH_CELL]][] : memref<f32>
|
||||
// CHECK: {{.*}} = constant 0.000000e+00 : f32
|
||||
// CHECK: {{.*}} = subf {{.*}}, {{.*}} : f32
|
||||
// CHECK: {{.*}} = exp {{.*}} : f32
|
||||
// CHECK: {{.*}} = exp {{.*}} : f32
|
||||
// CHECK: {{.*}} = subf {{.*}}, {{.*}} : f32
|
||||
// CHECK: {{.*}} = addf {{.*}}, {{.*}} : f32
|
||||
// CHECK: {{.*}} = divf {{.*}}, {{.*}} : f32
|
||||
// CHECK: store {{.*}}, [[ct]][] : memref<f32>
|
||||
// CHECK: }
|
||||
// CHECK: [[ct_LOAD:%.+]] = load [[ct]][] : memref<f32>
|
||||
|
||||
// CHECK: [[FtCt1:%.+]] = mulf [[Ft_LOAD]], [[Ct1_LOAD]] : f32
|
||||
// CHECK: [[Itct:%.+]] = mulf [[It_LOAD]], [[ct_LOAD]] : f32
|
||||
// CHECK: [[Ct:%.+]] = addf [[FtCt1]], [[Itct]] : f32
|
||||
// CHECK: store [[Ct]], [[CELL_STATE]][%c0, %arg4, %arg5] : memref<1x3x3xf32>
|
||||
|
||||
// CHECK: [[XtWo_LOAD:%.+]] = load [[XtWo_GEMM]][] : memref<f32>
|
||||
// CHECK: [[Ht1Ro_LOAD:%.+]] = load [[Ht1Ro_GEMM]][] : memref<f32>
|
||||
// CHECK: [[Ot_OUTPUT:%.+]] = addf [[XtWo_LOAD]], [[Ht1Ro_LOAD]] : f32
|
||||
|
||||
// CHECK: [[SIGMOID_OUTPUT:%.+]] = alloc() : memref<f32>
|
||||
// CHECK: store [[Ot_OUTPUT]], [[SIGMOID_OUTPUT]][] : memref<f32>
|
||||
// CHECK: krnl.define_loops 0
|
||||
// CHECK: krnl.optimize_loops {
|
||||
// CHECK: krnl.return_loops
|
||||
// CHECK: } : () -> ()
|
||||
// CHECK: krnl.iterate() with () {
|
||||
// CHECK: {{.*}} = load [[SIGMOID_OUTPUT]][] : memref<f32>
|
||||
// CHECK: {{.*}} = constant 0.000000e+00 : f32
|
||||
// CHECK: {{.*}} = constant 1.000000e+00 : f32
|
||||
// CHECK: {{.*}} = subf {{.*}}, {{.*}}: f32
|
||||
// CHECK: {{.*}} = exp {{.*}} : f32
|
||||
// CHECK: {{.*}} = addf {{.*}}, {{.*}} : f32
|
||||
// CHECK: {{.*}} = divf {{.*}}, {{.*}} : f32
|
||||
// CHECK: store {{.*}}, [[Ot]][] : memref<f32>
|
||||
// CHECK: }
|
||||
// CHECK: [[Ot_LOAD:%.+]] = load [[Ot]][] : memref<f32>
|
||||
|
||||
// CHECK: [[TANH_HIDDEN:%.+]] = alloc() : memref<f32>
|
||||
// CHECK: store [[Ct]], [[TANH_HIDDEN]][] : memref<f32>
|
||||
// CHECK: krnl.define_loops 0
|
||||
// CHECK: krnl.optimize_loops {
|
||||
// CHECK: krnl.return_loops
|
||||
// CHECK: } : () -> ()
|
||||
// CHECK: krnl.iterate() with () {
|
||||
// CHECK: {{.*}} = load [[TANH_HIDDEN]][] : memref<f32>
|
||||
// CHECK: {{.*}} = constant 0.000000e+00 : f32
|
||||
// CHECK: {{.*}} = subf {{.*}}, {{.*}} : f32
|
||||
// CHECK: {{.*}} = exp {{.*}} : f32
|
||||
// CHECK: {{.*}} = exp {{.*}} : f32
|
||||
// CHECK: {{.*}} = subf {{.*}}, {{.*}} : f32
|
||||
// CHECK: {{.*}} = addf {{.*}}, {{.*}} : f32
|
||||
// CHECK: {{.*}} = divf {{.*}}, {{.*}} : f32
|
||||
// CHECK: store {{.*}}, [[hCt]][] : memref<f32>
|
||||
// CHECK: }
|
||||
// CHECK: [[hCt_LOAD:%.+]] = load [[hCt]][] : memref<f32>
|
||||
|
||||
// CHECK: [[Ht:%.+]] = mulf [[Ot_LOAD]], [[hCt_LOAD]] : f32
|
||||
// CHECK: store [[Ht]], [[HIDDEN_STATE]][%c0, %arg4, %arg5] : memref<1x3x3xf32>
|
||||
|
||||
// CHECK: dealloc [[XtWi_GEMM]] : memref<f32>
|
||||
// CHECK: dealloc [[XtWo_GEMM]] : memref<f32>
|
||||
// CHECK: dealloc [[XtWf_GEMM]] : memref<f32>
|
||||
// CHECK: dealloc [[XtWc_GEMM]] : memref<f32>
|
||||
// CHECK: dealloc [[Ht1Ri_GEMM]] : memref<f32>
|
||||
// CHECK: dealloc [[Ht1Ro_GEMM]] : memref<f32>
|
||||
// CHECK: dealloc [[Ht1Rf_GEMM]] : memref<f32>
|
||||
// CHECK: dealloc [[Ht1Rc_GEMM]] : memref<f32>
|
||||
// CHECK: dealloc [[It]] : memref<f32>
|
||||
// CHECK: dealloc [[Ft]] : memref<f32>
|
||||
// CHECK: dealloc [[ct]] : memref<f32>
|
||||
// CHECK: dealloc [[Ot]] : memref<f32>
|
||||
// CHECK: dealloc [[hCt]] : memref<f32>
|
||||
// CHECK: }
|
||||
// CHECK: }
|
||||
// CHECK: dealloc [[CELL_STATE]] : memref<1x3x3xf32>
|
||||
// CHECK: return [[HIDDEN_STATE]] : memref<1x3x3xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @test_lstm_reverse_mode(%arg0: tensor<4x3x2xf32>, %arg1: tensor<1x12x2xf32>, %arg2: tensor<1x12x3xf32>) -> tensor<*xf32> {
|
||||
%cst = constant unit
|
||||
%Y, %Y_h, %Y_c = "onnx.LSTM"(%arg0, %arg1, %arg2, %cst, %cst, %cst, %cst, %cst) {hidden_size = 3 : i64, direction = "reverse"} : (tensor<4x3x2xf32>, tensor<1x12x2xf32>, tensor<1x12x3xf32>, none, none, none, none, none) -> (none, tensor<*xf32>, none)
|
||||
return %Y_h : tensor<*xf32>
|
||||
|
||||
// CHECK: [[REVERSE_IV_MAP:#.+]] = affine_map<(d0)[s0] -> (-d0 + s0 - 1)>
|
||||
// CHECK-LABEL: @test_lstm_reverse_mode
|
||||
|
||||
// CHECK: [[REVERSE_SEQUENCE_LOOPS:%.+]] = krnl.define_loops 1
|
||||
// CHECK: [[REVERSE_SEQUENCE_OPT_LOOPS:%.+]] = krnl.optimize_loops {
|
||||
// CHECK: krnl.return_loops [[REVERSE_SEQUENCE_LOOPS]]
|
||||
// CHECK: } : () -> !krnl.loop
|
||||
// CHECK: krnl.iterate([[REVERSE_SEQUENCE_OPT_LOOPS]]) with ([[REVERSE_SEQUENCE_LOOPS]] -> %arg3 = 0 to 4) {
|
||||
// CHECK: %[[SEQUENCE_LEN:.+]] = constant 4 : index
|
||||
// CHECK: %[[REVERSE_SEQUENCE_IV:.+]] = affine.apply [[REVERSE_IV_MAP]](%arg3)[%[[SEQUENCE_LEN]]{{]}}
|
||||
// CHECK: [[Xt_LOAD:%.+]] = load %arg0[%[[REVERSE_SEQUENCE_IV]], {{.*}}, {{.*}}] : memref<4x3x2xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @test_lstm_bidirectional_mode(%arg0: tensor<4x3x2xf32>, %arg1: tensor<1x12x2xf32>, %arg2: tensor<1x12x3xf32>) -> tensor<*xf32> {
|
||||
%cst = constant unit
|
||||
%Y, %Y_h, %Y_c = "onnx.LSTM"(%arg0, %arg1, %arg2, %cst, %cst, %cst, %cst, %cst) {hidden_size = 3 : i64, direction = "bidirectional"} : (tensor<4x3x2xf32>, tensor<1x12x2xf32>, tensor<1x12x3xf32>, none, none, none, none, none) -> (none, tensor<*xf32>, none)
|
||||
return %Y_h : tensor<*xf32>
|
||||
|
||||
// CHECK: [[REVERSE_IV_MAP:#.+]] = affine_map<(d0)[s0] -> (-d0 + s0 - 1)>
|
||||
// CHECK-LABEL: @test_lstm_bidirectional_mode
|
||||
|
||||
// CHECK: [[SEQUENCE_LOOPS:%.+]] = krnl.define_loops 1
|
||||
// CHECK: [[SEQUENCE_OPT_LOOPS:%.+]] = krnl.optimize_loops {
|
||||
// CHECK: krnl.return_loops [[SEQUENCE_LOOPS]]
|
||||
// CHECK: } : () -> !krnl.loop
|
||||
// CHECK: krnl.iterate([[SEQUENCE_OPT_LOOPS]]) with ([[SEQUENCE_LOOPS]] -> %arg3 = 0 to 4) {
|
||||
// CHECK: [[Xt_LOAD:%.+]] = load %arg0[%arg3, {{.*}}, {{.*}}] : memref<4x3x2xf32>
|
||||
|
||||
// CHECK: [[REVERSE_SEQUENCE_LOOPS:%.+]] = krnl.define_loops 1
|
||||
// CHECK: [[REVERSE_SEQUENCE_OPT_LOOPS:%.+]] = krnl.optimize_loops {
|
||||
// CHECK: krnl.return_loops [[REVERSE_SEQUENCE_LOOPS]]
|
||||
// CHECK: } : () -> !krnl.loop
|
||||
// CHECK: krnl.iterate([[REVERSE_SEQUENCE_OPT_LOOPS]]) with ([[REVERSE_SEQUENCE_LOOPS]] -> %arg3 = 0 to 4) {
|
||||
// CHECK: %[[SEQUENCE_LEN:.+]] = constant 4 : index
|
||||
// CHECK: %[[REVERSE_SEQUENCE_IV:.+]] = affine.apply [[REVERSE_IV_MAP]](%arg3)[%[[SEQUENCE_LEN]]{{]}}
|
||||
// CHECK: [[Xt_LOAD:%.+]] = load %arg0[%[[REVERSE_SEQUENCE_IV]], {{.*}}, {{.*}}] : memref<4x3x2xf32>
|
||||
}
|
||||
|
|
|
@ -613,6 +613,222 @@ func @test_concat_3(%arg0 : tensor<5x1x32xf32>, %arg1 : tensor<5x3x32xf32>, %arg
|
|||
|
||||
// -----
|
||||
|
||||
func @test_rnn_all_results(%arg0: tensor<4x3x2xf32>, %arg1: tensor<1x12x2xf32>, %arg2: tensor<1x12x3xf32>) -> tensor<*xf32> {
|
||||
%cst = constant unit
|
||||
%Y, %Y_h = "onnx.RNN"(%arg0, %arg1, %arg2, %cst, %cst, %cst) {hidden_size = 3 : i64} : (tensor<4x3x2xf32>, tensor<1x12x2xf32>, tensor<1x12x3xf32>, none, none, none) -> (tensor<*xf32>, tensor<*xf32>)
|
||||
return %Y_h : tensor<*xf32>
|
||||
|
||||
// CHECK-LABEL: test_rnn_all_results
|
||||
// CHECK: %{{.*}}, [[RES:%.+]] = "onnx.RNN"(%arg0, %arg1, %arg2, %cst, %cst, %cst) {hidden_size = 3 : i64} : (tensor<4x3x2xf32>, tensor<1x12x2xf32>, tensor<1x12x3xf32>, none, none, none) -> (tensor<4x1x3x3xf32>, tensor<1x3x3xf32>)
|
||||
// CHECK: return [[RES]] : tensor<1x3x3xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @test_rnn_no_results(%arg0: tensor<4x3x2xf32>, %arg1: tensor<1x12x2xf32>, %arg2: tensor<1x12x3xf32>) -> () {
|
||||
%cst = constant unit
|
||||
%Y, %Y_h = "onnx.RNN"(%arg0, %arg1, %arg2, %cst, %cst, %cst) {hidden_size = 3 : i64} : (tensor<4x3x2xf32>, tensor<1x12x2xf32>, tensor<1x12x3xf32>, none, none, none) -> (none, none)
|
||||
return
|
||||
|
||||
// CHECK-LABEL: test_rnn_no_results
|
||||
// CHECK: %{{.*}}, [[RES:%.+]] = "onnx.RNN"(%arg0, %arg1, %arg2, %cst, %cst, %cst) {hidden_size = 3 : i64} : (tensor<4x3x2xf32>, tensor<1x12x2xf32>, tensor<1x12x3xf32>, none, none, none) -> (none, none)
|
||||
// CHECK: return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @test_rnn_missing_first_result(%arg0: tensor<4x3x2xf32>, %arg1: tensor<1x12x2xf32>, %arg2: tensor<1x12x3xf32>) -> tensor<*xf32> {
|
||||
%cst = constant unit
|
||||
%Y, %Y_h = "onnx.RNN"(%arg0, %arg1, %arg2, %cst, %cst, %cst) {hidden_size = 3 : i64} : (tensor<4x3x2xf32>, tensor<1x12x2xf32>, tensor<1x12x3xf32>, none, none, none) -> (none, tensor<*xf32>)
|
||||
return %Y_h : tensor<*xf32>
|
||||
|
||||
// CHECK-LABEL: test_rnn_missing_first_result
|
||||
// CHECK: %{{.*}}, [[RES:%.+]] = "onnx.RNN"(%arg0, %arg1, %arg2, %cst, %cst, %cst) {hidden_size = 3 : i64} : (tensor<4x3x2xf32>, tensor<1x12x2xf32>, tensor<1x12x3xf32>, none, none, none) -> (none, tensor<1x3x3xf32>)
|
||||
// CHECK: return [[RES]] : tensor<1x3x3xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @test_rnn_missing_trailing_result(%arg0: tensor<4x3x2xf32>, %arg1: tensor<1x12x2xf32>, %arg2: tensor<1x12x3xf32>) -> () {
|
||||
%cst = constant unit
|
||||
%Y, %Y_h = "onnx.RNN"(%arg0, %arg1, %arg2, %cst, %cst, %cst) {hidden_size = 3 : i64} : (tensor<4x3x2xf32>, tensor<1x12x2xf32>, tensor<1x12x3xf32>, none, none, none) -> (tensor<*xf32>, none)
|
||||
return
|
||||
|
||||
// CHECK-LABEL: test_rnn_missing_trailing_result
|
||||
// CHECK: %{{.*}}, [[RES:%.+]] = "onnx.RNN"(%arg0, %arg1, %arg2, %cst, %cst, %cst) {hidden_size = 3 : i64} : (tensor<4x3x2xf32>, tensor<1x12x2xf32>, tensor<1x12x3xf32>, none, none, none) -> (tensor<4x1x3x3xf32>, none)
|
||||
// CHECK: return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @test_rnn_all_results_no_hidden_size(%arg0: tensor<4x3x2xf32>, %arg1: tensor<1x12x2xf32>, %arg2: tensor<1x12x3xf32>) -> tensor<*xf32> {
|
||||
%cst = constant unit
|
||||
%Y, %Y_h = "onnx.RNN"(%arg0, %arg1, %arg2, %cst, %cst, %cst) : (tensor<4x3x2xf32>, tensor<1x12x2xf32>, tensor<1x12x3xf32>, none, none, none) -> (tensor<*xf32>, tensor<*xf32>)
|
||||
return %Y_h : tensor<*xf32>
|
||||
|
||||
// CHECK-LABEL: test_rnn_all_results_no_hidden_size
|
||||
// CHECK: %{{.*}}, [[RES:%.+]] = "onnx.RNN"(%arg0, %arg1, %arg2, %cst, %cst, %cst) {hidden_size = 3 : i64} : (tensor<4x3x2xf32>, tensor<1x12x2xf32>, tensor<1x12x3xf32>, none, none, none) -> (tensor<4x1x3x3xf32>, tensor<1x3x3xf32>)
|
||||
// CHECK: return [[RES]] : tensor<1x3x3xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @test_rnn_all_results_unknown_dims(%arg0: tensor<?x?x?xf32>, %arg1: tensor<?x?x?xf32>, %arg2: tensor<?x?x?xf32>) -> tensor<*xf32> {
|
||||
%cst = constant unit
|
||||
%Y, %Y_h = "onnx.RNN"(%arg0, %arg1, %arg2, %cst, %cst, %cst) : (tensor<?x?x?xf32>, tensor<?x?x?xf32>, tensor<?x?x?xf32>, none, none, none) -> (tensor<*xf32>, tensor<*xf32>)
|
||||
return %Y_h : tensor<*xf32>
|
||||
|
||||
// CHECK-LABEL: test_rnn_all_results_unknown_dims
|
||||
// CHECK: %{{.*}}, [[RES:%.+]] = "onnx.RNN"(%arg0, %arg1, %arg2, %cst, %cst, %cst) : (tensor<?x?x?xf32>, tensor<?x?x?xf32>, tensor<?x?x?xf32>, none, none, none) -> (tensor<?x1x?x?xf32>, tensor<1x?x?xf32>)
|
||||
// CHECK: return [[RES]] : tensor<1x?x?xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @test_gru_all_results(%arg0: tensor<4x3x2xf32>, %arg1: tensor<1x12x2xf32>, %arg2: tensor<1x12x3xf32>) -> tensor<*xf32> {
|
||||
%cst = constant unit
|
||||
%Y, %Y_h = "onnx.GRU"(%arg0, %arg1, %arg2, %cst, %cst, %cst) {hidden_size = 3 : i64} : (tensor<4x3x2xf32>, tensor<1x12x2xf32>, tensor<1x12x3xf32>, none, none, none) -> (tensor<*xf32>, tensor<*xf32>)
|
||||
return %Y_h : tensor<*xf32>
|
||||
|
||||
// CHECK-LABEL: test_gru_all_results
|
||||
// CHECK: %{{.*}}, [[RES:%.+]] = "onnx.GRU"(%arg0, %arg1, %arg2, %cst, %cst, %cst) {hidden_size = 3 : i64} : (tensor<4x3x2xf32>, tensor<1x12x2xf32>, tensor<1x12x3xf32>, none, none, none) -> (tensor<4x1x3x3xf32>, tensor<1x3x3xf32>)
|
||||
// CHECK: return [[RES]] : tensor<1x3x3xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @test_gru_no_results(%arg0: tensor<4x3x2xf32>, %arg1: tensor<1x12x2xf32>, %arg2: tensor<1x12x3xf32>) -> () {
|
||||
%cst = constant unit
|
||||
%Y, %Y_h = "onnx.GRU"(%arg0, %arg1, %arg2, %cst, %cst, %cst) {hidden_size = 3 : i64} : (tensor<4x3x2xf32>, tensor<1x12x2xf32>, tensor<1x12x3xf32>, none, none, none) -> (none, none)
|
||||
return
|
||||
|
||||
// CHECK-LABEL: test_gru_no_results
|
||||
// CHECK: %{{.*}}, [[RES:%.+]] = "onnx.GRU"(%arg0, %arg1, %arg2, %cst, %cst, %cst) {hidden_size = 3 : i64} : (tensor<4x3x2xf32>, tensor<1x12x2xf32>, tensor<1x12x3xf32>, none, none, none) -> (none, none)
|
||||
// CHECK: return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @test_gru_missing_first_result(%arg0: tensor<4x3x2xf32>, %arg1: tensor<1x12x2xf32>, %arg2: tensor<1x12x3xf32>) -> tensor<*xf32> {
|
||||
%cst = constant unit
|
||||
%Y, %Y_h = "onnx.GRU"(%arg0, %arg1, %arg2, %cst, %cst, %cst) {hidden_size = 3 : i64} : (tensor<4x3x2xf32>, tensor<1x12x2xf32>, tensor<1x12x3xf32>, none, none, none) -> (none, tensor<*xf32>)
|
||||
return %Y_h : tensor<*xf32>
|
||||
|
||||
// CHECK-LABEL: test_gru_missing_first_result
|
||||
// CHECK: %{{.*}}, [[RES:%.+]] = "onnx.GRU"(%arg0, %arg1, %arg2, %cst, %cst, %cst) {hidden_size = 3 : i64} : (tensor<4x3x2xf32>, tensor<1x12x2xf32>, tensor<1x12x3xf32>, none, none, none) -> (none, tensor<1x3x3xf32>)
|
||||
// CHECK: return [[RES]] : tensor<1x3x3xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @test_gru_missing_trailing_result(%arg0: tensor<4x3x2xf32>, %arg1: tensor<1x12x2xf32>, %arg2: tensor<1x12x3xf32>) -> () {
|
||||
%cst = constant unit
|
||||
%Y, %Y_h = "onnx.GRU"(%arg0, %arg1, %arg2, %cst, %cst, %cst) {hidden_size = 3 : i64} : (tensor<4x3x2xf32>, tensor<1x12x2xf32>, tensor<1x12x3xf32>, none, none, none) -> (tensor<*xf32>, none)
|
||||
return
|
||||
|
||||
// CHECK-LABEL: test_gru_missing_trailing_result
|
||||
// CHECK: %{{.*}}, [[RES:%.+]] = "onnx.GRU"(%arg0, %arg1, %arg2, %cst, %cst, %cst) {hidden_size = 3 : i64} : (tensor<4x3x2xf32>, tensor<1x12x2xf32>, tensor<1x12x3xf32>, none, none, none) -> (tensor<4x1x3x3xf32>, none)
|
||||
// CHECK: return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @test_gru_all_results_no_hidden_size(%arg0: tensor<4x3x2xf32>, %arg1: tensor<1x12x2xf32>, %arg2: tensor<1x12x3xf32>) -> tensor<*xf32> {
|
||||
%cst = constant unit
|
||||
%Y, %Y_h = "onnx.GRU"(%arg0, %arg1, %arg2, %cst, %cst, %cst) : (tensor<4x3x2xf32>, tensor<1x12x2xf32>, tensor<1x12x3xf32>, none, none, none) -> (tensor<*xf32>, tensor<*xf32>)
|
||||
return %Y_h : tensor<*xf32>
|
||||
|
||||
// CHECK-LABEL: test_gru_all_results_no_hidden_size
|
||||
// CHECK: %{{.*}}, [[RES:%.+]] = "onnx.GRU"(%arg0, %arg1, %arg2, %cst, %cst, %cst) {hidden_size = 3 : i64} : (tensor<4x3x2xf32>, tensor<1x12x2xf32>, tensor<1x12x3xf32>, none, none, none) -> (tensor<4x1x3x3xf32>, tensor<1x3x3xf32>)
|
||||
// CHECK: return [[RES]] : tensor<1x3x3xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @test_gru_all_results_unknown_dims(%arg0: tensor<?x?x?xf32>, %arg1: tensor<?x?x?xf32>, %arg2: tensor<?x?x?xf32>) -> tensor<*xf32> {
|
||||
%cst = constant unit
|
||||
%Y, %Y_h = "onnx.GRU"(%arg0, %arg1, %arg2, %cst, %cst, %cst) : (tensor<?x?x?xf32>, tensor<?x?x?xf32>, tensor<?x?x?xf32>, none, none, none) -> (tensor<*xf32>, tensor<*xf32>)
|
||||
return %Y_h : tensor<*xf32>
|
||||
|
||||
// CHECK-LABEL: test_gru_all_results_unknown_dims
|
||||
// CHECK: %{{.*}}, [[RES:%.+]] = "onnx.GRU"(%arg0, %arg1, %arg2, %cst, %cst, %cst) : (tensor<?x?x?xf32>, tensor<?x?x?xf32>, tensor<?x?x?xf32>, none, none, none) -> (tensor<?x1x?x?xf32>, tensor<1x?x?xf32>)
|
||||
// CHECK: return [[RES]] : tensor<1x?x?xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @test_lstm_all_results(%arg0: tensor<4x3x2xf32>, %arg1: tensor<1x12x2xf32>, %arg2: tensor<1x12x3xf32>) -> tensor<*xf32> {
|
||||
%cst = constant unit
|
||||
%Y, %Y_h, %Y_c = "onnx.LSTM"(%arg0, %arg1, %arg2, %cst, %cst, %cst, %cst, %cst) {hidden_size = 3 : i64} : (tensor<4x3x2xf32>, tensor<1x12x2xf32>, tensor<1x12x3xf32>, none, none, none, none, none) -> (tensor<*xf32>, tensor<*xf32>, tensor<*xf32>)
|
||||
return %Y_h : tensor<*xf32>
|
||||
|
||||
// CHECK-LABEL: test_lstm_all_results
|
||||
// CHECK: %{{.*}}, [[RES:%.+]], %{{.*}} = "onnx.LSTM"(%arg0, %arg1, %arg2, %cst, %cst, %cst, %cst, %cst) {hidden_size = 3 : i64} : (tensor<4x3x2xf32>, tensor<1x12x2xf32>, tensor<1x12x3xf32>, none, none, none, none, none) -> (tensor<4x1x3x3xf32>, tensor<1x3x3xf32>, tensor<1x3x3xf32>)
|
||||
// CHECK: return [[RES]] : tensor<1x3x3xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @test_lstm_no_results(%arg0: tensor<4x3x2xf32>, %arg1: tensor<1x12x2xf32>, %arg2: tensor<1x12x3xf32>) -> () {
|
||||
%cst = constant unit
|
||||
%Y, %Y_h, %Y_c = "onnx.LSTM"(%arg0, %arg1, %arg2, %cst, %cst, %cst, %cst, %cst) {hidden_size = 3 : i64} : (tensor<4x3x2xf32>, tensor<1x12x2xf32>, tensor<1x12x3xf32>, none, none, none, none, none) -> (none, none, none)
|
||||
return
|
||||
|
||||
// CHECK-LABEL: test_lstm_no_results
|
||||
// CHECK: %{{.*}}, [[RES:%.+]], %{{.*}} = "onnx.LSTM"(%arg0, %arg1, %arg2, %cst, %cst, %cst, %cst, %cst) {hidden_size = 3 : i64} : (tensor<4x3x2xf32>, tensor<1x12x2xf32>, tensor<1x12x3xf32>, none, none, none, none, none) -> (none, none, none)
|
||||
// CHECK: return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @test_lstm_missing_first_result(%arg0: tensor<4x3x2xf32>, %arg1: tensor<1x12x2xf32>, %arg2: tensor<1x12x3xf32>) -> tensor<*xf32> {
|
||||
%cst = constant unit
|
||||
%Y, %Y_h, %Y_c = "onnx.LSTM"(%arg0, %arg1, %arg2, %cst, %cst, %cst, %cst, %cst) {hidden_size = 3 : i64} : (tensor<4x3x2xf32>, tensor<1x12x2xf32>, tensor<1x12x3xf32>, none, none, none, none, none) -> (none, tensor<*xf32>, tensor<*xf32>)
|
||||
return %Y_h : tensor<*xf32>
|
||||
|
||||
// CHECK-LABEL: test_lstm_missing_first_result
|
||||
// CHECK: %{{.*}}, [[RES:%.+]], %{{.*}} = "onnx.LSTM"(%arg0, %arg1, %arg2, %cst, %cst, %cst, %cst, %cst) {hidden_size = 3 : i64} : (tensor<4x3x2xf32>, tensor<1x12x2xf32>, tensor<1x12x3xf32>, none, none, none, none, none) -> (none, tensor<1x3x3xf32>, tensor<1x3x3xf32>)
|
||||
// CHECK: return [[RES]] : tensor<1x3x3xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @test_lstm_missing_trailing_result(%arg0: tensor<4x3x2xf32>, %arg1: tensor<1x12x2xf32>, %arg2: tensor<1x12x3xf32>) -> tensor<*xf32> {
|
||||
%cst = constant unit
|
||||
%Y, %Y_h, %Y_c = "onnx.LSTM"(%arg0, %arg1, %arg2, %cst, %cst, %cst, %cst, %cst) {hidden_size = 3 : i64} : (tensor<4x3x2xf32>, tensor<1x12x2xf32>, tensor<1x12x3xf32>, none, none, none, none, none) -> (tensor<*xf32>, tensor<*xf32>, none)
|
||||
return %Y_h : tensor<*xf32>
|
||||
|
||||
// CHECK-LABEL: test_lstm_missing_trailing_result
|
||||
// CHECK: %{{.*}}, [[RES:%.+]], %{{.*}} = "onnx.LSTM"(%arg0, %arg1, %arg2, %cst, %cst, %cst, %cst, %cst) {hidden_size = 3 : i64} : (tensor<4x3x2xf32>, tensor<1x12x2xf32>, tensor<1x12x3xf32>, none, none, none, none, none) -> (tensor<4x1x3x3xf32>, tensor<1x3x3xf32>, none)
|
||||
// CHECK: return [[RES]] : tensor<1x3x3xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @test_lstm_all_results_no_hidden_size(%arg0: tensor<4x3x2xf32>, %arg1: tensor<1x12x2xf32>, %arg2: tensor<1x12x3xf32>) -> tensor<*xf32> {
|
||||
%cst = constant unit
|
||||
%Y, %Y_h, %Y_c = "onnx.LSTM"(%arg0, %arg1, %arg2, %cst, %cst, %cst, %cst, %cst) : (tensor<4x3x2xf32>, tensor<1x12x2xf32>, tensor<1x12x3xf32>, none, none, none, none, none) -> (tensor<*xf32>, tensor<*xf32>, tensor<*xf32>)
|
||||
return %Y_h : tensor<*xf32>
|
||||
|
||||
// CHECK-LABEL: test_lstm_all_results_no_hidden_size
|
||||
// CHECK: %{{.*}}, [[RES:%.+]], %{{.*}} = "onnx.LSTM"(%arg0, %arg1, %arg2, %cst, %cst, %cst, %cst, %cst) {hidden_size = 3 : i64} : (tensor<4x3x2xf32>, tensor<1x12x2xf32>, tensor<1x12x3xf32>, none, none, none, none, none) -> (tensor<4x1x3x3xf32>, tensor<1x3x3xf32>, tensor<1x3x3xf32>)
|
||||
// CHECK: return [[RES]] : tensor<1x3x3xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @test_lstm_all_results_unknown_dims(%arg0: tensor<?x?x?xf32>, %arg1: tensor<?x?x?xf32>, %arg2: tensor<?x?x?xf32>) -> tensor<*xf32> {
|
||||
%cst = constant unit
|
||||
%Y, %Y_h, %Y_c = "onnx.LSTM"(%arg0, %arg1, %arg2, %cst, %cst, %cst, %cst, %cst) : (tensor<?x?x?xf32>, tensor<?x?x?xf32>, tensor<?x?x?xf32>, none, none, none, none, none) -> (tensor<*xf32>, tensor<*xf32>, tensor<*xf32>)
|
||||
return %Y_h : tensor<*xf32>
|
||||
|
||||
// CHECK-LABEL: test_lstm_all_results_unknown_dims
|
||||
// CHECK: %{{.*}}, [[RES:%.+]], %{{.*}} = "onnx.LSTM"(%arg0, %arg1, %arg2, %cst, %cst, %cst, %cst, %cst) : (tensor<?x?x?xf32>, tensor<?x?x?xf32>, tensor<?x?x?xf32>, none, none, none, none, none) -> (tensor<?x1x?x?xf32>, tensor<1x?x?xf32>, tensor<1x?x?xf32>)
|
||||
// CHECK: return [[RES]] : tensor<1x?x?xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @test_split_1(%arg0 : tensor<16x32x64xf32>) -> tensor<*xf32> {
|
||||
%0, %1 = "onnx.Split"(%arg0) { axis = 1 } : (tensor<16x32x64xf32>) -> (tensor<*xf32>, tensor<*xf32>)
|
||||
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||
|
|
|
@ -63,7 +63,8 @@ OpsWithShapeInference = [
|
|||
'LeakyRelu', 'Elu', 'Selu', 'HardSigmoid', 'Reshape', 'Reciprocal',
|
||||
'Identity', 'Cos', 'Log', 'Transpose', 'Softmax', 'ReduceMax', 'ReduceMin',
|
||||
'ReduceProd', 'ReduceSum', 'Softplus', 'Softsign', 'Sqrt', 'Unsqueeze',
|
||||
'Sign', 'Constant', 'AveragePool', 'Abs', 'Conv', 'Concat', 'Neg', 'Split'
|
||||
'Sign', 'Constant', 'AveragePool', 'Abs', 'Conv', 'Concat', 'Neg', 'RNN',
|
||||
'LSTM', 'GRU', 'Split'
|
||||
]
|
||||
|
||||
# Operations supporting canonicalization.
|
||||
|
|
Loading…
Reference in New Issue