//====------ ConvertONNXToKrnl.cpp - ONNX dialects to Krnl lowering //--------===// // // Copyright 2019 The IBM Research Authors. // // ============================================================================= // // This file implements the lowering of frontend operations to a combination of // Krnl IR and standard operations. // //===----------------------------------------------------------------------===// #include "src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp" using namespace mlir; //===----------------------------------------------------------------------===// // EntryPoint Op lowering to Krnl Entry Point. //===----------------------------------------------------------------------===// class ONNXEntryPointLowering : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite( ONNXEntryPointOp op, PatternRewriter &rewriter) const override { rewriter.replaceOpWithNewOp(op, op.getAttrOfType( ONNXEntryPointOp::getEntryPointFuncAttrName()), op.getAttrOfType(ONNXEntryPointOp::getNumInputsAttrName()), op.getAttrOfType( ONNXEntryPointOp::getNumOutputsAttrName())); return success(); } }; //===----------------------------------------------------------------------===// // Frontend to Krnl Dialect lowering pass //===----------------------------------------------------------------------===// /// This is a partial lowering to Krnl loops of the ONNX operations. namespace { struct FrontendToKrnlLoweringPass : public PassWrapper> { void runOnOperation() final; }; } // end anonymous namespace. void FrontendToKrnlLoweringPass::runOnOperation() { ModuleOp module = getOperation(); // The first thing to define is the conversion target. This will define the // final target for this lowering. ConversionTarget target(getContext()); // We define the specific operations, or dialects, that are legal targets for // this lowering. target.addLegalDialect(); // TODO: enable this once more ops are supported. // We also define the ONNX dialect as Illegal so that the conversion will fail // if any of these operations are *not* converted. // target.addIllegalDialect(); // TODO: add any other ops which are considered legal. // Some operations can be marked as being still legal. // Example: target.addLegalOp(); // Now that the conversion target has been defined, we just need to provide // the set of patterns that will lower the frontend operations. OwningRewritePatternList patterns; // Convert TensorType to MemRef TensorTypeConverter tensor_to_memref_converter; target.addDynamicallyLegalOp([&](FuncOp op) { // FuncOp is legal only if types have been converted to Std types. return tensor_to_memref_converter.isSignatureLegal(op.getType()); }); // Type conversion for function signatures. // Call MLIR FuncOp signature conversion when result type is // a ranked tensor. populateFuncOpTypeConversionPattern( patterns, &getContext(), tensor_to_memref_converter); // Frontend operation lowering. // Math populateLoweringONNXElementwiseOpPattern(patterns, &getContext()); populateLoweringONNXGemmOpPattern(patterns, &getContext()); populateLoweringONNXReductionOpPattern(patterns, &getContext()); populateLoweringONNXSoftmaxOpPattern(patterns, &getContext()); populateLoweringONNXMatMulOpPattern(patterns, &getContext()); // Tensor populateLoweringONNXReshapeOpPattern(patterns, &getContext()); populateLoweringONNXPadConstantValuePadOpPattern(patterns, &getContext()); populateLoweringONNXPadOpPattern(patterns, &getContext()); populateLoweringONNXUnsqueezeOpPattern(patterns, &getContext()); populateLoweringONNXTransposeOpPattern(patterns, &getContext()); populateLoweringONNXIdentityOpPattern(patterns, &getContext()); populateLoweringONNXConstantOpPattern(patterns, &getContext()); populateLoweringONNXConcatOpPattern(patterns, &getContext()); populateLoweringONNXSqueezeOpPattern(patterns, &getContext()); populateLoweringONNXSplitOpPattern(patterns, &getContext()); // Neural network populateLoweringONNXConvOpPattern(patterns, &getContext()); populateLoweringONNXNormalizationOpPattern(patterns, &getContext()); populateLoweringONNXPoolingOpPattern(patterns, &getContext()); // Recurrent neural network populateLoweringONNXLSTMOpPattern(patterns, &getContext()); // Entry point patterns.insert(&getContext()); // With the target and rewrite patterns defined, we can now attempt the // conversion. The conversion will signal failure if any of our `illegal` // operations were not converted successfully. if (failed(applyPartialConversion(module, target, patterns))) signalPassFailure(); } std::unique_ptr mlir::createLowerToKrnlPass() { return std::make_unique(); }