From b422116f1203e064ed7125083b2c820e6006d576 Mon Sep 17 00:00:00 2001 From: Alexandre Eichenberger Date: Tue, 31 Mar 2020 11:55:27 -0400 Subject: [PATCH] clean operands using names provided by operandAdaptors (#56) * clean operands using names provided by operandAdaptors * reverted changes that were erronerous, from Tung's comment * clang format issues --- src/Conversion/ONNXToKrnl/Math/Gemm.cpp | 7 +- src/Conversion/ONNXToKrnl/Math/MatMul.cpp | 5 +- src/Conversion/ONNXToKrnl/Math/Softmax.cpp | 27 ++- src/Conversion/ONNXToKrnl/NN/Conv.cpp | 12 +- .../ONNXToKrnl/NN/Normalization.cpp | 34 ++-- src/Conversion/ONNXToKrnl/NN/Pooling.cpp | 3 +- src/Conversion/ONNXToKrnl/Tensor/Identity.cpp | 8 +- .../ONNXToKrnl/Tensor/PadConstantValuePad.cpp | 20 +- src/Conversion/ONNXToKrnl/Tensor/Reshape.cpp | 18 +- .../ONNXToKrnl/Tensor/Transpose.cpp | 25 +-- .../ONNXToKrnl/Tensor/Unsqueeze.cpp | 11 +- src/Transform/LowerToLLVM.cpp | 180 +++++++++--------- src/Transform/ONNX/ONNXRewrite.cpp | 4 +- 13 files changed, 175 insertions(+), 179 deletions(-) diff --git a/src/Conversion/ONNXToKrnl/Math/Gemm.cpp b/src/Conversion/ONNXToKrnl/Math/Gemm.cpp index 6442334..c95a468 100644 --- a/src/Conversion/ONNXToKrnl/Math/Gemm.cpp +++ b/src/Conversion/ONNXToKrnl/Math/Gemm.cpp @@ -24,10 +24,11 @@ struct ONNXGemmOpLowering : public ConversionPattern { bool hasBias = !op->getOperand(2).getType().isa(); Value A, B, C; - A = operands[0]; - B = operands[1]; + ONNXGemmOpOperandAdaptor operandAdaptor(operands); + A = operandAdaptor.A(); + B = operandAdaptor.B(); if (hasBias) - C = operands[2]; + C = operandAdaptor.C(); auto memRefType = convertToMemRefType(*op->result_type_begin()); diff --git a/src/Conversion/ONNXToKrnl/Math/MatMul.cpp b/src/Conversion/ONNXToKrnl/Math/MatMul.cpp index e06deb4..c9a86af 100644 --- a/src/Conversion/ONNXToKrnl/Math/MatMul.cpp +++ b/src/Conversion/ONNXToKrnl/Math/MatMul.cpp @@ -21,8 +21,9 @@ struct ONNXMatMulOpLowering : public ConversionPattern { ConversionPatternRewriter &rewriter) const final { auto loc = op->getLoc(); - Value A = operands[0]; - Value B = operands[1]; + ONNXMatMulOpOperandAdaptor operandAdaptor(operands); + Value A = operandAdaptor.A(); + Value B = operandAdaptor.B(); auto AShape = A.getType().cast().getShape(); auto BShape = B.getType().cast().getShape(); diff --git a/src/Conversion/ONNXToKrnl/Math/Softmax.cpp b/src/Conversion/ONNXToKrnl/Math/Softmax.cpp index 1acf990..8fc6d0a 100644 --- a/src/Conversion/ONNXToKrnl/Math/Softmax.cpp +++ b/src/Conversion/ONNXToKrnl/Math/Softmax.cpp @@ -15,9 +15,8 @@ using namespace mlir; struct ONNXSoftmaxOpLowering : public ConversionPattern { ONNXSoftmaxOpLowering(MLIRContext *ctx) : ConversionPattern(mlir::ONNXSoftmaxOp::getOperationName(), 1, ctx) {} - PatternMatchResult - matchAndRewrite(Operation *op, ArrayRef operands, - ConversionPatternRewriter &rewriter) const final { + PatternMatchResult matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const final { // softmax(x) = let max_x = max(x) in // let exp_x = exp(x - max_x) in // let sum = sum(exp_x) in @@ -29,7 +28,8 @@ struct ONNXSoftmaxOpLowering : public ConversionPattern { assert(axis >= -rank && axis <= rank - 1); auto loc = op->getLoc(); - + ONNXSoftmaxOpOperandAdaptor operandAdaptor(operands); + Value input = operandAdaptor.input(); // Insert an allocation and deallocation for the result of this operation. auto elementType = memRefType.getElementType(); @@ -38,8 +38,8 @@ struct ONNXSoftmaxOpLowering : public ConversionPattern { if (hasAllConstantDimensions(memRefType)) alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc); else - alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc, - operands[0]); + alloc = insertAllocAndDealloc( + memRefType, loc, rewriter, insertDealloc, input); // Shape of the result auto memRefShape = memRefType.getShape(); @@ -49,15 +49,14 @@ struct ONNXSoftmaxOpLowering : public ConversionPattern { Value sumOp = insertAllocAndDealloc(scalarMemRefType, loc, rewriter, true); Value maxOp = insertAllocAndDealloc(scalarMemRefType, loc, rewriter, true); Value zero = emitConstantOp(rewriter, loc, elementType, 0); - Value negInfinity = rewriter.create( - loc, + Value negInfinity = rewriter.create(loc, FloatAttr::get(elementType, -std::numeric_limits::infinity())); // Define loops. std::vector originalLoops; std::vector optimizedLoops; - Block *optimizationBlock = defineLoops(rewriter, loc, originalLoops, - optimizedLoops, rank); + Block *optimizationBlock = + defineLoops(rewriter, loc, originalLoops, optimizedLoops, rank); // Coerce the input into a 2-D tensor. `axis` will be the coercing point. // This coercing follows the softmax definition in ONNX: @@ -75,7 +74,7 @@ struct ONNXSoftmaxOpLowering : public ConversionPattern { } KrnlIterateOperandPack outerPack(rewriter, outerLoops, optimizedOuterLoops); for (int i = 0; i < axis; ++i) - addDimensionToPack(rewriter, loc, outerPack, operands[0], i); + addDimensionToPack(rewriter, loc, outerPack, input, i); // Define an inner loop with respect to axis. std::vector innerLoops, optimizedInnerLoops; @@ -87,7 +86,7 @@ struct ONNXSoftmaxOpLowering : public ConversionPattern { } KrnlIterateOperandPack innerPack(rewriter, innerLoops, optimizedInnerLoops); for (int i = axis; i < rank; ++i) - addDimensionToPack(rewriter, loc, innerPack, operands[0], i); + addDimensionToPack(rewriter, loc, innerPack, input, i); KrnlIterateOp outerIterateOp, maxIterateOp, sumIterateOp, softmaxIterateOp; SmallVector outerLoopIVs; @@ -144,7 +143,7 @@ struct ONNXSoftmaxOpLowering : public ConversionPattern { // Compute the max value. Value max = rewriter.create(loc, maxOp); - Value nextMax = rewriter.create(loc, operands[0], maxLoopIVs); + Value nextMax = rewriter.create(loc, input, maxLoopIVs); auto maxCond = rewriter.create(loc, CmpFPredicate::OGT, max, nextMax); max = rewriter.create(loc, maxCond, max, nextMax); @@ -167,7 +166,7 @@ struct ONNXSoftmaxOpLowering : public ConversionPattern { // Sum up values. Value sum = rewriter.create(loc, sumOp); - Value next = rewriter.create(loc, operands[0], sumLoopIVs); + Value next = rewriter.create(loc, input, sumLoopIVs); Value sub = rewriter.create(loc, next, max); Value exp = rewriter.create(loc, sub); sum = rewriter.create(loc, sum, exp); diff --git a/src/Conversion/ONNXToKrnl/NN/Conv.cpp b/src/Conversion/ONNXToKrnl/NN/Conv.cpp index db29107..17f6c68 100644 --- a/src/Conversion/ONNXToKrnl/NN/Conv.cpp +++ b/src/Conversion/ONNXToKrnl/NN/Conv.cpp @@ -26,12 +26,6 @@ struct ONNXConvOpLowering : public ConversionPattern { bool insertDealloc = checkInsertDealloc(op); ONNXConvOp convOp = llvm::dyn_cast(op); - if (hasAllConstantDimensions(memRefType)) - alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc); - else - alloc = insertAllocAndDealloc( - memRefType, loc, rewriter, insertDealloc, {operands[0]}); - auto resultShape = memRefType.getShape(); auto inputOperand = operandAdaptor.X(); auto inputShape = inputOperand.getType().cast().getShape(); @@ -40,6 +34,12 @@ struct ONNXConvOpLowering : public ConversionPattern { auto biasOperand = operandAdaptor.B(); bool hasBias = !biasOperand.getType().isa(); + if (hasAllConstantDimensions(memRefType)) + alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc); + else + alloc = insertAllocAndDealloc( + memRefType, loc, rewriter, insertDealloc, {inputOperand}); + // R = Conv(D, K) // // The input/output shapes will look like this: diff --git a/src/Conversion/ONNXToKrnl/NN/Normalization.cpp b/src/Conversion/ONNXToKrnl/NN/Normalization.cpp index 16f6463..f5297de 100644 --- a/src/Conversion/ONNXToKrnl/NN/Normalization.cpp +++ b/src/Conversion/ONNXToKrnl/NN/Normalization.cpp @@ -1,4 +1,4 @@ -//===----------- Normalization.cpp - Lowering Normalization Ops ------------===// +//===----------- Normalization.cpp - Lowering Normalization Ops -----------===// // // Copyright 2019 The IBM Research Authors. // @@ -18,24 +18,24 @@ struct ONNXBatchNormalizationTestModeOpLowering : public ConversionPattern { mlir::ONNXBatchNormalizationTestModeOp::getOperationName(), 1, ctx) {} PatternMatchResult matchAndRewrite(Operation *op, ArrayRef operands, - ConversionPatternRewriter & rewriter) const final { + ConversionPatternRewriter &rewriter) const final { // batchnorm{epsilon}(x, scale, bias, mean, variance) = // scale * (x - mean) / sqrt(variance + epsilon) + bias + ONNXBatchNormalizationTestModeOpOperandAdaptor operandAdaptor(operands); auto loc = op->getLoc(); auto memRefType = convertToMemRefType(*op->result_type_begin()); - auto epsilonAttr = - FloatAttr::get(memRefType.getElementType(), - llvm::dyn_cast(op) - .epsilon() - .convertToFloat()); + auto epsilonAttr = FloatAttr::get(memRefType.getElementType(), + llvm::dyn_cast(op) + .epsilon() + .convertToFloat()); auto epsilon = rewriter.create(loc, epsilonAttr); - auto operand = operands[0]; - auto scale = operands[1]; - auto bias = operands[2]; - auto mean = operands[3]; - auto variance = operands[4]; + auto operand = operandAdaptor.X(); + auto scale = operandAdaptor.scale(); + auto bias = operandAdaptor.B(); + auto mean = operandAdaptor.mean(); + auto variance = operandAdaptor.var(); // Insert an allocation and deallocation for the result of this operation. Value alloc; @@ -44,8 +44,8 @@ struct ONNXBatchNormalizationTestModeOpLowering : public ConversionPattern { if (hasAllConstantDimensions(memRefType)) alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc); else - alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc, - {operand}); + alloc = insertAllocAndDealloc( + memRefType, loc, rewriter, insertDealloc, {operand}); // Operand's dimensions can be in the form of NxCxD1xD2x...xDn or N. // In case of N, C is assumed to be 1. @@ -67,8 +67,8 @@ struct ONNXBatchNormalizationTestModeOpLowering : public ConversionPattern { SmallVector loopCIVs; if (rank > 1) { - KrnlIterateOperandPack cPack(rewriter, originalLoops[1], - optimizedLoops[1]); + KrnlIterateOperandPack cPack( + rewriter, originalLoops[1], optimizedLoops[1]); addDimensionToPack(rewriter, loc, cPack, operand, 1); auto cIterateOp = rewriter.create(loc, cPack); Block &cIterationBlock = cIterateOp.bodyRegion().front(); @@ -76,7 +76,7 @@ struct ONNXBatchNormalizationTestModeOpLowering : public ConversionPattern { for (auto arg : cIterationBlock.getArguments()) loopCIVs.emplace_back(arg); } else { - loopCIVs.emplace_back(rewriter.create(loc, 0)); + loopCIVs.emplace_back(rewriter.create(loc, 0)); } auto scaleVal = rewriter.create(loc, scale, loopCIVs); diff --git a/src/Conversion/ONNXToKrnl/NN/Pooling.cpp b/src/Conversion/ONNXToKrnl/NN/Pooling.cpp index 4c0be1e..82c1397 100644 --- a/src/Conversion/ONNXToKrnl/NN/Pooling.cpp +++ b/src/Conversion/ONNXToKrnl/NN/Pooling.cpp @@ -38,6 +38,7 @@ struct ONNXMaxPoolSingleOutOpLowering : public ConversionPattern { PatternMatchResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const final { + ONNXMaxPoolSingleOutOpOperandAdaptor operandAdaptor(operands); auto loc = op->getLoc(); // Match @@ -71,7 +72,7 @@ struct ONNXMaxPoolSingleOutOpLowering : public ConversionPattern { dilations.emplace_back(dilation.cast().getInt()); // Type information about the input and result of this operation. - auto &inputOperand = operands[0]; + auto inputOperand = operandAdaptor.X(); auto inputShape = inputOperand.getType().cast().getShape(); auto memRefType = convertToMemRefType(*op->result_type_begin()); auto resultShape = memRefType.getShape(); diff --git a/src/Conversion/ONNXToKrnl/Tensor/Identity.cpp b/src/Conversion/ONNXToKrnl/Tensor/Identity.cpp index 9c3adc0..c4612b0 100644 --- a/src/Conversion/ONNXToKrnl/Tensor/Identity.cpp +++ b/src/Conversion/ONNXToKrnl/Tensor/Identity.cpp @@ -16,10 +16,10 @@ struct ONNXIdentityOpLowering : public ConversionPattern { ONNXIdentityOpLowering(MLIRContext *ctx) : ConversionPattern(mlir::ONNXIdentityOp::getOperationName(), 1, ctx) {} - PatternMatchResult - matchAndRewrite(Operation *op, ArrayRef operands, - ConversionPatternRewriter &rewriter) const final { - rewriter.replaceOp(op, operands[0]); + PatternMatchResult matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const final { + ONNXIdentityOpOperandAdaptor operandAdaptor(operands); + rewriter.replaceOp(op, operandAdaptor.input()); return matchSuccess(); } }; diff --git a/src/Conversion/ONNXToKrnl/Tensor/PadConstantValuePad.cpp b/src/Conversion/ONNXToKrnl/Tensor/PadConstantValuePad.cpp index 0841a71..417a6f9 100644 --- a/src/Conversion/ONNXToKrnl/Tensor/PadConstantValuePad.cpp +++ b/src/Conversion/ONNXToKrnl/Tensor/PadConstantValuePad.cpp @@ -14,13 +14,13 @@ using namespace mlir; struct ONNXPadConstantValuePadOpLowering : public ConversionPattern { ONNXPadConstantValuePadOpLowering(MLIRContext *ctx) - : ConversionPattern(mlir::ONNXPadConstantValuePadOp::getOperationName(), - 1, ctx) {} + : ConversionPattern( + mlir::ONNXPadConstantValuePadOp::getOperationName(), 1, ctx) {} - PatternMatchResult - matchAndRewrite(Operation *op, ArrayRef operands, - ConversionPatternRewriter &rewriter) const final { + PatternMatchResult matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const final { auto tensorType = (*op->result_type_begin()); + ONNXPadConstantValuePadOpOperandAdaptor operandAdaptor(operands); auto loc = op->getLoc(); // Only constant padding is supported now. @@ -55,7 +55,7 @@ struct ONNXPadConstantValuePadOpLowering : public ConversionPattern { BuildKrnlLoop valueLoops(rewriter, loc, rank); valueLoops.createDefineAndOptimizeOp(); for (int i = 0; i < rank; ++i) - valueLoops.pushBounds(0, operands[0], i); + valueLoops.pushBounds(0, operandAdaptor.data(), i); valueLoops.createIterateOp(); // Copy the input data into the output. @@ -67,7 +67,7 @@ struct ONNXPadConstantValuePadOpLowering : public ConversionPattern { auto pads = llvm::dyn_cast(op).pads(); SmallVector pad_begin; - for (int i = 0; i < pads.size()/2; ++i) { + for (int i = 0; i < pads.size() / 2; ++i) { pad_begin.emplace_back(pads.getValue()[i].cast().getInt()); } @@ -77,14 +77,14 @@ struct ONNXPadConstantValuePadOpLowering : public ConversionPattern { if (pad_begin[i] == 0) { outLoopIVs.emplace_back(valueLoops.getInductionVar(i)); } else { - auto outIV = rewriter.create( - loc, rewriter.create(loc, pad_begin[i]), + auto outIV = rewriter.create(loc, + rewriter.create(loc, pad_begin[i]), valueLoops.getInductionVar(i)); outLoopIVs.emplace_back(outIV); } } - auto inVal = rewriter.create(loc, operands[0], inLoopIVs); + auto inVal = rewriter.create(loc, operandAdaptor.data(), inLoopIVs); rewriter.create(loc, inVal, alloc, outLoopIVs); rewriter.setInsertionPointToStart(padLoops.getIterateBlock()); diff --git a/src/Conversion/ONNXToKrnl/Tensor/Reshape.cpp b/src/Conversion/ONNXToKrnl/Tensor/Reshape.cpp index 51ac1e5..7eb73ac 100644 --- a/src/Conversion/ONNXToKrnl/Tensor/Reshape.cpp +++ b/src/Conversion/ONNXToKrnl/Tensor/Reshape.cpp @@ -16,11 +16,12 @@ struct ONNXReshapeOpLowering : public ConversionPattern { ONNXReshapeOpLowering(MLIRContext *ctx) : ConversionPattern(mlir::ONNXReshapeOp::getOperationName(), 1, ctx) {} - PatternMatchResult - matchAndRewrite(Operation *op, ArrayRef operands, - ConversionPatternRewriter &rewriter) const final { - auto inputShape = operands[0].getType().cast().getShape(); + PatternMatchResult matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const final { + ONNXReshapeOpOperandAdaptor operandAdaptor(operands); auto loc = op->getLoc(); + Value data = operandAdaptor.data(); + auto inputShape = data.getType().cast().getShape(); // Insert an allocation and deallocation for the result of this operation. auto memRefType = convertToMemRefType(*op->result_type_begin()); @@ -33,7 +34,7 @@ struct ONNXReshapeOpLowering : public ConversionPattern { for (int i = 0; i < inputShape.size(); ++i) { Value dimVal; if (inputShape[i] < 0) { - Value dim = rewriter.create(loc, operands[0], i); + Value dim = rewriter.create(loc, data, i); dimVal = rewriter.create(loc, dim, rewriter.getIntegerType(64)); } else { @@ -61,8 +62,7 @@ struct ONNXReshapeOpLowering : public ConversionPattern { rewriter.getIntegerType(64), getMemRefEltSizeInBytes(memRefType)); SmallVector DimInfo; for (int i = 0; i < memRefShape.size(); ++i) { - Value index = - emitConstantOp(rewriter, loc, rewriter.getIndexType(), i); + Value index = emitConstantOp(rewriter, loc, rewriter.getIndexType(), i); // Load index from array of indices. Value loadedVal = rewriter.create(loc, operands[1], index); // If a dimension is zero, the actual dimension value is taken from the @@ -75,7 +75,7 @@ struct ONNXReshapeOpLowering : public ConversionPattern { Value dimVal; auto loadedValType = loadedVal.getType().cast(); if (inputShape[i] < 0) { - Value dim = rewriter.create(loc, operands[0], i); + Value dim = rewriter.create(loc, data, i); dimVal = rewriter.create(loc, dim, loadedValType); } else { dimVal = @@ -136,7 +136,7 @@ struct ONNXReshapeOpLowering : public ConversionPattern { alloc = allocateMemref; } - rewriter.create(loc, alloc, operands[0], tensorSize); + rewriter.create(loc, alloc, data, tensorSize); rewriter.replaceOp(op, alloc); return matchSuccess(); diff --git a/src/Conversion/ONNXToKrnl/Tensor/Transpose.cpp b/src/Conversion/ONNXToKrnl/Tensor/Transpose.cpp index a2ef7f3..11cf9d8 100644 --- a/src/Conversion/ONNXToKrnl/Tensor/Transpose.cpp +++ b/src/Conversion/ONNXToKrnl/Tensor/Transpose.cpp @@ -16,20 +16,21 @@ struct ONNXTransposeOpLowering : public ConversionPattern { ONNXTransposeOpLowering(MLIRContext *ctx) : ConversionPattern(mlir::ONNXTransposeOp::getOperationName(), 1, ctx) {} - PatternMatchResult - matchAndRewrite(Operation *op, ArrayRef operands, - ConversionPatternRewriter &rewriter) const final { + PatternMatchResult matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const final { + ONNXTransposeOpOperandAdaptor operandAdaptor(operands); auto loc = op->getLoc(); // Insert an allocation and deallocation for the result of this operation. auto memRefType = convertToMemRefType(*op->result_type_begin()); Value alloc; bool insertDealloc = checkInsertDealloc(op); + Value data = operandAdaptor.data(); if (hasAllConstantDimensions(memRefType)) alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc); else - alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc, - {operands[0]}); + alloc = insertAllocAndDealloc( + memRefType, loc, rewriter, insertDealloc, {data}); // Number of loops auto memRefShape = memRefType.getShape(); @@ -38,13 +39,13 @@ struct ONNXTransposeOpLowering : public ConversionPattern { // Define loops. std::vector originalLoops; std::vector optimizedLoops; - Block *optimizationBlock = defineLoops(rewriter, loc, originalLoops, - optimizedLoops, rank); + Block *optimizationBlock = + defineLoops(rewriter, loc, originalLoops, optimizedLoops, rank); KrnlIterateOperandPack pack(rewriter, originalLoops, optimizedLoops); // Iterate over the loop nest using the input shape. for (int i = 0; i < rank; ++i) - addDimensionToPack(rewriter, loc, pack, operands[0], i); + addDimensionToPack(rewriter, loc, pack, data, i); auto iterateOp = rewriter.create(loc, pack); Block &iterationBlock = iterateOp.bodyRegion().front(); @@ -74,8 +75,8 @@ struct ONNXTransposeOpLowering : public ConversionPattern { // TODO: Remove when perm is guaranteed to be present (even for // the default case). This means that perm was added by shape // inference or another pass to contain the values corresponding - // to the default behavior of Transpose. - for (int i = iterationBlock.getArguments().size()-1; i >= 0; i--) + // to the default behavior of Transpose. + for (int i = iterationBlock.getArguments().size() - 1; i >= 0; i--) perm.emplace_back(i); } @@ -84,10 +85,10 @@ struct ONNXTransposeOpLowering : public ConversionPattern { inLoopIVs.emplace_back(arg); SmallVector outLoopIVs; - for (int i=0; i(loc, operands[0], inLoopIVs); + auto inVal = rewriter.create(loc, data, inLoopIVs); rewriter.create(loc, inVal, alloc, outLoopIVs); rewriter.replaceOp(op, alloc); diff --git a/src/Conversion/ONNXToKrnl/Tensor/Unsqueeze.cpp b/src/Conversion/ONNXToKrnl/Tensor/Unsqueeze.cpp index 7e8ef69..502b831 100644 --- a/src/Conversion/ONNXToKrnl/Tensor/Unsqueeze.cpp +++ b/src/Conversion/ONNXToKrnl/Tensor/Unsqueeze.cpp @@ -16,12 +16,13 @@ struct ONNXUnsqueezeOpLowering : public ConversionPattern { ONNXUnsqueezeOpLowering(MLIRContext *ctx) : ConversionPattern(mlir::ONNXUnsqueezeOp::getOperationName(), 1, ctx) {} - PatternMatchResult - matchAndRewrite(Operation *op, ArrayRef operands, - ConversionPatternRewriter &rewriter) const final { + PatternMatchResult matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const final { + ONNXUnsqueezeOpOperandAdaptor operandAdaptor(operands); auto loc = op->getLoc(); auto memRefType = convertToMemRefType(*op->result_type_begin()); int outRank = memRefType.getRank(); + Value data = operandAdaptor.data(); // Assume that `axes` has been validated by shape inference. // So, here we just get it. @@ -55,7 +56,7 @@ struct ONNXUnsqueezeOpLowering : public ConversionPattern { for (int outIdx = 0, inIdx = 0; outIdx < memRefShape.size(); ++outIdx) { Value dimVal = nullptr; if (memRefShape[outIdx] < 0) { - Value index = rewriter.create(loc, operands[0], inIdx); + Value index = rewriter.create(loc, data, inIdx); dimVal = rewriter.create( loc, index, rewriter.getIntegerType(64)); allocOperands.emplace_back(index); @@ -74,7 +75,7 @@ struct ONNXUnsqueezeOpLowering : public ConversionPattern { dealloc.getOperation()->moveBefore(&parentBlock->back()); } } - rewriter.create(loc, alloc, operands[0], tensorSize); + rewriter.create(loc, alloc, data, tensorSize); rewriter.replaceOp(op, alloc); return matchSuccess(); } diff --git a/src/Transform/LowerToLLVM.cpp b/src/Transform/LowerToLLVM.cpp index a4bde78..9abe539 100644 --- a/src/Transform/LowerToLLVM.cpp +++ b/src/Transform/LowerToLLVM.cpp @@ -28,9 +28,7 @@ using namespace mlir; namespace { static FlatSymbolRefAttr getOrInsertExternFunc(StringRef funcName, - ModuleOp module, - mlir::LLVM::LLVMType funcType, - PatternRewriter &rewriter) { + ModuleOp module, mlir::LLVM::LLVMType funcType, PatternRewriter &rewriter) { auto *context = module.getContext(); if (module.lookupSymbol(funcName)) { auto symbolRef = SymbolRefAttr::get(funcName, context); @@ -71,10 +69,10 @@ public: explicit KrnlMemcpyOpLowering(MLIRContext *context) : ConversionPattern(KrnlMemcpyOp::getOperationName(), 1, context) {} - PatternMatchResult - matchAndRewrite(Operation *op, ArrayRef operands, - ConversionPatternRewriter &rewriter) const override { + PatternMatchResult matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { auto *context = op->getContext(); + KrnlMemcpyOpOperandAdaptor operandAdaptor(operands); auto loc = op->getLoc(); auto *llvmDialect = op->getContext()->getRegisteredDialect(); @@ -85,35 +83,39 @@ public: auto memcpyRef = getOrInsertMemcpy(rewriter, parentModule, llvmDialect); // First operand. - Type dstType = - operands[0].getType().cast().getStructElementType(1); + Type dstType = operandAdaptor.dest() + .getType() + .cast() + .getStructElementType(1); Value alignedDstMemory = rewriter.create( - loc, dstType, operands[0], rewriter.getI64ArrayAttr(1)); + loc, dstType, operandAdaptor.dest(), rewriter.getI64ArrayAttr(1)); Value alignedInt8PtrDstMemory = rewriter.create( loc, LLVM::LLVMType::getInt8PtrTy(llvmDialect), alignedDstMemory); // Second operand. - Type srcType = - operands[1].getType().cast().getStructElementType(1); + Type srcType = operandAdaptor.src() + .getType() + .cast() + .getStructElementType(1); Value alignedSrcMemory = rewriter.create( - loc, srcType, operands[1], rewriter.getI64ArrayAttr(1)); + loc, srcType, operandAdaptor.src(), rewriter.getI64ArrayAttr(1)); Value alignedInt8PtrSrcMemory = rewriter.create( loc, LLVM::LLVMType::getInt8PtrTy(llvmDialect), alignedSrcMemory); // Size. Value int64Size = rewriter.create( - loc, LLVM::LLVMType::getInt64Ty(llvmDialect), operands[2]); + loc, LLVM::LLVMType::getInt64Ty(llvmDialect), operandAdaptor.size()); // Is volatile (set to false). - Value isVolatile = rewriter.create( - loc, LLVM::LLVMType::getInt1Ty(llvmDialect), + Value isVolatile = rewriter.create(loc, + LLVM::LLVMType::getInt1Ty(llvmDialect), rewriter.getIntegerAttr(rewriter.getIntegerType(1), 0)); // Memcpy call - rewriter.create( - loc, memcpyRef, LLVM::LLVMType::getVoidTy(llvmDialect), + rewriter.create(loc, memcpyRef, + LLVM::LLVMType::getVoidTy(llvmDialect), ArrayRef({alignedInt8PtrDstMemory, alignedInt8PtrSrcMemory, - int64Size, isVolatile})); + int64Size, isVolatile})); rewriter.eraseOp(op); return matchSuccess(); @@ -123,8 +125,7 @@ private: /// Return a symbol reference to the memcpy function, inserting it into the /// module if necessary. static FlatSymbolRefAttr getOrInsertMemcpy(PatternRewriter &rewriter, - ModuleOp module, - LLVM::LLVMDialect *llvmDialect) { + ModuleOp module, LLVM::LLVMDialect *llvmDialect) { auto *context = module.getContext(); if (module.lookupSymbol("llvm.memcpy.p0i8.p0i8.i64")) return SymbolRefAttr::get("llvm.memcpy.p0i8.p0i8.i64", context); @@ -134,8 +135,7 @@ private: auto llvmI8PtrTy = LLVM::LLVMType::getInt8PtrTy(llvmDialect); auto llvmI64Ty = LLVM::LLVMType::getInt64Ty(llvmDialect); auto llvmI1Ty = LLVM::LLVMType::getInt1Ty(llvmDialect); - auto llvmFnType = LLVM::LLVMType::getFunctionTy( - llvmVoidTy, + auto llvmFnType = LLVM::LLVMType::getFunctionTy(llvmVoidTy, ArrayRef( {llvmI8PtrTy, llvmI8PtrTy, llvmI64Ty, llvmI1Ty}), false); @@ -143,8 +143,8 @@ private: // Insert the memcpy function into the body of the parent module. PatternRewriter::InsertionGuard insertGuard(rewriter); rewriter.setInsertionPointToStart(module.getBody()); - rewriter.create(module.getLoc(), - "llvm.memcpy.p0i8.p0i8.i64", llvmFnType); + rewriter.create( + module.getLoc(), "llvm.memcpy.p0i8.p0i8.i64", llvmFnType); return SymbolRefAttr::get("llvm.memcpy.p0i8.p0i8.i64", context); } }; @@ -176,18 +176,18 @@ public: SmallVector inputTys; ApiSpec(API id, const std::string &name, LLVM::LLVMType outputTy, - ArrayRef inputTys) + ArrayRef inputTys) : id(id), name(name), outputTy(outputTy), inputTys(inputTys.begin(), inputTys.end()) {} LLVM::LLVMType funcTy() { return LLVM::LLVMType::getFunctionTy(outputTy, inputTys, - /*isVarArg=*/false); + /*isVarArg=*/false); } }; - PatternMatchResult matchAndRewrite(KrnlEntryPointOp op, - PatternRewriter &rewriter) const override { + PatternMatchResult matchAndRewrite( + KrnlEntryPointOp op, PatternRewriter &rewriter) const override { auto *llvmDialect = op.getContext()->getRegisteredDialect(); @@ -248,7 +248,7 @@ public: auto idxVal = rewriter.create( loc, int32Ty, rewriter.getI32IntegerAttr(i)); auto dynMemRef = callApi(rewriter, loc, apiRegistry, API::GET_DYN_MEM_REF, - {wrappedInput, idxVal}); + {wrappedInput, idxVal}); // Create a (static) memref type corresponding to the i-th memref input to // the inference function on stack, and load it to memRef. @@ -257,12 +257,12 @@ public: auto one = rewriter.create( loc, int32Ty, rewriter.getI32IntegerAttr(1)); Value ptrToMemRef = rewriter.create(loc, memRefPtrTy, one, - /*alignment=*/0); + /*alignment=*/0); // Fill in the memref underlying ptrToMemRef with information extracted // from dynMemRef. - fillPtrToMemRefWithDynMemRef(dynMemRef, ptrToMemRef, rewriter, loc, - apiRegistry, llvmDialect); + fillPtrToMemRefWithDynMemRef( + dynMemRef, ptrToMemRef, rewriter, loc, apiRegistry, llvmDialect); // ptrToMemRef will be an input to main computation graph function. staticInputs.emplace_back(ptrToMemRef); @@ -273,14 +273,14 @@ public: assert(numOutputs == 1 && "only support 1 output tensor now."); // Call static entry point with the memref ptrs created, and get output. - auto outputMemRefs = rewriter.create( - loc, staticEntryPointTy.getFunctionResultType(), + auto outputMemRefs = rewriter.create(loc, + staticEntryPointTy.getFunctionResultType(), rewriter.getSymbolRefAttr(wrappedStaticEntryPointFuncName), staticInputs); // Create wrapped output. - auto wrappedOutput = callApi(rewriter, loc, apiRegistry, - API::CREATE_ORDERED_DYN_MEM_REF_DICT, {}); + auto wrappedOutput = callApi( + rewriter, loc, apiRegistry, API::CREATE_ORDERED_DYN_MEM_REF_DICT, {}); // Get the first memref returned, convert to a dynamic memref and store // it in the wrapped Output. @@ -290,17 +290,17 @@ public: auto outMemRefRankVal = rewriter.create( loc, int32Ty, rewriter.getI32IntegerAttr(outMemRefRank)); auto outDynMemRef = callApi(rewriter, loc, apiRegistry, - API::CREATE_DYN_MEM_REF, {outMemRefRankVal}); - fillDynMemRefWithMemRef(outMemRef, outDynMemRef, rewriter, loc, apiRegistry, - llvmDialect); + API::CREATE_DYN_MEM_REF, {outMemRefRankVal}); + fillDynMemRefWithMemRef( + outMemRef, outDynMemRef, rewriter, loc, apiRegistry, llvmDialect); auto zero = rewriter.create( loc, int32Ty, rewriter.getI32IntegerAttr(0)); callApi(rewriter, loc, apiRegistry, API::SET_DYN_MEM_REF, - {wrappedOutput, zero, outDynMemRef}); + {wrappedOutput, zero, outDynMemRef}); // Return wrapped output. - rewriter.create(loc, - SmallVector({wrappedOutput})); + rewriter.create( + loc, SmallVector({wrappedOutput})); return matchSuccess(); } @@ -308,7 +308,7 @@ private: using ApiRegistry = std::map; ApiRegistry RegisterAllApis(ModuleOp &module, PatternRewriter &rewriter, - LLVM::LLVMDialect *llvmDialect) const { + LLVM::LLVMDialect *llvmDialect) const { using LLVMType = LLVM::LLVMType; auto voidTy = LLVMType::getVoidTy(llvmDialect); auto opaquePtrTy = LLVMType::getInt8PtrTy(llvmDialect); @@ -335,8 +335,8 @@ private: // identities to a symbol reference to the API function. ApiRegistry registry; for (auto &apiSpec : apiSpecs) { - apiSpec.symbolRef = getOrInsertExternFunc(apiSpec.name, module, - apiSpec.funcTy(), rewriter); + apiSpec.symbolRef = getOrInsertExternFunc( + apiSpec.name, module, apiSpec.funcTy(), rewriter); registry.emplace(apiSpec.id, apiSpec); } @@ -346,10 +346,10 @@ private: // Call a registered API, return the return SSA values if only one result is // returned, otherwise return nullptr. Value callApi(PatternRewriter &rewriter, Location loc, ApiRegistry registry, - API apiId, ArrayRef params) const { - auto returnVals = rewriter.create( - loc, registry.at(apiId).outputTy, registry.at(apiId).symbolRef, - ArrayRef(params)); + API apiId, ArrayRef params) const { + auto returnVals = + rewriter.create(loc, registry.at(apiId).outputTy, + registry.at(apiId).symbolRef, ArrayRef(params)); if (returnVals.getNumResults() == 1) return returnVals.getResult(0); return nullptr; @@ -358,7 +358,7 @@ private: // Helper function to insert an entry block to LLVM function. // (TODO): upstream this to MLIR. Block &createEntryBlock(LLVM::LLVMType &dynEntryPointFuncType, - LLVM::LLVMFuncOp &dynamicEntryPointFunc) const { + LLVM::LLVMFuncOp &dynamicEntryPointFunc) const { // Add entry block: auto *entryPointEntryBlock = new Block(); dynamicEntryPointFunc.push_back(entryPointEntryBlock); @@ -370,10 +370,9 @@ private: } void fillPtrToMemRefWithDynMemRef(Value &dynMemRef, Value &ptrToMemRef, - PatternRewriter &rewriter, - const Location &loc, - const std::map &apiRegistry, - LLVM::LLVMDialect *llvmDialect) const { + PatternRewriter &rewriter, const Location &loc, + const std::map &apiRegistry, + LLVM::LLVMDialect *llvmDialect) const { auto memRefPtrTy = ptrToMemRef.getType().dyn_cast(); auto memRefTy = memRefPtrTy.getPointerElementTy(); auto int64Ty = LLVM::LLVMType::getInt64Ty(llvmDialect); @@ -385,18 +384,15 @@ private: callApi(rewriter, loc, apiRegistry, API::GET_DATA, {dynMemRef}); dataPtr = rewriter.create( loc, memRefTy.getStructElementType(0), dataPtr); - memRef = rewriter.create( - loc, memRefTy, memRef, dataPtr, - rewriter.getArrayAttr({rewriter.getI32IntegerAttr(0)})); - memRef = rewriter.create( - loc, memRefTy, memRef, dataPtr, - rewriter.getArrayAttr({rewriter.getI32IntegerAttr(1)})); + memRef = rewriter.create(loc, memRefTy, memRef, + dataPtr, rewriter.getArrayAttr({rewriter.getI32IntegerAttr(0)})); + memRef = rewriter.create(loc, memRefTy, memRef, + dataPtr, rewriter.getArrayAttr({rewriter.getI32IntegerAttr(1)})); // Use zero offset now. auto zero = rewriter.create( loc, int64Ty, rewriter.getI64IntegerAttr(0)); - memRef = rewriter.create( - loc, memRefTy, memRef, zero, + memRef = rewriter.create(loc, memRefTy, memRef, zero, rewriter.getArrayAttr({rewriter.getI32IntegerAttr(2)})); // Get rank, sizes array ptr and strides array ptr. @@ -411,24 +407,22 @@ private: loc, int64Ty, rewriter.getI64IntegerAttr(i)); // Insert size of the dimension. - auto dimSizePtr = rewriter.create( - loc, int64Ty.getPointerTo(), sizesArrayPtr, - ArrayRef({dimIdx})); - auto dimSize = rewriter.create(loc, int64Ty.getPointerTo(), - dimSizePtr); - memRef = rewriter.create( - loc, memRefTy, memRef, dimSize, + auto dimSizePtr = rewriter.create(loc, + int64Ty.getPointerTo(), sizesArrayPtr, ArrayRef({dimIdx})); + auto dimSize = rewriter.create( + loc, int64Ty.getPointerTo(), dimSizePtr); + memRef = rewriter.create(loc, memRefTy, memRef, + dimSize, rewriter.getArrayAttr( {rewriter.getI64IntegerAttr(3), rewriter.getI64IntegerAttr(i)})); // Insert stride of the dimension. - auto dimStridePtr = rewriter.create( - loc, int64Ty.getPointerTo(), sizesArrayPtr, - ArrayRef({dimIdx})); + auto dimStridePtr = rewriter.create(loc, + int64Ty.getPointerTo(), sizesArrayPtr, ArrayRef({dimIdx})); auto dimStride = rewriter.create( loc, int64Ty.getPointerTo(), dimStridePtr); - memRef = rewriter.create( - loc, memRefTy, memRef, dimStride, + memRef = rewriter.create(loc, memRefTy, memRef, + dimStride, rewriter.getArrayAttr( {rewriter.getI64IntegerAttr(4), rewriter.getI64IntegerAttr(i)})); } @@ -437,20 +431,20 @@ private: } void fillDynMemRefWithMemRef(Value &outMemRef, Value &outDynMemRef, - PatternRewriter &rewriter, const Location &loc, - const std::map &apiRegistry, - LLVM::LLVMDialect *llvmDialect) const { + PatternRewriter &rewriter, const Location &loc, + const std::map &apiRegistry, + LLVM::LLVMDialect *llvmDialect) const { auto outMemRefTy = outMemRef.getType().dyn_cast(); auto int64Ty = LLVM::LLVMType::getInt64Ty(llvmDialect); // Extract the data pointer, and record it in dynamic mem ref created. - Value outMemRefDataPtr = rewriter.create( - loc, outMemRefTy.getStructElementType(0), outMemRef, + Value outMemRefDataPtr = rewriter.create(loc, + outMemRefTy.getStructElementType(0), outMemRef, rewriter.getArrayAttr({rewriter.getI64IntegerAttr(0)})); outMemRefDataPtr = rewriter.create( loc, LLVM::LLVMType::getInt8PtrTy(llvmDialect), outMemRefDataPtr); callApi(rewriter, loc, apiRegistry, API::SET_DATA, - {outDynMemRef, outMemRefDataPtr}); + {outDynMemRef, outMemRefDataPtr}); auto rank = getRankFromMemRefType(outMemRefTy); auto sizesArrayPtr = @@ -463,23 +457,21 @@ private: loc, int64Ty, rewriter.getI64IntegerAttr(i)); // Transfer size of dimension from memref to dynamic memref. - auto dimSize = rewriter.create( - loc, int64Ty, outMemRef, + auto dimSize = rewriter.create(loc, int64Ty, + outMemRef, rewriter.getArrayAttr( {rewriter.getI64IntegerAttr(3), rewriter.getI64IntegerAttr(i)})); - auto dimSizePtr = rewriter.create( - loc, int64Ty.getPointerTo(), sizesArrayPtr, - ArrayRef({dimIdx})); + auto dimSizePtr = rewriter.create(loc, + int64Ty.getPointerTo(), sizesArrayPtr, ArrayRef({dimIdx})); rewriter.create(loc, dimSize, dimSizePtr); // Transfer stride of dimension from memref to dynamic memref. - auto dimStride = rewriter.create( - loc, int64Ty, outMemRef, + auto dimStride = rewriter.create(loc, int64Ty, + outMemRef, rewriter.getArrayAttr( {rewriter.getI64IntegerAttr(4), rewriter.getI64IntegerAttr(i)})); - auto dimStridePtr = rewriter.create( - loc, int64Ty.getPointerTo(), stridesArrayPtr, - ArrayRef({dimIdx})); + auto dimStridePtr = rewriter.create(loc, + int64Ty.getPointerTo(), stridesArrayPtr, ArrayRef({dimIdx})); rewriter.create(loc, dimStride, dimStridePtr); } } @@ -511,8 +503,8 @@ void KrnlToLLVMLoweringPass::runOnModule() { populateAffineToStdConversionPatterns(patterns, &getContext()); populateLoopToStdConversionPatterns(patterns, &getContext()); populateStdToLLVMConversionPatterns(typeConverter, patterns, - /*useAlloca=*/false, - /*emitCWrapper=*/true); + /*useAlloca=*/false, + /*emitCWrapper=*/true); // Lower from the `krnl` dialect i.e. the Reshape operation. patterns.insert( @@ -530,5 +522,5 @@ std::unique_ptr mlir::createKrnlLowerToLLVMPass() { return std::make_unique(); } -static PassRegistration - pass("lower-all-llvm", "Lower the Krnl Affine and Std dialects to LLVM."); +static PassRegistration pass( + "lower-all-llvm", "Lower the Krnl Affine and Std dialects to LLVM."); diff --git a/src/Transform/ONNX/ONNXRewrite.cpp b/src/Transform/ONNX/ONNXRewrite.cpp index afb4030..cde2cb3 100644 --- a/src/Transform/ONNX/ONNXRewrite.cpp +++ b/src/Transform/ONNX/ONNXRewrite.cpp @@ -22,7 +22,7 @@ namespace { bool hasNonZeroInArrayAttr(ArrayAttr attrs) { bool allZeros = true; if (attrs) { - for (auto attr: attrs.getValue()) { + for (auto attr : attrs.getValue()) { if (attr.cast().getInt() > 0) { allZeros = false; break; @@ -54,7 +54,7 @@ ArrayAttr createArrayAttrOfZeros( // This function is used for padding attribute in MaxPoolSingleOut. ArrayAttr insertZerosForNonPaddedDims( PatternRewriter &rewriter, ArrayAttr origAttrs, int extensionLength) { - int nDims = (int) origAttrs.getValue().size() / 2; + int nDims = (int)origAttrs.getValue().size() / 2; int nElements = (nDims + extensionLength) * 2; SmallVector pads(nElements, 0); for (int i = 0; i < nDims; ++i) {