Lowering softmax (#14)

* Rebase

* Use max normalization

* Handle axis

* Add tests

* Update SharingWork.md

* Remove redundant spaces

* Format code

* Rebase

* Change from the use of Value* to Value

* Add end-to-end tests

Co-authored-by: Tian Jin <tjingrant@gmail.com>
This commit is contained in:
Tung D. Le 2020-01-21 11:57:32 +09:00 committed by Tian Jin
parent 0aaab0d2d2
commit e89e51699b
8 changed files with 288 additions and 4 deletions

View File

@ -27,6 +27,7 @@ ONNX operations for which some work is needed.
| Selu | Tung | v | v | | | Selu | Tung | v | v | |
| Sigmoid | Tung | v | v | | | Sigmoid | Tung | v | v | |
| Sinh | Tung | v | v | | | Sinh | Tung | v | v | |
| Softmax | Tung | v | v | |
| Sub | Tung | v | v | M | | Sub | Tung | v | v | M |
| Sum | Tung | v | v | M | | Sum | Tung | v | v | M |
| Tanh | Tung | v | v | | | Tanh | Tung | v | v | |

View File

@ -267,7 +267,7 @@ def gen_schema(schema) :
'Add', 'Mul', 'Div', 'Sub', 'And', 'Or', 'Xor', 'Add', 'Mul', 'Div', 'Sub', 'And', 'Or', 'Xor',
'Sum', 'Max', 'Min', 'MatMul', 'Gemm', 'LeakyRelu', 'Sum', 'Max', 'Min', 'MatMul', 'Gemm', 'LeakyRelu',
'Elu', 'Selu', 'HardSigmoid', 'Reshape', 'Reciprocal', 'Elu', 'Selu', 'HardSigmoid', 'Reshape', 'Reciprocal',
'Identity', 'Cos', 'Log', 'Transpose'] 'Identity', 'Cos', 'Log', 'Transpose', 'Softmax']
CanonicalList=['Add', 'Identity'] CanonicalList=['Add', 'Identity']
line_indent = ' ' line_indent = ' '

View File

