[MLIR] Add support for reshape (#390)
* Add reshape op handling. * Lower reshape to KRNL dialect. * Add comments. * Propagate reshape to KRNL IR. * Lower KRNL reshape to affine and standard ops level dialects. * Add lowering of reshape operation to Krnl and LLVM Dialects. * Add test for LLVM IR dialect output for reshape. * Fix rebase. * Fix test variable. * Emit errors during reshape shape inference. Address other reviewer comments.
This commit is contained in:
parent
5ed79083d5
commit
e81a7654f9
|
@ -181,3 +181,16 @@ def KrnlTerminatorOp : Op<Krnl_Dialect, "terminate", [Terminator]> {
|
||||||
// Fully specified by traits.
|
// Fully specified by traits.
|
||||||
let verifier = ?;
|
let verifier = ?;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def KrnlMemcpyOp : Op<Krnl_Dialect, "memcpy"> {
|
||||||
|
let summary = "Krnl memcpy operation";
|
||||||
|
let description = [{
|
||||||
|
In the KRNL dialect the reshape op doesn't generate a new memory entry and
|
||||||
|
treats a reshape like a cast.
|
||||||
|
}];
|
||||||
|
|
||||||
|
let arguments = (ins AnyMemRef:$dest, AnyMemRef:$src, AnyInteger:$size);
|
||||||
|
|
||||||
|
let parser = ?;
|
||||||
|
let printer = ?;
|
||||||
|
}
|
||||||
|
|
|
@ -266,7 +266,7 @@ def gen_schema(schema) :
|
||||||
ShapeInferenceList=['Exp', 'Tanh', 'Sinh', 'Cosh', 'Sigmoid', 'Relu',
|
ShapeInferenceList=['Exp', 'Tanh', 'Sinh', 'Cosh', 'Sigmoid', 'Relu',
|
||||||
'Add', 'Mul', 'Div', 'Sub', 'And', 'Or', 'Xor',
|
'Add', 'Mul', 'Div', 'Sub', 'And', 'Or', 'Xor',
|
||||||
'Sum', 'Max', 'Min', 'MatMul', 'Gemm', 'LeakyRelu',
|
'Sum', 'Max', 'Min', 'MatMul', 'Gemm', 'LeakyRelu',
|
||||||
'Elu', 'Selu', 'HardSigmoid']
|
'Elu', 'Selu', 'HardSigmoid', 'Reshape']
|
||||||
CanonicalList=['Add', 'Identity']
|
CanonicalList=['Add', 'Identity']
|
||||||
line_indent = ' '
|
line_indent = ' '
|
||||||
|
|
||||||
|
|
|
@ -42,9 +42,7 @@ ONNXOpsDialect::ONNXOpsDialect(mlir::MLIRContext* ctx)
|
||||||
// Exp
|
// Exp
|
||||||
/// Infer the output shape of the ONNXExpOp. This method is required by the
|
/// Infer the output shape of the ONNXExpOp. This method is required by the
|
||||||
/// shape inference interface.
|
/// shape inference interface.
|
||||||
void ONNXExpOp::inferShapes() {
|
void ONNXExpOp::inferShapes() { getResult()->setType(getOperand()->getType()); }
|
||||||
getResult()->setType(getOperand()->getType());
|
|
||||||
}
|
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Tanh
|
// Tanh
|
||||||
|
@ -90,9 +88,7 @@ void ONNXSigmoidOp::inferShapes() {
|
||||||
// Elu
|
// Elu
|
||||||
/// Infer the output shape of the ONNXEluOp. This method is required by the
|
/// Infer the output shape of the ONNXEluOp. This method is required by the
|
||||||
/// shape inference interface.
|
/// shape inference interface.
|
||||||
void ONNXEluOp::inferShapes() {
|
void ONNXEluOp::inferShapes() { getResult()->setType(getOperand()->getType()); }
|
||||||
getResult()->setType(getOperand()->getType());
|
|
||||||
}
|
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Relu
|
// Relu
|
||||||
|
@ -162,9 +158,7 @@ void ONNXAndOp::inferShapes() {
|
||||||
// Or
|
// Or
|
||||||
/// Infer the output shape of the ONNXOrOp. This method is required by the
|
/// Infer the output shape of the ONNXOrOp. This method is required by the
|
||||||
/// shape inference interface.
|
/// shape inference interface.
|
||||||
void ONNXOrOp::inferShapes() {
|
void ONNXOrOp::inferShapes() { getResult()->setType(getOperand(0)->getType()); }
|
||||||
getResult()->setType(getOperand(0)->getType());
|
|
||||||
}
|
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Xor
|
// Xor
|
||||||
|
@ -257,6 +251,36 @@ void ONNXFullGemmOp::inferShapes() {
|
||||||
// Verify that matrix sizes are valid for multiplication and addition.
|
// Verify that matrix sizes are valid for multiplication and addition.
|
||||||
// Take into account the dimensionality of the matrix.
|
// Take into account the dimensionality of the matrix.
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
// Reshape
|
||||||
|
|
||||||
|
void ONNXReshapeOp::inferShapes() {
|
||||||
|
// Cannot infer shape if no shape tensor is specified.
|
||||||
|
if (!getOperand(1)->getType().isa<RankedTensorType>())
|
||||||
|
emitError("Shape tensor not ranked.");
|
||||||
|
|
||||||
|
auto inputTensorTy = getOperand(0)->getType().cast<RankedTensorType>();
|
||||||
|
auto shapeTensorTy = getOperand(1)->getType().cast<RankedTensorType>();
|
||||||
|
|
||||||
|
// Only rank 1 shape tensors are supported.
|
||||||
|
if (shapeTensorTy.getShape().size() != 1)
|
||||||
|
emitError("Shape tensor must have rank one.");
|
||||||
|
|
||||||
|
int64_t outputRank = shapeTensorTy.getShape()[0];
|
||||||
|
|
||||||
|
// Shape tensor must have constant shape.
|
||||||
|
if (outputRank < 0)
|
||||||
|
emitError("Shape tensor must have constant shape.");
|
||||||
|
|
||||||
|
SmallVector<int64_t, 2> dims;
|
||||||
|
for (int i = 0; i < outputRank; ++i)
|
||||||
|
dims.emplace_back(-1);
|
||||||
|
|
||||||
|
getResult()->setType(
|
||||||
|
RankedTensorType::get(dims, inputTensorTy.getElementType()));
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// TableGen'd op method definitions
|
// TableGen'd op method definitions
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
|
@ -2197,7 +2197,7 @@ def ONNXReluOp:ONNX_Op<"Relu",
|
||||||
}
|
}
|
||||||
|
|
||||||
def ONNXReshapeOp:ONNX_Op<"Reshape",
|
def ONNXReshapeOp:ONNX_Op<"Reshape",
|
||||||
[NoSideEffect]> {
|
[NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
|
||||||
let summary = "ONNX Reshape operation";
|
let summary = "ONNX Reshape operation";
|
||||||
let description = [{
|
let description = [{
|
||||||
"Reshape the input tensor similar to numpy.reshape."
|
"Reshape the input tensor similar to numpy.reshape."
|
||||||
|
|
|
@ -86,7 +86,7 @@ static bool checkInsertDealloc(Operation* currentOp) {
|
||||||
// If there is at least one result to investigate.
|
// If there is at least one result to investigate.
|
||||||
if (currentOp->getNumResults() > 0) {
|
if (currentOp->getNumResults() > 0) {
|
||||||
auto result = currentOp->getResult(0);
|
auto result = currentOp->getResult(0);
|
||||||
for (auto operand : op.getOperands())
|
for (const auto& operand : op.getOperands())
|
||||||
if (operand == result)
|
if (operand == result)
|
||||||
insertDealloc = false;
|
insertDealloc = false;
|
||||||
}
|
}
|
||||||
|
@ -95,6 +95,20 @@ static bool checkInsertDealloc(Operation* currentOp) {
|
||||||
return insertDealloc;
|
return insertDealloc;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
unsigned getMemRefEltSizeInBytes(MemRefType memRefType) {
|
||||||
|
auto elementType = memRefType.getElementType();
|
||||||
|
|
||||||
|
unsigned sizeInBits;
|
||||||
|
if (elementType.isIntOrFloat()) {
|
||||||
|
sizeInBits = elementType.getIntOrFloatBitWidth();
|
||||||
|
} else {
|
||||||
|
auto vectorType = elementType.cast<VectorType>();
|
||||||
|
sizeInBits =
|
||||||
|
vectorType.getElementTypeBitWidth() * vectorType.getNumElements();
|
||||||
|
}
|
||||||
|
return llvm::divideCeil(sizeInBits, 8);
|
||||||
|
}
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
template <typename ElementwiseNaryOp>
|
template <typename ElementwiseNaryOp>
|
||||||
|
@ -655,6 +669,62 @@ struct ONNXElementwiseVariadicOpLowering : public ConversionPattern {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
struct ONNXReshapeOpLowering : public ConversionPattern {
|
||||||
|
ONNXReshapeOpLowering(MLIRContext* ctx)
|
||||||
|
: ConversionPattern(mlir::ONNXReshapeOp::getOperationName(), 1, ctx) {}
|
||||||
|
|
||||||
|
PatternMatchResult matchAndRewrite(Operation* op, ArrayRef<Value*> operands,
|
||||||
|
ConversionPatternRewriter& rewriter) const final {
|
||||||
|
auto tensorType = (*op->result_type_begin()).cast<TensorType>();
|
||||||
|
auto loc = op->getLoc();
|
||||||
|
|
||||||
|
// Insert an allocation and deallocation for the result of this operation.
|
||||||
|
auto memRefType = convertTensorToMemRef(tensorType);
|
||||||
|
Value* alloc;
|
||||||
|
|
||||||
|
// Compute size in bytes.
|
||||||
|
Value* tensorSize = rewriter.create<ConstantOp>(loc,
|
||||||
|
rewriter.getIntegerAttr(
|
||||||
|
rewriter.getIntegerType(64), getMemRefEltSizeInBytes(memRefType)));
|
||||||
|
bool insertDealloc = checkInsertDealloc(op);
|
||||||
|
if (hasAllConstantDimensions(memRefType)) {
|
||||||
|
alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc);
|
||||||
|
} else {
|
||||||
|
auto memRefShape = memRefType.getShape();
|
||||||
|
SmallVector<Value*, 4> allocOperands;
|
||||||
|
for (int i = 0; i < memRefShape.size(); ++i) {
|
||||||
|
// The shape array can always be used to construct shape information of
|
||||||
|
// the result.
|
||||||
|
Value* index = rewriter.create<ConstantOp>(
|
||||||
|
loc, rewriter.getIntegerAttr(rewriter.getIndexType(), i));
|
||||||
|
Value* loadedVal = rewriter.create<LoadOp>(loc, operands[1], index);
|
||||||
|
Value* int64LoadedVal = rewriter.create<ZeroExtendIOp>(
|
||||||
|
loc, loadedVal, rewriter.getIntegerType(64));
|
||||||
|
tensorSize = rewriter.create<MulIOp>(loc, tensorSize, int64LoadedVal);
|
||||||
|
allocOperands.push_back(rewriter.create<IndexCastOp>(
|
||||||
|
loc, loadedVal, rewriter.getIndexType()));
|
||||||
|
}
|
||||||
|
AllocOp allocateMemref =
|
||||||
|
rewriter.create<AllocOp>(loc, memRefType, allocOperands);
|
||||||
|
|
||||||
|
// Make sure to allocate at the beginning of the block if
|
||||||
|
// all dimensions are known.
|
||||||
|
auto* parentBlock = allocateMemref.getOperation()->getBlock();
|
||||||
|
if (insertDealloc) {
|
||||||
|
auto dealloc = rewriter.create<DeallocOp>(loc, allocateMemref);
|
||||||
|
dealloc.getOperation()->moveBefore(&parentBlock->back());
|
||||||
|
}
|
||||||
|
|
||||||
|
alloc = allocateMemref;
|
||||||
|
}
|
||||||
|
|
||||||
|
rewriter.create<KrnlMemcpyOp>(loc, alloc, operands[0], tensorSize);
|
||||||
|
rewriter.replaceOp(op, alloc);
|
||||||
|
|
||||||
|
return matchSuccess();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Conversion from Tensor type to the Standard dialect MemRef type.
|
// Conversion from Tensor type to the Standard dialect MemRef type.
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -754,7 +824,8 @@ void FrontendToKrnlLoweringPass::runOnModule() {
|
||||||
ONNXElementwiseVariadicOpLowering<mlir::ONNXXorOp>,
|
ONNXElementwiseVariadicOpLowering<mlir::ONNXXorOp>,
|
||||||
ONNXElementwiseVariadicOpLowering<mlir::ONNXSumOp>,
|
ONNXElementwiseVariadicOpLowering<mlir::ONNXSumOp>,
|
||||||
ONNXElementwiseVariadicOpLowering<mlir::ONNXMaxOp>,
|
ONNXElementwiseVariadicOpLowering<mlir::ONNXMaxOp>,
|
||||||
ONNXElementwiseVariadicOpLowering<mlir::ONNXMinOp>>(&getContext());
|
ONNXElementwiseVariadicOpLowering<mlir::ONNXMinOp>,
|
||||||
|
ONNXReshapeOpLowering>(&getContext());
|
||||||
|
|
||||||
// With the target and rewrite patterns defined, we can now attempt the
|
// With the target and rewrite patterns defined, we can now attempt the
|
||||||
// conversion. The conversion will signal failure if any of our `illegal`
|
// conversion. The conversion will signal failure if any of our `illegal`
|
||||||
|
|
|
@ -23,4 +23,7 @@ std::unique_ptr<Pass> createLowerToKrnlPass();
|
||||||
/// Pass for lowering frontend dialects to Krnl IR dialect.
|
/// Pass for lowering frontend dialects to Krnl IR dialect.
|
||||||
std::unique_ptr<Pass> createLowerKrnlPass();
|
std::unique_ptr<Pass> createLowerKrnlPass();
|
||||||
|
|
||||||
|
/// Pass for lowering Krnl dialect to LLVM dialect.
|
||||||
|
std::unique_ptr<Pass> createKrnlLowerToLLVMPass();
|
||||||
|
|
||||||
} // end namespace mlir
|
} // end namespace mlir
|
||||||
|
|
|
@ -110,7 +110,8 @@ class ShapeInferencePass : public mlir::FunctionPass<ShapeInferencePass> {
|
||||||
op->getName().getStringRef() != "onnx.Min" &&
|
op->getName().getStringRef() != "onnx.Min" &&
|
||||||
op->getName().getStringRef() != "onnx.MatMul" &&
|
op->getName().getStringRef() != "onnx.MatMul" &&
|
||||||
op->getName().getStringRef() != "onnx.Gemm" &&
|
op->getName().getStringRef() != "onnx.Gemm" &&
|
||||||
op->getName().getStringRef() != "onnx.FullGemm")
|
op->getName().getStringRef() != "onnx.FullGemm" &&
|
||||||
|
op->getName().getStringRef() != "onnx.Reshape")
|
||||||
return false;
|
return false;
|
||||||
return llvm::any_of(op->getResultTypes(),
|
return llvm::any_of(op->getResultTypes(),
|
||||||
[](Type result_type) { return !result_type.isa<RankedTensorType>(); });
|
[](Type result_type) { return !result_type.isa<RankedTensorType>(); });
|
||||||
|
|
|
@ -1,4 +1,6 @@
|
||||||
add_library(onnf_transform lower_krnl.cpp)
|
add_library(onnf_transform
|
||||||
|
lower_krnl.cpp
|
||||||
|
lower_to_llvm.cpp)
|
||||||
|
|
||||||
target_include_directories(onnf_transform
|
target_include_directories(onnf_transform
|
||||||
PRIVATE ${ONNF_SRC_ROOT} ${ONNF_BIN_ROOT}
|
PRIVATE ${ONNF_SRC_ROOT} ${ONNF_BIN_ROOT}
|
||||||
|
|
|
@ -142,6 +142,7 @@ void KrnlToAffineLoweringPass::runOnFunction() {
|
||||||
target.addLegalDialect<AffineOpsDialect, StandardOpsDialect>();
|
target.addLegalDialect<AffineOpsDialect, StandardOpsDialect>();
|
||||||
// We expect IR to be free of Krnl Dialect Ops.
|
// We expect IR to be free of Krnl Dialect Ops.
|
||||||
target.addIllegalDialect<KrnlOpsDialect>();
|
target.addIllegalDialect<KrnlOpsDialect>();
|
||||||
|
target.addLegalOp<KrnlMemcpyOp>();
|
||||||
|
|
||||||
OwningRewritePatternList patterns;
|
OwningRewritePatternList patterns;
|
||||||
patterns.insert<KrnlIterateOpLowering, KrnlTerminatorLowering,
|
patterns.insert<KrnlIterateOpLowering, KrnlTerminatorLowering,
|
||||||
|
|
|
@ -0,0 +1,146 @@
|
||||||
|
//====- LowerToLLVM.cpp - Lowering from KRNL+Affine+Std to LLVM -----------===//
|
||||||
|
//
|
||||||
|
// Copyright 2019 The DLC Authors.
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#include "llvm/ADT/Sequence.h"
|
||||||
|
#include "mlir/Conversion/AffineToStandard/AffineToStandard.h"
|
||||||
|
#include "mlir/Conversion/LoopToStandard/ConvertLoopToStandard.h"
|
||||||
|
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
|
||||||
|
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
|
||||||
|
#include "mlir/Dialect/AffineOps/AffineOps.h"
|
||||||
|
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
||||||
|
#include "mlir/Dialect/LoopOps/LoopOps.h"
|
||||||
|
#include "mlir/Dialect/StandardOps/Ops.h"
|
||||||
|
#include "mlir/Pass/Pass.h"
|
||||||
|
#include "mlir/Transforms/DialectConversion.h"
|
||||||
|
|
||||||
|
#include "src/compiler/dialect/krnl/krnl_ops.hpp"
|
||||||
|
#include "src/compiler/pass/passes.hpp"
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// KRNL to LLVM: patterns which need a direct lowering to LLVM.
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
class KrnlMemcpyOpLowering : public ConversionPattern {
|
||||||
|
public:
|
||||||
|
explicit KrnlMemcpyOpLowering(MLIRContext* context)
|
||||||
|
: ConversionPattern(KrnlMemcpyOp::getOperationName(), 1, context) {}
|
||||||
|
|
||||||
|
PatternMatchResult matchAndRewrite(Operation* op, ArrayRef<Value*> operands,
|
||||||
|
ConversionPatternRewriter& rewriter) const override {
|
||||||
|
auto* context = op->getContext();
|
||||||
|
auto loc = op->getLoc();
|
||||||
|
auto* llvmDialect =
|
||||||
|
op->getContext()->getRegisteredDialect<LLVM::LLVMDialect>();
|
||||||
|
assert(llvmDialect && "expected llvm dialect to be registered");
|
||||||
|
|
||||||
|
// Get a symbol reference to the memcpy function, inserting it if necessary.
|
||||||
|
ModuleOp parentModule = op->getParentOfType<ModuleOp>();
|
||||||
|
auto memcpyRef = getOrInsertMemcpy(rewriter, parentModule, llvmDialect);
|
||||||
|
|
||||||
|
// First operand.
|
||||||
|
Type dstType =
|
||||||
|
operands[0]->getType().cast<LLVM::LLVMType>().getStructElementType(1);
|
||||||
|
Value* alignedDstMemory = rewriter.create<LLVM::ExtractValueOp>(
|
||||||
|
loc, dstType, operands[0], rewriter.getI64ArrayAttr(1));
|
||||||
|
Value* alignedInt8PtrDstMemory = rewriter.create<LLVM::BitcastOp>(
|
||||||
|
loc, LLVM::LLVMType::getInt8PtrTy(llvmDialect), alignedDstMemory);
|
||||||
|
|
||||||
|
// Second operand.
|
||||||
|
Type srcType =
|
||||||
|
operands[1]->getType().cast<LLVM::LLVMType>().getStructElementType(1);
|
||||||
|
Value* alignedSrcMemory = rewriter.create<LLVM::ExtractValueOp>(
|
||||||
|
loc, srcType, operands[1], rewriter.getI64ArrayAttr(1));
|
||||||
|
Value* alignedInt8PtrSrcMemory = rewriter.create<LLVM::BitcastOp>(
|
||||||
|
loc, LLVM::LLVMType::getInt8PtrTy(llvmDialect), alignedSrcMemory);
|
||||||
|
|
||||||
|
// Size.
|
||||||
|
Value* int64Size = rewriter.create<LLVM::SExtOp>(
|
||||||
|
loc, LLVM::LLVMType::getInt64Ty(llvmDialect), operands[2]);
|
||||||
|
|
||||||
|
// Memcpy call
|
||||||
|
rewriter.create<CallOp>(loc, memcpyRef,
|
||||||
|
LLVM::LLVMType::getVoidTy(llvmDialect),
|
||||||
|
ArrayRef<Value*>(
|
||||||
|
{alignedInt8PtrDstMemory, alignedInt8PtrSrcMemory, int64Size}));
|
||||||
|
|
||||||
|
rewriter.eraseOp(op);
|
||||||
|
return matchSuccess();
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
/// Return a symbol reference to the memcpy function, inserting it into the
|
||||||
|
/// module if necessary.
|
||||||
|
static FlatSymbolRefAttr getOrInsertMemcpy(PatternRewriter& rewriter,
|
||||||
|
ModuleOp module, LLVM::LLVMDialect* llvmDialect) {
|
||||||
|
auto* context = module.getContext();
|
||||||
|
if (module.lookupSymbol<LLVM::LLVMFuncOp>("llvm.memcpy.p0i8.p0i8.i64"))
|
||||||
|
return SymbolRefAttr::get("llvm.memcpy.p0i8.p0i8.i64", context);
|
||||||
|
// Create a function declaration for memcpy, the signature is:
|
||||||
|
// * `void (i8*, i8* , i64, i1)`
|
||||||
|
auto llvmVoidTy = LLVM::LLVMType::getVoidTy(llvmDialect);
|
||||||
|
auto llvmI8PtrTy = LLVM::LLVMType::getInt8PtrTy(llvmDialect);
|
||||||
|
auto llvmI64Ty = LLVM::LLVMType::getInt64Ty(llvmDialect);
|
||||||
|
auto llvmFnType = LLVM::LLVMType::getFunctionTy(llvmVoidTy,
|
||||||
|
ArrayRef<mlir::LLVM::LLVMType>({llvmI8PtrTy, llvmI8PtrTy, llvmI64Ty}),
|
||||||
|
false);
|
||||||
|
|
||||||
|
// Insert the memcpy function into the body of the parent module.
|
||||||
|
PatternRewriter::InsertionGuard insertGuard(rewriter);
|
||||||
|
rewriter.setInsertionPointToStart(module.getBody());
|
||||||
|
rewriter.create<LLVM::LLVMFuncOp>(
|
||||||
|
module.getLoc(), "llvm.memcpy.p0i8.p0i8.i64", llvmFnType);
|
||||||
|
return SymbolRefAttr::get("llvm.memcpy.p0i8.p0i8.i64", context);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // end namespace
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// KRNL + Stadard + Affine dialects lowering to LLVM.
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
struct KrnlToLLVMLoweringPass : public ModulePass<KrnlToLLVMLoweringPass> {
|
||||||
|
void runOnModule() final;
|
||||||
|
};
|
||||||
|
} // end anonymous namespace
|
||||||
|
|
||||||
|
void KrnlToLLVMLoweringPass::runOnModule() {
|
||||||
|
// Define the target for this lowering i.e. the LLVM dialect.
|
||||||
|
ConversionTarget target(getContext());
|
||||||
|
target.addLegalDialect<LLVM::LLVMDialect>();
|
||||||
|
target.addLegalOp<ModuleOp, ModuleTerminatorOp>();
|
||||||
|
|
||||||
|
// Lower the MemRef types to a representation in LLVM.
|
||||||
|
LLVMTypeConverter typeConverter(&getContext());
|
||||||
|
|
||||||
|
// We have a combination of `krnl`, `affine`, and `std` operations. We
|
||||||
|
// lower in stages until all the code is in the LLVM dialect.
|
||||||
|
OwningRewritePatternList patterns;
|
||||||
|
populateAffineToStdConversionPatterns(patterns, &getContext());
|
||||||
|
populateLoopToStdConversionPatterns(patterns, &getContext());
|
||||||
|
populateStdToLLVMConversionPatterns(typeConverter, patterns);
|
||||||
|
|
||||||
|
// Lower from the `krnl` dialect i.e. the Reshape operation.
|
||||||
|
patterns.insert<KrnlMemcpyOpLowering>(&getContext());
|
||||||
|
|
||||||
|
// We want to completely lower to LLVM, so we use a `FullConversion`. This
|
||||||
|
// ensures that only legal operations will remain after the conversion.
|
||||||
|
auto module = getModule();
|
||||||
|
if (failed(applyFullConversion(module, target, patterns, &typeConverter)))
|
||||||
|
signalPassFailure();
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Create the pass for lowering `Krnl`, `Affine` and `Std` dialects to LLVM.
|
||||||
|
std::unique_ptr<mlir::Pass> mlir::createKrnlLowerToLLVMPass() {
|
||||||
|
return std::make_unique<KrnlToLLVMLoweringPass>();
|
||||||
|
}
|
||||||
|
|
||||||
|
static PassRegistration<KrnlToLLVMLoweringPass> pass(
|
||||||
|
"lower-all-llvm", "Lower the Krnl Affine and Std dialects to LLVM.");
|
|
@ -131,7 +131,7 @@ int main(int ac, char* av[]) {
|
||||||
pm.addPass(mlir::createLowerKrnlPass());
|
pm.addPass(mlir::createLowerKrnlPass());
|
||||||
pm.addPass(mlir::createLowerAffinePass());
|
pm.addPass(mlir::createLowerAffinePass());
|
||||||
pm.addPass(mlir::createLowerToCFGPass());
|
pm.addPass(mlir::createLowerToCFGPass());
|
||||||
pm.addPass(mlir::createLowerToLLVMPass());
|
pm.addPass(mlir::createKrnlLowerToLLVMPass());
|
||||||
pm.addPass(mlir::createCanonicalizerPass());
|
pm.addPass(mlir::createCanonicalizerPass());
|
||||||
pm.run(*module);
|
pm.run(*module);
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,16 @@
|
||||||
|
// RUN: dlc-opt --shape-inference --lower-frontend --lower-krnl --lower-all-llvm %s -split-input-file | FileCheck %s
|
||||||
|
|
||||||
|
func @test_reshape(%arg0 : tensor<?x10xf32>, %arg1 : tensor<4xi32>) -> tensor<*xf32> {
|
||||||
|
%0 = "onnx.Reshape"(%arg0, %arg1) : (tensor<?x10xf32>, tensor<4xi32>) -> tensor<*xf32>
|
||||||
|
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||||
|
|
||||||
|
// CHECK: llvm.func @llvm.memcpy.p0i8.p0i8.i64(!llvm<"i8*">, !llvm<"i8*">, !llvm.i64)
|
||||||
|
// CHECK: [[RES:%.+]] = llvm.insertvalue {{.*}}[4, 3] : !llvm<"{ float*, float*, i64, [4 x i64], [4 x i64] }">
|
||||||
|
// CHECK: [[EXT_VAL_0:%.+]] = llvm.extractvalue [[RES]][1] : !llvm<"{ float*, float*, i64, [4 x i64], [4 x i64] }">
|
||||||
|
// CHECK: [[DST:%.+]] = llvm.bitcast [[EXT_VAL_0]] : !llvm<"float*"> to !llvm<"i8*">
|
||||||
|
// CHECK: [[EXT_VAL_1:%.+]] = llvm.extractvalue %0[1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
|
||||||
|
// CHECK: [[SRC:%.+]] = llvm.bitcast [[EXT_VAL_1]] : !llvm<"float*"> to !llvm<"i8*">
|
||||||
|
// CHECK: [[SIZE:%.+]] = llvm.sext %{{.*}} : !llvm.i64 to !llvm.i64
|
||||||
|
// CHECK: llvm.call @llvm.memcpy.p0i8.p0i8.i64([[DST]], [[SRC]], [[SIZE]]) : (!llvm<"i8*">, !llvm<"i8*">, !llvm.i64) -> !llvm.void
|
||||||
|
// CHECK: llvm.return [[RES]] : !llvm<"{ float*, float*, i64, [4 x i64], [4 x i64] }">
|
||||||
|
}
|
|
@ -279,6 +279,37 @@ func @test_relu(%arg0 : tensor<?x10xf32>) -> tensor<*xf32> {
|
||||||
// CHECK: return [[RES]] : memref<?x10xf32>
|
// CHECK: return [[RES]] : memref<?x10xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func @test_reshape(%arg0 : tensor<?x10xf32>, %arg1 : tensor<4xi32>) -> tensor<*xf32> {
|
||||||
|
%0 = "onnx.Reshape"(%arg0, %arg1) : (tensor<?x10xf32>, tensor<4xi32>) -> tensor<*xf32>
|
||||||
|
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||||
|
|
||||||
|
// CHECK-LABEL: test_reshape
|
||||||
|
// CHECK: [[TYPE_IN_BYTES:%.+]] = constant 4 : i64
|
||||||
|
// CHECK: %[[INDEX_0:.+]] = constant 0 : index
|
||||||
|
// CHECK: [[LOAD_0:%.+]] = load %arg1[%[[INDEX_0]]] : memref<4xi32>
|
||||||
|
// CHECK: [[EXT_0:%.+]] = zexti [[LOAD_0]] : i32 to i64
|
||||||
|
// CHECK: [[MUL_0:%.+]] = muli [[TYPE_IN_BYTES]], [[EXT_0]] : i64
|
||||||
|
// CHECK: [[CAST_0:%.+]] = index_cast [[LOAD_0]] : i32 to index
|
||||||
|
// CHECK: %[[INDEX_1:.+]] = constant 1 : index
|
||||||
|
// CHECK: [[LOAD_1:%.+]] = load %arg1[%[[INDEX_1]]] : memref<4xi32>
|
||||||
|
// CHECK: [[EXT_1:%.+]] = zexti [[LOAD_1]] : i32 to i64
|
||||||
|
// CHECK: [[MUL_1:%.+]] = muli [[MUL_0]], [[EXT_1]] : i64
|
||||||
|
// CHECK: [[CAST_1:%.+]] = index_cast [[LOAD_1]] : i32 to index
|
||||||
|
// CHECK: %[[INDEX_2:.+]] = constant 2 : index
|
||||||
|
// CHECK: [[LOAD_2:%.+]] = load %arg1[%[[INDEX_2]]] : memref<4xi32>
|
||||||
|
// CHECK: [[EXT_2:%.+]] = zexti [[LOAD_2]] : i32 to i64
|
||||||
|
// CHECK: [[MUL_2:%.+]] = muli [[MUL_1]], [[EXT_2]] : i64
|
||||||
|
// CHECK: [[CAST_2:%.+]] = index_cast [[LOAD_2]] : i32 to index
|
||||||
|
// CHECK: %[[INDEX_3:.+]] = constant 3 : index
|
||||||
|
// CHECK: [[LOAD_3:%.+]] = load %arg1[%[[INDEX_3]]] : memref<4xi32>
|
||||||
|
// CHECK: [[EXT_3:%.+]] = zexti [[LOAD_3]] : i32 to i64
|
||||||
|
// CHECK: [[MUL_3:%.+]] = muli [[MUL_2]], [[EXT_3]] : i64
|
||||||
|
// CHECK: [[CAST_3:%.+]] = index_cast [[LOAD_3]] : i32 to index
|
||||||
|
// CHECK: [[ALLOC:%.+]] = alloc([[CAST_0]], [[CAST_1]], [[CAST_2]], [[CAST_3]]) : memref<?x?x?x?xf32>
|
||||||
|
// CHECK: "krnl.memcpy"([[ALLOC]], %arg0, [[MUL_3]]) : (memref<?x?x?x?xf32>, memref<?x10xf32>, i64) -> ()
|
||||||
|
// CHECK: return [[ALLOC]] : memref<?x?x?x?xf32>
|
||||||
|
}
|
||||||
|
|
||||||
func @test_sum(%arg0 : tensor<?x10xf32>, %arg1 : tensor<?x10xf32>) -> tensor<*xf32> {
|
func @test_sum(%arg0 : tensor<?x10xf32>, %arg1 : tensor<?x10xf32>) -> tensor<*xf32> {
|
||||||
%0 = "onnx.Sum"(%arg0, %arg1) : (tensor<?x10xf32>, tensor<?x10xf32>) -> tensor<*xf32>
|
%0 = "onnx.Sum"(%arg0, %arg1) : (tensor<?x10xf32>, tensor<?x10xf32>) -> tensor<*xf32>
|
||||||
"std.return"(%0) : (tensor<*xf32>) -> ()
|
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||||
|
|
Loading…
Reference in New Issue