Lowering softmax (#14)
* Rebase * Use max normalization * Handle axis * Add tests * Update SharingWork.md * Remove redundant spaces * Format code * Rebase * Change from the use of Value* to Value * Add end-to-end tests Co-authored-by: Tian Jin <tjingrant@gmail.com>
This commit is contained in:
parent
0aaab0d2d2
commit
e89e51699b
|
@ -27,6 +27,7 @@ ONNX operations for which some work is needed.
|
|||
| Selu | Tung | v | v | |
|
||||
| Sigmoid | Tung | v | v | |
|
||||
| Sinh | Tung | v | v | |
|
||||
| Softmax | Tung | v | v | |
|
||||
| Sub | Tung | v | v | M |
|
||||
| Sum | Tung | v | v | M |
|
||||
| Tanh | Tung | v | v | |
|
||||
|
|
|
@ -267,7 +267,7 @@ def gen_schema(schema) :
|
|||
'Add', 'Mul', 'Div', 'Sub', 'And', 'Or', 'Xor',
|
||||
'Sum', 'Max', 'Min', 'MatMul', 'Gemm', 'LeakyRelu',
|
||||
'Elu', 'Selu', 'HardSigmoid', 'Reshape', 'Reciprocal',
|
||||
'Identity', 'Cos', 'Log', 'Transpose']
|
||||
'Identity', 'Cos', 'Log', 'Transpose', 'Softmax']
|
||||
CanonicalList=['Add', 'Identity']
|
||||
line_indent = ' '
|
||||
|
||||
|
|
|
@ -158,6 +158,14 @@ void ONNXReciprocalOp::inferShapes() {
|
|||
getResult().setType(getOperand().getType());
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Softmax
|
||||
/// Infer the output shape of the ONNXSoftmaxOp. This method is required by
|
||||
/// the shape inference interface.
|
||||
void ONNXSoftmaxOp::inferShapes() {
|
||||
getResult().setType(getOperand().getType());
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Add
|
||||
/// Infer the output shape of the ONNXAddOp. This method is required by the
|
||||
|
|
|
@ -2831,7 +2831,7 @@ def ONNXSliceOp:ONNX_Op<"Slice",
|
|||
}
|
||||
|
||||
def ONNXSoftmaxOp:ONNX_Op<"Softmax",
|
||||
[NoSideEffect]> {
|
||||
[NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
|
||||
let summary = "ONNX Softmax operation";
|
||||
let description = [{
|
||||
"The operator computes the softmax (normalized exponential) values for each layer in the batch"
|
||||
|
|
|
@ -824,6 +824,225 @@ struct ONNXElementwiseVariadicOpLowering : public ConversionPattern {
|
|||
}
|
||||
};
|
||||
|
||||
struct ONNXSoftmaxOpLowering : public ConversionPattern {
|
||||
ONNXSoftmaxOpLowering(MLIRContext *ctx)
|
||||
: ConversionPattern(mlir::ONNXSoftmaxOp::getOperationName(), 1, ctx) {}
|
||||
PatternMatchResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const final {
|
||||
// softmax(x) = let max_x = max(x) in
|
||||
// let exp_x = exp(x - max_x) in
|
||||
// let sum = sum(exp_x) in
|
||||
// exp_x / sum
|
||||
auto tensorType = (*op->result_type_begin()).cast<RankedTensorType>();
|
||||
int64_t rank = tensorType.getRank();
|
||||
int64_t axis = op->getAttrOfType<IntegerAttr>("Softmax.axis").getInt();
|
||||
axis = axis >= 0 ? axis : rank + axis;
|
||||
assert(axis >= -rank && axis <= rank - 1);
|
||||
|
||||
auto loc = op->getLoc();
|
||||
|
||||
// Insert an allocation and deallocation for the result of this operation.
|
||||
auto memRefType = convertTensorToMemRef(tensorType);
|
||||
auto elementType = memRefType.getElementType();
|
||||
|
||||
Value alloc;
|
||||
bool insertDealloc = checkInsertDealloc(op);
|
||||
if (hasAllConstantDimensions(memRefType))
|
||||
alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc);
|
||||
else
|
||||
alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc,
|
||||
operands[0]);
|
||||
|
||||
// Shape of the result
|
||||
auto memRefShape = memRefType.getShape();
|
||||
|
||||
// Insert allocations and deallocations for sum and max.
|
||||
MemRefType scalarMemRefType = MemRefType::get({}, elementType, {}, 0);
|
||||
Value sumOp = insertAllocAndDealloc(scalarMemRefType, loc, rewriter, true);
|
||||
Value maxOp = insertAllocAndDealloc(scalarMemRefType, loc, rewriter, true);
|
||||
Value zero =
|
||||
rewriter.create<ConstantOp>(loc, FloatAttr::get(elementType, 0));
|
||||
Value negInfinity = rewriter.create<ConstantOp>(
|
||||
loc,
|
||||
FloatAttr::get(elementType, -std::numeric_limits<float>::infinity()));
|
||||
|
||||
// Define loops.
|
||||
auto loopsOp = rewriter.create<KrnlDefineLoopsOp>(loc, rank);
|
||||
std::vector<Value> originalLoops;
|
||||
originalLoops.reserve(rank);
|
||||
for (auto result : loopsOp.getResults()) {
|
||||
originalLoops.push_back(result);
|
||||
}
|
||||
|
||||
// Define loop optimization.
|
||||
auto optimizedLoopsOp = rewriter.create<KrnlOptimizeLoopsOp>(loc, rank);
|
||||
std::vector<Value> optimizedLoops;
|
||||
optimizedLoops.reserve(rank);
|
||||
for (auto result : optimizedLoopsOp.getResults()) {
|
||||
optimizedLoops.push_back(result);
|
||||
}
|
||||
Block &optimizationBlock = optimizedLoopsOp.region().front();
|
||||
|
||||
// Coerce the input into a 2-D tensor. `axis` will be the coercing point.
|
||||
// This coercing follows the softmax definition in ONNX:
|
||||
// https://github.com/onnx/onnx/blob/master/docs/Operators.md#Softmax
|
||||
// Here, we create an outer loop and inner loop for handling the two
|
||||
// dimensions. The outer loop is only created once `axis` is not zero.
|
||||
|
||||
// Define an outer loop with respect to axis.
|
||||
std::vector<Value> outerLoops, optimizedOuterLoops;
|
||||
outerLoops.reserve(axis);
|
||||
optimizedOuterLoops.reserve(axis);
|
||||
for (int i = 0; i < axis; ++i) {
|
||||
outerLoops.push_back(originalLoops[i]);
|
||||
optimizedOuterLoops.push_back(optimizedLoops[i]);
|
||||
}
|
||||
KrnlIterateOperandPack outerPack(rewriter, outerLoops, optimizedOuterLoops);
|
||||
for (int i = 0; i < axis; ++i) {
|
||||
if (memRefShape[i] < 0) {
|
||||
outerPack.pushConstantBound(0);
|
||||
outerPack.pushOperandBound(
|
||||
rewriter.create<DimOp>(loc, operands[0], i).getResult());
|
||||
} else {
|
||||
outerPack.pushConstantBound(0);
|
||||
outerPack.pushConstantBound(memRefShape[i]);
|
||||
}
|
||||
}
|
||||
// Define an inner loop with respect to axis.
|
||||
std::vector<Value> innerLoops, optimizedInnerLoops;
|
||||
innerLoops.reserve(rank - axis);
|
||||
optimizedInnerLoops.reserve(rank - axis);
|
||||
for (int i = axis; i < rank; ++i) {
|
||||
innerLoops.push_back(originalLoops[i]);
|
||||
optimizedInnerLoops.push_back(optimizedLoops[i]);
|
||||
}
|
||||
KrnlIterateOperandPack innerPack(rewriter, innerLoops, optimizedInnerLoops);
|
||||
for (int i = axis; i < rank; ++i) {
|
||||
if (memRefShape[i] < 0) {
|
||||
innerPack.pushConstantBound(0);
|
||||
innerPack.pushOperandBound(
|
||||
rewriter.create<DimOp>(loc, operands[0], i).getResult());
|
||||
} else {
|
||||
innerPack.pushConstantBound(0);
|
||||
innerPack.pushConstantBound(memRefShape[i]);
|
||||
}
|
||||
}
|
||||
|
||||
KrnlIterateOp outerIterateOp, maxIterateOp, sumIterateOp, softmaxIterateOp;
|
||||
SmallVector<Value, 4> outerLoopIVs;
|
||||
if (axis != 0) {
|
||||
outerIterateOp = rewriter.create<KrnlIterateOp>(loc, outerPack);
|
||||
|
||||
// No optimization
|
||||
rewriter.setInsertionPointToEnd(&optimizationBlock);
|
||||
rewriter.create<KrnlReturnLoopsOp>(loc, originalLoops);
|
||||
rewriter.setInsertionPoint(optimizedLoopsOp);
|
||||
|
||||
// Insert instructions inside the outer loop.
|
||||
Block &outerIterationBlock = outerIterateOp.bodyRegion().front();
|
||||
rewriter.setInsertionPointToStart(&outerIterationBlock);
|
||||
for (auto arg : outerIterationBlock.getArguments())
|
||||
outerLoopIVs.push_back(arg);
|
||||
|
||||
// Reset accumulators.
|
||||
rewriter.create<StoreOp>(loc, zero, sumOp);
|
||||
rewriter.create<StoreOp>(loc, negInfinity, maxOp);
|
||||
|
||||
// Create an inner loop to compute max.
|
||||
maxIterateOp = rewriter.create<KrnlIterateOp>(loc, innerPack);
|
||||
// Create an inner loop to compute sum.
|
||||
sumIterateOp = rewriter.create<KrnlIterateOp>(loc, innerPack);
|
||||
// Create an inner loop to compute softmax.
|
||||
softmaxIterateOp = rewriter.create<KrnlIterateOp>(loc, innerPack);
|
||||
} else {
|
||||
// Reset accumulators.
|
||||
rewriter.create<StoreOp>(loc, zero, sumOp);
|
||||
rewriter.create<StoreOp>(loc, negInfinity, maxOp);
|
||||
|
||||
// Create an inner loop to compute max.
|
||||
maxIterateOp = rewriter.create<KrnlIterateOp>(loc, innerPack);
|
||||
// Create an inner loop to compute sum.
|
||||
sumIterateOp = rewriter.create<KrnlIterateOp>(loc, innerPack);
|
||||
// Create an inner loop to compute softmax.
|
||||
softmaxIterateOp = rewriter.create<KrnlIterateOp>(loc, innerPack);
|
||||
|
||||
// No optimization
|
||||
rewriter.setInsertionPointToEnd(&optimizationBlock);
|
||||
rewriter.create<KrnlReturnLoopsOp>(loc, originalLoops);
|
||||
rewriter.setInsertionPoint(optimizedLoopsOp);
|
||||
}
|
||||
|
||||
// Insert instructions inside the max loop.
|
||||
Block &maxIterationBlock = maxIterateOp.bodyRegion().front();
|
||||
rewriter.setInsertionPointToStart(&maxIterationBlock);
|
||||
|
||||
// Get induction variables.
|
||||
SmallVector<Value, 4> maxLoopIVs;
|
||||
for (auto arg : outerLoopIVs)
|
||||
maxLoopIVs.push_back(arg);
|
||||
for (auto arg : maxIterationBlock.getArguments())
|
||||
maxLoopIVs.push_back(arg);
|
||||
|
||||
// Compute the max value.
|
||||
Value max = rewriter.create<LoadOp>(loc, maxOp);
|
||||
Value nextMax = rewriter.create<LoadOp>(loc, operands[0], maxLoopIVs);
|
||||
auto maxCond =
|
||||
rewriter.create<CmpFOp>(loc, CmpFPredicate::OGT, max, nextMax);
|
||||
max = rewriter.create<SelectOp>(loc, maxCond, max, nextMax);
|
||||
rewriter.create<StoreOp>(loc, max, maxOp);
|
||||
|
||||
// Get the max.
|
||||
rewriter.setInsertionPoint(sumIterateOp);
|
||||
max = rewriter.create<LoadOp>(loc, maxOp);
|
||||
|
||||
// Insert instructions inside the sum loop.
|
||||
Block &sumIterationBlock = sumIterateOp.bodyRegion().front();
|
||||
rewriter.setInsertionPointToStart(&sumIterationBlock);
|
||||
|
||||
// Get induction variables.
|
||||
SmallVector<Value, 4> sumLoopIVs;
|
||||
for (auto arg : outerLoopIVs)
|
||||
sumLoopIVs.push_back(arg);
|
||||
for (auto arg : sumIterationBlock.getArguments())
|
||||
sumLoopIVs.push_back(arg);
|
||||
|
||||
// Sum up values.
|
||||
Value sum = rewriter.create<LoadOp>(loc, sumOp);
|
||||
Value next = rewriter.create<LoadOp>(loc, operands[0], sumLoopIVs);
|
||||
Value sub = rewriter.create<SubFOp>(loc, next, max);
|
||||
Value exp = rewriter.create<ExpOp>(loc, sub);
|
||||
sum = rewriter.create<AddFOp>(loc, sum, exp);
|
||||
rewriter.create<StoreOp>(loc, sum, sumOp);
|
||||
// Store intermediate values in the result to avoid recomputation.
|
||||
rewriter.create<StoreOp>(loc, exp, alloc, sumLoopIVs);
|
||||
|
||||
// Get the sum.
|
||||
rewriter.setInsertionPoint(softmaxIterateOp);
|
||||
sum = rewriter.create<LoadOp>(loc, sumOp);
|
||||
|
||||
// Insert instructions inside the softmax loop.
|
||||
Block &softmaxIterationBlock = softmaxIterateOp.bodyRegion().front();
|
||||
rewriter.setInsertionPointToStart(&softmaxIterationBlock);
|
||||
|
||||
// Get induction variables.
|
||||
SmallVector<Value, 4> softmaxLoopIVs;
|
||||
for (auto arg : outerLoopIVs)
|
||||
softmaxLoopIVs.push_back(arg);
|
||||
for (auto arg : softmaxIterationBlock.getArguments())
|
||||
softmaxLoopIVs.push_back(arg);
|
||||
|
||||
// Compute softmax.
|
||||
Value expLoadedVal = rewriter.create<LoadOp>(loc, alloc, softmaxLoopIVs);
|
||||
Value result = rewriter.create<DivFOp>(loc, expLoadedVal, sum);
|
||||
rewriter.create<StoreOp>(loc, result, alloc, softmaxLoopIVs);
|
||||
|
||||
rewriter.replaceOp(op, alloc);
|
||||
|
||||
return matchSuccess();
|
||||
}
|
||||
};
|
||||
|
||||
struct ONNXReshapeOpLowering : public ConversionPattern {
|
||||
ONNXReshapeOpLowering(MLIRContext *ctx)
|
||||
: ConversionPattern(mlir::ONNXReshapeOp::getOperationName(), 1, ctx) {}
|
||||
|
@ -1005,7 +1224,8 @@ void FrontendToKrnlLoweringPass::runOnModule() {
|
|||
ONNXElementwiseVariadicOpLowering<mlir::ONNXSumOp>,
|
||||
ONNXElementwiseVariadicOpLowering<mlir::ONNXMaxOp>,
|
||||
ONNXElementwiseVariadicOpLowering<mlir::ONNXMinOp>,
|
||||
ONNXReshapeOpLowering, ONNXEntryPointLowering>(&getContext());
|
||||
ONNXReshapeOpLowering, ONNXEntryPointLowering,
|
||||
ONNXSoftmaxOpLowering>(&getContext());
|
||||
|
||||
// With the target and rewrite patterns defined, we can now attempt the
|
||||
// conversion. The conversion will signal failure if any of our `illegal`
|
||||
|
|
|
@ -116,7 +116,8 @@ public:
|
|||
op->getName().getStringRef() != "onnx.Gemm" &&
|
||||
op->getName().getStringRef() != "onnx.GemmNoBias" &&
|
||||
op->getName().getStringRef() != "onnx.Reshape" &&
|
||||
op->getName().getStringRef() != "onnx.Transpose")
|
||||
op->getName().getStringRef() != "onnx.Transpose" &&
|
||||
op->getName().getStringRef() != "onnx.Softmax")
|
||||
return false;
|
||||
return llvm::any_of(op->getResultTypes(), [](Type result_type) {
|
||||
return !result_type.isa<RankedTensorType>();
|
||||
|
|
|
@ -130,6 +130,14 @@ test_to_enable = [
|
|||
"test_sigmoid_cpu",
|
||||
"test_sigmoid_example_cpu",
|
||||
|
||||
# Softmax Op:
|
||||
"test_softmax_axis_0_cpu",
|
||||
"test_softmax_axis_1_cpu",
|
||||
"test_softmax_axis_2_cpu",
|
||||
"test_softmax_default_axis_cpu",
|
||||
"test_softmax_example_cpu",
|
||||
"test_softmax_large_number_cpu",
|
||||
|
||||
# Sum Op:
|
||||
#"test_sum_example_cpu", <- error
|
||||
"test_sum_one_input_cpu",
|
||||
|
|
|
@ -533,3 +533,49 @@ func @test_add_with_broadcasting(%arg0 : tensor<?xf32>, %arg1 : tensor<?x10xf32>
|
|||
// CHECK: }
|
||||
// CHECK: return [[RES]] : memref<?x10xf32>
|
||||
}
|
||||
|
||||
func @test_softmax(%arg0 : tensor<10x10xf32>) -> tensor<*xf32> {
|
||||
%0 = "onnx.Softmax"(%arg0) {Softmax.axis=1:i32} : (tensor<10x10xf32>) -> tensor<*xf32>
|
||||
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||
|
||||
// CHECK-LABEL: test_softmax
|
||||
// CHECK: [[MAX:%.+]] = alloc() : memref<f32>
|
||||
// CHECK: [[SUM:%.+]] = alloc() : memref<f32>
|
||||
// CHECK: [[RES:%.+]] = alloc() : memref<10x10xf32>
|
||||
// CHECK: [[CST:%.+]] = constant 0.000000e+00 : f32
|
||||
// CHECK: [[CST_0:%.+]] = constant 0xFF800000 : f32
|
||||
// CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2
|
||||
// CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops {
|
||||
// CHECK: krnl.return_loops [[DEF_LOOPS]]#0, %3#1
|
||||
// CHECK: } : () -> (!krnl.loop, !krnl.loop)
|
||||
// CHECK: krnl.iterate([[OPT_LOOPS]]#0) with ([[DEF_LOOPS]]#0 -> %arg1 = 0 to 10) {
|
||||
// CHECK: store [[CST]], [[SUM]][] : memref<f32>
|
||||
// CHECK: store [[CST_0]], [[MAX]][] : memref<f32>
|
||||
// CHECK: krnl.iterate([[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#1 -> %arg2 = 0 to 10) {
|
||||
// CHECK: [[LOAD1:%.+]] = load [[MAX]][] : memref<f32>
|
||||
// CHECK: [[LOAD2:%.+]] = load %arg0[%arg1, %arg2] : memref<10x10xf32>
|
||||
// CHECK: [[COND:%.+]] = cmpf "ogt", [[LOAD1]], [[LOAD2]] : f32
|
||||
// CHECK: [[SELECT:%.+]] = select [[COND]], [[LOAD1]], [[LOAD2]] : f32
|
||||
// CHECK: store [[SELECT]], [[MAX]][] : memref<f32>
|
||||
// CHECK: }
|
||||
// CHECK: %5 = load [[MAX]][] : memref<f32>
|
||||
// CHECK: krnl.iterate([[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#1 -> %arg2 = 0 to 10) {
|
||||
// CHECK: [[LOAD1]] = load [[SUM]][] : memref<f32>
|
||||
// CHECK: [[LOAD2]] = load %arg0[%arg1, %arg2] : memref<10x10xf32>
|
||||
// CHECK: [[SUB:%.+]] = subf [[LOAD2]], %5 : f32
|
||||
// CHECK: [[EXP:%.+]] = exp [[SUB]] : f32
|
||||
// CHECK: [[ADD:%.+]] = addf [[LOAD1]], [[EXP]] : f32
|
||||
// CHECK: store [[ADD]], [[SUM]][] : memref<f32>
|
||||
// CHECK: store %10, [[RES]][%arg1, %arg2] : memref<10x10xf32>
|
||||
// CHECK: }
|
||||
// CHECK: %6 = load [[SUM]][] : memref<f32>
|
||||
// CHECK: krnl.iterate([[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#1 -> %arg2 = 0 to 10) {
|
||||
// CHECK: [[LOAD1]] = load [[RES]][%arg1, %arg2] : memref<10x10xf32>
|
||||
// CHECK: [[DIV:%.+]] = divf [[LOAD1]], %6 : f32
|
||||
// CHECK: store [[DIV]], [[RES]][%arg1, %arg2] : memref<10x10xf32>
|
||||
// CHECK: }
|
||||
// CHECK: }
|
||||
// CHECK: dealloc [[SUM]] : memref<f32>
|
||||
// CHECK: dealloc [[MAX]] : memref<f32>
|
||||
// CHECK: return [[RES]] : memref<10x10xf32>
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue