clean operands using names provided by operandAdaptors (#56)

* clean operands using names provided by operandAdaptors

* reverted changes that were erronerous, from Tung's comment

* clang format issues
This commit is contained in:
Alexandre Eichenberger 2020-03-31 11:55:27 -04:00 committed by GitHub
parent 844dcd8b1f
commit b422116f12
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 175 additions and 179 deletions

View File

@ -24,10 +24,11 @@ struct ONNXGemmOpLowering : public ConversionPattern {
bool hasBias = !op->getOperand(2).getType().isa<NoneType>(); bool hasBias = !op->getOperand(2).getType().isa<NoneType>();
Value A, B, C; Value A, B, C;
A = operands[0]; ONNXGemmOpOperandAdaptor operandAdaptor(operands);
B = operands[1]; A = operandAdaptor.A();
B = operandAdaptor.B();
if (hasBias) if (hasBias)
C = operands[2]; C = operandAdaptor.C();
auto memRefType = convertToMemRefType(*op->result_type_begin()); auto memRefType = convertToMemRefType(*op->result_type_begin());

View File

@ -21,8 +21,9 @@ struct ONNXMatMulOpLowering : public ConversionPattern {
ConversionPatternRewriter &rewriter) const final { ConversionPatternRewriter &rewriter) const final {
auto loc = op->getLoc(); auto loc = op->getLoc();
Value A = operands[0]; ONNXMatMulOpOperandAdaptor operandAdaptor(operands);
Value B = operands[1]; Value A = operandAdaptor.A();
Value B = operandAdaptor.B();
auto AShape = A.getType().cast<MemRefType>().getShape(); auto AShape = A.getType().cast<MemRefType>().getShape();
auto BShape = B.getType().cast<MemRefType>().getShape(); auto BShape = B.getType().cast<MemRefType>().getShape();

View File

@ -15,9 +15,8 @@ using namespace mlir;
struct ONNXSoftmaxOpLowering : public ConversionPattern { struct ONNXSoftmaxOpLowering : public ConversionPattern {
ONNXSoftmaxOpLowering(MLIRContext *ctx) ONNXSoftmaxOpLowering(MLIRContext *ctx)
: ConversionPattern(mlir::ONNXSoftmaxOp::getOperationName(), 1, ctx) {} : ConversionPattern(mlir::ONNXSoftmaxOp::getOperationName(), 1, ctx) {}
PatternMatchResult PatternMatchResult matchAndRewrite(Operation *op, ArrayRef<Value> operands,
matchAndRewrite(Operation *op, ArrayRef<Value> operands, ConversionPatternRewriter &rewriter) const final {
ConversionPatternRewriter &rewriter) const final {
// softmax(x) = let max_x = max(x) in // softmax(x) = let max_x = max(x) in
// let exp_x = exp(x - max_x) in // let exp_x = exp(x - max_x) in
// let sum = sum(exp_x) in // let sum = sum(exp_x) in
@ -29,7 +28,8 @@ struct ONNXSoftmaxOpLowering : public ConversionPattern {
assert(axis >= -rank && axis <= rank - 1); assert(axis >= -rank && axis <= rank - 1);
auto loc = op->getLoc(); auto loc = op->getLoc();
ONNXSoftmaxOpOperandAdaptor operandAdaptor(operands);
Value input = operandAdaptor.input();
// Insert an allocation and deallocation for the result of this operation. // Insert an allocation and deallocation for the result of this operation.
auto elementType = memRefType.getElementType(); auto elementType = memRefType.getElementType();
@ -38,8 +38,8 @@ struct ONNXSoftmaxOpLowering : public ConversionPattern {
if (hasAllConstantDimensions(memRefType)) if (hasAllConstantDimensions(memRefType))
alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc); alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc);
else else
alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc, alloc = insertAllocAndDealloc(
operands[0]); memRefType, loc, rewriter, insertDealloc, input);
// Shape of the result // Shape of the result
auto memRefShape = memRefType.getShape(); auto memRefShape = memRefType.getShape();
@ -49,15 +49,14 @@ struct ONNXSoftmaxOpLowering : public ConversionPattern {
Value sumOp = insertAllocAndDealloc(scalarMemRefType, loc, rewriter, true); Value sumOp = insertAllocAndDealloc(scalarMemRefType, loc, rewriter, true);
Value maxOp = insertAllocAndDealloc(scalarMemRefType, loc, rewriter, true); Value maxOp = insertAllocAndDealloc(scalarMemRefType, loc, rewriter, true);
Value zero = emitConstantOp(rewriter, loc, elementType, 0); Value zero = emitConstantOp(rewriter, loc, elementType, 0);
Value negInfinity = rewriter.create<ConstantOp>( Value negInfinity = rewriter.create<ConstantOp>(loc,
loc,
FloatAttr::get(elementType, -std::numeric_limits<float>::infinity())); FloatAttr::get(elementType, -std::numeric_limits<float>::infinity()));
// Define loops. // Define loops.
std::vector<Value> originalLoops; std::vector<Value> originalLoops;
std::vector<Value> optimizedLoops; std::vector<Value> optimizedLoops;
Block *optimizationBlock = defineLoops(rewriter, loc, originalLoops, Block *optimizationBlock =
optimizedLoops, rank); defineLoops(rewriter, loc, originalLoops, optimizedLoops, rank);
// Coerce the input into a 2-D tensor. `axis` will be the coercing point. // Coerce the input into a 2-D tensor. `axis` will be the coercing point.
// This coercing follows the softmax definition in ONNX: // This coercing follows the softmax definition in ONNX:
@ -75,7 +74,7 @@ struct ONNXSoftmaxOpLowering : public ConversionPattern {
} }
KrnlIterateOperandPack outerPack(rewriter, outerLoops, optimizedOuterLoops); KrnlIterateOperandPack outerPack(rewriter, outerLoops, optimizedOuterLoops);
for (int i = 0; i < axis; ++i) for (int i = 0; i < axis; ++i)
addDimensionToPack(rewriter, loc, outerPack, operands[0], i); addDimensionToPack(rewriter, loc, outerPack, input, i);
// Define an inner loop with respect to axis. // Define an inner loop with respect to axis.
std::vector<Value> innerLoops, optimizedInnerLoops; std::vector<Value> innerLoops, optimizedInnerLoops;
@ -87,7 +86,7 @@ struct ONNXSoftmaxOpLowering : public ConversionPattern {
} }
KrnlIterateOperandPack innerPack(rewriter, innerLoops, optimizedInnerLoops); KrnlIterateOperandPack innerPack(rewriter, innerLoops, optimizedInnerLoops);
for (int i = axis; i < rank; ++i) for (int i = axis; i < rank; ++i)
addDimensionToPack(rewriter, loc, innerPack, operands[0], i); addDimensionToPack(rewriter, loc, innerPack, input, i);
KrnlIterateOp outerIterateOp, maxIterateOp, sumIterateOp, softmaxIterateOp; KrnlIterateOp outerIterateOp, maxIterateOp, sumIterateOp, softmaxIterateOp;
SmallVector<Value, 4> outerLoopIVs; SmallVector<Value, 4> outerLoopIVs;
@ -144,7 +143,7 @@ struct ONNXSoftmaxOpLowering : public ConversionPattern {
// Compute the max value. // Compute the max value.
Value max = rewriter.create<LoadOp>(loc, maxOp); Value max = rewriter.create<LoadOp>(loc, maxOp);
Value nextMax = rewriter.create<LoadOp>(loc, operands[0], maxLoopIVs); Value nextMax = rewriter.create<LoadOp>(loc, input, maxLoopIVs);
auto maxCond = auto maxCond =
rewriter.create<CmpFOp>(loc, CmpFPredicate::OGT, max, nextMax); rewriter.create<CmpFOp>(loc, CmpFPredicate::OGT, max, nextMax);
max = rewriter.create<SelectOp>(loc, maxCond, max, nextMax); max = rewriter.create<SelectOp>(loc, maxCond, max, nextMax);
@ -167,7 +166,7 @@ struct ONNXSoftmaxOpLowering : public ConversionPattern {
// Sum up values. // Sum up values.
Value sum = rewriter.create<LoadOp>(loc, sumOp); Value sum = rewriter.create<LoadOp>(loc, sumOp);
Value next = rewriter.create<LoadOp>(loc, operands[0], sumLoopIVs); Value next = rewriter.create<LoadOp>(loc, input, sumLoopIVs);
Value sub = rewriter.create<SubFOp>(loc, next, max); Value sub = rewriter.create<SubFOp>(loc, next, max);
Value exp = rewriter.create<ExpOp>(loc, sub); Value exp = rewriter.create<ExpOp>(loc, sub);
sum = rewriter.create<AddFOp>(loc, sum, exp); sum = rewriter.create<AddFOp>(loc, sum, exp);

View File

@ -26,12 +26,6 @@ struct ONNXConvOpLowering : public ConversionPattern {
bool insertDealloc = checkInsertDealloc(op); bool insertDealloc = checkInsertDealloc(op);
ONNXConvOp convOp = llvm::dyn_cast<ONNXConvOp>(op); ONNXConvOp convOp = llvm::dyn_cast<ONNXConvOp>(op);
if (hasAllConstantDimensions(memRefType))
alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc);
else
alloc = insertAllocAndDealloc(
memRefType, loc, rewriter, insertDealloc, {operands[0]});
auto resultShape = memRefType.getShape(); auto resultShape = memRefType.getShape();
auto inputOperand = operandAdaptor.X(); auto inputOperand = operandAdaptor.X();
auto inputShape = inputOperand.getType().cast<MemRefType>().getShape(); auto inputShape = inputOperand.getType().cast<MemRefType>().getShape();
@ -40,6 +34,12 @@ struct ONNXConvOpLowering : public ConversionPattern {
auto biasOperand = operandAdaptor.B(); auto biasOperand = operandAdaptor.B();
bool hasBias = !biasOperand.getType().isa<NoneType>(); bool hasBias = !biasOperand.getType().isa<NoneType>();
if (hasAllConstantDimensions(memRefType))
alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc);
else
alloc = insertAllocAndDealloc(
memRefType, loc, rewriter, insertDealloc, {inputOperand});
// R = Conv(D, K) // R = Conv(D, K)
// //
// The input/output shapes will look like this: // The input/output shapes will look like this:

View File

@ -1,4 +1,4 @@
//===----------- Normalization.cpp - Lowering Normalization Ops ------------===// //===----------- Normalization.cpp - Lowering Normalization Ops -----------===//
// //
// Copyright 2019 The IBM Research Authors. // Copyright 2019 The IBM Research Authors.
// //
@ -18,24 +18,24 @@ struct ONNXBatchNormalizationTestModeOpLowering : public ConversionPattern {
mlir::ONNXBatchNormalizationTestModeOp::getOperationName(), 1, mlir::ONNXBatchNormalizationTestModeOp::getOperationName(), 1,
ctx) {} ctx) {}
PatternMatchResult matchAndRewrite(Operation *op, ArrayRef<Value> operands, PatternMatchResult matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter & rewriter) const final { ConversionPatternRewriter &rewriter) const final {
// batchnorm{epsilon}(x, scale, bias, mean, variance) = // batchnorm{epsilon}(x, scale, bias, mean, variance) =
// scale * (x - mean) / sqrt(variance + epsilon) + bias // scale * (x - mean) / sqrt(variance + epsilon) + bias
ONNXBatchNormalizationTestModeOpOperandAdaptor operandAdaptor(operands);
auto loc = op->getLoc(); auto loc = op->getLoc();
auto memRefType = convertToMemRefType(*op->result_type_begin()); auto memRefType = convertToMemRefType(*op->result_type_begin());
auto epsilonAttr = auto epsilonAttr = FloatAttr::get(memRefType.getElementType(),
FloatAttr::get(memRefType.getElementType(), llvm::dyn_cast<ONNXBatchNormalizationTestModeOp>(op)
llvm::dyn_cast<ONNXBatchNormalizationTestModeOp>(op) .epsilon()
.epsilon() .convertToFloat());
.convertToFloat());
auto epsilon = rewriter.create<ConstantOp>(loc, epsilonAttr); auto epsilon = rewriter.create<ConstantOp>(loc, epsilonAttr);
auto operand = operands[0]; auto operand = operandAdaptor.X();
auto scale = operands[1]; auto scale = operandAdaptor.scale();
auto bias = operands[2]; auto bias = operandAdaptor.B();
auto mean = operands[3]; auto mean = operandAdaptor.mean();
auto variance = operands[4]; auto variance = operandAdaptor.var();
// Insert an allocation and deallocation for the result of this operation. // Insert an allocation and deallocation for the result of this operation.
Value alloc; Value alloc;
@ -44,8 +44,8 @@ struct ONNXBatchNormalizationTestModeOpLowering : public ConversionPattern {
if (hasAllConstantDimensions(memRefType)) if (hasAllConstantDimensions(memRefType))
alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc); alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc);
else else
alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc, alloc = insertAllocAndDealloc(
{operand}); memRefType, loc, rewriter, insertDealloc, {operand});
// Operand's dimensions can be in the form of NxCxD1xD2x...xDn or N. // Operand's dimensions can be in the form of NxCxD1xD2x...xDn or N.
// In case of N, C is assumed to be 1. // In case of N, C is assumed to be 1.
@ -67,8 +67,8 @@ struct ONNXBatchNormalizationTestModeOpLowering : public ConversionPattern {
SmallVector<Value, 1> loopCIVs; SmallVector<Value, 1> loopCIVs;
if (rank > 1) { if (rank > 1) {
KrnlIterateOperandPack cPack(rewriter, originalLoops[1], KrnlIterateOperandPack cPack(
optimizedLoops[1]); rewriter, originalLoops[1], optimizedLoops[1]);
addDimensionToPack(rewriter, loc, cPack, operand, 1); addDimensionToPack(rewriter, loc, cPack, operand, 1);
auto cIterateOp = rewriter.create<KrnlIterateOp>(loc, cPack); auto cIterateOp = rewriter.create<KrnlIterateOp>(loc, cPack);
Block &cIterationBlock = cIterateOp.bodyRegion().front(); Block &cIterationBlock = cIterateOp.bodyRegion().front();
@ -76,7 +76,7 @@ struct ONNXBatchNormalizationTestModeOpLowering : public ConversionPattern {
for (auto arg : cIterationBlock.getArguments()) for (auto arg : cIterationBlock.getArguments())
loopCIVs.emplace_back(arg); loopCIVs.emplace_back(arg);
} else { } else {
loopCIVs.emplace_back(rewriter.create<ConstantIndexOp>(loc, 0)); loopCIVs.emplace_back(rewriter.create<ConstantIndexOp>(loc, 0));
} }
auto scaleVal = rewriter.create<LoadOp>(loc, scale, loopCIVs); auto scaleVal = rewriter.create<LoadOp>(loc, scale, loopCIVs);

View File

@ -38,6 +38,7 @@ struct ONNXMaxPoolSingleOutOpLowering : public ConversionPattern {
PatternMatchResult matchAndRewrite(Operation *op, ArrayRef<Value> operands, PatternMatchResult matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final { ConversionPatternRewriter &rewriter) const final {
ONNXMaxPoolSingleOutOpOperandAdaptor operandAdaptor(operands);
auto loc = op->getLoc(); auto loc = op->getLoc();
// Match // Match
@ -71,7 +72,7 @@ struct ONNXMaxPoolSingleOutOpLowering : public ConversionPattern {
dilations.emplace_back(dilation.cast<IntegerAttr>().getInt()); dilations.emplace_back(dilation.cast<IntegerAttr>().getInt());
// Type information about the input and result of this operation. // Type information about the input and result of this operation.
auto &inputOperand = operands[0]; auto inputOperand = operandAdaptor.X();
auto inputShape = inputOperand.getType().cast<MemRefType>().getShape(); auto inputShape = inputOperand.getType().cast<MemRefType>().getShape();
auto memRefType = convertToMemRefType(*op->result_type_begin()); auto memRefType = convertToMemRefType(*op->result_type_begin());
auto resultShape = memRefType.getShape(); auto resultShape = memRefType.getShape();

View File

@ -16,10 +16,10 @@ struct ONNXIdentityOpLowering : public ConversionPattern {
ONNXIdentityOpLowering(MLIRContext *ctx) ONNXIdentityOpLowering(MLIRContext *ctx)
: ConversionPattern(mlir::ONNXIdentityOp::getOperationName(), 1, ctx) {} : ConversionPattern(mlir::ONNXIdentityOp::getOperationName(), 1, ctx) {}
PatternMatchResult PatternMatchResult matchAndRewrite(Operation *op, ArrayRef<Value> operands,
matchAndRewrite(Operation *op, ArrayRef<Value> operands, ConversionPatternRewriter &rewriter) const final {
ConversionPatternRewriter &rewriter) const final { ONNXIdentityOpOperandAdaptor operandAdaptor(operands);
rewriter.replaceOp(op, operands[0]); rewriter.replaceOp(op, operandAdaptor.input());
return matchSuccess(); return matchSuccess();
} }
}; };

View File

@ -14,13 +14,13 @@ using namespace mlir;
struct ONNXPadConstantValuePadOpLowering : public ConversionPattern { struct ONNXPadConstantValuePadOpLowering : public ConversionPattern {
ONNXPadConstantValuePadOpLowering(MLIRContext *ctx) ONNXPadConstantValuePadOpLowering(MLIRContext *ctx)
: ConversionPattern(mlir::ONNXPadConstantValuePadOp::getOperationName(), : ConversionPattern(
1, ctx) {} mlir::ONNXPadConstantValuePadOp::getOperationName(), 1, ctx) {}
PatternMatchResult PatternMatchResult matchAndRewrite(Operation *op, ArrayRef<Value> operands,
matchAndRewrite(Operation *op, ArrayRef<Value> operands, ConversionPatternRewriter &rewriter) const final {
ConversionPatternRewriter &rewriter) const final {
auto tensorType = (*op->result_type_begin()); auto tensorType = (*op->result_type_begin());
ONNXPadConstantValuePadOpOperandAdaptor operandAdaptor(operands);
auto loc = op->getLoc(); auto loc = op->getLoc();
// Only constant padding is supported now. // Only constant padding is supported now.
@ -55,7 +55,7 @@ struct ONNXPadConstantValuePadOpLowering : public ConversionPattern {
BuildKrnlLoop valueLoops(rewriter, loc, rank); BuildKrnlLoop valueLoops(rewriter, loc, rank);
valueLoops.createDefineAndOptimizeOp(); valueLoops.createDefineAndOptimizeOp();
for (int i = 0; i < rank; ++i) for (int i = 0; i < rank; ++i)
valueLoops.pushBounds(0, operands[0], i); valueLoops.pushBounds(0, operandAdaptor.data(), i);
valueLoops.createIterateOp(); valueLoops.createIterateOp();
// Copy the input data into the output. // Copy the input data into the output.
@ -67,7 +67,7 @@ struct ONNXPadConstantValuePadOpLowering : public ConversionPattern {
auto pads = llvm::dyn_cast<ONNXPadConstantValuePadOp>(op).pads(); auto pads = llvm::dyn_cast<ONNXPadConstantValuePadOp>(op).pads();
SmallVector<int64_t, 4> pad_begin; SmallVector<int64_t, 4> pad_begin;
for (int i = 0; i < pads.size()/2; ++i) { for (int i = 0; i < pads.size() / 2; ++i) {
pad_begin.emplace_back(pads.getValue()[i].cast<IntegerAttr>().getInt()); pad_begin.emplace_back(pads.getValue()[i].cast<IntegerAttr>().getInt());
} }
@ -77,14 +77,14 @@ struct ONNXPadConstantValuePadOpLowering : public ConversionPattern {
if (pad_begin[i] == 0) { if (pad_begin[i] == 0) {
outLoopIVs.emplace_back(valueLoops.getInductionVar(i)); outLoopIVs.emplace_back(valueLoops.getInductionVar(i));
} else { } else {
auto outIV = rewriter.create<AddIOp>( auto outIV = rewriter.create<AddIOp>(loc,
loc, rewriter.create<ConstantIndexOp>(loc, pad_begin[i]), rewriter.create<ConstantIndexOp>(loc, pad_begin[i]),
valueLoops.getInductionVar(i)); valueLoops.getInductionVar(i));
outLoopIVs.emplace_back(outIV); outLoopIVs.emplace_back(outIV);
} }
} }
auto inVal = rewriter.create<LoadOp>(loc, operands[0], inLoopIVs); auto inVal = rewriter.create<LoadOp>(loc, operandAdaptor.data(), inLoopIVs);
rewriter.create<StoreOp>(loc, inVal, alloc, outLoopIVs); rewriter.create<StoreOp>(loc, inVal, alloc, outLoopIVs);
rewriter.setInsertionPointToStart(padLoops.getIterateBlock()); rewriter.setInsertionPointToStart(padLoops.getIterateBlock());

View File

@ -16,11 +16,12 @@ struct ONNXReshapeOpLowering : public ConversionPattern {
ONNXReshapeOpLowering(MLIRContext *ctx) ONNXReshapeOpLowering(MLIRContext *ctx)
: ConversionPattern(mlir::ONNXReshapeOp::getOperationName(), 1, ctx) {} : ConversionPattern(mlir::ONNXReshapeOp::getOperationName(), 1, ctx) {}
PatternMatchResult PatternMatchResult matchAndRewrite(Operation *op, ArrayRef<Value> operands,
matchAndRewrite(Operation *op, ArrayRef<Value> operands, ConversionPatternRewriter &rewriter) const final {
ConversionPatternRewriter &rewriter) const final { ONNXReshapeOpOperandAdaptor operandAdaptor(operands);
auto inputShape = operands[0].getType().cast<MemRefType>().getShape();
auto loc = op->getLoc(); auto loc = op->getLoc();
Value data = operandAdaptor.data();
auto inputShape = data.getType().cast<MemRefType>().getShape();
// Insert an allocation and deallocation for the result of this operation. // Insert an allocation and deallocation for the result of this operation.
auto memRefType = convertToMemRefType(*op->result_type_begin()); auto memRefType = convertToMemRefType(*op->result_type_begin());
@ -33,7 +34,7 @@ struct ONNXReshapeOpLowering : public ConversionPattern {
for (int i = 0; i < inputShape.size(); ++i) { for (int i = 0; i < inputShape.size(); ++i) {
Value dimVal; Value dimVal;
if (inputShape[i] < 0) { if (inputShape[i] < 0) {
Value dim = rewriter.create<DimOp>(loc, operands[0], i); Value dim = rewriter.create<DimOp>(loc, data, i);
dimVal = dimVal =
rewriter.create<IndexCastOp>(loc, dim, rewriter.getIntegerType(64)); rewriter.create<IndexCastOp>(loc, dim, rewriter.getIntegerType(64));
} else { } else {
@ -61,8 +62,7 @@ struct ONNXReshapeOpLowering : public ConversionPattern {
rewriter.getIntegerType(64), getMemRefEltSizeInBytes(memRefType)); rewriter.getIntegerType(64), getMemRefEltSizeInBytes(memRefType));
SmallVector<Value, 4> DimInfo; SmallVector<Value, 4> DimInfo;
for (int i = 0; i < memRefShape.size(); ++i) { for (int i = 0; i < memRefShape.size(); ++i) {
Value index = Value index = emitConstantOp(rewriter, loc, rewriter.getIndexType(), i);
emitConstantOp(rewriter, loc, rewriter.getIndexType(), i);
// Load index from array of indices. // Load index from array of indices.
Value loadedVal = rewriter.create<LoadOp>(loc, operands[1], index); Value loadedVal = rewriter.create<LoadOp>(loc, operands[1], index);
// If a dimension is zero, the actual dimension value is taken from the // If a dimension is zero, the actual dimension value is taken from the
@ -75,7 +75,7 @@ struct ONNXReshapeOpLowering : public ConversionPattern {
Value dimVal; Value dimVal;
auto loadedValType = loadedVal.getType().cast<IntegerType>(); auto loadedValType = loadedVal.getType().cast<IntegerType>();
if (inputShape[i] < 0) { if (inputShape[i] < 0) {
Value dim = rewriter.create<DimOp>(loc, operands[0], i); Value dim = rewriter.create<DimOp>(loc, data, i);
dimVal = rewriter.create<IndexCastOp>(loc, dim, loadedValType); dimVal = rewriter.create<IndexCastOp>(loc, dim, loadedValType);
} else { } else {
dimVal = dimVal =
@ -136,7 +136,7 @@ struct ONNXReshapeOpLowering : public ConversionPattern {
alloc = allocateMemref; alloc = allocateMemref;
} }
rewriter.create<KrnlMemcpyOp>(loc, alloc, operands[0], tensorSize); rewriter.create<KrnlMemcpyOp>(loc, alloc, data, tensorSize);
rewriter.replaceOp(op, alloc); rewriter.replaceOp(op, alloc);
return matchSuccess(); return matchSuccess();

View File

@ -16,20 +16,21 @@ struct ONNXTransposeOpLowering : public ConversionPattern {
ONNXTransposeOpLowering(MLIRContext *ctx) ONNXTransposeOpLowering(MLIRContext *ctx)
: ConversionPattern(mlir::ONNXTransposeOp::getOperationName(), 1, ctx) {} : ConversionPattern(mlir::ONNXTransposeOp::getOperationName(), 1, ctx) {}
PatternMatchResult PatternMatchResult matchAndRewrite(Operation *op, ArrayRef<Value> operands,
matchAndRewrite(Operation *op, ArrayRef<Value> operands, ConversionPatternRewriter &rewriter) const final {
ConversionPatternRewriter &rewriter) const final { ONNXTransposeOpOperandAdaptor operandAdaptor(operands);
auto loc = op->getLoc(); auto loc = op->getLoc();
// Insert an allocation and deallocation for the result of this operation. // Insert an allocation and deallocation for the result of this operation.
auto memRefType = convertToMemRefType(*op->result_type_begin()); auto memRefType = convertToMemRefType(*op->result_type_begin());
Value alloc; Value alloc;
bool insertDealloc = checkInsertDealloc(op); bool insertDealloc = checkInsertDealloc(op);
Value data = operandAdaptor.data();
if (hasAllConstantDimensions(memRefType)) if (hasAllConstantDimensions(memRefType))
alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc); alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc);
else else
alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc, alloc = insertAllocAndDealloc(
{operands[0]}); memRefType, loc, rewriter, insertDealloc, {data});
// Number of loops // Number of loops
auto memRefShape = memRefType.getShape(); auto memRefShape = memRefType.getShape();
@ -38,13 +39,13 @@ struct ONNXTransposeOpLowering : public ConversionPattern {
// Define loops. // Define loops.
std::vector<Value> originalLoops; std::vector<Value> originalLoops;
std::vector<Value> optimizedLoops; std::vector<Value> optimizedLoops;
Block *optimizationBlock = defineLoops(rewriter, loc, originalLoops, Block *optimizationBlock =
optimizedLoops, rank); defineLoops(rewriter, loc, originalLoops, optimizedLoops, rank);
KrnlIterateOperandPack pack(rewriter, originalLoops, optimizedLoops); KrnlIterateOperandPack pack(rewriter, originalLoops, optimizedLoops);
// Iterate over the loop nest using the input shape. // Iterate over the loop nest using the input shape.
for (int i = 0; i < rank; ++i) for (int i = 0; i < rank; ++i)
addDimensionToPack(rewriter, loc, pack, operands[0], i); addDimensionToPack(rewriter, loc, pack, data, i);
auto iterateOp = rewriter.create<KrnlIterateOp>(loc, pack); auto iterateOp = rewriter.create<KrnlIterateOp>(loc, pack);
Block &iterationBlock = iterateOp.bodyRegion().front(); Block &iterationBlock = iterateOp.bodyRegion().front();
@ -75,7 +76,7 @@ struct ONNXTransposeOpLowering : public ConversionPattern {
// the default case). This means that perm was added by shape // the default case). This means that perm was added by shape
// inference or another pass to contain the values corresponding // inference or another pass to contain the values corresponding
// to the default behavior of Transpose. // to the default behavior of Transpose.
for (int i = iterationBlock.getArguments().size()-1; i >= 0; i--) for (int i = iterationBlock.getArguments().size() - 1; i >= 0; i--)
perm.emplace_back(i); perm.emplace_back(i);
} }
@ -84,10 +85,10 @@ struct ONNXTransposeOpLowering : public ConversionPattern {
inLoopIVs.emplace_back(arg); inLoopIVs.emplace_back(arg);
SmallVector<Value, 4> outLoopIVs; SmallVector<Value, 4> outLoopIVs;
for (int i=0; i<iterationBlock.getArguments().size(); ++i) for (int i = 0; i < iterationBlock.getArguments().size(); ++i)
outLoopIVs.emplace_back(iterationBlock.getArguments()[perm[i]]); outLoopIVs.emplace_back(iterationBlock.getArguments()[perm[i]]);
auto inVal = rewriter.create<LoadOp>(loc, operands[0], inLoopIVs); auto inVal = rewriter.create<LoadOp>(loc, data, inLoopIVs);
rewriter.create<StoreOp>(loc, inVal, alloc, outLoopIVs); rewriter.create<StoreOp>(loc, inVal, alloc, outLoopIVs);
rewriter.replaceOp(op, alloc); rewriter.replaceOp(op, alloc);

View File

@ -16,12 +16,13 @@ struct ONNXUnsqueezeOpLowering : public ConversionPattern {
ONNXUnsqueezeOpLowering(MLIRContext *ctx) ONNXUnsqueezeOpLowering(MLIRContext *ctx)
: ConversionPattern(mlir::ONNXUnsqueezeOp::getOperationName(), 1, ctx) {} : ConversionPattern(mlir::ONNXUnsqueezeOp::getOperationName(), 1, ctx) {}
PatternMatchResult PatternMatchResult matchAndRewrite(Operation *op, ArrayRef<Value> operands,
matchAndRewrite(Operation *op, ArrayRef<Value> operands, ConversionPatternRewriter &rewriter) const final {
ConversionPatternRewriter &rewriter) const final { ONNXUnsqueezeOpOperandAdaptor operandAdaptor(operands);
auto loc = op->getLoc(); auto loc = op->getLoc();
auto memRefType = convertToMemRefType(*op->result_type_begin()); auto memRefType = convertToMemRefType(*op->result_type_begin());
int outRank = memRefType.getRank(); int outRank = memRefType.getRank();
Value data = operandAdaptor.data();
// Assume that `axes` has been validated by shape inference. // Assume that `axes` has been validated by shape inference.
// So, here we just get it. // So, here we just get it.
@ -55,7 +56,7 @@ struct ONNXUnsqueezeOpLowering : public ConversionPattern {
for (int outIdx = 0, inIdx = 0; outIdx < memRefShape.size(); ++outIdx) { for (int outIdx = 0, inIdx = 0; outIdx < memRefShape.size(); ++outIdx) {
Value dimVal = nullptr; Value dimVal = nullptr;
if (memRefShape[outIdx] < 0) { if (memRefShape[outIdx] < 0) {
Value index = rewriter.create<DimOp>(loc, operands[0], inIdx); Value index = rewriter.create<DimOp>(loc, data, inIdx);
dimVal = rewriter.create<IndexCastOp>( dimVal = rewriter.create<IndexCastOp>(
loc, index, rewriter.getIntegerType(64)); loc, index, rewriter.getIntegerType(64));
allocOperands.emplace_back(index); allocOperands.emplace_back(index);
@ -74,7 +75,7 @@ struct ONNXUnsqueezeOpLowering : public ConversionPattern {
dealloc.getOperation()->moveBefore(&parentBlock->back()); dealloc.getOperation()->moveBefore(&parentBlock->back());
} }
} }
rewriter.create<KrnlMemcpyOp>(loc, alloc, operands[0], tensorSize); rewriter.create<KrnlMemcpyOp>(loc, alloc, data, tensorSize);
rewriter.replaceOp(op, alloc); rewriter.replaceOp(op, alloc);
return matchSuccess(); return matchSuccess();
} }

View File

@ -28,9 +28,7 @@ using namespace mlir;
namespace { namespace {
static FlatSymbolRefAttr getOrInsertExternFunc(StringRef funcName, static FlatSymbolRefAttr getOrInsertExternFunc(StringRef funcName,
ModuleOp module, ModuleOp module, mlir::LLVM::LLVMType funcType, PatternRewriter &rewriter) {
mlir::LLVM::LLVMType funcType,
PatternRewriter &rewriter) {
auto *context = module.getContext(); auto *context = module.getContext();
if (module.lookupSymbol<LLVM::LLVMFuncOp>(funcName)) { if (module.lookupSymbol<LLVM::LLVMFuncOp>(funcName)) {
auto symbolRef = SymbolRefAttr::get(funcName, context); auto symbolRef = SymbolRefAttr::get(funcName, context);
@ -71,10 +69,10 @@ public:
explicit KrnlMemcpyOpLowering(MLIRContext *context) explicit KrnlMemcpyOpLowering(MLIRContext *context)
: ConversionPattern(KrnlMemcpyOp::getOperationName(), 1, context) {} : ConversionPattern(KrnlMemcpyOp::getOperationName(), 1, context) {}
PatternMatchResult PatternMatchResult matchAndRewrite(Operation *op, ArrayRef<Value> operands,
matchAndRewrite(Operation *op, ArrayRef<Value> operands, ConversionPatternRewriter &rewriter) const override {
ConversionPatternRewriter &rewriter) const override {
auto *context = op->getContext(); auto *context = op->getContext();
KrnlMemcpyOpOperandAdaptor operandAdaptor(operands);
auto loc = op->getLoc(); auto loc = op->getLoc();
auto *llvmDialect = auto *llvmDialect =
op->getContext()->getRegisteredDialect<LLVM::LLVMDialect>(); op->getContext()->getRegisteredDialect<LLVM::LLVMDialect>();
@ -85,35 +83,39 @@ public:
auto memcpyRef = getOrInsertMemcpy(rewriter, parentModule, llvmDialect); auto memcpyRef = getOrInsertMemcpy(rewriter, parentModule, llvmDialect);
// First operand. // First operand.
Type dstType = Type dstType = operandAdaptor.dest()
operands[0].getType().cast<LLVM::LLVMType>().getStructElementType(1); .getType()
.cast<LLVM::LLVMType>()
.getStructElementType(1);
Value alignedDstMemory = rewriter.create<LLVM::ExtractValueOp>( Value alignedDstMemory = rewriter.create<LLVM::ExtractValueOp>(
loc, dstType, operands[0], rewriter.getI64ArrayAttr(1)); loc, dstType, operandAdaptor.dest(), rewriter.getI64ArrayAttr(1));
Value alignedInt8PtrDstMemory = rewriter.create<LLVM::BitcastOp>( Value alignedInt8PtrDstMemory = rewriter.create<LLVM::BitcastOp>(
loc, LLVM::LLVMType::getInt8PtrTy(llvmDialect), alignedDstMemory); loc, LLVM::LLVMType::getInt8PtrTy(llvmDialect), alignedDstMemory);
// Second operand. // Second operand.
Type srcType = Type srcType = operandAdaptor.src()
operands[1].getType().cast<LLVM::LLVMType>().getStructElementType(1); .getType()
.cast<LLVM::LLVMType>()
.getStructElementType(1);
Value alignedSrcMemory = rewriter.create<LLVM::ExtractValueOp>( Value alignedSrcMemory = rewriter.create<LLVM::ExtractValueOp>(
loc, srcType, operands[1], rewriter.getI64ArrayAttr(1)); loc, srcType, operandAdaptor.src(), rewriter.getI64ArrayAttr(1));
Value alignedInt8PtrSrcMemory = rewriter.create<LLVM::BitcastOp>( Value alignedInt8PtrSrcMemory = rewriter.create<LLVM::BitcastOp>(
loc, LLVM::LLVMType::getInt8PtrTy(llvmDialect), alignedSrcMemory); loc, LLVM::LLVMType::getInt8PtrTy(llvmDialect), alignedSrcMemory);
// Size. // Size.
Value int64Size = rewriter.create<LLVM::SExtOp>( Value int64Size = rewriter.create<LLVM::SExtOp>(
loc, LLVM::LLVMType::getInt64Ty(llvmDialect), operands[2]); loc, LLVM::LLVMType::getInt64Ty(llvmDialect), operandAdaptor.size());
// Is volatile (set to false). // Is volatile (set to false).
Value isVolatile = rewriter.create<LLVM::ConstantOp>( Value isVolatile = rewriter.create<LLVM::ConstantOp>(loc,
loc, LLVM::LLVMType::getInt1Ty(llvmDialect), LLVM::LLVMType::getInt1Ty(llvmDialect),
rewriter.getIntegerAttr(rewriter.getIntegerType(1), 0)); rewriter.getIntegerAttr(rewriter.getIntegerType(1), 0));
// Memcpy call // Memcpy call
rewriter.create<CallOp>( rewriter.create<CallOp>(loc, memcpyRef,
loc, memcpyRef, LLVM::LLVMType::getVoidTy(llvmDialect), LLVM::LLVMType::getVoidTy(llvmDialect),
ArrayRef<Value>({alignedInt8PtrDstMemory, alignedInt8PtrSrcMemory, ArrayRef<Value>({alignedInt8PtrDstMemory, alignedInt8PtrSrcMemory,
int64Size, isVolatile})); int64Size, isVolatile}));
rewriter.eraseOp(op); rewriter.eraseOp(op);
return matchSuccess(); return matchSuccess();
@ -123,8 +125,7 @@ private:
/// Return a symbol reference to the memcpy function, inserting it into the /// Return a symbol reference to the memcpy function, inserting it into the
/// module if necessary. /// module if necessary.
static FlatSymbolRefAttr getOrInsertMemcpy(PatternRewriter &rewriter, static FlatSymbolRefAttr getOrInsertMemcpy(PatternRewriter &rewriter,
ModuleOp module, ModuleOp module, LLVM::LLVMDialect *llvmDialect) {
LLVM::LLVMDialect *llvmDialect) {
auto *context = module.getContext(); auto *context = module.getContext();
if (module.lookupSymbol<LLVM::LLVMFuncOp>("llvm.memcpy.p0i8.p0i8.i64")) if (module.lookupSymbol<LLVM::LLVMFuncOp>("llvm.memcpy.p0i8.p0i8.i64"))
return SymbolRefAttr::get("llvm.memcpy.p0i8.p0i8.i64", context); return SymbolRefAttr::get("llvm.memcpy.p0i8.p0i8.i64", context);
@ -134,8 +135,7 @@ private:
auto llvmI8PtrTy = LLVM::LLVMType::getInt8PtrTy(llvmDialect); auto llvmI8PtrTy = LLVM::LLVMType::getInt8PtrTy(llvmDialect);
auto llvmI64Ty = LLVM::LLVMType::getInt64Ty(llvmDialect); auto llvmI64Ty = LLVM::LLVMType::getInt64Ty(llvmDialect);
auto llvmI1Ty = LLVM::LLVMType::getInt1Ty(llvmDialect); auto llvmI1Ty = LLVM::LLVMType::getInt1Ty(llvmDialect);
auto llvmFnType = LLVM::LLVMType::getFunctionTy( auto llvmFnType = LLVM::LLVMType::getFunctionTy(llvmVoidTy,
llvmVoidTy,
ArrayRef<mlir::LLVM::LLVMType>( ArrayRef<mlir::LLVM::LLVMType>(
{llvmI8PtrTy, llvmI8PtrTy, llvmI64Ty, llvmI1Ty}), {llvmI8PtrTy, llvmI8PtrTy, llvmI64Ty, llvmI1Ty}),
false); false);
@ -143,8 +143,8 @@ private:
// Insert the memcpy function into the body of the parent module. // Insert the memcpy function into the body of the parent module.
PatternRewriter::InsertionGuard insertGuard(rewriter); PatternRewriter::InsertionGuard insertGuard(rewriter);
rewriter.setInsertionPointToStart(module.getBody()); rewriter.setInsertionPointToStart(module.getBody());
rewriter.create<LLVM::LLVMFuncOp>(module.getLoc(), rewriter.create<LLVM::LLVMFuncOp>(
"llvm.memcpy.p0i8.p0i8.i64", llvmFnType); module.getLoc(), "llvm.memcpy.p0i8.p0i8.i64", llvmFnType);
return SymbolRefAttr::get("llvm.memcpy.p0i8.p0i8.i64", context); return SymbolRefAttr::get("llvm.memcpy.p0i8.p0i8.i64", context);
} }
}; };
@ -176,18 +176,18 @@ public:
SmallVector<LLVM::LLVMType, 4> inputTys; SmallVector<LLVM::LLVMType, 4> inputTys;
ApiSpec(API id, const std::string &name, LLVM::LLVMType outputTy, ApiSpec(API id, const std::string &name, LLVM::LLVMType outputTy,
ArrayRef<LLVM::LLVMType> inputTys) ArrayRef<LLVM::LLVMType> inputTys)
: id(id), name(name), outputTy(outputTy), : id(id), name(name), outputTy(outputTy),
inputTys(inputTys.begin(), inputTys.end()) {} inputTys(inputTys.begin(), inputTys.end()) {}
LLVM::LLVMType funcTy() { LLVM::LLVMType funcTy() {
return LLVM::LLVMType::getFunctionTy(outputTy, inputTys, return LLVM::LLVMType::getFunctionTy(outputTy, inputTys,
/*isVarArg=*/false); /*isVarArg=*/false);
} }
}; };
PatternMatchResult matchAndRewrite(KrnlEntryPointOp op, PatternMatchResult matchAndRewrite(
PatternRewriter &rewriter) const override { KrnlEntryPointOp op, PatternRewriter &rewriter) const override {
auto *llvmDialect = auto *llvmDialect =
op.getContext()->getRegisteredDialect<LLVM::LLVMDialect>(); op.getContext()->getRegisteredDialect<LLVM::LLVMDialect>();
@ -248,7 +248,7 @@ public:
auto idxVal = rewriter.create<LLVM::ConstantOp>( auto idxVal = rewriter.create<LLVM::ConstantOp>(
loc, int32Ty, rewriter.getI32IntegerAttr(i)); loc, int32Ty, rewriter.getI32IntegerAttr(i));
auto dynMemRef = callApi(rewriter, loc, apiRegistry, API::GET_DYN_MEM_REF, auto dynMemRef = callApi(rewriter, loc, apiRegistry, API::GET_DYN_MEM_REF,
{wrappedInput, idxVal}); {wrappedInput, idxVal});
// Create a (static) memref type corresponding to the i-th memref input to // Create a (static) memref type corresponding to the i-th memref input to
// the inference function on stack, and load it to memRef. // the inference function on stack, and load it to memRef.
@ -257,12 +257,12 @@ public:
auto one = rewriter.create<LLVM::ConstantOp>( auto one = rewriter.create<LLVM::ConstantOp>(
loc, int32Ty, rewriter.getI32IntegerAttr(1)); loc, int32Ty, rewriter.getI32IntegerAttr(1));
Value ptrToMemRef = rewriter.create<LLVM::AllocaOp>(loc, memRefPtrTy, one, Value ptrToMemRef = rewriter.create<LLVM::AllocaOp>(loc, memRefPtrTy, one,
/*alignment=*/0); /*alignment=*/0);
// Fill in the memref underlying ptrToMemRef with information extracted // Fill in the memref underlying ptrToMemRef with information extracted
// from dynMemRef. // from dynMemRef.
fillPtrToMemRefWithDynMemRef(dynMemRef, ptrToMemRef, rewriter, loc, fillPtrToMemRefWithDynMemRef(
apiRegistry, llvmDialect); dynMemRef, ptrToMemRef, rewriter, loc, apiRegistry, llvmDialect);
// ptrToMemRef will be an input to main computation graph function. // ptrToMemRef will be an input to main computation graph function.
staticInputs.emplace_back(ptrToMemRef); staticInputs.emplace_back(ptrToMemRef);
@ -273,14 +273,14 @@ public:
assert(numOutputs == 1 && "only support 1 output tensor now."); assert(numOutputs == 1 && "only support 1 output tensor now.");
// Call static entry point with the memref ptrs created, and get output. // Call static entry point with the memref ptrs created, and get output.
auto outputMemRefs = rewriter.create<LLVM::CallOp>( auto outputMemRefs = rewriter.create<LLVM::CallOp>(loc,
loc, staticEntryPointTy.getFunctionResultType(), staticEntryPointTy.getFunctionResultType(),
rewriter.getSymbolRefAttr(wrappedStaticEntryPointFuncName), rewriter.getSymbolRefAttr(wrappedStaticEntryPointFuncName),
staticInputs); staticInputs);
// Create wrapped output. // Create wrapped output.
auto wrappedOutput = callApi(rewriter, loc, apiRegistry, auto wrappedOutput = callApi(
API::CREATE_ORDERED_DYN_MEM_REF_DICT, {}); rewriter, loc, apiRegistry, API::CREATE_ORDERED_DYN_MEM_REF_DICT, {});
// Get the first memref returned, convert to a dynamic memref and store // Get the first memref returned, convert to a dynamic memref and store
// it in the wrapped Output. // it in the wrapped Output.
@ -290,17 +290,17 @@ public:
auto outMemRefRankVal = rewriter.create<LLVM::ConstantOp>( auto outMemRefRankVal = rewriter.create<LLVM::ConstantOp>(
loc, int32Ty, rewriter.getI32IntegerAttr(outMemRefRank)); loc, int32Ty, rewriter.getI32IntegerAttr(outMemRefRank));
auto outDynMemRef = callApi(rewriter, loc, apiRegistry, auto outDynMemRef = callApi(rewriter, loc, apiRegistry,
API::CREATE_DYN_MEM_REF, {outMemRefRankVal}); API::CREATE_DYN_MEM_REF, {outMemRefRankVal});
fillDynMemRefWithMemRef(outMemRef, outDynMemRef, rewriter, loc, apiRegistry, fillDynMemRefWithMemRef(
llvmDialect); outMemRef, outDynMemRef, rewriter, loc, apiRegistry, llvmDialect);
auto zero = rewriter.create<LLVM::ConstantOp>( auto zero = rewriter.create<LLVM::ConstantOp>(
loc, int32Ty, rewriter.getI32IntegerAttr(0)); loc, int32Ty, rewriter.getI32IntegerAttr(0));
callApi(rewriter, loc, apiRegistry, API::SET_DYN_MEM_REF, callApi(rewriter, loc, apiRegistry, API::SET_DYN_MEM_REF,
{wrappedOutput, zero, outDynMemRef}); {wrappedOutput, zero, outDynMemRef});
// Return wrapped output. // Return wrapped output.
rewriter.create<LLVM::ReturnOp>(loc, rewriter.create<LLVM::ReturnOp>(
SmallVector<Value, 1>({wrappedOutput})); loc, SmallVector<Value, 1>({wrappedOutput}));
return matchSuccess(); return matchSuccess();
} }
@ -308,7 +308,7 @@ private:
using ApiRegistry = std::map<API, ApiSpec>; using ApiRegistry = std::map<API, ApiSpec>;
ApiRegistry RegisterAllApis(ModuleOp &module, PatternRewriter &rewriter, ApiRegistry RegisterAllApis(ModuleOp &module, PatternRewriter &rewriter,
LLVM::LLVMDialect *llvmDialect) const { LLVM::LLVMDialect *llvmDialect) const {
using LLVMType = LLVM::LLVMType; using LLVMType = LLVM::LLVMType;
auto voidTy = LLVMType::getVoidTy(llvmDialect); auto voidTy = LLVMType::getVoidTy(llvmDialect);
auto opaquePtrTy = LLVMType::getInt8PtrTy(llvmDialect); auto opaquePtrTy = LLVMType::getInt8PtrTy(llvmDialect);
@ -335,8 +335,8 @@ private:
// identities to a symbol reference to the API function. // identities to a symbol reference to the API function.
ApiRegistry registry; ApiRegistry registry;
for (auto &apiSpec : apiSpecs) { for (auto &apiSpec : apiSpecs) {
apiSpec.symbolRef = getOrInsertExternFunc(apiSpec.name, module, apiSpec.symbolRef = getOrInsertExternFunc(
apiSpec.funcTy(), rewriter); apiSpec.name, module, apiSpec.funcTy(), rewriter);
registry.emplace(apiSpec.id, apiSpec); registry.emplace(apiSpec.id, apiSpec);
} }
@ -346,10 +346,10 @@ private:
// Call a registered API, return the return SSA values if only one result is // Call a registered API, return the return SSA values if only one result is
// returned, otherwise return nullptr. // returned, otherwise return nullptr.
Value callApi(PatternRewriter &rewriter, Location loc, ApiRegistry registry, Value callApi(PatternRewriter &rewriter, Location loc, ApiRegistry registry,
API apiId, ArrayRef<Value> params) const { API apiId, ArrayRef<Value> params) const {
auto returnVals = rewriter.create<LLVM::CallOp>( auto returnVals =
loc, registry.at(apiId).outputTy, registry.at(apiId).symbolRef, rewriter.create<LLVM::CallOp>(loc, registry.at(apiId).outputTy,
ArrayRef<Value>(params)); registry.at(apiId).symbolRef, ArrayRef<Value>(params));
if (returnVals.getNumResults() == 1) if (returnVals.getNumResults() == 1)
return returnVals.getResult(0); return returnVals.getResult(0);
return nullptr; return nullptr;
@ -358,7 +358,7 @@ private:
// Helper function to insert an entry block to LLVM function. // Helper function to insert an entry block to LLVM function.
// (TODO): upstream this to MLIR. // (TODO): upstream this to MLIR.
Block &createEntryBlock(LLVM::LLVMType &dynEntryPointFuncType, Block &createEntryBlock(LLVM::LLVMType &dynEntryPointFuncType,
LLVM::LLVMFuncOp &dynamicEntryPointFunc) const { LLVM::LLVMFuncOp &dynamicEntryPointFunc) const {
// Add entry block: // Add entry block:
auto *entryPointEntryBlock = new Block(); auto *entryPointEntryBlock = new Block();
dynamicEntryPointFunc.push_back(entryPointEntryBlock); dynamicEntryPointFunc.push_back(entryPointEntryBlock);
@ -370,10 +370,9 @@ private:
} }
void fillPtrToMemRefWithDynMemRef(Value &dynMemRef, Value &ptrToMemRef, void fillPtrToMemRefWithDynMemRef(Value &dynMemRef, Value &ptrToMemRef,
PatternRewriter &rewriter, PatternRewriter &rewriter, const Location &loc,
const Location &loc, const std::map<API, ApiSpec> &apiRegistry,
const std::map<API, ApiSpec> &apiRegistry, LLVM::LLVMDialect *llvmDialect) const {
LLVM::LLVMDialect *llvmDialect) const {
auto memRefPtrTy = ptrToMemRef.getType().dyn_cast<LLVM::LLVMType>(); auto memRefPtrTy = ptrToMemRef.getType().dyn_cast<LLVM::LLVMType>();
auto memRefTy = memRefPtrTy.getPointerElementTy(); auto memRefTy = memRefPtrTy.getPointerElementTy();
auto int64Ty = LLVM::LLVMType::getInt64Ty(llvmDialect); auto int64Ty = LLVM::LLVMType::getInt64Ty(llvmDialect);
@ -385,18 +384,15 @@ private:
callApi(rewriter, loc, apiRegistry, API::GET_DATA, {dynMemRef}); callApi(rewriter, loc, apiRegistry, API::GET_DATA, {dynMemRef});
dataPtr = rewriter.create<LLVM::BitcastOp>( dataPtr = rewriter.create<LLVM::BitcastOp>(
loc, memRefTy.getStructElementType(0), dataPtr); loc, memRefTy.getStructElementType(0), dataPtr);
memRef = rewriter.create<LLVM::InsertValueOp>( memRef = rewriter.create<LLVM::InsertValueOp>(loc, memRefTy, memRef,
loc, memRefTy, memRef, dataPtr, dataPtr, rewriter.getArrayAttr({rewriter.getI32IntegerAttr(0)}));
rewriter.getArrayAttr({rewriter.getI32IntegerAttr(0)})); memRef = rewriter.create<LLVM::InsertValueOp>(loc, memRefTy, memRef,
memRef = rewriter.create<LLVM::InsertValueOp>( dataPtr, rewriter.getArrayAttr({rewriter.getI32IntegerAttr(1)}));
loc, memRefTy, memRef, dataPtr,
rewriter.getArrayAttr({rewriter.getI32IntegerAttr(1)}));
// Use zero offset now. // Use zero offset now.
auto zero = rewriter.create<LLVM::ConstantOp>( auto zero = rewriter.create<LLVM::ConstantOp>(
loc, int64Ty, rewriter.getI64IntegerAttr(0)); loc, int64Ty, rewriter.getI64IntegerAttr(0));
memRef = rewriter.create<LLVM::InsertValueOp>( memRef = rewriter.create<LLVM::InsertValueOp>(loc, memRefTy, memRef, zero,
loc, memRefTy, memRef, zero,
rewriter.getArrayAttr({rewriter.getI32IntegerAttr(2)})); rewriter.getArrayAttr({rewriter.getI32IntegerAttr(2)}));
// Get rank, sizes array ptr and strides array ptr. // Get rank, sizes array ptr and strides array ptr.
@ -411,24 +407,22 @@ private:
loc, int64Ty, rewriter.getI64IntegerAttr(i)); loc, int64Ty, rewriter.getI64IntegerAttr(i));
// Insert size of the dimension. // Insert size of the dimension.
auto dimSizePtr = rewriter.create<LLVM::GEPOp>( auto dimSizePtr = rewriter.create<LLVM::GEPOp>(loc,
loc, int64Ty.getPointerTo(), sizesArrayPtr, int64Ty.getPointerTo(), sizesArrayPtr, ArrayRef<Value>({dimIdx}));
ArrayRef<Value>({dimIdx})); auto dimSize = rewriter.create<LLVM::LoadOp>(
auto dimSize = rewriter.create<LLVM::LoadOp>(loc, int64Ty.getPointerTo(), loc, int64Ty.getPointerTo(), dimSizePtr);
dimSizePtr); memRef = rewriter.create<LLVM::InsertValueOp>(loc, memRefTy, memRef,
memRef = rewriter.create<LLVM::InsertValueOp>( dimSize,
loc, memRefTy, memRef, dimSize,
rewriter.getArrayAttr( rewriter.getArrayAttr(
{rewriter.getI64IntegerAttr(3), rewriter.getI64IntegerAttr(i)})); {rewriter.getI64IntegerAttr(3), rewriter.getI64IntegerAttr(i)}));
// Insert stride of the dimension. // Insert stride of the dimension.
auto dimStridePtr = rewriter.create<LLVM::GEPOp>( auto dimStridePtr = rewriter.create<LLVM::GEPOp>(loc,
loc, int64Ty.getPointerTo(), sizesArrayPtr, int64Ty.getPointerTo(), sizesArrayPtr, ArrayRef<Value>({dimIdx}));
ArrayRef<Value>({dimIdx}));
auto dimStride = rewriter.create<LLVM::LoadOp>( auto dimStride = rewriter.create<LLVM::LoadOp>(
loc, int64Ty.getPointerTo(), dimStridePtr); loc, int64Ty.getPointerTo(), dimStridePtr);
memRef = rewriter.create<LLVM::InsertValueOp>( memRef = rewriter.create<LLVM::InsertValueOp>(loc, memRefTy, memRef,
loc, memRefTy, memRef, dimStride, dimStride,
rewriter.getArrayAttr( rewriter.getArrayAttr(
{rewriter.getI64IntegerAttr(4), rewriter.getI64IntegerAttr(i)})); {rewriter.getI64IntegerAttr(4), rewriter.getI64IntegerAttr(i)}));
} }
@ -437,20 +431,20 @@ private:
} }
void fillDynMemRefWithMemRef(Value &outMemRef, Value &outDynMemRef, void fillDynMemRefWithMemRef(Value &outMemRef, Value &outDynMemRef,
PatternRewriter &rewriter, const Location &loc, PatternRewriter &rewriter, const Location &loc,
const std::map<API, ApiSpec> &apiRegistry, const std::map<API, ApiSpec> &apiRegistry,
LLVM::LLVMDialect *llvmDialect) const { LLVM::LLVMDialect *llvmDialect) const {
auto outMemRefTy = outMemRef.getType().dyn_cast<LLVM::LLVMType>(); auto outMemRefTy = outMemRef.getType().dyn_cast<LLVM::LLVMType>();
auto int64Ty = LLVM::LLVMType::getInt64Ty(llvmDialect); auto int64Ty = LLVM::LLVMType::getInt64Ty(llvmDialect);
// Extract the data pointer, and record it in dynamic mem ref created. // Extract the data pointer, and record it in dynamic mem ref created.
Value outMemRefDataPtr = rewriter.create<LLVM::ExtractValueOp>( Value outMemRefDataPtr = rewriter.create<LLVM::ExtractValueOp>(loc,
loc, outMemRefTy.getStructElementType(0), outMemRef, outMemRefTy.getStructElementType(0), outMemRef,
rewriter.getArrayAttr({rewriter.getI64IntegerAttr(0)})); rewriter.getArrayAttr({rewriter.getI64IntegerAttr(0)}));
outMemRefDataPtr = rewriter.create<LLVM::BitcastOp>( outMemRefDataPtr = rewriter.create<LLVM::BitcastOp>(
loc, LLVM::LLVMType::getInt8PtrTy(llvmDialect), outMemRefDataPtr); loc, LLVM::LLVMType::getInt8PtrTy(llvmDialect), outMemRefDataPtr);
callApi(rewriter, loc, apiRegistry, API::SET_DATA, callApi(rewriter, loc, apiRegistry, API::SET_DATA,
{outDynMemRef, outMemRefDataPtr}); {outDynMemRef, outMemRefDataPtr});
auto rank = getRankFromMemRefType(outMemRefTy); auto rank = getRankFromMemRefType(outMemRefTy);
auto sizesArrayPtr = auto sizesArrayPtr =
@ -463,23 +457,21 @@ private:
loc, int64Ty, rewriter.getI64IntegerAttr(i)); loc, int64Ty, rewriter.getI64IntegerAttr(i));
// Transfer size of dimension from memref to dynamic memref. // Transfer size of dimension from memref to dynamic memref.
auto dimSize = rewriter.create<LLVM::ExtractValueOp>( auto dimSize = rewriter.create<LLVM::ExtractValueOp>(loc, int64Ty,
loc, int64Ty, outMemRef, outMemRef,
rewriter.getArrayAttr( rewriter.getArrayAttr(
{rewriter.getI64IntegerAttr(3), rewriter.getI64IntegerAttr(i)})); {rewriter.getI64IntegerAttr(3), rewriter.getI64IntegerAttr(i)}));
auto dimSizePtr = rewriter.create<LLVM::GEPOp>( auto dimSizePtr = rewriter.create<LLVM::GEPOp>(loc,
loc, int64Ty.getPointerTo(), sizesArrayPtr, int64Ty.getPointerTo(), sizesArrayPtr, ArrayRef<Value>({dimIdx}));
ArrayRef<Value>({dimIdx}));
rewriter.create<LLVM::StoreOp>(loc, dimSize, dimSizePtr); rewriter.create<LLVM::StoreOp>(loc, dimSize, dimSizePtr);
// Transfer stride of dimension from memref to dynamic memref. // Transfer stride of dimension from memref to dynamic memref.
auto dimStride = rewriter.create<LLVM::ExtractValueOp>( auto dimStride = rewriter.create<LLVM::ExtractValueOp>(loc, int64Ty,
loc, int64Ty, outMemRef, outMemRef,
rewriter.getArrayAttr( rewriter.getArrayAttr(
{rewriter.getI64IntegerAttr(4), rewriter.getI64IntegerAttr(i)})); {rewriter.getI64IntegerAttr(4), rewriter.getI64IntegerAttr(i)}));
auto dimStridePtr = rewriter.create<LLVM::GEPOp>( auto dimStridePtr = rewriter.create<LLVM::GEPOp>(loc,
loc, int64Ty.getPointerTo(), stridesArrayPtr, int64Ty.getPointerTo(), stridesArrayPtr, ArrayRef<Value>({dimIdx}));
ArrayRef<Value>({dimIdx}));
rewriter.create<LLVM::StoreOp>(loc, dimStride, dimStridePtr); rewriter.create<LLVM::StoreOp>(loc, dimStride, dimStridePtr);
} }
} }
@ -511,8 +503,8 @@ void KrnlToLLVMLoweringPass::runOnModule() {
populateAffineToStdConversionPatterns(patterns, &getContext()); populateAffineToStdConversionPatterns(patterns, &getContext());
populateLoopToStdConversionPatterns(patterns, &getContext()); populateLoopToStdConversionPatterns(patterns, &getContext());
populateStdToLLVMConversionPatterns(typeConverter, patterns, populateStdToLLVMConversionPatterns(typeConverter, patterns,
/*useAlloca=*/false, /*useAlloca=*/false,
/*emitCWrapper=*/true); /*emitCWrapper=*/true);
// Lower from the `krnl` dialect i.e. the Reshape operation. // Lower from the `krnl` dialect i.e. the Reshape operation.
patterns.insert<KrnlMemcpyOpLowering, KrnlEntryPointOpLowering>( patterns.insert<KrnlMemcpyOpLowering, KrnlEntryPointOpLowering>(
@ -530,5 +522,5 @@ std::unique_ptr<mlir::Pass> mlir::createKrnlLowerToLLVMPass() {
return std::make_unique<KrnlToLLVMLoweringPass>(); return std::make_unique<KrnlToLLVMLoweringPass>();
} }
static PassRegistration<KrnlToLLVMLoweringPass> static PassRegistration<KrnlToLLVMLoweringPass> pass(
pass("lower-all-llvm", "Lower the Krnl Affine and Std dialects to LLVM."); "lower-all-llvm", "Lower the Krnl Affine and Std dialects to LLVM.");

View File

@ -22,7 +22,7 @@ namespace {
bool hasNonZeroInArrayAttr(ArrayAttr attrs) { bool hasNonZeroInArrayAttr(ArrayAttr attrs) {
bool allZeros = true; bool allZeros = true;
if (attrs) { if (attrs) {
for (auto attr: attrs.getValue()) { for (auto attr : attrs.getValue()) {
if (attr.cast<IntegerAttr>().getInt() > 0) { if (attr.cast<IntegerAttr>().getInt() > 0) {
allZeros = false; allZeros = false;
break; break;
@ -54,7 +54,7 @@ ArrayAttr createArrayAttrOfZeros(
// This function is used for padding attribute in MaxPoolSingleOut. // This function is used for padding attribute in MaxPoolSingleOut.
ArrayAttr insertZerosForNonPaddedDims( ArrayAttr insertZerosForNonPaddedDims(
PatternRewriter &rewriter, ArrayAttr origAttrs, int extensionLength) { PatternRewriter &rewriter, ArrayAttr origAttrs, int extensionLength) {
int nDims = (int) origAttrs.getValue().size() / 2; int nDims = (int)origAttrs.getValue().size() / 2;
int nElements = (nDims + extensionLength) * 2; int nElements = (nDims + extensionLength) * 2;
SmallVector<int64_t, 4> pads(nElements, 0); SmallVector<int64_t, 4> pads(nElements, 0);
for (int i = 0; i < nDims; ++i) { for (int i = 0; i < nDims; ++i) {