//===--------------- RNNBase.cpp - Lowering RNN Ops -----------------------===// // // Copyright 2019 The IBM Research Authors. // // ============================================================================= // // This file defines base functions for lowerng the ONNX RNN Operators. // //===----------------------------------------------------------------------===// #include "src/Conversion/ONNXToKrnl/RNN/RNNBase.hpp" using namespace mlir; // Check a Value's type is none or not. bool isNoneType(Value val) { return val.getType().isa(); } // Get a dimension of the tensor's shape. int64_t dimAt(Value val, int index) { return val.getType().cast().getShape()[index]; } // Apply an activation function on a given scalar operand. Value applyActivation(ConversionPatternRewriter &rewriter, Location loc, RNNActivation activation, Value scalarOperand) { Value res; MemRefType scalarMemRefType = MemRefType::get({}, scalarOperand.getType(), {}, 0); Value alloc = rewriter.create(loc, scalarMemRefType); rewriter.create(loc, scalarOperand, alloc, ArrayRef{}); std::vector attributes; if (activation.alpha) { attributes.emplace_back( rewriter.getNamedAttr("alpha", activation.alpha.getValue())); } if (activation.beta) { attributes.emplace_back( rewriter.getNamedAttr("beta", activation.beta.getValue())); } if (activation.name.equals_lower("relu")) res = rewriter.create(loc, scalarMemRefType, alloc); else if (activation.name.equals_lower("tanh")) res = rewriter.create(loc, scalarMemRefType, alloc); else if (activation.name.equals_lower("sigmoid")) res = rewriter.create(loc, scalarMemRefType, alloc); else if (activation.name.equals_lower("affine")) llvm_unreachable("Unsupported activation"); else if (activation.name.equals_lower("leakyrelu")) res = rewriter.create( loc, scalarMemRefType, alloc, attributes); else if (activation.name.equals_lower("thresholdedrelu")) res = rewriter.create( loc, scalarMemRefType, alloc, attributes); else if (activation.name.equals_lower("scaledtanh")) llvm_unreachable("Unsupported activation"); else if (activation.name.equals_lower("hardsigmoid")) res = rewriter.create( loc, scalarMemRefType, alloc, attributes); else if (activation.name.equals_lower("elu")) res = rewriter.create(loc, scalarMemRefType, alloc, attributes); else if (activation.name.equals_lower("softsign")) res = rewriter.create(loc, scalarMemRefType, alloc); else if (activation.name.equals_lower("softplus")) res = rewriter.create(loc, scalarMemRefType, alloc); else llvm_unreachable("Unsupported activation"); Value result = rewriter.create(loc, res); return result; }