clean operands using names provided by operandAdaptors (#56)

* clean operands using names provided by operandAdaptors

* reverted changes that were erronerous, from Tung's comment

* clang format issues
This commit is contained in:
Alexandre Eichenberger 2020-03-31 11:55:27 -04:00 committed by GitHub
parent 844dcd8b1f
commit b422116f12
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 175 additions and 179 deletions

View File

@ -24,10 +24,11 @@ struct ONNXGemmOpLowering : public ConversionPattern {
bool hasBias = !op->getOperand(2).getType().isa<NoneType>();
Value A, B, C;
A = operands[0];
B = operands[1];
ONNXGemmOpOperandAdaptor operandAdaptor(operands);
A = operandAdaptor.A();
B = operandAdaptor.B();
if (hasBias)
C = operands[2];
C = operandAdaptor.C();
auto memRefType = convertToMemRefType(*op->result_type_begin());

View File

@ -21,8 +21,9 @@ struct ONNXMatMulOpLowering : public ConversionPattern {
ConversionPatternRewriter &rewriter) const final {
auto loc = op->getLoc();
Value A = operands[0];
Value B = operands[1];
ONNXMatMulOpOperandAdaptor operandAdaptor(operands);
Value A = operandAdaptor.A();
Value B = operandAdaptor.B();
auto AShape = A.getType().cast<MemRefType>().getShape();
auto BShape = B.getType().cast<MemRefType>().getShape();

View File

@ -15,8 +15,7 @@ using namespace mlir;
struct ONNXSoftmaxOpLowering : public ConversionPattern {
ONNXSoftmaxOpLowering(MLIRContext *ctx)
: ConversionPattern(mlir::ONNXSoftmaxOp::getOperationName(), 1, ctx) {}
PatternMatchResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
PatternMatchResult matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
// softmax(x) = let max_x = max(x) in
// let exp_x = exp(x - max_x) in
@ -29,7 +28,8 @@ struct ONNXSoftmaxOpLowering : public ConversionPattern {
assert(axis >= -rank && axis <= rank - 1);
auto loc = op->getLoc();
ONNXSoftmaxOpOperandAdaptor operandAdaptor(operands);
Value input = operandAdaptor.input();
// Insert an allocation and deallocation for the result of this operation.
auto elementType = memRefType.getElementType();
@ -38,8 +38,8 @@ struct ONNXSoftmaxOpLowering : public ConversionPattern {
if (hasAllConstantDimensions(memRefType))
alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc);
else
alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc,
operands[0]);
alloc = insertAllocAndDealloc(
memRefType, loc, rewriter, insertDealloc, input);
// Shape of the result
auto memRefShape = memRefType.getShape();
@ -49,15 +49,14 @@ struct ONNXSoftmaxOpLowering : public ConversionPattern {
Value sumOp = insertAllocAndDealloc(scalarMemRefType, loc, rewriter, true);
Value maxOp = insertAllocAndDealloc(scalarMemRefType, loc, rewriter, true);
Value zero = emitConstantOp(rewriter, loc, elementType, 0);
Value negInfinity = rewriter.create<ConstantOp>(
loc,
Value negInfinity = rewriter.create<ConstantOp>(loc,
FloatAttr::get(elementType, -std::numeric_limits<float>::infinity()));
// Define loops.
std::vector<Value> originalLoops;
std::vector<Value> optimizedLoops;
Block *optimizationBlock = defineLoops(rewriter, loc, originalLoops,
optimizedLoops, rank);
Block *optimizationBlock =
defineLoops(rewriter, loc, originalLoops, optimizedLoops, rank);
// Coerce the input into a 2-D tensor. `axis` will be the coercing point.
// This coercing follows the softmax definition in ONNX:
@ -75,7 +74,7 @@ struct ONNXSoftmaxOpLowering : public ConversionPattern {
}
KrnlIterateOperandPack outerPack(rewriter, outerLoops, optimizedOuterLoops);
for (int i = 0; i < axis; ++i)
addDimensionToPack(rewriter, loc, outerPack, operands[0], i);
addDimensionToPack(rewriter, loc, outerPack, input, i);
// Define an inner loop with respect to axis.
std::vector<Value> innerLoops, optimizedInnerLoops;
@ -87,7 +86,7 @@ struct ONNXSoftmaxOpLowering : public ConversionPattern {
}
KrnlIterateOperandPack innerPack(rewriter, innerLoops, optimizedInnerLoops);
for (int i = axis; i < rank; ++i)
addDimensionToPack(rewriter, loc, innerPack, operands[0], i);
addDimensionToPack(rewriter, loc, innerPack, input, i);
KrnlIterateOp outerIterateOp, maxIterateOp, sumIterateOp, softmaxIterateOp;
SmallVector<Value, 4> outerLoopIVs;
@ -144,7 +143,7 @@ struct ONNXSoftmaxOpLowering : public ConversionPattern {
// Compute the max value.
Value max = rewriter.create<LoadOp>(loc, maxOp);
Value nextMax = rewriter.create<LoadOp>(loc, operands[0], maxLoopIVs);
Value nextMax = rewriter.create<LoadOp>(loc, input, maxLoopIVs);
auto maxCond =
rewriter.create<CmpFOp>(loc, CmpFPredicate::OGT, max, nextMax);
max = rewriter.create<SelectOp>(loc, maxCond, max, nextMax);
@ -167,7 +166,7 @@ struct ONNXSoftmaxOpLowering : public ConversionPattern {
// Sum up values.
Value sum = rewriter.create<LoadOp>(loc, sumOp);
Value next = rewriter.create<LoadOp>(loc, operands[0], sumLoopIVs);
Value next = rewriter.create<LoadOp>(loc, input, sumLoopIVs);
Value sub = rewriter.create<SubFOp>(loc, next, max);
Value exp = rewriter.create<ExpOp>(loc, sub);
sum = rewriter.create<AddFOp>(loc, sum, exp);

View File

@ -26,12 +26,6 @@ struct ONNXConvOpLowering : public ConversionPattern {
bool insertDealloc = checkInsertDealloc(op);
ONNXConvOp convOp = llvm::dyn_cast<ONNXConvOp>(op);
if (hasAllConstantDimensions(memRefType))
alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc);
else
alloc = insertAllocAndDealloc(
memRefType, loc, rewriter, insertDealloc, {operands[0]});
auto resultShape = memRefType.getShape();
auto inputOperand = operandAdaptor.X();
auto inputShape = inputOperand.getType().cast<MemRefType>().getShape();
@ -40,6 +34,12 @@ struct ONNXConvOpLowering : public ConversionPattern {
auto biasOperand = operandAdaptor.B();
bool hasBias = !biasOperand.getType().isa<NoneType>();
if (hasAllConstantDimensions(memRefType))
alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc);
else
alloc = insertAllocAndDealloc(
memRefType, loc, rewriter, insertDealloc, {inputOperand});
// R = Conv(D, K)
//
// The input/output shapes will look like this:

View File

@ -1,4 +1,4 @@
//===----------- Normalization.cpp - Lowering Normalization Ops ------------===//
//===----------- Normalization.cpp - Lowering Normalization Ops -----------===//
//
// Copyright 2019 The IBM Research Authors.
//
@ -18,24 +18,24 @@ struct ONNXBatchNormalizationTestModeOpLowering : public ConversionPattern {
mlir::ONNXBatchNormalizationTestModeOp::getOperationName(), 1,
ctx) {}
PatternMatchResult matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter & rewriter) const final {
ConversionPatternRewriter &rewriter) const final {
// batchnorm{epsilon}(x, scale, bias, mean, variance) =
// scale * (x - mean) / sqrt(variance + epsilon) + bias
ONNXBatchNormalizationTestModeOpOperandAdaptor operandAdaptor(operands);
auto loc = op->getLoc();
auto memRefType = convertToMemRefType(*op->result_type_begin());
auto epsilonAttr =
FloatAttr::get(memRefType.getElementType(),
auto epsilonAttr = FloatAttr::get(memRefType.getElementType(),
llvm::dyn_cast<ONNXBatchNormalizationTestModeOp>(op)
.epsilon()
.convertToFloat());
auto epsilon = rewriter.create<ConstantOp>(loc, epsilonAttr);
auto operand = operands[0];
auto scale = operands[1];
auto bias = operands[2];
auto mean = operands[3];
auto variance = operands[4];
auto operand = operandAdaptor.X();
auto scale = operandAdaptor.scale();
auto bias = operandAdaptor.B();
auto mean = operandAdaptor.mean();
auto variance = operandAdaptor.var();
// Insert an allocation and deallocation for the result of this operation.
Value alloc;
@ -44,8 +44,8 @@ struct ONNXBatchNormalizationTestModeOpLowering : public ConversionPattern {
if (hasAllConstantDimensions(memRefType))
alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc);
else
alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc,
{operand});
alloc = insertAllocAndDealloc(
memRefType, loc, rewriter, insertDealloc, {operand});
// Operand's dimensions can be in the form of NxCxD1xD2x...xDn or N.
// In case of N, C is assumed to be 1.
@ -67,8 +67,8 @@ struct ONNXBatchNormalizationTestModeOpLowering : public ConversionPattern {
SmallVector<Value, 1> loopCIVs;
if (rank > 1) {
KrnlIterateOperandPack cPack(rewriter, originalLoops[1],
optimizedLoops[1]);
KrnlIterateOperandPack cPack(
rewriter, originalLoops[1], optimizedLoops[1]);
addDimensionToPack(rewriter, loc, cPack, operand, 1);
auto cIterateOp = rewriter.create<KrnlIterateOp>(loc, cPack);
Block &cIterationBlock = cIterateOp.bodyRegion().front();

View File

@ -38,6 +38,7 @@ struct ONNXMaxPoolSingleOutOpLowering : public ConversionPattern {
PatternMatchResult matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
ONNXMaxPoolSingleOutOpOperandAdaptor operandAdaptor(operands);
auto loc = op->getLoc();
// Match
@ -71,7 +72,7 @@ struct ONNXMaxPoolSingleOutOpLowering : public ConversionPattern {
dilations.emplace_back(dilation.cast<IntegerAttr>().getInt());
// Type information about the input and result of this operation.
auto &inputOperand = operands[0];
auto inputOperand = operandAdaptor.X();
auto inputShape = inputOperand.getType().cast<MemRefType>().getShape();
auto memRefType = convertToMemRefType(*op->result_type_begin());
auto resultShape = memRefType.getShape();

View File

@ -16,10 +16,10 @@ struct ONNXIdentityOpLowering : public ConversionPattern {
ONNXIdentityOpLowering(MLIRContext *ctx)
: ConversionPattern(mlir::ONNXIdentityOp::getOperationName(), 1, ctx) {}
PatternMatchResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
PatternMatchResult matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
rewriter.replaceOp(op, operands[0]);
ONNXIdentityOpOperandAdaptor operandAdaptor(operands);
rewriter.replaceOp(op, operandAdaptor.input());
return matchSuccess();
}
};

View File

@ -14,13 +14,13 @@ using namespace mlir;
struct ONNXPadConstantValuePadOpLowering : public ConversionPattern {
ONNXPadConstantValuePadOpLowering(MLIRContext *ctx)
: ConversionPattern(mlir::ONNXPadConstantValuePadOp::getOperationName(),
1, ctx) {}
: ConversionPattern(
mlir::ONNXPadConstantValuePadOp::getOperationName(), 1, ctx) {}
PatternMatchResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
PatternMatchResult matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
auto tensorType = (*op->result_type_begin());
ONNXPadConstantValuePadOpOperandAdaptor operandAdaptor(operands);
auto loc = op->getLoc();
// Only constant padding is supported now.
@ -55,7 +55,7 @@ struct ONNXPadConstantValuePadOpLowering : public ConversionPattern {
BuildKrnlLoop valueLoops(rewriter, loc, rank);
valueLoops.createDefineAndOptimizeOp();
for (int i = 0; i < rank; ++i)
valueLoops.pushBounds(0, operands[0], i);
valueLoops.pushBounds(0, operandAdaptor.data(), i);
valueLoops.createIterateOp();
// Copy the input data into the output.
@ -67,7 +67,7 @@ struct ONNXPadConstantValuePadOpLowering : public ConversionPattern {
auto pads = llvm::dyn_cast<ONNXPadConstantValuePadOp>(op).pads();
SmallVector<int64_t, 4> pad_begin;
for (int i = 0; i < pads.size()/2; ++i) {
for (int i = 0; i < pads.size() / 2; ++i) {
pad_begin.emplace_back(pads.getValue()[i].cast<IntegerAttr>().getInt());
}
@ -77,14 +77,14 @@ struct ONNXPadConstantValuePadOpLowering : public ConversionPattern {
if (pad_begin[i] == 0) {
outLoopIVs.emplace_back(valueLoops.getInductionVar(i));
} else {
auto outIV = rewriter.create<AddIOp>(
loc, rewriter.create<ConstantIndexOp>(loc, pad_begin[i]),
auto outIV = rewriter.create<AddIOp>(loc,
rewriter.create<ConstantIndexOp>(loc, pad_begin[i]),
valueLoops.getInductionVar(i));
outLoopIVs.emplace_back(outIV);
}
}
auto inVal = rewriter.create<LoadOp>(loc, operands[0], inLoopIVs);
auto inVal = rewriter.create<LoadOp>(loc, operandAdaptor.data(), inLoopIVs);
rewriter.create<StoreOp>(loc, inVal, alloc, outLoopIVs);
rewriter.setInsertionPointToStart(padLoops.getIterateBlock());

View File

@ -16,11 +16,12 @@ struct ONNXReshapeOpLowering : public ConversionPattern {
ONNXReshapeOpLowering(MLIRContext *ctx)
: ConversionPattern(mlir::ONNXReshapeOp::getOperationName(), 1, ctx) {}
PatternMatchResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
PatternMatchResult matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
auto inputShape = operands[0].getType().cast<MemRefType>().getShape();
ONNXReshapeOpOperandAdaptor operandAdaptor(operands);
auto loc = op->getLoc();
Value data = operandAdaptor.data();
auto inputShape = data.getType().cast<MemRefType>().getShape();
// Insert an allocation and deallocation for the result of this operation.
auto memRefType = convertToMemRefType(*op->result_type_begin());
@ -33,7 +34,7 @@ struct ONNXReshapeOpLowering : public ConversionPattern {
for (int i = 0; i < inputShape.size(); ++i) {
Value dimVal;
if (inputShape[i] < 0) {
Value dim = rewriter.create<DimOp>(loc, operands[0], i);
Value dim = rewriter.create<DimOp>(loc, data, i);
dimVal =
rewriter.create<IndexCastOp>(loc, dim, rewriter.getIntegerType(64));
} else {
@ -61,8 +62,7 @@ struct ONNXReshapeOpLowering : public ConversionPattern {
rewriter.getIntegerType(64), getMemRefEltSizeInBytes(memRefType));
SmallVector<Value, 4> DimInfo;
for (int i = 0; i < memRefShape.size(); ++i) {
Value index =
emitConstantOp(rewriter, loc, rewriter.getIndexType(), i);
Value index = emitConstantOp(rewriter, loc, rewriter.getIndexType(), i);
// Load index from array of indices.
Value loadedVal = rewriter.create<LoadOp>(loc, operands[1], index);
// If a dimension is zero, the actual dimension value is taken from the
@ -75,7 +75,7 @@ struct ONNXReshapeOpLowering : public ConversionPattern {
Value dimVal;
auto loadedValType = loadedVal.getType().cast<IntegerType>();
if (inputShape[i] < 0) {
Value dim = rewriter.create<DimOp>(loc, operands[0], i);
Value dim = rewriter.create<DimOp>(loc, data, i);
dimVal = rewriter.create<IndexCastOp>(loc, dim, loadedValType);
} else {
dimVal =
@ -136,7 +136,7 @@ struct ONNXReshapeOpLowering : public ConversionPattern {
alloc = allocateMemref;
}
rewriter.create<KrnlMemcpyOp>(loc, alloc, operands[0], tensorSize);
rewriter.create<KrnlMemcpyOp>(loc, alloc, data, tensorSize);
rewriter.replaceOp(op, alloc);
return matchSuccess();

View File

@ -16,20 +16,21 @@ struct ONNXTransposeOpLowering : public ConversionPattern {
ONNXTransposeOpLowering(MLIRContext *ctx)
: ConversionPattern(mlir::ONNXTransposeOp::getOperationName(), 1, ctx) {}
PatternMatchResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
PatternMatchResult matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
ONNXTransposeOpOperandAdaptor operandAdaptor(operands);
auto loc = op->getLoc();
// Insert an allocation and deallocation for the result of this operation.
auto memRefType = convertToMemRefType(*op->result_type_begin());
Value alloc;
bool insertDealloc = checkInsertDealloc(op);
Value data = operandAdaptor.data();
if (hasAllConstantDimensions(memRefType))
alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc);
else
alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc,
{operands[0]});
alloc = insertAllocAndDealloc(
memRefType, loc, rewriter, insertDealloc, {data});
// Number of loops
auto memRefShape = memRefType.getShape();
@ -38,13 +39,13 @@ struct ONNXTransposeOpLowering : public ConversionPattern {
// Define loops.
std::vector<Value> originalLoops;
std::vector<Value> optimizedLoops;
Block *optimizationBlock = defineLoops(rewriter, loc, originalLoops,
optimizedLoops, rank);
Block *optimizationBlock =
defineLoops(rewriter, loc, originalLoops, optimizedLoops, rank);
KrnlIterateOperandPack pack(rewriter, originalLoops, optimizedLoops);
// Iterate over the loop nest using the input shape.
for (int i = 0; i < rank; ++i)
addDimensionToPack(rewriter, loc, pack, operands[0], i);
addDimensionToPack(rewriter, loc, pack, data, i);
auto iterateOp = rewriter.create<KrnlIterateOp>(loc, pack);
Block &iterationBlock = iterateOp.bodyRegion().front();
@ -75,7 +76,7 @@ struct ONNXTransposeOpLowering : public ConversionPattern {
// the default case). This means that perm was added by shape
// inference or another pass to contain the values corresponding
// to the default behavior of Transpose.
for (int i = iterationBlock.getArguments().size()-1; i >= 0; i--)
for (int i = iterationBlock.getArguments().size() - 1; i >= 0; i--)
perm.emplace_back(i);
}
@ -84,10 +85,10 @@ struct ONNXTransposeOpLowering : public ConversionPattern {
inLoopIVs.emplace_back(arg);
SmallVector<Value, 4> outLoopIVs;
for (int i=0; i<iterationBlock.getArguments().size(); ++i)
for (int i = 0; i < iterationBlock.getArguments().size(); ++i)
outLoopIVs.emplace_back(iterationBlock.getArguments()[perm[i]]);
auto inVal = rewriter.create<LoadOp>(loc, operands[0], inLoopIVs);
auto inVal = rewriter.create<LoadOp>(loc, data, inLoopIVs);
rewriter.create<StoreOp>(loc, inVal, alloc, outLoopIVs);
rewriter.replaceOp(op, alloc);

View File

@ -16,12 +16,13 @@ struct ONNXUnsqueezeOpLowering : public ConversionPattern {
ONNXUnsqueezeOpLowering(MLIRContext *ctx)
: ConversionPattern(mlir::ONNXUnsqueezeOp::getOperationName(), 1, ctx) {}
PatternMatchResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
PatternMatchResult matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
ONNXUnsqueezeOpOperandAdaptor operandAdaptor(operands);
auto loc = op->getLoc();
auto memRefType = convertToMemRefType(*op->result_type_begin());
int outRank = memRefType.getRank();
Value data = operandAdaptor.data();
// Assume that `axes` has been validated by shape inference.
// So, here we just get it.
@ -55,7 +56,7 @@ struct ONNXUnsqueezeOpLowering : public ConversionPattern {
for (int outIdx = 0, inIdx = 0; outIdx < memRefShape.size(); ++outIdx) {
Value dimVal = nullptr;
if (memRefShape[outIdx] < 0) {
Value index = rewriter.create<DimOp>(loc, operands[0], inIdx);
Value index = rewriter.create<DimOp>(loc, data, inIdx);
dimVal = rewriter.create<IndexCastOp>(
loc, index, rewriter.getIntegerType(64));
allocOperands.emplace_back(index);
@ -74,7 +75,7 @@ struct ONNXUnsqueezeOpLowering : public ConversionPattern {
dealloc.getOperation()->moveBefore(&parentBlock->back());
}
}
rewriter.create<KrnlMemcpyOp>(loc, alloc, operands[0], tensorSize);
rewriter.create<KrnlMemcpyOp>(loc, alloc, data, tensorSize);
rewriter.replaceOp(op, alloc);
return matchSuccess();
}

View File

@ -28,9 +28,7 @@ using namespace mlir;
namespace {
static FlatSymbolRefAttr getOrInsertExternFunc(StringRef funcName,
ModuleOp module,
mlir::LLVM::LLVMType funcType,
PatternRewriter &rewriter) {
ModuleOp module, mlir::LLVM::LLVMType funcType, PatternRewriter &rewriter) {
auto *context = module.getContext();
if (module.lookupSymbol<LLVM::LLVMFuncOp>(funcName)) {
auto symbolRef = SymbolRefAttr::get(funcName, context);
@ -71,10 +69,10 @@ public:
explicit KrnlMemcpyOpLowering(MLIRContext *context)
: ConversionPattern(KrnlMemcpyOp::getOperationName(), 1, context) {}
PatternMatchResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
PatternMatchResult matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto *context = op->getContext();
KrnlMemcpyOpOperandAdaptor operandAdaptor(operands);
auto loc = op->getLoc();
auto *llvmDialect =
op->getContext()->getRegisteredDialect<LLVM::LLVMDialect>();
@ -85,33 +83,37 @@ public:
auto memcpyRef = getOrInsertMemcpy(rewriter, parentModule, llvmDialect);
// First operand.
Type dstType =
operands[0].getType().cast<LLVM::LLVMType>().getStructElementType(1);
Type dstType = operandAdaptor.dest()
.getType()
.cast<LLVM::LLVMType>()
.getStructElementType(1);
Value alignedDstMemory = rewriter.create<LLVM::ExtractValueOp>(
loc, dstType, operands[0], rewriter.getI64ArrayAttr(1));
loc, dstType, operandAdaptor.dest(), rewriter.getI64ArrayAttr(1));
Value alignedInt8PtrDstMemory = rewriter.create<LLVM::BitcastOp>(
loc, LLVM::LLVMType::getInt8PtrTy(llvmDialect), alignedDstMemory);
// Second operand.
Type srcType =
operands[1].getType().cast<LLVM::LLVMType>().getStructElementType(1);
Type srcType = operandAdaptor.src()
.getType()
.cast<LLVM::LLVMType>()
.getStructElementType(1);
Value alignedSrcMemory = rewriter.create<LLVM::ExtractValueOp>(
loc, srcType, operands[1], rewriter.getI64ArrayAttr(1));
loc, srcType, operandAdaptor.src(), rewriter.getI64ArrayAttr(1));
Value alignedInt8PtrSrcMemory = rewriter.create<LLVM::BitcastOp>(
loc, LLVM::LLVMType::getInt8PtrTy(llvmDialect), alignedSrcMemory);
// Size.
Value int64Size = rewriter.create<LLVM::SExtOp>(
loc, LLVM::LLVMType::getInt64Ty(llvmDialect), operands[2]);
loc, LLVM::LLVMType::getInt64Ty(llvmDialect), operandAdaptor.size());
// Is volatile (set to false).
Value isVolatile = rewriter.create<LLVM::ConstantOp>(
loc, LLVM::LLVMType::getInt1Ty(llvmDialect),
Value isVolatile = rewriter.create<LLVM::ConstantOp>(loc,
LLVM::LLVMType::getInt1Ty(llvmDialect),
rewriter.getIntegerAttr(rewriter.getIntegerType(1), 0));
// Memcpy call
rewriter.create<CallOp>(
loc, memcpyRef, LLVM::LLVMType::getVoidTy(llvmDialect),
rewriter.create<CallOp>(loc, memcpyRef,
LLVM::LLVMType::getVoidTy(llvmDialect),
ArrayRef<Value>({alignedInt8PtrDstMemory, alignedInt8PtrSrcMemory,
int64Size, isVolatile}));
@ -123,8 +125,7 @@ private:
/// Return a symbol reference to the memcpy function, inserting it into the
/// module if necessary.
static FlatSymbolRefAttr getOrInsertMemcpy(PatternRewriter &rewriter,
ModuleOp module,
LLVM::LLVMDialect *llvmDialect) {
ModuleOp module, LLVM::LLVMDialect *llvmDialect) {
auto *context = module.getContext();
if (module.lookupSymbol<LLVM::LLVMFuncOp>("llvm.memcpy.p0i8.p0i8.i64"))
return SymbolRefAttr::get("llvm.memcpy.p0i8.p0i8.i64", context);
@ -134,8 +135,7 @@ private:
auto llvmI8PtrTy = LLVM::LLVMType::getInt8PtrTy(llvmDialect);
auto llvmI64Ty = LLVM::LLVMType::getInt64Ty(llvmDialect);
auto llvmI1Ty = LLVM::LLVMType::getInt1Ty(llvmDialect);
auto llvmFnType = LLVM::LLVMType::getFunctionTy(
llvmVoidTy,
auto llvmFnType = LLVM::LLVMType::getFunctionTy(llvmVoidTy,
ArrayRef<mlir::LLVM::LLVMType>(
{llvmI8PtrTy, llvmI8PtrTy, llvmI64Ty, llvmI1Ty}),
false);
@ -143,8 +143,8 @@ private:
// Insert the memcpy function into the body of the parent module.
PatternRewriter::InsertionGuard insertGuard(rewriter);
rewriter.setInsertionPointToStart(module.getBody());
rewriter.create<LLVM::LLVMFuncOp>(module.getLoc(),
"llvm.memcpy.p0i8.p0i8.i64", llvmFnType);
rewriter.create<LLVM::LLVMFuncOp>(
module.getLoc(), "llvm.memcpy.p0i8.p0i8.i64", llvmFnType);
return SymbolRefAttr::get("llvm.memcpy.p0i8.p0i8.i64", context);
}
};
@ -186,8 +186,8 @@ public:
}
};
PatternMatchResult matchAndRewrite(KrnlEntryPointOp op,
PatternRewriter &rewriter) const override {
PatternMatchResult matchAndRewrite(
KrnlEntryPointOp op, PatternRewriter &rewriter) const override {
auto *llvmDialect =
op.getContext()->getRegisteredDialect<LLVM::LLVMDialect>();
@ -261,8 +261,8 @@ public:
// Fill in the memref underlying ptrToMemRef with information extracted
// from dynMemRef.
fillPtrToMemRefWithDynMemRef(dynMemRef, ptrToMemRef, rewriter, loc,
apiRegistry, llvmDialect);
fillPtrToMemRefWithDynMemRef(
dynMemRef, ptrToMemRef, rewriter, loc, apiRegistry, llvmDialect);
// ptrToMemRef will be an input to main computation graph function.
staticInputs.emplace_back(ptrToMemRef);
@ -273,14 +273,14 @@ public:
assert(numOutputs == 1 && "only support 1 output tensor now.");
// Call static entry point with the memref ptrs created, and get output.
auto outputMemRefs = rewriter.create<LLVM::CallOp>(
loc, staticEntryPointTy.getFunctionResultType(),
auto outputMemRefs = rewriter.create<LLVM::CallOp>(loc,
staticEntryPointTy.getFunctionResultType(),
rewriter.getSymbolRefAttr(wrappedStaticEntryPointFuncName),
staticInputs);
// Create wrapped output.
auto wrappedOutput = callApi(rewriter, loc, apiRegistry,
API::CREATE_ORDERED_DYN_MEM_REF_DICT, {});
auto wrappedOutput = callApi(
rewriter, loc, apiRegistry, API::CREATE_ORDERED_DYN_MEM_REF_DICT, {});
// Get the first memref returned, convert to a dynamic memref and store
// it in the wrapped Output.
@ -291,16 +291,16 @@ public:
loc, int32Ty, rewriter.getI32IntegerAttr(outMemRefRank));
auto outDynMemRef = callApi(rewriter, loc, apiRegistry,
API::CREATE_DYN_MEM_REF, {outMemRefRankVal});
fillDynMemRefWithMemRef(outMemRef, outDynMemRef, rewriter, loc, apiRegistry,
llvmDialect);
fillDynMemRefWithMemRef(
outMemRef, outDynMemRef, rewriter, loc, apiRegistry, llvmDialect);
auto zero = rewriter.create<LLVM::ConstantOp>(
loc, int32Ty, rewriter.getI32IntegerAttr(0));
callApi(rewriter, loc, apiRegistry, API::SET_DYN_MEM_REF,
{wrappedOutput, zero, outDynMemRef});
// Return wrapped output.
rewriter.create<LLVM::ReturnOp>(loc,
SmallVector<Value, 1>({wrappedOutput}));
rewriter.create<LLVM::ReturnOp>(
loc, SmallVector<Value, 1>({wrappedOutput}));
return matchSuccess();
}
@ -335,8 +335,8 @@ private:
// identities to a symbol reference to the API function.
ApiRegistry registry;
for (auto &apiSpec : apiSpecs) {
apiSpec.symbolRef = getOrInsertExternFunc(apiSpec.name, module,
apiSpec.funcTy(), rewriter);
apiSpec.symbolRef = getOrInsertExternFunc(
apiSpec.name, module, apiSpec.funcTy(), rewriter);
registry.emplace(apiSpec.id, apiSpec);
}
@ -347,9 +347,9 @@ private:
// returned, otherwise return nullptr.
Value callApi(PatternRewriter &rewriter, Location loc, ApiRegistry registry,
API apiId, ArrayRef<Value> params) const {
auto returnVals = rewriter.create<LLVM::CallOp>(
loc, registry.at(apiId).outputTy, registry.at(apiId).symbolRef,
ArrayRef<Value>(params));
auto returnVals =
rewriter.create<LLVM::CallOp>(loc, registry.at(apiId).outputTy,
registry.at(apiId).symbolRef, ArrayRef<Value>(params));
if (returnVals.getNumResults() == 1)
return returnVals.getResult(0);
return nullptr;
@ -370,8 +370,7 @@ private:
}
void fillPtrToMemRefWithDynMemRef(Value &dynMemRef, Value &ptrToMemRef,
PatternRewriter &rewriter,
const Location &loc,
PatternRewriter &rewriter, const Location &loc,
const std::map<API, ApiSpec> &apiRegistry,
LLVM::LLVMDialect *llvmDialect) const {
auto memRefPtrTy = ptrToMemRef.getType().dyn_cast<LLVM::LLVMType>();
@ -385,18 +384,15 @@ private:
callApi(rewriter, loc, apiRegistry, API::GET_DATA, {dynMemRef});
dataPtr = rewriter.create<LLVM::BitcastOp>(
loc, memRefTy.getStructElementType(0), dataPtr);
memRef = rewriter.create<LLVM::InsertValueOp>(
loc, memRefTy, memRef, dataPtr,
rewriter.getArrayAttr({rewriter.getI32IntegerAttr(0)}));
memRef = rewriter.create<LLVM::InsertValueOp>(
loc, memRefTy, memRef, dataPtr,
rewriter.getArrayAttr({rewriter.getI32IntegerAttr(1)}));
memRef = rewriter.create<LLVM::InsertValueOp>(loc, memRefTy, memRef,
dataPtr, rewriter.getArrayAttr({rewriter.getI32IntegerAttr(0)}));
memRef = rewriter.create<LLVM::InsertValueOp>(loc, memRefTy, memRef,
dataPtr, rewriter.getArrayAttr({rewriter.getI32IntegerAttr(1)}));
// Use zero offset now.
auto zero = rewriter.create<LLVM::ConstantOp>(
loc, int64Ty, rewriter.getI64IntegerAttr(0));
memRef = rewriter.create<LLVM::InsertValueOp>(
loc, memRefTy, memRef, zero,
memRef = rewriter.create<LLVM::InsertValueOp>(loc, memRefTy, memRef, zero,
rewriter.getArrayAttr({rewriter.getI32IntegerAttr(2)}));
// Get rank, sizes array ptr and strides array ptr.
@ -411,24 +407,22 @@ private:
loc, int64Ty, rewriter.getI64IntegerAttr(i));
// Insert size of the dimension.
auto dimSizePtr = rewriter.create<LLVM::GEPOp>(
loc, int64Ty.getPointerTo(), sizesArrayPtr,
ArrayRef<Value>({dimIdx}));
auto dimSize = rewriter.create<LLVM::LoadOp>(loc, int64Ty.getPointerTo(),
dimSizePtr);
memRef = rewriter.create<LLVM::InsertValueOp>(
loc, memRefTy, memRef, dimSize,
auto dimSizePtr = rewriter.create<LLVM::GEPOp>(loc,
int64Ty.getPointerTo(), sizesArrayPtr, ArrayRef<Value>({dimIdx}));
auto dimSize = rewriter.create<LLVM::LoadOp>(
loc, int64Ty.getPointerTo(), dimSizePtr);
memRef = rewriter.create<LLVM::InsertValueOp>(loc, memRefTy, memRef,
dimSize,
rewriter.getArrayAttr(
{rewriter.getI64IntegerAttr(3), rewriter.getI64IntegerAttr(i)}));
// Insert stride of the dimension.
auto dimStridePtr = rewriter.create<LLVM::GEPOp>(
loc, int64Ty.getPointerTo(), sizesArrayPtr,
ArrayRef<Value>({dimIdx}));
auto dimStridePtr = rewriter.create<LLVM::GEPOp>(loc,
int64Ty.getPointerTo(), sizesArrayPtr, ArrayRef<Value>({dimIdx}));
auto dimStride = rewriter.create<LLVM::LoadOp>(
loc, int64Ty.getPointerTo(), dimStridePtr);
memRef = rewriter.create<LLVM::InsertValueOp>(
loc, memRefTy, memRef, dimStride,
memRef = rewriter.create<LLVM::InsertValueOp>(loc, memRefTy, memRef,
dimStride,
rewriter.getArrayAttr(
{rewriter.getI64IntegerAttr(4), rewriter.getI64IntegerAttr(i)}));
}
@ -444,8 +438,8 @@ private:
auto int64Ty = LLVM::LLVMType::getInt64Ty(llvmDialect);
// Extract the data pointer, and record it in dynamic mem ref created.
Value outMemRefDataPtr = rewriter.create<LLVM::ExtractValueOp>(
loc, outMemRefTy.getStructElementType(0), outMemRef,
Value outMemRefDataPtr = rewriter.create<LLVM::ExtractValueOp>(loc,
outMemRefTy.getStructElementType(0), outMemRef,
rewriter.getArrayAttr({rewriter.getI64IntegerAttr(0)}));
outMemRefDataPtr = rewriter.create<LLVM::BitcastOp>(
loc, LLVM::LLVMType::getInt8PtrTy(llvmDialect), outMemRefDataPtr);
@ -463,23 +457,21 @@ private:
loc, int64Ty, rewriter.getI64IntegerAttr(i));
// Transfer size of dimension from memref to dynamic memref.
auto dimSize = rewriter.create<LLVM::ExtractValueOp>(
loc, int64Ty, outMemRef,
auto dimSize = rewriter.create<LLVM::ExtractValueOp>(loc, int64Ty,
outMemRef,
rewriter.getArrayAttr(
{rewriter.getI64IntegerAttr(3), rewriter.getI64IntegerAttr(i)}));
auto dimSizePtr = rewriter.create<LLVM::GEPOp>(
loc, int64Ty.getPointerTo(), sizesArrayPtr,
ArrayRef<Value>({dimIdx}));
auto dimSizePtr = rewriter.create<LLVM::GEPOp>(loc,
int64Ty.getPointerTo(), sizesArrayPtr, ArrayRef<Value>({dimIdx}));
rewriter.create<LLVM::StoreOp>(loc, dimSize, dimSizePtr);
// Transfer stride of dimension from memref to dynamic memref.
auto dimStride = rewriter.create<LLVM::ExtractValueOp>(
loc, int64Ty, outMemRef,
auto dimStride = rewriter.create<LLVM::ExtractValueOp>(loc, int64Ty,
outMemRef,
rewriter.getArrayAttr(
{rewriter.getI64IntegerAttr(4), rewriter.getI64IntegerAttr(i)}));
auto dimStridePtr = rewriter.create<LLVM::GEPOp>(
loc, int64Ty.getPointerTo(), stridesArrayPtr,
ArrayRef<Value>({dimIdx}));
auto dimStridePtr = rewriter.create<LLVM::GEPOp>(loc,
int64Ty.getPointerTo(), stridesArrayPtr, ArrayRef<Value>({dimIdx}));
rewriter.create<LLVM::StoreOp>(loc, dimStride, dimStridePtr);
}
}
@ -530,5 +522,5 @@ std::unique_ptr<mlir::Pass> mlir::createKrnlLowerToLLVMPass() {
return std::make_unique<KrnlToLLVMLoweringPass>();
}
static PassRegistration<KrnlToLLVMLoweringPass>
pass("lower-all-llvm", "Lower the Krnl Affine and Std dialects to LLVM.");
static PassRegistration<KrnlToLLVMLoweringPass> pass(
"lower-all-llvm", "Lower the Krnl Affine and Std dialects to LLVM.");

View File

@ -22,7 +22,7 @@ namespace {
bool hasNonZeroInArrayAttr(ArrayAttr attrs) {
bool allZeros = true;
if (attrs) {
for (auto attr: attrs.getValue()) {
for (auto attr : attrs.getValue()) {
if (attr.cast<IntegerAttr>().getInt() > 0) {
allZeros = false;
break;
@ -54,7 +54,7 @@ ArrayAttr createArrayAttrOfZeros(
// This function is used for padding attribute in MaxPoolSingleOut.
ArrayAttr insertZerosForNonPaddedDims(
PatternRewriter &rewriter, ArrayAttr origAttrs, int extensionLength) {
int nDims = (int) origAttrs.getValue().size() / 2;
int nDims = (int)origAttrs.getValue().size() / 2;
int nElements = (nDims + extensionLength) * 2;
SmallVector<int64_t, 4> pads(nElements, 0);
for (int i = 0; i < nDims; ++i) {