diff --git a/src/Conversion/ONNXToKrnl/CMakeLists.txt b/src/Conversion/ONNXToKrnl/CMakeLists.txt index 5d4cbe9..f147f8f 100644 --- a/src/Conversion/ONNXToKrnl/CMakeLists.txt +++ b/src/Conversion/ONNXToKrnl/CMakeLists.txt @@ -9,6 +9,9 @@ add_library(OMONNXToKrnl NN/Conv.cpp NN/Normalization.cpp NN/Pooling.cpp + RNN/RNNBase.cpp + RNN/RNNBase.hpp + RNN/LSTM.cpp Tensor/Identity.cpp Tensor/Reshape.cpp Tensor/PadConstantValuePad.cpp diff --git a/src/Conversion/ONNXToKrnl/ConvertONNXToKrnl.cpp b/src/Conversion/ONNXToKrnl/ConvertONNXToKrnl.cpp index 95750a5..f529a1a 100644 --- a/src/Conversion/ONNXToKrnl/ConvertONNXToKrnl.cpp +++ b/src/Conversion/ONNXToKrnl/ConvertONNXToKrnl.cpp @@ -103,6 +103,8 @@ void FrontendToKrnlLoweringPass::runOnOperation() { populateLoweringONNXConvOpPattern(patterns, &getContext()); populateLoweringONNXNormalizationOpPattern(patterns, &getContext()); populateLoweringONNXPoolingOpPattern(patterns, &getContext()); + // Recurrent neural network + populateLoweringONNXLSTMOpPattern(patterns, &getContext()); // Entry point patterns.insert(&getContext()); diff --git a/src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.cpp b/src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.cpp index 5a7ae4e..aac7dc9 100644 --- a/src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.cpp +++ b/src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.cpp @@ -102,16 +102,14 @@ Value insertAllocAndDealloc(MemRefType type, Location loc, // Determine if current function returns the result value of the // current op being lowered. If it does then dealloc should not be // inserted. -bool checkInsertDealloc(Operation *currentOp) { +bool checkInsertDealloc(Operation *currentOp, int resultIndex) { auto parentBlock = currentOp->getBlock(); bool insertDealloc = true; - parentBlock->walk([&insertDealloc, currentOp](ReturnOp op) { - assert(currentOp->getNumResults() < 2 && - "No more than one result supported (for now)."); + parentBlock->walk([&insertDealloc, currentOp, resultIndex](ReturnOp op) { // If there is at least one result to investigate. if (currentOp->getNumResults() > 0) { - auto result = currentOp->getResult(0); + auto result = currentOp->getResult(resultIndex); for (const auto &operand : op.getOperands()) if (operand == result) insertDealloc = false; @@ -488,4 +486,4 @@ Value emitNegativeInfinityConstantOp( int64_t ArrayAttrIntVal(ArrayAttr a, int i) { return (a.getValue()[i]).cast().getInt(); -} \ No newline at end of file +} diff --git a/src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp b/src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp index bb4c618..74da9f2 100644 --- a/src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp +++ b/src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp @@ -46,7 +46,7 @@ Value insertAllocAndDealloc(MemRefType type, Location loc, // Determine if current function returns the result value of the // current op being lowered. If it does then dealloc should not be // inserted. -bool checkInsertDealloc(Operation *currentOp); +bool checkInsertDealloc(Operation *currentOp, int resultIndex = 0); // Create a mapping from result type's dimensions to input type's dimensions, // given that the result type is the result of a reduction op over the input @@ -218,6 +218,10 @@ void populateLoweringONNXNormalizationOpPattern( void populateLoweringONNXPoolingOpPattern( OwningRewritePatternList &patterns, MLIRContext *ctx); +// `RNN` directory methods: +void populateLoweringONNXLSTMOpPattern( + OwningRewritePatternList &patterns, MLIRContext *ctx); + // `Tensor` directory methods: void populateLoweringONNXUnsqueezeOpPattern( diff --git a/src/Conversion/ONNXToKrnl/RNN/LSTM.cpp b/src/Conversion/ONNXToKrnl/RNN/LSTM.cpp new file mode 100644 index 0000000..e4b985e --- /dev/null +++ b/src/Conversion/ONNXToKrnl/RNN/LSTM.cpp @@ -0,0 +1,537 @@ +//===--------------- LSTM.cpp - Lowering LSTM Op --------------------------===// +// +// Copyright 2019 The IBM Research Authors. +// +// ============================================================================= +// +// This file lowers the ONNX LSTM Operators to Krnl dialect. +// +//===----------------------------------------------------------------------===// + +#include "src/Conversion/ONNXToKrnl/RNN/RNNBase.hpp" + +using namespace mlir; + +struct LstmState { + Value allH; + Value ht; + Value ct; +}; + +struct LstmActivationPack { + RNNActivation f; + RNNActivation g; + RNNActivation h; +}; + +template <> +bool hasAllNoneOutput(ONNXLSTMOp *op) { + return ( + isNoneType(op->Y()) && isNoneType(op->Y_h()) && isNoneType(op->Y_c())); +} + +template <> +std::tuple +getActivationPack(ONNXLSTMOp *op) { + auto direction = op->direction(); + auto activations = op->activations(); + auto activationAlpha = op->activation_alpha(); + auto activationBeta = op->activation_beta(); + + LstmActivationPack activationForward, activationReverse; + + // Get activation function name. + // Default forward functions + activationForward.f.name = "sigmoid"; + activationForward.g.name = "tanh"; + activationForward.h.name = "tanh"; + // Default backward functions + activationReverse.f.name = "sigmoid"; + activationReverse.g.name = "tanh"; + activationReverse.h.name = "tanh"; + if (activations) { + ArrayAttr activationArrAttr = activations.getValue(); + if (direction == FORWARD || direction == BIDIRECTIONAL) { + // Forward activations. + if (activationArrAttr.size() > 0) { + activationForward.f.name = + activationArrAttr[0].cast().getValue(); + } + if (activationArrAttr.size() > 1) { + activationForward.g.name = + activationArrAttr[1].cast().getValue(); + } + if (activationArrAttr.size() > 2) { + activationForward.h.name = + activationArrAttr[2].cast().getValue(); + } + } + + // Reverse activations. + if (direction == REVERSE || direction == BIDIRECTIONAL) { + int startIndex = (direction == REVERSE) ? 0 : 3; + if (activationArrAttr.size() > startIndex) { + activationReverse.f.name = + activationArrAttr[startIndex].cast().getValue(); + } + if (activationArrAttr.size() > startIndex + 1) { + activationReverse.g.name = + activationArrAttr[startIndex + 1].cast().getValue(); + } + if (activationArrAttr.size() > startIndex + 2) { + activationReverse.h.name = + activationArrAttr[startIndex + 2].cast().getValue(); + } + } + } + + // Get alpha attributes. + if (activationAlpha) { + ArrayAttr activationArrAttr = activationAlpha.getValue(); + if (direction == FORWARD || direction == BIDIRECTIONAL) { + // Forward activations. + if (activationArrAttr.size() > 0) { + activationForward.f.alpha = activationArrAttr[0].cast(); + } + if (activationArrAttr.size() > 1) { + activationForward.g.alpha = activationArrAttr[1].cast(); + } + if (activationArrAttr.size() > 2) { + activationForward.h.alpha = activationArrAttr[2].cast(); + } + } + + // Reverse activations. + if (direction == REVERSE || direction == BIDIRECTIONAL) { + int startIndex = (direction == REVERSE) ? 0 : 3; + if (activationArrAttr.size() > startIndex) { + activationReverse.f.alpha = + activationArrAttr[startIndex].cast(); + } + if (activationArrAttr.size() > startIndex + 1) { + activationReverse.g.alpha = + activationArrAttr[startIndex + 1].cast(); + } + if (activationArrAttr.size() > startIndex + 2) { + activationReverse.h.alpha = + activationArrAttr[startIndex + 2].cast(); + } + } + } + + // Get beta attributes. + if (activationBeta) { + ArrayAttr activationArrAttr = activationBeta.getValue(); + if (direction == FORWARD || direction == BIDIRECTIONAL) { + // Forward activations. + if (activationArrAttr.size() > 0) { + activationForward.f.beta = activationArrAttr[0].cast(); + } + if (activationArrAttr.size() > 1) { + activationForward.g.beta = activationArrAttr[1].cast(); + } + if (activationArrAttr.size() > 2) { + activationForward.h.beta = activationArrAttr[2].cast(); + } + } + + // Reverse activations. + if (direction == REVERSE || direction == BIDIRECTIONAL) { + int startIndex = (direction == REVERSE) ? 0 : 3; + if (activationArrAttr.size() > startIndex) { + activationReverse.f.beta = + activationArrAttr[startIndex].cast(); + } + if (activationArrAttr.size() > startIndex + 1) { + activationReverse.g.beta = + activationArrAttr[startIndex + 1].cast(); + } + if (activationArrAttr.size() > startIndex + 2) { + activationReverse.h.beta = + activationArrAttr[startIndex + 2].cast(); + } + } + } + + return std::make_tuple(activationForward, activationReverse); +} + +template <> +LstmState allocAndInitializeStates( + ConversionPatternRewriter &rewriter, Location loc, ONNXLSTMOp *op, + OperandAdaptor operandAdaptor) { + LstmState state; + + // Insert allocation and deallocation for the results of this operation. + if (!isNoneType(op->Y())) { + auto yMemRefType = convertToMemRefType(op->Y().getType()); + if (hasAllConstantDimensions(yMemRefType)) + state.allH = insertAllocAndDealloc(yMemRefType, loc, rewriter, + checkInsertDealloc(op->getOperation(), 0)); + else + emitError(loc, "Unsupported dynamic dimensions."); + } else { + state.allH = op->Y(); + } + + // Y_h :: [num_directions, batch_size, hidden_size] + if (!isNoneType(op->Y_h())) { + auto yhMemRefType = convertToMemRefType(op->Y_h().getType()); + if (hasAllConstantDimensions(yhMemRefType)) + state.ht = insertAllocAndDealloc(yhMemRefType, loc, rewriter, + checkInsertDealloc(op->getOperation(), 1)); + else + emitError(loc, "Unsupported dynamic dimensions."); + } else { + auto yhMemRefType = MemRefType::get( + {dimAt(operandAdaptor.W(), 0), dimAt(operandAdaptor.X(), 1), + dimAt(operandAdaptor.R(), 2)}, + operandAdaptor.X().getType().cast().getElementType()); + state.ht = insertAllocAndDealloc(yhMemRefType, loc, rewriter, true); + } + + // Y_c :: [num_directions, batch_size, hidden_size] + if (!isNoneType(op->Y_c())) { + auto ycMemRefType = convertToMemRefType(op->Y_c().getType()); + if (hasAllConstantDimensions(ycMemRefType)) + state.ct = insertAllocAndDealloc(ycMemRefType, loc, rewriter, + checkInsertDealloc(op->getOperation(), 2)); + else + emitError(loc, "Unsupported dynamic dimensions."); + } else { + auto ycMemRefType = MemRefType::get( + {dimAt(operandAdaptor.W(), 0), dimAt(operandAdaptor.X(), 1), + dimAt(operandAdaptor.R(), 2)}, + operandAdaptor.X().getType().cast().getElementType()); + state.ct = insertAllocAndDealloc(ycMemRefType, loc, rewriter, true); + } + + // Initialize ht and ct. + Value zero = emitConstantOp(rewriter, loc, + operandAdaptor.X().getType().cast().getElementType(), 0); + int nLoops = 3; + BuildKrnlLoop initializationLoops(rewriter, loc, nLoops); + initializationLoops.createDefineOptimizeAndIterateOp(state.ht); + auto ipInitializationLoops = rewriter.saveInsertionPoint(); + rewriter.setInsertionPointToStart(initializationLoops.getIterateBlock()); + { + SmallVector IVs; + for (int i = 0; i < nLoops; ++i) + IVs.emplace_back(initializationLoops.getInductionVar(i)); + + Value hiddenVal = zero; + if (!isNoneType(operandAdaptor.initial_h())) + hiddenVal = rewriter.create(loc, operandAdaptor.initial_h(), IVs); + rewriter.create(loc, hiddenVal, state.ht, IVs); + + Value cellVal = zero; + if (!isNoneType(operandAdaptor.initial_c())) + cellVal = rewriter.create(loc, operandAdaptor.initial_c(), IVs); + rewriter.create(loc, cellVal, state.ct, IVs); + } + rewriter.restoreInsertionPoint(ipInitializationLoops); + return state; +} + +template <> +void calculateState( + ConversionPatternRewriter &rewriter, Location loc, + OperandAdaptor operandAdaptor, LstmState state, + LstmActivationPack activationPack, Value directionIV, Value sequenceIV) { + + bool hasBiasForInput = false, hasPeepholes = false; + if (!isNoneType(operandAdaptor.B())) + hasBiasForInput = true; + if (!isNoneType(operandAdaptor.P())) + hasPeepholes = true; + + // Prepare dimensions. + auto batchDimSize = dimAt(operandAdaptor.X(), 1); + auto inputDimSize = dimAt(operandAdaptor.X(), 2); + auto hiddenDimSize = dimAt(operandAdaptor.R(), 2); + Value hiddenDimVal = + emitConstantOp(rewriter, loc, rewriter.getIndexType(), hiddenDimSize); + + auto elementType = + operandAdaptor.X().getType().cast().getElementType(); + + // Prepare AffineMap to access bias, peepholes tensors. + AffineMap accessByOffsetMap; + { + AffineExpr iv = rewriter.getAffineDimExpr(0); + AffineExpr index = rewriter.getAffineSymbolExpr(0); + AffineExpr size = rewriter.getAffineSymbolExpr(1); + AffineExpr accessByOffsetExpr = index * size + iv; + accessByOffsetMap = AffineMap::get(1, 2, accessByOffsetExpr); + } + + // Prepare constant indices. + SmallVector constantIndices; + for (int i = 0; i < 8; i++) + constantIndices.emplace_back( + emitConstantOp(rewriter, loc, rewriter.getIndexType(), i)); + + // Equations for LSTM. + // it = f(Xt*(Wi^T) + Ht-1*(Ri^T) + Pi (.) Ct-1 + Wbi + Rbi) + // ft = f(Xt*(Wf^T) + Ht-1*(Rf^T) + Pf (.) Ct-1 + Wbf + Rbf) + // ct = g(Xt*(Wc^T) + Ht-1*(Rc^T) + Wbc + Rbc) + // Ct = ft (.) Ct-1 + it (.) ct + // ot = f(Xt*(Wo^T) + Ht-1*(Ro^T) + Po (.) Ct + Wbo + Rbo) + // Ht = ot (.) h(Ct) + // + // The following code will emit loops as follows: + // for b in 0 .. BatchDimSize + // for h in 0 .. HiddenDimSize + // for i in 0 .. InputDimSize { + // compute Xt*(Wi^T), Xt*(Wo^T), Xt*(Wf^t), Xt*(Wc^T), + // Ht-1*(Ri^T), Ht-1*(Ro^T), Ht-1*(Rf^t), Ht-1*(Rc^T) + // } + // compute it, ft, ct, Ct, ot, Ht + + BuildKrnlLoop stateLoops(rewriter, loc, 2); + stateLoops.createDefineAndOptimizeOp(); + stateLoops.pushBounds(0, batchDimSize); + stateLoops.pushBounds(0, hiddenDimSize); + stateLoops.createIterateOp(); + + rewriter.setInsertionPointToStart(stateLoops.getIterateBlock()); + { + auto batchIV = stateLoops.getInductionVar(0); + auto hiddenIV = stateLoops.getInductionVar(1); + + // IVs to access tensors. + // IVs for the hidden and cell state tensors. + SmallVector hIVs, cIVs; + // IVs for the bias tensors for W and R. + SmallVector, 4> wbIOFCIVs, rbIOFCIVs; + // IVs for the peepholes. + SmallVector, 4> pIOFIVs; + + { // Compute IVs. + // H :: [num_directions, batch_size, hidden_size] + hIVs = {directionIV, batchIV, hiddenIV}; + // C :: [num_directions, batch_size, hidden_size] + cIVs = {directionIV, batchIV, hiddenIV}; + + // Bias [Wb[iofc], Rb[iofc]] :: [num_directions, 8*hidden_size] + if (hasBiasForInput) { + // Wb[iofc] + for (unsigned i = 0; i < 4; ++i) { + Value wHiddenIV = + rewriter.create(loc, accessByOffsetMap, + ValueRange(std::vector{/*iv=*/hiddenIV, + /*index=*/constantIndices[i], /*size=*/hiddenDimVal})); + wbIOFCIVs.emplace_back(SmallVector{directionIV, wHiddenIV}); + } + // Rb[iofc] + for (unsigned i = 4; i < 8; ++i) { + SmallVector rbIVs; + Value rHiddenIV = + rewriter.create(loc, accessByOffsetMap, + ValueRange(std::vector{/*iv=*/hiddenIV, + /*index=*/constantIndices[i], /*size=*/hiddenDimVal})); + rbIOFCIVs.emplace_back(SmallVector{directionIV, rHiddenIV}); + } + } + + // Peepholes P[iof] :: [num_directions, 3*hidden_size] + if (hasPeepholes) { + for (unsigned i = 0; i < 3; ++i) { + SmallVector pIVs; + Value pHiddenIV = + rewriter.create(loc, accessByOffsetMap, + ValueRange(std::vector{ + hiddenIV, constantIndices[i], hiddenDimVal})); + pIOFIVs.emplace_back(SmallVector{directionIV, pHiddenIV}); + } + } + } + + Value loadH = rewriter.create(loc, state.ht, hIVs); + Value loadC = rewriter.create(loc, state.ct, cIVs); + + // Emit instructions for matrix multiplications: + // Xt*(Wi^T), Xt*(Wo^T), Xt*(Wf^t), Xt*(Wc^T) + // Ht-1*(Ri^T), Ht-1*(Ro^T), Ht-1*(Rf^t), Ht-1*(Rc^T) + + // Allocate memory for storing matrix multiplication results. + SmallVector xwIOFC, hrIOFC; + Value zero = emitConstantOp(rewriter, loc, elementType, 0); + MemRefType scalarMemRefType = MemRefType::get({}, elementType, {}, 0); + for (unsigned i = 0; i < 4; ++i) { + Value xwAlloc = rewriter.create(loc, scalarMemRefType); + rewriter.create(loc, zero, xwAlloc); + Value hrAlloc = rewriter.create(loc, scalarMemRefType); + rewriter.create(loc, zero, hrAlloc); + xwIOFC.emplace_back(xwAlloc); + hrIOFC.emplace_back(hrAlloc); + } + + { // Emit instructions for matrix multiplications. + // input_size is the reduction dimension. + BuildKrnlLoop reductionLoops(rewriter, loc, 1); + reductionLoops.createDefineAndOptimizeOp(); + reductionLoops.pushBounds(0, inputDimSize); + reductionLoops.createIterateOp(); + + auto ipReductionLoops = rewriter.saveInsertionPoint(); + rewriter.setInsertionPointToStart(reductionLoops.getIterateBlock()); + { + auto reductionIV = reductionLoops.getInductionVar(0); + // Prepare IVs for accessing the input tensor and parameters. + SmallVector xIVs; + SmallVector, 4> wIOFCIVs, rIOFCIVs; + + // X :: [seq_length, batch_size, input_size] + xIVs = {sequenceIV, batchIV, reductionIV}; + + // W[iofc] :: [num_directions, 4*hidden_size, input_size] + // R[iofc] :: [num_directions, 4*hidden_size, input_size] + for (unsigned i = 0; i < 4; ++i) { + SmallVector wIVs, rIVs; + Value wHiddenIV = + rewriter.create(loc, accessByOffsetMap, + ValueRange(std::vector{ + hiddenIV, constantIndices[i], hiddenDimVal})); + + wIVs = {directionIV, wHiddenIV, reductionIV}; + wIOFCIVs.emplace_back(wIVs); + + rIVs = {directionIV, wHiddenIV, reductionIV}; + rIOFCIVs.emplace_back(rIVs); + } + + Value loadX = rewriter.create(loc, operandAdaptor.X(), xIVs); + for (unsigned i = 0; i < 4; ++i) { + // Xt * Wiofc + Value loadW = + rewriter.create(loc, operandAdaptor.W(), wIOFCIVs[i]); + Value xwVal = rewriter.create(loc, loadX, loadW); + Value loadXW = rewriter.create(loc, xwIOFC[i]); + Value nextXW = rewriter.create(loc, loadXW, xwVal); + rewriter.create(loc, nextXW, xwIOFC[i]); + // Ht-1 * Riofc + Value loadR = + rewriter.create(loc, operandAdaptor.R(), rIOFCIVs[i]); + Value hrVal = rewriter.create(loc, loadH, loadR); + Value loadHR = rewriter.create(loc, hrIOFC[i]); + Value nextHR = rewriter.create(loc, loadHR, hrVal); + rewriter.create(loc, nextHR, hrIOFC[i]); + } + } + rewriter.restoreInsertionPoint(ipReductionLoops); + } + + // it = f(Xt*(Wi^T) + Ht-1*(Ri^T) + Pi (.) Ct-1 + Wbi + Rbi) + Value loadXWI = rewriter.create(loc, xwIOFC[0]); + Value loadHRI = rewriter.create(loc, hrIOFC[0]); + Value it = rewriter.create(loc, loadXWI, loadHRI); + if (hasPeepholes) { + Value loadP = + rewriter.create(loc, operandAdaptor.P(), pIOFIVs[0]); + Value PC = rewriter.create(loc, loadP, loadC); + it = rewriter.create(loc, it, PC); + } + if (hasBiasForInput) { + Value loadWB = + rewriter.create(loc, operandAdaptor.B(), wbIOFCIVs[0]); + it = rewriter.create(loc, it, loadWB); + Value loadRB = + rewriter.create(loc, operandAdaptor.B(), rbIOFCIVs[0]); + it = rewriter.create(loc, it, loadRB); + } + it = applyActivation(rewriter, loc, activationPack.f, it); + + // ft = f(Xt*(Wf^T) + Ht-1*(Rf^T) + Pf (.) Ct-1 + Wbf + Rbf) + Value loadXWF = rewriter.create(loc, xwIOFC[2]); + Value loadHRF = rewriter.create(loc, hrIOFC[2]); + Value ft = rewriter.create(loc, loadXWF, loadHRF); + if (hasPeepholes) { + Value loadP = + rewriter.create(loc, operandAdaptor.P(), pIOFIVs[2]); + Value PC = rewriter.create(loc, loadP, loadC); + ft = rewriter.create(loc, ft, PC); + } + if (hasBiasForInput) { + Value loadWB = + rewriter.create(loc, operandAdaptor.B(), wbIOFCIVs[2]); + ft = rewriter.create(loc, ft, loadWB); + Value loadRB = + rewriter.create(loc, operandAdaptor.B(), rbIOFCIVs[2]); + ft = rewriter.create(loc, ft, loadRB); + } + ft = applyActivation(rewriter, loc, activationPack.f, ft); + + // ct = g(Xt*(Wc^T) + Ht-1*(Rc^T) + Wbc + Rbc) + Value loadXWC = rewriter.create(loc, xwIOFC[3]); + Value loadHRC = rewriter.create(loc, hrIOFC[3]); + Value ct = rewriter.create(loc, loadXWC, loadHRC); + if (hasBiasForInput) { + Value loadWB = + rewriter.create(loc, operandAdaptor.B(), wbIOFCIVs[3]); + ct = rewriter.create(loc, ct, loadWB); + Value loadRB = + rewriter.create(loc, operandAdaptor.B(), rbIOFCIVs[3]); + ct = rewriter.create(loc, ct, loadRB); + } + ct = applyActivation(rewriter, loc, activationPack.g, ct); + + // Ct = ft (.) Ct-1 + it (.) ct + Value FtCt1 = rewriter.create(loc, ft, loadC); + Value itct = rewriter.create(loc, it, ct); + Value Ct = rewriter.create(loc, FtCt1, itct); + rewriter.create(loc, Ct, state.ct, cIVs); + + // ot = f(Xt*(Wo^T) + Ht-1*(Ro^T) + Po (.) Ct + Wbo + Rbo) + Value loadXWO = rewriter.create(loc, xwIOFC[1]); + Value loadHRO = rewriter.create(loc, hrIOFC[1]); + Value ot = rewriter.create(loc, loadXWO, loadHRO); + if (hasPeepholes) { + Value loadP = + rewriter.create(loc, operandAdaptor.P(), pIOFIVs[1]); + Value PC = rewriter.create(loc, loadP, Ct); + ot = rewriter.create(loc, ot, PC); + } + if (hasBiasForInput) { + Value loadWB = + rewriter.create(loc, operandAdaptor.B(), wbIOFCIVs[1]); + ot = rewriter.create(loc, ot, loadWB); + Value loadRB = + rewriter.create(loc, operandAdaptor.B(), rbIOFCIVs[1]); + ot = rewriter.create(loc, ot, loadRB); + } + ot = applyActivation(rewriter, loc, activationPack.f, ot); + + // Ht = ot (.) h(Ct) + Value hCt = applyActivation(rewriter, loc, activationPack.h, Ct); + Value Ht = rewriter.create(loc, ot, hCt); + rewriter.create(loc, Ht, state.ht, hIVs); + + // Store the current Ht if required. + if (!isNoneType(state.allH)) { + SmallVector allHIVs{sequenceIV, directionIV, batchIV, hiddenIV}; + rewriter.create(loc, Ht, state.allH, allHIVs); + } + + // Deallocate the temporary results of matrix multiplications. + for (Value v : xwIOFC) + rewriter.create(loc, v); + for (Value v : hrIOFC) + rewriter.create(loc, v); + } +} + +template <> +void stateToOutput( + ONNXLSTMOp *op, LstmState state, std::vector &outputs) { + Value noneValue; + outputs.emplace_back((isNoneType(op->Y()) ? noneValue : state.allH)); + outputs.emplace_back((isNoneType(op->Y_h()) ? noneValue : state.ht)); + outputs.emplace_back((isNoneType(op->Y_c()) ? noneValue : state.ct)); +} + +void populateLoweringONNXLSTMOpPattern( + OwningRewritePatternList &patterns, MLIRContext *ctx) { + patterns.insert>( + ctx); +} diff --git a/src/Conversion/ONNXToKrnl/RNN/RNNBase.cpp b/src/Conversion/ONNXToKrnl/RNN/RNNBase.cpp new file mode 100644 index 0000000..2ce79c2 --- /dev/null +++ b/src/Conversion/ONNXToKrnl/RNN/RNNBase.cpp @@ -0,0 +1,73 @@ +//===--------------- RNNBase.cpp - Lowering RNN Ops -----------------------===// +// +// Copyright 2019 The IBM Research Authors. +// +// ============================================================================= +// +// This file defines base functions for lowerng the ONNX RNN Operators. +// +//===----------------------------------------------------------------------===// + +#include "src/Conversion/ONNXToKrnl/RNN/RNNBase.hpp" + +using namespace mlir; + +// Check a Value's type is none or not. +bool isNoneType(Value val) { return val.getType().isa(); } + +// Get a dimension of the tensor's shape. +int64_t dimAt(Value val, int index) { + return val.getType().cast().getShape()[index]; +} + +// Apply an activation function on a given scalar operand. +Value applyActivation(ConversionPatternRewriter &rewriter, Location loc, + RNNActivation activation, Value scalarOperand) { + Value res; + + MemRefType scalarMemRefType = + MemRefType::get({}, scalarOperand.getType(), {}, 0); + Value alloc = rewriter.create(loc, scalarMemRefType); + rewriter.create(loc, scalarOperand, alloc); + + std::vector attributes; + if (activation.alpha) { + attributes.emplace_back( + rewriter.getNamedAttr("alpha", activation.alpha.getValue())); + } + if (activation.beta) { + attributes.emplace_back( + rewriter.getNamedAttr("beta", activation.beta.getValue())); + } + + if (activation.name.equals_lower("relu")) + res = rewriter.create(loc, scalarMemRefType, alloc); + else if (activation.name.equals_lower("tanh")) + res = rewriter.create(loc, scalarMemRefType, alloc); + else if (activation.name.equals_lower("sigmoid")) + res = rewriter.create(loc, scalarMemRefType, alloc); + else if (activation.name.equals_lower("affine")) + emitError(loc, "Unsupported activation"); + else if (activation.name.equals_lower("leakyrelu")) + res = rewriter.create( + loc, scalarMemRefType, alloc, attributes); + else if (activation.name.equals_lower("thresholdedrelu")) + res = rewriter.create( + loc, scalarMemRefType, alloc, attributes); + else if (activation.name.equals_lower("scaledtanh")) + emitError(loc, "Unsupported activation"); + else if (activation.name.equals_lower("hardsigmoid")) + res = rewriter.create( + loc, scalarMemRefType, alloc, attributes); + else if (activation.name.equals_lower("elu")) + res = rewriter.create(loc, scalarMemRefType, alloc, attributes); + else if (activation.name.equals_lower("softsign")) + res = rewriter.create(loc, scalarMemRefType, alloc); + else if (activation.name.equals_lower("softplus")) + res = rewriter.create(loc, scalarMemRefType, alloc); + else + llvm_unreachable("Unsupported activation"); + + Value result = rewriter.create(loc, res); + return result; +} diff --git a/src/Conversion/ONNXToKrnl/RNN/RNNBase.hpp b/src/Conversion/ONNXToKrnl/RNN/RNNBase.hpp new file mode 100644 index 0000000..29abceb --- /dev/null +++ b/src/Conversion/ONNXToKrnl/RNN/RNNBase.hpp @@ -0,0 +1,144 @@ +//===--------------- RNNBase.hpp - Lowering RNN Ops -----------------------===// +// +// Copyright 2019 The IBM Research Authors. +// +// ============================================================================= +// +// This file defines base functions for lowerng the ONNX RNN Operators. +// +//===----------------------------------------------------------------------===// + +#include "mlir/IR/AffineExpr.h" +#include "src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp" + +using namespace mlir; + +static const StringRef FORWARD = "forward"; +static const StringRef REVERSE = "reverse"; +static const StringRef BIDIRECTIONAL = "bidirectional"; + +struct RNNActivation { + StringRef name; + Optional alpha; + Optional beta; +}; + +// Check a Value's type is none or not. +bool isNoneType(Value val); + +// Get a dimension of the tensor's shape. +int64_t dimAt(Value val, int index); + +// Apply an activation function on a given scalar operand. +Value applyActivation(ConversionPatternRewriter &rewriter, Location loc, + RNNActivation activation, Value scalarOperand); + +// Override the following methods when lowering an RNN operation: +// - hasAllNoneOutput +// - getActivationPack +// - allocAndInitializeStates +// - calculateState +// - stateToOutput + +// Check whether all outputs have NoneType or not. +template +bool hasAllNoneOutput(RNNOp *op); + +// Obtain activations functions for a specific operation. +template +std::tuple getActivationPack(RNNOp *op); + +// Allocate memory for RNN states and initialize them. +template +S allocAndInitializeStates(ConversionPatternRewriter &rewriter, Location loc, + RNNOp *op, OperandAdaptor operandAdaptor); + +// Calculate new states from the current input and states. +template +void calculateState(ConversionPatternRewriter &rewriter, Location loc, + OperandAdaptor operandAdaptor, S state, A activationSet, + Value directionIV, Value sequenceIV); + +// Write states to the RNN's outputs. +template +void stateToOutput(RNNOp *op, S state, std::vector &outputs); + +// A common template for lowering an RNN operation. +template +struct ONNXRNNOpLowering : public ConversionPattern { + ONNXRNNOpLowering(MLIRContext *ctx) + : ConversionPattern(RNNOp::getOperationName(), 1, ctx) {} + + LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const final { + auto loc = op->getLoc(); + + RNNOp rnnOp = llvm::dyn_cast(op); + OperandAdaptor operandAdaptor(operands); + + if (hasAllNoneOutput(&rnnOp)) { + rewriter.eraseOp(op); + return success(); + } + + S state = allocAndInitializeStates( + rewriter, loc, &rnnOp, operandAdaptor); + + A activationForward, activationReverse; + std::tie(activationForward, activationReverse) = + getActivationPack(&rnnOp); + + int64_t sequenceDimSize = dimAt(rnnOp.X(), 0); + auto direction = rnnOp.direction(); + + if (direction == FORWARD || direction == BIDIRECTIONAL) { + BuildKrnlLoop sequenceLoops(rewriter, loc, 1); + sequenceLoops.createDefineAndOptimizeOp(); + sequenceLoops.pushBounds(0, sequenceDimSize); + sequenceLoops.createIterateOp(); + + auto ipSequenceLoops = rewriter.saveInsertionPoint(); + rewriter.setInsertionPointToStart(sequenceLoops.getIterateBlock()); + { + Value directionIV = + emitConstantOp(rewriter, loc, rewriter.getIndexType(), 0); + Value sequenceIV = sequenceLoops.getInductionVar(0); + // Emit calculation for one RNN step. + calculateState(rewriter, loc, operandAdaptor, state, + activationForward, directionIV, sequenceIV); + } + rewriter.restoreInsertionPoint(ipSequenceLoops); + } + + if (direction == REVERSE || direction == BIDIRECTIONAL) { + BuildKrnlLoop sequenceLoops(rewriter, loc, 1); + sequenceLoops.createDefineAndOptimizeOp(); + sequenceLoops.pushBounds(0, sequenceDimSize); + sequenceLoops.createIterateOp(); + + auto ipSequenceLoops = rewriter.saveInsertionPoint(); + rewriter.setInsertionPointToStart(sequenceLoops.getIterateBlock()); + { + AffineMap reverseIVMap = AffineMap::get(1, 1, + rewriter.getAffineSymbolExpr(0) - rewriter.getAffineDimExpr(0) - 1); + + Value directionIV = emitConstantOp(rewriter, loc, + rewriter.getIndexType(), (direction == REVERSE) ? 0 : 1); + Value reverseSequenceIV = + rewriter.create(loc, reverseIVMap, + ValueRange(std::vector{sequenceLoops.getInductionVar(0), + emitConstantOp(rewriter, loc, rewriter.getIndexType(), + sequenceDimSize)})); + // Emit calculation for one RNN step. + calculateState(rewriter, loc, operandAdaptor, state, + activationReverse, directionIV, reverseSequenceIV); + } + rewriter.restoreInsertionPoint(ipSequenceLoops); + } + + std::vector outputs; + stateToOutput(&rnnOp, state, outputs); + rewriter.replaceOp(op, outputs); + return success(); + } +}; diff --git a/src/Dialect/ONNX/ONNXOps.cpp b/src/Dialect/ONNX/ONNXOps.cpp index 1bb3a2c..738c45c 100644 --- a/src/Dialect/ONNX/ONNXOps.cpp +++ b/src/Dialect/ONNX/ONNXOps.cpp @@ -299,6 +299,112 @@ static void insertConvSpatialDim(SmallVector *outputDims, } } +//===----------------------------------------------------------------------===// +// Support function that infers shape for RNN operations. +template +static bool RNNShapeInference(T *op) { + Value X = op->X(); + Value W = op->W(); + Value R = op->R(); + + if (!X.getType().isa() || + !W.getType().isa() || + !R.getType().isa()) + return false; + + auto xTy = X.getType().cast(); + auto elementType = xTy.getElementType(); + + // xShape :: [seq_length, batch_size, input_size] + auto xShape = xTy.getShape(); + // wShape :: [num_directions, 4*hidden_size, input_size] + auto wShape = W.getType().cast().getShape(); + // rShape :: [num_directions, 4*hidden_size, hidden_size] + auto rShape = R.getType().cast().getShape(); + + if (xShape.size() != 3) { + op->emitError("The first input tensor must have rank 3"); + return false; + } + if (wShape.size() != 3) { + op->emitError("The second input tensor must have rank 3"); + return false; + } + if (rShape.size() != 3) { + op->emitError("The third input tensor must have rank 3"); + return false; + } + + // Get sequence length, batch size and input size. + auto sequenceLength = xShape[0]; + auto batchSize = xShape[1]; + auto inputSize = xShape[2]; + + // Get hidden size from hidden_size attribute. + int64_t hiddenSize = -1; + if (op->hidden_size().hasValue()) { + hiddenSize = op->hidden_size().getValue().getSExtValue(); + } else { + // Infer hidden_size from wShape and rShape if possible. + if (rShape[2] != -1) + hiddenSize = rShape[2]; + else if (rShape[1] != -1) + hiddenSize = rShape[1] / 4; + else if (wShape[1] != -1) + hiddenSize = wShape[1] / 4; + // Update hidden_size attribute. + if (hiddenSize != -1) { + auto builder = mlir::Builder(op->getContext()); + op->hidden_sizeAttr(builder.getI64IntegerAttr(hiddenSize)); + } + } + + // Get direction. + int numDirection; + if ((op->direction() == "forward") || (op->direction() == "reverse")) + numDirection = 1; + else if (op->direction() == "bidirectional") + numDirection = 2; + else + numDirection = -1; + if (numDirection == -1) { + op->emitError("direction attribute muse be one of the strings: forward, " + "reverse, and bidirectional"); + return false; + } + + // Set result types. + unsigned numOfResults = op->getNumResults(); + if (numOfResults > 0) { + // Y :: [seq_length, num_directions, batch_size, hidden_size] + Type yTy = op->getResults()[0].getType(); + if (!yTy.isa()) { + yTy = RankedTensorType::get( + {sequenceLength, numDirection, batchSize, hiddenSize}, elementType); + op->getResults()[0].setType(yTy); + } + } + if (numOfResults > 1) { + // Y_h :: [num_directions, batch_size, hidden_size] + Type yhTy = op->getResults()[1].getType(); + if (!yhTy.isa()) { + yhTy = RankedTensorType::get( + {numDirection, batchSize, hiddenSize}, elementType); + op->getResults()[1].setType(yhTy); + } + } + if (numOfResults > 2) { + // Y_c :: [num_directions, batch_size, hidden_size] + Type ycTy = op->getResults()[2].getType(); + if (!ycTy.isa()) { + ycTy = RankedTensorType::get( + {numDirection, batchSize, hiddenSize}, elementType); + op->getResults()[2].setType(ycTy); + } + } + return true; +} + //===----------------------------------------------------------------------===// // ONNXOpsDialect //===----------------------------------------------------------------------===// @@ -1472,7 +1578,6 @@ bool ONNXConstantOp::inferShapes() { return true; } -//===----------------------------------------------------------------------===// // Concat bool ONNXConcatOp::inferShapes() { @@ -1537,6 +1642,21 @@ bool ONNXConcatOp::inferShapes() { return true; } +//===----------------------------------------------------------------------===// +// RNN + +bool ONNXRNNOp::inferShapes() { return RNNShapeInference<>(this); } + +//===----------------------------------------------------------------------===// +// LSTM + +bool ONNXLSTMOp::inferShapes() { return RNNShapeInference<>(this); } + +//===----------------------------------------------------------------------===// +// GRU + +bool ONNXGRUOp::inferShapes() { return RNNShapeInference<>(this); } + //===----------------------------------------------------------------------===// // Split diff --git a/src/Dialect/ONNX/ONNXOps.td.inc b/src/Dialect/ONNX/ONNXOps.td.inc index 3b2dc86..f047ae3 100644 --- a/src/Dialect/ONNX/ONNXOps.td.inc +++ b/src/Dialect/ONNX/ONNXOps.td.inc @@ -738,7 +738,7 @@ def ONNXFloorOp:ONNX_Op<"Floor", } def ONNXGRUOp:ONNX_Op<"GRU", - [NoSideEffect]> { + [NoSideEffect, DeclareOpInterfaceMethods]> { let summary = "ONNX GRU operation"; let description = [{ "Computes an one-layer GRU. This operator is usually supported via some custom" @@ -1246,7 +1246,7 @@ def ONNXLRNOp:ONNX_Op<"LRN", } def ONNXLSTMOp:ONNX_Op<"LSTM", - [NoSideEffect]> { + [NoSideEffect, DeclareOpInterfaceMethods]> { let summary = "ONNX LSTM operation"; let description = [{ "Computes an one-layer LSTM. This operator is usually supported via some" @@ -2086,7 +2086,7 @@ def ONNXQuantizeLinearOp:ONNX_Op<"QuantizeLinear", } def ONNXRNNOp:ONNX_Op<"RNN", - [NoSideEffect]> { + [NoSideEffect, DeclareOpInterfaceMethods]> { let summary = "ONNX RNN operation"; let description = [{ "Computes an one-layer simple RNN. This operator is usually supported" diff --git a/src/Transform/ONNX/ShapeInferencePass.cpp b/src/Transform/ONNX/ShapeInferencePass.cpp index 97f9794..9b81db7 100644 --- a/src/Transform/ONNX/ShapeInferencePass.cpp +++ b/src/Transform/ONNX/ShapeInferencePass.cpp @@ -126,6 +126,9 @@ public: op->getName().getStringRef() != "onnx.Concat" && op->getName().getStringRef() != "onnx.Split" && op->getName().getStringRef() != "onnx.Neg" && + op->getName().getStringRef() != "onnx.RNN" && + op->getName().getStringRef() != "onnx.LSTM" && + op->getName().getStringRef() != "onnx.GRU" && op->getName().getStringRef() != "onnx.Unsqueeze") return false; return llvm::any_of(op->getResultTypes(), [](Type result_type) { diff --git a/test/backend/test.py b/test/backend/test.py index 36c37df..8914f7f 100644 --- a/test/backend/test.py +++ b/test/backend/test.py @@ -356,6 +356,11 @@ test_to_enable = [ "test_averagepool_2d_strides_cpu", "test_averagepool_3d_default_cpu", + # LSTM + "test_lstm_defaults_cpu", + "test_lstm_with_initial_bias_cpu", + "test_lstm_with_peepholes_cpu", + ] # Extract name of all test cases. diff --git a/test/mlir/onnx/onnx_lowering.mlir b/test/mlir/onnx/onnx_lowering.mlir index a9eff20..89cb551 100644 --- a/test/mlir/onnx/onnx_lowering.mlir +++ b/test/mlir/onnx/onnx_lowering.mlir @@ -1769,3 +1769,314 @@ func @test_maxpool_pooling_operation(%arg0 : tensor<1x3x32x32xf32>) -> tensor<*x // CHECK: } } +// ----- + +func @test_lstm_general_computation(%arg0: tensor<4x3x2xf32>, %arg1: tensor<1x12x2xf32>, %arg2: tensor<1x12x3xf32>) -> tensor<*xf32> { + %cst = constant unit + %Y, %Y_h, %Y_c = "onnx.LSTM"(%arg0, %arg1, %arg2, %cst, %cst, %cst, %cst, %cst) {hidden_size = 3 : i64} : (tensor<4x3x2xf32>, tensor<1x12x2xf32>, tensor<1x12x3xf32>, none, none, none, none, none) -> (none, tensor<*xf32>, none) + return %Y_h : tensor<*xf32> + + // CHECK-DAG: [[ACCESS_BY_OFFSET_MAP:#.+]] = affine_map<(d0)[s0, s1] -> (d0 + s0 * s1)> + // CHECK-LABEL: @test_lstm_general_computation + + // CHECK: [[CELL_STATE:%.+]] = alloc() : memref<1x3x3xf32> + // CHECK: [[HIDDEN_STATE:%.+]] = alloc() : memref<1x3x3xf32> + // CHECK: {{.*}} = constant unit + + // CHECK: [[INITIAL_VALUE:%.+]] = constant 0.000000e+00 : f32 + // CHECK: [[INITIALIZE_LOOPS:%.+]]:3 = krnl.define_loops 3 + // CHECK: [[INITIALIZE_OPT_LOOPS:%.+]]:3 = krnl.optimize_loops { + // CHECK: krnl.return_loops [[INITIALIZE_LOOPS]]#0, [[INITIALIZE_LOOPS]]#1, [[INITIALIZE_LOOPS]]#2 + // CHECK: } : () -> (!krnl.loop, !krnl.loop, !krnl.loop) + // CHECK: krnl.iterate([[INITIALIZE_OPT_LOOPS]]#0, [[INITIALIZE_OPT_LOOPS]]#1, [[INITIALIZE_OPT_LOOPS]]#2) with ([[INITIALIZE_LOOPS]]#0 -> %arg3 = 0 to 1, [[INITIALIZE_LOOPS]]#1 -> %arg4 = 0 to 3, [[INITIALIZE_LOOPS]]#2 -> %arg5 = 0 to 3) { + // CHECK: store [[INITIAL_VALUE]], [[HIDDEN_STATE]][%arg3, %arg4, %arg5] : memref<1x3x3xf32> + // CHECK: store [[INITIAL_VALUE]], [[CELL_STATE]][%arg3, %arg4, %arg5] : memref<1x3x3xf32> + // CHECK: } + + // CHECK: [[SEQUENCE_LOOPS:%.+]] = krnl.define_loops 1 + // CHECK: [[SEQUENCE_OPT_LOOPS:%.+]] = krnl.optimize_loops { + // CHECK: krnl.return_loops [[SEQUENCE_LOOPS]] + // CHECK: } : () -> !krnl.loop + // CHECK: krnl.iterate([[SEQUENCE_OPT_LOOPS]]) with ([[SEQUENCE_LOOPS]] -> %arg3 = 0 to 4) { + // CHECK: {{.*}} = constant 0 : index + // CHECK: {{.*}} = constant 3 : index + // CHECK: {{.*}} = constant 0 : index + // CHECK: {{.*}} = constant 1 : index + // CHECK: {{.*}} = constant 2 : index + // CHECK: {{.*}} = constant 3 : index + // CHECK: {{.*}} = constant 4 : index + // CHECK: {{.*}} = constant 5 : index + // CHECK: {{.*}} = constant 6 : index + // CHECK: {{.*}} = constant 7 : index + // CHECK: [[DATA_LOOPS:%.+]]:2 = krnl.define_loops 2 + // CHECK: [[DATA_OPT_LOOPS:%.+]]:2 = krnl.optimize_loops { + // CHECK: krnl.return_loops [[DATA_LOOPS]]#0, [[DATA_LOOPS]]#1 + // CHECK: } : () -> (!krnl.loop, !krnl.loop) + // CHECK: krnl.iterate([[DATA_OPT_LOOPS]]#0, [[DATA_OPT_LOOPS]]#1) with ([[DATA_LOOPS]]#0 -> %arg4 = 0 to 3, [[DATA_LOOPS]]#1 -> %arg5 = 0 to 3) { + // CHECK: [[hCt:%.+]] = alloc() : memref + // CHECK: [[Ot:%.+]] = alloc() : memref + // CHECK: [[ct:%.+]] = alloc() : memref + // CHECK: [[Ft:%.+]] = alloc() : memref + // CHECK: [[It:%.+]] = alloc() : memref + // CHECK: [[Ht1_LOAD:%.+]] = load [[HIDDEN_STATE]][%c0, %arg4, %arg5] : memref<1x3x3xf32> + // CHECK: [[Ct1_LOAD:%.+]] = load [[CELL_STATE]][%c0, %arg4, %arg5] : memref<1x3x3xf32> + + // CHECK: [[ZERO_FLOAT:%.+]] = constant 0.000000e+00 : f32 + // CHECK: [[XtWi_GEMM:%.+]] = alloc() : memref + // CHECK: store [[ZERO_FLOAT]], [[XtWi_GEMM]][] : memref + // CHECK: [[Ht1Ri_GEMM:%.+]] = alloc() : memref + // CHECK: store [[ZERO_FLOAT]], [[Ht1Ri_GEMM]][] : memref + // CHECK: [[XtWo_GEMM:%.+]] = alloc() : memref + // CHECK: store [[ZERO_FLOAT]], [[XtWo_GEMM]][] : memref + // CHECK: [[Ht1Ro_GEMM:%.+]] = alloc() : memref + // CHECK: store [[ZERO_FLOAT]], [[Ht1Ro_GEMM]][] : memref + // CHECK: [[XtWf_GEMM:%.+]] = alloc() : memref + // CHECK: store [[ZERO_FLOAT]], [[XtWf_GEMM]][] : memref + // CHECK: [[Ht1Rf_GEMM:%.+]] = alloc() : memref + // CHECK: store [[ZERO_FLOAT]], [[Ht1Rf_GEMM]][] : memref + // CHECK: [[XtWc_GEMM:%.+]] = alloc() : memref + // CHECK: store [[ZERO_FLOAT]], [[XtWc_GEMM]][] : memref + // CHECK: [[Ht1Rc_GEMM:%.+]] = alloc() : memref + // CHECK: store [[ZERO_FLOAT]], [[Ht1Rc_GEMM]][] : memref + + // CHECK: [[REDUCTION_LOOPS:%.+]] = krnl.define_loops 1 + // CHECK: [[REDUCTION_OPT_LOOPS:%.+]] = krnl.optimize_loops { + // CHECK: krnl.return_loops [[REDUCTION_LOOPS]] + // CHECK: } : () -> !krnl.loop + // CHECK: krnl.iterate([[REDUCTION_OPT_LOOPS]]) with ([[REDUCTION_LOOPS]] -> %arg6 = 0 to 2) { + // CHECK: [[INPUT_HIDDEN_INDEX:%.+]] = affine.apply #{{.*}}(%arg5)[%c0_1, %c3] + // CHECK: [[OUTPUT_HIDDEN_INDEX:%.+]] = affine.apply #{{.*}}(%arg5)[%c1, %c3] + // CHECK: [[FORGET_HIDDEN_INDEX:%.+]] = affine.apply #{{.*}}(%arg5)[%c2, %c3] + // CHECK: [[CELL_HIDDEN_INDEX:%.+]] = affine.apply #{{.*}}(%arg5)[%c3_2, %c3] + // CHECK: [[Xt_LOAD:%.+]] = load %arg0[%arg3, %arg4, %arg6] : memref<4x3x2xf32> + + // CHECK: [[Wi_LOAD:%.+]] = load %arg1[%c0, [[INPUT_HIDDEN_INDEX]], %arg6] : memref<1x12x2xf32> + // CHECK: {{.*}} = mulf [[Xt_LOAD]], [[Wi_LOAD]] : f32 + // CHECK: {{.*}} = load [[XtWi_GEMM]][] : memref + // CHECK: {{.*}} = addf {{.*}}, {{.*}} : f32 + // CHECK: store %59, [[XtWi_GEMM]][] : memref + + // CHECK: [[Ri_LOAD:%.+]] = load %arg2[%c0, [[INPUT_HIDDEN_INDEX]], %arg6] : memref<1x12x3xf32> + // CHECK: {{.*}} = mulf [[Ht1_LOAD]], [[Ri_LOAD]] : f32 + // CHECK: {{.*}} = load [[Ht1Ri_GEMM]][] : memref + // CHECK: {{.*}} = addf {{.*}}, {{.*}} : f32 + // CHECK: store %63, [[Ht1Ri_GEMM]][] : memref + + // CHECK: [[Wo_LOAD:%.+]] = load %arg1[%c0, [[OUTPUT_HIDDEN_INDEX]], %arg6] : memref<1x12x2xf32> + // CHECK: {{.*}} = mulf [[Xt_LOAD]], [[Wo_LOAD]] : f32 + // CHECK: {{.*}} = load [[XtWo_GEMM]][] : memref + // CHECK: {{.*}} = addf {{.*}}, {{.*}} : f32 + // CHECK: store %67, [[XtWo_GEMM]][] : memref + + // CHECK: [[Ro_LOAD:%.+]] = load %arg2[%c0, [[OUTPUT_HIDDEN_INDEX]], %arg6] : memref<1x12x3xf32> + // CHECK: {{.*}} = mulf [[Ht1_LOAD]], [[Ro_LOAD]] : f32 + // CHECK: {{.*}} = load [[Ht1Ro_GEMM]][] : memref + // CHECK: {{.*}} = addf {{.*}}, {{.*}} : f32 + // CHECK: store %71, [[Ht1Ro_GEMM]][] : memref + + // CHECK: [[Wf_LOAD:%.+]] = load %arg1[%c0, [[FORGET_HIDDEN_INDEX]], %arg6] : memref<1x12x2xf32> + // CHECK: {{.*}} = mulf [[Xt_LOAD]], [[Wf_LOAD]] : f32 + // CHECK: {{.*}} = load [[XtWf_GEMM]][] : memref + // CHECK: {{.*}} = addf {{.*}}, {{.*}} : f32 + // CHECK: store %75, [[XtWf_GEMM]][] : memref + + // CHECK: [[Rf_LOAD:%.+]] = load %arg2[%c0, [[FORGET_HIDDEN_INDEX]], %arg6] : memref<1x12x3xf32> + // CHECK: {{.*}} = mulf [[Ht1_LOAD]], [[Rf_LOAD]] : f32 + // CHECK: {{.*}} = load [[Ht1Rf_GEMM]][] : memref + // CHECK: {{.*}} = addf {{.*}}, {{.*}} : f32 + // CHECK: store %79, [[Ht1Rf_GEMM]][] : memref + + // CHECK: [[Wc_LOAD:%.+]] = load %arg1[%c0, [[CELL_HIDDEN_INDEX]], %arg6] : memref<1x12x2xf32> + // CHECK: {{.*}} = mulf [[Xt_LOAD]], [[Wc_LOAD]] : f32 + // CHECK: {{.*}} = load [[XtWc_GEMM]][] : memref + // CHECK: {{.*}} = addf {{.*}}, {{.*}} : f32 + // CHECK: store %83, [[XtWc_GEMM]][] : memref + + // CHECK: [[Rc_LOAD:%.+]] = load %arg2[%c0, [[CELL_HIDDEN_INDEX]], %arg6] : memref<1x12x3xf32> + // CHECK: {{.*}} = mulf [[Ht1_LOAD]], [[Rc_LOAD]] : f32 + // CHECK: {{.*}} = load [[Ht1Rc_GEMM]][] : memref + // CHECK: {{.*}} = addf {{.*}}, {{.*}} : f32 + // CHECK: store %87, [[Ht1Rc_GEMM]][] : memref + // CHECK: } + + // CHECK: [[XtWi_LOAD:%.+]] = load [[XtWi_GEMM]][] : memref + // CHECK: [[Ht1Ri_LOAD:%.+]] = load [[Ht1Ri_GEMM]][] : memref + // CHECK: [[It_OUTPUT:%.+]] = addf [[XtWi_LOAD]], [[Ht1Ri_LOAD]] : f32 + + // CHECK: [[SIGMOID_INPUT:%.+]] = alloc() : memref + // CHECK: store [[It_OUTPUT]], [[SIGMOID_INPUT]][] : memref + // CHECK: krnl.define_loops 0 + // CHECK: krnl.optimize_loops { + // CHECK: krnl.return_loops + // CHECK: } : () -> () + // CHECK: krnl.iterate() with () { + // CHECK: {{.*}} = load [[SIGMOID_INPUT]][] : memref + // CHECK: {{.*}} = constant 0.000000e+00 : f32 + // CHECK: {{.*}} = constant 1.000000e+00 : f32 + // CHECK: {{.*}} = subf {{.*}}, {{.*}}: f32 + // CHECK: {{.*}} = exp {{.*}} : f32 + // CHECK: {{.*}} = addf {{.*}}, {{.*}} : f32 + // CHECK: {{.*}} = divf {{.*}}, {{.*}} : f32 + // CHECK: store {{.*}}, [[It]][] : memref + // CHECK: } + // CHECK: [[It_LOAD:%.+]] = load [[It]][] : memref + + // CHECK: [[XtWf_LOAD:%.+]] = load [[XtWf_GEMM]][] : memref + // CHECK: [[Ht1Rf_LOAD:%.+]] = load [[Ht1Rf_GEMM]][] : memref + // CHECK: [[Ft_OUTPUT:%.+]] = addf [[XtWf_LOAD]], [[Ht1Rf_LOAD]] : f32 + + // CHECK: [[SIGMOID_FORGET:%.+]] = alloc() : memref + // CHECK: store [[Ft_OUTPUT]], [[SIGMOID_FORGET]][] : memref + // CHECK: krnl.define_loops 0 + // CHECK: krnl.optimize_loops { + // CHECK: krnl.return_loops + // CHECK: } : () -> () + // CHECK: krnl.iterate() with () { + // CHECK: {{.*}} = load [[SIGMOID_FORGET]][] : memref + // CHECK: {{.*}} = constant 0.000000e+00 : f32 + // CHECK: {{.*}} = constant 1.000000e+00 : f32 + // CHECK: {{.*}} = subf {{.*}}, {{.*}}: f32 + // CHECK: {{.*}} = exp {{.*}} : f32 + // CHECK: {{.*}} = addf {{.*}}, {{.*}} : f32 + // CHECK: {{.*}} = divf {{.*}}, {{.*}} : f32 + // CHECK: store {{.*}}, [[Ft]][] : memref + // CHECK: } + // CHECK: [[Ft_LOAD:%.+]] = load [[Ft]][] : memref + + // CHECK: [[XtWc_LOAD:%.+]] = load [[XtWc_GEMM]][] : memref + // CHECK: [[Ht1Rc_LOAD:%.+]] = load [[Ht1Rc_GEMM]][] : memref + // CHECK: [[ct_OUTPUT:%.+]] = addf [[XtWc_LOAD]], [[Ht1Rc_LOAD]] : f32 + + // CHECK: [[TANH_CELL:%.+]] = alloc() : memref + // CHECK: store [[ct_OUTPUT]], [[TANH_CELL]][] : memref + // CHECK: krnl.define_loops 0 + // CHECK: krnl.optimize_loops { + // CHECK: krnl.return_loops + // CHECK: } : () -> () + // CHECK: krnl.iterate() with () { + // CHECK: {{.*}} = load [[TANH_CELL]][] : memref + // CHECK: {{.*}} = constant 0.000000e+00 : f32 + // CHECK: {{.*}} = subf {{.*}}, {{.*}} : f32 + // CHECK: {{.*}} = exp {{.*}} : f32 + // CHECK: {{.*}} = exp {{.*}} : f32 + // CHECK: {{.*}} = subf {{.*}}, {{.*}} : f32 + // CHECK: {{.*}} = addf {{.*}}, {{.*}} : f32 + // CHECK: {{.*}} = divf {{.*}}, {{.*}} : f32 + // CHECK: store {{.*}}, [[ct]][] : memref + // CHECK: } + // CHECK: [[ct_LOAD:%.+]] = load [[ct]][] : memref + + // CHECK: [[FtCt1:%.+]] = mulf [[Ft_LOAD]], [[Ct1_LOAD]] : f32 + // CHECK: [[Itct:%.+]] = mulf [[It_LOAD]], [[ct_LOAD]] : f32 + // CHECK: [[Ct:%.+]] = addf [[FtCt1]], [[Itct]] : f32 + // CHECK: store [[Ct]], [[CELL_STATE]][%c0, %arg4, %arg5] : memref<1x3x3xf32> + + // CHECK: [[XtWo_LOAD:%.+]] = load [[XtWo_GEMM]][] : memref + // CHECK: [[Ht1Ro_LOAD:%.+]] = load [[Ht1Ro_GEMM]][] : memref + // CHECK: [[Ot_OUTPUT:%.+]] = addf [[XtWo_LOAD]], [[Ht1Ro_LOAD]] : f32 + + // CHECK: [[SIGMOID_OUTPUT:%.+]] = alloc() : memref + // CHECK: store [[Ot_OUTPUT]], [[SIGMOID_OUTPUT]][] : memref + // CHECK: krnl.define_loops 0 + // CHECK: krnl.optimize_loops { + // CHECK: krnl.return_loops + // CHECK: } : () -> () + // CHECK: krnl.iterate() with () { + // CHECK: {{.*}} = load [[SIGMOID_OUTPUT]][] : memref + // CHECK: {{.*}} = constant 0.000000e+00 : f32 + // CHECK: {{.*}} = constant 1.000000e+00 : f32 + // CHECK: {{.*}} = subf {{.*}}, {{.*}}: f32 + // CHECK: {{.*}} = exp {{.*}} : f32 + // CHECK: {{.*}} = addf {{.*}}, {{.*}} : f32 + // CHECK: {{.*}} = divf {{.*}}, {{.*}} : f32 + // CHECK: store {{.*}}, [[Ot]][] : memref + // CHECK: } + // CHECK: [[Ot_LOAD:%.+]] = load [[Ot]][] : memref + + // CHECK: [[TANH_HIDDEN:%.+]] = alloc() : memref + // CHECK: store [[Ct]], [[TANH_HIDDEN]][] : memref + // CHECK: krnl.define_loops 0 + // CHECK: krnl.optimize_loops { + // CHECK: krnl.return_loops + // CHECK: } : () -> () + // CHECK: krnl.iterate() with () { + // CHECK: {{.*}} = load [[TANH_HIDDEN]][] : memref + // CHECK: {{.*}} = constant 0.000000e+00 : f32 + // CHECK: {{.*}} = subf {{.*}}, {{.*}} : f32 + // CHECK: {{.*}} = exp {{.*}} : f32 + // CHECK: {{.*}} = exp {{.*}} : f32 + // CHECK: {{.*}} = subf {{.*}}, {{.*}} : f32 + // CHECK: {{.*}} = addf {{.*}}, {{.*}} : f32 + // CHECK: {{.*}} = divf {{.*}}, {{.*}} : f32 + // CHECK: store {{.*}}, [[hCt]][] : memref + // CHECK: } + // CHECK: [[hCt_LOAD:%.+]] = load [[hCt]][] : memref + + // CHECK: [[Ht:%.+]] = mulf [[Ot_LOAD]], [[hCt_LOAD]] : f32 + // CHECK: store [[Ht]], [[HIDDEN_STATE]][%c0, %arg4, %arg5] : memref<1x3x3xf32> + + // CHECK: dealloc [[XtWi_GEMM]] : memref + // CHECK: dealloc [[XtWo_GEMM]] : memref + // CHECK: dealloc [[XtWf_GEMM]] : memref + // CHECK: dealloc [[XtWc_GEMM]] : memref + // CHECK: dealloc [[Ht1Ri_GEMM]] : memref + // CHECK: dealloc [[Ht1Ro_GEMM]] : memref + // CHECK: dealloc [[Ht1Rf_GEMM]] : memref + // CHECK: dealloc [[Ht1Rc_GEMM]] : memref + // CHECK: dealloc [[It]] : memref + // CHECK: dealloc [[Ft]] : memref + // CHECK: dealloc [[ct]] : memref + // CHECK: dealloc [[Ot]] : memref + // CHECK: dealloc [[hCt]] : memref + // CHECK: } + // CHECK: } + // CHECK: dealloc [[CELL_STATE]] : memref<1x3x3xf32> + // CHECK: return [[HIDDEN_STATE]] : memref<1x3x3xf32> +} + +// ----- + +func @test_lstm_reverse_mode(%arg0: tensor<4x3x2xf32>, %arg1: tensor<1x12x2xf32>, %arg2: tensor<1x12x3xf32>) -> tensor<*xf32> { + %cst = constant unit + %Y, %Y_h, %Y_c = "onnx.LSTM"(%arg0, %arg1, %arg2, %cst, %cst, %cst, %cst, %cst) {hidden_size = 3 : i64, direction = "reverse"} : (tensor<4x3x2xf32>, tensor<1x12x2xf32>, tensor<1x12x3xf32>, none, none, none, none, none) -> (none, tensor<*xf32>, none) + return %Y_h : tensor<*xf32> + + // CHECK: [[REVERSE_IV_MAP:#.+]] = affine_map<(d0)[s0] -> (-d0 + s0 - 1)> + // CHECK-LABEL: @test_lstm_reverse_mode + + // CHECK: [[REVERSE_SEQUENCE_LOOPS:%.+]] = krnl.define_loops 1 + // CHECK: [[REVERSE_SEQUENCE_OPT_LOOPS:%.+]] = krnl.optimize_loops { + // CHECK: krnl.return_loops [[REVERSE_SEQUENCE_LOOPS]] + // CHECK: } : () -> !krnl.loop + // CHECK: krnl.iterate([[REVERSE_SEQUENCE_OPT_LOOPS]]) with ([[REVERSE_SEQUENCE_LOOPS]] -> %arg3 = 0 to 4) { + // CHECK: %[[SEQUENCE_LEN:.+]] = constant 4 : index + // CHECK: %[[REVERSE_SEQUENCE_IV:.+]] = affine.apply [[REVERSE_IV_MAP]](%arg3)[%[[SEQUENCE_LEN]]{{]}} + // CHECK: [[Xt_LOAD:%.+]] = load %arg0[%[[REVERSE_SEQUENCE_IV]], {{.*}}, {{.*}}] : memref<4x3x2xf32> +} + +// ----- + +func @test_lstm_bidirectional_mode(%arg0: tensor<4x3x2xf32>, %arg1: tensor<1x12x2xf32>, %arg2: tensor<1x12x3xf32>) -> tensor<*xf32> { + %cst = constant unit + %Y, %Y_h, %Y_c = "onnx.LSTM"(%arg0, %arg1, %arg2, %cst, %cst, %cst, %cst, %cst) {hidden_size = 3 : i64, direction = "bidirectional"} : (tensor<4x3x2xf32>, tensor<1x12x2xf32>, tensor<1x12x3xf32>, none, none, none, none, none) -> (none, tensor<*xf32>, none) + return %Y_h : tensor<*xf32> + + // CHECK: [[REVERSE_IV_MAP:#.+]] = affine_map<(d0)[s0] -> (-d0 + s0 - 1)> + // CHECK-LABEL: @test_lstm_bidirectional_mode + + // CHECK: [[SEQUENCE_LOOPS:%.+]] = krnl.define_loops 1 + // CHECK: [[SEQUENCE_OPT_LOOPS:%.+]] = krnl.optimize_loops { + // CHECK: krnl.return_loops [[SEQUENCE_LOOPS]] + // CHECK: } : () -> !krnl.loop + // CHECK: krnl.iterate([[SEQUENCE_OPT_LOOPS]]) with ([[SEQUENCE_LOOPS]] -> %arg3 = 0 to 4) { + // CHECK: [[Xt_LOAD:%.+]] = load %arg0[%arg3, {{.*}}, {{.*}}] : memref<4x3x2xf32> + + // CHECK: [[REVERSE_SEQUENCE_LOOPS:%.+]] = krnl.define_loops 1 + // CHECK: [[REVERSE_SEQUENCE_OPT_LOOPS:%.+]] = krnl.optimize_loops { + // CHECK: krnl.return_loops [[REVERSE_SEQUENCE_LOOPS]] + // CHECK: } : () -> !krnl.loop + // CHECK: krnl.iterate([[REVERSE_SEQUENCE_OPT_LOOPS]]) with ([[REVERSE_SEQUENCE_LOOPS]] -> %arg3 = 0 to 4) { + // CHECK: %[[SEQUENCE_LEN:.+]] = constant 4 : index + // CHECK: %[[REVERSE_SEQUENCE_IV:.+]] = affine.apply [[REVERSE_IV_MAP]](%arg3)[%[[SEQUENCE_LEN]]{{]}} + // CHECK: [[Xt_LOAD:%.+]] = load %arg0[%[[REVERSE_SEQUENCE_IV]], {{.*}}, {{.*}}] : memref<4x3x2xf32> +} diff --git a/test/mlir/onnx/onnx_shape_inference.mlir b/test/mlir/onnx/onnx_shape_inference.mlir index 8bee657..faed4ee 100644 --- a/test/mlir/onnx/onnx_shape_inference.mlir +++ b/test/mlir/onnx/onnx_shape_inference.mlir @@ -613,6 +613,222 @@ func @test_concat_3(%arg0 : tensor<5x1x32xf32>, %arg1 : tensor<5x3x32xf32>, %arg // ----- +func @test_rnn_all_results(%arg0: tensor<4x3x2xf32>, %arg1: tensor<1x12x2xf32>, %arg2: tensor<1x12x3xf32>) -> tensor<*xf32> { + %cst = constant unit + %Y, %Y_h = "onnx.RNN"(%arg0, %arg1, %arg2, %cst, %cst, %cst) {hidden_size = 3 : i64} : (tensor<4x3x2xf32>, tensor<1x12x2xf32>, tensor<1x12x3xf32>, none, none, none) -> (tensor<*xf32>, tensor<*xf32>) + return %Y_h : tensor<*xf32> + + // CHECK-LABEL: test_rnn_all_results + // CHECK: %{{.*}}, [[RES:%.+]] = "onnx.RNN"(%arg0, %arg1, %arg2, %cst, %cst, %cst) {hidden_size = 3 : i64} : (tensor<4x3x2xf32>, tensor<1x12x2xf32>, tensor<1x12x3xf32>, none, none, none) -> (tensor<4x1x3x3xf32>, tensor<1x3x3xf32>) + // CHECK: return [[RES]] : tensor<1x3x3xf32> +} + +// ----- + +func @test_rnn_no_results(%arg0: tensor<4x3x2xf32>, %arg1: tensor<1x12x2xf32>, %arg2: tensor<1x12x3xf32>) -> () { + %cst = constant unit + %Y, %Y_h = "onnx.RNN"(%arg0, %arg1, %arg2, %cst, %cst, %cst) {hidden_size = 3 : i64} : (tensor<4x3x2xf32>, tensor<1x12x2xf32>, tensor<1x12x3xf32>, none, none, none) -> (none, none) + return + + // CHECK-LABEL: test_rnn_no_results + // CHECK: %{{.*}}, [[RES:%.+]] = "onnx.RNN"(%arg0, %arg1, %arg2, %cst, %cst, %cst) {hidden_size = 3 : i64} : (tensor<4x3x2xf32>, tensor<1x12x2xf32>, tensor<1x12x3xf32>, none, none, none) -> (none, none) + // CHECK: return +} + +// ----- + +func @test_rnn_missing_first_result(%arg0: tensor<4x3x2xf32>, %arg1: tensor<1x12x2xf32>, %arg2: tensor<1x12x3xf32>) -> tensor<*xf32> { + %cst = constant unit + %Y, %Y_h = "onnx.RNN"(%arg0, %arg1, %arg2, %cst, %cst, %cst) {hidden_size = 3 : i64} : (tensor<4x3x2xf32>, tensor<1x12x2xf32>, tensor<1x12x3xf32>, none, none, none) -> (none, tensor<*xf32>) + return %Y_h : tensor<*xf32> + + // CHECK-LABEL: test_rnn_missing_first_result + // CHECK: %{{.*}}, [[RES:%.+]] = "onnx.RNN"(%arg0, %arg1, %arg2, %cst, %cst, %cst) {hidden_size = 3 : i64} : (tensor<4x3x2xf32>, tensor<1x12x2xf32>, tensor<1x12x3xf32>, none, none, none) -> (none, tensor<1x3x3xf32>) + // CHECK: return [[RES]] : tensor<1x3x3xf32> +} + +// ----- + +func @test_rnn_missing_trailing_result(%arg0: tensor<4x3x2xf32>, %arg1: tensor<1x12x2xf32>, %arg2: tensor<1x12x3xf32>) -> () { + %cst = constant unit + %Y, %Y_h = "onnx.RNN"(%arg0, %arg1, %arg2, %cst, %cst, %cst) {hidden_size = 3 : i64} : (tensor<4x3x2xf32>, tensor<1x12x2xf32>, tensor<1x12x3xf32>, none, none, none) -> (tensor<*xf32>, none) + return + + // CHECK-LABEL: test_rnn_missing_trailing_result + // CHECK: %{{.*}}, [[RES:%.+]] = "onnx.RNN"(%arg0, %arg1, %arg2, %cst, %cst, %cst) {hidden_size = 3 : i64} : (tensor<4x3x2xf32>, tensor<1x12x2xf32>, tensor<1x12x3xf32>, none, none, none) -> (tensor<4x1x3x3xf32>, none) + // CHECK: return +} + +// ----- + +func @test_rnn_all_results_no_hidden_size(%arg0: tensor<4x3x2xf32>, %arg1: tensor<1x12x2xf32>, %arg2: tensor<1x12x3xf32>) -> tensor<*xf32> { + %cst = constant unit + %Y, %Y_h = "onnx.RNN"(%arg0, %arg1, %arg2, %cst, %cst, %cst) : (tensor<4x3x2xf32>, tensor<1x12x2xf32>, tensor<1x12x3xf32>, none, none, none) -> (tensor<*xf32>, tensor<*xf32>) + return %Y_h : tensor<*xf32> + + // CHECK-LABEL: test_rnn_all_results_no_hidden_size + // CHECK: %{{.*}}, [[RES:%.+]] = "onnx.RNN"(%arg0, %arg1, %arg2, %cst, %cst, %cst) {hidden_size = 3 : i64} : (tensor<4x3x2xf32>, tensor<1x12x2xf32>, tensor<1x12x3xf32>, none, none, none) -> (tensor<4x1x3x3xf32>, tensor<1x3x3xf32>) + // CHECK: return [[RES]] : tensor<1x3x3xf32> +} + +// ----- + +func @test_rnn_all_results_unknown_dims(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor<*xf32> { + %cst = constant unit + %Y, %Y_h = "onnx.RNN"(%arg0, %arg1, %arg2, %cst, %cst, %cst) : (tensor, tensor, tensor, none, none, none) -> (tensor<*xf32>, tensor<*xf32>) + return %Y_h : tensor<*xf32> + + // CHECK-LABEL: test_rnn_all_results_unknown_dims + // CHECK: %{{.*}}, [[RES:%.+]] = "onnx.RNN"(%arg0, %arg1, %arg2, %cst, %cst, %cst) : (tensor, tensor, tensor, none, none, none) -> (tensor, tensor<1x?x?xf32>) + // CHECK: return [[RES]] : tensor<1x?x?xf32> +} + +// ----- + +func @test_gru_all_results(%arg0: tensor<4x3x2xf32>, %arg1: tensor<1x12x2xf32>, %arg2: tensor<1x12x3xf32>) -> tensor<*xf32> { + %cst = constant unit + %Y, %Y_h = "onnx.GRU"(%arg0, %arg1, %arg2, %cst, %cst, %cst) {hidden_size = 3 : i64} : (tensor<4x3x2xf32>, tensor<1x12x2xf32>, tensor<1x12x3xf32>, none, none, none) -> (tensor<*xf32>, tensor<*xf32>) + return %Y_h : tensor<*xf32> + + // CHECK-LABEL: test_gru_all_results + // CHECK: %{{.*}}, [[RES:%.+]] = "onnx.GRU"(%arg0, %arg1, %arg2, %cst, %cst, %cst) {hidden_size = 3 : i64} : (tensor<4x3x2xf32>, tensor<1x12x2xf32>, tensor<1x12x3xf32>, none, none, none) -> (tensor<4x1x3x3xf32>, tensor<1x3x3xf32>) + // CHECK: return [[RES]] : tensor<1x3x3xf32> +} + +// ----- + +func @test_gru_no_results(%arg0: tensor<4x3x2xf32>, %arg1: tensor<1x12x2xf32>, %arg2: tensor<1x12x3xf32>) -> () { + %cst = constant unit + %Y, %Y_h = "onnx.GRU"(%arg0, %arg1, %arg2, %cst, %cst, %cst) {hidden_size = 3 : i64} : (tensor<4x3x2xf32>, tensor<1x12x2xf32>, tensor<1x12x3xf32>, none, none, none) -> (none, none) + return + + // CHECK-LABEL: test_gru_no_results + // CHECK: %{{.*}}, [[RES:%.+]] = "onnx.GRU"(%arg0, %arg1, %arg2, %cst, %cst, %cst) {hidden_size = 3 : i64} : (tensor<4x3x2xf32>, tensor<1x12x2xf32>, tensor<1x12x3xf32>, none, none, none) -> (none, none) + // CHECK: return +} + +// ----- + +func @test_gru_missing_first_result(%arg0: tensor<4x3x2xf32>, %arg1: tensor<1x12x2xf32>, %arg2: tensor<1x12x3xf32>) -> tensor<*xf32> { + %cst = constant unit + %Y, %Y_h = "onnx.GRU"(%arg0, %arg1, %arg2, %cst, %cst, %cst) {hidden_size = 3 : i64} : (tensor<4x3x2xf32>, tensor<1x12x2xf32>, tensor<1x12x3xf32>, none, none, none) -> (none, tensor<*xf32>) + return %Y_h : tensor<*xf32> + + // CHECK-LABEL: test_gru_missing_first_result + // CHECK: %{{.*}}, [[RES:%.+]] = "onnx.GRU"(%arg0, %arg1, %arg2, %cst, %cst, %cst) {hidden_size = 3 : i64} : (tensor<4x3x2xf32>, tensor<1x12x2xf32>, tensor<1x12x3xf32>, none, none, none) -> (none, tensor<1x3x3xf32>) + // CHECK: return [[RES]] : tensor<1x3x3xf32> +} + +// ----- + +func @test_gru_missing_trailing_result(%arg0: tensor<4x3x2xf32>, %arg1: tensor<1x12x2xf32>, %arg2: tensor<1x12x3xf32>) -> () { + %cst = constant unit + %Y, %Y_h = "onnx.GRU"(%arg0, %arg1, %arg2, %cst, %cst, %cst) {hidden_size = 3 : i64} : (tensor<4x3x2xf32>, tensor<1x12x2xf32>, tensor<1x12x3xf32>, none, none, none) -> (tensor<*xf32>, none) + return + + // CHECK-LABEL: test_gru_missing_trailing_result + // CHECK: %{{.*}}, [[RES:%.+]] = "onnx.GRU"(%arg0, %arg1, %arg2, %cst, %cst, %cst) {hidden_size = 3 : i64} : (tensor<4x3x2xf32>, tensor<1x12x2xf32>, tensor<1x12x3xf32>, none, none, none) -> (tensor<4x1x3x3xf32>, none) + // CHECK: return +} + +// ----- + +func @test_gru_all_results_no_hidden_size(%arg0: tensor<4x3x2xf32>, %arg1: tensor<1x12x2xf32>, %arg2: tensor<1x12x3xf32>) -> tensor<*xf32> { + %cst = constant unit + %Y, %Y_h = "onnx.GRU"(%arg0, %arg1, %arg2, %cst, %cst, %cst) : (tensor<4x3x2xf32>, tensor<1x12x2xf32>, tensor<1x12x3xf32>, none, none, none) -> (tensor<*xf32>, tensor<*xf32>) + return %Y_h : tensor<*xf32> + + // CHECK-LABEL: test_gru_all_results_no_hidden_size + // CHECK: %{{.*}}, [[RES:%.+]] = "onnx.GRU"(%arg0, %arg1, %arg2, %cst, %cst, %cst) {hidden_size = 3 : i64} : (tensor<4x3x2xf32>, tensor<1x12x2xf32>, tensor<1x12x3xf32>, none, none, none) -> (tensor<4x1x3x3xf32>, tensor<1x3x3xf32>) + // CHECK: return [[RES]] : tensor<1x3x3xf32> +} + +// ----- + +func @test_gru_all_results_unknown_dims(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor<*xf32> { + %cst = constant unit + %Y, %Y_h = "onnx.GRU"(%arg0, %arg1, %arg2, %cst, %cst, %cst) : (tensor, tensor, tensor, none, none, none) -> (tensor<*xf32>, tensor<*xf32>) + return %Y_h : tensor<*xf32> + + // CHECK-LABEL: test_gru_all_results_unknown_dims + // CHECK: %{{.*}}, [[RES:%.+]] = "onnx.GRU"(%arg0, %arg1, %arg2, %cst, %cst, %cst) : (tensor, tensor, tensor, none, none, none) -> (tensor, tensor<1x?x?xf32>) + // CHECK: return [[RES]] : tensor<1x?x?xf32> +} + +// ----- + +func @test_lstm_all_results(%arg0: tensor<4x3x2xf32>, %arg1: tensor<1x12x2xf32>, %arg2: tensor<1x12x3xf32>) -> tensor<*xf32> { + %cst = constant unit + %Y, %Y_h, %Y_c = "onnx.LSTM"(%arg0, %arg1, %arg2, %cst, %cst, %cst, %cst, %cst) {hidden_size = 3 : i64} : (tensor<4x3x2xf32>, tensor<1x12x2xf32>, tensor<1x12x3xf32>, none, none, none, none, none) -> (tensor<*xf32>, tensor<*xf32>, tensor<*xf32>) + return %Y_h : tensor<*xf32> + + // CHECK-LABEL: test_lstm_all_results + // CHECK: %{{.*}}, [[RES:%.+]], %{{.*}} = "onnx.LSTM"(%arg0, %arg1, %arg2, %cst, %cst, %cst, %cst, %cst) {hidden_size = 3 : i64} : (tensor<4x3x2xf32>, tensor<1x12x2xf32>, tensor<1x12x3xf32>, none, none, none, none, none) -> (tensor<4x1x3x3xf32>, tensor<1x3x3xf32>, tensor<1x3x3xf32>) + // CHECK: return [[RES]] : tensor<1x3x3xf32> +} + +// ----- + +func @test_lstm_no_results(%arg0: tensor<4x3x2xf32>, %arg1: tensor<1x12x2xf32>, %arg2: tensor<1x12x3xf32>) -> () { + %cst = constant unit + %Y, %Y_h, %Y_c = "onnx.LSTM"(%arg0, %arg1, %arg2, %cst, %cst, %cst, %cst, %cst) {hidden_size = 3 : i64} : (tensor<4x3x2xf32>, tensor<1x12x2xf32>, tensor<1x12x3xf32>, none, none, none, none, none) -> (none, none, none) + return + + // CHECK-LABEL: test_lstm_no_results + // CHECK: %{{.*}}, [[RES:%.+]], %{{.*}} = "onnx.LSTM"(%arg0, %arg1, %arg2, %cst, %cst, %cst, %cst, %cst) {hidden_size = 3 : i64} : (tensor<4x3x2xf32>, tensor<1x12x2xf32>, tensor<1x12x3xf32>, none, none, none, none, none) -> (none, none, none) + // CHECK: return +} + +// ----- + +func @test_lstm_missing_first_result(%arg0: tensor<4x3x2xf32>, %arg1: tensor<1x12x2xf32>, %arg2: tensor<1x12x3xf32>) -> tensor<*xf32> { + %cst = constant unit + %Y, %Y_h, %Y_c = "onnx.LSTM"(%arg0, %arg1, %arg2, %cst, %cst, %cst, %cst, %cst) {hidden_size = 3 : i64} : (tensor<4x3x2xf32>, tensor<1x12x2xf32>, tensor<1x12x3xf32>, none, none, none, none, none) -> (none, tensor<*xf32>, tensor<*xf32>) + return %Y_h : tensor<*xf32> + + // CHECK-LABEL: test_lstm_missing_first_result + // CHECK: %{{.*}}, [[RES:%.+]], %{{.*}} = "onnx.LSTM"(%arg0, %arg1, %arg2, %cst, %cst, %cst, %cst, %cst) {hidden_size = 3 : i64} : (tensor<4x3x2xf32>, tensor<1x12x2xf32>, tensor<1x12x3xf32>, none, none, none, none, none) -> (none, tensor<1x3x3xf32>, tensor<1x3x3xf32>) + // CHECK: return [[RES]] : tensor<1x3x3xf32> +} + +// ----- + +func @test_lstm_missing_trailing_result(%arg0: tensor<4x3x2xf32>, %arg1: tensor<1x12x2xf32>, %arg2: tensor<1x12x3xf32>) -> tensor<*xf32> { + %cst = constant unit + %Y, %Y_h, %Y_c = "onnx.LSTM"(%arg0, %arg1, %arg2, %cst, %cst, %cst, %cst, %cst) {hidden_size = 3 : i64} : (tensor<4x3x2xf32>, tensor<1x12x2xf32>, tensor<1x12x3xf32>, none, none, none, none, none) -> (tensor<*xf32>, tensor<*xf32>, none) + return %Y_h : tensor<*xf32> + + // CHECK-LABEL: test_lstm_missing_trailing_result + // CHECK: %{{.*}}, [[RES:%.+]], %{{.*}} = "onnx.LSTM"(%arg0, %arg1, %arg2, %cst, %cst, %cst, %cst, %cst) {hidden_size = 3 : i64} : (tensor<4x3x2xf32>, tensor<1x12x2xf32>, tensor<1x12x3xf32>, none, none, none, none, none) -> (tensor<4x1x3x3xf32>, tensor<1x3x3xf32>, none) + // CHECK: return [[RES]] : tensor<1x3x3xf32> +} + +// ----- + +func @test_lstm_all_results_no_hidden_size(%arg0: tensor<4x3x2xf32>, %arg1: tensor<1x12x2xf32>, %arg2: tensor<1x12x3xf32>) -> tensor<*xf32> { + %cst = constant unit + %Y, %Y_h, %Y_c = "onnx.LSTM"(%arg0, %arg1, %arg2, %cst, %cst, %cst, %cst, %cst) : (tensor<4x3x2xf32>, tensor<1x12x2xf32>, tensor<1x12x3xf32>, none, none, none, none, none) -> (tensor<*xf32>, tensor<*xf32>, tensor<*xf32>) + return %Y_h : tensor<*xf32> + + // CHECK-LABEL: test_lstm_all_results_no_hidden_size + // CHECK: %{{.*}}, [[RES:%.+]], %{{.*}} = "onnx.LSTM"(%arg0, %arg1, %arg2, %cst, %cst, %cst, %cst, %cst) {hidden_size = 3 : i64} : (tensor<4x3x2xf32>, tensor<1x12x2xf32>, tensor<1x12x3xf32>, none, none, none, none, none) -> (tensor<4x1x3x3xf32>, tensor<1x3x3xf32>, tensor<1x3x3xf32>) + // CHECK: return [[RES]] : tensor<1x3x3xf32> +} + +// ----- + +func @test_lstm_all_results_unknown_dims(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor<*xf32> { + %cst = constant unit + %Y, %Y_h, %Y_c = "onnx.LSTM"(%arg0, %arg1, %arg2, %cst, %cst, %cst, %cst, %cst) : (tensor, tensor, tensor, none, none, none, none, none) -> (tensor<*xf32>, tensor<*xf32>, tensor<*xf32>) + return %Y_h : tensor<*xf32> + + // CHECK-LABEL: test_lstm_all_results_unknown_dims + // CHECK: %{{.*}}, [[RES:%.+]], %{{.*}} = "onnx.LSTM"(%arg0, %arg1, %arg2, %cst, %cst, %cst, %cst, %cst) : (tensor, tensor, tensor, none, none, none, none, none) -> (tensor, tensor<1x?x?xf32>, tensor<1x?x?xf32>) + // CHECK: return [[RES]] : tensor<1x?x?xf32> +} + +// ----- + func @test_split_1(%arg0 : tensor<16x32x64xf32>) -> tensor<*xf32> { %0, %1 = "onnx.Split"(%arg0) { axis = 1 } : (tensor<16x32x64xf32>) -> (tensor<*xf32>, tensor<*xf32>) "std.return"(%0) : (tensor<*xf32>) -> () diff --git a/utils/gen_doc.py b/utils/gen_doc.py index 1f7fe65..46afb5d 100644 --- a/utils/gen_doc.py +++ b/utils/gen_doc.py @@ -63,7 +63,8 @@ OpsWithShapeInference = [ 'LeakyRelu', 'Elu', 'Selu', 'HardSigmoid', 'Reshape', 'Reciprocal', 'Identity', 'Cos', 'Log', 'Transpose', 'Softmax', 'ReduceMax', 'ReduceMin', 'ReduceProd', 'ReduceSum', 'Softplus', 'Softsign', 'Sqrt', 'Unsqueeze', - 'Sign', 'Constant', 'AveragePool', 'Abs', 'Conv', 'Concat', 'Neg', 'Split' + 'Sign', 'Constant', 'AveragePool', 'Abs', 'Conv', 'Concat', 'Neg', 'RNN', + 'LSTM', 'GRU', 'Split' ] # Operations supporting canonicalization.