@ -158,6 +158,14 @@ void ONNXReciprocalOp::inferShapes() {
getResult().setType(getOperand().getType()); getResult().setType(getOperand().getType());
} }
//===----------------------------------------------------------------------===//
// Softmax
/// Infer the output shape of the ONNXSoftmaxOp. This method is required by
/// the shape inference interface.
void ONNXSoftmaxOp::inferShapes() {
getResult().setType(getOperand().getType());
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Add // Add
/// Infer the output shape of the ONNXAddOp. This method is required by the /// Infer the output shape of the ONNXAddOp. This method is required by the

View File

@ -2831,7 +2831,7 @@ def ONNXSliceOp:ONNX_Op<"Slice",
} }
def ONNXSoftmaxOp:ONNX_Op<"Softmax", def ONNXSoftmaxOp:ONNX_Op<"Softmax",
[NoSideEffect]> { [NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
let summary = "ONNX Softmax operation"; let summary = "ONNX Softmax operation";
let description = [{ let description = [{
"The operator computes the softmax (normalized exponential) values for each layer in the batch" "The operator computes the softmax (normalized exponential) values for each layer in the batch"

View File

@ -824,6 +824,225 @@ struct ONNXElementwiseVariadicOpLowering : public ConversionPattern {
} }
}; };
struct ONNXSoftmaxOpLowering : public ConversionPattern {
ONNXSoftmaxOpLowering(MLIRContext *ctx)
: ConversionPattern(mlir::ONNXSoftmaxOp::getOperationName(), 1, ctx) {}
PatternMatchResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
// softmax(x) = let max_x = max(x) in
// let exp_x = exp(x - max_x) in
// let sum = sum(exp_x) in
// exp_x / sum
auto tensorType = (*op->result_type_begin()).cast<RankedTensorType>();
int64_t rank = tensorType.getRank();
int64_t axis = op->getAttrOfType<IntegerAttr>("Softmax.axis").getInt();
axis = axis >= 0 ? axis : rank + axis;
assert(axis >= -rank && axis <= rank - 1);
auto loc = op->getLoc();
// Insert an allocation and deallocation for the result of this operation.
auto memRefType = convertTensorToMemRef(tensorType);
auto elementType = memRefType.getElementType();
Value alloc;
bool insertDealloc = checkInsertDealloc(op);
if (hasAllConstantDimensions(memRefType))
alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc);
else
alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc,
operands[0]);
// Shape of the result
auto memRefShape = memRefType.getShape();
// Insert allocations and deallocations for sum and max.
MemRefType scalarMemRefType = MemRefType::get({}, elementType, {}, 0);
Value sumOp = insertAllocAndDealloc(scalarMemRefType, loc, rewriter, true);
Value maxOp = insertAllocAndDealloc(scalarMemRefType, loc, rewriter, true);
Value zero =
rewriter.create<ConstantOp>(loc, FloatAttr::get(elementType, 0));
Value negInfinity = rewriter.create<ConstantOp>(
loc,
FloatAttr::get(elementType, -std::numeric_limits<float>::infinity()));
// Define loops.
auto loopsOp = rewriter.create<KrnlDefineLoopsOp>(loc, rank);
std::vector<Value> originalLoops;
originalLoops.reserve(rank);
for (auto result : loopsOp.getResults()) {
originalLoops.push_back(result);
}
// Define loop optimization.
auto optimizedLoopsOp = rewriter.create<KrnlOptimizeLoopsOp>(loc, rank);
std::vector<Value> optimizedLoops;
optimizedLoops.reserve(rank);
for (auto result : optimizedLoopsOp.getResults()) {
optimizedLoops.push_back(result);
}
Block &optimizationBlock = optimizedLoopsOp.region().front();
// Coerce the input into a 2-D tensor. `axis` will be the coercing point.
// This coercing follows the softmax definition in ONNX:
// https://github.com/onnx/onnx/blob/master/docs/Operators.md#Softmax
// Here, we create an outer loop and inner loop for handling the two
// dimensions. The outer loop is only created once `axis` is not zero.
// Define an outer loop with respect to axis.
std::vector<Value> outerLoops, optimizedOuterLoops;
outerLoops.reserve(axis);
optimizedOuterLoops.reserve(axis);
for (int i = 0; i < axis; ++i) {
outerLoops.push_back(originalLoops[i]);
optimizedOuterLoops.push_back(optimizedLoops[i]);
}
KrnlIterateOperandPack outerPack(rewriter, outerLoops, optimizedOuterLoops);
for (int i = 0; i < axis; ++i) {
if (memRefShape[i] < 0) {
outerPack.pushConstantBound(0);
outerPack.pushOperandBound(
rewriter.create<DimOp>(loc, operands[0], i).getResult());
} else {
outerPack.pushConstantBound(0);
outerPack.pushConstantBound(memRefShape[i]);
}
}
// Define an inner loop with respect to axis.
std::vector<Value> innerLoops, optimizedInnerLoops;
innerLoops.reserve(rank - axis);
optimizedInnerLoops.reserve(rank - axis);
for (int i = axis; i < rank; ++i) {
innerLoops.push_back(originalLoops[i]);
optimizedInnerLoops.push_back(optimizedLoops[i]);
}
KrnlIterateOperandPack innerPack(rewriter, innerLoops, optimizedInnerLoops);
for (int i = axis; i < rank; ++i) {
if (memRefShape[i] < 0) {
innerPack.pushConstantBound(0);
innerPack.pushOperandBound(
rewriter.create<DimOp>(loc, operands[0], i).getResult());
} else {
innerPack.pushConstantBound(0);
innerPack.pushConstantBound(memRefShape[i]);
}
}
KrnlIterateOp outerIterateOp, maxIterateOp, sumIterateOp, softmaxIterateOp;
SmallVector<Value, 4> outerLoopIVs;
if (axis != 0) {
outerIterateOp = rewriter.create<KrnlIterateOp>(loc, outerPack);
// No optimization
rewriter.setInsertionPointToEnd(&optimizationBlock);
rewriter.create<KrnlReturnLoopsOp>(loc, originalLoops);
rewriter.setInsertionPoint(optimizedLoopsOp);
// Insert instructions inside the outer loop.
Block &outerIterationBlock = outerIterateOp.bodyRegion().front();
rewriter.setInsertionPointToStart(&outerIterationBlock);
for (auto arg : outerIterationBlock.getArguments())
outerLoopIVs.push_back(arg);
// Reset accumulators.
rewriter.create<StoreOp>(loc, zero, sumOp);
rewriter.create<StoreOp>(loc, negInfinity, maxOp);
// Create an inner loop to compute max.
maxIterateOp = rewriter.create<KrnlIterateOp>(loc, innerPack);
// Create an inner loop to compute sum.
sumIterateOp = rewriter.create<KrnlIterateOp>(loc, innerPack);
// Create an inner loop to compute softmax.
softmaxIterateOp = rewriter.create<KrnlIterateOp>(loc, innerPack);
} else {
// Reset accumulators.
rewriter.create<StoreOp>(loc, zero, sumOp);
rewriter.create<StoreOp>(loc, negInfinity, maxOp);
// Create an inner loop to compute max.
maxIterateOp = rewriter.create<KrnlIterateOp>(loc, innerPack);
// Create an inner loop to compute sum.
sumIterateOp = rewriter.create<KrnlIterateOp>(loc, innerPack);
// Create an inner loop to compute softmax.
softmaxIterateOp = rewriter.create<KrnlIterateOp>(loc, innerPack);
// No optimization
rewriter.setInsertionPointToEnd(&optimizationBlock);
rewriter.create<KrnlReturnLoopsOp>(loc, originalLoops);
rewriter.setInsertionPoint(optimizedLoopsOp);
}
// Insert instructions inside the max loop.
Block &maxIterationBlock = maxIterateOp.bodyRegion().front();
rewriter.setInsertionPointToStart(&maxIterationBlock);
// Get induction variables.
SmallVector<Value, 4> maxLoopIVs;
for (auto arg : outerLoopIVs)
maxLoopIVs.push_back(arg);
for (auto arg : maxIterationBlock.getArguments())
maxLoopIVs.push_back(arg);
// Compute the max value.
Value max = rewriter.create<LoadOp>(loc, maxOp);
Value nextMax = rewriter.create<LoadOp>(loc, operands[0], maxLoopIVs);
auto maxCond =
rewriter.create<CmpFOp>(loc, CmpFPredicate::OGT, max, nextMax);
max = rewriter.create<SelectOp>(loc, maxCond, max, nextMax);
rewriter.create<StoreOp>(loc, max, maxOp);
// Get the max.
rewriter.setInsertionPoint(sumIterateOp);
max = rewriter.create<LoadOp>(loc, maxOp);
// Insert instructions inside the sum loop.
Block &sumIterationBlock = sumIterateOp.bodyRegion().front();
rewriter.setInsertionPointToStart(&sumIterationBlock);
// Get induction variables.
SmallVector<Value, 4> sumLoopIVs;
for (auto arg : outerLoopIVs)
sumLoopIVs.push_back(arg);
for (auto arg : sumIterationBlock.getArguments())
sumLoopIVs.push_back(arg);
// Sum up values.
Value sum = rewriter.create<LoadOp>(loc, sumOp);
Value next = rewriter.create<LoadOp>(loc, operands[0], sumLoopIVs);
Value sub = rewriter.create<SubFOp>(loc, next, max);
Value exp = rewriter.create<ExpOp>(loc, sub);
sum = rewriter.create<AddFOp>(loc, sum, exp);
rewriter.create<StoreOp>(loc, sum, sumOp);
// Store intermediate values in the result to avoid recomputation.
rewriter.create<StoreOp>(loc, exp, alloc, sumLoopIVs);
// Get the sum.
rewriter.setInsertionPoint(softmaxIterateOp);
sum = rewriter.create<LoadOp>(loc, sumOp);
// Insert instructions inside the softmax loop.
Block &softmaxIterationBlock = softmaxIterateOp.bodyRegion().front();
rewriter.setInsertionPointToStart(&softmaxIterationBlock);
// Get induction variables.
SmallVector<Value, 4> softmaxLoopIVs;
for (auto arg : outerLoopIVs)
softmaxLoopIVs.push_back(arg);
for (auto arg : softmaxIterationBlock.getArguments())
softmaxLoopIVs.push_back(arg);
// Compute softmax.
Value expLoadedVal = rewriter.create<LoadOp>(loc, alloc, softmaxLoopIVs);
Value result = rewriter.create<DivFOp>(loc, expLoadedVal, sum);
rewriter.create<StoreOp>(loc, result, alloc, softmaxLoopIVs);
rewriter.replaceOp(op, alloc);
return matchSuccess();
}
};
struct ONNXReshapeOpLowering : public ConversionPattern { struct ONNXReshapeOpLowering : public ConversionPattern {
ONNXReshapeOpLowering(MLIRContext *ctx) ONNXReshapeOpLowering(MLIRContext *ctx)
: ConversionPattern(mlir::ONNXReshapeOp::getOperationName(), 1, ctx) {} : ConversionPattern(mlir::ONNXReshapeOp::getOperationName(), 1, ctx) {}
@ -1005,7 +1224,8 @@ void FrontendToKrnlLoweringPass::runOnModule() {
ONNXElementwiseVariadicOpLowering<mlir::ONNXSumOp>, ONNXElementwiseVariadicOpLowering<mlir::ONNXSumOp>,
ONNXElementwiseVariadicOpLowering<mlir::ONNXMaxOp>, ONNXElementwiseVariadicOpLowering<mlir::ONNXMaxOp>,
ONNXElementwiseVariadicOpLowering<mlir::ONNXMinOp>, ONNXElementwiseVariadicOpLowering<mlir::ONNXMinOp>,
ONNXReshapeOpLowering, ONNXEntryPointLowering>(&getContext()); ONNXReshapeOpLowering, ONNXEntryPointLowering,
ONNXSoftmaxOpLowering>(&getContext());
// With the target and rewrite patterns defined, we can now attempt the // With the target and rewrite patterns defined, we can now attempt the
// conversion. The conversion will signal failure if any of our `illegal` // conversion. The conversion will signal failure if any of our `illegal`

View File

@ -116,7 +116,8 @@ public:
op->getName().getStringRef() != "onnx.Gemm" && op->getName().getStringRef() != "onnx.Gemm" &&
op->getName().getStringRef() != "onnx.GemmNoBias" && op->getName().getStringRef() != "onnx.GemmNoBias" &&
op->getName().getStringRef() != "onnx.Reshape" && op->getName().getStringRef() != "onnx.Reshape" &&
op->getName().getStringRef() != "onnx.Transpose") op->getName().getStringRef() != "onnx.Transpose" &&
op->getName().getStringRef() != "onnx.Softmax")
return false; return false;
return llvm::any_of(op->getResultTypes(), [](Type result_type) { return llvm::any_of(op->getResultTypes(), [](Type result_type) {
return !result_type.isa<RankedTensorType>(); return !result_type.isa<RankedTensorType>();

View File

@ -130,6 +130,14 @@ test_to_enable = [
"test_sigmoid_cpu", "test_sigmoid_cpu",
"test_sigmoid_example_cpu", "test_sigmoid_example_cpu",
# Softmax Op:
"test_softmax_axis_0_cpu",
"test_softmax_axis_1_cpu",
"test_softmax_axis_2_cpu",
"test_softmax_default_axis_cpu",
"test_softmax_example_cpu",
"test_softmax_large_number_cpu",
# Sum Op: # Sum Op:
#"test_sum_example_cpu", <- error #"test_sum_example_cpu", <- error
"test_sum_one_input_cpu", "test_sum_one_input_cpu",

View File

@ -533,3 +533,49 @@ func @test_add_with_broadcasting(%arg0 : tensor<?xf32>, %arg1 : tensor<?x10xf32>
// CHECK: } // CHECK: }
// CHECK: return [[RES]] : memref<?x10xf32> // CHECK: return [[RES]] : memref<?x10xf32>
} }
func @test_softmax(%arg0 : tensor<10x10xf32>) -> tensor<*xf32> {
%0 = "onnx.Softmax"(%arg0) {Softmax.axis=1:i32} : (tensor<10x10xf32>) -> tensor<*xf32>
"std.return"(%0) : (tensor<*xf32>) -> ()
// CHECK-LABEL: test_softmax
// CHECK: [[MAX:%.+]] = alloc() : memref<f32>
// CHECK: [[SUM:%.+]] = alloc() : memref<f32>
// CHECK: [[RES:%.+]] = alloc() : memref<10x10xf32>
// CHECK: [[CST:%.+]] = constant 0.000000e+00 : f32
// CHECK: [[CST_0:%.+]] = constant 0xFF800000 : f32
// CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2
// CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops {
// CHECK: krnl.return_loops [[DEF_LOOPS]]#0, %3#1
// CHECK: } : () -> (!krnl.loop, !krnl.loop)
// CHECK: krnl.iterate([[OPT_LOOPS]]#0) with ([[DEF_LOOPS]]#0 -> %arg1 = 0 to 10) {
// CHECK: store [[CST]], [[SUM]][] : memref<f32>
// CHECK: store [[CST_0]], [[MAX]][] : memref<f32>
// CHECK: krnl.iterate([[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#1 -> %arg2 = 0 to 10) {
// CHECK: [[LOAD1:%.+]] = load [[MAX]][] : memref<f32>
// CHECK: [[LOAD2:%.+]] = load %arg0[%arg1, %arg2] : memref<10x10xf32>
// CHECK: [[COND:%.+]] = cmpf "ogt", [[LOAD1]], [[LOAD2]] : f32
// CHECK: [[SELECT:%.+]] = select [[COND]], [[LOAD1]], [[LOAD2]] : f32
// CHECK: store [[SELECT]], [[MAX]][] : memref<f32>
// CHECK: }
// CHECK: %5 = load [[MAX]][] : memref<f32>
// CHECK: krnl.iterate([[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#1 -> %arg2 = 0 to 10) {
// CHECK: [[LOAD1]] = load [[SUM]][] : memref<f32>
// CHECK: [[LOAD2]] = load %arg0[%arg1, %arg2] : memref<10x10xf32>
// CHECK: [[SUB:%.+]] = subf [[LOAD2]], %5 : f32
// CHECK: [[EXP:%.+]] = exp [[SUB]] : f32
// CHECK: [[ADD:%.+]] = addf [[LOAD1]], [[EXP]] : f32
// CHECK: store [[ADD]], [[SUM]][] : memref<f32>
// CHECK: store %10, [[RES]][%arg1, %arg2] : memref<10x10xf32>
// CHECK: }
// CHECK: %6 = load [[SUM]][] : memref<f32>
// CHECK: krnl.iterate([[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#1 -> %arg2 = 0 to 10) {
// CHECK: [[LOAD1]] = load [[RES]][%arg1, %arg2] : memref<10x10xf32>
// CHECK: [[DIV:%.+]] = divf [[LOAD1]], %6 : f32
// CHECK: store [[DIV]], [[RES]][%arg1, %arg2] : memref<10x10xf32>
// CHECK: }
// CHECK: }
// CHECK: dealloc [[SUM]] : memref<f32>
// CHECK: dealloc [[MAX]] : memref<f32>
// CHECK: return [[RES]] : memref<10x10xf32>
}