74 lines
2.9 KiB
C++
74 lines
2.9 KiB
C++
//===--------------- RNNBase.cpp - Lowering RNN Ops -----------------------===//
|
|
//
|
|
// Copyright 2019 The IBM Research Authors.
|
|
//
|
|
// =============================================================================
|
|
//
|
|
// This file defines base functions for lowerng the ONNX RNN Operators.
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "src/Conversion/ONNXToKrnl/RNN/RNNBase.hpp"
|
|
|
|
using namespace mlir;
|
|
|
|
// Check a Value's type is none or not.
|
|
bool isNoneType(Value val) { return val.getType().isa<NoneType>(); }
|
|
|
|
// Get a dimension of the tensor's shape.
|
|
int64_t dimAt(Value val, int index) {
|
|
return val.getType().cast<ShapedType>().getShape()[index];
|
|
}
|
|
|
|
// Apply an activation function on a given scalar operand.
|
|
Value applyActivation(ConversionPatternRewriter &rewriter, Location loc,
|
|
RNNActivation activation, Value scalarOperand) {
|
|
Value res;
|
|
|
|
MemRefType scalarMemRefType =
|
|
MemRefType::get({}, scalarOperand.getType(), {}, 0);
|
|
Value alloc = rewriter.create<AllocOp>(loc, scalarMemRefType);
|
|
rewriter.create<StoreOp>(loc, scalarOperand, alloc);
|
|
|
|
std::vector<mlir::NamedAttribute> attributes;
|
|
if (activation.alpha) {
|
|
attributes.emplace_back(
|
|
rewriter.getNamedAttr("alpha", activation.alpha.getValue()));
|
|
}
|
|
if (activation.beta) {
|
|
attributes.emplace_back(
|
|
rewriter.getNamedAttr("beta", activation.beta.getValue()));
|
|
}
|
|
|
|
if (activation.name.equals_lower("relu"))
|
|
res = rewriter.create<ONNXReluOp>(loc, scalarMemRefType, alloc);
|
|
else if (activation.name.equals_lower("tanh"))
|
|
res = rewriter.create<ONNXTanhOp>(loc, scalarMemRefType, alloc);
|
|
else if (activation.name.equals_lower("sigmoid"))
|
|
res = rewriter.create<ONNXSigmoidOp>(loc, scalarMemRefType, alloc);
|
|
else if (activation.name.equals_lower("affine"))
|
|
emitError(loc, "Unsupported activation");
|
|
else if (activation.name.equals_lower("leakyrelu"))
|
|
res = rewriter.create<ONNXLeakyReluOp>(
|
|
loc, scalarMemRefType, alloc, attributes);
|
|
else if (activation.name.equals_lower("thresholdedrelu"))
|
|
res = rewriter.create<ONNXThresholdedReluOp>(
|
|
loc, scalarMemRefType, alloc, attributes);
|
|
else if (activation.name.equals_lower("scaledtanh"))
|
|
emitError(loc, "Unsupported activation");
|
|
else if (activation.name.equals_lower("hardsigmoid"))
|
|
res = rewriter.create<ONNXHardSigmoidOp>(
|
|
loc, scalarMemRefType, alloc, attributes);
|
|
else if (activation.name.equals_lower("elu"))
|
|
res = rewriter.create<ONNXEluOp>(loc, scalarMemRefType, alloc, attributes);
|
|
else if (activation.name.equals_lower("softsign"))
|
|
res = rewriter.create<ONNXSoftsignOp>(loc, scalarMemRefType, alloc);
|
|
else if (activation.name.equals_lower("softplus"))
|
|
res = rewriter.create<ONNXSoftplusOp>(loc, scalarMemRefType, alloc);
|
|
else
|
|
llvm_unreachable("Unsupported activation");
|
|
|
|
Value result = rewriter.create<LoadOp>(loc, res);
|
|
return result;
|
|
}
|