[MLIR] Add support for reshape (#390)
* Add reshape op handling. * Lower reshape to KRNL dialect. * Add comments. * Propagate reshape to KRNL IR. * Lower KRNL reshape to affine and standard ops level dialects. * Add lowering of reshape operation to Krnl and LLVM Dialects. * Add test for LLVM IR dialect output for reshape. * Fix rebase. * Fix test variable. * Emit errors during reshape shape inference. Address other reviewer comments.
This commit is contained in:
		
							parent
							
								
									5ed79083d5
								
							
						
					
					
						commit
						e81a7654f9
					
				|  | @ -181,3 +181,16 @@ def KrnlTerminatorOp : Op<Krnl_Dialect, "terminate", [Terminator]> { | |||
|   // Fully specified by traits. | ||||
|   let verifier = ?; | ||||
| } | ||||
| 
 | ||||
| def KrnlMemcpyOp : Op<Krnl_Dialect, "memcpy"> { | ||||
|   let summary = "Krnl memcpy operation"; | ||||
|   let description = [{ | ||||
|     In the KRNL dialect the reshape op doesn't generate a new memory entry and | ||||
|     treats a reshape like a cast. | ||||
|   }]; | ||||
| 
 | ||||
|   let arguments = (ins AnyMemRef:$dest, AnyMemRef:$src, AnyInteger:$size); | ||||
| 
 | ||||
|   let parser = ?; | ||||
|   let printer = ?; | ||||
| } | ||||
|  |  | |||
|  | @ -266,7 +266,7 @@ def gen_schema(schema) : | |||
|     ShapeInferenceList=['Exp', 'Tanh', 'Sinh', 'Cosh', 'Sigmoid', 'Relu', | ||||
|                         'Add', 'Mul', 'Div', 'Sub', 'And', 'Or', 'Xor', | ||||
|                         'Sum', 'Max', 'Min', 'MatMul', 'Gemm', 'LeakyRelu', | ||||
|                         'Elu', 'Selu', 'HardSigmoid'] | ||||
|                         'Elu', 'Selu', 'HardSigmoid', 'Reshape'] | ||||
|     CanonicalList=['Add', 'Identity'] | ||||
|     line_indent = '  ' | ||||
| 
 | ||||
|  |  | |||
|  | @ -42,9 +42,7 @@ ONNXOpsDialect::ONNXOpsDialect(mlir::MLIRContext* ctx) | |||
| // Exp
 | ||||
| /// Infer the output shape of the ONNXExpOp. This method is required by the
 | ||||
| /// shape inference interface.
 | ||||
| void ONNXExpOp::inferShapes() { | ||||
|   getResult()->setType(getOperand()->getType()); | ||||
| } | ||||
| void ONNXExpOp::inferShapes() { getResult()->setType(getOperand()->getType()); } | ||||
| 
 | ||||
| //===----------------------------------------------------------------------===//
 | ||||
| // Tanh
 | ||||
