[MLIR] Add support for reshape (#390)
* Add reshape op handling. * Lower reshape to KRNL dialect. * Add comments. * Propagate reshape to KRNL IR. * Lower KRNL reshape to affine and standard ops level dialects. * Add lowering of reshape operation to Krnl and LLVM Dialects. * Add test for LLVM IR dialect output for reshape. * Fix rebase. * Fix test variable. * Emit errors during reshape shape inference. Address other reviewer comments.
This commit is contained in:
parent
5ed79083d5
commit
e81a7654f9
|
@ -181,3 +181,16 @@ def KrnlTerminatorOp : Op<Krnl_Dialect, "terminate", [Terminator]> {
|
|||
// Fully specified by traits.
|
||||
let verifier = ?;
|
||||
}
|
||||
|
||||
def KrnlMemcpyOp : Op<Krnl_Dialect, "memcpy"> {
|
||||
let summary = "Krnl memcpy operation";
|
||||
let description = [{
|
||||
In the KRNL dialect the reshape op doesn't generate a new memory entry and
|
||||
treats a reshape like a cast.
|
||||
}];
|
||||
|
||||
let arguments = (ins AnyMemRef:$dest, AnyMemRef:$src, AnyInteger:$size);
|
||||
|
||||
let parser = ?;
|
||||
let printer = ?;
|
||||
}
|
||||
|
|
|
@ -266,7 +266,7 @@ def gen_schema(schema) :
|
|||
ShapeInferenceList=['Exp', 'Tanh', 'Sinh', 'Cosh', 'Sigmoid', 'Relu',
|
||||
'Add', 'Mul', 'Div', 'Sub', 'And', 'Or', 'Xor',
|
||||
'Sum', 'Max', 'Min', 'MatMul', 'Gemm', 'LeakyRelu',
|
||||
'Elu', 'Selu', 'HardSigmoid']
|
||||
'Elu', 'Selu', 'HardSigmoid', 'Reshape']
|
||||
CanonicalList=['Add', 'Identity']
|
||||
line_indent = ' '
|
||||
|
||||
|
|
|
@ -42,9 +42,7 @@ ONNXOpsDialect::ONNXOpsDialect(mlir::MLIRContext* ctx)
|
|||
// Exp
|
||||
/// Infer the output shape of the ONNXExpOp. This method is required by the
|
||||
/// shape inference interface.
|
||||
void ONNXExpOp::inferShapes() {
|
||||
getResult()->setType(getOperand()->getType());
|
||||
}
|
||||
void ONNXExpOp::inferShapes() { getResult()->setType(getOperand()->getType()); }
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Tanh
|
||||
|
@ -90,9 +88,7 @@ void ONNXSigmoidOp::inferShapes() {
|
|||
// Elu
|
||||
/// Infer the output shape of the ONNXEluOp. This method is required by the
|
||||
/// shape inference interface.
|
||||
void ONNXEluOp::inferShapes() {
|
||||
getResult()->setType(getOperand()->getType());
|
||||
}
|
||||
void ONNXEluOp::inferShapes() { getResult()->setType(getOperand()->getType()); }
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Relu
|
||||
|
@ -162,9 +158,7 @@ void ONNXAndOp::inferShapes() {
|
|||
// Or
|
||||
/// Infer the output shape of the ONNXOrOp. This method is required by the
|
||||
/// shape inference interface.
|
||||
void ONNXOrOp::inferShapes() {
|
||||
getResult()->setType(getOperand(0)->getType());
|
||||
}
|
||||
void ONNXOrOp::inferShapes() { getResult()->setType(getOperand(0)->getType()); }
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Xor
|
||||
|
@ -257,6 +251,36 @@ void ONNXFullGemmOp::inferShapes() {
|
|||
// Verify that matrix sizes are valid for multiplication and addition.
|
||||
// Take into account the dimensionality of the matrix.
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// Reshape
|
||||
|
||||
void ONNXReshapeOp::inferShapes() {
|
||||
// Cannot infer shape if no shape tensor is specified.
|
||||
if (!getOperand(1)->getType().isa<RankedTensorType>())
|
||||
emitError("Shape tensor not ranked.");
|
||||
|
||||
auto inputTensorTy = getOperand(0)->getType().cast<RankedTensorType>();
|
||||
auto shapeTensorTy = getOperand(1)->getType().cast<RankedTensorType>();
|
||||
|
||||
// Only rank 1 shape tensors are supported.
|
||||
if (shapeTensorTy.getShape().size() != 1)
|
||||
emitError("Shape tensor must have rank one.");
|
||||
|
||||
int64_t outputRank = shapeTensorTy.getShape()[0];
|
||||
|
||||
// Shape tensor must have constant shape.
|
||||
if (outputRank < 0)
|
||||
emitError("Shape tensor must have constant shape.");
|
||||
|
||||
SmallVector<int64_t, 2> dims;
|
||||
for (int i = 0; i < outputRank; ++i)
|
||||
dims.emplace_back(-1);
|
||||
|
||||
getResult()->setType(
|
||||
RankedTensorType::get(dims, inputTensorTy.getElementType()));
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TableGen'd op method definitions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -2197,7 +2197,7 @@ def ONNXReluOp:ONNX_Op<"Relu",
|
|||
}
|
||||
|
||||
def ONNXReshapeOp:ONNX_Op<"Reshape",
|
||||
[NoSideEffect]> {
|
||||
[NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
|
||||
let summary = "ONNX Reshape operation";
|
||||
let description = [{
|
||||
"Reshape the input tensor similar to numpy.reshape."
|
||||
|
|
|
@ -86,7 +86,7 @@ static bool checkInsertDealloc(Operation* currentOp) {
|
|||
// If there is at least one result to investigate.
|
||||
if (currentOp->getNumResults() > 0) {
|
||||
auto result = currentOp->getResult(0);
|
||||
for (auto operand : op.getOperands())
|
||||
for (const auto& operand : op.getOperands())
|
||||
if (operand == result)
|
||||
insertDealloc = false;
|
||||
}
|
||||
|
@ -95,6 +95,20 @@ static bool checkInsertDealloc(Operation* currentOp) {
|
|||
return insertDealloc;
|
||||
}
|
||||
|
||||
unsigned getMemRefEltSizeInBytes(MemRefType memRefType) {
|
||||
auto elementType = memRefType.getElementType();
|
||||
|
||||
unsigned sizeInBits;
|
||||
if (elementType.isIntOrFloat()) {
|
||||
sizeInBits = elementType.getIntOrFloatBitWidth();
|
||||
} else {
|
||||
auto vectorType = elementType.cast<VectorType>();
|
||||
sizeInBits =
|
||||
vectorType.getElementTypeBitWidth() * vectorType.getNumElements();
|
||||
}
|
||||
return llvm::divideCeil(sizeInBits, 8);
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename ElementwiseNaryOp>
|
||||
|
@ -655,6 +669,62 @@ struct ONNXElementwiseVariadicOpLowering : public ConversionPattern {
|
|||
}
|
||||
};
|
||||
|
||||
struct ONNXReshapeOpLowering : public ConversionPattern {
|
||||
ONNXReshapeOpLowering(MLIRContext* ctx)
|
||||
: ConversionPattern(mlir::ONNXReshapeOp::getOperationName(), 1, ctx) {}
|
||||
|
||||
PatternMatchResult matchAndRewrite(Operation* op, ArrayRef<Value*> operands,
|
||||
ConversionPatternRewriter& rewriter) const final {
|
||||
auto tensorType = (*op->result_type_begin()).cast<TensorType>();
|
||||
auto loc = op->getLoc();
|
||||
|
||||
// Insert an allocation and deallocation for the result of this operation.
|
||||
auto memRefType = convertTensorToMemRef(tensorType);
|
||||
Value* alloc;
|
||||
|
||||
// Compute size in bytes.
|
||||
Value* tensorSize = rewriter.create<ConstantOp>(loc,
|
||||
rewriter.getIntegerAttr(
|
||||
rewriter.getIntegerType(64), getMemRefEltSizeInBytes(memRefType)));
|
||||
bool insertDealloc = checkInsertDealloc(op);
|
||||
if (hasAllConstantDimensions(memRefType)) {
|
||||
alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc);
|
||||
} else {
|
||||
auto memRefShape = memRefType.getShape();
|
||||
SmallVector<Value*, 4> allocOperands;
|
||||
for (int i = 0; i < memRefShape.size(); ++i) {
|
||||
// The shape array can always be used to construct shape information of
|
||||
// the result.
|
||||
Value* index = rewriter.create<ConstantOp>(
|
||||
loc, rewriter.getIntegerAttr(rewriter.getIndexType(), i));
|
||||
Value* loadedVal = rewriter.create<LoadOp>(loc, operands[1], index);
|
||||
Value* int64LoadedVal = rewriter.create<ZeroExtendIOp>(
|
||||
loc, loadedVal, rewriter.getIntegerType(64));
|
||||
tensorSize = rewriter.create<MulIOp>(loc, tensorSize, int64LoadedVal);
|
||||
allocOperands.push_back(rewriter.create<IndexCastOp>(
|
||||
loc, loadedVal, rewriter.getIndexType()));
|
||||
}
|
||||
AllocOp allocateMemref =
|
||||
rewriter.create<AllocOp>(loc, memRefType, allocOperands);
|
||||
|
||||
// Make sure to allocate at the beginning of the block if
|
||||
// all dimensions are known.
|
||||
auto* parentBlock = allocateMemref.getOperation()->getBlock();
|
||||
if (insertDealloc) {
|
||||
auto dealloc = rewriter.create<DeallocOp>(loc, allocateMemref);
|
||||
dealloc.getOperation()->moveBefore(&parentBlock->back());
|
||||
}
|
||||
|
||||
alloc = allocateMemref;
|
||||
}
|
||||
|
||||
rewriter.create<KrnlMemcpyOp>(loc, alloc, operands[0], tensorSize);
|
||||
rewriter.replaceOp(op, alloc);
|
||||
|
||||
return matchSuccess();
|
||||
}
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Conversion from Tensor type to the Standard dialect MemRef type.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -754,7 +824,8 @@ void FrontendToKrnlLoweringPass::runOnModule() {
|
|||
ONNXElementwiseVariadicOpLowering<mlir::ONNXXorOp>,
|
||||
ONNXElementwiseVariadicOpLowering<mlir::ONNXSumOp>,
|
||||
ONNXElementwiseVariadicOpLowering<mlir::ONNXMaxOp>,
|
||||
ONNXElementwiseVariadicOpLowering<mlir::ONNXMinOp>>(&getContext());
|
||||
ONNXElementwiseVariadicOpLowering<mlir::ONNXMinOp>,
|
||||
ONNXReshapeOpLowering>(&getContext());
|
||||
|
||||
// With the target and rewrite patterns defined, we can now attempt the
|
||||
// conversion. The conversion will signal failure if any of our `illegal`
|
||||
|
|
|
@ -23,4 +23,7 @@ std::unique_ptr<Pass> createLowerToKrnlPass();
|
|||
/// Pass for lowering frontend dialects to Krnl IR dialect.
|
||||
std::unique_ptr<Pass> createLowerKrnlPass();
|
||||
|
||||
/// Pass for lowering Krnl dialect to LLVM dialect.
|
||||
std::unique_ptr<Pass> createKrnlLowerToLLVMPass();
|
||||
|
||||
} // end namespace mlir
|
||||
|
|
|
@ -75,7 +75,7 @@ class ShapeInferencePass : public mlir::FunctionPass<ShapeInferencePass> {
|
|||
if (auto terminator_op = f.getBody().back().getTerminator()) {
|
||||
auto results = terminator_op->getOperandTypes();
|
||||
f.setType(FunctionType::get(f.getType().getInputs(),
|
||||
std::vector<Type>(results.begin(), results.end()), f.getContext()));
|
||||
std::vector<Type>(results.begin(), results.end()), f.getContext()));
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -110,7 +110,8 @@ class ShapeInferencePass : public mlir::FunctionPass<ShapeInferencePass> {
|
|||
op->getName().getStringRef() != "onnx.Min" &&
|
||||
op->getName().getStringRef() != "onnx.MatMul" &&
|
||||
op->getName().getStringRef() != "onnx.Gemm" &&
|
||||
op->getName().getStringRef() != "onnx.FullGemm")
|
||||
op->getName().getStringRef() != "onnx.FullGemm" &&
|
||||
op->getName().getStringRef() != "onnx.Reshape")
|
||||
return false;
|
||||
return llvm::any_of(op->getResultTypes(),
|
||||
[](Type result_type) { return !result_type.isa<RankedTensorType>(); });
|
||||
|
|
|
@ -1,4 +1,6 @@
|
|||
add_library(onnf_transform lower_krnl.cpp)
|
||||
add_library(onnf_transform
|
||||
lower_krnl.cpp
|
||||
lower_to_llvm.cpp)
|
||||
|
||||
target_include_directories(onnf_transform
|
||||
PRIVATE ${ONNF_SRC_ROOT} ${ONNF_BIN_ROOT}
|
||||
|
|
|
@ -142,6 +142,7 @@ void KrnlToAffineLoweringPass::runOnFunction() {
|
|||
target.addLegalDialect<AffineOpsDialect, StandardOpsDialect>();
|
||||
// We expect IR to be free of Krnl Dialect Ops.
|
||||
target.addIllegalDialect<KrnlOpsDialect>();
|
||||
target.addLegalOp<KrnlMemcpyOp>();
|
||||
|
||||
OwningRewritePatternList patterns;
|
||||
patterns.insert<KrnlIterateOpLowering, KrnlTerminatorLowering,
|
||||
|
|
|
@ -0,0 +1,146 @@
|
|||
//====- LowerToLLVM.cpp - Lowering from KRNL+Affine+Std to LLVM -----------===//
|
||||
//
|
||||
// Copyright 2019 The DLC Authors.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "llvm/ADT/Sequence.h"
|
||||
#include "mlir/Conversion/AffineToStandard/AffineToStandard.h"
|
||||
#include "mlir/Conversion/LoopToStandard/ConvertLoopToStandard.h"
|
||||
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
|
||||
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
|
||||
#include "mlir/Dialect/AffineOps/AffineOps.h"
|
||||
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
||||
#include "mlir/Dialect/LoopOps/LoopOps.h"
|
||||
#include "mlir/Dialect/StandardOps/Ops.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
|
||||
#include "src/compiler/dialect/krnl/krnl_ops.hpp"
|
||||
#include "src/compiler/pass/passes.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace {
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// KRNL to LLVM: patterns which need a direct lowering to LLVM.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
class KrnlMemcpyOpLowering : public ConversionPattern {
|
||||
public:
|
||||
explicit KrnlMemcpyOpLowering(MLIRContext* context)
|
||||
: ConversionPattern(KrnlMemcpyOp::getOperationName(), 1, context) {}
|
||||
|
||||
PatternMatchResult matchAndRewrite(Operation* op, ArrayRef<Value*> operands,
|
||||
ConversionPatternRewriter& rewriter) const override {
|
||||
auto* context = op->getContext();
|
||||
auto loc = op->getLoc();
|
||||
auto* llvmDialect =
|
||||
op->getContext()->getRegisteredDialect<LLVM::LLVMDialect>();
|
||||
assert(llvmDialect && "expected llvm dialect to be registered");
|
||||
|
||||
// Get a symbol reference to the memcpy function, inserting it if necessary.
|
||||
ModuleOp parentModule = op->getParentOfType<ModuleOp>();
|
||||
auto memcpyRef = getOrInsertMemcpy(rewriter, parentModule, llvmDialect);
|
||||
|
||||
// First operand.
|
||||
Type dstType =
|
||||
operands[0]->getType().cast<LLVM::LLVMType>().getStructElementType(1);
|
||||
Value* alignedDstMemory = rewriter.create<LLVM::ExtractValueOp>(
|
||||
loc, dstType, operands[0], rewriter.getI64ArrayAttr(1));
|
||||
Value* alignedInt8PtrDstMemory = rewriter.create<LLVM::BitcastOp>(
|
||||
loc, LLVM::LLVMType::getInt8PtrTy(llvmDialect), alignedDstMemory);
|
||||
|
||||
// Second operand.
|
||||
Type srcType =
|
||||
operands[1]->getType().cast<LLVM::LLVMType>().getStructElementType(1);
|
||||
Value* alignedSrcMemory = rewriter.create<LLVM::ExtractValueOp>(
|
||||
loc, srcType, operands[1], rewriter.getI64ArrayAttr(1));
|
||||
Value* alignedInt8PtrSrcMemory = rewriter.create<LLVM::BitcastOp>(
|
||||
loc, LLVM::LLVMType::getInt8PtrTy(llvmDialect), alignedSrcMemory);
|
||||
|
||||
// Size.
|
||||
Value* int64Size = rewriter.create<LLVM::SExtOp>(
|
||||
loc, LLVM::LLVMType::getInt64Ty(llvmDialect), operands[2]);
|
||||
|
||||
// Memcpy call
|
||||
rewriter.create<CallOp>(loc, memcpyRef,
|
||||
LLVM::LLVMType::getVoidTy(llvmDialect),
|
||||
ArrayRef<Value*>(
|
||||
{alignedInt8PtrDstMemory, alignedInt8PtrSrcMemory, int64Size}));
|
||||
|
||||
rewriter.eraseOp(op);
|
||||
return matchSuccess();
|
||||
}
|
||||
|
||||
private:
|
||||
/// Return a symbol reference to the memcpy function, inserting it into the
|
||||
/// module if necessary.
|
||||
static FlatSymbolRefAttr getOrInsertMemcpy(PatternRewriter& rewriter,
|
||||
ModuleOp module, LLVM::LLVMDialect* llvmDialect) {
|
||||
auto* context = module.getContext();
|
||||
if (module.lookupSymbol<LLVM::LLVMFuncOp>("llvm.memcpy.p0i8.p0i8.i64"))
|
||||
return SymbolRefAttr::get("llvm.memcpy.p0i8.p0i8.i64", context);
|
||||
// Create a function declaration for memcpy, the signature is:
|
||||
// * `void (i8*, i8* , i64, i1)`
|
||||
auto llvmVoidTy = LLVM::LLVMType::getVoidTy(llvmDialect);
|
||||
auto llvmI8PtrTy = LLVM::LLVMType::getInt8PtrTy(llvmDialect);
|
||||
auto llvmI64Ty = LLVM::LLVMType::getInt64Ty(llvmDialect);
|
||||
auto llvmFnType = LLVM::LLVMType::getFunctionTy(llvmVoidTy,
|
||||
ArrayRef<mlir::LLVM::LLVMType>({llvmI8PtrTy, llvmI8PtrTy, llvmI64Ty}),
|
||||
false);
|
||||
|
||||
// Insert the memcpy function into the body of the parent module.
|
||||
PatternRewriter::InsertionGuard insertGuard(rewriter);
|
||||
rewriter.setInsertionPointToStart(module.getBody());
|
||||
rewriter.create<LLVM::LLVMFuncOp>(
|
||||
module.getLoc(), "llvm.memcpy.p0i8.p0i8.i64", llvmFnType);
|
||||
return SymbolRefAttr::get("llvm.memcpy.p0i8.p0i8.i64", context);
|
||||
}
|
||||
};
|
||||
} // end namespace
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// KRNL + Stadard + Affine dialects lowering to LLVM.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
namespace {
|
||||
struct KrnlToLLVMLoweringPass : public ModulePass<KrnlToLLVMLoweringPass> {
|
||||
void runOnModule() final;
|
||||
};
|
||||
} // end anonymous namespace
|
||||
|
||||
void KrnlToLLVMLoweringPass::runOnModule() {
|
||||
// Define the target for this lowering i.e. the LLVM dialect.
|
||||
ConversionTarget target(getContext());
|
||||
target.addLegalDialect<LLVM::LLVMDialect>();
|
||||
target.addLegalOp<ModuleOp, ModuleTerminatorOp>();
|
||||
|
||||
// Lower the MemRef types to a representation in LLVM.
|
||||
LLVMTypeConverter typeConverter(&getContext());
|
||||
|
||||
// We have a combination of `krnl`, `affine`, and `std` operations. We
|
||||
// lower in stages until all the code is in the LLVM dialect.
|
||||
OwningRewritePatternList patterns;
|
||||
populateAffineToStdConversionPatterns(patterns, &getContext());
|
||||
populateLoopToStdConversionPatterns(patterns, &getContext());
|
||||
populateStdToLLVMConversionPatterns(typeConverter, patterns);
|
||||
|
||||
// Lower from the `krnl` dialect i.e. the Reshape operation.
|
||||
patterns.insert<KrnlMemcpyOpLowering>(&getContext());
|
||||
|
||||
// We want to completely lower to LLVM, so we use a `FullConversion`. This
|
||||
// ensures that only legal operations will remain after the conversion.
|
||||
auto module = getModule();
|
||||
if (failed(applyFullConversion(module, target, patterns, &typeConverter)))
|
||||
signalPassFailure();
|
||||
}
|
||||
|
||||
/// Create the pass for lowering `Krnl`, `Affine` and `Std` dialects to LLVM.
|
||||
std::unique_ptr<mlir::Pass> mlir::createKrnlLowerToLLVMPass() {
|
||||
return std::make_unique<KrnlToLLVMLoweringPass>();
|
||||
}
|
||||
|
||||
static PassRegistration<KrnlToLLVMLoweringPass> pass(
|
||||
"lower-all-llvm", "Lower the Krnl Affine and Std dialects to LLVM.");
|
|
@ -131,7 +131,7 @@ int main(int ac, char* av[]) {
|
|||
pm.addPass(mlir::createLowerKrnlPass());
|
||||
pm.addPass(mlir::createLowerAffinePass());
|
||||
pm.addPass(mlir::createLowerToCFGPass());
|
||||
pm.addPass(mlir::createLowerToLLVMPass());
|
||||
pm.addPass(mlir::createKrnlLowerToLLVMPass());
|
||||
pm.addPass(mlir::createCanonicalizerPass());
|
||||
pm.run(*module);
|
||||
|
||||
|
|
|
@ -0,0 +1,16 @@
|
|||
// RUN: dlc-opt --shape-inference --lower-frontend --lower-krnl --lower-all-llvm %s -split-input-file | FileCheck %s
|
||||
|
||||
func @test_reshape(%arg0 : tensor<?x10xf32>, %arg1 : tensor<4xi32>) -> tensor<*xf32> {
|
||||
%0 = "onnx.Reshape"(%arg0, %arg1) : (tensor<?x10xf32>, tensor<4xi32>) -> tensor<*xf32>
|
||||
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||
|
||||
// CHECK: llvm.func @llvm.memcpy.p0i8.p0i8.i64(!llvm<"i8*">, !llvm<"i8*">, !llvm.i64)
|
||||
// CHECK: [[RES:%.+]] = llvm.insertvalue {{.*}}[4, 3] : !llvm<"{ float*, float*, i64, [4 x i64], [4 x i64] }">
|
||||
// CHECK: [[EXT_VAL_0:%.+]] = llvm.extractvalue [[RES]][1] : !llvm<"{ float*, float*, i64, [4 x i64], [4 x i64] }">
|
||||
// CHECK: [[DST:%.+]] = llvm.bitcast [[EXT_VAL_0]] : !llvm<"float*"> to !llvm<"i8*">
|
||||
// CHECK: [[EXT_VAL_1:%.+]] = llvm.extractvalue %0[1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }">
|
||||
// CHECK: [[SRC:%.+]] = llvm.bitcast [[EXT_VAL_1]] : !llvm<"float*"> to !llvm<"i8*">
|
||||
// CHECK: [[SIZE:%.+]] = llvm.sext %{{.*}} : !llvm.i64 to !llvm.i64
|
||||
// CHECK: llvm.call @llvm.memcpy.p0i8.p0i8.i64([[DST]], [[SRC]], [[SIZE]]) : (!llvm<"i8*">, !llvm<"i8*">, !llvm.i64) -> !llvm.void
|
||||
// CHECK: llvm.return [[RES]] : !llvm<"{ float*, float*, i64, [4 x i64], [4 x i64] }">
|
||||
}
|
|
@ -279,6 +279,37 @@ func @test_relu(%arg0 : tensor<?x10xf32>) -> tensor<*xf32> {
|
|||
// CHECK: return [[RES]] : memref<?x10xf32>
|
||||
}
|
||||
|
||||
func @test_reshape(%arg0 : tensor<?x10xf32>, %arg1 : tensor<4xi32>) -> tensor<*xf32> {
|
||||
%0 = "onnx.Reshape"(%arg0, %arg1) : (tensor<?x10xf32>, tensor<4xi32>) -> tensor<*xf32>
|
||||
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||
|
||||
// CHECK-LABEL: test_reshape
|
||||
// CHECK: [[TYPE_IN_BYTES:%.+]] = constant 4 : i64
|
||||
// CHECK: %[[INDEX_0:.+]] = constant 0 : index
|
||||
// CHECK: [[LOAD_0:%.+]] = load %arg1[%[[INDEX_0]]] : memref<4xi32>
|
||||
// CHECK: [[EXT_0:%.+]] = zexti [[LOAD_0]] : i32 to i64
|
||||
// CHECK: [[MUL_0:%.+]] = muli [[TYPE_IN_BYTES]], [[EXT_0]] : i64
|
||||
// CHECK: [[CAST_0:%.+]] = index_cast [[LOAD_0]] : i32 to index
|
||||
// CHECK: %[[INDEX_1:.+]] = constant 1 : index
|
||||
// CHECK: [[LOAD_1:%.+]] = load %arg1[%[[INDEX_1]]] : memref<4xi32>
|
||||
// CHECK: [[EXT_1:%.+]] = zexti [[LOAD_1]] : i32 to i64
|
||||
// CHECK: [[MUL_1:%.+]] = muli [[MUL_0]], [[EXT_1]] : i64
|
||||
// CHECK: [[CAST_1:%.+]] = index_cast [[LOAD_1]] : i32 to index
|
||||
// CHECK: %[[INDEX_2:.+]] = constant 2 : index
|
||||
// CHECK: [[LOAD_2:%.+]] = load %arg1[%[[INDEX_2]]] : memref<4xi32>
|
||||
// CHECK: [[EXT_2:%.+]] = zexti [[LOAD_2]] : i32 to i64
|
||||
// CHECK: [[MUL_2:%.+]] = muli [[MUL_1]], [[EXT_2]] : i64
|
||||
// CHECK: [[CAST_2:%.+]] = index_cast [[LOAD_2]] : i32 to index
|
||||
// CHECK: %[[INDEX_3:.+]] = constant 3 : index
|
||||
// CHECK: [[LOAD_3:%.+]] = load %arg1[%[[INDEX_3]]] : memref<4xi32>
|
||||
// CHECK: [[EXT_3:%.+]] = zexti [[LOAD_3]] : i32 to i64
|
||||
// CHECK: [[MUL_3:%.+]] = muli [[MUL_2]], [[EXT_3]] : i64
|
||||
// CHECK: [[CAST_3:%.+]] = index_cast [[LOAD_3]] : i32 to index
|
||||
// CHECK: [[ALLOC:%.+]] = alloc([[CAST_0]], [[CAST_1]], [[CAST_2]], [[CAST_3]]) : memref<?x?x?x?xf32>
|
||||
// CHECK: "krnl.memcpy"([[ALLOC]], %arg0, [[MUL_3]]) : (memref<?x?x?x?xf32>, memref<?x10xf32>, i64) -> ()
|
||||
// CHECK: return [[ALLOC]] : memref<?x?x?x?xf32>
|
||||
}
|
||||
|
||||
func @test_sum(%arg0 : tensor<?x10xf32>, %arg1 : tensor<?x10xf32>) -> tensor<*xf32> {
|
||||
%0 = "onnx.Sum"(%arg0, %arg1) : (tensor<?x10xf32>, tensor<?x10xf32>) -> tensor<*xf32>
|
||||
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||
|
|
Loading…
Reference in New Issue