clean operands using names provided by operandAdaptors (#56)
* clean operands using names provided by operandAdaptors * reverted changes that were erronerous, from Tung's comment * clang format issues
This commit is contained in:
parent
844dcd8b1f
commit
b422116f12
|
@ -24,10 +24,11 @@ struct ONNXGemmOpLowering : public ConversionPattern {
|
|||
bool hasBias = !op->getOperand(2).getType().isa<NoneType>();
|
||||
|
||||
Value A, B, C;
|
||||
A = operands[0];
|
||||
B = operands[1];
|
||||
ONNXGemmOpOperandAdaptor operandAdaptor(operands);
|
||||
A = operandAdaptor.A();
|
||||
B = operandAdaptor.B();
|
||||
if (hasBias)
|
||||
C = operands[2];
|
||||
C = operandAdaptor.C();
|
||||
|
||||
auto memRefType = convertToMemRefType(*op->result_type_begin());
|
||||
|
||||
|
|
|
@ -21,8 +21,9 @@ struct ONNXMatMulOpLowering : public ConversionPattern {
|
|||
ConversionPatternRewriter &rewriter) const final {
|
||||
auto loc = op->getLoc();
|
||||
|
||||
Value A = operands[0];
|
||||
Value B = operands[1];
|
||||
ONNXMatMulOpOperandAdaptor operandAdaptor(operands);
|
||||
Value A = operandAdaptor.A();
|
||||
Value B = operandAdaptor.B();
|
||||
auto AShape = A.getType().cast<MemRefType>().getShape();
|
||||
auto BShape = B.getType().cast<MemRefType>().getShape();
|
||||
|
||||
|
|
|
@ -15,8 +15,7 @@ using namespace mlir;
|
|||
struct ONNXSoftmaxOpLowering : public ConversionPattern {
|
||||
ONNXSoftmaxOpLowering(MLIRContext *ctx)
|
||||
: ConversionPattern(mlir::ONNXSoftmaxOp::getOperationName(), 1, ctx) {}
|
||||
PatternMatchResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
PatternMatchResult matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const final {
|
||||
// softmax(x) = let max_x = max(x) in
|
||||
// let exp_x = exp(x - max_x) in
|
||||
|
@ -29,7 +28,8 @@ struct ONNXSoftmaxOpLowering : public ConversionPattern {
|
|||
assert(axis >= -rank && axis <= rank - 1);
|
||||
|
||||
auto loc = op->getLoc();
|
||||
|
||||
ONNXSoftmaxOpOperandAdaptor operandAdaptor(operands);
|
||||
Value input = operandAdaptor.input();
|
||||
// Insert an allocation and deallocation for the result of this operation.
|
||||
auto elementType = memRefType.getElementType();
|
||||
|
||||
|
@ -38,8 +38,8 @@ struct ONNXSoftmaxOpLowering : public ConversionPattern {
|
|||
if (hasAllConstantDimensions(memRefType))
|
||||
alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc);
|
||||
else
|
||||
alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc,
|
||||
operands[0]);
|
||||
alloc = insertAllocAndDealloc(
|
||||
memRefType, loc, rewriter, insertDealloc, input);
|
||||
|
||||
// Shape of the result
|
||||
auto memRefShape = memRefType.getShape();
|
||||
|
@ -49,15 +49,14 @@ struct ONNXSoftmaxOpLowering : public ConversionPattern {
|
|||
Value sumOp = insertAllocAndDealloc(scalarMemRefType, loc, rewriter, true);
|
||||
Value maxOp = insertAllocAndDealloc(scalarMemRefType, loc, rewriter, true);
|
||||
Value zero = emitConstantOp(rewriter, loc, elementType, 0);
|
||||
Value negInfinity = rewriter.create<ConstantOp>(
|
||||
loc,
|
||||
Value negInfinity = rewriter.create<ConstantOp>(loc,
|
||||
FloatAttr::get(elementType, -std::numeric_limits<float>::infinity()));
|
||||
|
||||
// Define loops.
|
||||
std::vector<Value> originalLoops;
|
||||
std::vector<Value> optimizedLoops;
|
||||
Block *optimizationBlock = defineLoops(rewriter, loc, originalLoops,
|
||||
optimizedLoops, rank);
|
||||
Block *optimizationBlock =
|
||||
defineLoops(rewriter, loc, originalLoops, optimizedLoops, rank);
|
||||
|
||||
// Coerce the input into a 2-D tensor. `axis` will be the coercing point.
|
||||
// This coercing follows the softmax definition in ONNX:
|
||||
|
@ -75,7 +74,7 @@ struct ONNXSoftmaxOpLowering : public ConversionPattern {
|
|||
}
|
||||
KrnlIterateOperandPack outerPack(rewriter, outerLoops, optimizedOuterLoops);
|
||||
for (int i = 0; i < axis; ++i)
|
||||
addDimensionToPack(rewriter, loc, outerPack, operands[0], i);
|
||||
addDimensionToPack(rewriter, loc, outerPack, input, i);
|
||||
|
||||
// Define an inner loop with respect to axis.
|
||||
std::vector<Value> innerLoops, optimizedInnerLoops;
|
||||
|
@ -87,7 +86,7 @@ struct ONNXSoftmaxOpLowering : public ConversionPattern {
|
|||
}
|
||||
KrnlIterateOperandPack innerPack(rewriter, innerLoops, optimizedInnerLoops);
|
||||
for (int i = axis; i < rank; ++i)
|
||||
addDimensionToPack(rewriter, loc, innerPack, operands[0], i);
|
||||
addDimensionToPack(rewriter, loc, innerPack, input, i);
|
||||
|
||||
KrnlIterateOp outerIterateOp, maxIterateOp, sumIterateOp, softmaxIterateOp;
|
||||
SmallVector<Value, 4> outerLoopIVs;
|
||||
|
@ -144,7 +143,7 @@ struct ONNXSoftmaxOpLowering : public ConversionPattern {
|
|||
|
||||
// Compute the max value.
|
||||
Value max = rewriter.create<LoadOp>(loc, maxOp);
|
||||
Value nextMax = rewriter.create<LoadOp>(loc, operands[0], maxLoopIVs);
|
||||
Value nextMax = rewriter.create<LoadOp>(loc, input, maxLoopIVs);
|
||||
auto maxCond =
|
||||
rewriter.create<CmpFOp>(loc, CmpFPredicate::OGT, max, nextMax);
|
||||
max = rewriter.create<SelectOp>(loc, maxCond, max, nextMax);
|
||||
|
@ -167,7 +166,7 @@ struct ONNXSoftmaxOpLowering : public ConversionPattern {
|
|||
|
||||
// Sum up values.
|
||||
Value sum = rewriter.create<LoadOp>(loc, sumOp);
|
||||
Value next = rewriter.create<LoadOp>(loc, operands[0], sumLoopIVs);
|
||||
Value next = rewriter.create<LoadOp>(loc, input, sumLoopIVs);
|
||||
Value sub = rewriter.create<SubFOp>(loc, next, max);
|
||||
Value exp = rewriter.create<ExpOp>(loc, sub);
|
||||
sum = rewriter.create<AddFOp>(loc, sum, exp);
|
||||
|
|
|
@ -26,12 +26,6 @@ struct ONNXConvOpLowering : public ConversionPattern {
|
|||
bool insertDealloc = checkInsertDealloc(op);
|
||||
ONNXConvOp convOp = llvm::dyn_cast<ONNXConvOp>(op);
|
||||
|
||||
if (hasAllConstantDimensions(memRefType))
|
||||
alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc);
|
||||
else
|
||||
alloc = insertAllocAndDealloc(
|
||||
memRefType, loc, rewriter, insertDealloc, {operands[0]});
|
||||
|
||||
auto resultShape = memRefType.getShape();
|
||||
auto inputOperand = operandAdaptor.X();
|
||||
auto inputShape = inputOperand.getType().cast<MemRefType>().getShape();
|
||||
|
@ -40,6 +34,12 @@ struct ONNXConvOpLowering : public ConversionPattern {
|
|||
auto biasOperand = operandAdaptor.B();
|
||||
bool hasBias = !biasOperand.getType().isa<NoneType>();
|
||||
|
||||
if (hasAllConstantDimensions(memRefType))
|
||||
alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc);
|
||||
else
|
||||
alloc = insertAllocAndDealloc(
|
||||
memRefType, loc, rewriter, insertDealloc, {inputOperand});
|
||||
|
||||
// R = Conv(D, K)
|
||||
//
|
||||
// The input/output shapes will look like this:
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
//===----------- Normalization.cpp - Lowering Normalization Ops ------------===//
|
||||
//===----------- Normalization.cpp - Lowering Normalization Ops -----------===//
|
||||
//
|
||||
// Copyright 2019 The IBM Research Authors.
|
||||
//
|
||||
|
@ -21,21 +21,21 @@ struct ONNXBatchNormalizationTestModeOpLowering : public ConversionPattern {
|
|||
ConversionPatternRewriter &rewriter) const final {
|
||||
// batchnorm{epsilon}(x, scale, bias, mean, variance) =
|
||||
// scale * (x - mean) / sqrt(variance + epsilon) + bias
|
||||
ONNXBatchNormalizationTestModeOpOperandAdaptor operandAdaptor(operands);
|
||||
auto loc = op->getLoc();
|
||||
|
||||
auto memRefType = convertToMemRefType(*op->result_type_begin());
|
||||
auto epsilonAttr =
|
||||
FloatAttr::get(memRefType.getElementType(),
|
||||
auto epsilonAttr = FloatAttr::get(memRefType.getElementType(),
|
||||
llvm::dyn_cast<ONNXBatchNormalizationTestModeOp>(op)
|
||||
.epsilon()
|
||||
.convertToFloat());
|
||||
auto epsilon = rewriter.create<ConstantOp>(loc, epsilonAttr);
|
||||
|
||||
auto operand = operands[0];
|
||||
auto scale = operands[1];
|
||||
auto bias = operands[2];
|
||||
auto mean = operands[3];
|
||||
auto variance = operands[4];
|
||||
auto operand = operandAdaptor.X();
|
||||
auto scale = operandAdaptor.scale();
|
||||
auto bias = operandAdaptor.B();
|
||||
auto mean = operandAdaptor.mean();
|
||||
auto variance = operandAdaptor.var();
|
||||
|
||||
// Insert an allocation and deallocation for the result of this operation.
|
||||
Value alloc;
|
||||
|
@ -44,8 +44,8 @@ struct ONNXBatchNormalizationTestModeOpLowering : public ConversionPattern {
|
|||
if (hasAllConstantDimensions(memRefType))
|
||||
alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc);
|
||||
else
|
||||
alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc,
|
||||
{operand});
|
||||
alloc = insertAllocAndDealloc(
|
||||
memRefType, loc, rewriter, insertDealloc, {operand});
|
||||
|
||||
// Operand's dimensions can be in the form of NxCxD1xD2x...xDn or N.
|
||||
// In case of N, C is assumed to be 1.
|
||||
|
@ -67,8 +67,8 @@ struct ONNXBatchNormalizationTestModeOpLowering : public ConversionPattern {
|
|||
|
||||
SmallVector<Value, 1> loopCIVs;
|
||||
if (rank > 1) {
|
||||
KrnlIterateOperandPack cPack(rewriter, originalLoops[1],
|
||||
optimizedLoops[1]);
|
||||
KrnlIterateOperandPack cPack(
|
||||
rewriter, originalLoops[1], optimizedLoops[1]);
|
||||
addDimensionToPack(rewriter, loc, cPack, operand, 1);
|
||||
auto cIterateOp = rewriter.create<KrnlIterateOp>(loc, cPack);
|
||||
Block &cIterationBlock = cIterateOp.bodyRegion().front();
|
||||
|
|
|
@ -38,6 +38,7 @@ struct ONNXMaxPoolSingleOutOpLowering : public ConversionPattern {
|
|||
|
||||
PatternMatchResult matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const final {
|
||||
ONNXMaxPoolSingleOutOpOperandAdaptor operandAdaptor(operands);
|
||||
auto loc = op->getLoc();
|
||||
|
||||
// Match
|
||||
|
@ -71,7 +72,7 @@ struct ONNXMaxPoolSingleOutOpLowering : public ConversionPattern {
|
|||
dilations.emplace_back(dilation.cast<IntegerAttr>().getInt());
|
||||
|
||||
// Type information about the input and result of this operation.
|
||||
auto &inputOperand = operands[0];
|
||||
auto inputOperand = operandAdaptor.X();
|
||||
auto inputShape = inputOperand.getType().cast<MemRefType>().getShape();
|
||||
auto memRefType = convertToMemRefType(*op->result_type_begin());
|
||||
auto resultShape = memRefType.getShape();
|
||||
|
|
|
@ -16,10 +16,10 @@ struct ONNXIdentityOpLowering : public ConversionPattern {
|
|||
ONNXIdentityOpLowering(MLIRContext *ctx)
|
||||
: ConversionPattern(mlir::ONNXIdentityOp::getOperationName(), 1, ctx) {}
|
||||
|
||||
PatternMatchResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
PatternMatchResult matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const final {
|
||||
rewriter.replaceOp(op, operands[0]);
|
||||
ONNXIdentityOpOperandAdaptor operandAdaptor(operands);
|
||||
rewriter.replaceOp(op, operandAdaptor.input());
|
||||
return matchSuccess();
|
||||
}
|
||||
};
|
||||
|
|
|
@ -14,13 +14,13 @@ using namespace mlir;
|
|||
|
||||
struct ONNXPadConstantValuePadOpLowering : public ConversionPattern {
|
||||
ONNXPadConstantValuePadOpLowering(MLIRContext *ctx)
|
||||
: ConversionPattern(mlir::ONNXPadConstantValuePadOp::getOperationName(),
|
||||
1, ctx) {}
|
||||
: ConversionPattern(
|
||||
mlir::ONNXPadConstantValuePadOp::getOperationName(), 1, ctx) {}
|
||||
|
||||
PatternMatchResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
PatternMatchResult matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const final {
|
||||
auto tensorType = (*op->result_type_begin());
|
||||
ONNXPadConstantValuePadOpOperandAdaptor operandAdaptor(operands);
|
||||
auto loc = op->getLoc();
|
||||
|
||||
// Only constant padding is supported now.
|
||||
|
@ -55,7 +55,7 @@ struct ONNXPadConstantValuePadOpLowering : public ConversionPattern {
|
|||
BuildKrnlLoop valueLoops(rewriter, loc, rank);
|
||||
valueLoops.createDefineAndOptimizeOp();
|
||||
for (int i = 0; i < rank; ++i)
|
||||
valueLoops.pushBounds(0, operands[0], i);
|
||||
valueLoops.pushBounds(0, operandAdaptor.data(), i);
|
||||
valueLoops.createIterateOp();
|
||||
|
||||
// Copy the input data into the output.
|
||||
|
@ -77,14 +77,14 @@ struct ONNXPadConstantValuePadOpLowering : public ConversionPattern {
|
|||
if (pad_begin[i] == 0) {
|
||||
outLoopIVs.emplace_back(valueLoops.getInductionVar(i));
|
||||
} else {
|
||||
auto outIV = rewriter.create<AddIOp>(
|
||||
loc, rewriter.create<ConstantIndexOp>(loc, pad_begin[i]),
|
||||
auto outIV = rewriter.create<AddIOp>(loc,
|
||||
rewriter.create<ConstantIndexOp>(loc, pad_begin[i]),
|
||||
valueLoops.getInductionVar(i));
|
||||
outLoopIVs.emplace_back(outIV);
|
||||
}
|
||||
}
|
||||
|
||||
auto inVal = rewriter.create<LoadOp>(loc, operands[0], inLoopIVs);
|
||||
auto inVal = rewriter.create<LoadOp>(loc, operandAdaptor.data(), inLoopIVs);
|
||||
rewriter.create<StoreOp>(loc, inVal, alloc, outLoopIVs);
|
||||
rewriter.setInsertionPointToStart(padLoops.getIterateBlock());
|
||||
|
||||
|
|
|
@ -16,11 +16,12 @@ struct ONNXReshapeOpLowering : public ConversionPattern {
|
|||
ONNXReshapeOpLowering(MLIRContext *ctx)
|
||||
: ConversionPattern(mlir::ONNXReshapeOp::getOperationName(), 1, ctx) {}
|
||||
|
||||
PatternMatchResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
PatternMatchResult matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const final {
|
||||
auto inputShape = operands[0].getType().cast<MemRefType>().getShape();
|
||||
ONNXReshapeOpOperandAdaptor operandAdaptor(operands);
|
||||
auto loc = op->getLoc();
|
||||
Value data = operandAdaptor.data();
|
||||
auto inputShape = data.getType().cast<MemRefType>().getShape();
|
||||
|
||||
// Insert an allocation and deallocation for the result of this operation.
|
||||
auto memRefType = convertToMemRefType(*op->result_type_begin());
|
||||
|
@ -33,7 +34,7 @@ struct ONNXReshapeOpLowering : public ConversionPattern {
|
|||
for (int i = 0; i < inputShape.size(); ++i) {
|
||||
Value dimVal;
|
||||
if (inputShape[i] < 0) {
|
||||
Value dim = rewriter.create<DimOp>(loc, operands[0], i);
|
||||
Value dim = rewriter.create<DimOp>(loc, data, i);
|
||||
dimVal =
|
||||
rewriter.create<IndexCastOp>(loc, dim, rewriter.getIntegerType(64));
|
||||
} else {
|
||||
|
@ -61,8 +62,7 @@ struct ONNXReshapeOpLowering : public ConversionPattern {
|
|||
rewriter.getIntegerType(64), getMemRefEltSizeInBytes(memRefType));
|
||||
SmallVector<Value, 4> DimInfo;
|
||||
for (int i = 0; i < memRefShape.size(); ++i) {
|
||||
Value index =
|
||||
emitConstantOp(rewriter, loc, rewriter.getIndexType(), i);
|
||||
Value index = emitConstantOp(rewriter, loc, rewriter.getIndexType(), i);
|
||||
// Load index from array of indices.
|
||||
Value loadedVal = rewriter.create<LoadOp>(loc, operands[1], index);
|
||||
// If a dimension is zero, the actual dimension value is taken from the
|
||||
|
@ -75,7 +75,7 @@ struct ONNXReshapeOpLowering : public ConversionPattern {
|
|||
Value dimVal;
|
||||
auto loadedValType = loadedVal.getType().cast<IntegerType>();
|
||||
if (inputShape[i] < 0) {
|
||||
Value dim = rewriter.create<DimOp>(loc, operands[0], i);
|
||||
Value dim = rewriter.create<DimOp>(loc, data, i);
|
||||
dimVal = rewriter.create<IndexCastOp>(loc, dim, loadedValType);
|
||||
} else {
|
||||
dimVal =
|
||||
|
@ -136,7 +136,7 @@ struct ONNXReshapeOpLowering : public ConversionPattern {
|
|||
alloc = allocateMemref;
|
||||
}
|
||||
|
||||
rewriter.create<KrnlMemcpyOp>(loc, alloc, operands[0], tensorSize);
|
||||
rewriter.create<KrnlMemcpyOp>(loc, alloc, data, tensorSize);
|
||||
rewriter.replaceOp(op, alloc);
|
||||
|
||||
return matchSuccess();
|
||||
|
|
|
@ -16,20 +16,21 @@ struct ONNXTransposeOpLowering : public ConversionPattern {
|
|||
ONNXTransposeOpLowering(MLIRContext *ctx)
|
||||
: ConversionPattern(mlir::ONNXTransposeOp::getOperationName(), 1, ctx) {}
|
||||
|
||||
PatternMatchResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
PatternMatchResult matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const final {
|
||||
ONNXTransposeOpOperandAdaptor operandAdaptor(operands);
|
||||
auto loc = op->getLoc();
|
||||
// Insert an allocation and deallocation for the result of this operation.
|
||||
auto memRefType = convertToMemRefType(*op->result_type_begin());
|
||||
Value alloc;
|
||||
bool insertDealloc = checkInsertDealloc(op);
|
||||
Value data = operandAdaptor.data();
|
||||
|
||||
if (hasAllConstantDimensions(memRefType))
|
||||
alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc);
|
||||
else
|
||||
alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc,
|
||||
{operands[0]});
|
||||
alloc = insertAllocAndDealloc(
|
||||
memRefType, loc, rewriter, insertDealloc, {data});
|
||||
|
||||
// Number of loops
|
||||
auto memRefShape = memRefType.getShape();
|
||||
|
@ -38,13 +39,13 @@ struct ONNXTransposeOpLowering : public ConversionPattern {
|
|||
// Define loops.
|
||||
std::vector<Value> originalLoops;
|
||||
std::vector<Value> optimizedLoops;
|
||||
Block *optimizationBlock = defineLoops(rewriter, loc, originalLoops,
|
||||
optimizedLoops, rank);
|
||||
Block *optimizationBlock =
|
||||
defineLoops(rewriter, loc, originalLoops, optimizedLoops, rank);
|
||||
|
||||
KrnlIterateOperandPack pack(rewriter, originalLoops, optimizedLoops);
|
||||
// Iterate over the loop nest using the input shape.
|
||||
for (int i = 0; i < rank; ++i)
|
||||
addDimensionToPack(rewriter, loc, pack, operands[0], i);
|
||||
addDimensionToPack(rewriter, loc, pack, data, i);
|
||||
|
||||
auto iterateOp = rewriter.create<KrnlIterateOp>(loc, pack);
|
||||
Block &iterationBlock = iterateOp.bodyRegion().front();
|
||||
|
@ -87,7 +88,7 @@ struct ONNXTransposeOpLowering : public ConversionPattern {
|
|||
for (int i = 0; i < iterationBlock.getArguments().size(); ++i)
|
||||
outLoopIVs.emplace_back(iterationBlock.getArguments()[perm[i]]);
|
||||
|
||||
auto inVal = rewriter.create<LoadOp>(loc, operands[0], inLoopIVs);
|
||||
auto inVal = rewriter.create<LoadOp>(loc, data, inLoopIVs);
|
||||
rewriter.create<StoreOp>(loc, inVal, alloc, outLoopIVs);
|
||||
|
||||
rewriter.replaceOp(op, alloc);
|
||||
|
|
|
@ -16,12 +16,13 @@ struct ONNXUnsqueezeOpLowering : public ConversionPattern {
|
|||
ONNXUnsqueezeOpLowering(MLIRContext *ctx)
|
||||
: ConversionPattern(mlir::ONNXUnsqueezeOp::getOperationName(), 1, ctx) {}
|
||||
|
||||
PatternMatchResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
PatternMatchResult matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const final {
|
||||
ONNXUnsqueezeOpOperandAdaptor operandAdaptor(operands);
|
||||
auto loc = op->getLoc();
|
||||
auto memRefType = convertToMemRefType(*op->result_type_begin());
|
||||
int outRank = memRefType.getRank();
|
||||
Value data = operandAdaptor.data();
|
||||
|
||||
// Assume that `axes` has been validated by shape inference.
|
||||
// So, here we just get it.
|
||||
|
@ -55,7 +56,7 @@ struct ONNXUnsqueezeOpLowering : public ConversionPattern {
|
|||
for (int outIdx = 0, inIdx = 0; outIdx < memRefShape.size(); ++outIdx) {
|
||||
Value dimVal = nullptr;
|
||||
if (memRefShape[outIdx] < 0) {
|
||||
Value index = rewriter.create<DimOp>(loc, operands[0], inIdx);
|
||||
Value index = rewriter.create<DimOp>(loc, data, inIdx);
|
||||
dimVal = rewriter.create<IndexCastOp>(
|
||||
loc, index, rewriter.getIntegerType(64));
|
||||
allocOperands.emplace_back(index);
|
||||
|
@ -74,7 +75,7 @@ struct ONNXUnsqueezeOpLowering : public ConversionPattern {
|
|||
dealloc.getOperation()->moveBefore(&parentBlock->back());
|
||||
}
|
||||
}
|
||||
rewriter.create<KrnlMemcpyOp>(loc, alloc, operands[0], tensorSize);
|
||||
rewriter.create<KrnlMemcpyOp>(loc, alloc, data, tensorSize);
|
||||
rewriter.replaceOp(op, alloc);
|
||||
return matchSuccess();
|
||||
}
|
||||
|
|
|
@ -28,9 +28,7 @@ using namespace mlir;
|
|||
namespace {
|
||||
|
||||
static FlatSymbolRefAttr getOrInsertExternFunc(StringRef funcName,
|
||||
ModuleOp module,
|
||||
mlir::LLVM::LLVMType funcType,
|
||||
PatternRewriter &rewriter) {
|
||||
ModuleOp module, mlir::LLVM::LLVMType funcType, PatternRewriter &rewriter) {
|
||||
auto *context = module.getContext();
|
||||
if (module.lookupSymbol<LLVM::LLVMFuncOp>(funcName)) {
|
||||
auto symbolRef = SymbolRefAttr::get(funcName, context);
|
||||
|
@ -71,10 +69,10 @@ public:
|
|||
explicit KrnlMemcpyOpLowering(MLIRContext *context)
|
||||
: ConversionPattern(KrnlMemcpyOp::getOperationName(), 1, context) {}
|
||||
|
||||
PatternMatchResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
PatternMatchResult matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto *context = op->getContext();
|
||||
KrnlMemcpyOpOperandAdaptor operandAdaptor(operands);
|
||||
auto loc = op->getLoc();
|
||||
auto *llvmDialect =
|
||||
op->getContext()->getRegisteredDialect<LLVM::LLVMDialect>();
|
||||
|
@ -85,33 +83,37 @@ public:
|
|||
auto memcpyRef = getOrInsertMemcpy(rewriter, parentModule, llvmDialect);
|
||||
|
||||
// First operand.
|
||||
Type dstType =
|
||||
operands[0].getType().cast<LLVM::LLVMType>().getStructElementType(1);
|
||||
Type dstType = operandAdaptor.dest()
|
||||
.getType()
|
||||
.cast<LLVM::LLVMType>()
|
||||
.getStructElementType(1);
|
||||
Value alignedDstMemory = rewriter.create<LLVM::ExtractValueOp>(
|
||||
loc, dstType, operands[0], rewriter.getI64ArrayAttr(1));
|
||||
loc, dstType, operandAdaptor.dest(), rewriter.getI64ArrayAttr(1));
|
||||
Value alignedInt8PtrDstMemory = rewriter.create<LLVM::BitcastOp>(
|
||||
loc, LLVM::LLVMType::getInt8PtrTy(llvmDialect), alignedDstMemory);
|
||||
|
||||
// Second operand.
|
||||
Type srcType =
|
||||
operands[1].getType().cast<LLVM::LLVMType>().getStructElementType(1);
|
||||
Type srcType = operandAdaptor.src()
|
||||
.getType()
|
||||
.cast<LLVM::LLVMType>()
|
||||
.getStructElementType(1);
|
||||
Value alignedSrcMemory = rewriter.create<LLVM::ExtractValueOp>(
|
||||
loc, srcType, operands[1], rewriter.getI64ArrayAttr(1));
|
||||
loc, srcType, operandAdaptor.src(), rewriter.getI64ArrayAttr(1));
|
||||
Value alignedInt8PtrSrcMemory = rewriter.create<LLVM::BitcastOp>(
|
||||
loc, LLVM::LLVMType::getInt8PtrTy(llvmDialect), alignedSrcMemory);
|
||||
|
||||
// Size.
|
||||
Value int64Size = rewriter.create<LLVM::SExtOp>(
|
||||
loc, LLVM::LLVMType::getInt64Ty(llvmDialect), operands[2]);
|
||||
loc, LLVM::LLVMType::getInt64Ty(llvmDialect), operandAdaptor.size());
|
||||
|
||||
// Is volatile (set to false).
|
||||
Value isVolatile = rewriter.create<LLVM::ConstantOp>(
|
||||
loc, LLVM::LLVMType::getInt1Ty(llvmDialect),
|
||||
Value isVolatile = rewriter.create<LLVM::ConstantOp>(loc,
|
||||
LLVM::LLVMType::getInt1Ty(llvmDialect),
|
||||
rewriter.getIntegerAttr(rewriter.getIntegerType(1), 0));
|
||||
|
||||
// Memcpy call
|
||||
rewriter.create<CallOp>(
|
||||
loc, memcpyRef, LLVM::LLVMType::getVoidTy(llvmDialect),
|
||||
rewriter.create<CallOp>(loc, memcpyRef,
|
||||
LLVM::LLVMType::getVoidTy(llvmDialect),
|
||||
ArrayRef<Value>({alignedInt8PtrDstMemory, alignedInt8PtrSrcMemory,
|
||||
int64Size, isVolatile}));
|
||||
|
||||
|
@ -123,8 +125,7 @@ private:
|
|||
/// Return a symbol reference to the memcpy function, inserting it into the
|
||||
/// module if necessary.
|
||||
static FlatSymbolRefAttr getOrInsertMemcpy(PatternRewriter &rewriter,
|
||||
ModuleOp module,
|
||||
LLVM::LLVMDialect *llvmDialect) {
|
||||
ModuleOp module, LLVM::LLVMDialect *llvmDialect) {
|
||||
auto *context = module.getContext();
|
||||
if (module.lookupSymbol<LLVM::LLVMFuncOp>("llvm.memcpy.p0i8.p0i8.i64"))
|
||||
return SymbolRefAttr::get("llvm.memcpy.p0i8.p0i8.i64", context);
|
||||
|
@ -134,8 +135,7 @@ private:
|
|||
auto llvmI8PtrTy = LLVM::LLVMType::getInt8PtrTy(llvmDialect);
|
||||
auto llvmI64Ty = LLVM::LLVMType::getInt64Ty(llvmDialect);
|
||||
auto llvmI1Ty = LLVM::LLVMType::getInt1Ty(llvmDialect);
|
||||
auto llvmFnType = LLVM::LLVMType::getFunctionTy(
|
||||
llvmVoidTy,
|
||||
auto llvmFnType = LLVM::LLVMType::getFunctionTy(llvmVoidTy,
|
||||
ArrayRef<mlir::LLVM::LLVMType>(
|
||||
{llvmI8PtrTy, llvmI8PtrTy, llvmI64Ty, llvmI1Ty}),
|
||||
false);
|
||||
|
@ -143,8 +143,8 @@ private:
|
|||
// Insert the memcpy function into the body of the parent module.
|
||||
PatternRewriter::InsertionGuard insertGuard(rewriter);
|
||||
rewriter.setInsertionPointToStart(module.getBody());
|
||||
rewriter.create<LLVM::LLVMFuncOp>(module.getLoc(),
|
||||
"llvm.memcpy.p0i8.p0i8.i64", llvmFnType);
|
||||
rewriter.create<LLVM::LLVMFuncOp>(
|
||||
module.getLoc(), "llvm.memcpy.p0i8.p0i8.i64", llvmFnType);
|
||||
return SymbolRefAttr::get("llvm.memcpy.p0i8.p0i8.i64", context);
|
||||
}
|
||||
};
|
||||
|
@ -186,8 +186,8 @@ public:
|
|||
}
|
||||
};
|
||||
|
||||
PatternMatchResult matchAndRewrite(KrnlEntryPointOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
PatternMatchResult matchAndRewrite(
|
||||
KrnlEntryPointOp op, PatternRewriter &rewriter) const override {
|
||||
|
||||
auto *llvmDialect =
|
||||
op.getContext()->getRegisteredDialect<LLVM::LLVMDialect>();
|
||||
|
@ -261,8 +261,8 @@ public:
|
|||
|
||||
// Fill in the memref underlying ptrToMemRef with information extracted
|
||||
// from dynMemRef.
|
||||
fillPtrToMemRefWithDynMemRef(dynMemRef, ptrToMemRef, rewriter, loc,
|
||||
apiRegistry, llvmDialect);
|
||||
fillPtrToMemRefWithDynMemRef(
|
||||
dynMemRef, ptrToMemRef, rewriter, loc, apiRegistry, llvmDialect);
|
||||
|
||||
// ptrToMemRef will be an input to main computation graph function.
|
||||
staticInputs.emplace_back(ptrToMemRef);
|
||||
|
@ -273,14 +273,14 @@ public:
|
|||
assert(numOutputs == 1 && "only support 1 output tensor now.");
|
||||
|
||||
// Call static entry point with the memref ptrs created, and get output.
|
||||
auto outputMemRefs = rewriter.create<LLVM::CallOp>(
|
||||
loc, staticEntryPointTy.getFunctionResultType(),
|
||||
auto outputMemRefs = rewriter.create<LLVM::CallOp>(loc,
|
||||
staticEntryPointTy.getFunctionResultType(),
|
||||
rewriter.getSymbolRefAttr(wrappedStaticEntryPointFuncName),
|
||||
staticInputs);
|
||||
|
||||
// Create wrapped output.
|
||||
auto wrappedOutput = callApi(rewriter, loc, apiRegistry,
|
||||
API::CREATE_ORDERED_DYN_MEM_REF_DICT, {});
|
||||
auto wrappedOutput = callApi(
|
||||
rewriter, loc, apiRegistry, API::CREATE_ORDERED_DYN_MEM_REF_DICT, {});
|
||||
|
||||
// Get the first memref returned, convert to a dynamic memref and store
|
||||
// it in the wrapped Output.
|
||||
|
@ -291,16 +291,16 @@ public:
|
|||
loc, int32Ty, rewriter.getI32IntegerAttr(outMemRefRank));
|
||||
auto outDynMemRef = callApi(rewriter, loc, apiRegistry,
|
||||
API::CREATE_DYN_MEM_REF, {outMemRefRankVal});
|
||||
fillDynMemRefWithMemRef(outMemRef, outDynMemRef, rewriter, loc, apiRegistry,
|
||||
llvmDialect);
|
||||
fillDynMemRefWithMemRef(
|
||||
outMemRef, outDynMemRef, rewriter, loc, apiRegistry, llvmDialect);
|
||||
auto zero = rewriter.create<LLVM::ConstantOp>(
|
||||
loc, int32Ty, rewriter.getI32IntegerAttr(0));
|
||||
callApi(rewriter, loc, apiRegistry, API::SET_DYN_MEM_REF,
|
||||
{wrappedOutput, zero, outDynMemRef});
|
||||
|
||||
// Return wrapped output.
|
||||
rewriter.create<LLVM::ReturnOp>(loc,
|
||||
SmallVector<Value, 1>({wrappedOutput}));
|
||||
rewriter.create<LLVM::ReturnOp>(
|
||||
loc, SmallVector<Value, 1>({wrappedOutput}));
|
||||
return matchSuccess();
|
||||
}
|
||||
|
||||
|
@ -335,8 +335,8 @@ private:
|
|||
// identities to a symbol reference to the API function.
|
||||
ApiRegistry registry;
|
||||
for (auto &apiSpec : apiSpecs) {
|
||||
apiSpec.symbolRef = getOrInsertExternFunc(apiSpec.name, module,
|
||||
apiSpec.funcTy(), rewriter);
|
||||
apiSpec.symbolRef = getOrInsertExternFunc(
|
||||
apiSpec.name, module, apiSpec.funcTy(), rewriter);
|
||||
registry.emplace(apiSpec.id, apiSpec);
|
||||
}
|
||||
|
||||
|
@ -347,9 +347,9 @@ private:
|
|||
// returned, otherwise return nullptr.
|
||||
Value callApi(PatternRewriter &rewriter, Location loc, ApiRegistry registry,
|
||||
API apiId, ArrayRef<Value> params) const {
|
||||
auto returnVals = rewriter.create<LLVM::CallOp>(
|
||||
loc, registry.at(apiId).outputTy, registry.at(apiId).symbolRef,
|
||||
ArrayRef<Value>(params));
|
||||
auto returnVals =
|
||||
rewriter.create<LLVM::CallOp>(loc, registry.at(apiId).outputTy,
|
||||
registry.at(apiId).symbolRef, ArrayRef<Value>(params));
|
||||
if (returnVals.getNumResults() == 1)
|
||||
return returnVals.getResult(0);
|
||||
return nullptr;
|
||||
|
@ -370,8 +370,7 @@ private:
|
|||
}
|
||||
|
||||
void fillPtrToMemRefWithDynMemRef(Value &dynMemRef, Value &ptrToMemRef,
|
||||
PatternRewriter &rewriter,
|
||||
const Location &loc,
|
||||
PatternRewriter &rewriter, const Location &loc,
|
||||
const std::map<API, ApiSpec> &apiRegistry,
|
||||
LLVM::LLVMDialect *llvmDialect) const {
|
||||
auto memRefPtrTy = ptrToMemRef.getType().dyn_cast<LLVM::LLVMType>();
|
||||
|
@ -385,18 +384,15 @@ private:
|
|||
callApi(rewriter, loc, apiRegistry, API::GET_DATA, {dynMemRef});
|
||||
dataPtr = rewriter.create<LLVM::BitcastOp>(
|
||||
loc, memRefTy.getStructElementType(0), dataPtr);
|
||||
memRef = rewriter.create<LLVM::InsertValueOp>(
|
||||
loc, memRefTy, memRef, dataPtr,
|
||||
rewriter.getArrayAttr({rewriter.getI32IntegerAttr(0)}));
|
||||
memRef = rewriter.create<LLVM::InsertValueOp>(
|
||||
loc, memRefTy, memRef, dataPtr,
|
||||
rewriter.getArrayAttr({rewriter.getI32IntegerAttr(1)}));
|
||||
memRef = rewriter.create<LLVM::InsertValueOp>(loc, memRefTy, memRef,
|
||||
dataPtr, rewriter.getArrayAttr({rewriter.getI32IntegerAttr(0)}));
|
||||
memRef = rewriter.create<LLVM::InsertValueOp>(loc, memRefTy, memRef,
|
||||
dataPtr, rewriter.getArrayAttr({rewriter.getI32IntegerAttr(1)}));
|
||||
|
||||
// Use zero offset now.
|
||||
auto zero = rewriter.create<LLVM::ConstantOp>(
|
||||
loc, int64Ty, rewriter.getI64IntegerAttr(0));
|
||||
memRef = rewriter.create<LLVM::InsertValueOp>(
|
||||
loc, memRefTy, memRef, zero,
|
||||
memRef = rewriter.create<LLVM::InsertValueOp>(loc, memRefTy, memRef, zero,
|
||||
rewriter.getArrayAttr({rewriter.getI32IntegerAttr(2)}));
|
||||
|
||||
// Get rank, sizes array ptr and strides array ptr.
|
||||
|
@ -411,24 +407,22 @@ private:
|
|||
loc, int64Ty, rewriter.getI64IntegerAttr(i));
|
||||
|
||||
// Insert size of the dimension.
|
||||
auto dimSizePtr = rewriter.create<LLVM::GEPOp>(
|
||||
loc, int64Ty.getPointerTo(), sizesArrayPtr,
|
||||
ArrayRef<Value>({dimIdx}));
|
||||
auto dimSize = rewriter.create<LLVM::LoadOp>(loc, int64Ty.getPointerTo(),
|
||||
dimSizePtr);
|
||||
memRef = rewriter.create<LLVM::InsertValueOp>(
|
||||
loc, memRefTy, memRef, dimSize,
|
||||
auto dimSizePtr = rewriter.create<LLVM::GEPOp>(loc,
|
||||
int64Ty.getPointerTo(), sizesArrayPtr, ArrayRef<Value>({dimIdx}));
|
||||
auto dimSize = rewriter.create<LLVM::LoadOp>(
|
||||
loc, int64Ty.getPointerTo(), dimSizePtr);
|
||||
memRef = rewriter.create<LLVM::InsertValueOp>(loc, memRefTy, memRef,
|
||||
dimSize,
|
||||
rewriter.getArrayAttr(
|
||||
{rewriter.getI64IntegerAttr(3), rewriter.getI64IntegerAttr(i)}));
|
||||
|
||||
// Insert stride of the dimension.
|
||||
auto dimStridePtr = rewriter.create<LLVM::GEPOp>(
|
||||
loc, int64Ty.getPointerTo(), sizesArrayPtr,
|
||||
ArrayRef<Value>({dimIdx}));
|
||||
auto dimStridePtr = rewriter.create<LLVM::GEPOp>(loc,
|
||||
int64Ty.getPointerTo(), sizesArrayPtr, ArrayRef<Value>({dimIdx}));
|
||||
auto dimStride = rewriter.create<LLVM::LoadOp>(
|
||||
loc, int64Ty.getPointerTo(), dimStridePtr);
|
||||
memRef = rewriter.create<LLVM::InsertValueOp>(
|
||||
loc, memRefTy, memRef, dimStride,
|
||||
memRef = rewriter.create<LLVM::InsertValueOp>(loc, memRefTy, memRef,
|
||||
dimStride,
|
||||
rewriter.getArrayAttr(
|
||||
{rewriter.getI64IntegerAttr(4), rewriter.getI64IntegerAttr(i)}));
|
||||
}
|
||||
|
@ -444,8 +438,8 @@ private:
|
|||
auto int64Ty = LLVM::LLVMType::getInt64Ty(llvmDialect);
|
||||
|
||||
// Extract the data pointer, and record it in dynamic mem ref created.
|
||||
Value outMemRefDataPtr = rewriter.create<LLVM::ExtractValueOp>(
|
||||
loc, outMemRefTy.getStructElementType(0), outMemRef,
|
||||
Value outMemRefDataPtr = rewriter.create<LLVM::ExtractValueOp>(loc,
|
||||
outMemRefTy.getStructElementType(0), outMemRef,
|
||||
rewriter.getArrayAttr({rewriter.getI64IntegerAttr(0)}));
|
||||
outMemRefDataPtr = rewriter.create<LLVM::BitcastOp>(
|
||||
loc, LLVM::LLVMType::getInt8PtrTy(llvmDialect), outMemRefDataPtr);
|
||||
|
@ -463,23 +457,21 @@ private:
|
|||
loc, int64Ty, rewriter.getI64IntegerAttr(i));
|
||||
|
||||
// Transfer size of dimension from memref to dynamic memref.
|
||||
auto dimSize = rewriter.create<LLVM::ExtractValueOp>(
|
||||
loc, int64Ty, outMemRef,
|
||||
auto dimSize = rewriter.create<LLVM::ExtractValueOp>(loc, int64Ty,
|
||||
outMemRef,
|
||||
rewriter.getArrayAttr(
|
||||
{rewriter.getI64IntegerAttr(3), rewriter.getI64IntegerAttr(i)}));
|
||||
auto dimSizePtr = rewriter.create<LLVM::GEPOp>(
|
||||
loc, int64Ty.getPointerTo(), sizesArrayPtr,
|
||||
ArrayRef<Value>({dimIdx}));
|
||||
auto dimSizePtr = rewriter.create<LLVM::GEPOp>(loc,
|
||||
int64Ty.getPointerTo(), sizesArrayPtr, ArrayRef<Value>({dimIdx}));
|
||||
rewriter.create<LLVM::StoreOp>(loc, dimSize, dimSizePtr);
|
||||
|
||||
// Transfer stride of dimension from memref to dynamic memref.
|
||||
auto dimStride = rewriter.create<LLVM::ExtractValueOp>(
|
||||
loc, int64Ty, outMemRef,
|
||||
auto dimStride = rewriter.create<LLVM::ExtractValueOp>(loc, int64Ty,
|
||||
outMemRef,
|
||||
rewriter.getArrayAttr(
|
||||
{rewriter.getI64IntegerAttr(4), rewriter.getI64IntegerAttr(i)}));
|
||||
auto dimStridePtr = rewriter.create<LLVM::GEPOp>(
|
||||
loc, int64Ty.getPointerTo(), stridesArrayPtr,
|
||||
ArrayRef<Value>({dimIdx}));
|
||||
auto dimStridePtr = rewriter.create<LLVM::GEPOp>(loc,
|
||||
int64Ty.getPointerTo(), stridesArrayPtr, ArrayRef<Value>({dimIdx}));
|
||||
rewriter.create<LLVM::StoreOp>(loc, dimStride, dimStridePtr);
|
||||
}
|
||||
}
|
||||
|
@ -530,5 +522,5 @@ std::unique_ptr<mlir::Pass> mlir::createKrnlLowerToLLVMPass() {
|
|||
return std::make_unique<KrnlToLLVMLoweringPass>();
|
||||
}
|
||||
|
||||
static PassRegistration<KrnlToLLVMLoweringPass>
|
||||
pass("lower-all-llvm", "Lower the Krnl Affine and Std dialects to LLVM.");
|
||||
static PassRegistration<KrnlToLLVMLoweringPass> pass(
|
||||
"lower-all-llvm", "Lower the Krnl Affine and Std dialects to LLVM.");
|
||||
|
|
Loading…
Reference in New Issue