parent
955968b750
commit
c25831094e
|
@ -12,10 +12,7 @@ add_library(compiler
|
||||||
pass/onnx_combine.cpp
|
pass/onnx_combine.cpp
|
||||||
pass/onnx_rewrite.cpp
|
pass/onnx_rewrite.cpp
|
||||||
pass/onnx_decompose.cpp
|
pass/onnx_decompose.cpp
|
||||||
pass/passes.hpp
|
pass/passes.hpp)
|
||||||
dialect/onnx/const_operands_interface.hpp
|
|
||||||
dialect/onnx/attribute_promotion.cpp
|
|
||||||
)
|
|
||||||
|
|
||||||
# Include root src directory.
|
# Include root src directory.
|
||||||
target_include_directories(compiler PRIVATE ${ONNF_SRC_ROOT})
|
target_include_directories(compiler PRIVATE ${ONNF_SRC_ROOT})
|
||||||
|
@ -50,18 +47,12 @@ onnf_tablegen(onnx_rewrite.inc -gen-rewriters)
|
||||||
add_public_tablegen_target(gen_onnx_rewrite)
|
add_public_tablegen_target(gen_onnx_rewrite)
|
||||||
add_dependencies(compiler gen_onnx_rewrite)
|
add_dependencies(compiler gen_onnx_rewrite)
|
||||||
|
|
||||||
add_subdirectory(dialect/onnx)
|
|
||||||
|
|
||||||
set(LLVM_TARGET_DEFINITIONS dialect/onnx/onnx.td)
|
set(LLVM_TARGET_DEFINITIONS dialect/onnx/onnx.td)
|
||||||
onnf_tablegen(onnx.hpp.inc -gen-op-decls "-I${CMAKE_SOURCE_DIR}/compiler/pass")
|
onnf_tablegen(onnx.hpp.inc -gen-op-decls "-I${CMAKE_SOURCE_DIR}/compiler/pass")
|
||||||
onnf_tablegen(onnx.cpp.inc -gen-op-defs "-I${CMAKE_SOURCE_DIR}/compiler/pass")
|
onnf_tablegen(onnx.cpp.inc -gen-op-defs "-I${CMAKE_SOURCE_DIR}/compiler/pass")
|
||||||
set(GEN_DOC_FILE ${CMAKE_BINARY_DIR}/docs/Dialects/onnx.md)
|
set(GEN_DOC_FILE ${CMAKE_BINARY_DIR}/docs/Dialects/onnx.md)
|
||||||
add_public_tablegen_target(gen_onnx)
|
add_public_tablegen_target(gen_onnx)
|
||||||
|
|
||||||
add_dependencies(gen_onnx gen_shape_inference)
|
|
||||||
add_dependencies(gen_onnx gen_const_operands_interface)
|
|
||||||
add_dependencies(compiler gen_onnx)
|
add_dependencies(compiler gen_onnx)
|
||||||
|
|
||||||
add_onnf_dialect_doc(onnx dialect/onnx/onnx.td)
|
add_onnf_dialect_doc(onnx dialect/onnx/onnx.td)
|
||||||
|
|
||||||
set(LLVM_TARGET_DEFINITIONS dialect/krnl/krnl_ops.td)
|
set(LLVM_TARGET_DEFINITIONS dialect/krnl/krnl_ops.td)
|
||||||
|
@ -75,7 +66,7 @@ target_include_directories(onnf_onnx_decompose
|
||||||
PRIVATE ${ONNF_SRC_ROOT} ${ONNF_BIN_ROOT}
|
PRIVATE ${ONNF_SRC_ROOT} ${ONNF_BIN_ROOT}
|
||||||
${ONNF_SRC_ROOT})
|
${ONNF_SRC_ROOT})
|
||||||
target_link_libraries(onnf_onnx_decompose ${MLIRLibs})
|
target_link_libraries(onnf_onnx_decompose ${MLIRLibs})
|
||||||
add_dependencies(onnf_onnx_decompose gen_onnx)
|
add_dependencies(onnf_onnx_decompose gen_krnl_ops)
|
||||||
|
|
||||||
add_library(onnf_shape_inference pass/shape_inference_pass.cpp)
|
add_library(onnf_shape_inference pass/shape_inference_pass.cpp)
|
||||||
target_include_directories(onnf_shape_inference
|
target_include_directories(onnf_shape_inference
|
||||||
|
|
|
@ -1,11 +0,0 @@
|
||||||
#add_library(onnf_fold_const_operand_to_attribute fold_const_operand_to_attribute.hpp)
|
|
||||||
#target_include_directories(onnf_fold_const_operand_to_attribute
|
|
||||||
# PRIVATE ${ONNF_SRC_ROOT} ${ONNF_BIN_ROOT})
|
|
||||||
#target_link_libraries(onnf_fold_const_operand_to_attribute ${MLIRLibs})
|
|
||||||
|
|
||||||
set(LLVM_TARGET_DEFINITIONS const_operands_interface.td)
|
|
||||||
onnf_tablegen(const_operands_interface.hpp.inc -gen-op-interface-decls)
|
|
||||||
onnf_tablegen(const_operands_interface.cpp.inc -gen-op-interface-defs)
|
|
||||||
message("TBLGENOUT" ${TABLEGEN_OUTPUT})
|
|
||||||
add_public_tablegen_target(gen_const_operands_interface)
|
|
||||||
add_dependencies(compiler gen_const_operands_interface)
|
|
|
@ -1,79 +0,0 @@
|
||||||
//===----- shape_inference_pass.cpp - Shape Inference ---------------------===//
|
|
||||||
//
|
|
||||||
// Copyright 2019 The IBM Research Authors.
|
|
||||||
//
|
|
||||||
// =============================================================================
|
|
||||||
//
|
|
||||||
// This file implements a Function level pass performing propagation of array
|
|
||||||
// shapes through function specialization.
|
|
||||||
//
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
|
|
||||||
#include "mlir/Pass/Pass.h"
|
|
||||||
#include "llvm/Support/raw_ostream.h"
|
|
||||||
|
|
||||||
#include "src/dialect/onnx/const_operands_interface.hpp"
|
|
||||||
#include "src/dialect/onnx/onnx_ops.hpp"
|
|
||||||
#include "src/pass/passes.hpp"
|
|
||||||
|
|
||||||
using namespace mlir;
|
|
||||||
|
|
||||||
#include "src/dialect/onnx/const_operands_interface.cpp.inc"
|
|
||||||
|
|
||||||
namespace {
|
|
||||||
/*!
|
|
||||||
* FunctionPass that performs shape inference by iterating over a list of
|
|
||||||
* candidate operations and propagating the shape information until the list
|
|
||||||
* of operations is empty [credit MLIR authors].
|
|
||||||
*/
|
|
||||||
class AttributePromotionPass
|
|
||||||
: public mlir::FunctionPass<AttributePromotionPass> {
|
|
||||||
public:
|
|
||||||
void runOnFunction() override {
|
|
||||||
auto f = getFunction();
|
|
||||||
|
|
||||||
llvm::Optional<mlir::Value> none;
|
|
||||||
f.walk([&](mlir::Operation *op) {
|
|
||||||
if (op->getNumResults() && op->getOpResult(0).getType().isa<NoneType>())
|
|
||||||
none = op->getOpResult(0);
|
|
||||||
});
|
|
||||||
|
|
||||||
if (!none) {
|
|
||||||
OpBuilder builder(f.getContext());
|
|
||||||
builder.setInsertionPointToStart(&f.front());
|
|
||||||
none =
|
|
||||||
builder.create<mlir::ConstantOp>(f.getLoc(), builder.getUnitAttr());
|
|
||||||
}
|
|
||||||
|
|
||||||
// Iterate on the operations that need shape inference i.e the operations
|
|
||||||
// that return a dynamic shape.
|
|
||||||
f.walk([&](mlir::Operation *op) {
|
|
||||||
if (IdentifyConstOperandsOpInterface opWithConstOperand =
|
|
||||||
dyn_cast<IdentifyConstOperandsOpInterface>(op)) {
|
|
||||||
auto promotableOperands = opWithConstOperand.identifyPromotableConstOperands();
|
|
||||||
for (const auto& operandNameToIdx : promotableOperands) {
|
|
||||||
auto name = operandNameToIdx.first;
|
|
||||||
auto idx = operandNameToIdx.second;
|
|
||||||
|
|
||||||
auto operandToPromote = op->getOperand(idx);
|
|
||||||
if (auto constantOp =
|
|
||||||
dyn_cast_or_null<ConstantOp>(operandToPromote.getDefiningOp())) {
|
|
||||||
op->setAttr(name, constantOp.value());
|
|
||||||
op->setOperand(idx, *none);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
});
|
|
||||||
}
|
|
||||||
};
|
|
||||||
} // end anonymous namespace
|
|
||||||
|
|
||||||
/*!
|
|
||||||
* Create a Shape Inference pass.
|
|
||||||
*/
|
|
||||||
std::unique_ptr<mlir::Pass> mlir::createAttributePromotionPass() {
|
|
||||||
return std::make_unique<AttributePromotionPass>();
|
|
||||||
}
|
|
||||||
|
|
||||||
static PassRegistration<AttributePromotionPass> pass(
|
|
||||||
"attribute-promotion", "Shape inference for frontend dialects.");
|
|
|
@ -1,24 +0,0 @@
|
||||||
//===- shape_inference_interface.hpp - Definition for ShapeInference --------=//
|
|
||||||
//
|
|
||||||
// Copyright 2019 The IBM Research Authors.
|
|
||||||
//
|
|
||||||
// =============================================================================
|
|
||||||
//
|
|
||||||
// This file contains the declarations of the shape inference interfaces defined
|
|
||||||
// in ShapeInferenceInterface.td.
|
|
||||||
//
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
|
|
||||||
#pragma once
|
|
||||||
|
|
||||||
#include <map>
|
|
||||||
#include <string>
|
|
||||||
|
|
||||||
#include "mlir/IR/OpDefinition.h"
|
|
||||||
|
|
||||||
namespace mlir {
|
|
||||||
|
|
||||||
/// Include the auto-generated declarations.
|
|
||||||
#include "src/dialect/onnx/const_operands_interface.hpp.inc"
|
|
||||||
|
|
||||||
} // end namespace mlir
|
|
|
@ -1,32 +0,0 @@
|
||||||
//=- shape_inference_interface.td - Shape Inference Interface -*- tablegen -==//
|
|
||||||
//
|
|
||||||
// Copyright 2019 The IBM Research Authors.
|
|
||||||
//
|
|
||||||
// =============================================================================
|
|
||||||
//
|
|
||||||
// Defines the operations of the Shape Inference Op Interface.
|
|
||||||
//
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
|
|
||||||
#ifdef IDENTIFY_PROMOTABLE_CONST_OPERANDS_OP_INTERFACE
|
|
||||||
#else
|
|
||||||
#define IDENTIFY_PROMOTABLE_CONST_OPERANDS_OP_INTERFACE
|
|
||||||
|
|
||||||
#ifdef OP_BASE
|
|
||||||
#else
|
|
||||||
include "mlir/IR/OpBase.td"
|
|
||||||
#endif // OP_BASE
|
|
||||||
|
|
||||||
def IdentifyConstOperandsOpInterface : OpInterface<"IdentifyConstOperandsOpInterface"> {
|
|
||||||
let description = [{
|
|
||||||
Interface to access a registered method to infer the return types for an
|
|
||||||
operation that can be used during type inference.
|
|
||||||
}];
|
|
||||||
|
|
||||||
let methods = [
|
|
||||||
InterfaceMethod<"Infer and set the output shape for the current operation.",
|
|
||||||
"std::map<std::string, size_t>", "identifyPromotableConstOperands">
|
|
||||||
];
|
|
||||||
}
|
|
||||||
|
|
||||||
#endif // IDENTIFY_PROMOTABLE_CONST_OPERANDS_OP_INTERFACE
|
|
|
@ -22,11 +22,6 @@ include "mlir/IR/OpBase.td"
|
||||||
include "pass/shape_inference_interface.td"
|
include "pass/shape_inference_interface.td"
|
||||||
#endif // SHAPE_INFERENCE_INTERFACE
|
#endif // SHAPE_INFERENCE_INTERFACE
|
||||||
|
|
||||||
#ifdef IDENTIFY_PROMOTABLE_CONST_OPERANDS_OP_INTERFACE
|
|
||||||
#else
|
|
||||||
include "dialect/onnx/const_operands_interface.td"
|
|
||||||
#endif // IDENTIFY_PROMOTABLE_CONST_OPERANDS_OP_INTERFACE
|
|
||||||
|
|
||||||
def ONNX_Dialect : Dialect {
|
def ONNX_Dialect : Dialect {
|
||||||
let name = "onnx";
|
let name = "onnx";
|
||||||
let cppNamespace = "";
|
let cppNamespace = "";
|
||||||
|
|
|
@ -10,9 +10,6 @@
|
||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include <map>
|
|
||||||
#include <string>
|
|
||||||
|
|
||||||
#include "mlir/Dialect/StandardOps/Ops.h"
|
#include "mlir/Dialect/StandardOps/Ops.h"
|
||||||
#include "mlir/IR/Builders.h"
|
#include "mlir/IR/Builders.h"
|
||||||
#include "mlir/IR/Dialect.h"
|
#include "mlir/IR/Dialect.h"
|
||||||
|
@ -20,7 +17,6 @@
|
||||||
#include "mlir/IR/StandardTypes.h"
|
#include "mlir/IR/StandardTypes.h"
|
||||||
|
|
||||||
#include "src/pass/shape_inference_interface.hpp"
|
#include "src/pass/shape_inference_interface.hpp"
|
||||||
#include "src/dialect/onnx/const_operands_interface.hpp"
|
|
||||||
|
|
||||||
namespace mlir {
|
namespace mlir {
|
||||||
|
|
||||||
|
|
|
@ -26,7 +26,7 @@ def ONNXAbsOp:ONNX_Op<"Abs",
|
||||||
outputTypes.emplace_back(UnrankedTensorType::get(elementType));
|
outputTypes.emplace_back(UnrankedTensorType::get(elementType));
|
||||||
build(builder, state, outputTypes, operands, attributes);
|
build(builder, state, outputTypes, operands, attributes);
|
||||||
}]>
|
}]>
|
||||||
];
|
];
|
||||||
}
|
}
|
||||||
|
|
||||||
def ONNXAcosOp:ONNX_Op<"Acos",
|
def ONNXAcosOp:ONNX_Op<"Acos",
|
||||||
|
@ -2501,9 +2501,7 @@ def ONNXReluOp:ONNX_Op<"Relu",
|
||||||
}
|
}
|
||||||
|
|
||||||
def ONNXReshapeOp:ONNX_Op<"Reshape",
|
def ONNXReshapeOp:ONNX_Op<"Reshape",
|
||||||
[NoSideEffect,
|
[NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
|
||||||
DeclareOpInterfaceMethods<ShapeInferenceOpInterface>,
|
|
||||||
OpInterface<"IdentifyConstOperandsOpInterface">]> {
|
|
||||||
let summary = "ONNX Reshape operation";
|
let summary = "ONNX Reshape operation";
|
||||||
let description = [{
|
let description = [{
|
||||||
"Reshape the input tensor similar to numpy.reshape."
|
"Reshape the input tensor similar to numpy.reshape."
|
||||||
|
@ -2514,11 +2512,8 @@ def ONNXReshapeOp:ONNX_Op<"Reshape",
|
||||||
"from the input tensor)."
|
"from the input tensor)."
|
||||||
}];
|
}];
|
||||||
let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$data,
|
let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$data,
|
||||||
AnyTypeOf<[AnyMemRef, AnyTensor, NoneType]>:$shape);
|
AnyTypeOf<[AnyMemRef, AnyTensor]>:$shape);
|
||||||
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$reshaped);
|
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$reshaped);
|
||||||
let extraClassDeclaration = [{
|
|
||||||
std::map<std::string, size_t> identifyPromotableConstOperands() { return {{"shape", 1}}; }
|
|
||||||
}];
|
|
||||||
}
|
}
|
||||||
|
|
||||||
def ONNXResizeOp:ONNX_Op<"Resize",
|
def ONNXResizeOp:ONNX_Op<"Resize",
|
||||||
|
|
|
@ -126,7 +126,6 @@ int main(int argc, char *argv[]) {
|
||||||
pm.addPass(mlir::createShapeInferencePass());
|
pm.addPass(mlir::createShapeInferencePass());
|
||||||
pm.addPass(mlir::createCanonicalizerPass());
|
pm.addPass(mlir::createCanonicalizerPass());
|
||||||
pm.addPass(mlir::createShapeInferencePass());
|
pm.addPass(mlir::createShapeInferencePass());
|
||||||
pm.addPass(mlir::createAttributePromotionPass());
|
|
||||||
|
|
||||||
if (emissionTarget >= EmitMLIR) {
|
if (emissionTarget >= EmitMLIR) {
|
||||||
pm.addPass(mlir::createLowerToKrnlPass());
|
pm.addPass(mlir::createLowerToKrnlPass());
|
||||||
|
|
|
@ -20,8 +20,6 @@ std::unique_ptr<Pass> createDecomposeONNXToONNXPass();
|
||||||
|
|
||||||
std::unique_ptr<Pass> createShapeInferencePass();
|
std::unique_ptr<Pass> createShapeInferencePass();
|
||||||
|
|
||||||
std::unique_ptr<Pass> createAttributePromotionPass();
|
|
||||||
|
|
||||||
/// Add pass for lowering to Krnl IR.
|
/// Add pass for lowering to Krnl IR.
|
||||||
std::unique_ptr<Pass> createLowerToKrnlPass();
|
std::unique_ptr<Pass> createLowerToKrnlPass();
|
||||||
|
|
||||||
|
|
|
@ -12,9 +12,9 @@
|
||||||
#include "mlir/Pass/Pass.h"
|
#include "mlir/Pass/Pass.h"
|
||||||
#include "llvm/ADT/SmallPtrSet.h"
|
#include "llvm/ADT/SmallPtrSet.h"
|
||||||
#include "llvm/Support/raw_ostream.h"
|
#include "llvm/Support/raw_ostream.h"
|
||||||
#include "mlir/IR/StandardTypes.h"
|
|
||||||
|
|
||||||
#include "shape_inference_interface.hpp"
|
#include "shape_inference_interface.hpp"
|
||||||
|
#include "src/dialect/onnx/onnx_ops.hpp"
|
||||||
|
|
||||||
#include "passes.hpp"
|
#include "passes.hpp"
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue