From 027245152188bf6c842972474fe222e99d9afb95 Mon Sep 17 00:00:00 2001 From: Gheorghe-Teodor Bercea Date: Fri, 7 Feb 2020 16:51:32 -0500 Subject: [PATCH] Lower convolution to KRNL dialect. (#65) * Ensure data shape is at least 4. * First version of convolution. * Simplify code for KRNL lowering. * Add test without padding or strides. * Refactor code for lowering frontend operations to KRNL dialect. * Add test for conv with no bias and no padding. * Add test with group greater than one. * Address comment. --- src/dialect/onnx/onnx_ops.cpp | 4 + src/pass/lower_frontend_to_krnl.cpp | 524 +++++++++++++++++++--------- test/backend/test.py | 3 + test/mlir/onnx/onnx_lowering.mlir | 100 +++++- 4 files changed, 457 insertions(+), 174 deletions(-) diff --git a/src/dialect/onnx/onnx_ops.cpp b/src/dialect/onnx/onnx_ops.cpp index 8d9b52b..0946a65 100644 --- a/src/dialect/onnx/onnx_ops.cpp +++ b/src/dialect/onnx/onnx_ops.cpp @@ -628,6 +628,10 @@ void ONNXConvNoBiasOp::inferShapes() { auto dataShape = dataTy.getShape(); auto weightShape = weightTy.getShape(); + // Lowest ranked input supported is of shape (N x C x H x W). + if (dataShape.size() < 4) + emitError("Data input shape must be at least (NxCxHxW)."); + // Check that shape of weight and data have same length. if (dataShape.size() != weightShape.size()) emitError("Weight size not compatible with data size."); diff --git a/src/pass/lower_frontend_to_krnl.cpp b/src/pass/lower_frontend_to_krnl.cpp index 7f13a0f..040b1be 100644 --- a/src/pass/lower_frontend_to_krnl.cpp +++ b/src/pass/lower_frontend_to_krnl.cpp @@ -130,6 +130,78 @@ static bool checkInsertDealloc(Operation *currentOp) { return insertDealloc; } +// Add bounds associated with the op operand to the KRNL iteration pack. +// Dynamic dimenions are supported. +static void addDimensionToPack(ConversionPatternRewriter &rewriter, + Location loc, KrnlIterateOperandPack &pack, Value operand, int index) { + auto shape = operand.getType().cast().getShape(); + if (shape[index] < 0) { + pack.pushConstantBound(0); + pack.pushOperandBound( + rewriter.create(loc, operand, index).getResult()); + } else { + pack.pushConstantBound(0); + pack.pushConstantBound(shape[index]); + } +} + +// Function that defines the KRNL dialect loops and their respective +// optimized version. +static KrnlOptimizeLoopsOp emitOptimizedLoops( + ConversionPatternRewriter &rewriter, Location loc, + std::vector &loops, std::vector &optimizedLoops, + int64_t numLoops) { + // Define loops. + auto loopsOp = rewriter.create(loc, numLoops); + loops.reserve(numLoops); + for (auto result : loopsOp.getResults()) + loops.push_back(result); + + // Define optimized version of the loops. + auto optimizedLoopsOp = rewriter.create(loc, numLoops); + optimizedLoops.reserve(numLoops); + for (auto result : optimizedLoopsOp.getResults()) + optimizedLoops.push_back(result); + + return optimizedLoopsOp; +} + +// Function that emits the loops and their optimized version. +// The function returns a reference to the inner optimization block. +static Block* defineLoops(ConversionPatternRewriter &rewriter, + Location loc, std::vector &loops, + std::vector &optimizedLoops, int64_t numLoops) { + KrnlOptimizeLoopsOp optimizedLoopsOp = emitOptimizedLoops( + rewriter, loc, loops, optimizedLoops, numLoops); + return &optimizedLoopsOp.region().front(); +} + +// Function which emits a basic set of loops and optimized loops +// for a given operation argument. A reference to the loop optimization +// block is returned in the last argument of the function. +static void emitKrnlLoopsAndIterationForOperand( + ConversionPatternRewriter &rewriter, Location loc, + Value operand, std::vector &originalLoops, + KrnlOptimizeLoopsOp &optimizedLoopsOp, KrnlIterateOp &iterateOp) { + // Operand shape. + auto shape = operand.getType().cast().getShape(); + + // Number of loops. + int64_t rank = shape.size(); + + // Define loops and optimized loops. + std::vector optimizedLoops; + optimizedLoopsOp = emitOptimizedLoops(rewriter, loc, originalLoops, + optimizedLoops, rank); + + KrnlIterateOperandPack pack(rewriter, originalLoops, optimizedLoops); + // Iterate over the loop nest. + for (int i = 0; i < rank; ++i) + addDimensionToPack(rewriter, loc, pack, operand, i); + + iterateOp = rewriter.create(loc, pack); +} + unsigned getMemRefEltSizeInBytes(MemRefType memRefType) { auto elementType = memRefType.getElementType(); @@ -749,55 +821,21 @@ struct ONNXElementwiseUnaryOpLowering : public ConversionPattern { alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc, {operands[0]}); - // Number of loops - auto memRefShape = memRefType.getShape(); - int64_t rank = memRefShape.size(); - - // Define loops. - auto loopsOp = rewriter.create(loc, rank); std::vector originalLoops; - originalLoops.reserve(rank); - for (auto result : loopsOp.getResults()) { - originalLoops.push_back(result); - } - - // Define loop optimization. - auto optimizedLoopsOp = rewriter.create(loc, rank); - std::vector optimizedLoops; - optimizedLoops.reserve(rank); - for (auto result : optimizedLoopsOp.getResults()) { - optimizedLoops.push_back(result); - } + KrnlOptimizeLoopsOp optimizedLoopsOp; + KrnlIterateOp iterateOp; + emitKrnlLoopsAndIterationForOperand( + rewriter, loc, operands[0], originalLoops, + optimizedLoopsOp, iterateOp); Block &optimizationBlock = optimizedLoopsOp.region().front(); - - KrnlIterateOperandPack pack(rewriter, originalLoops, optimizedLoops); - // Iterate over the loop nest. - // TODO (Tian): move this logic inside KrnlIterateOp. Pass MemRefShape - // to KrnlIterateOp instead. - for (int i = 0; i < rank; ++i) { - if (memRefShape[i] < 0) { - pack.pushConstantBound(0); - pack.pushOperandBound( - rewriter.create(loc, operands[0], i).getResult()); - } else { - pack.pushConstantBound(0); - pack.pushConstantBound(memRefShape[i]); - } - } - - auto iterateOp = rewriter.create(loc, pack); Block &iterationBlock = iterateOp.bodyRegion().front(); - // Now perform the insertions into the body of the - // just generated instructions: - // 1. Insert any optimizations in the KrnlOptimizeLoopsOp body. rewriter.setInsertionPointToEnd(&optimizationBlock); // Return from KrnlOptimizeLoopsOp body. // When no optimizations are present we just return the loops // unchaged. rewriter.create(loc, originalLoops); - rewriter.setInsertionPoint(optimizedLoopsOp); // 2. Insert instructions inside the KernelIterateOp body. rewriter.setInsertionPointToStart(&iterationBlock); @@ -851,59 +889,25 @@ struct ONNXElementwiseVariadicOpLowering : public ConversionPattern { alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc, operands); - // Number of loops - auto memRefShape = memRefType.getShape(); - int64_t rank = memRefShape.size(); - - // Define loops. - auto loopsOp = rewriter.create(loc, rank); - std::vector originalLoops; - originalLoops.reserve(rank); - for (auto result : loopsOp.getResults()) { - originalLoops.push_back(result); - } - - // Define loop optimization. - auto optimizedLoopsOp = rewriter.create(loc, rank); - std::vector optimizedLoops; - optimizedLoops.reserve(rank); - for (auto result : optimizedLoopsOp.getResults()) { - optimizedLoops.push_back(result); - } - Block &optimizationBlock = optimizedLoopsOp.region().front(); - - KrnlIterateOperandPack pack(rewriter, originalLoops, optimizedLoops); - // Iterate over the loop nest. - // TODO (Tian): move this logic inside KrnlIterateOp. Pass MemRefShape - // to KrnlIterateOp instead. - for (int i = 0; i < rank; ++i) { - if (memRefShape[i] < 0) { - pack.pushConstantBound(0); - pack.pushOperandBound( - rewriter.create(loc, alloc, i).getResult()); - } else { - pack.pushConstantBound(0); - pack.pushConstantBound(memRefShape[i]); - } - } - // Get run-time dimension information for unknown dimensions used for // broadcasting. std::map> broadcastedDimInfo = getBroadcastedDimInfo(loc, rewriter, memRefType, operands); - auto iterateOp = rewriter.create(loc, pack); + std::vector originalLoops; + KrnlOptimizeLoopsOp optimizedLoopsOp; + KrnlIterateOp iterateOp; + emitKrnlLoopsAndIterationForOperand( + rewriter, loc, alloc, originalLoops, + optimizedLoopsOp, iterateOp); + Block &optimizationBlock = optimizedLoopsOp.region().front(); Block &iterationBlock = iterateOp.bodyRegion().front(); - // Now perform the insertions into the body of the - // just generated instructions: - // 1. Insert any optimizations in the KrnlOptimizeLoopsOp body. rewriter.setInsertionPointToEnd(&optimizationBlock); // Return from KrnlOptimizeLoopsOp body. // When no optimizations are present we just return the loops unchaged. rewriter.create(loc, originalLoops); - rewriter.setInsertionPoint(optimizedLoopsOp); // 2. Insert instructions inside the KernelIterateOp body. rewriter.setInsertionPointToStart(&iterationBlock); @@ -978,21 +982,10 @@ struct ONNXSoftmaxOpLowering : public ConversionPattern { FloatAttr::get(elementType, -std::numeric_limits::infinity())); // Define loops. - auto loopsOp = rewriter.create(loc, rank); std::vector originalLoops; - originalLoops.reserve(rank); - for (auto result : loopsOp.getResults()) { - originalLoops.push_back(result); - } - - // Define loop optimization. - auto optimizedLoopsOp = rewriter.create(loc, rank); std::vector optimizedLoops; - optimizedLoops.reserve(rank); - for (auto result : optimizedLoopsOp.getResults()) { - optimizedLoops.push_back(result); - } - Block &optimizationBlock = optimizedLoopsOp.region().front(); + Block *optimizationBlock = defineLoops(rewriter, loc, originalLoops, + optimizedLoops, rank); // Coerce the input into a 2-D tensor. `axis` will be the coercing point. // This coercing follows the softmax definition in ONNX: @@ -1009,16 +1002,9 @@ struct ONNXSoftmaxOpLowering : public ConversionPattern { optimizedOuterLoops.push_back(optimizedLoops[i]); } KrnlIterateOperandPack outerPack(rewriter, outerLoops, optimizedOuterLoops); - for (int i = 0; i < axis; ++i) { - if (memRefShape[i] < 0) { - outerPack.pushConstantBound(0); - outerPack.pushOperandBound( - rewriter.create(loc, operands[0], i).getResult()); - } else { - outerPack.pushConstantBound(0); - outerPack.pushConstantBound(memRefShape[i]); - } - } + for (int i = 0; i < axis; ++i) + addDimensionToPack(rewriter, loc, outerPack, operands[0], i); + // Define an inner loop with respect to axis. std::vector innerLoops, optimizedInnerLoops; innerLoops.reserve(rank - axis); @@ -1028,16 +1014,8 @@ struct ONNXSoftmaxOpLowering : public ConversionPattern { optimizedInnerLoops.push_back(optimizedLoops[i]); } KrnlIterateOperandPack innerPack(rewriter, innerLoops, optimizedInnerLoops); - for (int i = axis; i < rank; ++i) { - if (memRefShape[i] < 0) { - innerPack.pushConstantBound(0); - innerPack.pushOperandBound( - rewriter.create(loc, operands[0], i).getResult()); - } else { - innerPack.pushConstantBound(0); - innerPack.pushConstantBound(memRefShape[i]); - } - } + for (int i = axis; i < rank; ++i) + addDimensionToPack(rewriter, loc, innerPack, operands[0], i); KrnlIterateOp outerIterateOp, maxIterateOp, sumIterateOp, softmaxIterateOp; SmallVector outerLoopIVs; @@ -1045,9 +1023,8 @@ struct ONNXSoftmaxOpLowering : public ConversionPattern { outerIterateOp = rewriter.create(loc, outerPack); // No optimization - rewriter.setInsertionPointToEnd(&optimizationBlock); + rewriter.setInsertionPointToEnd(optimizationBlock); rewriter.create(loc, originalLoops); - rewriter.setInsertionPoint(optimizedLoopsOp); // Insert instructions inside the outer loop. Block &outerIterationBlock = outerIterateOp.bodyRegion().front(); @@ -1078,9 +1055,8 @@ struct ONNXSoftmaxOpLowering : public ConversionPattern { softmaxIterateOp = rewriter.create(loc, innerPack); // No optimization - rewriter.setInsertionPointToEnd(&optimizationBlock); + rewriter.setInsertionPointToEnd(optimizationBlock); rewriter.create(loc, originalLoops); - rewriter.setInsertionPoint(optimizedLoopsOp); } // Insert instructions inside the max loop. @@ -1291,20 +1267,10 @@ struct ONNXGemmOpLowering : public ConversionPattern { int64_t numLoops = 3; // Define loops. - auto loopsOp = rewriter.create(loc, numLoops); std::vector originalLoops; - originalLoops.reserve(numLoops); - for (auto result : loopsOp.getResults()) { - originalLoops.push_back(result); - } - - auto optimizedLoopsOp = rewriter.create(loc, numLoops); std::vector optimizedLoops; - optimizedLoops.reserve(numLoops); - for (auto result : optimizedLoopsOp.getResults()) { - optimizedLoops.push_back(result); - } - Block &optimizationBlock = optimizedLoopsOp.region().front(); + Block *optimizationBlock = defineLoops(rewriter, loc, originalLoops, + optimizedLoops, numLoops); // We have two Krnl loops: // - Outer loop iterates over the output matrix dimensions, and @@ -1321,16 +1287,9 @@ struct ONNXGemmOpLowering : public ConversionPattern { KrnlIterateOperandPack outerPack(rewriter, outerLoops, optimizedOuterLoops); // Induction variables for the outer loops - for (int i = 0; i < 2; ++i) { - if (memRefShape[i] < 0) { - outerPack.pushConstantBound(0); - outerPack.pushOperandBound( - rewriter.create(loc, alloc, i).getResult()); - } else { - outerPack.pushConstantBound(0); - outerPack.pushConstantBound(memRefShape[i]); - } - } + for (int i = 0; i < 2; ++i) + addDimensionToPack(rewriter, loc, outerPack, alloc, i); + // Reduction loop std::vector reductionLoops, optimizedReductionLoops; reductionLoops.reserve(1); @@ -1378,9 +1337,8 @@ struct ONNXGemmOpLowering : public ConversionPattern { // just generated instructions: // No optimization - rewriter.setInsertionPointToEnd(&optimizationBlock); + rewriter.setInsertionPointToEnd(optimizationBlock); rewriter.create(loc, originalLoops); - rewriter.setInsertionPoint(optimizedLoopsOp); // Insert instructions inside the outer loop. Block &outerIterationBlock = outerIterateOp.bodyRegion().front(); @@ -1544,36 +1502,15 @@ struct ONNXTransposeOpLowering : public ConversionPattern { int64_t rank = memRefShape.size(); // Define loops. - auto loopsOp = rewriter.create(loc, rank); std::vector originalLoops; - originalLoops.reserve(rank); - - for (auto result : loopsOp.getResults()) { - originalLoops.push_back(result); - } - - // Define loop optimization. - auto optimizedLoopsOp = rewriter.create(loc, rank); std::vector optimizedLoops; - optimizedLoops.reserve(rank); + Block *optimizationBlock = defineLoops(rewriter, loc, originalLoops, + optimizedLoops, rank); - for (auto result : optimizedLoopsOp.getResults()) { - optimizedLoops.push_back(result); - } - Block &optimizationBlock = optimizedLoopsOp.region().front(); KrnlIterateOperandPack pack(rewriter, originalLoops, optimizedLoops); // Iterate over the loop nest using the input shape. - auto inputShape = operands[0].getType().cast().getShape(); - for (int i = 0; i < rank; ++i) { - if (inputShape[i] < 0) { - pack.pushConstantBound(0); - pack.pushOperandBound( - rewriter.create(loc, operands[0], i).getResult()); - } else { - pack.pushConstantBound(0); - pack.pushConstantBound(inputShape[i]); - } - } + for (int i = 0; i < rank; ++i) + addDimensionToPack(rewriter, loc, pack, operands[0], i); auto iterateOp = rewriter.create(loc, pack); Block &iterationBlock = iterateOp.bodyRegion().front(); @@ -1582,12 +1519,11 @@ struct ONNXTransposeOpLowering : public ConversionPattern { // just generated instructions: // 1. Insert any optimizations in the KrnlOptimizeLoopsOp body. - rewriter.setInsertionPointToEnd(&optimizationBlock); + rewriter.setInsertionPointToEnd(optimizationBlock); // Return from KrnlOptimizeLoopsOp body. // When no optimizations are present we just return the loops // unchaged. rewriter.create(loc, originalLoops); - rewriter.setInsertionPoint(optimizedLoopsOp); // 2. Insert instructions inside the KernelIterateOp body. rewriter.setInsertionPointToStart(&iterationBlock); @@ -1638,6 +1574,255 @@ struct ONNXIdentityOpLowering : public ConversionPattern { } }; +struct ONNXConvNoBiasOpLowering : public ConversionPattern { + ONNXConvNoBiasOpLowering(MLIRContext *ctx) + : ConversionPattern(mlir::ONNXConvNoBiasOp::getOperationName(), 1, ctx) {} + + PatternMatchResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const final { + auto tensorType = (*op->result_type_begin()).cast(); + auto loc = op->getLoc(); + // Insert an allocation and deallocation for the result of this operation. + auto memRefType = convertTensorToMemRef(tensorType); + Value alloc; + bool insertDealloc = checkInsertDealloc(op); + ONNXConvNoBiasOp convOp = llvm::dyn_cast(op); + + if (hasAllConstantDimensions(memRefType)) + alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc); + else + alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc, + {operands[0]}); + + auto resultShape = memRefType.getShape(); + auto inputShape = operands[0].getType().cast().getShape(); + auto kernelShape = operands[1].getType().cast().getShape(); + + // R = ConvNoBias(D, K) + // + // The input/output shapes will look like this: + // + // D (NxCxHxW) x K (MxC/groupxKHxKW) -> R (NxMxRHxRW) + // + // M is a multiple of the number of groups: + // M = group * kernelsPerGroup + // + // The loop nest will look as follows: + // + // kernelsPerGroup = M / group; + // for n = 0 .. N: + // for g = 0 .. group: + // for m = 0 .. kernelsPerGroup: + // kernel = g * kernelsPerGroup + m; + // for r1 = 0 .. RH: + // for r2 = 0 .. RW: + // R[n][kernel][r1][r2] = 0; + // for c = 0 .. C/group: + // for k1 = 0 .. KH: + // for k2 = 0 .. KW: + // R[n][kernel][r1][r2] = + // D[n][g * (C / group) + c][r1 + k1][r2 + k2] * + // K[kernel][c][k1][k2]; + // + // TODO: handle padding. + // + // In the general case: + // + // D (NxCxD1xD2x...xDdim) x K (MxC/groupxK1xK2x...xKdim) + // -> R (NxMxR1xR2x...xRdim) + // + // The above loop nest can be adapted by increasing the number + // of r- and k-index loop i.e. r1 r2 and k1 k2 loops. + + // Set up outermost loops: n g m r1 r2 ... rdim + // Skip g if group is 1. + + // Before we start the iteration we need to compute the number of + // unsplit kernels and fetch the number of groups from the attribute + // list. Group is always a compilation constant. + int64_t group = convOp.group().getSExtValue(); + // Compute the number of unsplit kernels. The number of kernels + // must be a multiple of the number of groups. + int64_t kernelsPerGroup = floor(kernelShape[0] / group); + auto kernelsPerGroupValue = + rewriter.create(loc, kernelsPerGroup); + auto zero = rewriter.create( + loc, FloatAttr::get(memRefType.getElementType(), 0)); + Value subchannels; + if (kernelShape[1] < 0) { + subchannels = + rewriter.create(loc, operands[1], 1).getResult(); + } else { + subchannels = rewriter.create( + loc, kernelShape[1]); + } + + // 1. Define outer loops and emit empty optimization block: + int64_t nOuterLoops = (group > 1) ? 3 : 2; + std::vector outerLoops; + std::vector optimizedOuterLoops; + Block *optimizationBlock = defineLoops(rewriter, loc, outerLoops, + optimizedOuterLoops, nOuterLoops); + + // Prepare iteration arguments over outer loop nest. + KrnlIterateOperandPack pack( + rewriter, outerLoops, optimizedOuterLoops); + // for n = 0 .. N: + pack.pushConstantBound(0); + if (inputShape[0] < 0) + pack.pushOperandBound( + rewriter.create(loc, operands[0], 0).getResult()); + else + pack.pushConstantBound(inputShape[0]); + // for g = 0 .. N: + if (group > 1) { + pack.pushConstantBound(0); + pack.pushConstantBound(group); + } + // for m = 0 .. kernelsPerGroup: + pack.pushConstantBound(0); + pack.pushConstantBound(kernelsPerGroup); + // Outer loop iteration. + auto iterateOp = rewriter.create(loc, pack); + Block &outerIterationBlock = iterateOp.bodyRegion().front(); + // Emit optimizations for outer loops: + rewriter.setInsertionPointToEnd(optimizationBlock); + rewriter.create(loc, outerLoops); + rewriter.setInsertionPointToStart(&outerIterationBlock); + { + // 2. Emit the body of the outer loop nest. + + // 2.1 Compute kernel order number: kernel = g * kernelsPerGroup + m; + // If group is not set then the value of the kernel ID is + // identical to that of the loop over kernels. + Value kernel = outerIterationBlock.getArguments()[1]; + if (group > 1) { + // Middle loop is over groups and third loop is over the + // kernel identifiers in the current group. + auto kernelsOffset = rewriter.create(loc, + outerIterationBlock.getArguments()[1], + kernelsPerGroupValue); + kernel = rewriter.create(loc, kernelsOffset, + outerIterationBlock.getArguments()[2]); + } + + // 2.2 Define spatial loops + int64_t nSpatialLoops = resultShape.size() - 2; + std::vector spatialLoops; + std::vector optimizedSpatialLoops; + Block *optSpatialLoopBlock = defineLoops(rewriter, loc, spatialLoops, + optimizedSpatialLoops, nSpatialLoops); + + // 2.3 Prepare iteration arguments for spatial loop nest. + KrnlIterateOperandPack spatialPack( + rewriter, spatialLoops, optimizedSpatialLoops); + for (int i = 2; i < resultShape.size(); ++i) + addDimensionToPack(rewriter, loc, spatialPack, alloc, i); + + // 2.4 Emit loop nest over output spatial dimensions. + // for rX = 0 .. RX + auto spatialIterateOp = + rewriter.create(loc, spatialPack); + Block &spatialIterationBlock = spatialIterateOp.bodyRegion().front(); + // 2.5 Emit optimizations for outer loops: + rewriter.setInsertionPointToEnd(optSpatialLoopBlock); + rewriter.create(loc, spatialLoops); + rewriter.setInsertionPointToStart(&spatialIterationBlock); + { + // 3. Emit the body of the spatial loop nest. + // 3.1 Emit: R[n][kernel][r1][r2] = 0; + SmallVector resultIndices; + // n + resultIndices.emplace_back(outerIterationBlock.getArguments()[0]); + // kernel + resultIndices.emplace_back(kernel); + // rX + for (auto arg : spatialIterationBlock.getArguments()) + resultIndices.emplace_back(arg); + // Store initializer value into output location. + rewriter.create(loc, zero, alloc, resultIndices); + + // 3.2 Define inner loops. + int64_t nInnerLoops = 1 + (kernelShape.size() - 2); + std::vector innerLoops; + std::vector optimizedInnerLoops; + Block *optInnerLoopBlock = defineLoops(rewriter, loc, innerLoops, + optimizedInnerLoops, nInnerLoops); + + // 3.3 Prepare iteration arguments for inner loop nest. + KrnlIterateOperandPack innerPack( + rewriter, innerLoops, optimizedInnerLoops); + // for c = 0 .. C/group + innerPack.pushConstantBound(0); + innerPack.pushConstantBound(kernelShape[1]); + // for Kx = 0 .. KX + for (int i = 2; i < kernelShape.size(); ++i) + addDimensionToPack(rewriter, loc, innerPack, operands[1], i); + + // 3.4 Emit inner loop nest. + auto innerIterateOp = + rewriter.create(loc, innerPack); + Block &innerIterationBlock = innerIterateOp.bodyRegion().front(); + // 3.5 Emit optimizations for outer loops: + rewriter.setInsertionPointToEnd(optInnerLoopBlock); + rewriter.create(loc, innerLoops); + rewriter.setInsertionPointToStart(&innerIterationBlock); + { + // 4. Emit inner loop body + // R[n][kernel][r1][r2] = + // D[n][g * (C / group) + c][r1 + k1][r2 + k2] * + // K[kernel][c][k1][k2]; + + // 4.1 Prepare indices for accesing the data tensor. + SmallVector dataIndices; + // n + dataIndices.emplace_back(outerIterationBlock.getArguments()[0]); + // g * (C / group) + c + Value channelDepth = innerIterationBlock.getArguments()[0]; + if (group > 1) + channelDepth = rewriter.create(loc, channelDepth, + rewriter.create(loc, subchannels, + outerIterationBlock.getArguments()[1])); + dataIndices.emplace_back(channelDepth); + // rX + kX + for (int i = 0; i < kernelShape.size() - 2; ++i) + dataIndices.emplace_back( + rewriter.create(loc, + spatialIterationBlock.getArguments()[i], + innerIterationBlock.getArguments()[i+1])); + + // 4.2 Prepare indices for accessing the kernel tensor. + SmallVector kernelIndices; + // kernel + kernelIndices.emplace_back(kernel); + // c + kernelIndices.emplace_back(innerIterationBlock.getArguments()[0]); + // kX + for (int i = 0; i < kernelShape.size() - 2; ++i) + kernelIndices.emplace_back( + innerIterationBlock.getArguments()[i+1]); + + // 4.3 Compute convolution. + auto loadData = + rewriter.create(loc, operands[0], dataIndices); + auto loadKernel = + rewriter.create(loc, operands[1], kernelIndices); + auto loadPartialSum = + rewriter.create(loc, alloc, resultIndices); + Value result = rewriter.create(loc, loadPartialSum, + rewriter.create(loc, loadData, loadKernel)); + // 4.4 Store computed value into output location. + rewriter.create(loc, result, alloc, resultIndices); + } + } + } + rewriter.replaceOp(op, alloc); + + return matchSuccess(); + } +}; + //===----------------------------------------------------------------------===// // EntryPoint Op lowering to Krnl Entry Point. //===----------------------------------------------------------------------===// @@ -1769,7 +1954,8 @@ void FrontendToKrnlLoweringPass::runOnModule() { ONNXReshapeOpLowering, ONNXEntryPointLowering, ONNXSoftmaxOpLowering, ONNXGemmOpLowering, ONNXUnsqueezeOpLowering, ONNXTransposeOpLowering, - ONNXIdentityOpLowering>(&getContext()); + ONNXIdentityOpLowering, ONNXConvNoBiasOpLowering + >(&getContext()); // With the target and rewrite patterns defined, we can now attempt the // conversion. The conversion will signal failure if any of our `illegal` diff --git a/test/backend/test.py b/test/backend/test.py index 2387e15..d7ae639 100644 --- a/test/backend/test.py +++ b/test/backend/test.py @@ -202,6 +202,9 @@ test_to_enable = [ "test_transpose_all_permutations_4_cpu", "test_transpose_all_permutations_5_cpu", + # Conv + "test_basic_conv_without_padding_cpu", + # Sign Op: "test_sign_cpu", ] diff --git a/test/mlir/onnx/onnx_lowering.mlir b/test/mlir/onnx/onnx_lowering.mlir index c58dcd1..b5a7dd4 100644 --- a/test/mlir/onnx/onnx_lowering.mlir +++ b/test/mlir/onnx/onnx_lowering.mlir @@ -568,15 +568,15 @@ func @test_add_with_broadcasting(%arg0 : tensor, %arg1 : tensor // CHECK-LABEL: test_add_with_broadcasting // CHECK: [[DIM1:%.+]] = dim %arg1, 0 : memref // CHECK: [[RES:%.+]] = alloc([[DIM1]]) : memref + // CHECK: [[DIM2:%.+]] = dim %arg0, 0 : memref + // CHECK: [[ONE:%.+]] = constant 1 : index + // CHECK: [[IS_ONE:%.+]] = cmpi "eq", [[DIM2]], [[ONE]] : index // CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2 // CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops { // CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1 // CHECK: } : () -> (!krnl.loop, !krnl.loop) - // CHECK: [[DIM2:%.+]] = dim [[RES]], 0 : memref - // CHECK: [[DIM3:%.+]] = dim %arg0, 0 : memref - // CHECK: [[ONE:%.+]] = constant 1 : index - // CHECK: [[IS_ONE:%.+]] = cmpi "eq", [[DIM3]], [[ONE]] : index - // CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg2 = 0 to [[DIM2]], [[DEF_LOOPS]]#1 -> %arg3 = 0 to 10) { + // CHECK: [[DIM3:%.+]] = dim [[RES]], 0 : memref + // CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg2 = 0 to [[DIM3]], [[DEF_LOOPS]]#1 -> %arg3 = 0 to 10) { // CHECK: [[ZERO:%.+]] = constant 0 : index // CHECK: %[[SELECT1:.+]] = select [[IS_ONE]], [[ZERO]], %arg3 : index // CHECK: [[LOAD1:%.+]] = load %arg0[%[[SELECT1]]] : memref @@ -788,3 +788,93 @@ func @test_sign_i(%arg0 : tensor) -> tensor<*xi32> { // CHECK: store [[SIGN_RES]], [[RES]][%arg1, %arg2] : memref // CHECK: return [[RES]] : memref } + +func @test_conv_no_bias_no_pad(%arg0 : tensor<1x2x32x64xf32>, %arg1 : tensor<5x2x6x7xf32>) -> tensor<*xf32> { + %0 = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "NOTSET", group = 1 : i64} : (tensor<1x2x32x64xf32>, tensor<5x2x6x7xf32>) -> tensor<*xf32> + "std.return"(%0) : (tensor<*xf32>) -> () + + // CHECK-LABEL: test_conv_no_bias_no_pad + // CHECK: [[RES:%.+]] = alloc() : memref<1x5x27x58xf32> + // CHECK: [[CONST0:%.+]] = constant 5 : index + // CHECK: [[CONST1:%.+]] = constant 0.000000e+00 : f32 + // CHECK: [[CONST2:%.+]] = constant 2 : index + // CHECK: [[OUTER_LOOPS:%.+]]:2 = krnl.define_loops 2 + // CHECK: [[OPT_OUTER_LOOPS:%.+]]:2 = krnl.optimize_loops { + // CHECK: krnl.return_loops [[OUTER_LOOPS]]#0, [[OUTER_LOOPS]]#1 + // CHECK: } : () -> (!krnl.loop, !krnl.loop) + + // CHECK: krnl.iterate([[OPT_OUTER_LOOPS]]#0, [[OPT_OUTER_LOOPS]]#1) with ([[OUTER_LOOPS]]#0 -> %arg2 = 0 to 1, [[OUTER_LOOPS]]#1 -> %arg3 = 0 to 5) { + // CHECK: [[SPATIAL_LOOPS:%.+]]:2 = krnl.define_loops 2 + // CHECK: [[OPT_SPATIAL_LOOPS:%.+]]:2 = krnl.optimize_loops { + // CHECK: krnl.return_loops [[SPATIAL_LOOPS]]#0, [[SPATIAL_LOOPS]]#1 + // CHECK: } : () -> (!krnl.loop, !krnl.loop) + + // CHECK: krnl.iterate([[OPT_SPATIAL_LOOPS]]#0, [[OPT_SPATIAL_LOOPS]]#1) with ([[SPATIAL_LOOPS]]#0 -> %arg4 = 0 to 27, [[SPATIAL_LOOPS]]#1 -> %arg5 = 0 to 58) { + // CHECK: store [[CONST1]], [[RES]][%arg2, %arg3, %arg4, %arg5] : memref<1x5x27x58xf32> + // CHECK: [[INNER_LOOPS:%.+]]:3 = krnl.define_loops 3 + // CHECK: [[OPT_INNER_LOOPS:%.+]]:3 = krnl.optimize_loops { + // CHECK: krnl.return_loops [[INNER_LOOPS]]#0, [[INNER_LOOPS]]#1, [[INNER_LOOPS]]#2 + // CHECK: } : () -> (!krnl.loop, !krnl.loop, !krnl.loop) + + // CHECK: krnl.iterate([[OPT_INNER_LOOPS]]#0, [[OPT_INNER_LOOPS]]#1, [[OPT_INNER_LOOPS]]#2) with ([[INNER_LOOPS]]#0 -> %arg6 = 0 to 2, [[INNER_LOOPS]]#1 -> %arg7 = 0 to 6, [[INNER_LOOPS]]#2 -> %arg8 = 0 to 7) { + // CHECK: [[R1PLUSK1:%.+]] = addi %arg4, %arg7 : index + // CHECK: [[R2PLUSK2:%.+]] = addi %arg5, %arg8 : index + // CHECK: [[DATA:%.+]] = load %arg0[%arg2, %arg6, [[R1PLUSK1]], [[R2PLUSK2]]] : memref<1x2x32x64xf32> + // CHECK: [[KERNEL:%.+]] = load %arg1[%arg3, %arg6, %arg7, %arg8] : memref<5x2x6x7xf32> + // CHECK: [[ACC_RES:%.+]] = load %0[%arg2, %arg3, %arg4, %arg5] : memref<1x5x27x58xf32> + // CHECK: [[MUL:%.+]] = mulf [[DATA]], [[KERNEL]] : f32 + // CHECK: [[ADD:%.+]] = addf [[ACC_RES]], [[MUL]] : f32 + // CHECK: store [[ADD]], [[RES]][%arg2, %arg3, %arg4, %arg5] : memref<1x5x27x58xf32> + // CHECK: } + // CHECK: } + // CHECK: } + + // CHECK: return [[RES]] : memref<1x5x27x58xf32> +} + +func @test_conv_no_bias_no_pad_w_group(%arg0 : tensor<1x9x32x64xf32>, %arg1 : tensor<5x3x6x7xf32>) -> tensor<*xf32> { + %0 = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "NOTSET", group = 3 : i64} : (tensor<1x9x32x64xf32>, tensor<5x3x6x7xf32>) -> tensor<*xf32> + "std.return"(%0) : (tensor<*xf32>) -> () + + // CHECK-LABEL: test_conv_no_bias_no_pad_w_group + // CHECK: [[RES:%.+]] = alloc() : memref<1x5x27x58xf32> + // CHECK: [[CONST0:%.+]] = constant 1 : index + // CHECK: [[CONST1:%.+]] = constant 0.000000e+00 : f32 + // CHECK: [[CONST2:%.+]] = constant 3 : index + // CHECK: [[OUTER_LOOPS:%.+]]:3 = krnl.define_loops 3 + // CHECK: [[OPT_OUTER_LOOPS:%.+]]:3 = krnl.optimize_loops { + // CHECK: krnl.return_loops [[OUTER_LOOPS]]#0, [[OUTER_LOOPS]]#1, [[OUTER_LOOPS]]#2 + // CHECK: } : () -> (!krnl.loop, !krnl.loop, !krnl.loop) + + // CHECK: krnl.iterate([[OPT_OUTER_LOOPS]]#0, [[OPT_OUTER_LOOPS]]#1, [[OPT_OUTER_LOOPS]]#2) with ([[OUTER_LOOPS]]#0 -> %arg2 = 0 to 1, [[OUTER_LOOPS]]#1 -> %arg3 = 0 to 3, [[OUTER_LOOPS]]#2 -> %arg4 = 0 to 1) { + // CHECK: [[MUL1:%.+]] = muli %arg3, [[CONST0]] : index + // CHECK: %[[ADD1:.+]] = addi [[MUL1]], %arg4 : index + // CHECK: [[SPATIAL_LOOPS:%.+]]:2 = krnl.define_loops 2 + // CHECK: [[OPT_SPATIAL_LOOPS:%.+]]:2 = krnl.optimize_loops { + // CHECK: krnl.return_loops [[SPATIAL_LOOPS]]#0, [[SPATIAL_LOOPS]]#1 + // CHECK: } : () -> (!krnl.loop, !krnl.loop) + + // CHECK: krnl.iterate([[OPT_SPATIAL_LOOPS]]#0, [[OPT_SPATIAL_LOOPS]]#1) with ([[SPATIAL_LOOPS]]#0 -> %arg5 = 0 to 27, [[SPATIAL_LOOPS]]#1 -> %arg6 = 0 to 58) { + // CHECK: store [[CONST1]], [[RES]][%arg2, %[[ADD1]], %arg5, %arg6] : memref<1x5x27x58xf32> + // CHECK: [[INNER_LOOPS:%.+]]:3 = krnl.define_loops 3 + // CHECK: [[OPT_INNER_LOOPS:%.+]]:3 = krnl.optimize_loops { + // CHECK: krnl.return_loops [[INNER_LOOPS]]#0, [[INNER_LOOPS]]#1, [[INNER_LOOPS]]#2 + // CHECK: } : () -> (!krnl.loop, !krnl.loop, !krnl.loop) + + // CHECK: krnl.iterate([[OPT_INNER_LOOPS]]#0, [[OPT_INNER_LOOPS]]#1, [[OPT_INNER_LOOPS]]#2) with ([[INNER_LOOPS]]#0 -> %arg7 = 0 to 3, [[INNER_LOOPS]]#1 -> %arg8 = 0 to 6, [[INNER_LOOPS]]#2 -> %arg9 = 0 to 7) { + // CHECK: [[MUL2:%.+]] = muli [[CONST2]], %arg3 : index + // CHECK: [[ADD2:%.+]] = addi %arg7, [[MUL2]] : index + // CHECK: [[R1PLUSK1:%.+]] = addi %arg5, %arg8 : index + // CHECK: [[R2PLUSK2:%.+]] = addi %arg6, %arg9 : index + // CHECK: [[DATA:%.+]] = load %arg0[%arg2, [[ADD2]], [[R1PLUSK1]], [[R2PLUSK2]]] : memref<1x9x32x64xf32> + // CHECK: [[KERNEL:%.+]] = load %arg1[%[[ADD1]], %arg7, %arg8, %arg9] : memref<5x3x6x7xf32> + // CHECK: [[ACC_RES:%.+]] = load %0[%arg2, %[[ADD1]], %arg5, %arg6] : memref<1x5x27x58xf32> + // CHECK: [[MUL:%.+]] = mulf [[DATA]], [[KERNEL]] : f32 + // CHECK: [[ADD:%.+]] = addf [[ACC_RES]], [[MUL]] : f32 + // CHECK: store [[ADD]], [[RES]][%arg2, %[[ADD1]], %arg5, %arg6] : memref<1x5x27x58xf32> + // CHECK: } + // CHECK: } + // CHECK: } + + // CHECK: return [[RES]] : memref<1x5x27x58xf32> +}