diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 488f77a..6cb106c 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -12,7 +12,10 @@ add_library(compiler pass/onnx_combine.cpp pass/onnx_rewrite.cpp pass/onnx_decompose.cpp - pass/passes.hpp) + pass/passes.hpp + dialect/onnx/const_operands_interface.hpp + dialect/onnx/attribute_promotion.cpp + ) # Include root src directory. target_include_directories(compiler PRIVATE ${ONNF_SRC_ROOT}) @@ -47,12 +50,18 @@ onnf_tablegen(onnx_rewrite.inc -gen-rewriters) add_public_tablegen_target(gen_onnx_rewrite) add_dependencies(compiler gen_onnx_rewrite) +add_subdirectory(dialect/onnx) + set(LLVM_TARGET_DEFINITIONS dialect/onnx/onnx.td) onnf_tablegen(onnx.hpp.inc -gen-op-decls "-I${CMAKE_SOURCE_DIR}/compiler/pass") onnf_tablegen(onnx.cpp.inc -gen-op-defs "-I${CMAKE_SOURCE_DIR}/compiler/pass") set(GEN_DOC_FILE ${CMAKE_BINARY_DIR}/docs/Dialects/onnx.md) add_public_tablegen_target(gen_onnx) + +add_dependencies(gen_onnx gen_shape_inference) +add_dependencies(gen_onnx gen_const_operands_interface) add_dependencies(compiler gen_onnx) + add_onnf_dialect_doc(onnx dialect/onnx/onnx.td) set(LLVM_TARGET_DEFINITIONS dialect/krnl/krnl_ops.td) @@ -66,7 +75,7 @@ target_include_directories(onnf_onnx_decompose PRIVATE ${ONNF_SRC_ROOT} ${ONNF_BIN_ROOT} ${ONNF_SRC_ROOT}) target_link_libraries(onnf_onnx_decompose ${MLIRLibs}) -add_dependencies(onnf_onnx_decompose gen_krnl_ops) +add_dependencies(onnf_onnx_decompose gen_onnx) add_library(onnf_shape_inference pass/shape_inference_pass.cpp) target_include_directories(onnf_shape_inference diff --git a/src/dialect/onnx/CMakeLists.txt b/src/dialect/onnx/CMakeLists.txt new file mode 100644 index 0000000..0eda4b6 --- /dev/null +++ b/src/dialect/onnx/CMakeLists.txt @@ -0,0 +1,11 @@ +#add_library(onnf_fold_const_operand_to_attribute fold_const_operand_to_attribute.hpp) +#target_include_directories(onnf_fold_const_operand_to_attribute +# PRIVATE ${ONNF_SRC_ROOT} ${ONNF_BIN_ROOT}) +#target_link_libraries(onnf_fold_const_operand_to_attribute ${MLIRLibs}) + +set(LLVM_TARGET_DEFINITIONS const_operands_interface.td) +onnf_tablegen(const_operands_interface.hpp.inc -gen-op-interface-decls) +onnf_tablegen(const_operands_interface.cpp.inc -gen-op-interface-defs) +message("TBLGENOUT" ${TABLEGEN_OUTPUT}) +add_public_tablegen_target(gen_const_operands_interface) +add_dependencies(compiler gen_const_operands_interface) \ No newline at end of file diff --git a/src/dialect/onnx/attribute_promotion.cpp b/src/dialect/onnx/attribute_promotion.cpp new file mode 100644 index 0000000..b076dfe --- /dev/null +++ b/src/dialect/onnx/attribute_promotion.cpp @@ -0,0 +1,79 @@ +//===----- shape_inference_pass.cpp - Shape Inference ---------------------===// +// +// Copyright 2019 The IBM Research Authors. +// +// ============================================================================= +// +// This file implements a Function level pass performing propagation of array +// shapes through function specialization. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Pass/Pass.h" +#include "llvm/Support/raw_ostream.h" + +#include "src/dialect/onnx/const_operands_interface.hpp" +#include "src/dialect/onnx/onnx_ops.hpp" +#include "src/pass/passes.hpp" + +using namespace mlir; + +#include "src/dialect/onnx/const_operands_interface.cpp.inc" + +namespace { +/*! + * FunctionPass that performs shape inference by iterating over a list of + * candidate operations and propagating the shape information until the list + * of operations is empty [credit MLIR authors]. + */ +class AttributePromotionPass + : public mlir::FunctionPass { +public: + void runOnFunction() override { + auto f = getFunction(); + + llvm::Optional none; + f.walk([&](mlir::Operation *op) { + if (op->getNumResults() && op->getOpResult(0).getType().isa()) + none = op->getOpResult(0); + }); + + if (!none) { + OpBuilder builder(f.getContext()); + builder.setInsertionPointToStart(&f.front()); + none = + builder.create(f.getLoc(), builder.getUnitAttr()); + } + + // Iterate on the operations that need shape inference i.e the operations + // that return a dynamic shape. + f.walk([&](mlir::Operation *op) { + if (IdentifyConstOperandsOpInterface opWithConstOperand = + dyn_cast(op)) { + auto promotableOperands = opWithConstOperand.identifyPromotableConstOperands(); + for (const auto& operandNameToIdx : promotableOperands) { + auto name = operandNameToIdx.first; + auto idx = operandNameToIdx.second; + + auto operandToPromote = op->getOperand(idx); + if (auto constantOp = + dyn_cast_or_null(operandToPromote.getDefiningOp())) { + op->setAttr(name, constantOp.value()); + op->setOperand(idx, *none); + } + } + } + }); + } +}; +} // end anonymous namespace + +/*! + * Create a Shape Inference pass. + */ +std::unique_ptr mlir::createAttributePromotionPass() { + return std::make_unique(); +} + +static PassRegistration pass( + "attribute-promotion", "Shape inference for frontend dialects."); diff --git a/src/dialect/onnx/const_operands_interface.hpp b/src/dialect/onnx/const_operands_interface.hpp new file mode 100644 index 0000000..66c01f0 --- /dev/null +++ b/src/dialect/onnx/const_operands_interface.hpp @@ -0,0 +1,24 @@ +//===- shape_inference_interface.hpp - Definition for ShapeInference --------=// +// +// Copyright 2019 The IBM Research Authors. +// +// ============================================================================= +// +// This file contains the declarations of the shape inference interfaces defined +// in ShapeInferenceInterface.td. +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include +#include + +#include "mlir/IR/OpDefinition.h" + +namespace mlir { + +/// Include the auto-generated declarations. +#include "src/dialect/onnx/const_operands_interface.hpp.inc" + +} // end namespace mlir \ No newline at end of file diff --git a/src/dialect/onnx/const_operands_interface.td b/src/dialect/onnx/const_operands_interface.td new file mode 100644 index 0000000..080ee7f --- /dev/null +++ b/src/dialect/onnx/const_operands_interface.td @@ -0,0 +1,32 @@ +//=- shape_inference_interface.td - Shape Inference Interface -*- tablegen -==// +// +// Copyright 2019 The IBM Research Authors. +// +// ============================================================================= +// +// Defines the operations of the Shape Inference Op Interface. +// +//===----------------------------------------------------------------------===// + +#ifdef IDENTIFY_PROMOTABLE_CONST_OPERANDS_OP_INTERFACE +#else +#define IDENTIFY_PROMOTABLE_CONST_OPERANDS_OP_INTERFACE + +#ifdef OP_BASE +#else +include "mlir/IR/OpBase.td" +#endif // OP_BASE + +def IdentifyConstOperandsOpInterface : OpInterface<"IdentifyConstOperandsOpInterface"> { + let description = [{ + Interface to access a registered method to infer the return types for an + operation that can be used during type inference. + }]; + + let methods = [ + InterfaceMethod<"Infer and set the output shape for the current operation.", + "std::map", "identifyPromotableConstOperands"> + ]; +} + +#endif // IDENTIFY_PROMOTABLE_CONST_OPERANDS_OP_INTERFACE diff --git a/src/dialect/onnx/onnx.td b/src/dialect/onnx/onnx.td index 68dbbf6..09acf6c 100644 --- a/src/dialect/onnx/onnx.td +++ b/src/dialect/onnx/onnx.td @@ -22,6 +22,11 @@ include "mlir/IR/OpBase.td" include "pass/shape_inference_interface.td" #endif // SHAPE_INFERENCE_INTERFACE +#ifdef IDENTIFY_PROMOTABLE_CONST_OPERANDS_OP_INTERFACE +#else +include "dialect/onnx/const_operands_interface.td" +#endif // IDENTIFY_PROMOTABLE_CONST_OPERANDS_OP_INTERFACE + def ONNX_Dialect : Dialect { let name = "onnx"; let cppNamespace = ""; diff --git a/src/dialect/onnx/onnx_ops.hpp b/src/dialect/onnx/onnx_ops.hpp index 1ba9669..ab4ca65 100644 --- a/src/dialect/onnx/onnx_ops.hpp +++ b/src/dialect/onnx/onnx_ops.hpp @@ -10,6 +10,9 @@ #pragma once +#include +#include + #include "mlir/Dialect/StandardOps/Ops.h" #include "mlir/IR/Builders.h" #include "mlir/IR/Dialect.h" @@ -17,6 +20,7 @@ #include "mlir/IR/StandardTypes.h" #include "src/pass/shape_inference_interface.hpp" +#include "src/dialect/onnx/const_operands_interface.hpp" namespace mlir { diff --git a/src/dialect/onnx/onnxop.inc b/src/dialect/onnx/onnxop.inc index c64e1e6..9091556 100644 --- a/src/dialect/onnx/onnxop.inc +++ b/src/dialect/onnx/onnxop.inc @@ -26,7 +26,7 @@ def ONNXAbsOp:ONNX_Op<"Abs", outputTypes.emplace_back(UnrankedTensorType::get(elementType)); build(builder, state, outputTypes, operands, attributes); }]> - ]; + ]; } def ONNXAcosOp:ONNX_Op<"Acos", @@ -2501,7 +2501,9 @@ def ONNXReluOp:ONNX_Op<"Relu", } def ONNXReshapeOp:ONNX_Op<"Reshape", - [NoSideEffect, DeclareOpInterfaceMethods]> { + [NoSideEffect, + DeclareOpInterfaceMethods, + OpInterface<"IdentifyConstOperandsOpInterface">]> { let summary = "ONNX Reshape operation"; let description = [{ "Reshape the input tensor similar to numpy.reshape." @@ -2512,8 +2514,11 @@ def ONNXReshapeOp:ONNX_Op<"Reshape", "from the input tensor)." }]; let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$data, - AnyTypeOf<[AnyMemRef, AnyTensor]>:$shape); + AnyTypeOf<[AnyMemRef, AnyTensor, NoneType]>:$shape); let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$reshaped); + let extraClassDeclaration = [{ + std::map identifyPromotableConstOperands() { return {{"shape", 1}}; } + }]; } def ONNXResizeOp:ONNX_Op<"Resize", diff --git a/src/main.cpp b/src/main.cpp index 9500d31..b5f5cf0 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -126,6 +126,7 @@ int main(int argc, char *argv[]) { pm.addPass(mlir::createShapeInferencePass()); pm.addPass(mlir::createCanonicalizerPass()); pm.addPass(mlir::createShapeInferencePass()); + pm.addPass(mlir::createAttributePromotionPass()); if (emissionTarget >= EmitMLIR) { pm.addPass(mlir::createLowerToKrnlPass()); diff --git a/src/pass/passes.hpp b/src/pass/passes.hpp index b7bdc96..cbe9961 100644 --- a/src/pass/passes.hpp +++ b/src/pass/passes.hpp @@ -20,6 +20,8 @@ std::unique_ptr createDecomposeONNXToONNXPass(); std::unique_ptr createShapeInferencePass(); +std::unique_ptr createAttributePromotionPass(); + /// Add pass for lowering to Krnl IR. std::unique_ptr createLowerToKrnlPass(); diff --git a/src/pass/shape_inference_pass.cpp b/src/pass/shape_inference_pass.cpp index 8df3d66..9c3ab85 100644 --- a/src/pass/shape_inference_pass.cpp +++ b/src/pass/shape_inference_pass.cpp @@ -12,9 +12,9 @@ #include "mlir/Pass/Pass.h" #include "llvm/ADT/SmallPtrSet.h" #include "llvm/Support/raw_ostream.h" +#include "mlir/IR/StandardTypes.h" #include "shape_inference_interface.hpp" -#include "src/dialect/onnx/onnx_ops.hpp" #include "passes.hpp"