From e89e51699bac092899c5c4121a9c442bb13e2a1c Mon Sep 17 00:00:00 2001 From: "Tung D. Le" Date: Tue, 21 Jan 2020 11:57:32 +0900 Subject: [PATCH] Lowering softmax (#14) * Rebase * Use max normalization * Handle axis * Add tests * Update SharingWork.md * Remove redundant spaces * Format code * Rebase * Change from the use of Value* to Value * Add end-to-end tests Co-authored-by: Tian Jin --- SharingWork.md | 1 + src/dialect/onnx/gen_doc.py | 2 +- src/dialect/onnx/onnx_ops.cpp | 8 + src/dialect/onnx/onnxop.inc | 2 +- src/pass/lower_frontend_to_krnl.cpp | 222 +++++++++++++++++++++++++++- src/pass/shape_inference_pass.cpp | 3 +- test/backend/test.py | 8 + test/mlir/onnx/onnx_lowering.mlir | 46 ++++++ 8 files changed, 288 insertions(+), 4 deletions(-) diff --git a/SharingWork.md b/SharingWork.md index 6b6c063..fe43494 100644 --- a/SharingWork.md +++ b/SharingWork.md @@ -27,6 +27,7 @@ ONNX operations for which some work is needed. | Selu | Tung | v | v | | | Sigmoid | Tung | v | v | | | Sinh | Tung | v | v | | +| Softmax | Tung | v | v | | | Sub | Tung | v | v | M | | Sum | Tung | v | v | M | | Tanh | Tung | v | v | | diff --git a/src/dialect/onnx/gen_doc.py b/src/dialect/onnx/gen_doc.py index 6d986c2..4141556 100644 --- a/src/dialect/onnx/gen_doc.py +++ b/src/dialect/onnx/gen_doc.py @@ -267,7 +267,7 @@ def gen_schema(schema) : 'Add', 'Mul', 'Div', 'Sub', 'And', 'Or', 'Xor', 'Sum', 'Max', 'Min', 'MatMul', 'Gemm', 'LeakyRelu', 'Elu', 'Selu', 'HardSigmoid', 'Reshape', 'Reciprocal', - 'Identity', 'Cos', 'Log', 'Transpose'] + 'Identity', 'Cos', 'Log', 'Transpose', 'Softmax'] CanonicalList=['Add', 'Identity'] line_indent = ' ' diff --git a/src/dialect/onnx/onnx_ops.cpp b/src/dialect/onnx/onnx_ops.cpp index 7e1675d..53e463d 100644 --- a/src/dialect/onnx/onnx_ops.cpp +++ b/src/dialect/onnx/onnx_ops.cpp @@ -158,6 +158,14 @@ void ONNXReciprocalOp::inferShapes() { getResult().setType(getOperand().getType()); } +//===----------------------------------------------------------------------===// +// Softmax +/// Infer the output shape of the ONNXSoftmaxOp. This method is required by +/// the shape inference interface. +void ONNXSoftmaxOp::inferShapes() { + getResult().setType(getOperand().getType()); +} + //===----------------------------------------------------------------------===// // Add /// Infer the output shape of the ONNXAddOp. This method is required by the diff --git a/src/dialect/onnx/onnxop.inc b/src/dialect/onnx/onnxop.inc index fc2714e..e87a01a 100644 --- a/src/dialect/onnx/onnxop.inc +++ b/src/dialect/onnx/onnxop.inc @@ -2831,7 +2831,7 @@ def ONNXSliceOp:ONNX_Op<"Slice", } def ONNXSoftmaxOp:ONNX_Op<"Softmax", - [NoSideEffect]> { + [NoSideEffect, DeclareOpInterfaceMethods]> { let summary = "ONNX Softmax operation"; let description = [{ "The operator computes the softmax (normalized exponential) values for each layer in the batch" diff --git a/src/pass/lower_frontend_to_krnl.cpp b/src/pass/lower_frontend_to_krnl.cpp index a578479..3d899ee 100644 --- a/src/pass/lower_frontend_to_krnl.cpp +++ b/src/pass/lower_frontend_to_krnl.cpp @@ -824,6 +824,225 @@ struct ONNXElementwiseVariadicOpLowering : public ConversionPattern { } }; +struct ONNXSoftmaxOpLowering : public ConversionPattern { + ONNXSoftmaxOpLowering(MLIRContext *ctx) + : ConversionPattern(mlir::ONNXSoftmaxOp::getOperationName(), 1, ctx) {} + PatternMatchResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const final { + // softmax(x) = let max_x = max(x) in + // let exp_x = exp(x - max_x) in + // let sum = sum(exp_x) in + // exp_x / sum + auto tensorType = (*op->result_type_begin()).cast(); + int64_t rank = tensorType.getRank(); + int64_t axis = op->getAttrOfType("Softmax.axis").getInt(); + axis = axis >= 0 ? axis : rank + axis; + assert(axis >= -rank && axis <= rank - 1); + + auto loc = op->getLoc(); + + // Insert an allocation and deallocation for the result of this operation. + auto memRefType = convertTensorToMemRef(tensorType); + auto elementType = memRefType.getElementType(); + + Value alloc; + bool insertDealloc = checkInsertDealloc(op); + if (hasAllConstantDimensions(memRefType)) + alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc); + else + alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc, + operands[0]); + + // Shape of the result + auto memRefShape = memRefType.getShape(); + + // Insert allocations and deallocations for sum and max. + MemRefType scalarMemRefType = MemRefType::get({}, elementType, {}, 0); + Value sumOp = insertAllocAndDealloc(scalarMemRefType, loc, rewriter, true); + Value maxOp = insertAllocAndDealloc(scalarMemRefType, loc, rewriter, true); + Value zero = + rewriter.create(loc, FloatAttr::get(elementType, 0)); + Value negInfinity = rewriter.create( + loc, + FloatAttr::get(elementType, -std::numeric_limits::infinity())); + + // Define loops. + auto loopsOp = rewriter.create(loc, rank); + std::vector originalLoops; + originalLoops.reserve(rank); + for (auto result : loopsOp.getResults()) { + originalLoops.push_back(result); + } + + // Define loop optimization. + auto optimizedLoopsOp = rewriter.create(loc, rank); + std::vector optimizedLoops; + optimizedLoops.reserve(rank); + for (auto result : optimizedLoopsOp.getResults()) { + optimizedLoops.push_back(result); + } + Block &optimizationBlock = optimizedLoopsOp.region().front(); + + // Coerce the input into a 2-D tensor. `axis` will be the coercing point. + // This coercing follows the softmax definition in ONNX: + // https://github.com/onnx/onnx/blob/master/docs/Operators.md#Softmax + // Here, we create an outer loop and inner loop for handling the two + // dimensions. The outer loop is only created once `axis` is not zero. + + // Define an outer loop with respect to axis. + std::vector outerLoops, optimizedOuterLoops; + outerLoops.reserve(axis); + optimizedOuterLoops.reserve(axis); + for (int i = 0; i < axis; ++i) { + outerLoops.push_back(originalLoops[i]); + optimizedOuterLoops.push_back(optimizedLoops[i]); + } + KrnlIterateOperandPack outerPack(rewriter, outerLoops, optimizedOuterLoops); + for (int i = 0; i < axis; ++i) { + if (memRefShape[i] < 0) { + outerPack.pushConstantBound(0); + outerPack.pushOperandBound( + rewriter.create(loc, operands[0], i).getResult()); + } else { + outerPack.pushConstantBound(0); + outerPack.pushConstantBound(memRefShape[i]); + } + } + // Define an inner loop with respect to axis. + std::vector innerLoops, optimizedInnerLoops; + innerLoops.reserve(rank - axis); + optimizedInnerLoops.reserve(rank - axis); + for (int i = axis; i < rank; ++i) { + innerLoops.push_back(originalLoops[i]); + optimizedInnerLoops.push_back(optimizedLoops[i]); + } + KrnlIterateOperandPack innerPack(rewriter, innerLoops, optimizedInnerLoops); + for (int i = axis; i < rank; ++i) { + if (memRefShape[i] < 0) { + innerPack.pushConstantBound(0); + innerPack.pushOperandBound( + rewriter.create(loc, operands[0], i).getResult()); + } else { + innerPack.pushConstantBound(0); + innerPack.pushConstantBound(memRefShape[i]); + } + } + + KrnlIterateOp outerIterateOp, maxIterateOp, sumIterateOp, softmaxIterateOp; + SmallVector outerLoopIVs; + if (axis != 0) { + outerIterateOp = rewriter.create(loc, outerPack); + + // No optimization + rewriter.setInsertionPointToEnd(&optimizationBlock); + rewriter.create(loc, originalLoops); + rewriter.setInsertionPoint(optimizedLoopsOp); + + // Insert instructions inside the outer loop. + Block &outerIterationBlock = outerIterateOp.bodyRegion().front(); + rewriter.setInsertionPointToStart(&outerIterationBlock); + for (auto arg : outerIterationBlock.getArguments()) + outerLoopIVs.push_back(arg); + + // Reset accumulators. + rewriter.create(loc, zero, sumOp); + rewriter.create(loc, negInfinity, maxOp); + + // Create an inner loop to compute max. + maxIterateOp = rewriter.create(loc, innerPack); + // Create an inner loop to compute sum. + sumIterateOp = rewriter.create(loc, innerPack); + // Create an inner loop to compute softmax. + softmaxIterateOp = rewriter.create(loc, innerPack); + } else { + // Reset accumulators. + rewriter.create(loc, zero, sumOp); + rewriter.create(loc, negInfinity, maxOp); + + // Create an inner loop to compute max. + maxIterateOp = rewriter.create(loc, innerPack); + // Create an inner loop to compute sum. + sumIterateOp = rewriter.create(loc, innerPack); + // Create an inner loop to compute softmax. + softmaxIterateOp = rewriter.create(loc, innerPack); + + // No optimization + rewriter.setInsertionPointToEnd(&optimizationBlock); + rewriter.create(loc, originalLoops); + rewriter.setInsertionPoint(optimizedLoopsOp); + } + + // Insert instructions inside the max loop. + Block &maxIterationBlock = maxIterateOp.bodyRegion().front(); + rewriter.setInsertionPointToStart(&maxIterationBlock); + + // Get induction variables. + SmallVector maxLoopIVs; + for (auto arg : outerLoopIVs) + maxLoopIVs.push_back(arg); + for (auto arg : maxIterationBlock.getArguments()) + maxLoopIVs.push_back(arg); + + // Compute the max value. + Value max = rewriter.create(loc, maxOp); + Value nextMax = rewriter.create(loc, operands[0], maxLoopIVs); + auto maxCond = + rewriter.create(loc, CmpFPredicate::OGT, max, nextMax); + max = rewriter.create(loc, maxCond, max, nextMax); + rewriter.create(loc, max, maxOp); + + // Get the max. + rewriter.setInsertionPoint(sumIterateOp); + max = rewriter.create(loc, maxOp); + + // Insert instructions inside the sum loop. + Block &sumIterationBlock = sumIterateOp.bodyRegion().front(); + rewriter.setInsertionPointToStart(&sumIterationBlock); + + // Get induction variables. + SmallVector sumLoopIVs; + for (auto arg : outerLoopIVs) + sumLoopIVs.push_back(arg); + for (auto arg : sumIterationBlock.getArguments()) + sumLoopIVs.push_back(arg); + + // Sum up values. + Value sum = rewriter.create(loc, sumOp); + Value next = rewriter.create(loc, operands[0], sumLoopIVs); + Value sub = rewriter.create(loc, next, max); + Value exp = rewriter.create(loc, sub); + sum = rewriter.create(loc, sum, exp); + rewriter.create(loc, sum, sumOp); + // Store intermediate values in the result to avoid recomputation. + rewriter.create(loc, exp, alloc, sumLoopIVs); + + // Get the sum. + rewriter.setInsertionPoint(softmaxIterateOp); + sum = rewriter.create(loc, sumOp); + + // Insert instructions inside the softmax loop. + Block &softmaxIterationBlock = softmaxIterateOp.bodyRegion().front(); + rewriter.setInsertionPointToStart(&softmaxIterationBlock); + + // Get induction variables. + SmallVector softmaxLoopIVs; + for (auto arg : outerLoopIVs) + softmaxLoopIVs.push_back(arg); + for (auto arg : softmaxIterationBlock.getArguments()) + softmaxLoopIVs.push_back(arg); + + // Compute softmax. + Value expLoadedVal = rewriter.create(loc, alloc, softmaxLoopIVs); + Value result = rewriter.create(loc, expLoadedVal, sum); + rewriter.create(loc, result, alloc, softmaxLoopIVs); + + rewriter.replaceOp(op, alloc); + + return matchSuccess(); + } +}; + struct ONNXReshapeOpLowering : public ConversionPattern { ONNXReshapeOpLowering(MLIRContext *ctx) : ConversionPattern(mlir::ONNXReshapeOp::getOperationName(), 1, ctx) {} @@ -1005,7 +1224,8 @@ void FrontendToKrnlLoweringPass::runOnModule() { ONNXElementwiseVariadicOpLowering, ONNXElementwiseVariadicOpLowering, ONNXElementwiseVariadicOpLowering, - ONNXReshapeOpLowering, ONNXEntryPointLowering>(&getContext()); + ONNXReshapeOpLowering, ONNXEntryPointLowering, + ONNXSoftmaxOpLowering>(&getContext()); // With the target and rewrite patterns defined, we can now attempt the // conversion. The conversion will signal failure if any of our `illegal` diff --git a/src/pass/shape_inference_pass.cpp b/src/pass/shape_inference_pass.cpp index 5ccb9a4..3226f16 100644 --- a/src/pass/shape_inference_pass.cpp +++ b/src/pass/shape_inference_pass.cpp @@ -116,7 +116,8 @@ public: op->getName().getStringRef() != "onnx.Gemm" && op->getName().getStringRef() != "onnx.GemmNoBias" && op->getName().getStringRef() != "onnx.Reshape" && - op->getName().getStringRef() != "onnx.Transpose") + op->getName().getStringRef() != "onnx.Transpose" && + op->getName().getStringRef() != "onnx.Softmax") return false; return llvm::any_of(op->getResultTypes(), [](Type result_type) { return !result_type.isa(); diff --git a/test/backend/test.py b/test/backend/test.py index a9072db..60ca4a8 100644 --- a/test/backend/test.py +++ b/test/backend/test.py @@ -130,6 +130,14 @@ test_to_enable = [ "test_sigmoid_cpu", "test_sigmoid_example_cpu", + # Softmax Op: + "test_softmax_axis_0_cpu", + "test_softmax_axis_1_cpu", + "test_softmax_axis_2_cpu", + "test_softmax_default_axis_cpu", + "test_softmax_example_cpu", + "test_softmax_large_number_cpu", + # Sum Op: #"test_sum_example_cpu", <- error "test_sum_one_input_cpu", diff --git a/test/mlir/onnx/onnx_lowering.mlir b/test/mlir/onnx/onnx_lowering.mlir index 123e6a1..3ffce9a 100644 --- a/test/mlir/onnx/onnx_lowering.mlir +++ b/test/mlir/onnx/onnx_lowering.mlir @@ -533,3 +533,49 @@ func @test_add_with_broadcasting(%arg0 : tensor, %arg1 : tensor // CHECK: } // CHECK: return [[RES]] : memref } + +func @test_softmax(%arg0 : tensor<10x10xf32>) -> tensor<*xf32> { + %0 = "onnx.Softmax"(%arg0) {Softmax.axis=1:i32} : (tensor<10x10xf32>) -> tensor<*xf32> + "std.return"(%0) : (tensor<*xf32>) -> () + + // CHECK-LABEL: test_softmax + // CHECK: [[MAX:%.+]] = alloc() : memref + // CHECK: [[SUM:%.+]] = alloc() : memref + // CHECK: [[RES:%.+]] = alloc() : memref<10x10xf32> + // CHECK: [[CST:%.+]] = constant 0.000000e+00 : f32 + // CHECK: [[CST_0:%.+]] = constant 0xFF800000 : f32 + // CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2 + // CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops { + // CHECK: krnl.return_loops [[DEF_LOOPS]]#0, %3#1 + // CHECK: } : () -> (!krnl.loop, !krnl.loop) + // CHECK: krnl.iterate([[OPT_LOOPS]]#0) with ([[DEF_LOOPS]]#0 -> %arg1 = 0 to 10) { + // CHECK: store [[CST]], [[SUM]][] : memref + // CHECK: store [[CST_0]], [[MAX]][] : memref + // CHECK: krnl.iterate([[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#1 -> %arg2 = 0 to 10) { + // CHECK: [[LOAD1:%.+]] = load [[MAX]][] : memref + // CHECK: [[LOAD2:%.+]] = load %arg0[%arg1, %arg2] : memref<10x10xf32> + // CHECK: [[COND:%.+]] = cmpf "ogt", [[LOAD1]], [[LOAD2]] : f32 + // CHECK: [[SELECT:%.+]] = select [[COND]], [[LOAD1]], [[LOAD2]] : f32 + // CHECK: store [[SELECT]], [[MAX]][] : memref + // CHECK: } + // CHECK: %5 = load [[MAX]][] : memref + // CHECK: krnl.iterate([[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#1 -> %arg2 = 0 to 10) { + // CHECK: [[LOAD1]] = load [[SUM]][] : memref + // CHECK: [[LOAD2]] = load %arg0[%arg1, %arg2] : memref<10x10xf32> + // CHECK: [[SUB:%.+]] = subf [[LOAD2]], %5 : f32 + // CHECK: [[EXP:%.+]] = exp [[SUB]] : f32 + // CHECK: [[ADD:%.+]] = addf [[LOAD1]], [[EXP]] : f32 + // CHECK: store [[ADD]], [[SUM]][] : memref + // CHECK: store %10, [[RES]][%arg1, %arg2] : memref<10x10xf32> + // CHECK: } + // CHECK: %6 = load [[SUM]][] : memref + // CHECK: krnl.iterate([[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#1 -> %arg2 = 0 to 10) { + // CHECK: [[LOAD1]] = load [[RES]][%arg1, %arg2] : memref<10x10xf32> + // CHECK: [[DIV:%.+]] = divf [[LOAD1]], %6 : f32 + // CHECK: store [[DIV]], [[RES]][%arg1, %arg2] : memref<10x10xf32> + // CHECK: } + // CHECK: } + // CHECK: dealloc [[SUM]] : memref + // CHECK: dealloc [[MAX]] : memref + // CHECK: return [[RES]] : memref<10x10xf32> +}