|  | @ -90,9 +88,7 @@ void ONNXSigmoidOp::inferShapes() { | |||
| // Elu
 | ||||
| /// Infer the output shape of the ONNXEluOp. This method is required by the
 | ||||
| /// shape inference interface.
 | ||||
| void ONNXEluOp::inferShapes() { | ||||
|   getResult()->setType(getOperand()->getType()); | ||||
| } | ||||
| void ONNXEluOp::inferShapes() { getResult()->setType(getOperand()->getType()); } | ||||
| 
 | ||||
| //===----------------------------------------------------------------------===//
 | ||||
| // Relu
 | ||||
|  | @ -162,9 +158,7 @@ void ONNXAndOp::inferShapes() { | |||
| // Or
 | ||||
| /// Infer the output shape of the ONNXOrOp. This method is required by the
 | ||||
| /// shape inference interface.
 | ||||
| void ONNXOrOp::inferShapes() { | ||||
|   getResult()->setType(getOperand(0)->getType()); | ||||
| } | ||||
| void ONNXOrOp::inferShapes() { getResult()->setType(getOperand(0)->getType()); } | ||||
| 
 | ||||
| //===----------------------------------------------------------------------===//
 | ||||
| // Xor
 | ||||
|  | @ -257,6 +251,36 @@ void ONNXFullGemmOp::inferShapes() { | |||
| //   Verify that matrix sizes are valid for multiplication and addition.
 | ||||
| //   Take into account the dimensionality of the matrix.
 | ||||
| 
 | ||||
| //===----------------------------------------------------------------------===//
 | ||||
| 
 | ||||
| // Reshape
 | ||||
| 
 | ||||
| void ONNXReshapeOp::inferShapes() { | ||||
|   // Cannot infer shape if no shape tensor is specified.
 | ||||
|   if (!getOperand(1)->getType().isa<RankedTensorType>()) | ||||
|     emitError("Shape tensor not ranked."); | ||||
| 
 | ||||
|   auto inputTensorTy = getOperand(0)->getType().cast<RankedTensorType>(); | ||||
|   auto shapeTensorTy = getOperand(1)->getType().cast<RankedTensorType>(); | ||||
| 
 | ||||
|   // Only rank 1 shape tensors are supported.
 | ||||
|   if (shapeTensorTy.getShape().size() != 1) | ||||
|     emitError("Shape tensor must have rank one."); | ||||
| 
 | ||||
|   int64_t outputRank = shapeTensorTy.getShape()[0]; | ||||
| 
 | ||||
|   // Shape tensor must have constant shape.
 | ||||
|   if (outputRank < 0) | ||||
|     emitError("Shape tensor must have constant shape."); | ||||
| 
 | ||||
|   SmallVector<int64_t, 2> dims; | ||||
|   for (int i = 0; i < outputRank; ++i) | ||||
|     dims.emplace_back(-1); | ||||
| 
 | ||||
|   getResult()->setType( | ||||
|       RankedTensorType::get(dims, inputTensorTy.getElementType())); | ||||
| } | ||||
| 
 | ||||
| //===----------------------------------------------------------------------===//
 | ||||
| // TableGen'd op method definitions
 | ||||
| //===----------------------------------------------------------------------===//
 | ||||
|  |  | |||
|  | @ -2197,7 +2197,7 @@ def ONNXReluOp:ONNX_Op<"Relu", | |||
| } | ||||
| 
 | ||||
| def ONNXReshapeOp:ONNX_Op<"Reshape",  | ||||
|     [NoSideEffect]> { | ||||
|     [NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> { | ||||
|   let summary = "ONNX Reshape operation"; | ||||
|   let description = [{ | ||||
|     "Reshape the input tensor similar to numpy.reshape." | ||||
|  |  | |||
|  | @ -86,7 +86,7 @@ static bool checkInsertDealloc(Operation* currentOp) { | |||
|     // If there is at least one result to investigate.
 | ||||
|     if (currentOp->getNumResults() > 0) { | ||||
|       auto result = currentOp->getResult(0); | ||||
|       for (auto operand : op.getOperands()) | ||||
|       for (const auto& operand : op.getOperands()) | ||||
|         if (operand == result) | ||||
|           insertDealloc = false; | ||||
|     } | ||||
|  | @ -95,6 +95,20 @@ static bool checkInsertDealloc(Operation* currentOp) { | |||
|   return insertDealloc; | ||||
| } | ||||
| 
 | ||||
| unsigned getMemRefEltSizeInBytes(MemRefType memRefType) { | ||||
|   auto elementType = memRefType.getElementType(); | ||||
| 
 | ||||
|   unsigned sizeInBits; | ||||
|   if (elementType.isIntOrFloat()) { | ||||
|     sizeInBits = elementType.getIntOrFloatBitWidth(); | ||||
|   } else { | ||||
|     auto vectorType = elementType.cast<VectorType>(); | ||||
|     sizeInBits = | ||||
|         vectorType.getElementTypeBitWidth() * vectorType.getNumElements(); | ||||
|   } | ||||
|   return llvm::divideCeil(sizeInBits, 8); | ||||
| } | ||||
| 
 | ||||
| namespace { | ||||
| 
 | ||||
| template <typename ElementwiseNaryOp> | ||||
|  | @ -655,6 +669,62 @@ struct ONNXElementwiseVariadicOpLowering : public ConversionPattern { | |||
|   } | ||||
| }; | ||||
| 
 | ||||
| struct ONNXReshapeOpLowering : public ConversionPattern { | ||||
|   ONNXReshapeOpLowering(MLIRContext* ctx) | ||||
|       : ConversionPattern(mlir::ONNXReshapeOp::getOperationName(), 1, ctx) {} | ||||
| 
 | ||||
|   PatternMatchResult matchAndRewrite(Operation* op, ArrayRef<Value*> operands, | ||||
|       ConversionPatternRewriter& rewriter) const final { | ||||
|     auto tensorType = (*op->result_type_begin()).cast<TensorType>(); | ||||
|     auto loc = op->getLoc(); | ||||
| 
 | ||||
|     // Insert an allocation and deallocation for the result of this operation.
 | ||||
|     auto memRefType = convertTensorToMemRef(tensorType); | ||||
|     Value* alloc; | ||||
| 
 | ||||
|     // Compute size in bytes.
 | ||||
|     Value* tensorSize = rewriter.create<ConstantOp>(loc, | ||||
|         rewriter.getIntegerAttr( | ||||
|             rewriter.getIntegerType(64), getMemRefEltSizeInBytes(memRefType))); | ||||
|     bool insertDealloc = checkInsertDealloc(op); | ||||
|     if (hasAllConstantDimensions(memRefType)) { | ||||
|       alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc); | ||||
|     } else { | ||||
|       auto memRefShape = memRefType.getShape(); | ||||
|       SmallVector<Value*, 4> allocOperands; | ||||
|       for (int i = 0; i < memRefShape.size(); ++i) { | ||||
|         // The shape array can always be used to construct shape information of
 | ||||
|         // the result.
 | ||||
|         Value* index = rewriter.create<ConstantOp>( | ||||
|             loc, rewriter.getIntegerAttr(rewriter.getIndexType(), i)); | ||||
|         Value* loadedVal = rewriter.create<LoadOp>(loc, operands[1], index); | ||||
|         Value* int64LoadedVal = rewriter.create<ZeroExtendIOp>( | ||||
|             loc, loadedVal, rewriter.getIntegerType(64)); | ||||
|         tensorSize = rewriter.create<MulIOp>(loc, tensorSize, int64LoadedVal); | ||||
|         allocOperands.push_back(rewriter.create<IndexCastOp>( | ||||
|             loc, loadedVal, rewriter.getIndexType())); | ||||
|       } | ||||
|       AllocOp allocateMemref = | ||||
|           rewriter.create<AllocOp>(loc, memRefType, allocOperands); | ||||
| 
 | ||||
|       // Make sure to allocate at the beginning of the block if
 | ||||
|       // all dimensions are known.
 | ||||
|       auto* parentBlock = allocateMemref.getOperation()->getBlock(); | ||||
|       if (insertDealloc) { | ||||
|         auto dealloc = rewriter.create<DeallocOp>(loc, allocateMemref); | ||||
|         dealloc.getOperation()->moveBefore(&parentBlock->back()); | ||||
|       } | ||||
| 
 | ||||
|       alloc = allocateMemref; | ||||
|     } | ||||
| 
 | ||||
|     rewriter.create<KrnlMemcpyOp>(loc, alloc, operands[0], tensorSize); | ||||
|     rewriter.replaceOp(op, alloc); | ||||
| 
 | ||||
|     return matchSuccess(); | ||||
|   } | ||||
| }; | ||||
| 
 | ||||
| //===----------------------------------------------------------------------===//
 | ||||
| // Conversion from Tensor type to the Standard dialect MemRef type.
 | ||||
| //===----------------------------------------------------------------------===//
 | ||||
|  | @ -754,7 +824,8 @@ void FrontendToKrnlLoweringPass::runOnModule() { | |||
|       ONNXElementwiseVariadicOpLowering<mlir::ONNXXorOp>, | ||||
|       ONNXElementwiseVariadicOpLowering<mlir::ONNXSumOp>, | ||||
|       ONNXElementwiseVariadicOpLowering<mlir::ONNXMaxOp>, | ||||
|       ONNXElementwiseVariadicOpLowering<mlir::ONNXMinOp>>(&getContext()); | ||||
|       ONNXElementwiseVariadicOpLowering<mlir::ONNXMinOp>, | ||||
|       ONNXReshapeOpLowering>(&getContext()); | ||||
| 
 | ||||
|   // With the target and rewrite patterns defined, we can now attempt the
 | ||||
|   // conversion. The conversion will signal failure if any of our `illegal`
 | ||||
|  |  | |||
|  | @ -23,4 +23,7 @@ std::unique_ptr<Pass> createLowerToKrnlPass(); | |||
| /// Pass for lowering frontend dialects to Krnl IR dialect.
 | ||||
| std::unique_ptr<Pass> createLowerKrnlPass(); | ||||
| 
 | ||||
| /// Pass for lowering Krnl dialect to LLVM dialect.
 | ||||
| std::unique_ptr<Pass> createKrnlLowerToLLVMPass(); | ||||
| 
 | ||||
| }  // end namespace mlir
 | ||||
|  |  | |||
|  | @ -75,7 +75,7 @@ class ShapeInferencePass : public mlir::FunctionPass<ShapeInferencePass> { | |||
|     if (auto terminator_op = f.getBody().back().getTerminator()) { | ||||
|       auto results = terminator_op->getOperandTypes(); | ||||
|       f.setType(FunctionType::get(f.getType().getInputs(), | ||||
|           std::vector<Type>(results.begin(), results.end()), f.getContext())); | ||||
|                 std::vector<Type>(results.begin(), results.end()), f.getContext())); | ||||
|     } | ||||
|   } | ||||
| 
 | ||||
|  | @ -110,7 +110,8 @@ class ShapeInferencePass : public mlir::FunctionPass<ShapeInferencePass> { | |||
|         op->getName().getStringRef() != "onnx.Min" && | ||||
|         op->getName().getStringRef() != "onnx.MatMul" && | ||||
|         op->getName().getStringRef() != "onnx.Gemm" && | ||||
|         op->getName().getStringRef() != "onnx.FullGemm") | ||||
|         op->getName().getStringRef() != "onnx.FullGemm" && | ||||
|         op->getName().getStringRef() != "onnx.Reshape") | ||||
|       return false; | ||||
|     return llvm::any_of(op->getResultTypes(), | ||||
|         [](Type result_type) { return !result_type.isa<RankedTensorType>(); }); | ||||
|  |  | |||
|  | @ -1,4 +1,6 @@ | |||
| add_library(onnf_transform lower_krnl.cpp) | ||||
| add_library(onnf_transform | ||||
|             lower_krnl.cpp | ||||
|             lower_to_llvm.cpp) | ||||
| 
 | ||||
| target_include_directories(onnf_transform | ||||
|                            PRIVATE ${ONNF_SRC_ROOT} ${ONNF_BIN_ROOT} | ||||
|  |  | |||
|  | @ -142,6 +142,7 @@ void KrnlToAffineLoweringPass::runOnFunction() { | |||
|   target.addLegalDialect<AffineOpsDialect, StandardOpsDialect>(); | ||||
|   // We expect IR to be free of Krnl Dialect Ops.
 | ||||
|   target.addIllegalDialect<KrnlOpsDialect>(); | ||||
|   target.addLegalOp<KrnlMemcpyOp>(); | ||||
| 
 | ||||
|   OwningRewritePatternList patterns; | ||||
|   patterns.insert<KrnlIterateOpLowering, KrnlTerminatorLowering, | ||||
|  |  | |||
|  | @ -0,0 +1,146 @@ | |||
| //====- LowerToLLVM.cpp - Lowering from KRNL+Affine+Std to LLVM -----------===//
 | ||||
| //
 | ||||
| // Copyright 2019 The DLC Authors.
 | ||||
| //
 | ||||
| //===----------------------------------------------------------------------===//
 | ||||
| 
 | ||||
| #include "llvm/ADT/Sequence.h" | ||||
| #include "mlir/Conversion/AffineToStandard/AffineToStandard.h" | ||||
| #include "mlir/Conversion/LoopToStandard/ConvertLoopToStandard.h" | ||||
| #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" | ||||
| #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" | ||||
| #include "mlir/Dialect/AffineOps/AffineOps.h" | ||||
| #include "mlir/Dialect/LLVMIR/LLVMDialect.h" | ||||
| #include "mlir/Dialect/LoopOps/LoopOps.h" | ||||
| #include "mlir/Dialect/StandardOps/Ops.h" | ||||
| #include "mlir/Pass/Pass.h" | ||||
| #include "mlir/Transforms/DialectConversion.h" | ||||
| 
 | ||||
| #include "src/compiler/dialect/krnl/krnl_ops.hpp" | ||||
| #include "src/compiler/pass/passes.hpp" | ||||
| 
 | ||||
| using namespace mlir; | ||||
| 
 | ||||
| namespace { | ||||
| 
 | ||||
| //===----------------------------------------------------------------------===//
 | ||||
| // KRNL to LLVM: patterns which need a direct lowering to LLVM.
 | ||||
| //===----------------------------------------------------------------------===//
 | ||||
| 
 | ||||
| class KrnlMemcpyOpLowering : public ConversionPattern { | ||||
|  public: | ||||
|   explicit KrnlMemcpyOpLowering(MLIRContext* context) | ||||
|       : ConversionPattern(KrnlMemcpyOp::getOperationName(), 1, context) {} | ||||
| 
 | ||||
|   PatternMatchResult matchAndRewrite(Operation* op, ArrayRef<Value*> operands, | ||||
|       ConversionPatternRewriter& rewriter) const override { | ||||
|     auto* context = op->getContext(); | ||||
|     auto loc = op->getLoc(); | ||||
|     auto* llvmDialect = | ||||
|         op->getContext()->getRegisteredDialect<LLVM::LLVMDialect>(); | ||||
|     assert(llvmDialect && "expected llvm dialect to be registered"); | ||||
| 
 | ||||
|     // Get a symbol reference to the memcpy function, inserting it if necessary.
 | ||||
|     ModuleOp parentModule = op->getParentOfType<ModuleOp>(); | ||||
|     auto memcpyRef = getOrInsertMemcpy(rewriter, parentModule, llvmDialect); | ||||
| 
 | ||||
|     // First operand.
 | ||||
|     Type dstType = | ||||
|         operands[0]->getType().cast<LLVM::LLVMType>().getStructElementType(1); | ||||
|     Value* alignedDstMemory = rewriter.create<LLVM::ExtractValueOp>( | ||||
|         loc, dstType, operands[0], rewriter.getI64ArrayAttr(1)); | ||||
|     Value* alignedInt8PtrDstMemory = rewriter.create<LLVM::BitcastOp>( | ||||
|         loc, LLVM::LLVMType::getInt8PtrTy(llvmDialect), alignedDstMemory); | ||||
| 
 | ||||
|     // Second operand.
 | ||||
|     Type srcType = | ||||
|         operands[1]->getType().cast<LLVM::LLVMType>().getStructElementType(1); | ||||
|     Value* alignedSrcMemory = rewriter.create<LLVM::ExtractValueOp>( | ||||
|         loc, srcType, operands[1], rewriter.getI64ArrayAttr(1)); | ||||
|     Value* alignedInt8PtrSrcMemory = rewriter.create<LLVM::BitcastOp>( | ||||
|         loc, LLVM::LLVMType::getInt8PtrTy(llvmDialect), alignedSrcMemory); | ||||
| 
 | ||||
|     // Size.
 | ||||
|     Value* int64Size = rewriter.create<LLVM::SExtOp>( | ||||
|         loc, LLVM::LLVMType::getInt64Ty(llvmDialect), operands[2]); | ||||
| 
 | ||||
|     // Memcpy call
 | ||||
|     rewriter.create<CallOp>(loc, memcpyRef, | ||||
|         LLVM::LLVMType::getVoidTy(llvmDialect), | ||||
|         ArrayRef<Value*>( | ||||
|             {alignedInt8PtrDstMemory, alignedInt8PtrSrcMemory, int64Size})); | ||||
| 
 | ||||
|     rewriter.eraseOp(op); | ||||
|     return matchSuccess(); | ||||
|   } | ||||
| 
 | ||||
|  private: | ||||
|   /// Return a symbol reference to the memcpy function, inserting it into the
 | ||||
|   /// module if necessary.
 | ||||
|   static FlatSymbolRefAttr getOrInsertMemcpy(PatternRewriter& rewriter, | ||||
|       ModuleOp module, LLVM::LLVMDialect* llvmDialect) { | ||||
|     auto* context = module.getContext(); | ||||
|     if (module.lookupSymbol<LLVM::LLVMFuncOp>("llvm.memcpy.p0i8.p0i8.i64")) | ||||
|       return SymbolRefAttr::get("llvm.memcpy.p0i8.p0i8.i64", context); | ||||
|     // Create a function declaration for memcpy, the signature is:
 | ||||
|     //   * `void (i8*, i8* , i64, i1)`
 | ||||
|     auto llvmVoidTy = LLVM::LLVMType::getVoidTy(llvmDialect); | ||||
|     auto llvmI8PtrTy = LLVM::LLVMType::getInt8PtrTy(llvmDialect); | ||||
|     auto llvmI64Ty = LLVM::LLVMType::getInt64Ty(llvmDialect); | ||||
|     auto llvmFnType = LLVM::LLVMType::getFunctionTy(llvmVoidTy, | ||||
|         ArrayRef<mlir::LLVM::LLVMType>({llvmI8PtrTy, llvmI8PtrTy, llvmI64Ty}), | ||||
|         false); | ||||
| 
 | ||||
|     // Insert the memcpy function into the body of the parent module.
 | ||||
|     PatternRewriter::InsertionGuard insertGuard(rewriter); | ||||
|     rewriter.setInsertionPointToStart(module.getBody()); | ||||
|     rewriter.create<LLVM::LLVMFuncOp>( | ||||
|         module.getLoc(), "llvm.memcpy.p0i8.p0i8.i64", llvmFnType); | ||||
|     return SymbolRefAttr::get("llvm.memcpy.p0i8.p0i8.i64", context); | ||||
|   } | ||||
| }; | ||||
| }  // end namespace
 | ||||
| 
 | ||||
| //===----------------------------------------------------------------------===//
 | ||||
| // KRNL + Stadard + Affine dialects lowering to LLVM.
 | ||||
| //===----------------------------------------------------------------------===//
 | ||||
| 
 | ||||
| namespace { | ||||
| struct KrnlToLLVMLoweringPass : public ModulePass<KrnlToLLVMLoweringPass> { | ||||
|   void runOnModule() final; | ||||
| }; | ||||
| }  // end anonymous namespace
 | ||||
| 
 | ||||
| void KrnlToLLVMLoweringPass::runOnModule() { | ||||
|   // Define the target for this lowering i.e. the LLVM dialect.
 | ||||
|   ConversionTarget target(getContext()); | ||||
|   target.addLegalDialect<LLVM::LLVMDialect>(); | ||||
|   target.addLegalOp<ModuleOp, ModuleTerminatorOp>(); | ||||
| 
 | ||||
|   // Lower the MemRef types to a representation in LLVM.
 | ||||
|   LLVMTypeConverter typeConverter(&getContext()); | ||||
| 
 | ||||
|   // We have a combination of `krnl`, `affine`, and `std` operations. We
 | ||||
|   // lower in stages until all the code is in the LLVM dialect.
 | ||||
|   OwningRewritePatternList patterns; | ||||
|   populateAffineToStdConversionPatterns(patterns, &getContext()); | ||||
|   populateLoopToStdConversionPatterns(patterns, &getContext()); | ||||
|   populateStdToLLVMConversionPatterns(typeConverter, patterns); | ||||
| 
 | ||||
|   // Lower from the `krnl` dialect i.e. the Reshape operation.
 | ||||
|   patterns.insert<KrnlMemcpyOpLowering>(&getContext()); | ||||
| 
 | ||||
|   // We want to completely lower to LLVM, so we use a `FullConversion`. This
 | ||||
|   // ensures that only legal operations will remain after the conversion.
 | ||||
|   auto module = getModule(); | ||||
|   if (failed(applyFullConversion(module, target, patterns, &typeConverter))) | ||||
|     signalPassFailure(); | ||||
| } | ||||
| 
 | ||||
| /// Create the pass for lowering `Krnl`, `Affine` and `Std` dialects to LLVM.
 | ||||
| std::unique_ptr<mlir::Pass> mlir::createKrnlLowerToLLVMPass() { | ||||
|   return std::make_unique<KrnlToLLVMLoweringPass>(); | ||||
| } | ||||
| 
 | ||||
| static PassRegistration<KrnlToLLVMLoweringPass> pass( | ||||
|     "lower-all-llvm", "Lower the Krnl Affine and Std dialects to LLVM."); | ||||
|  | @ -131,7 +131,7 @@ int main(int ac, char* av[]) { | |||
|   pm.addPass(mlir::createLowerKrnlPass()); | ||||
|   pm.addPass(mlir::createLowerAffinePass()); | ||||
|   pm.addPass(mlir::createLowerToCFGPass()); | ||||
|   pm.addPass(mlir::createLowerToLLVMPass()); | ||||
|   pm.addPass(mlir::createKrnlLowerToLLVMPass()); | ||||
|   pm.addPass(mlir::createCanonicalizerPass()); | ||||
|   pm.run(*module); | ||||
| 
 | ||||
|  |  | |||
|  | @ -0,0 +1,16 @@ | |||
| // RUN: dlc-opt --shape-inference --lower-frontend --lower-krnl --lower-all-llvm %s -split-input-file | FileCheck %s | ||||
| 
 | ||||
| func @test_reshape(%arg0 : tensor<?x10xf32>, %arg1 : tensor<4xi32>) -> tensor<*xf32> { | ||||
|   %0 = "onnx.Reshape"(%arg0, %arg1) : (tensor<?x10xf32>, tensor<4xi32>) -> tensor<*xf32> | ||||
|   "std.return"(%0) : (tensor<*xf32>) -> () | ||||
| 
 | ||||
|   // CHECK: llvm.func @llvm.memcpy.p0i8.p0i8.i64(!llvm<"i8*">, !llvm<"i8*">, !llvm.i64) | ||||
|   // CHECK: [[RES:%.+]] = llvm.insertvalue {{.*}}[4, 3] : !llvm<"{ float*, float*, i64, [4 x i64], [4 x i64] }"> | ||||
|   // CHECK: [[EXT_VAL_0:%.+]] = llvm.extractvalue [[RES]][1] : !llvm<"{ float*, float*, i64, [4 x i64], [4 x i64] }"> | ||||
|   // CHECK: [[DST:%.+]] = llvm.bitcast [[EXT_VAL_0]] : !llvm<"float*"> to !llvm<"i8*"> | ||||
|   // CHECK: [[EXT_VAL_1:%.+]] = llvm.extractvalue %0[1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> | ||||
|   // CHECK: [[SRC:%.+]] = llvm.bitcast [[EXT_VAL_1]] : !llvm<"float*"> to !llvm<"i8*"> | ||||
|   // CHECK: [[SIZE:%.+]] = llvm.sext %{{.*}} : !llvm.i64 to !llvm.i64 | ||||
|   // CHECK: llvm.call @llvm.memcpy.p0i8.p0i8.i64([[DST]], [[SRC]], [[SIZE]]) : (!llvm<"i8*">, !llvm<"i8*">, !llvm.i64) -> !llvm.void | ||||
|   // CHECK: llvm.return [[RES]] : !llvm<"{ float*, float*, i64, [4 x i64], [4 x i64] }"> | ||||
| } | ||||
|  | @ -279,6 +279,37 @@ func @test_relu(%arg0 : tensor<?x10xf32>) -> tensor<*xf32> { | |||
|   // CHECK: return [[RES]] : memref<?x10xf32> | ||||
| } | ||||
| 
 | ||||
| func @test_reshape(%arg0 : tensor<?x10xf32>, %arg1 : tensor<4xi32>) -> tensor<*xf32> { | ||||
|   %0 = "onnx.Reshape"(%arg0, %arg1) : (tensor<?x10xf32>, tensor<4xi32>) -> tensor<*xf32> | ||||
|   "std.return"(%0) : (tensor<*xf32>) -> () | ||||
| 
 | ||||
|   // CHECK-LABEL: test_reshape | ||||
|   // CHECK: [[TYPE_IN_BYTES:%.+]] = constant 4 : i64 | ||||
|   // CHECK: %[[INDEX_0:.+]] = constant 0 : index | ||||
|   // CHECK: [[LOAD_0:%.+]] = load %arg1[%[[INDEX_0]]] : memref<4xi32> | ||||
|   // CHECK: [[EXT_0:%.+]] = zexti [[LOAD_0]] : i32 to i64 | ||||
|   // CHECK: [[MUL_0:%.+]] = muli [[TYPE_IN_BYTES]], [[EXT_0]] : i64 | ||||
|   // CHECK: [[CAST_0:%.+]] = index_cast [[LOAD_0]] : i32 to index | ||||
|   // CHECK: %[[INDEX_1:.+]] = constant 1 : index | ||||
|   // CHECK: [[LOAD_1:%.+]] = load %arg1[%[[INDEX_1]]] : memref<4xi32> | ||||
|   // CHECK: [[EXT_1:%.+]] = zexti [[LOAD_1]] : i32 to i64 | ||||
|   // CHECK: [[MUL_1:%.+]] = muli [[MUL_0]], [[EXT_1]] : i64 | ||||
|   // CHECK: [[CAST_1:%.+]] = index_cast [[LOAD_1]] : i32 to index | ||||
|   // CHECK: %[[INDEX_2:.+]] = constant 2 : index | ||||
|   // CHECK: [[LOAD_2:%.+]] = load %arg1[%[[INDEX_2]]] : memref<4xi32> | ||||
|   // CHECK: [[EXT_2:%.+]] = zexti [[LOAD_2]] : i32 to i64 | ||||
|   // CHECK: [[MUL_2:%.+]] = muli [[MUL_1]], [[EXT_2]] : i64 | ||||
|   // CHECK: [[CAST_2:%.+]] = index_cast [[LOAD_2]] : i32 to index | ||||
|   // CHECK: %[[INDEX_3:.+]] = constant 3 : index | ||||
|   // CHECK: [[LOAD_3:%.+]] = load %arg1[%[[INDEX_3]]] : memref<4xi32> | ||||
|   // CHECK: [[EXT_3:%.+]] = zexti [[LOAD_3]] : i32 to i64 | ||||
|   // CHECK: [[MUL_3:%.+]] = muli [[MUL_2]], [[EXT_3]] : i64 | ||||
|   // CHECK: [[CAST_3:%.+]] = index_cast [[LOAD_3]] : i32 to index | ||||
|   // CHECK: [[ALLOC:%.+]] = alloc([[CAST_0]], [[CAST_1]], [[CAST_2]], [[CAST_3]]) : memref<?x?x?x?xf32> | ||||
|   // CHECK: "krnl.memcpy"([[ALLOC]], %arg0, [[MUL_3]]) : (memref<?x?x?x?xf32>, memref<?x10xf32>, i64) -> () | ||||
|   // CHECK: return [[ALLOC]] : memref<?x?x?x?xf32> | ||||
| } | ||||
| 
 | ||||
| func @test_sum(%arg0 : tensor<?x10xf32>, %arg1 : tensor<?x10xf32>) -> tensor<*xf32> { | ||||
|   %0 = "onnx.Sum"(%arg0, %arg1) : (tensor<?x10xf32>, tensor<?x10xf32>) -> tensor<*xf32> | ||||
|   "std.return"(%0) : (tensor<*xf32>) -> () | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue