onnx-mlir/src/Conversion/ONNXToKrnl/ConvertONNXToKrnl.cpp

123 lines
5.0 KiB
C++

//====------ ConvertONNXToKrnl.cpp - ONNX dialects to Krnl lowering
//--------===//
//
// Copyright 2019 The IBM Research Authors.
//
// =============================================================================
//
// This file implements the lowering of frontend operations to a combination of
// Krnl IR and standard operations.
//
//===----------------------------------------------------------------------===//
#include "src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp"
using namespace mlir;
//===----------------------------------------------------------------------===//
// EntryPoint Op lowering to Krnl Entry Point.
//===----------------------------------------------------------------------===//
class ONNXEntryPointLowering : public OpRewritePattern<ONNXEntryPointOp> {
public:
using OpRewritePattern<ONNXEntryPointOp>::OpRewritePattern;
LogicalResult matchAndRewrite(
ONNXEntryPointOp op, PatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<KrnlEntryPointOp>(op,
op.getAttrOfType<SymbolRefAttr>(
ONNXEntryPointOp::getEntryPointFuncAttrName()),
op.getAttrOfType<IntegerAttr>(ONNXEntryPointOp::getNumInputsAttrName()),
op.getAttrOfType<IntegerAttr>(
ONNXEntryPointOp::getNumOutputsAttrName()));
return success();
}
};
//===----------------------------------------------------------------------===//
// Frontend to Krnl Dialect lowering pass
//===----------------------------------------------------------------------===//
/// This is a partial lowering to Krnl loops of the ONNX operations.
namespace {
struct FrontendToKrnlLoweringPass
: public PassWrapper<FrontendToKrnlLoweringPass, OperationPass<ModuleOp>> {
void runOnOperation() final;
};
} // end anonymous namespace.
void FrontendToKrnlLoweringPass::runOnOperation() {
ModuleOp module = getOperation();
// The first thing to define is the conversion target. This will define the
// final target for this lowering.
ConversionTarget target(getContext());
// We define the specific operations, or dialects, that are legal targets for
// this lowering.
target.addLegalDialect<KrnlOpsDialect, AffineDialect, StandardOpsDialect>();
// TODO: enable this once more ops are supported.
// We also define the ONNX dialect as Illegal so that the conversion will fail
// if any of these operations are *not* converted.
// target.addIllegalDialect<mlir::ONNXOpsDialect>();
// TODO: add any other ops which are considered legal.
// Some operations can be marked as being still legal.
// Example: target.addLegalOp<mlir::OpName>();
// Now that the conversion target has been defined, we just need to provide
// the set of patterns that will lower the frontend operations.
OwningRewritePatternList patterns;
// Convert TensorType to MemRef
TensorTypeConverter tensor_to_memref_converter;
target.addDynamicallyLegalOp<FuncOp>([&](FuncOp op) {
// FuncOp is legal only if types have been converted to Std types.
return tensor_to_memref_converter.isSignatureLegal(op.getType());
});
// Type conversion for function signatures.
// Call MLIR FuncOp signature conversion when result type is
// a ranked tensor.
populateFuncOpTypeConversionPattern(
patterns, &getContext(), tensor_to_memref_converter);
// Frontend operation lowering.
// Math
populateLoweringONNXElementwiseOpPattern(patterns, &getContext());
populateLoweringONNXGemmOpPattern(patterns, &getContext());
populateLoweringONNXReductionOpPattern(patterns, &getContext());
populateLoweringONNXSoftmaxOpPattern(patterns, &getContext());
populateLoweringONNXMatMulOpPattern(patterns, &getContext());
// Tensor
populateLoweringONNXReshapeOpPattern(patterns, &getContext());
populateLoweringONNXPadConstantValuePadOpPattern(patterns, &getContext());
populateLoweringONNXPadOpPattern(patterns, &getContext());
populateLoweringONNXUnsqueezeOpPattern(patterns, &getContext());
populateLoweringONNXTransposeOpPattern(patterns, &getContext());
populateLoweringONNXIdentityOpPattern(patterns, &getContext());
populateLoweringONNXConstantOpPattern(patterns, &getContext());
populateLoweringONNXConcatOpPattern(patterns, &getContext());
populateLoweringONNXSqueezeOpPattern(patterns, &getContext());
populateLoweringONNXSplitOpPattern(patterns, &getContext());
// Neural network
populateLoweringONNXConvOpPattern(patterns, &getContext());
populateLoweringONNXNormalizationOpPattern(patterns, &getContext());
populateLoweringONNXPoolingOpPattern(patterns, &getContext());
// Recurrent neural network
populateLoweringONNXLSTMOpPattern(patterns, &getContext());
// Entry point
patterns.insert<ONNXEntryPointLowering>(&getContext());
// With the target and rewrite patterns defined, we can now attempt the
// conversion. The conversion will signal failure if any of our `illegal`
// operations were not converted successfully.
if (failed(applyPartialConversion(module, target, patterns)))
signalPassFailure();
}
std::unique_ptr<Pass> mlir::createLowerToKrnlPass() {
return std::make_unique<FrontendToKrnlLoweringPass>();
}