clean operands using names provided by operandAdaptors (#56)
* clean operands using names provided by operandAdaptors * reverted changes that were erronerous, from Tung's comment * clang format issues
This commit is contained in:
parent
844dcd8b1f
commit
b422116f12
|
@ -24,10 +24,11 @@ struct ONNXGemmOpLowering : public ConversionPattern {
|
||||||
bool hasBias = !op->getOperand(2).getType().isa<NoneType>();
|
bool hasBias = !op->getOperand(2).getType().isa<NoneType>();
|
||||||
|
|
||||||
Value A, B, C;
|
Value A, B, C;
|
||||||
A = operands[0];
|
ONNXGemmOpOperandAdaptor operandAdaptor(operands);
|
||||||
B = operands[1];
|
A = operandAdaptor.A();
|
||||||
|
B = operandAdaptor.B();
|
||||||
if (hasBias)
|
if (hasBias)
|
||||||
C = operands[2];
|
C = operandAdaptor.C();
|
||||||
|
|
||||||
auto memRefType = convertToMemRefType(*op->result_type_begin());
|
auto memRefType = convertToMemRefType(*op->result_type_begin());
|
||||||
|
|
||||||
|
|
|
@ -21,8 +21,9 @@ struct ONNXMatMulOpLowering : public ConversionPattern {
|
||||||
ConversionPatternRewriter &rewriter) const final {
|
ConversionPatternRewriter &rewriter) const final {
|
||||||
auto loc = op->getLoc();
|
auto loc = op->getLoc();
|
||||||
|
|
||||||
Value A = operands[0];
|
ONNXMatMulOpOperandAdaptor operandAdaptor(operands);
|
||||||
Value B = operands[1];
|
Value A = operandAdaptor.A();
|
||||||
|
Value B = operandAdaptor.B();
|
||||||
auto AShape = A.getType().cast<MemRefType>().getShape();
|
auto AShape = A.getType().cast<MemRefType>().getShape();
|
||||||
auto BShape = B.getType().cast<MemRefType>().getShape();
|
auto BShape = B.getType().cast<MemRefType>().getShape();
|
||||||
|
|
||||||
|
|
|
@ -15,8 +15,7 @@ using namespace mlir;
|
||||||
struct ONNXSoftmaxOpLowering : public ConversionPattern {
|
struct ONNXSoftmaxOpLowering : public ConversionPattern {
|
||||||
ONNXSoftmaxOpLowering(MLIRContext *ctx)
|
ONNXSoftmaxOpLowering(MLIRContext *ctx)
|
||||||
: ConversionPattern(mlir::ONNXSoftmaxOp::getOperationName(), 1, ctx) {}
|
: ConversionPattern(mlir::ONNXSoftmaxOp::getOperationName(), 1, ctx) {}
|
||||||
PatternMatchResult
|
PatternMatchResult matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
|
||||||
ConversionPatternRewriter &rewriter) const final {
|
ConversionPatternRewriter &rewriter) const final {
|
||||||
// softmax(x) = let max_x = max(x) in
|
// softmax(x) = let max_x = max(x) in
|
||||||
// let exp_x = exp(x - max_x) in
|
// let exp_x = exp(x - max_x) in
|
||||||
|
@ -29,7 +28,8 @@ struct ONNXSoftmaxOpLowering : public ConversionPattern {
|
||||||
assert(axis >= -rank && axis <= rank - 1);
|
assert(axis >= -rank && axis <= rank - 1);
|
||||||
|
|
||||||
auto loc = op->getLoc();
|
auto loc = op->getLoc();
|
||||||
|
ONNXSoftmaxOpOperandAdaptor operandAdaptor(operands);
|
||||||
|
Value input = operandAdaptor.input();
|
||||||
// Insert an allocation and deallocation for the result of this operation.
|
// Insert an allocation and deallocation for the result of this operation.
|
||||||
auto elementType = memRefType.getElementType();
|
auto elementType = memRefType.getElementType();
|
||||||
|
|
||||||
|
@ -38,8 +38,8 @@ struct ONNXSoftmaxOpLowering : public ConversionPattern {
|
||||||
if (hasAllConstantDimensions(memRefType))
|
if (hasAllConstantDimensions(memRefType))
|
||||||
alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc);
|
alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc);
|
||||||
else
|
else
|
||||||
alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc,
|
alloc = insertAllocAndDealloc(
|
||||||
operands[0]);
|
memRefType, loc, rewriter, insertDealloc, input);
|
||||||
|
|
||||||
// Shape of the result
|
// Shape of the result
|
||||||
auto memRefShape = memRefType.getShape();
|
auto memRefShape = memRefType.getShape();
|
||||||
|
@ -49,15 +49,14 @@ struct ONNXSoftmaxOpLowering : public ConversionPattern {
|
||||||
Value sumOp = insertAllocAndDealloc(scalarMemRefType, loc, rewriter, true);
|
Value sumOp = insertAllocAndDealloc(scalarMemRefType, loc, rewriter, true);
|
||||||
Value maxOp = insertAllocAndDealloc(scalarMemRefType, loc, rewriter, true);
|
Value maxOp = insertAllocAndDealloc(scalarMemRefType, loc, rewriter, true);
|
||||||
Value zero = emitConstantOp(rewriter, loc, elementType, 0);
|
Value zero = emitConstantOp(rewriter, loc, elementType, 0);
|
||||||
Value negInfinity = rewriter.create<ConstantOp>(
|
Value negInfinity = rewriter.create<ConstantOp>(loc,
|
||||||
loc,
|
|
||||||
FloatAttr::get(elementType, -std::numeric_limits<float>::infinity()));
|
FloatAttr::get(elementType, -std::numeric_limits<float>::infinity()));
|
||||||
|
|
||||||
// Define loops.
|
// Define loops.
|
||||||
std::vector<Value> originalLoops;
|
std::vector<Value> originalLoops;
|
||||||
std::vector<Value> optimizedLoops;
|
std::vector<Value> optimizedLoops;
|
||||||
Block *optimizationBlock = defineLoops(rewriter, loc, originalLoops,
|
Block *optimizationBlock =
|
||||||
optimizedLoops, rank);
|
defineLoops(rewriter, loc, originalLoops, optimizedLoops, rank);
|
||||||
|
|
||||||
// Coerce the input into a 2-D tensor. `axis` will be the coercing point.
|
// Coerce the input into a 2-D tensor. `axis` will be the coercing point.
|
||||||
// This coercing follows the softmax definition in ONNX:
|
// This coercing follows the softmax definition in ONNX:
|
||||||
|
@ -75,7 +74,7 @@ struct ONNXSoftmaxOpLowering : public ConversionPattern {
|
||||||
}
|
}
|
||||||
KrnlIterateOperandPack outerPack(rewriter, outerLoops, optimizedOuterLoops);
|
KrnlIterateOperandPack outerPack(rewriter, outerLoops, optimizedOuterLoops);
|
||||||
for (int i = 0; i < axis; ++i)
|
for (int i = 0; i < axis; ++i)
|
||||||
addDimensionToPack(rewriter, loc, outerPack, operands[0], i);
|
addDimensionToPack(rewriter, loc, outerPack, input, i);
|
||||||
|
|
||||||
// Define an inner loop with respect to axis.
|
// Define an inner loop with respect to axis.
|
||||||
std::vector<Value> innerLoops, optimizedInnerLoops;
|
std::vector<Value> innerLoops, optimizedInnerLoops;
|
||||||
|
@ -87,7 +86,7 @@ struct ONNXSoftmaxOpLowering : public ConversionPattern {
|
||||||
}
|
}
|
||||||
KrnlIterateOperandPack innerPack(rewriter, innerLoops, optimizedInnerLoops);
|
KrnlIterateOperandPack innerPack(rewriter, innerLoops, optimizedInnerLoops);
|
||||||
for (int i = axis; i < rank; ++i)
|
for (int i = axis; i < rank; ++i)
|
||||||
addDimensionToPack(rewriter, loc, innerPack, operands[0], i);
|
addDimensionToPack(rewriter, loc, innerPack, input, i);
|
||||||
|
|
||||||
KrnlIterateOp outerIterateOp, maxIterateOp, sumIterateOp, softmaxIterateOp;
|
KrnlIterateOp outerIterateOp, maxIterateOp, sumIterateOp, softmaxIterateOp;
|
||||||
SmallVector<Value, 4> outerLoopIVs;
|
SmallVector<Value, 4> outerLoopIVs;
|
||||||
|
@ -144,7 +143,7 @@ struct ONNXSoftmaxOpLowering : public ConversionPattern {
|
||||||
|
|
||||||
// Compute the max value.
|
// Compute the max value.
|
||||||
Value max = rewriter.create<LoadOp>(loc, maxOp);
|
Value max = rewriter.create<LoadOp>(loc, maxOp);
|
||||||
Value nextMax = rewriter.create<LoadOp>(loc, operands[0], maxLoopIVs);
|
Value nextMax = rewriter.create<LoadOp>(loc, input, maxLoopIVs);
|
||||||
auto maxCond =
|
auto maxCond =
|
||||||
rewriter.create<CmpFOp>(loc, CmpFPredicate::OGT, max, nextMax);
|
rewriter.create<CmpFOp>(loc, CmpFPredicate::OGT, max, nextMax);
|
||||||
max = rewriter.create<SelectOp>(loc, maxCond, max, nextMax);
|
max = rewriter.create<SelectOp>(loc, maxCond, max, nextMax);
|
||||||
|
@ -167,7 +166,7 @@ struct ONNXSoftmaxOpLowering : public ConversionPattern {
|
||||||
|
|
||||||
// Sum up values.
|
// Sum up values.
|
||||||
Value sum = rewriter.create<LoadOp>(loc, sumOp);
|
Value sum = rewriter.create<LoadOp>(loc, sumOp);
|
||||||
Value next = rewriter.create<LoadOp>(loc, operands[0], sumLoopIVs);
|
Value next = rewriter.create<LoadOp>(loc, input, sumLoopIVs);
|
||||||
Value sub = rewriter.create<SubFOp>(loc, next, max);
|
Value sub = rewriter.create<SubFOp>(loc, next, max);
|
||||||
Value exp = rewriter.create<ExpOp>(loc, sub);
|
Value exp = rewriter.create<ExpOp>(loc, sub);
|
||||||
sum = rewriter.create<AddFOp>(loc, sum, exp);
|
sum = rewriter.create<AddFOp>(loc, sum, exp);
|
||||||
|
|
|
@ -26,12 +26,6 @@ struct ONNXConvOpLowering : public ConversionPattern {
|
||||||
bool insertDealloc = checkInsertDealloc(op);
|
bool insertDealloc = checkInsertDealloc(op);
|
||||||
ONNXConvOp convOp = llvm::dyn_cast<ONNXConvOp>(op);
|
ONNXConvOp convOp = llvm::dyn_cast<ONNXConvOp>(op);
|
||||||
|
|
||||||
if (hasAllConstantDimensions(memRefType))
|
|
||||||
alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc);
|
|
||||||
else
|
|
||||||
alloc = insertAllocAndDealloc(
|
|
||||||
memRefType, loc, rewriter, insertDealloc, {operands[0]});
|
|
||||||
|
|
||||||
auto resultShape = memRefType.getShape();
|
auto resultShape = memRefType.getShape();
|
||||||
auto inputOperand = operandAdaptor.X();
|
auto inputOperand = operandAdaptor.X();
|
||||||
auto inputShape = inputOperand.getType().cast<MemRefType>().getShape();
|
auto inputShape = inputOperand.getType().cast<MemRefType>().getShape();
|
||||||
|
@ -40,6 +34,12 @@ struct ONNXConvOpLowering : public ConversionPattern {
|
||||||
auto biasOperand = operandAdaptor.B();
|
auto biasOperand = operandAdaptor.B();
|
||||||
bool hasBias = !biasOperand.getType().isa<NoneType>();
|
bool hasBias = !biasOperand.getType().isa<NoneType>();
|
||||||
|
|
||||||
|
if (hasAllConstantDimensions(memRefType))
|
||||||
|
alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc);
|
||||||
|
else
|
||||||
|
alloc = insertAllocAndDealloc(
|
||||||
|
memRefType, loc, rewriter, insertDealloc, {inputOperand});
|
||||||
|
|
||||||
// R = Conv(D, K)
|
// R = Conv(D, K)
|
||||||
//
|
//
|
||||||
// The input/output shapes will look like this:
|
// The input/output shapes will look like this:
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
//===----------- Normalization.cpp - Lowering Normalization Ops ------------===//
|
//===----------- Normalization.cpp - Lowering Normalization Ops -----------===//
|
||||||
//
|
//
|
||||||
// Copyright 2019 The IBM Research Authors.
|
// Copyright 2019 The IBM Research Authors.
|
||||||
//
|
//
|
||||||
|
@ -21,21 +21,21 @@ struct ONNXBatchNormalizationTestModeOpLowering : public ConversionPattern {
|
||||||
ConversionPatternRewriter &rewriter) const final {
|
ConversionPatternRewriter &rewriter) const final {
|
||||||
// batchnorm{epsilon}(x, scale, bias, mean, variance) =
|
// batchnorm{epsilon}(x, scale, bias, mean, variance) =
|
||||||
// scale * (x - mean) / sqrt(variance + epsilon) + bias
|
// scale * (x - mean) / sqrt(variance + epsilon) + bias
|
||||||
|
ONNXBatchNormalizationTestModeOpOperandAdaptor operandAdaptor(operands);
|
||||||
auto loc = op->getLoc();
|
auto loc = op->getLoc();
|
||||||
|
|
||||||
auto memRefType = convertToMemRefType(*op->result_type_begin());
|
auto memRefType = convertToMemRefType(*op->result_type_begin());
|
||||||
auto epsilonAttr =
|
auto epsilonAttr = FloatAttr::get(memRefType.getElementType(),
|
||||||
FloatAttr::get(memRefType.getElementType(),
|
|
||||||
llvm::dyn_cast<ONNXBatchNormalizationTestModeOp>(op)
|
llvm::dyn_cast<ONNXBatchNormalizationTestModeOp>(op)
|
||||||
.epsilon()
|
.epsilon()
|
||||||
.convertToFloat());
|
.convertToFloat());
|
||||||
auto epsilon = rewriter.create<ConstantOp>(loc, epsilonAttr);
|
auto epsilon = rewriter.create<ConstantOp>(loc, epsilonAttr);
|
||||||
|
|
||||||
auto operand = operands[0];
|
auto operand = operandAdaptor.X();
|
||||||
auto scale = operands[1];
|
auto scale = operandAdaptor.scale();
|
||||||
auto bias = operands[2];
|
auto bias = operandAdaptor.B();
|
||||||
auto mean = operands[3];
|
auto mean = operandAdaptor.mean();
|
||||||
auto variance = operands[4];
|
auto variance = operandAdaptor.var();
|
||||||
|
|
||||||
// Insert an allocation and deallocation for the result of this operation.
|
// Insert an allocation and deallocation for the result of this operation.
|
||||||
Value alloc;
|
Value alloc;
|
||||||
|
@ -44,8 +44,8 @@ struct ONNXBatchNormalizationTestModeOpLowering : public ConversionPattern {
|
||||||
if (hasAllConstantDimensions(memRefType))
|
if (hasAllConstantDimensions(memRefType))
|
||||||
alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc);
|
alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc);
|
||||||
else
|
else
|
||||||
alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc,
|
alloc = insertAllocAndDealloc(
|
||||||
{operand});
|
memRefType, loc, rewriter, insertDealloc, {operand});
|
||||||
|
|
||||||
// Operand's dimensions can be in the form of NxCxD1xD2x...xDn or N.
|
// Operand's dimensions can be in the form of NxCxD1xD2x...xDn or N.
|
||||||
// In case of N, C is assumed to be 1.
|
// In case of N, C is assumed to be 1.
|
||||||
|
@ -67,8 +67,8 @@ struct ONNXBatchNormalizationTestModeOpLowering : public ConversionPattern {
|
||||||
|
|
||||||
SmallVector<Value, 1> loopCIVs;
|
SmallVector<Value, 1> loopCIVs;
|
||||||
if (rank > 1) {
|
if (rank > 1) {
|
||||||
KrnlIterateOperandPack cPack(rewriter, originalLoops[1],
|
KrnlIterateOperandPack cPack(
|
||||||
optimizedLoops[1]);
|
rewriter, originalLoops[1], optimizedLoops[1]);
|
||||||
addDimensionToPack(rewriter, loc, cPack, operand, 1);
|
addDimensionToPack(rewriter, loc, cPack, operand, 1);
|
||||||
auto cIterateOp = rewriter.create<KrnlIterateOp>(loc, cPack);
|
auto cIterateOp = rewriter.create<KrnlIterateOp>(loc, cPack);
|
||||||
Block &cIterationBlock = cIterateOp.bodyRegion().front();
|
Block &cIterationBlock = cIterateOp.bodyRegion().front();
|
||||||
|
|
|
@ -38,6 +38,7 @@ struct ONNXMaxPoolSingleOutOpLowering : public ConversionPattern {
|
||||||
|
|
||||||
PatternMatchResult matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
PatternMatchResult matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||||
ConversionPatternRewriter &rewriter) const final {
|
ConversionPatternRewriter &rewriter) const final {
|
||||||
|
ONNXMaxPoolSingleOutOpOperandAdaptor operandAdaptor(operands);
|
||||||
auto loc = op->getLoc();
|
auto loc = op->getLoc();
|
||||||
|
|
||||||
// Match
|
// Match
|
||||||
|
@ -71,7 +72,7 @@ struct ONNXMaxPoolSingleOutOpLowering : public ConversionPattern {
|
||||||
dilations.emplace_back(dilation.cast<IntegerAttr>().getInt());
|
dilations.emplace_back(dilation.cast<IntegerAttr>().getInt());
|
||||||
|
|
||||||
// Type information about the input and result of this operation.
|
// Type information about the input and result of this operation.
|
||||||
auto &inputOperand = operands[0];
|
auto inputOperand = operandAdaptor.X();
|
||||||
auto inputShape = inputOperand.getType().cast<MemRefType>().getShape();
|
auto inputShape = inputOperand.getType().cast<MemRefType>().getShape();
|
||||||
auto memRefType = convertToMemRefType(*op->result_type_begin());
|
auto memRefType = convertToMemRefType(*op->result_type_begin());
|
||||||
auto resultShape = memRefType.getShape();
|
auto resultShape = memRefType.getShape();
|
||||||
|
|
|
@ -16,10 +16,10 @@ struct ONNXIdentityOpLowering : public ConversionPattern {
|
||||||
ONNXIdentityOpLowering(MLIRContext *ctx)
|
ONNXIdentityOpLowering(MLIRContext *ctx)
|
||||||
: ConversionPattern(mlir::ONNXIdentityOp::getOperationName(), 1, ctx) {}
|
: ConversionPattern(mlir::ONNXIdentityOp::getOperationName(), 1, ctx) {}
|
||||||
|
|
||||||
PatternMatchResult
|
PatternMatchResult matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
|
||||||
ConversionPatternRewriter &rewriter) const final {
|
ConversionPatternRewriter &rewriter) const final {
|
||||||
rewriter.replaceOp(op, operands[0]);
|
ONNXIdentityOpOperandAdaptor operandAdaptor(operands);
|
||||||
|
rewriter.replaceOp(op, operandAdaptor.input());
|
||||||
return matchSuccess();
|
return matchSuccess();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
|
@ -14,13 +14,13 @@ using namespace mlir;
|
||||||
|
|
||||||
struct ONNXPadConstantValuePadOpLowering : public ConversionPattern {
|
struct ONNXPadConstantValuePadOpLowering : public ConversionPattern {
|
||||||
ONNXPadConstantValuePadOpLowering(MLIRContext *ctx)
|
ONNXPadConstantValuePadOpLowering(MLIRContext *ctx)
|
||||||
: ConversionPattern(mlir::ONNXPadConstantValuePadOp::getOperationName(),
|
: ConversionPattern(
|
||||||
1, ctx) {}
|
mlir::ONNXPadConstantValuePadOp::getOperationName(), 1, ctx) {}
|
||||||
|
|
||||||
PatternMatchResult
|
PatternMatchResult matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
|
||||||
ConversionPatternRewriter &rewriter) const final {
|
ConversionPatternRewriter &rewriter) const final {
|
||||||
auto tensorType = (*op->result_type_begin());
|
auto tensorType = (*op->result_type_begin());
|
||||||
|
ONNXPadConstantValuePadOpOperandAdaptor operandAdaptor(operands);
|
||||||
auto loc = op->getLoc();
|
auto loc = op->getLoc();
|
||||||
|
|
||||||
// Only constant padding is supported now.
|
// Only constant padding is supported now.
|
||||||
|
@ -55,7 +55,7 @@ struct ONNXPadConstantValuePadOpLowering : public ConversionPattern {
|
||||||
BuildKrnlLoop valueLoops(rewriter, loc, rank);
|
BuildKrnlLoop valueLoops(rewriter, loc, rank);
|
||||||
valueLoops.createDefineAndOptimizeOp();
|
valueLoops.createDefineAndOptimizeOp();
|
||||||
for (int i = 0; i < rank; ++i)
|
for (int i = 0; i < rank; ++i)
|
||||||
valueLoops.pushBounds(0, operands[0], i);
|
valueLoops.pushBounds(0, operandAdaptor.data(), i);
|
||||||
valueLoops.createIterateOp();
|
valueLoops.createIterateOp();
|
||||||
|
|
||||||
// Copy the input data into the output.
|
// Copy the input data into the output.
|
||||||
|
@ -77,14 +77,14 @@ struct ONNXPadConstantValuePadOpLowering : public ConversionPattern {
|
||||||
if (pad_begin[i] == 0) {
|
if (pad_begin[i] == 0) {
|
||||||
outLoopIVs.emplace_back(valueLoops.getInductionVar(i));
|
outLoopIVs.emplace_back(valueLoops.getInductionVar(i));
|
||||||
} else {
|
} else {
|
||||||
auto outIV = rewriter.create<AddIOp>(
|
auto outIV = rewriter.create<AddIOp>(loc,
|
||||||
loc, rewriter.create<ConstantIndexOp>(loc, pad_begin[i]),
|
rewriter.create<ConstantIndexOp>(loc, pad_begin[i]),
|
||||||
valueLoops.getInductionVar(i));
|
valueLoops.getInductionVar(i));
|
||||||
outLoopIVs.emplace_back(outIV);
|
outLoopIVs.emplace_back(outIV);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
auto inVal = rewriter.create<LoadOp>(loc, operands[0], inLoopIVs);
|
auto inVal = rewriter.create<LoadOp>(loc, operandAdaptor.data(), inLoopIVs);
|
||||||
rewriter.create<StoreOp>(loc, inVal, alloc, outLoopIVs);
|
rewriter.create<StoreOp>(loc, inVal, alloc, outLoopIVs);
|
||||||
rewriter.setInsertionPointToStart(padLoops.getIterateBlock());
|
rewriter.setInsertionPointToStart(padLoops.getIterateBlock());
|
||||||
|
|
||||||
|
|
|
@ -16,11 +16,12 @@ struct ONNXReshapeOpLowering : public ConversionPattern {
|
||||||
ONNXReshapeOpLowering(MLIRContext *ctx)
|
ONNXReshapeOpLowering(MLIRContext *ctx)
|
||||||
: ConversionPattern(mlir::ONNXReshapeOp::getOperationName(), 1, ctx) {}
|
: ConversionPattern(mlir::ONNXReshapeOp::getOperationName(), 1, ctx) {}
|
||||||
|
|
||||||
PatternMatchResult
|
PatternMatchResult matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
|
||||||
ConversionPatternRewriter &rewriter) const final {
|
ConversionPatternRewriter &rewriter) const final {
|
||||||
auto inputShape = operands[0].getType().cast<MemRefType>().getShape();
|
ONNXReshapeOpOperandAdaptor operandAdaptor(operands);
|
||||||
auto loc = op->getLoc();
|
auto loc = op->getLoc();
|
||||||
|
Value data = operandAdaptor.data();
|
||||||
|
auto inputShape = data.getType().cast<MemRefType>().getShape();
|
||||||
|
|
||||||
// Insert an allocation and deallocation for the result of this operation.
|
// Insert an allocation and deallocation for the result of this operation.
|
||||||
auto memRefType = convertToMemRefType(*op->result_type_begin());
|
auto memRefType = convertToMemRefType(*op->result_type_begin());
|
||||||
|
@ -33,7 +34,7 @@ struct ONNXReshapeOpLowering : public ConversionPattern {
|
||||||
for (int i = 0; i < inputShape.size(); ++i) {
|
for (int i = 0; i < inputShape.size(); ++i) {
|
||||||
Value dimVal;
|
Value dimVal;
|
||||||
if (inputShape[i] < 0) {
|
if (inputShape[i] < 0) {
|
||||||
Value dim = rewriter.create<DimOp>(loc, operands[0], i);
|
Value dim = rewriter.create<DimOp>(loc, data, i);
|
||||||
dimVal =
|
dimVal =
|
||||||
rewriter.create<IndexCastOp>(loc, dim, rewriter.getIntegerType(64));
|
rewriter.create<IndexCastOp>(loc, dim, rewriter.getIntegerType(64));
|
||||||
} else {
|
} else {
|
||||||
|
@ -61,8 +62,7 @@ struct ONNXReshapeOpLowering : public ConversionPattern {
|
||||||
rewriter.getIntegerType(64), getMemRefEltSizeInBytes(memRefType));
|
rewriter.getIntegerType(64), getMemRefEltSizeInBytes(memRefType));
|
||||||
SmallVector<Value, 4> DimInfo;
|
SmallVector<Value, 4> DimInfo;
|
||||||
for (int i = 0; i < memRefShape.size(); ++i) {
|
for (int i = 0; i < memRefShape.size(); ++i) {
|
||||||
Value index =
|
Value index = emitConstantOp(rewriter, loc, rewriter.getIndexType(), i);
|
||||||
emitConstantOp(rewriter, loc, rewriter.getIndexType(), i);
|
|
||||||
// Load index from array of indices.
|
// Load index from array of indices.
|
||||||
Value loadedVal = rewriter.create<LoadOp>(loc, operands[1], index);
|
Value loadedVal = rewriter.create<LoadOp>(loc, operands[1], index);
|
||||||
// If a dimension is zero, the actual dimension value is taken from the
|
// If a dimension is zero, the actual dimension value is taken from the
|
||||||
|
@ -75,7 +75,7 @@ struct ONNXReshapeOpLowering : public ConversionPattern {
|
||||||
Value dimVal;
|
Value dimVal;
|
||||||
auto loadedValType = loadedVal.getType().cast<IntegerType>();
|
auto loadedValType = loadedVal.getType().cast<IntegerType>();
|
||||||
if (inputShape[i] < 0) {
|
if (inputShape[i] < 0) {
|
||||||
Value dim = rewriter.create<DimOp>(loc, operands[0], i);
|
Value dim = rewriter.create<DimOp>(loc, data, i);
|
||||||
dimVal = rewriter.create<IndexCastOp>(loc, dim, loadedValType);
|
dimVal = rewriter.create<IndexCastOp>(loc, dim, loadedValType);
|
||||||
} else {
|
} else {
|
||||||
dimVal =
|
dimVal =
|
||||||
|
@ -136,7 +136,7 @@ struct ONNXReshapeOpLowering : public ConversionPattern {
|
||||||
alloc = allocateMemref;
|
alloc = allocateMemref;
|
||||||
}
|
}
|
||||||
|
|
||||||
rewriter.create<KrnlMemcpyOp>(loc, alloc, operands[0], tensorSize);
|
rewriter.create<KrnlMemcpyOp>(loc, alloc, data, tensorSize);
|
||||||
rewriter.replaceOp(op, alloc);
|
rewriter.replaceOp(op, alloc);
|
||||||
|
|
||||||
return matchSuccess();
|
return matchSuccess();
|
||||||
|
|
|
@ -16,20 +16,21 @@ struct ONNXTransposeOpLowering : public ConversionPattern {
|
||||||
ONNXTransposeOpLowering(MLIRContext *ctx)
|
ONNXTransposeOpLowering(MLIRContext *ctx)
|
||||||
: ConversionPattern(mlir::ONNXTransposeOp::getOperationName(), 1, ctx) {}
|
: ConversionPattern(mlir::ONNXTransposeOp::getOperationName(), 1, ctx) {}
|
||||||
|
|
||||||
PatternMatchResult
|
PatternMatchResult matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
|
||||||
ConversionPatternRewriter &rewriter) const final {
|
ConversionPatternRewriter &rewriter) const final {
|
||||||
|
ONNXTransposeOpOperandAdaptor operandAdaptor(operands);
|
||||||
auto loc = op->getLoc();
|
auto loc = op->getLoc();
|
||||||
// Insert an allocation and deallocation for the result of this operation.
|
// Insert an allocation and deallocation for the result of this operation.
|
||||||
auto memRefType = convertToMemRefType(*op->result_type_begin());
|
auto memRefType = convertToMemRefType(*op->result_type_begin());
|
||||||
Value alloc;
|
Value alloc;
|
||||||
bool insertDealloc = checkInsertDealloc(op);
|
bool insertDealloc = checkInsertDealloc(op);
|
||||||
|
Value data = operandAdaptor.data();
|
||||||
|
|
||||||
if (hasAllConstantDimensions(memRefType))
|
if (hasAllConstantDimensions(memRefType))
|
||||||
alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc);
|
alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc);
|
||||||
else
|
else
|
||||||
alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc,
|
alloc = insertAllocAndDealloc(
|
||||||
{operands[0]});
|
memRefType, loc, rewriter, insertDealloc, {data});
|
||||||
|
|
||||||
// Number of loops
|
// Number of loops
|
||||||
auto memRefShape = memRefType.getShape();
|
auto memRefShape = memRefType.getShape();
|
||||||
|
@ -38,13 +39,13 @@ struct ONNXTransposeOpLowering : public ConversionPattern {
|
||||||
// Define loops.
|
// Define loops.
|
||||||
std::vector<Value> originalLoops;
|
std::vector<Value> originalLoops;
|
||||||
std::vector<Value> optimizedLoops;
|
std::vector<Value> optimizedLoops;
|
||||||
Block *optimizationBlock = defineLoops(rewriter, loc, originalLoops,
|
Block *optimizationBlock =
|
||||||
optimizedLoops, rank);
|
defineLoops(rewriter, loc, originalLoops, optimizedLoops, rank);
|
||||||
|
|
||||||
KrnlIterateOperandPack pack(rewriter, originalLoops, optimizedLoops);
|
KrnlIterateOperandPack pack(rewriter, originalLoops, optimizedLoops);
|
||||||
// Iterate over the loop nest using the input shape.
|
// Iterate over the loop nest using the input shape.
|
||||||
for (int i = 0; i < rank; ++i)
|
for (int i = 0; i < rank; ++i)
|
||||||
addDimensionToPack(rewriter, loc, pack, operands[0], i);
|
addDimensionToPack(rewriter, loc, pack, data, i);
|
||||||
|
|
||||||
auto iterateOp = rewriter.create<KrnlIterateOp>(loc, pack);
|
auto iterateOp = rewriter.create<KrnlIterateOp>(loc, pack);
|
||||||
Block &iterationBlock = iterateOp.bodyRegion().front();
|
Block &iterationBlock = iterateOp.bodyRegion().front();
|
||||||
|
@ -87,7 +88,7 @@ struct ONNXTransposeOpLowering : public ConversionPattern {
|
||||||
for (int i = 0; i < iterationBlock.getArguments().size(); ++i)
|
for (int i = 0; i < iterationBlock.getArguments().size(); ++i)
|
||||||
outLoopIVs.emplace_back(iterationBlock.getArguments()[perm[i]]);
|
outLoopIVs.emplace_back(iterationBlock.getArguments()[perm[i]]);
|
||||||
|
|
||||||
auto inVal = rewriter.create<LoadOp>(loc, operands[0], inLoopIVs);
|
auto inVal = rewriter.create<LoadOp>(loc, data, inLoopIVs);
|
||||||
rewriter.create<StoreOp>(loc, inVal, alloc, outLoopIVs);
|
rewriter.create<StoreOp>(loc, inVal, alloc, outLoopIVs);
|
||||||
|
|
||||||
rewriter.replaceOp(op, alloc);
|
rewriter.replaceOp(op, alloc);
|
||||||
|
|
|
@ -16,12 +16,13 @@ struct ONNXUnsqueezeOpLowering : public ConversionPattern {
|
||||||
ONNXUnsqueezeOpLowering(MLIRContext *ctx)
|
ONNXUnsqueezeOpLowering(MLIRContext *ctx)
|
||||||
: ConversionPattern(mlir::ONNXUnsqueezeOp::getOperationName(), 1, ctx) {}
|
: ConversionPattern(mlir::ONNXUnsqueezeOp::getOperationName(), 1, ctx) {}
|
||||||
|
|
||||||
PatternMatchResult
|
PatternMatchResult matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
|
||||||
ConversionPatternRewriter &rewriter) const final {
|
ConversionPatternRewriter &rewriter) const final {
|
||||||
|
ONNXUnsqueezeOpOperandAdaptor operandAdaptor(operands);
|
||||||
auto loc = op->getLoc();
|
auto loc = op->getLoc();
|
||||||
auto memRefType = convertToMemRefType(*op->result_type_begin());
|
auto memRefType = convertToMemRefType(*op->result_type_begin());
|
||||||
int outRank = memRefType.getRank();
|
int outRank = memRefType.getRank();
|
||||||
|
Value data = operandAdaptor.data();
|
||||||
|
|
||||||
// Assume that `axes` has been validated by shape inference.
|
// Assume that `axes` has been validated by shape inference.
|
||||||
// So, here we just get it.
|
// So, here we just get it.
|
||||||
|
@ -55,7 +56,7 @@ struct ONNXUnsqueezeOpLowering : public ConversionPattern {
|
||||||
for (int outIdx = 0, inIdx = 0; outIdx < memRefShape.size(); ++outIdx) {
|
for (int outIdx = 0, inIdx = 0; outIdx < memRefShape.size(); ++outIdx) {
|
||||||
Value dimVal = nullptr;
|
Value dimVal = nullptr;
|
||||||
if (memRefShape[outIdx] < 0) {
|
if (memRefShape[outIdx] < 0) {
|
||||||
Value index = rewriter.create<DimOp>(loc, operands[0], inIdx);
|
Value index = rewriter.create<DimOp>(loc, data, inIdx);
|
||||||
dimVal = rewriter.create<IndexCastOp>(
|
dimVal = rewriter.create<IndexCastOp>(
|
||||||
loc, index, rewriter.getIntegerType(64));
|
loc, index, rewriter.getIntegerType(64));
|
||||||
allocOperands.emplace_back(index);
|
allocOperands.emplace_back(index);
|
||||||
|
@ -74,7 +75,7 @@ struct ONNXUnsqueezeOpLowering : public ConversionPattern {
|
||||||
dealloc.getOperation()->moveBefore(&parentBlock->back());
|
dealloc.getOperation()->moveBefore(&parentBlock->back());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
rewriter.create<KrnlMemcpyOp>(loc, alloc, operands[0], tensorSize);
|
rewriter.create<KrnlMemcpyOp>(loc, alloc, data, tensorSize);
|
||||||
rewriter.replaceOp(op, alloc);
|
rewriter.replaceOp(op, alloc);
|
||||||
return matchSuccess();
|
return matchSuccess();
|
||||||
}
|
}
|
||||||
|
|
|
@ -28,9 +28,7 @@ using namespace mlir;
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
static FlatSymbolRefAttr getOrInsertExternFunc(StringRef funcName,
|
static FlatSymbolRefAttr getOrInsertExternFunc(StringRef funcName,
|
||||||
ModuleOp module,
|
ModuleOp module, mlir::LLVM::LLVMType funcType, PatternRewriter &rewriter) {
|
||||||
mlir::LLVM::LLVMType funcType,
|
|
||||||
PatternRewriter &rewriter) {
|
|
||||||
auto *context = module.getContext();
|
auto *context = module.getContext();
|
||||||
if (module.lookupSymbol<LLVM::LLVMFuncOp>(funcName)) {
|
if (module.lookupSymbol<LLVM::LLVMFuncOp>(funcName)) {
|
||||||
auto symbolRef = SymbolRefAttr::get(funcName, context);
|
auto symbolRef = SymbolRefAttr::get(funcName, context);
|
||||||
|
@ -71,10 +69,10 @@ public:
|
||||||
explicit KrnlMemcpyOpLowering(MLIRContext *context)
|
explicit KrnlMemcpyOpLowering(MLIRContext *context)
|
||||||
: ConversionPattern(KrnlMemcpyOp::getOperationName(), 1, context) {}
|
: ConversionPattern(KrnlMemcpyOp::getOperationName(), 1, context) {}
|
||||||
|
|
||||||
PatternMatchResult
|
PatternMatchResult matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
|
||||||
ConversionPatternRewriter &rewriter) const override {
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
auto *context = op->getContext();
|
auto *context = op->getContext();
|
||||||
|
KrnlMemcpyOpOperandAdaptor operandAdaptor(operands);
|
||||||
auto loc = op->getLoc();
|
auto loc = op->getLoc();
|
||||||
auto *llvmDialect =
|
auto *llvmDialect =
|
||||||
op->getContext()->getRegisteredDialect<LLVM::LLVMDialect>();
|
op->getContext()->getRegisteredDialect<LLVM::LLVMDialect>();
|
||||||
|
@ -85,33 +83,37 @@ public:
|
||||||
auto memcpyRef = getOrInsertMemcpy(rewriter, parentModule, llvmDialect);
|
auto memcpyRef = getOrInsertMemcpy(rewriter, parentModule, llvmDialect);
|
||||||
|
|
||||||
// First operand.
|
// First operand.
|
||||||
Type dstType =
|
Type dstType = operandAdaptor.dest()
|
||||||
operands[0].getType().cast<LLVM::LLVMType>().getStructElementType(1);
|
.getType()
|
||||||
|
.cast<LLVM::LLVMType>()
|
||||||
|
.getStructElementType(1);
|
||||||
Value alignedDstMemory = rewriter.create<LLVM::ExtractValueOp>(
|
Value alignedDstMemory = rewriter.create<LLVM::ExtractValueOp>(
|
||||||
loc, dstType, operands[0], rewriter.getI64ArrayAttr(1));
|
loc, dstType, operandAdaptor.dest(), rewriter.getI64ArrayAttr(1));
|
||||||
Value alignedInt8PtrDstMemory = rewriter.create<LLVM::BitcastOp>(
|
Value alignedInt8PtrDstMemory = rewriter.create<LLVM::BitcastOp>(
|
||||||
loc, LLVM::LLVMType::getInt8PtrTy(llvmDialect), alignedDstMemory);
|
loc, LLVM::LLVMType::getInt8PtrTy(llvmDialect), alignedDstMemory);
|
||||||
|
|
||||||
// Second operand.
|
// Second operand.
|
||||||
Type srcType =
|
Type srcType = operandAdaptor.src()
|
||||||
operands[1].getType().cast<LLVM::LLVMType>().getStructElementType(1);
|
.getType()
|
||||||
|
.cast<LLVM::LLVMType>()
|
||||||
|
.getStructElementType(1);
|
||||||
Value alignedSrcMemory = rewriter.create<LLVM::ExtractValueOp>(
|
Value alignedSrcMemory = rewriter.create<LLVM::ExtractValueOp>(
|
||||||
loc, srcType, operands[1], rewriter.getI64ArrayAttr(1));
|
loc, srcType, operandAdaptor.src(), rewriter.getI64ArrayAttr(1));
|
||||||
Value alignedInt8PtrSrcMemory = rewriter.create<LLVM::BitcastOp>(
|
Value alignedInt8PtrSrcMemory = rewriter.create<LLVM::BitcastOp>(
|
||||||
loc, LLVM::LLVMType::getInt8PtrTy(llvmDialect), alignedSrcMemory);
|
loc, LLVM::LLVMType::getInt8PtrTy(llvmDialect), alignedSrcMemory);
|
||||||
|
|
||||||
// Size.
|
// Size.
|
||||||
Value int64Size = rewriter.create<LLVM::SExtOp>(
|
Value int64Size = rewriter.create<LLVM::SExtOp>(
|
||||||
loc, LLVM::LLVMType::getInt64Ty(llvmDialect), operands[2]);
|
loc, LLVM::LLVMType::getInt64Ty(llvmDialect), operandAdaptor.size());
|
||||||
|
|
||||||
// Is volatile (set to false).
|
// Is volatile (set to false).
|
||||||
Value isVolatile = rewriter.create<LLVM::ConstantOp>(
|
Value isVolatile = rewriter.create<LLVM::ConstantOp>(loc,
|
||||||
loc, LLVM::LLVMType::getInt1Ty(llvmDialect),
|
LLVM::LLVMType::getInt1Ty(llvmDialect),
|
||||||
rewriter.getIntegerAttr(rewriter.getIntegerType(1), 0));
|
rewriter.getIntegerAttr(rewriter.getIntegerType(1), 0));
|
||||||
|
|
||||||
// Memcpy call
|
// Memcpy call
|
||||||
rewriter.create<CallOp>(
|
rewriter.create<CallOp>(loc, memcpyRef,
|
||||||
loc, memcpyRef, LLVM::LLVMType::getVoidTy(llvmDialect),
|
LLVM::LLVMType::getVoidTy(llvmDialect),
|
||||||
ArrayRef<Value>({alignedInt8PtrDstMemory, alignedInt8PtrSrcMemory,
|
ArrayRef<Value>({alignedInt8PtrDstMemory, alignedInt8PtrSrcMemory,
|
||||||
int64Size, isVolatile}));
|
int64Size, isVolatile}));
|
||||||
|
|
||||||
|
@ -123,8 +125,7 @@ private:
|
||||||
/// Return a symbol reference to the memcpy function, inserting it into the
|
/// Return a symbol reference to the memcpy function, inserting it into the
|
||||||
/// module if necessary.
|
/// module if necessary.
|
||||||
static FlatSymbolRefAttr getOrInsertMemcpy(PatternRewriter &rewriter,
|
static FlatSymbolRefAttr getOrInsertMemcpy(PatternRewriter &rewriter,
|
||||||
ModuleOp module,
|
ModuleOp module, LLVM::LLVMDialect *llvmDialect) {
|
||||||
LLVM::LLVMDialect *llvmDialect) {
|
|
||||||
auto *context = module.getContext();
|
auto *context = module.getContext();
|
||||||
if (module.lookupSymbol<LLVM::LLVMFuncOp>("llvm.memcpy.p0i8.p0i8.i64"))
|
if (module.lookupSymbol<LLVM::LLVMFuncOp>("llvm.memcpy.p0i8.p0i8.i64"))
|
||||||
return SymbolRefAttr::get("llvm.memcpy.p0i8.p0i8.i64", context);
|
return SymbolRefAttr::get("llvm.memcpy.p0i8.p0i8.i64", context);
|
||||||
|
@ -134,8 +135,7 @@ private:
|
||||||
auto llvmI8PtrTy = LLVM::LLVMType::getInt8PtrTy(llvmDialect);
|
auto llvmI8PtrTy = LLVM::LLVMType::getInt8PtrTy(llvmDialect);
|
||||||
auto llvmI64Ty = LLVM::LLVMType::getInt64Ty(llvmDialect);
|
auto llvmI64Ty = LLVM::LLVMType::getInt64Ty(llvmDialect);
|
||||||
auto llvmI1Ty = LLVM::LLVMType::getInt1Ty(llvmDialect);
|
auto llvmI1Ty = LLVM::LLVMType::getInt1Ty(llvmDialect);
|
||||||
auto llvmFnType = LLVM::LLVMType::getFunctionTy(
|
auto llvmFnType = LLVM::LLVMType::getFunctionTy(llvmVoidTy,
|
||||||
llvmVoidTy,
|
|
||||||
ArrayRef<mlir::LLVM::LLVMType>(
|
ArrayRef<mlir::LLVM::LLVMType>(
|
||||||
{llvmI8PtrTy, llvmI8PtrTy, llvmI64Ty, llvmI1Ty}),
|
{llvmI8PtrTy, llvmI8PtrTy, llvmI64Ty, llvmI1Ty}),
|
||||||
false);
|
false);
|
||||||
|
@ -143,8 +143,8 @@ private:
|
||||||
// Insert the memcpy function into the body of the parent module.
|
// Insert the memcpy function into the body of the parent module.
|
||||||
PatternRewriter::InsertionGuard insertGuard(rewriter);
|
PatternRewriter::InsertionGuard insertGuard(rewriter);
|
||||||
rewriter.setInsertionPointToStart(module.getBody());
|
rewriter.setInsertionPointToStart(module.getBody());
|
||||||
rewriter.create<LLVM::LLVMFuncOp>(module.getLoc(),
|
rewriter.create<LLVM::LLVMFuncOp>(
|
||||||
"llvm.memcpy.p0i8.p0i8.i64", llvmFnType);
|
module.getLoc(), "llvm.memcpy.p0i8.p0i8.i64", llvmFnType);
|
||||||
return SymbolRefAttr::get("llvm.memcpy.p0i8.p0i8.i64", context);
|
return SymbolRefAttr::get("llvm.memcpy.p0i8.p0i8.i64", context);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -186,8 +186,8 @@ public:
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
PatternMatchResult matchAndRewrite(KrnlEntryPointOp op,
|
PatternMatchResult matchAndRewrite(
|
||||||
PatternRewriter &rewriter) const override {
|
KrnlEntryPointOp op, PatternRewriter &rewriter) const override {
|
||||||
|
|
||||||
auto *llvmDialect =
|
auto *llvmDialect =
|
||||||
op.getContext()->getRegisteredDialect<LLVM::LLVMDialect>();
|
op.getContext()->getRegisteredDialect<LLVM::LLVMDialect>();
|
||||||
|
@ -261,8 +261,8 @@ public:
|
||||||
|
|
||||||
// Fill in the memref underlying ptrToMemRef with information extracted
|
// Fill in the memref underlying ptrToMemRef with information extracted
|
||||||
// from dynMemRef.
|
// from dynMemRef.
|
||||||
fillPtrToMemRefWithDynMemRef(dynMemRef, ptrToMemRef, rewriter, loc,
|
fillPtrToMemRefWithDynMemRef(
|
||||||
apiRegistry, llvmDialect);
|
dynMemRef, ptrToMemRef, rewriter, loc, apiRegistry, llvmDialect);
|
||||||
|
|
||||||
// ptrToMemRef will be an input to main computation graph function.
|
// ptrToMemRef will be an input to main computation graph function.
|
||||||
staticInputs.emplace_back(ptrToMemRef);
|
staticInputs.emplace_back(ptrToMemRef);
|
||||||
|
@ -273,14 +273,14 @@ public:
|
||||||
assert(numOutputs == 1 && "only support 1 output tensor now.");
|
assert(numOutputs == 1 && "only support 1 output tensor now.");
|
||||||
|
|
||||||
// Call static entry point with the memref ptrs created, and get output.
|
// Call static entry point with the memref ptrs created, and get output.
|
||||||
auto outputMemRefs = rewriter.create<LLVM::CallOp>(
|
auto outputMemRefs = rewriter.create<LLVM::CallOp>(loc,
|
||||||
loc, staticEntryPointTy.getFunctionResultType(),
|
staticEntryPointTy.getFunctionResultType(),
|
||||||
rewriter.getSymbolRefAttr(wrappedStaticEntryPointFuncName),
|
rewriter.getSymbolRefAttr(wrappedStaticEntryPointFuncName),
|
||||||
staticInputs);
|
staticInputs);
|
||||||
|
|
||||||
// Create wrapped output.
|
// Create wrapped output.
|
||||||
auto wrappedOutput = callApi(rewriter, loc, apiRegistry,
|
auto wrappedOutput = callApi(
|
||||||
API::CREATE_ORDERED_DYN_MEM_REF_DICT, {});
|
rewriter, loc, apiRegistry, API::CREATE_ORDERED_DYN_MEM_REF_DICT, {});
|
||||||
|
|
||||||
// Get the first memref returned, convert to a dynamic memref and store
|
// Get the first memref returned, convert to a dynamic memref and store
|
||||||
// it in the wrapped Output.
|
// it in the wrapped Output.
|
||||||
|
@ -291,16 +291,16 @@ public:
|
||||||
loc, int32Ty, rewriter.getI32IntegerAttr(outMemRefRank));
|
loc, int32Ty, rewriter.getI32IntegerAttr(outMemRefRank));
|
||||||
auto outDynMemRef = callApi(rewriter, loc, apiRegistry,
|
auto outDynMemRef = callApi(rewriter, loc, apiRegistry,
|
||||||
API::CREATE_DYN_MEM_REF, {outMemRefRankVal});
|
API::CREATE_DYN_MEM_REF, {outMemRefRankVal});
|
||||||
fillDynMemRefWithMemRef(outMemRef, outDynMemRef, rewriter, loc, apiRegistry,
|
fillDynMemRefWithMemRef(
|
||||||
llvmDialect);
|
outMemRef, outDynMemRef, rewriter, loc, apiRegistry, llvmDialect);
|
||||||
auto zero = rewriter.create<LLVM::ConstantOp>(
|
auto zero = rewriter.create<LLVM::ConstantOp>(
|
||||||
loc, int32Ty, rewriter.getI32IntegerAttr(0));
|
loc, int32Ty, rewriter.getI32IntegerAttr(0));
|
||||||
callApi(rewriter, loc, apiRegistry, API::SET_DYN_MEM_REF,
|
callApi(rewriter, loc, apiRegistry, API::SET_DYN_MEM_REF,
|
||||||
{wrappedOutput, zero, outDynMemRef});
|
{wrappedOutput, zero, outDynMemRef});
|
||||||
|
|
||||||
// Return wrapped output.
|
// Return wrapped output.
|
||||||
rewriter.create<LLVM::ReturnOp>(loc,
|
rewriter.create<LLVM::ReturnOp>(
|
||||||
SmallVector<Value, 1>({wrappedOutput}));
|
loc, SmallVector<Value, 1>({wrappedOutput}));
|
||||||
return matchSuccess();
|
return matchSuccess();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -335,8 +335,8 @@ private:
|
||||||
// identities to a symbol reference to the API function.
|
// identities to a symbol reference to the API function.
|
||||||
ApiRegistry registry;
|
ApiRegistry registry;
|
||||||
for (auto &apiSpec : apiSpecs) {
|
for (auto &apiSpec : apiSpecs) {
|
||||||
apiSpec.symbolRef = getOrInsertExternFunc(apiSpec.name, module,
|
apiSpec.symbolRef = getOrInsertExternFunc(
|
||||||
apiSpec.funcTy(), rewriter);
|
apiSpec.name, module, apiSpec.funcTy(), rewriter);
|
||||||
registry.emplace(apiSpec.id, apiSpec);
|
registry.emplace(apiSpec.id, apiSpec);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -347,9 +347,9 @@ private:
|
||||||
// returned, otherwise return nullptr.
|
// returned, otherwise return nullptr.
|
||||||
Value callApi(PatternRewriter &rewriter, Location loc, ApiRegistry registry,
|
Value callApi(PatternRewriter &rewriter, Location loc, ApiRegistry registry,
|
||||||
API apiId, ArrayRef<Value> params) const {
|
API apiId, ArrayRef<Value> params) const {
|
||||||
auto returnVals = rewriter.create<LLVM::CallOp>(
|
auto returnVals =
|
||||||
loc, registry.at(apiId).outputTy, registry.at(apiId).symbolRef,
|
rewriter.create<LLVM::CallOp>(loc, registry.at(apiId).outputTy,
|
||||||
ArrayRef<Value>(params));
|
registry.at(apiId).symbolRef, ArrayRef<Value>(params));
|
||||||
if (returnVals.getNumResults() == 1)
|
if (returnVals.getNumResults() == 1)
|
||||||
return returnVals.getResult(0);
|
return returnVals.getResult(0);
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
@ -370,8 +370,7 @@ private:
|
||||||
}
|
}
|
||||||
|
|
||||||
void fillPtrToMemRefWithDynMemRef(Value &dynMemRef, Value &ptrToMemRef,
|
void fillPtrToMemRefWithDynMemRef(Value &dynMemRef, Value &ptrToMemRef,
|
||||||
PatternRewriter &rewriter,
|
PatternRewriter &rewriter, const Location &loc,
|
||||||
const Location &loc,
|
|
||||||
const std::map<API, ApiSpec> &apiRegistry,
|
const std::map<API, ApiSpec> &apiRegistry,
|
||||||
LLVM::LLVMDialect *llvmDialect) const {
|
LLVM::LLVMDialect *llvmDialect) const {
|
||||||
auto memRefPtrTy = ptrToMemRef.getType().dyn_cast<LLVM::LLVMType>();
|
auto memRefPtrTy = ptrToMemRef.getType().dyn_cast<LLVM::LLVMType>();
|
||||||
|
@ -385,18 +384,15 @@ private:
|
||||||
callApi(rewriter, loc, apiRegistry, API::GET_DATA, {dynMemRef});
|
callApi(rewriter, loc, apiRegistry, API::GET_DATA, {dynMemRef});
|
||||||
dataPtr = rewriter.create<LLVM::BitcastOp>(
|
dataPtr = rewriter.create<LLVM::BitcastOp>(
|
||||||
loc, memRefTy.getStructElementType(0), dataPtr);
|
loc, memRefTy.getStructElementType(0), dataPtr);
|
||||||
memRef = rewriter.create<LLVM::InsertValueOp>(
|
memRef = rewriter.create<LLVM::InsertValueOp>(loc, memRefTy, memRef,
|
||||||
loc, memRefTy, memRef, dataPtr,
|
dataPtr, rewriter.getArrayAttr({rewriter.getI32IntegerAttr(0)}));
|
||||||
rewriter.getArrayAttr({rewriter.getI32IntegerAttr(0)}));
|
memRef = rewriter.create<LLVM::InsertValueOp>(loc, memRefTy, memRef,
|
||||||
memRef = rewriter.create<LLVM::InsertValueOp>(
|
dataPtr, rewriter.getArrayAttr({rewriter.getI32IntegerAttr(1)}));
|
||||||
loc, memRefTy, memRef, dataPtr,
|
|
||||||
rewriter.getArrayAttr({rewriter.getI32IntegerAttr(1)}));
|
|
||||||
|
|
||||||
// Use zero offset now.
|
// Use zero offset now.
|
||||||
auto zero = rewriter.create<LLVM::ConstantOp>(
|
auto zero = rewriter.create<LLVM::ConstantOp>(
|
||||||
loc, int64Ty, rewriter.getI64IntegerAttr(0));
|
loc, int64Ty, rewriter.getI64IntegerAttr(0));
|
||||||
memRef = rewriter.create<LLVM::InsertValueOp>(
|
memRef = rewriter.create<LLVM::InsertValueOp>(loc, memRefTy, memRef, zero,
|
||||||
loc, memRefTy, memRef, zero,
|
|
||||||
rewriter.getArrayAttr({rewriter.getI32IntegerAttr(2)}));
|
rewriter.getArrayAttr({rewriter.getI32IntegerAttr(2)}));
|
||||||
|
|
||||||
// Get rank, sizes array ptr and strides array ptr.
|
// Get rank, sizes array ptr and strides array ptr.
|
||||||
|
@ -411,24 +407,22 @@ private:
|
||||||
loc, int64Ty, rewriter.getI64IntegerAttr(i));
|
loc, int64Ty, rewriter.getI64IntegerAttr(i));
|
||||||
|
|
||||||
// Insert size of the dimension.
|
// Insert size of the dimension.
|
||||||
auto dimSizePtr = rewriter.create<LLVM::GEPOp>(
|
auto dimSizePtr = rewriter.create<LLVM::GEPOp>(loc,
|
||||||
loc, int64Ty.getPointerTo(), sizesArrayPtr,
|
int64Ty.getPointerTo(), sizesArrayPtr, ArrayRef<Value>({dimIdx}));
|
||||||
ArrayRef<Value>({dimIdx}));
|
auto dimSize = rewriter.create<LLVM::LoadOp>(
|
||||||
auto dimSize = rewriter.create<LLVM::LoadOp>(loc, int64Ty.getPointerTo(),
|
loc, int64Ty.getPointerTo(), dimSizePtr);
|
||||||
dimSizePtr);
|
memRef = rewriter.create<LLVM::InsertValueOp>(loc, memRefTy, memRef,
|
||||||
memRef = rewriter.create<LLVM::InsertValueOp>(
|
dimSize,
|
||||||
loc, memRefTy, memRef, dimSize,
|
|
||||||
rewriter.getArrayAttr(
|
rewriter.getArrayAttr(
|
||||||
{rewriter.getI64IntegerAttr(3), rewriter.getI64IntegerAttr(i)}));
|
{rewriter.getI64IntegerAttr(3), rewriter.getI64IntegerAttr(i)}));
|
||||||
|
|
||||||
// Insert stride of the dimension.
|
// Insert stride of the dimension.
|
||||||
auto dimStridePtr = rewriter.create<LLVM::GEPOp>(
|
auto dimStridePtr = rewriter.create<LLVM::GEPOp>(loc,
|
||||||
loc, int64Ty.getPointerTo(), sizesArrayPtr,
|
int64Ty.getPointerTo(), sizesArrayPtr, ArrayRef<Value>({dimIdx}));
|
||||||
ArrayRef<Value>({dimIdx}));
|
|
||||||
auto dimStride = rewriter.create<LLVM::LoadOp>(
|
auto dimStride = rewriter.create<LLVM::LoadOp>(
|
||||||
loc, int64Ty.getPointerTo(), dimStridePtr);
|
loc, int64Ty.getPointerTo(), dimStridePtr);
|
||||||
memRef = rewriter.create<LLVM::InsertValueOp>(
|
memRef = rewriter.create<LLVM::InsertValueOp>(loc, memRefTy, memRef,
|
||||||
loc, memRefTy, memRef, dimStride,
|
dimStride,
|
||||||
rewriter.getArrayAttr(
|
rewriter.getArrayAttr(
|
||||||
{rewriter.getI64IntegerAttr(4), rewriter.getI64IntegerAttr(i)}));
|
{rewriter.getI64IntegerAttr(4), rewriter.getI64IntegerAttr(i)}));
|
||||||
}
|
}
|
||||||
|
@ -444,8 +438,8 @@ private:
|
||||||
auto int64Ty = LLVM::LLVMType::getInt64Ty(llvmDialect);
|
auto int64Ty = LLVM::LLVMType::getInt64Ty(llvmDialect);
|
||||||
|
|
||||||
// Extract the data pointer, and record it in dynamic mem ref created.
|
// Extract the data pointer, and record it in dynamic mem ref created.
|
||||||
Value outMemRefDataPtr = rewriter.create<LLVM::ExtractValueOp>(
|
Value outMemRefDataPtr = rewriter.create<LLVM::ExtractValueOp>(loc,
|
||||||
loc, outMemRefTy.getStructElementType(0), outMemRef,
|
outMemRefTy.getStructElementType(0), outMemRef,
|
||||||
rewriter.getArrayAttr({rewriter.getI64IntegerAttr(0)}));
|
rewriter.getArrayAttr({rewriter.getI64IntegerAttr(0)}));
|
||||||
outMemRefDataPtr = rewriter.create<LLVM::BitcastOp>(
|
outMemRefDataPtr = rewriter.create<LLVM::BitcastOp>(
|
||||||
loc, LLVM::LLVMType::getInt8PtrTy(llvmDialect), outMemRefDataPtr);
|
loc, LLVM::LLVMType::getInt8PtrTy(llvmDialect), outMemRefDataPtr);
|
||||||
|
@ -463,23 +457,21 @@ private:
|
||||||
loc, int64Ty, rewriter.getI64IntegerAttr(i));
|
loc, int64Ty, rewriter.getI64IntegerAttr(i));
|
||||||
|
|
||||||
// Transfer size of dimension from memref to dynamic memref.
|
// Transfer size of dimension from memref to dynamic memref.
|
||||||
auto dimSize = rewriter.create<LLVM::ExtractValueOp>(
|
auto dimSize = rewriter.create<LLVM::ExtractValueOp>(loc, int64Ty,
|
||||||
loc, int64Ty, outMemRef,
|
outMemRef,
|
||||||
rewriter.getArrayAttr(
|
rewriter.getArrayAttr(
|
||||||
{rewriter.getI64IntegerAttr(3), rewriter.getI64IntegerAttr(i)}));
|
{rewriter.getI64IntegerAttr(3), rewriter.getI64IntegerAttr(i)}));
|
||||||
auto dimSizePtr = rewriter.create<LLVM::GEPOp>(
|
auto dimSizePtr = rewriter.create<LLVM::GEPOp>(loc,
|
||||||
loc, int64Ty.getPointerTo(), sizesArrayPtr,
|
int64Ty.getPointerTo(), sizesArrayPtr, ArrayRef<Value>({dimIdx}));
|
||||||
ArrayRef<Value>({dimIdx}));
|
|
||||||
rewriter.create<LLVM::StoreOp>(loc, dimSize, dimSizePtr);
|
rewriter.create<LLVM::StoreOp>(loc, dimSize, dimSizePtr);
|
||||||
|
|
||||||
// Transfer stride of dimension from memref to dynamic memref.
|
// Transfer stride of dimension from memref to dynamic memref.
|
||||||
auto dimStride = rewriter.create<LLVM::ExtractValueOp>(
|
auto dimStride = rewriter.create<LLVM::ExtractValueOp>(loc, int64Ty,
|
||||||
loc, int64Ty, outMemRef,
|
outMemRef,
|
||||||
rewriter.getArrayAttr(
|
rewriter.getArrayAttr(
|
||||||
{rewriter.getI64IntegerAttr(4), rewriter.getI64IntegerAttr(i)}));
|
{rewriter.getI64IntegerAttr(4), rewriter.getI64IntegerAttr(i)}));
|
||||||
auto dimStridePtr = rewriter.create<LLVM::GEPOp>(
|
auto dimStridePtr = rewriter.create<LLVM::GEPOp>(loc,
|
||||||
loc, int64Ty.getPointerTo(), stridesArrayPtr,
|
int64Ty.getPointerTo(), stridesArrayPtr, ArrayRef<Value>({dimIdx}));
|
||||||
ArrayRef<Value>({dimIdx}));
|
|
||||||
rewriter.create<LLVM::StoreOp>(loc, dimStride, dimStridePtr);
|
rewriter.create<LLVM::StoreOp>(loc, dimStride, dimStridePtr);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -530,5 +522,5 @@ std::unique_ptr<mlir::Pass> mlir::createKrnlLowerToLLVMPass() {
|
||||||
return std::make_unique<KrnlToLLVMLoweringPass>();
|
return std::make_unique<KrnlToLLVMLoweringPass>();
|
||||||
}
|
}
|
||||||
|
|
||||||
static PassRegistration<KrnlToLLVMLoweringPass>
|
static PassRegistration<KrnlToLLVMLoweringPass> pass(
|
||||||
pass("lower-all-llvm", "Lower the Krnl Affine and Std dialects to LLVM.");
|
"lower-all-llvm", "Lower the Krnl Affine and Std dialects to LLVM.");
|
||||||
|
|
Loading…
Reference in New Issue