onnx-mlir/src/transform/lower_to_llvm.cpp

586 lines
25 KiB
C++

//====- LowerToLLVM.cpp - Lowering from KRNL+Affine+Std to LLVM -----------===//
//
// Copyright 2019 The IBM Research Authors.
//
//===----------------------------------------------------------------------===//
#include "mlir/Conversion/AffineToStandard/AffineToStandard.h"
#include "mlir/Conversion/LoopToStandard/ConvertLoopToStandard.h"
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
#include "mlir/Dialect/AffineOps/AffineOps.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/LoopOps/LoopOps.h"
#include "mlir/Dialect/StandardOps/Ops.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/Sequence.h"
#include "src/dialect/krnl/krnl_ops.hpp"
#include "src/pass/passes.hpp"
using namespace mlir;
namespace {
static FlatSymbolRefAttr getOrInsertExternFunc(StringRef funcName,
ModuleOp module,
mlir::LLVM::LLVMType funcType,
PatternRewriter &rewriter) {
auto *context = module.getContext();
if (module.lookupSymbol<LLVM::LLVMFuncOp>(funcName)) {
auto symbolRef = SymbolRefAttr::get(funcName, context);
assert(symbolRef.getType() == funcType && "wrong symbol type");
return symbolRef;
}
// Insert the function into the body of the parent module.
PatternRewriter::InsertionGuard insertGuard(rewriter);
rewriter.setInsertionPointToStart(module.getBody());
rewriter.create<LLVM::LLVMFuncOp>(module.getLoc(), funcName, funcType);
return SymbolRefAttr::get(funcName, context);
}
static size_t getRankFromMemRefType(LLVM::LLVMType memRefTy) {
// Usually a MemRef is a 5-element struct, where the 4th and 5th elements in
// this struct are arrays whose size is the rank of the tensor. In the event
// that the corresponding tensor of this MemRef is a scalar, the 4th and 5th
// elements will have 0-length, which in turn causes the MemRef struct to
// degenerate into a 3-element struct. For more information, refer to
// https://github.com/llvm/llvm-project/blob/master/mlir/docs/ConversionToLLVMDialect.md#memref-types.
auto numElems = memRefTy.getStructNumElements();
assert((numElems == 3 || numElems == 5) &&
"Expect MemRef type to contain either 3 or 5 elements.");
if (numElems == 3)
return 0; // MemRef refers to a scalar.
else
return memRefTy.getStructElementType(3).getArrayNumElements();
}
//===----------------------------------------------------------------------===//
// KRNL to LLVM: KrnlMemcpyOpLowering
//===----------------------------------------------------------------------===//
class KrnlMemcpyOpLowering : public ConversionPattern {
public:
explicit KrnlMemcpyOpLowering(MLIRContext *context)
: ConversionPattern(KrnlMemcpyOp::getOperationName(), 1, context) {}
PatternMatchResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto *context = op->getContext();
auto loc = op->getLoc();
auto *llvmDialect =
op->getContext()->getRegisteredDialect<LLVM::LLVMDialect>();
assert(llvmDialect && "expected llvm dialect to be registered");
// Get a symbol reference to the memcpy function, inserting it if necessary.
ModuleOp parentModule = op->getParentOfType<ModuleOp>();
auto memcpyRef = getOrInsertMemcpy(rewriter, parentModule, llvmDialect);
// First operand.
Type dstType =
operands[0].getType().cast<LLVM::LLVMType>().getStructElementType(1);
Value alignedDstMemory = rewriter.create<LLVM::ExtractValueOp>(
loc, dstType, operands[0], rewriter.getI64ArrayAttr(1));
Value alignedInt8PtrDstMemory = rewriter.create<LLVM::BitcastOp>(
loc, LLVM::LLVMType::getInt8PtrTy(llvmDialect), alignedDstMemory);
// Second operand.
Type srcType =
operands[1].getType().cast<LLVM::LLVMType>().getStructElementType(1);
Value alignedSrcMemory = rewriter.create<LLVM::ExtractValueOp>(
loc, srcType, operands[1], rewriter.getI64ArrayAttr(1));
Value alignedInt8PtrSrcMemory = rewriter.create<LLVM::BitcastOp>(
loc, LLVM::LLVMType::getInt8PtrTy(llvmDialect), alignedSrcMemory);
// Size.
Value int64Size = rewriter.create<LLVM::SExtOp>(
loc, LLVM::LLVMType::getInt64Ty(llvmDialect), operands[2]);
// Is volatile (set to false).
Value isVolatile = rewriter.create<LLVM::ConstantOp>(
loc, LLVM::LLVMType::getInt1Ty(llvmDialect),
rewriter.getIntegerAttr(rewriter.getIntegerType(1), 0));
// Memcpy call
rewriter.create<CallOp>(
loc, memcpyRef, LLVM::LLVMType::getVoidTy(llvmDialect),
ArrayRef<Value>({alignedInt8PtrDstMemory, alignedInt8PtrSrcMemory,
int64Size, isVolatile}));
rewriter.eraseOp(op);
return matchSuccess();
}
private:
/// Return a symbol reference to the memcpy function, inserting it into the
/// module if necessary.
static FlatSymbolRefAttr getOrInsertMemcpy(PatternRewriter &rewriter,
ModuleOp module,
LLVM::LLVMDialect *llvmDialect) {
auto *context = module.getContext();
if (module.lookupSymbol<LLVM::LLVMFuncOp>("llvm.memcpy.p0i8.p0i8.i64"))
return SymbolRefAttr::get("llvm.memcpy.p0i8.p0i8.i64", context);
// Create a function declaration for memcpy, the signature is:
// * `void (i8*, i8* , i64, i1)`
auto llvmVoidTy = LLVM::LLVMType::getVoidTy(llvmDialect);
auto llvmI8PtrTy = LLVM::LLVMType::getInt8PtrTy(llvmDialect);
auto llvmI64Ty = LLVM::LLVMType::getInt64Ty(llvmDialect);
auto llvmI1Ty = LLVM::LLVMType::getInt1Ty(llvmDialect);
auto llvmFnType = LLVM::LLVMType::getFunctionTy(
llvmVoidTy,
ArrayRef<mlir::LLVM::LLVMType>(
{llvmI8PtrTy, llvmI8PtrTy, llvmI64Ty, llvmI1Ty}),
false);
// Insert the memcpy function into the body of the parent module.
PatternRewriter::InsertionGuard insertGuard(rewriter);
rewriter.setInsertionPointToStart(module.getBody());
rewriter.create<LLVM::LLVMFuncOp>(module.getLoc(),
"llvm.memcpy.p0i8.p0i8.i64", llvmFnType);
return SymbolRefAttr::get("llvm.memcpy.p0i8.p0i8.i64", context);
}
};
//===----------------------------------------------------------------------===//
// KRNL to LLVM: KrnlEntryPointOp
//===----------------------------------------------------------------------===//
class KrnlEntryPointOpLowering : public OpRewritePattern<KrnlEntryPointOp> {
public:
using OpRewritePattern<KrnlEntryPointOp>::OpRewritePattern;
enum class API {
CREATE_ORDERED_DYN_MEM_REF_DICT,
CREATE_DYN_MEM_REF,
GET_DYN_MEM_REF,
SET_DYN_MEM_REF,
GET_DATA,
SET_DATA,
GET_SIZES,
GET_STRIDES,
};
struct ApiSpec {
API id;
std::string name;
FlatSymbolRefAttr symbolRef;
LLVM::LLVMType outputTy;
SmallVector<LLVM::LLVMType, 4> inputTys;
ApiSpec(API id, const std::string &name, LLVM::LLVMType outputTy,
ArrayRef<LLVM::LLVMType> inputTys)
: id(id), name(name), outputTy(outputTy),
inputTys(inputTys.begin(), inputTys.end()) {}
LLVM::LLVMType funcTy() {
return LLVM::LLVMType::getFunctionTy(outputTy, inputTys,
/*isVarArg=*/false);
}
};
PatternMatchResult matchAndRewrite(KrnlEntryPointOp op,
PatternRewriter &rewriter) const override {
auto *llvmDialect =
op.getContext()->getRegisteredDialect<LLVM::LLVMDialect>();
assert(llvmDialect && "expected llvm dialect to be registered");
auto module = op.getParentOfType<ModuleOp>();
auto apiRegistry = RegisterAllApis(module, rewriter, llvmDialect);
auto loc = op.getLoc();
auto numOutputs =
op.getAttrOfType<IntegerAttr>(KrnlEntryPointOp::getNumOutputsAttrName())
.getInt();
using LLVMType = LLVM::LLVMType;
auto opaquePtrTy = LLVMType::getInt8PtrTy(llvmDialect);
auto int32Ty = LLVMType::getInt32Ty(llvmDialect);
// Rewrite Krnl Entry Point Operation to an LLVM function with a dynamic
// signature. The signature is dynamic because it remains the same no matter
// what the model input/output schema look like. Such dynamic signature
// takes a opaque ptr as input, representing a ptr to a data structure
// containing a set of dynamic memrefs wrapped in a vector; similarly the
// output is also a opaque ptr to a data structure with output memrefs
// wrapped within it.
auto staticEntryPointFuncName =
op.getAttrOfType<SymbolRefAttr>(
KrnlEntryPointOp::getEntryPointFuncAttrName())
.getLeafReference();
auto dynEntryPointName = "_dyn_entry_point_" + staticEntryPointFuncName;
assert(module.lookupSymbol(dynEntryPointName.str()) == nullptr &&
"dynamic entry point name is not unique");
rewriter.eraseOp(op);
auto dynEntryPointFuncTy =
LLVMType::getFunctionTy(opaquePtrTy, {opaquePtrTy}, false);
auto dynamicEntryPointFunc = rewriter.create<LLVM::LLVMFuncOp>(
loc, dynEntryPointName.str(), dynEntryPointFuncTy);
auto &entryPointEntryBlock =
createEntryBlock(dynEntryPointFuncTy, dynamicEntryPointFunc);
rewriter.setInsertionPointToStart(&entryPointEntryBlock);
// Based on the static entry point type signature, unpack dynamic memory
// refs to corresponding static memory refs.
auto *staticEntryPointFunc = module.lookupSymbol(staticEntryPointFuncName);
assert(staticEntryPointFunc &&
isa<LLVM::LLVMFuncOp>(staticEntryPointFunc) &&
"entry point func must exist and be an llvm func op");
auto staticEntryPointTy = dyn_cast<LLVM::LLVMFuncOp>(staticEntryPointFunc)
.getType()
.dyn_cast<LLVMType>();
// Retrieve dynamic mem refs from wrapped input, and convert every one of
// them to static mem refs.
SmallVector<Value, 4> staticInputs;
auto wrappedInput = entryPointEntryBlock.getArgument(0);
for (size_t i = 0; i < staticEntryPointTy.getFunctionNumParams(); i++) {
// Call API function to retrieve the i-th dynamic memref.
auto idxVal = rewriter.create<LLVM::ConstantOp>(
loc, int32Ty, rewriter.getI32IntegerAttr(i));
auto dynMemRef = callApi(rewriter, loc, apiRegistry, API::GET_DYN_MEM_REF,
{wrappedInput, idxVal});
// Create a (static) memref type corresponding to the i-th memref input to
// the inference function on stack, and load it to memRef.
auto memRefPtrTy = staticEntryPointTy.getFunctionParamType(i);
auto memRefTy = memRefPtrTy.getPointerElementTy();
auto one = rewriter.create<LLVM::ConstantOp>(
loc, int32Ty, rewriter.getI32IntegerAttr(1));
Value ptrToMemRef = rewriter.create<LLVM::AllocaOp>(loc, memRefPtrTy, one,
/*alignment=*/0);
// Fill in the memref underlying ptrToMemRef with information extracted
// from dynMemRef.
fillPtrToMemRefWithDynMemRef(dynMemRef, ptrToMemRef, rewriter, loc,
apiRegistry, llvmDialect);
// ptrToMemRef will be an input to main computation graph function.
staticInputs.emplace_back(ptrToMemRef);
}
// If more than one output exists, the struct becomes a nested struct,
// the unpacking logic can be more involved, so no support for now.
assert(numOutputs == 1 && "only support 1 output tensor now.");
// Call static entry point with the memref ptrs created, and get output.
auto outputMemRefs = rewriter.create<LLVM::CallOp>(
loc, staticEntryPointTy.getFunctionResultType(),
rewriter.getSymbolRefAttr(staticEntryPointFuncName), staticInputs);
// Create wrapped output.
auto wrappedOutput = callApi(rewriter, loc, apiRegistry,
API::CREATE_ORDERED_DYN_MEM_REF_DICT, {});
// Get the first memref returned, convert to a dynamic memref and store
// it in the wrapped Output.
auto outMemRef = outputMemRefs.getResult(0);
auto outMemRefTy = outMemRef.getType().dyn_cast<LLVMType>();
auto outMemRefRank = getRankFromMemRefType(outMemRefTy);
auto outMemRefRankVal = rewriter.create<LLVM::ConstantOp>(
loc, int32Ty, rewriter.getI32IntegerAttr(outMemRefRank));
auto outDynMemRef = callApi(rewriter, loc, apiRegistry,
API::CREATE_DYN_MEM_REF, {outMemRefRankVal});
fillDynMemRefWithMemRef(outMemRef, outDynMemRef, rewriter, loc, apiRegistry,
llvmDialect);
auto zero = rewriter.create<LLVM::ConstantOp>(
loc, int32Ty, rewriter.getI32IntegerAttr(0));
callApi(rewriter, loc, apiRegistry, API::SET_DYN_MEM_REF,
{wrappedOutput, zero, outDynMemRef});
// Return wrapped output.
rewriter.create<LLVM::ReturnOp>(loc,
SmallVector<Value, 1>({wrappedOutput}));
return matchSuccess();
}
private:
using ApiRegistry = std::map<API, ApiSpec>;
ApiRegistry RegisterAllApis(ModuleOp &module, PatternRewriter &rewriter,
LLVM::LLVMDialect *llvmDialect) const {
using LLVMType = LLVM::LLVMType;
auto voidTy = LLVMType::getVoidTy(llvmDialect);
auto opaquePtrTy = LLVMType::getInt8PtrTy(llvmDialect);
auto int32Ty = LLVMType::getInt32Ty(llvmDialect);
auto int64Ty = LLVMType::getInt64Ty(llvmDialect);
auto int64PtrTy = int64Ty.getPointerTo();
// Declare API type as an enum value, its string name and an LLVM Type
// specifying its signature.
// clang-format off
std::vector<ApiSpec> apiSpecs = {
ApiSpec(API::CREATE_ORDERED_DYN_MEM_REF_DICT, "createOrderedDynMemRefDict", opaquePtrTy, {}),
ApiSpec(API::CREATE_DYN_MEM_REF, "createDynMemRef", opaquePtrTy, {int32Ty}),
ApiSpec(API::GET_DATA, "getData", opaquePtrTy, {opaquePtrTy}),
ApiSpec(API::SET_DATA, "setData", voidTy, {opaquePtrTy, opaquePtrTy}),
ApiSpec(API::GET_DYN_MEM_REF, "getDynMemRef", opaquePtrTy, {opaquePtrTy, int32Ty}),
ApiSpec(API::SET_DYN_MEM_REF, "setDynMemRef", voidTy, {opaquePtrTy, int32Ty, opaquePtrTy}),
ApiSpec(API::GET_SIZES, "getSizes", int64PtrTy, {opaquePtrTy}),
ApiSpec(API::GET_STRIDES, "getStrides", int64PtrTy, {opaquePtrTy})
};
// clang-format on
// Declare APIs in the current module and build an API registry mapping api
// identities to a symbol reference to the API function.
ApiRegistry registry;
for (auto &apiSpec : apiSpecs) {
apiSpec.symbolRef = getOrInsertExternFunc(apiSpec.name, module,
apiSpec.funcTy(), rewriter);
registry.emplace(apiSpec.id, apiSpec);
}
return registry;
}
// Call a registered API, return the return SSA values if only one result is
// returned, otherwise return nullptr.
Value callApi(PatternRewriter &rewriter, Location loc, ApiRegistry registry,
API apiId, ArrayRef<Value> params) const {
auto returnVals = rewriter.create<LLVM::CallOp>(
loc, registry.at(apiId).outputTy, registry.at(apiId).symbolRef,
ArrayRef<Value>(params));
if (returnVals.getNumResults() == 1)
return returnVals.getResult(0);
return nullptr;
}
// Helper function to insert an entry block to LLVM function.
// (TODO): upstream this to MLIR.
Block &createEntryBlock(LLVM::LLVMType &dynEntryPointFuncType,
LLVM::LLVMFuncOp &dynamicEntryPointFunc) const {
// Add entry block:
auto *entryPointEntryBlock = new Block();
dynamicEntryPointFunc.push_back(entryPointEntryBlock);
llvm::SmallVector<Type, 4> argTypes;
for (size_t i = 0; i < dynEntryPointFuncType.getFunctionNumParams(); i++)
argTypes.emplace_back(dynEntryPointFuncType.getFunctionParamType(i));
entryPointEntryBlock->addArguments(argTypes);
return *entryPointEntryBlock;
}
void fillPtrToMemRefWithDynMemRef(Value &dynMemRef, Value &ptrToMemRef,
PatternRewriter &rewriter,
const Location &loc,
const std::map<API, ApiSpec> &apiRegistry,
LLVM::LLVMDialect *llvmDialect) const {
auto memRefPtrTy = ptrToMemRef.getType().dyn_cast<LLVM::LLVMType>();
auto memRefTy = memRefPtrTy.getPointerElementTy();
auto int64Ty = LLVM::LLVMType::getInt64Ty(llvmDialect);
Value memRef = rewriter.create<LLVM::LoadOp>(loc, memRefPtrTy, ptrToMemRef);
// Set dataPtr and alignedDataPtr;
auto dataPtr =
callApi(rewriter, loc, apiRegistry, API::GET_DATA, {dynMemRef});
dataPtr = rewriter.create<LLVM::BitcastOp>(
loc, memRefTy.getStructElementType(0), dataPtr);
memRef = rewriter.create<LLVM::InsertValueOp>(
loc, memRefTy, memRef, dataPtr,
rewriter.getArrayAttr({rewriter.getI32IntegerAttr(0)}));
memRef = rewriter.create<LLVM::InsertValueOp>(
loc, memRefTy, memRef, dataPtr,
rewriter.getArrayAttr({rewriter.getI32IntegerAttr(1)}));
// Use zero offset now.
auto zero = rewriter.create<LLVM::ConstantOp>(
loc, int64Ty, rewriter.getI64IntegerAttr(0));
memRef = rewriter.create<LLVM::InsertValueOp>(
loc, memRefTy, memRef, zero,
rewriter.getArrayAttr({rewriter.getI32IntegerAttr(2)}));
// Get rank, sizes array ptr and strides array ptr.
auto rank = getRankFromMemRefType(memRefTy);
auto sizesArrayPtr =
callApi(rewriter, loc, apiRegistry, API::GET_SIZES, {dynMemRef});
auto stridesArrayPtr =
callApi(rewriter, loc, apiRegistry, API::GET_STRIDES, {dynMemRef});
for (decltype(rank) i = 0; i < rank; i++) {
auto dimIdx = rewriter.create<LLVM::ConstantOp>(
loc, int64Ty, rewriter.getI64IntegerAttr(i));
// Insert size of the dimension.
auto dimSizePtr = rewriter.create<LLVM::GEPOp>(
loc, int64Ty.getPointerTo(), sizesArrayPtr,
ArrayRef<Value>({dimIdx}));
auto dimSize = rewriter.create<LLVM::LoadOp>(loc, int64Ty.getPointerTo(),
dimSizePtr);
memRef = rewriter.create<LLVM::InsertValueOp>(
loc, memRefTy, memRef, dimSize,
rewriter.getArrayAttr(
{rewriter.getI64IntegerAttr(3), rewriter.getI64IntegerAttr(i)}));
// Insert stride of the dimension.
auto dimStridePtr = rewriter.create<LLVM::GEPOp>(
loc, int64Ty.getPointerTo(), sizesArrayPtr,
ArrayRef<Value>({dimIdx}));
auto dimStride = rewriter.create<LLVM::LoadOp>(
loc, int64Ty.getPointerTo(), dimStridePtr);
memRef = rewriter.create<LLVM::InsertValueOp>(
loc, memRefTy, memRef, dimStride,
rewriter.getArrayAttr(
{rewriter.getI64IntegerAttr(4), rewriter.getI64IntegerAttr(i)}));
}
rewriter.create<LLVM::StoreOp>(loc, memRef, ptrToMemRef);
}
void fillDynMemRefWithMemRef(Value &outMemRef, Value &outDynMemRef,
PatternRewriter &rewriter, const Location &loc,
const std::map<API, ApiSpec> &apiRegistry,
LLVM::LLVMDialect *llvmDialect) const {
auto outMemRefTy = outMemRef.getType().dyn_cast<LLVM::LLVMType>();
auto int64Ty = LLVM::LLVMType::getInt64Ty(llvmDialect);
// Extract the data pointer, and record it in dynamic mem ref created.
Value outMemRefDataPtr = rewriter.create<LLVM::ExtractValueOp>(
loc, outMemRefTy.getStructElementType(0), outMemRef,
rewriter.getArrayAttr({rewriter.getI64IntegerAttr(0)}));
outMemRefDataPtr = rewriter.create<LLVM::BitcastOp>(
loc, LLVM::LLVMType::getInt8PtrTy(llvmDialect), outMemRefDataPtr);
callApi(rewriter, loc, apiRegistry, API::SET_DATA,
{outDynMemRef, outMemRefDataPtr});
auto rank = getRankFromMemRefType(outMemRefTy);
auto sizesArrayPtr =
callApi(rewriter, loc, apiRegistry, API::GET_SIZES, {outDynMemRef});
auto stridesArrayPtr =
callApi(rewriter, loc, apiRegistry, API::GET_STRIDES, {outDynMemRef});
for (decltype(rank) i = 0; i < rank; i++) {
auto dimIdx = rewriter.create<LLVM::ConstantOp>(
loc, int64Ty, rewriter.getI64IntegerAttr(i));
// Transfer size of dimension from memref to dynamic memref.
auto dimSize = rewriter.create<LLVM::ExtractValueOp>(
loc, int64Ty, outMemRef,
rewriter.getArrayAttr(
{rewriter.getI64IntegerAttr(3), rewriter.getI64IntegerAttr(i)}));
auto dimSizePtr = rewriter.create<LLVM::GEPOp>(
loc, int64Ty.getPointerTo(), sizesArrayPtr,
ArrayRef<Value>({dimIdx}));
rewriter.create<LLVM::StoreOp>(loc, dimSize, dimSizePtr);
// Transfer stride of dimension from memref to dynamic memref.
auto dimStride = rewriter.create<LLVM::ExtractValueOp>(
loc, int64Ty, outMemRef,
rewriter.getArrayAttr(
{rewriter.getI64IntegerAttr(4), rewriter.getI64IntegerAttr(i)}));
auto dimStridePtr = rewriter.create<LLVM::GEPOp>(
loc, int64Ty.getPointerTo(), stridesArrayPtr,
ArrayRef<Value>({dimIdx}));
rewriter.create<LLVM::StoreOp>(loc, dimStride, dimStridePtr);
}
}
};
//===----------------------------------------------------------------------===//
// KRNL to LLVM: KrnlSqrlOpLowering
//===----------------------------------------------------------------------===//
class KrnlSqrtOpLowering : public ConversionPattern {
public:
explicit KrnlSqrtOpLowering(MLIRContext *context)
: ConversionPattern(KrnlSqrtOp::getOperationName(), 1, context) {}
PatternMatchResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
OperandAdaptor<KrnlSqrtOp> adaptor(operands);
LLVM::LLVMType operandType =
adaptor.operand().getType().dyn_cast_or_null<LLVM::LLVMType>();
if (!operandType)
return matchFailure();
std::string functionName;
if (operandType.isFloatTy())
functionName = "llvm.sqrt.f32";
else if (operandType.isDoubleTy())
functionName = "llvm.sqrt.f64";
else
assert(false && "Unsupported operand type.");
// Get a symbol reference to the sqrt function, inserting it if necessary.
ModuleOp parentModule = op->getParentOfType<ModuleOp>();
auto sqrtRef =
getOrInsertSqrt(rewriter, parentModule, functionName, operandType);
// Sqrt call
rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, operandType, sqrtRef,
adaptor.operand());
return matchSuccess();
}
private:
/// Return a symbol reference to the sqrt function, inserting it into the
/// module if necessary.
static FlatSymbolRefAttr getOrInsertSqrt(PatternRewriter &rewriter,
ModuleOp module, std::string fnName,
LLVM::LLVMType operandType) {
auto *context = module.getContext();
if (module.lookupSymbol<LLVM::LLVMFuncOp>(fnName))
return SymbolRefAttr::get(fnName, context);
// Create a function declaration for sqrt, the signature is:
// * `float (float)`
auto llvmFnType =
LLVM::LLVMType::getFunctionTy(operandType, operandType, false);
// Insert the sqrt function into the body of the parent module.
PatternRewriter::InsertionGuard insertGuard(rewriter);
rewriter.setInsertionPointToStart(module.getBody());
rewriter.create<LLVM::LLVMFuncOp>(module.getLoc(), fnName, llvmFnType);
return SymbolRefAttr::get(fnName, context);
}
};
} // end namespace
//===----------------------------------------------------------------------===//
// KRNL + Stadard + Affine dialects lowering to LLVM.
//===----------------------------------------------------------------------===//
namespace {
struct KrnlToLLVMLoweringPass : public ModulePass<KrnlToLLVMLoweringPass> {
void runOnModule() final;
};
} // end anonymous namespace
void KrnlToLLVMLoweringPass::runOnModule() {
// Define the target for this lowering i.e. the LLVM dialect.
ConversionTarget target(getContext());
target.addLegalDialect<LLVM::LLVMDialect>();
target.addLegalOp<ModuleOp, ModuleTerminatorOp>();
// Lower the MemRef types to a representation in LLVM.
LLVMTypeConverter typeConverter(&getContext());
// We have a combination of `krnl`, `affine`, and `std` operations. We
// lower in stages until all the code is in the LLVM dialect.
OwningRewritePatternList patterns;
populateAffineToStdConversionPatterns(patterns, &getContext());
populateLoopToStdConversionPatterns(patterns, &getContext());
populateStdToLLVMConversionPatterns(typeConverter, patterns);
// Lower from the `krnl` dialect i.e. the Reshape operation.
patterns.insert<KrnlMemcpyOpLowering, KrnlEntryPointOpLowering,
KrnlSqrtOpLowering>(&getContext());
// We want to completely lower to LLVM, so we use a `FullConversion`. This
// ensures that only legal operations will remain after the conversion.
if (failed(
applyFullConversion(getModule(), target, patterns, &typeConverter)))
signalPassFailure();
}
/// Create the pass for lowering `Krnl`, `Affine` and `Std` dialects to LLVM.
std::unique_ptr<mlir::Pass> mlir::createKrnlLowerToLLVMPass() {
return std::make_unique<KrnlToLLVMLoweringPass>();
}
static PassRegistration<KrnlToLLVMLoweringPass>
pass("lower-all-llvm", "Lower the Krnl Affine and Std dialects to LLVM.");