//====- lower_frontend_to_krnl.cpp - Frontend dialects to Krnl lowering ---===// // // Copyright 2019 The IBM Research Authors. // // ============================================================================= // // This file implements the lowering of frontend operations to a combination of // Krnl IR and standard operations. // //===----------------------------------------------------------------------===// #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/Sequence.h" #include "mlir/Dialect/AffineOps/AffineOps.h" #include "mlir/Dialect/StandardOps/Ops.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/DialectConversion.h" #include "src/compiler/dialect/krnl/krnl_ops.hpp" #include "src/compiler/dialect/onnx/onnx_ops.hpp" #include "src/compiler/pass/passes.hpp" using namespace mlir; //===----------------------------------------------------------------------===// // FrontendToAffine RewritePatterns //===----------------------------------------------------------------------===// /// Check is all dimensions are known at compile time. static bool hasAllConstantDimensions(MemRefType type) { auto memRefShape = type.getShape(); for (int i = 0; i < memRefShape.size(); ++i) if (memRefShape[i] < 0) return false; return true; } /// Convert the given TensorType into the corresponding MemRefType. static MemRefType convertTensorToMemRef(TensorType type) { assert(type.hasRank() && "expected only ranked shapes"); return MemRefType::get(type.getShape(), type.getElementType()); } /// Insert an allocation and deallocation for the given MemRefType. static Value* insertAllocAndDealloc( MemRefType type, Location loc, PatternRewriter& rewriter, Value *oldMemRef = nullptr) { // Put together alloc operands for any dynamic dimensions of the memref. AllocOp alloc; if (oldMemRef) { SmallVector allocOperands; auto memRefShape = type.getShape(); for (int i = 0; i < memRefShape.size(); ++i) if (memRefShape[i] < 0) allocOperands.push_back(rewriter.create(loc, oldMemRef, i)); alloc = rewriter.create(loc, type, allocOperands); } else { alloc = rewriter.create(loc, type); } // Make sure to allocate at the beginning of the block if // all dimensions are known. auto* parentBlock = alloc.getOperation()->getBlock(); if (hasAllConstantDimensions(type)) alloc.getOperation()->moveBefore(&parentBlock->front()); return alloc; } namespace { //===----------------------------------------------------------------------===// // AddOp lowering to Krnl dialect. //===----------------------------------------------------------------------===// struct ONNXAddOpLowering : public ConversionPattern { ONNXAddOpLowering(MLIRContext* ctx) : ConversionPattern(mlir::ONNXAddOp::getOperationName(), 1, ctx) {} PatternMatchResult matchAndRewrite(Operation* op, ArrayRef operands, ConversionPatternRewriter& rewriter) const final { // TODO: Check that the types are valid. // Add is an operation that must have all operands and the result of // the same type. This should have been verified by the verifier. auto tensorType = (*op->result_type_begin()).cast(); auto loc = op->getLoc(); // Insert an allocation and deallocation for the result of this operation. auto memRefType = convertTensorToMemRef(tensorType); // If the output has a dynamic dimension, pass the operands required for // each dynamic dimension to the AllocOp. The first operand of the Add // operation is used. The operands of the Add need to match in terms of // dimensions with the result at this pre-optimization phase. // TODO: verify that dimensions match. // TODO: can the dimension of the result differ after optimizations? Value *alloc; if (hasAllConstantDimensions(memRefType)) alloc = insertAllocAndDealloc(memRefType, loc, rewriter); else alloc = insertAllocAndDealloc(memRefType, loc, rewriter, operands[0]); // Number of loops auto memRefShape = memRefType.getShape(); int64_t rank = memRefShape.size(); // Define loops. auto loopsOp = rewriter.create(loc, rank); std::vector originalLoops; originalLoops.reserve(rank); for (auto result : loopsOp.getResults()) { originalLoops.push_back(result); } // Define loop optimization. auto optimizedLoopsOp = rewriter.create(loc, rank); std::vector optimizedLoops; optimizedLoops.reserve(rank); for (auto result : optimizedLoopsOp.getResults()) { optimizedLoops.push_back(result); } Block& optimizationBlock = optimizedLoopsOp.region().front(); // Iterate over the loop nest. // TODO (Tian): move this logic inside KrnlIterateOp. Pass MemRefShape // to KrnlIterateOp instead. SmallVector operandBounds; SmallVector constBounds; SmallVector boundTypes; for (int i = 0; i < rank; ++i) { if (memRefShape[i] < 0) { // This is a dynamic value, hence use operands. // Lower bound constBounds.push_back(0); boundTypes.push_back(0); // Upper bound operandBounds.push_back( rewriter.create(loc, operands[0], i).getResult()); boundTypes.push_back(1); } else { // Lower bound constBounds.push_back(0); boundTypes.push_back(0); // Upper bound constBounds.push_back(memRefShape[i]); boundTypes.push_back(0); } } auto iterateOp = rewriter.create(loc, originalLoops, optimizedLoops, operandBounds, constBounds, boundTypes); Block& iterationBlock = iterateOp.bodyRegion().front(); // Now perform the insertions into the body of the // just generated instructions: // 1. Insert any optimizations in the KrnlOptimizeLoopsOp body. rewriter.setInsertionPointToEnd(&optimizationBlock); // Return from KrnlOptimizeLoopsOp body. // When no optimizations are present we just return the loops // unchaged. rewriter.create(loc, originalLoops); rewriter.setInsertionPoint(optimizedLoopsOp); // 2. Insert instructions inside the KernelIterateOp body. rewriter.setInsertionPointToStart(&iterationBlock); // Handle AddOp: SmallVector loopIVs; for (auto arg : iterationBlock.getArguments()) loopIVs.push_back(arg); auto loadedFirstVal = rewriter.create(loc, operands[0], loopIVs); auto loadedSecondVal = rewriter.create(loc, operands[1], loopIVs); // TODO: Choose type of the Add for now use the Float Add. auto addOpResult = rewriter.create( loc, loadedFirstVal, loadedSecondVal); // Store result in the resulting array. rewriter.create(loc, addOpResult, alloc, loopIVs); rewriter.replaceOp(op, alloc); return matchSuccess(); } }; //===----------------------------------------------------------------------===// // Conversion from Tensor type to the Standard dialect MemRef type. //===----------------------------------------------------------------------===// struct TensorTypeConverter : public TypeConverter { using TypeConverter::TypeConverter; LogicalResult convertType(Type t, SmallVectorImpl& results) override { if (auto tensor_type = t.dyn_cast()) { results.push_back(convertTensorToMemRef(tensor_type)); return success(); } results.push_back(t); return success(); } /// Return true if the inputs and outputs of the given function type are /// legal. [Taken from MLIR and adapted to only check the legality of the /// inputs. Once unranked results can be handled gracefully this /// override needs to be removed in favour of the original MLIR one.] bool isSignatureLegal(FunctionType funcType) { return llvm::all_of(funcType.getInputs(), [this](Type type) { return isLegal(type); }); } }; } // end anonymous namespace. //===----------------------------------------------------------------------===// // Frontend to Krnl Dialect lowering pass //===----------------------------------------------------------------------===// /// This is a partial lowering to Krnl loops of the ONNX operations. namespace { struct FrontendToKrnlLoweringPass : public ModulePass { void runOnModule() final; }; } // end anonymous namespace. void FrontendToKrnlLoweringPass::runOnModule() { auto module = getModule(); // The first thing to define is the conversion target. This will define the // final target for this lowering. ConversionTarget target(getContext()); // We define the specific operations, or dialects, that are legal targets for // this lowering. target .addLegalDialect(); // TODO: enable this once more ops are supported. // We also define the ONNX dialect as Illegal so that the conversion will fail // if any of these operations are *not* converted. // target.addIllegalDialect(); // TODO: add any other ops which are considered legal. // Some operations can be marked as being still legal. // Example: target.addLegalOp(); // Now that the conversion target has been defined, we just need to provide // the set of patterns that will lower the frontend operations. OwningRewritePatternList patterns; // Convert TensorType to MemRef TensorTypeConverter tensor_to_memref_converter; target.addDynamicallyLegalOp([&](FuncOp op) { // FuncOp is legal only if types have been converted to Std types. return tensor_to_memref_converter.isSignatureLegal(op.getType()); }); // Type conversion for function signatures. // Call MLIR FuncOp signature conversion when result type is // a ranked tensor. populateFuncOpTypeConversionPattern( patterns, &getContext(), tensor_to_memref_converter); // Frontent operation lowering. patterns.insert(&getContext()); // With the target and rewrite patterns defined, we can now attempt the // conversion. The conversion will signal failure if any of our `illegal` // operations were not converted successfully. if (failed(applyPartialConversion( module, target, patterns))) signalPassFailure(); } std::unique_ptr mlir::createLowerToKrnlPass() { return std::make_unique(); }