2020-03-19 16:48:09 +08:00
|
|
|
//===----------------- Softmax.cpp - Softmax Op ---------------------------===//
|
2020-02-19 15:17:48 +08:00
|
|
|
//
|
|
|
|
// Copyright 2019 The IBM Research Authors.
|
|
|
|
//
|
|
|
|
// =============================================================================
|
|
|
|
//
|
|
|
|
// This file lowers ONNX softmax operator to Krnl dialect.
|
|
|
|
//
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2020-03-19 16:48:09 +08:00
|
|
|
#include "src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp"
|
2020-02-25 23:38:08 +08:00
|
|
|
|
|
|
|
using namespace mlir;
|
|
|
|
|
2020-02-19 15:17:48 +08:00
|
|
|
struct ONNXSoftmaxOpLowering : public ConversionPattern {
|
|
|
|
ONNXSoftmaxOpLowering(MLIRContext *ctx)
|
|
|
|
: ConversionPattern(mlir::ONNXSoftmaxOp::getOperationName(), 1, ctx) {}
|
2020-04-02 00:38:34 +08:00
|
|
|
LogicalResult matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
2020-03-31 23:55:27 +08:00
|
|
|
ConversionPatternRewriter &rewriter) const final {
|
2020-02-19 15:17:48 +08:00
|
|
|
// softmax(x) = let max_x = max(x) in
|
|
|
|
// let exp_x = exp(x - max_x) in
|
|
|
|
// let sum = sum(exp_x) in
|
|
|
|
// exp_x / sum
|
2020-02-20 21:44:02 +08:00
|
|
|
auto memRefType = convertToMemRefType(*op->result_type_begin());
|
|
|
|
int64_t rank = memRefType.getRank();
|
2020-02-19 15:17:48 +08:00
|
|
|
int64_t axis = llvm::dyn_cast<ONNXSoftmaxOp>(op).axis().getSExtValue();
|
|
|
|
axis = axis >= 0 ? axis : rank + axis;
|
|
|
|
assert(axis >= -rank && axis <= rank - 1);
|
|
|
|
|
|
|
|
auto loc = op->getLoc();
|
2020-07-07 21:26:00 +08:00
|
|
|
ONNXSoftmaxOpAdaptor operandAdaptor(operands);
|
2020-03-31 23:55:27 +08:00
|
|
|
Value input = operandAdaptor.input();
|
2020-02-19 15:17:48 +08:00
|
|
|
// Insert an allocation and deallocation for the result of this operation.
|
|
|
|
auto elementType = memRefType.getElementType();
|
|
|
|
|
|
|
|
Value alloc;
|
|
|
|
bool insertDealloc = checkInsertDealloc(op);
|
|
|
|
if (hasAllConstantDimensions(memRefType))
|
|
|
|
alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc);
|
|
|
|
else
|
2020-03-31 23:55:27 +08:00
|
|
|
alloc = insertAllocAndDealloc(
|
|
|
|
memRefType, loc, rewriter, insertDealloc, input);
|
2020-02-19 15:17:48 +08:00
|
|
|
|
|
|
|
// Shape of the result
|
|
|
|
auto memRefShape = memRefType.getShape();
|
|
|
|
|
|
|
|
// Insert allocations and deallocations for sum and max.
|
|
|
|
MemRefType scalarMemRefType = MemRefType::get({}, elementType, {}, 0);
|
|
|
|
Value sumOp = insertAllocAndDealloc(scalarMemRefType, loc, rewriter, true);
|
|
|
|
Value maxOp = insertAllocAndDealloc(scalarMemRefType, loc, rewriter, true);
|
2020-03-06 03:21:00 +08:00
|
|
|
Value zero = emitConstantOp(rewriter, loc, elementType, 0);
|
2020-03-31 23:55:27 +08:00
|
|
|
Value negInfinity = rewriter.create<ConstantOp>(loc,
|
2020-02-19 15:17:48 +08:00
|
|
|
FloatAttr::get(elementType, -std::numeric_limits<float>::infinity()));
|
|
|
|
|
|
|
|
// Define loops.
|
|
|
|
std::vector<Value> originalLoops;
|
|
|
|
std::vector<Value> optimizedLoops;
|
2020-03-31 23:55:27 +08:00
|
|
|
Block *optimizationBlock =
|
|
|
|
defineLoops(rewriter, loc, originalLoops, optimizedLoops, rank);
|
2020-02-19 15:17:48 +08:00
|
|
|
|
|
|
|
// Coerce the input into a 2-D tensor. `axis` will be the coercing point.
|
|
|
|
// This coercing follows the softmax definition in ONNX:
|
|
|
|
// https://github.com/onnx/onnx/blob/master/docs/Operators.md#Softmax
|
|
|
|
// Here, we create an outer loop and inner loop for handling the two
|
|
|
|
// dimensions. The outer loop is only created once `axis` is not zero.
|
|
|
|
|
|
|
|
// Define an outer loop with respect to axis.
|
|
|
|
std::vector<Value> outerLoops, optimizedOuterLoops;
|
|
|
|
outerLoops.reserve(axis);
|
|
|
|
optimizedOuterLoops.reserve(axis);
|
|
|
|
for (int i = 0; i < axis; ++i) {
|
|
|
|
outerLoops.push_back(originalLoops[i]);
|
|
|
|
optimizedOuterLoops.push_back(optimizedLoops[i]);
|
|
|
|
}
|
|
|
|
KrnlIterateOperandPack outerPack(rewriter, outerLoops, optimizedOuterLoops);
|
|
|
|
for (int i = 0; i < axis; ++i)
|
2020-03-31 23:55:27 +08:00
|
|
|
addDimensionToPack(rewriter, loc, outerPack, input, i);
|
2020-02-19 15:17:48 +08:00
|
|
|
|
|
|
|
// Define an inner loop with respect to axis.
|
|
|
|
std::vector<Value> innerLoops, optimizedInnerLoops;
|
|
|
|
innerLoops.reserve(rank - axis);
|
|
|
|
optimizedInnerLoops.reserve(rank - axis);
|
|
|
|
for (int i = axis; i < rank; ++i) {
|
|
|
|
innerLoops.push_back(originalLoops[i]);
|
|
|
|
optimizedInnerLoops.push_back(optimizedLoops[i]);
|
|
|
|
}
|
|
|
|
KrnlIterateOperandPack innerPack(rewriter, innerLoops, optimizedInnerLoops);
|
|
|
|
for (int i = axis; i < rank; ++i)
|
2020-03-31 23:55:27 +08:00
|
|
|
addDimensionToPack(rewriter, loc, innerPack, input, i);
|
2020-02-19 15:17:48 +08:00
|
|
|
|
|
|
|
KrnlIterateOp outerIterateOp, maxIterateOp, sumIterateOp, softmaxIterateOp;
|
|
|
|
SmallVector<Value, 4> outerLoopIVs;
|
|
|
|
if (axis != 0) {
|
|
|
|
outerIterateOp = rewriter.create<KrnlIterateOp>(loc, outerPack);
|
|
|
|
|
|
|
|
// No optimization
|
|
|
|
rewriter.setInsertionPointToEnd(optimizationBlock);
|
|
|
|
rewriter.create<KrnlReturnLoopsOp>(loc, originalLoops);
|
|
|
|
|
|
|
|
// Insert instructions inside the outer loop.
|
|
|
|
Block &outerIterationBlock = outerIterateOp.bodyRegion().front();
|
|
|
|
rewriter.setInsertionPointToStart(&outerIterationBlock);
|
|
|
|
for (auto arg : outerIterationBlock.getArguments())
|
|
|
|
outerLoopIVs.push_back(arg);
|
|
|
|
|
|
|
|
// Reset accumulators.
|
2020-07-05 16:20:21 +08:00
|
|
|
rewriter.create<AffineStoreOp>(loc, zero, sumOp, ArrayRef<Value>{});
|
|
|
|
rewriter.create<AffineStoreOp>(
|
|
|
|
loc, negInfinity, maxOp, ArrayRef<Value>{});
|
2020-02-19 15:17:48 +08:00
|
|
|
|
|
|
|
// Create an inner loop to compute max.
|
|
|
|
maxIterateOp = rewriter.create<KrnlIterateOp>(loc, innerPack);
|
|
|
|
// Create an inner loop to compute sum.
|
|
|
|
sumIterateOp = rewriter.create<KrnlIterateOp>(loc, innerPack);
|
|
|
|
// Create an inner loop to compute softmax.
|
|
|
|
softmaxIterateOp = rewriter.create<KrnlIterateOp>(loc, innerPack);
|
|
|
|
} else {
|
|
|
|
// Reset accumulators.
|
2020-07-05 16:20:21 +08:00
|
|
|
rewriter.create<AffineStoreOp>(loc, zero, sumOp, ArrayRef<Value>{});
|
|
|
|
rewriter.create<AffineStoreOp>(
|
|
|
|
loc, negInfinity, maxOp, ArrayRef<Value>{});
|
2020-02-19 15:17:48 +08:00
|
|
|
|
|
|
|
// Create an inner loop to compute max.
|
|
|
|
maxIterateOp = rewriter.create<KrnlIterateOp>(loc, innerPack);
|
|
|
|
// Create an inner loop to compute sum.
|
|
|
|
sumIterateOp = rewriter.create<KrnlIterateOp>(loc, innerPack);
|
|
|
|
// Create an inner loop to compute softmax.
|
|
|
|
softmaxIterateOp = rewriter.create<KrnlIterateOp>(loc, innerPack);
|
|
|
|
|
|
|
|
// No optimization
|
|
|
|
rewriter.setInsertionPointToEnd(optimizationBlock);
|
|
|
|
rewriter.create<KrnlReturnLoopsOp>(loc, originalLoops);
|
|
|
|
}
|
|
|
|
|
|
|
|
// Insert instructions inside the max loop.
|
|
|
|
Block &maxIterationBlock = maxIterateOp.bodyRegion().front();
|
|
|
|
rewriter.setInsertionPointToStart(&maxIterationBlock);
|
|
|
|
|
|
|
|
// Get induction variables.
|
|
|
|
SmallVector<Value, 4> maxLoopIVs;
|
|
|
|
for (auto arg : outerLoopIVs)
|
|
|
|
maxLoopIVs.push_back(arg);
|
|
|
|
for (auto arg : maxIterationBlock.getArguments())
|
|
|
|
maxLoopIVs.push_back(arg);
|
|
|
|
|
|
|
|
// Compute the max value.
|
2020-07-05 16:20:21 +08:00
|
|
|
Value max = rewriter.create<AffineLoadOp>(loc, maxOp);
|
|
|
|
Value nextMax = rewriter.create<AffineLoadOp>(loc, input, maxLoopIVs);
|
2020-02-19 15:17:48 +08:00
|
|
|
auto maxCond =
|
|
|
|
rewriter.create<CmpFOp>(loc, CmpFPredicate::OGT, max, nextMax);
|
|
|
|
max = rewriter.create<SelectOp>(loc, maxCond, max, nextMax);
|
2020-07-05 16:20:21 +08:00
|
|
|
rewriter.create<AffineStoreOp>(loc, max, maxOp, ArrayRef<Value>{});
|
2020-02-19 15:17:48 +08:00
|
|
|
|
|
|
|
// Get the max.
|
|
|
|
rewriter.setInsertionPoint(sumIterateOp);
|
2020-07-05 16:20:21 +08:00
|
|
|
max = rewriter.create<AffineLoadOp>(loc, maxOp);
|
2020-02-19 15:17:48 +08:00
|
|
|
|
|
|
|
// Insert instructions inside the sum loop.
|
|
|
|
Block &sumIterationBlock = sumIterateOp.bodyRegion().front();
|
|
|
|
rewriter.setInsertionPointToStart(&sumIterationBlock);
|
|
|
|
|
|
|
|
// Get induction variables.
|
|
|
|
SmallVector<Value, 4> sumLoopIVs;
|
|
|
|
for (auto arg : outerLoopIVs)
|
|
|
|
sumLoopIVs.push_back(arg);
|
|
|
|
for (auto arg : sumIterationBlock.getArguments())
|
|
|
|
sumLoopIVs.push_back(arg);
|
|
|
|
|
|
|
|
// Sum up values.
|
2020-07-05 16:20:21 +08:00
|
|
|
Value sum = rewriter.create<AffineLoadOp>(loc, sumOp);
|
|
|
|
Value next = rewriter.create<AffineLoadOp>(loc, input, sumLoopIVs);
|
2020-02-19 15:17:48 +08:00
|
|
|
Value sub = rewriter.create<SubFOp>(loc, next, max);
|
|
|
|
Value exp = rewriter.create<ExpOp>(loc, sub);
|
|
|
|
sum = rewriter.create<AddFOp>(loc, sum, exp);
|
2020-07-05 16:20:21 +08:00
|
|
|
rewriter.create<AffineStoreOp>(loc, sum, sumOp, ArrayRef<Value>{});
|
2020-02-19 15:17:48 +08:00
|
|
|
// Store intermediate values in the result to avoid recomputation.
|
2020-07-05 16:20:21 +08:00
|
|
|
rewriter.create<AffineStoreOp>(loc, exp, alloc, sumLoopIVs);
|
2020-02-19 15:17:48 +08:00
|
|
|
|
|
|
|
// Get the sum.
|
|
|
|
rewriter.setInsertionPoint(softmaxIterateOp);
|
2020-07-05 16:20:21 +08:00
|
|
|
sum = rewriter.create<AffineLoadOp>(loc, sumOp);
|
2020-02-19 15:17:48 +08:00
|
|
|
|
|
|
|
// Insert instructions inside the softmax loop.
|
|
|
|
Block &softmaxIterationBlock = softmaxIterateOp.bodyRegion().front();
|
|
|
|
rewriter.setInsertionPointToStart(&softmaxIterationBlock);
|
|
|
|
|
|
|
|
// Get induction variables.
|
|
|
|
SmallVector<Value, 4> softmaxLoopIVs;
|
|
|
|
for (auto arg : outerLoopIVs)
|
|
|
|
softmaxLoopIVs.push_back(arg);
|
|
|
|
for (auto arg : softmaxIterationBlock.getArguments())
|
|
|
|
softmaxLoopIVs.push_back(arg);
|
|
|
|
|
|
|
|
// Compute softmax.
|
2020-07-05 16:20:21 +08:00
|
|
|
Value expLoadedVal =
|
|
|
|
rewriter.create<AffineLoadOp>(loc, alloc, softmaxLoopIVs);
|
2020-02-19 15:17:48 +08:00
|
|
|
Value result = rewriter.create<DivFOp>(loc, expLoadedVal, sum);
|
2020-07-05 16:20:21 +08:00
|
|
|
rewriter.create<AffineStoreOp>(loc, result, alloc, softmaxLoopIVs);
|
2020-02-19 15:17:48 +08:00
|
|
|
|
|
|
|
rewriter.replaceOp(op, alloc);
|
|
|
|
|
2020-04-02 00:38:34 +08:00
|
|
|
return success();
|
2020-02-19 15:17:48 +08:00
|
|
|
}
|
|
|
|
};
|
|
|
|
|
|
|
|
void populateLoweringONNXSoftmaxOpPattern(
|
|
|
|
OwningRewritePatternList &patterns, MLIRContext *ctx) {
|
|
|
|
patterns.insert<ONNXSoftmaxOpLowering>(ctx);
|
|
|
|
}
|