diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 6cb106c..488f77a 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -12,10 +12,7 @@ add_library(compiler pass/onnx_combine.cpp pass/onnx_rewrite.cpp pass/onnx_decompose.cpp - pass/passes.hpp - dialect/onnx/const_operands_interface.hpp - dialect/onnx/attribute_promotion.cpp - ) + pass/passes.hpp) # Include root src directory. target_include_directories(compiler PRIVATE ${ONNF_SRC_ROOT}) @@ -50,18 +47,12 @@ onnf_tablegen(onnx_rewrite.inc -gen-rewriters) add_public_tablegen_target(gen_onnx_rewrite) add_dependencies(compiler gen_onnx_rewrite) -add_subdirectory(dialect/onnx) - set(LLVM_TARGET_DEFINITIONS dialect/onnx/onnx.td) onnf_tablegen(onnx.hpp.inc -gen-op-decls "-I${CMAKE_SOURCE_DIR}/compiler/pass") onnf_tablegen(onnx.cpp.inc -gen-op-defs "-I${CMAKE_SOURCE_DIR}/compiler/pass") set(GEN_DOC_FILE ${CMAKE_BINARY_DIR}/docs/Dialects/onnx.md) add_public_tablegen_target(gen_onnx) - -add_dependencies(gen_onnx gen_shape_inference) -add_dependencies(gen_onnx gen_const_operands_interface) add_dependencies(compiler gen_onnx) - add_onnf_dialect_doc(onnx dialect/onnx/onnx.td) set(LLVM_TARGET_DEFINITIONS dialect/krnl/krnl_ops.td) @@ -75,7 +66,7 @@ target_include_directories(onnf_onnx_decompose PRIVATE ${ONNF_SRC_ROOT} ${ONNF_BIN_ROOT} ${ONNF_SRC_ROOT}) target_link_libraries(onnf_onnx_decompose ${MLIRLibs}) -add_dependencies(onnf_onnx_decompose gen_onnx) +add_dependencies(onnf_onnx_decompose gen_krnl_ops) add_library(onnf_shape_inference pass/shape_inference_pass.cpp) target_include_directories(onnf_shape_inference diff --git a/src/dialect/onnx/CMakeLists.txt b/src/dialect/onnx/CMakeLists.txt deleted file mode 100644 index 0eda4b6..0000000 --- a/src/dialect/onnx/CMakeLists.txt +++ /dev/null @@ -1,11 +0,0 @@ -#add_library(onnf_fold_const_operand_to_attribute fold_const_operand_to_attribute.hpp) -#target_include_directories(onnf_fold_const_operand_to_attribute -# PRIVATE ${ONNF_SRC_ROOT} ${ONNF_BIN_ROOT}) -#target_link_libraries(onnf_fold_const_operand_to_attribute ${MLIRLibs}) - -set(LLVM_TARGET_DEFINITIONS const_operands_interface.td) -onnf_tablegen(const_operands_interface.hpp.inc -gen-op-interface-decls) -onnf_tablegen(const_operands_interface.cpp.inc -gen-op-interface-defs) -message("TBLGENOUT" ${TABLEGEN_OUTPUT}) -add_public_tablegen_target(gen_const_operands_interface) -add_dependencies(compiler gen_const_operands_interface) \ No newline at end of file diff --git a/src/dialect/onnx/attribute_promotion.cpp b/src/dialect/onnx/attribute_promotion.cpp deleted file mode 100644 index b076dfe..0000000 --- a/src/dialect/onnx/attribute_promotion.cpp +++ /dev/null @@ -1,79 +0,0 @@ -//===----- shape_inference_pass.cpp - Shape Inference ---------------------===// -// -// Copyright 2019 The IBM Research Authors. -// -// ============================================================================= -// -// This file implements a Function level pass performing propagation of array -// shapes through function specialization. -// -//===----------------------------------------------------------------------===// - -#include "mlir/Pass/Pass.h" -#include "llvm/Support/raw_ostream.h" - -#include "src/dialect/onnx/const_operands_interface.hpp" -#include "src/dialect/onnx/onnx_ops.hpp" -#include "src/pass/passes.hpp" - -using namespace mlir; - -#include "src/dialect/onnx/const_operands_interface.cpp.inc" - -namespace { -/*! - * FunctionPass that performs shape inference by iterating over a list of - * candidate operations and propagating the shape information until the list - * of operations is empty [credit MLIR authors]. - */ -class AttributePromotionPass - : public mlir::FunctionPass { -public: - void runOnFunction() override { - auto f = getFunction(); - - llvm::Optional none; - f.walk([&](mlir::Operation *op) { - if (op->getNumResults() && op->getOpResult(0).getType().isa()) - none = op->getOpResult(0); - }); - - if (!none) { - OpBuilder builder(f.getContext()); - builder.setInsertionPointToStart(&f.front()); - none = - builder.create(f.getLoc(), builder.getUnitAttr()); - } - - // Iterate on the operations that need shape inference i.e the operations - // that return a dynamic shape. - f.walk([&](mlir::Operation *op) { - if (IdentifyConstOperandsOpInterface opWithConstOperand = - dyn_cast(op)) { - auto promotableOperands = opWithConstOperand.identifyPromotableConstOperands(); - for (const auto& operandNameToIdx : promotableOperands) { - auto name = operandNameToIdx.first; - auto idx = operandNameToIdx.second; - - auto operandToPromote = op->getOperand(idx); - if (auto constantOp = - dyn_cast_or_null(operandToPromote.getDefiningOp())) { - op->setAttr(name, constantOp.value()); - op->setOperand(idx, *none); - } - } - } - }); - } -}; -} // end anonymous namespace - -/*! - * Create a Shape Inference pass. - */ -std::unique_ptr mlir::createAttributePromotionPass() { - return std::make_unique(); -} - -static PassRegistration pass( - "attribute-promotion", "Shape inference for frontend dialects."); diff --git a/src/dialect/onnx/const_operands_interface.hpp b/src/dialect/onnx/const_operands_interface.hpp deleted file mode 100644 index 66c01f0..0000000 --- a/src/dialect/onnx/const_operands_interface.hpp +++ /dev/null @@ -1,24 +0,0 @@ -//===- shape_inference_interface.hpp - Definition for ShapeInference --------=// -// -// Copyright 2019 The IBM Research Authors. -// -// ============================================================================= -// -// This file contains the declarations of the shape inference interfaces defined -// in ShapeInferenceInterface.td. -// -//===----------------------------------------------------------------------===// - -#pragma once - -#include -#include - -#include "mlir/IR/OpDefinition.h" - -namespace mlir { - -/// Include the auto-generated declarations. -#include "src/dialect/onnx/const_operands_interface.hpp.inc" - -} // end namespace mlir \ No newline at end of file diff --git a/src/dialect/onnx/const_operands_interface.td b/src/dialect/onnx/const_operands_interface.td deleted file mode 100644 index 080ee7f..0000000 --- a/src/dialect/onnx/const_operands_interface.td +++ /dev/null @@ -1,32 +0,0 @@ -//=- shape_inference_interface.td - Shape Inference Interface -*- tablegen -==// -// -// Copyright 2019 The IBM Research Authors. -// -// ============================================================================= -// -// Defines the operations of the Shape Inference Op Interface. -// -//===----------------------------------------------------------------------===// - -#ifdef IDENTIFY_PROMOTABLE_CONST_OPERANDS_OP_INTERFACE -#else -#define IDENTIFY_PROMOTABLE_CONST_OPERANDS_OP_INTERFACE - -#ifdef OP_BASE -#else -include "mlir/IR/OpBase.td" -#endif // OP_BASE - -def IdentifyConstOperandsOpInterface : OpInterface<"IdentifyConstOperandsOpInterface"> { - let description = [{ - Interface to access a registered method to infer the return types for an - operation that can be used during type inference. - }]; - - let methods = [ - InterfaceMethod<"Infer and set the output shape for the current operation.", - "std::map", "identifyPromotableConstOperands"> - ]; -} - -#endif // IDENTIFY_PROMOTABLE_CONST_OPERANDS_OP_INTERFACE diff --git a/src/dialect/onnx/onnx.td b/src/dialect/onnx/onnx.td index 09acf6c..68dbbf6 100644 --- a/src/dialect/onnx/onnx.td +++ b/src/dialect/onnx/onnx.td @@ -22,11 +22,6 @@ include "mlir/IR/OpBase.td" include "pass/shape_inference_interface.td" #endif // SHAPE_INFERENCE_INTERFACE -#ifdef IDENTIFY_PROMOTABLE_CONST_OPERANDS_OP_INTERFACE -#else -include "dialect/onnx/const_operands_interface.td" -#endif // IDENTIFY_PROMOTABLE_CONST_OPERANDS_OP_INTERFACE - def ONNX_Dialect : Dialect { let name = "onnx"; let cppNamespace = ""; diff --git a/src/dialect/onnx/onnx_ops.hpp b/src/dialect/onnx/onnx_ops.hpp index ab4ca65..1ba9669 100644 --- a/src/dialect/onnx/onnx_ops.hpp +++ b/src/dialect/onnx/onnx_ops.hpp @@ -10,9 +10,6 @@ #pragma once -#include -#include - #include "mlir/Dialect/StandardOps/Ops.h" #include "mlir/IR/Builders.h" #include "mlir/IR/Dialect.h" @@ -20,7 +17,6 @@ #include "mlir/IR/StandardTypes.h" #include "src/pass/shape_inference_interface.hpp" -#include "src/dialect/onnx/const_operands_interface.hpp" namespace mlir { diff --git a/src/dialect/onnx/onnxop.inc b/src/dialect/onnx/onnxop.inc index 9091556..c64e1e6 100644 --- a/src/dialect/onnx/onnxop.inc +++ b/src/dialect/onnx/onnxop.inc @@ -26,7 +26,7 @@ def ONNXAbsOp:ONNX_Op<"Abs", outputTypes.emplace_back(UnrankedTensorType::get(elementType)); build(builder, state, outputTypes, operands, attributes); }]> - ]; + ]; } def ONNXAcosOp:ONNX_Op<"Acos", @@ -2501,9 +2501,7 @@ def ONNXReluOp:ONNX_Op<"Relu", } def ONNXReshapeOp:ONNX_Op<"Reshape", - [NoSideEffect, - DeclareOpInterfaceMethods, - OpInterface<"IdentifyConstOperandsOpInterface">]> { + [NoSideEffect, DeclareOpInterfaceMethods]> { let summary = "ONNX Reshape operation"; let description = [{ "Reshape the input tensor similar to numpy.reshape." @@ -2514,11 +2512,8 @@ def ONNXReshapeOp:ONNX_Op<"Reshape", "from the input tensor)." }]; let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$data, - AnyTypeOf<[AnyMemRef, AnyTensor, NoneType]>:$shape); + AnyTypeOf<[AnyMemRef, AnyTensor]>:$shape); let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$reshaped); - let extraClassDeclaration = [{ - std::map identifyPromotableConstOperands() { return {{"shape", 1}}; } - }]; } def ONNXResizeOp:ONNX_Op<"Resize", diff --git a/src/main.cpp b/src/main.cpp index b5f5cf0..9500d31 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -126,7 +126,6 @@ int main(int argc, char *argv[]) { pm.addPass(mlir::createShapeInferencePass()); pm.addPass(mlir::createCanonicalizerPass()); pm.addPass(mlir::createShapeInferencePass()); - pm.addPass(mlir::createAttributePromotionPass()); if (emissionTarget >= EmitMLIR) { pm.addPass(mlir::createLowerToKrnlPass()); diff --git a/src/pass/passes.hpp b/src/pass/passes.hpp index cbe9961..b7bdc96 100644 --- a/src/pass/passes.hpp +++ b/src/pass/passes.hpp @@ -20,8 +20,6 @@ std::unique_ptr createDecomposeONNXToONNXPass(); std::unique_ptr createShapeInferencePass(); -std::unique_ptr createAttributePromotionPass(); - /// Add pass for lowering to Krnl IR. std::unique_ptr createLowerToKrnlPass(); diff --git a/src/pass/shape_inference_pass.cpp b/src/pass/shape_inference_pass.cpp index 9c3ab85..8df3d66 100644 --- a/src/pass/shape_inference_pass.cpp +++ b/src/pass/shape_inference_pass.cpp @@ -12,9 +12,9 @@ #include "mlir/Pass/Pass.h" #include "llvm/ADT/SmallPtrSet.h" #include "llvm/Support/raw_ostream.h" -#include "mlir/IR/StandardTypes.h" #include "shape_inference_interface.hpp" +#include "src/dialect/onnx/onnx_ops.hpp" #include "passes.hpp"