From 8665ecd998074f0a60c941112ccd32d427356890 Mon Sep 17 00:00:00 2001 From: Tian Jin Date: Mon, 20 Jan 2020 12:30:08 -0500 Subject: [PATCH 1/8] Enable e2e tests (#29) * Sync with latest MLIR. * Enable ONNX backend tests as a means to test ONNF lowering end-to-end. * Install ONNX using quiet mode. * Remove debug comments. * Install ONNX from third_party/onnx. * Check python version and fix pip command for installing ONNX. * Using --user install option to prevent permission denied. * Remove unused imports. * Try using stock ONNX pip package as there are more tests in them. * Pip got stuck building wheels, try sudo. * Use verbose install to debug. * Invalidate cache to build LLVM tools. * Fix mlir installation script location. * Debug to locate ONNF. * Sanity check. * Check out ONNF code first. * Use verbose LIT output. * 1. Update documentation to always use verbose LIT. 2. Update krnl ops to reflect new affine map attribute syntax. * See if conda exists * Install ONNX by manually cloning the repo. * Install cmake first. * Using sudo priviledge when installing. * Limit build parallelism. * Limit parallelism. * Larger memory. * Install onnx package with pip. * Build MLIR tools. * Invalidate cache. * Compile model.so with -fPIC. * Remove module dump to get concise debug output. * Print command before executing. * Use quiet install mode to reduce logging. * Use -relocation-model=pic to generate position independent code. * 1. Remove MAKEFLAGS because now buildbot has enough memory. 2. Run DocCheck as a last step. * 1. Add verbose mode for backtend test. * When dumping to LLVM bitcode, do not dump module IR, but print a message indicating that bitcode has been written to disk. * Do not pass MakeFlags to CMake. * Add more explaination for posible reasons of failing to identify tests. --- .circleci/config.yml | 38 ++++++----- CMakeLists.txt | 3 +- README.md | 4 +- src/main.cpp | 11 ++-- test/CMakeLists.txt | 3 +- test/backend/CMakeLists.txt | 10 +++ .../{onnx_backend_test.py => backend/test.py} | 66 ++++++++++--------- test/backend/test_config.py.in | 3 + test/mlir/krnl/ops.mlir | 18 ++--- utils/install-mlir.sh | 3 +- utils/install-onnf.sh | 1 + 11 files changed, 94 insertions(+), 66 deletions(-) create mode 100644 test/backend/CMakeLists.txt rename test/{onnx_backend_test.py => backend/test.py} (71%) create mode 100644 test/backend/test_config.py.in diff --git a/.circleci/config.yml b/.circleci/config.yml index f0556fd..6fba52f 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -3,27 +3,12 @@ jobs: build: docker: - image: circleci/python + resource_class: medium+ steps: - run: name: Installing GCC, CMake, Ninja, Protobuf command: sudo apt-get update && sudo apt-get install -y gcc g++ cmake ninja-build protobuf-compiler - # Use cached mlir installation if possible. - - restore_cache: - key: V2-LLVM-PROJECT-{{ arch }} - - run: - name: Install MLIR - command: | - # Check whether cache restoration succeeds by checking whether - # mlir-opt executable exists. - if [ ! -f llvm-project/build/bin/mlir-opt ]; then - export MAKEFLAGS=-j4 - source utils/install-mlir.sh - fi - - save_cache: - key: V2-LLVM-PROJECT-{{ arch }} - paths: - - llvm-project - checkout: path: ONNF - run: @@ -31,9 +16,30 @@ jobs: command: | cd ONNF git submodule update --init --recursive + # Use cached mlir installation if possible. + - restore_cache: + key: V4-LLVM-PROJECT-{{ arch }} + - run: + name: Install MLIR + command: | + # Check whether cache restoration succeeds by checking whether + # mlir-opt executable exists. + if [ ! -f llvm-project/build/bin/mlir-opt ]; then + source ONNF/utils/install-mlir.sh + fi + - save_cache: + key: V4-LLVM-PROJECT-{{ arch }} + paths: + - llvm-project - run: name: Install ONNF command: source ONNF/utils/install-onnf.sh + - run: + name: Run End-To-End Tests + command: | + sudo pip install -q onnx + cd ONNF/build + cmake --build . --target run-onnx-backend-test - run: name: Run DocCheck command: cd ONNF/build && cmake --build . --target check-doc diff --git a/CMakeLists.txt b/CMakeLists.txt index e672ac2..ab9e1d7 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -26,5 +26,4 @@ add_subdirectory(third_party/pybind11) set(CMAKE_CXX_STANDARD 14) add_subdirectory(src) add_subdirectory(doc) -add_subdirectory(test) - +add_subdirectory(test) \ No newline at end of file diff --git a/README.md b/README.md index 286d7c5..19c3d64 100644 --- a/README.md +++ b/README.md @@ -20,7 +20,8 @@ cmake -G Ninja ../llvm \ -DLLVM_ENABLE_ASSERTIONS=ON \ -DLLVM_ENABLE_RTTI=ON -cmake --build . --target check-mlir -- ${MAKEFLAGS} +cmake --build . --target +cmake --build . --target check-mlir ``` Two environment variables need to be set: @@ -42,6 +43,7 @@ cmake .. cmake --build . --target onnf # Run FileCheck tests: +export LIT_OPTS=-v cmake --build . --target check-mlir-lit ``` diff --git a/src/main.cpp b/src/main.cpp index 002bf08..6e2c8e2 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -135,10 +135,13 @@ int main(int argc, char *argv[]) { if (mlir::failed(pm.run(*module))) return 4; - module->dump(); - // Write LLVM bitcode to disk. - if (emissionTarget == EmitLLVMBC) - EmitLLVMBitCode(module); + if (emissionTarget == EmitLLVMBC) { + // Write LLVM bitcode to disk. + EmitLLVMBitCode(module); + printf("LLVM bitcode written to ./model.bc"); + } else + module->dump(); + return 0; } diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 967b5b0..2e49add 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -1 +1,2 @@ -add_subdirectory(mlir) \ No newline at end of file +add_subdirectory(mlir) +add_subdirectory(backend) \ No newline at end of file diff --git a/test/backend/CMakeLists.txt b/test/backend/CMakeLists.txt new file mode 100644 index 0000000..641422c --- /dev/null +++ b/test/backend/CMakeLists.txt @@ -0,0 +1,10 @@ +configure_file(test.py test.py COPYONLY) +configure_file(test_config.py.in test_config.py) + +find_package(PythonInterp 3 REQUIRED) +add_custom_target(run-onnx-backend-test + COMMAND ${PYTHON_EXECUTABLE} + ${CMAKE_CURRENT_BINARY_DIR}/test.py) + +add_dependencies(run-onnx-backend-test onnf) +add_dependencies(run-onnx-backend-test pyruntime) diff --git a/test/onnx_backend_test.py b/test/backend/test.py similarity index 71% rename from test/onnx_backend_test.py rename to test/backend/test.py index fe7f158..a9072db 100644 --- a/test/onnx_backend_test.py +++ b/test/backend/test.py @@ -3,46 +3,51 @@ from __future__ import division from __future__ import print_function from __future__ import unicode_literals -import itertools import os +import sys import unittest import onnx.backend.base import onnx.backend.test from onnx.backend.base import Device, DeviceType -import onnx.shape_inference -import onnx.version_converter import subprocess +import test_config + +VERBOSE = bool(os.environ.get("VERBOSE")) + +CXX = test_config.CXX_PATH +ONNF = os.path.join(test_config.ONNF_BUILD_PATH, "bin/onnf") +LLC = os.path.join(test_config.LLVM_PROJ_BUILD_PATH, "bin/llc") + +# Make lib folder under build directory visible in PYTHONPATH +doc_check_base_dir = os.path.dirname(os.path.realpath(__file__)) +RUNTIME_DIR = os.path.join(test_config.ONNF_BUILD_PATH, "lib") +sys.path.append(RUNTIME_DIR) from pyruntime import ExecutionSession -CXX = os.getenv('CXX') -ONNF = os.getenv('ONNF') -LLC = os.getenv('LLC') -RT_DIR = os.getenv('RT_DIR') -assert CXX and ONNF and LLC and RT_DIR, "tools path not set" + +def execute_commands(cmds): + if (VERBOSE): + print(" ".join(cmds)) + subprocess.run(cmds, stdout=subprocess.PIPE) + class DummyBackend(onnx.backend.base.Backend): @classmethod - def prepare( - cls, - model, - device='CPU', - **kwargs - ): + def prepare(cls, model, device='CPU', **kwargs): super(DummyBackend, cls).prepare(model, device, **kwargs) # Save model to disk as temp_model.onnx. onnx.save(model, "temp_model.onnx") # Call frontend to process temp_model.onnx, bit code will be generated. - subprocess.run([ONNF, "temp_model.onnx"], stdout=subprocess.PIPE) + execute_commands([ONNF, "temp_model.onnx"]) # Call llc to generate object file from bitcode. - subprocess.run([LLC, "-filetype=obj", "model.bc"], - stdout=subprocess.PIPE) + execute_commands( + [LLC, "-filetype=obj", "-relocation-model=pic", "model.bc"]) # Generate shared library from object file, linking with c runtime. - subprocess.run([ - CXX, "-shared", "model.o", "-o", "model.so", "-L" + RT_DIR, - "-lcruntime" - ], - stdout=subprocess.PIPE) + execute_commands([ + CXX, "-shared", "-fPIC", "model.o", "-o", "model.so", + "-L" + RUNTIME_DIR, "-lcruntime" + ]) return ExecutionSession("./model.so", "_dyn_entry_point_main_graph") @classmethod @@ -124,7 +129,7 @@ test_to_enable = [ # Sigmoid Op: "test_sigmoid_cpu", "test_sigmoid_example_cpu", - + # Sum Op: #"test_sum_example_cpu", <- error "test_sum_one_input_cpu", @@ -140,18 +145,15 @@ import inspect all_tests = inspect.getmembers( backend_test.test_cases["OnnxBackendNodeModelTest"]) all_test_names = list(map(lambda x: x[0], all_tests)) + +# Ensure that test names specified in test_to_enable actually exist. for test_name in test_to_enable: - assert test_name in all_test_names, "test name {} not found".format(test_name) + assert test_name in all_test_names, "test name {} not found, it is likely " + "that you may have misspelled the test name or the specified test does not " + "exist in the version of onnx package you installed.".format( + test_name) backend_test.include(r"^{}$".format(test_name)) - -def tearDownModule(): - print() - print("*" * 40) - print("A total of {} tests should have run".format(len(test_to_enable))) - print("*" * 40) - - # import all test cases at global scope to make them visible to python.unittest globals().update(backend_test.test_cases) diff --git a/test/backend/test_config.py.in b/test/backend/test_config.py.in new file mode 100644 index 0000000..571e35d --- /dev/null +++ b/test/backend/test_config.py.in @@ -0,0 +1,3 @@ +ONNF_BUILD_PATH = "@CMAKE_BINARY_DIR@" +LLVM_PROJ_BUILD_PATH = "@LLVM_PROJ_BUILD@" +CXX_PATH = "@CMAKE_CXX_COMPILER@" diff --git a/test/mlir/krnl/ops.mlir b/test/mlir/krnl/ops.mlir index 0300a11..a098d66 100644 --- a/test/mlir/krnl/ops.mlir +++ b/test/mlir/krnl/ops.mlir @@ -1,12 +1,12 @@ // RUN: onnf-opt %s -mlir-print-op-generic | FileCheck -check-prefix=GENERIC %s // RUN: onnf-opt %s | FileCheck %s -// GENERIC-DAG: #{{.*}} = () -> (0) -// GENERIC-DAG: #{{.*}} = () -> (10) -// GENERIC-DAG: #{{.*}} = () -> (1) -// GENERIC-DAG: #{{.*}} = () -> (11) -// GENERIC-DAG: #{{.*}} = (d0, d1) -> (d0 - d1) -// GENERIC-DAG: #{{.*}} = (d0, d1) -> (d0 + d1) +// GENERIC-DAG: #{{.*}} = affine_map<() -> (0)> +// GENERIC-DAG: #{{.*}} = affine_map<() -> (10)> +// GENERIC-DAG: #{{.*}} = affine_map<() -> (1)> +// GENERIC-DAG: #{{.*}} = affine_map<() -> (11)> +// GENERIC-DAG: #{{.*}} = affine_map<(d0, d1) -> (d0 - d1)> +// GENERIC-DAG: #{{.*}} = affine_map<(d0, d1) -> (d0 + d1)> func @simple_iterate(%N : index) { %ii, %ij, %ik = krnl.define_loops 3 @@ -55,18 +55,18 @@ func @affine_map_bound(%N : index) { // GENERIC: "krnl.iterate"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) ( { // GENERIC-NEXT: ^bb0(%{{.*}}: index, %{{.*}}: index): // CHECK: krnl.iterate(%{{.*}}, %{{.*}}) with (%{{.*}} -> %{{.*}} = 0 to 10, %{{.*}} -> %{{.*}} = 0 to 10) { - krnl.iterate(%oi, %oj) with (%ii -> %i = ()->(0)() to ()->(10)(), %ij -> %j = 0 to 10) { + krnl.iterate(%oi, %oj) with (%ii -> %i = affine_map<()->(0)>() to affine_map<()->(10)>(), %ij -> %j = 0 to 10) { // GENERIC: "krnl.iterate"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) ( { // GENERIC-NEXT: ^bb0(%{{.*}}: index): // CHECK: krnl.iterate(%{{.*}}) with (%{{.*}} -> %{{.*}} = #{{.*}}(%{{.*}}, %{{.*}}) to #{{.*}}(%{{.*}}, %{{.*}})) { - krnl.iterate(%ok) with (%ik -> %k = (d0, d1)->(d0 - d1)(%i, %j) to (d0, d1)->(d0 + d1)(%i, %j)) { + krnl.iterate(%ok) with (%ik -> %k = affine_map<(d0, d1)->(d0 - d1)>(%i, %j) to affine_map<(d0, d1)->(d0 + d1)>(%i, %j)) { } // GENERIC: "krnl.iterate"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) ( { // GENERIC-NEXT: ^bb0(%{{.*}}: index): // CHECK: krnl.iterate(%{{.*}}) with (%{{.*}} -> %{{.*}} = max #map{{.*}}(%{{.*}}, %{{.*}}) to min #map{{.*}}(%{{.*}}, %{{.*}})[%{{.*}}]) { - krnl.iterate(%ok) with (%ik -> %k = max (d0, d1)->(d0 - d1, 0)(%i, %j) to min (d0, d1)[s0]->(d0 + d1, s0)(%i, %j)[%N]) { + krnl.iterate(%ok) with (%ik -> %k = max affine_map<(d0, d1)->(d0 - d1, 0)>(%i, %j) to min affine_map<(d0, d1)[s0]->(d0 + d1, s0)>(%i, %j)[%N]) { } } diff --git a/utils/install-mlir.sh b/utils/install-mlir.sh index 425e57c..d47c37a 100644 --- a/utils/install-mlir.sh +++ b/utils/install-mlir.sh @@ -9,4 +9,5 @@ cmake -G Ninja ../llvm \ -DLLVM_ENABLE_ASSERTIONS=ON \ -DLLVM_ENABLE_RTTI=ON -cmake --build . --target check-mlir -- ${MAKEFLAGS} \ No newline at end of file +cmake --build . --target +cmake --build . --target check-mlir \ No newline at end of file diff --git a/utils/install-onnf.sh b/utils/install-onnf.sh index e28670f..5f2a98b 100644 --- a/utils/install-onnf.sh +++ b/utils/install-onnf.sh @@ -7,4 +7,5 @@ cmake .. cmake --build . --target onnf # Run FileCheck tests: +export LIT_OPTS=-v cmake --build . --target check-mlir-lit \ No newline at end of file From 9d1078540d37210644681bb3cd47acd2eb080e5f Mon Sep 17 00:00:00 2001 From: Doru Bercea Date: Mon, 13 Jan 2020 18:08:19 -0500 Subject: [PATCH 2/8] Transpose using perm attribute. --- src/dialect/onnx/onnx_ops.cpp | 24 +++++++++++++++++++----- src/dialect/onnx/onnxop.inc | 4 ++++ 2 files changed, 23 insertions(+), 5 deletions(-) diff --git a/src/dialect/onnx/onnx_ops.cpp b/src/dialect/onnx/onnx_ops.cpp index 985f63d..73d76b5 100644 --- a/src/dialect/onnx/onnx_ops.cpp +++ b/src/dialect/onnx/onnx_ops.cpp @@ -407,11 +407,25 @@ void ONNXTransposeOp::inferShapes() { // Naive transposition which handles the default case of // reversing the shape of the tensor (similar to numpy.transpose). - // TODO: Once attributes are supported we can handle the case where the - // transposition uses a permutation vector to interchange the axes. - auto arrayTy = getOperand().getType().cast(); - SmallVector dims(llvm::reverse(arrayTy.getShape())); - getResult().setType(RankedTensorType::get(dims, arrayTy.getElementType())); + auto arrayTy = getOperand()->getType().cast(); + SmallVector dims; + + if (auto permutation = getAttrOfType( + ONNXTransposeOp::getPermAttrName())) { + // Perform transposition according to perm attribute. + for (auto perm : permutation.getValue()) { + int32_t index = perm.cast().getInt(); + if (index < 0) + emitError("Cannot tranpose when permutation contains negative index."); + dims.emplace_back(arrayTy.getShape()[index]); + } + } else { + // Default + for (auto shape : llvm::reverse(arrayTy.getShape())) + dims.emplace_back(shape); + } + + getResult()->setType(RankedTensorType::get(dims, arrayTy.getElementType())); } //===----------------------------------------------------------------------===// diff --git a/src/dialect/onnx/onnxop.inc b/src/dialect/onnx/onnxop.inc index 16ad979..5d22346 100644 --- a/src/dialect/onnx/onnxop.inc +++ b/src/dialect/onnx/onnxop.inc @@ -3098,6 +3098,10 @@ def ONNXTransposeOp:ONNX_Op<"Transpose", }]; let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$data); let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>); + + let extraClassDeclaration = [{ + static StringRef getPermAttrName() { return "perm"; } + }]; } def ONNXUniqueOp:ONNX_Op<"Unique", From f0b484c0bc59f0f39a93e111cee428fb6f596ccf Mon Sep 17 00:00:00 2001 From: Doru Bercea Date: Tue, 14 Jan 2020 10:37:05 -0500 Subject: [PATCH 3/8] Add test for transpose with permutation. --- test/mlir/onnx/onnx_shape_inference.mlir | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/test/mlir/onnx/onnx_shape_inference.mlir b/test/mlir/onnx/onnx_shape_inference.mlir index 4cb9bec..aaa08a7 100644 --- a/test/mlir/onnx/onnx_shape_inference.mlir +++ b/test/mlir/onnx/onnx_shape_inference.mlir @@ -9,4 +9,14 @@ func @test_default_transpose(%arg0 : tensor<5x5x1x32xf32>) -> tensor<*xf32> { // CHECK-LABEL: test_default_transpose // CHECK: [[RES:%.+]] = "onnx.Transpose"(%arg0) : (tensor<5x5x1x32xf32>) -> tensor<32x1x5x5xf32> -// CHECK: return [[RES]] : tensor<32x1x5x5xf32> \ No newline at end of file +// CHECK: return [[RES]] : tensor<32x1x5x5xf32> + +/// Test shape inference for transposition when perm attribute is specified. +func @test_transpose(%arg0 : tensor<5x5x1x32xf32>) -> tensor<*xf32> { + %0 = "onnx.Transpose"(%arg0) {perm = [2, 0, 3, 1]} : (tensor<5x5x1x32xf32>) -> tensor<*xf32> + "std.return"(%0) : (tensor<*xf32>) -> () +} + +// CHECK-LABEL: test_transpose +// CHECK: [[RES_ATTR:%.+]] = "onnx.Transpose"(%arg0) {perm = [2, 0, 3, 1]} : (tensor<5x5x1x32xf32>) -> tensor<1x5x32x5xf32> +// CHECK: return [[RES_ATTR]] : tensor<1x5x32x5xf32> \ No newline at end of file From bd44d8402e09c8b1c6068550abc640fdf5c41ef9 Mon Sep 17 00:00:00 2001 From: Doru Bercea Date: Mon, 20 Jan 2020 14:46:54 -0500 Subject: [PATCH 4/8] Add verifier function for checking negative perms. --- src/dialect/onnx/onnx_ops.cpp | 28 ++++++++++++++++++++-------- src/dialect/onnx/onnxop.inc | 2 ++ 2 files changed, 22 insertions(+), 8 deletions(-) diff --git a/src/dialect/onnx/onnx_ops.cpp b/src/dialect/onnx/onnx_ops.cpp index 73d76b5..cef90cb 100644 --- a/src/dialect/onnx/onnx_ops.cpp +++ b/src/dialect/onnx/onnx_ops.cpp @@ -13,6 +13,7 @@ #include "mlir/IR/Function.h" #include "mlir/IR/IntegerSet.h" #include "mlir/IR/Matchers.h" +#include "mlir/IR/Module.h" #include "mlir/IR/OpImplementation.h" #include "mlir/IR/PatternMatch.h" #include "llvm/ADT/SetVector.h" @@ -413,21 +414,32 @@ void ONNXTransposeOp::inferShapes() { if (auto permutation = getAttrOfType( ONNXTransposeOp::getPermAttrName())) { // Perform transposition according to perm attribute. - for (auto perm : permutation.getValue()) { - int32_t index = perm.cast().getInt(); - if (index < 0) - emitError("Cannot tranpose when permutation contains negative index."); - dims.emplace_back(arrayTy.getShape()[index]); - } + for (auto perm : permutation.getValue()) + dims.emplace_back(arrayTy.getShape()[perm.cast().getInt()]); } else { // Default - for (auto shape : llvm::reverse(arrayTy.getShape())) - dims.emplace_back(shape); + for (auto dim : llvm::reverse(arrayTy.getShape())) + dims.emplace_back(dim); } getResult()->setType(RankedTensorType::get(dims, arrayTy.getElementType())); } +LogicalResult verify(ONNXTransposeOp op) { + auto module = op.getParentOfType(); + if (!module) + op.emitError("Expected to belong to a module."); + + if (auto permutation = op.getAttrOfType( + ONNXTransposeOp::getPermAttrName())) { + for (auto perm : permutation.getValue()) + if (perm.cast().getInt() < 0) + op.emitError("Cannot tranpose, permuation contains negative index."); + } + + return success(); +} + //===----------------------------------------------------------------------===// // TableGen'd op method definitions //===----------------------------------------------------------------------===// diff --git a/src/dialect/onnx/onnxop.inc b/src/dialect/onnx/onnxop.inc index 5d22346..fc2714e 100644 --- a/src/dialect/onnx/onnxop.inc +++ b/src/dialect/onnx/onnxop.inc @@ -3102,6 +3102,8 @@ def ONNXTransposeOp:ONNX_Op<"Transpose", let extraClassDeclaration = [{ static StringRef getPermAttrName() { return "perm"; } }]; + + let verifier = [{ return ::verify(*this); }]; } def ONNXUniqueOp:ONNX_Op<"Unique", From 6b55bb43c7fb1c7eab0104b29139caf4dedd2aa3 Mon Sep 17 00:00:00 2001 From: Doru Bercea Date: Mon, 20 Jan 2020 15:48:16 -0500 Subject: [PATCH 5/8] Fix operand type access. --- src/dialect/onnx/onnx_ops.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/dialect/onnx/onnx_ops.cpp b/src/dialect/onnx/onnx_ops.cpp index cef90cb..7e1675d 100644 --- a/src/dialect/onnx/onnx_ops.cpp +++ b/src/dialect/onnx/onnx_ops.cpp @@ -408,7 +408,7 @@ void ONNXTransposeOp::inferShapes() { // Naive transposition which handles the default case of // reversing the shape of the tensor (similar to numpy.transpose). - auto arrayTy = getOperand()->getType().cast(); + auto arrayTy = getOperand().getType().cast(); SmallVector dims; if (auto permutation = getAttrOfType( @@ -422,7 +422,7 @@ void ONNXTransposeOp::inferShapes() { dims.emplace_back(dim); } - getResult()->setType(RankedTensorType::get(dims, arrayTy.getElementType())); + getResult().setType(RankedTensorType::get(dims, arrayTy.getElementType())); } LogicalResult verify(ONNXTransposeOp op) { From e89e51699bac092899c5c4121a9c442bb13e2a1c Mon Sep 17 00:00:00 2001 From: "Tung D. Le" Date: Tue, 21 Jan 2020 11:57:32 +0900 Subject: [PATCH 6/8] Lowering softmax (#14) * Rebase * Use max normalization * Handle axis * Add tests * Update SharingWork.md * Remove redundant spaces * Format code * Rebase * Change from the use of Value* to Value * Add end-to-end tests Co-authored-by: Tian Jin --- SharingWork.md | 1 + src/dialect/onnx/gen_doc.py | 2 +- src/dialect/onnx/onnx_ops.cpp | 8 + src/dialect/onnx/onnxop.inc | 2 +- src/pass/lower_frontend_to_krnl.cpp | 222 +++++++++++++++++++++++++++- src/pass/shape_inference_pass.cpp | 3 +- test/backend/test.py | 8 + test/mlir/onnx/onnx_lowering.mlir | 46 ++++++ 8 files changed, 288 insertions(+), 4 deletions(-) diff --git a/SharingWork.md b/SharingWork.md index 6b6c063..fe43494 100644 --- a/SharingWork.md +++ b/SharingWork.md @@ -27,6 +27,7 @@ ONNX operations for which some work is needed. | Selu | Tung | v | v | | | Sigmoid | Tung | v | v | | | Sinh | Tung | v | v | | +| Softmax | Tung | v | v | | | Sub | Tung | v | v | M | | Sum | Tung | v | v | M | | Tanh | Tung | v | v | | diff --git a/src/dialect/onnx/gen_doc.py b/src/dialect/onnx/gen_doc.py index 6d986c2..4141556 100644 --- a/src/dialect/onnx/gen_doc.py +++ b/src/dialect/onnx/gen_doc.py @@ -267,7 +267,7 @@ def gen_schema(schema) : 'Add', 'Mul', 'Div', 'Sub', 'And', 'Or', 'Xor', 'Sum', 'Max', 'Min', 'MatMul', 'Gemm', 'LeakyRelu', 'Elu', 'Selu', 'HardSigmoid', 'Reshape', 'Reciprocal', - 'Identity', 'Cos', 'Log', 'Transpose'] + 'Identity', 'Cos', 'Log', 'Transpose', 'Softmax'] CanonicalList=['Add', 'Identity'] line_indent = ' ' diff --git a/src/dialect/onnx/onnx_ops.cpp b/src/dialect/onnx/onnx_ops.cpp index 7e1675d..53e463d 100644 --- a/src/dialect/onnx/onnx_ops.cpp +++ b/src/dialect/onnx/onnx_ops.cpp @@ -158,6 +158,14 @@ void ONNXReciprocalOp::inferShapes() { getResult().setType(getOperand().getType()); } +//===----------------------------------------------------------------------===// +// Softmax +/// Infer the output shape of the ONNXSoftmaxOp. This method is required by +/// the shape inference interface. +void ONNXSoftmaxOp::inferShapes() { + getResult().setType(getOperand().getType()); +} + //===----------------------------------------------------------------------===// // Add /// Infer the output shape of the ONNXAddOp. This method is required by the diff --git a/src/dialect/onnx/onnxop.inc b/src/dialect/onnx/onnxop.inc index fc2714e..e87a01a 100644 --- a/src/dialect/onnx/onnxop.inc +++ b/src/dialect/onnx/onnxop.inc @@ -2831,7 +2831,7 @@ def ONNXSliceOp:ONNX_Op<"Slice", } def ONNXSoftmaxOp:ONNX_Op<"Softmax", - [NoSideEffect]> { + [NoSideEffect, DeclareOpInterfaceMethods]> { let summary = "ONNX Softmax operation"; let description = [{ "The operator computes the softmax (normalized exponential) values for each layer in the batch" diff --git a/src/pass/lower_frontend_to_krnl.cpp b/src/pass/lower_frontend_to_krnl.cpp index a578479..3d899ee 100644 --- a/src/pass/lower_frontend_to_krnl.cpp +++ b/src/pass/lower_frontend_to_krnl.cpp @@ -824,6 +824,225 @@ struct ONNXElementwiseVariadicOpLowering : public ConversionPattern { } }; +struct ONNXSoftmaxOpLowering : public ConversionPattern { + ONNXSoftmaxOpLowering(MLIRContext *ctx) + : ConversionPattern(mlir::ONNXSoftmaxOp::getOperationName(), 1, ctx) {} + PatternMatchResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const final { + // softmax(x) = let max_x = max(x) in + // let exp_x = exp(x - max_x) in + // let sum = sum(exp_x) in + // exp_x / sum + auto tensorType = (*op->result_type_begin()).cast(); + int64_t rank = tensorType.getRank(); + int64_t axis = op->getAttrOfType("Softmax.axis").getInt(); + axis = axis >= 0 ? axis : rank + axis; + assert(axis >= -rank && axis <= rank - 1); + + auto loc = op->getLoc(); + + // Insert an allocation and deallocation for the result of this operation. + auto memRefType = convertTensorToMemRef(tensorType); + auto elementType = memRefType.getElementType(); + + Value alloc; + bool insertDealloc = checkInsertDealloc(op); + if (hasAllConstantDimensions(memRefType)) + alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc); + else + alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc, + operands[0]); + + // Shape of the result + auto memRefShape = memRefType.getShape(); + + // Insert allocations and deallocations for sum and max. + MemRefType scalarMemRefType = MemRefType::get({}, elementType, {}, 0); + Value sumOp = insertAllocAndDealloc(scalarMemRefType, loc, rewriter, true); + Value maxOp = insertAllocAndDealloc(scalarMemRefType, loc, rewriter, true); + Value zero = + rewriter.create(loc, FloatAttr::get(elementType, 0)); + Value negInfinity = rewriter.create( + loc, + FloatAttr::get(elementType, -std::numeric_limits::infinity())); + + // Define loops. + auto loopsOp = rewriter.create(loc, rank); + std::vector originalLoops; + originalLoops.reserve(rank); + for (auto result : loopsOp.getResults()) { + originalLoops.push_back(result); + } + + // Define loop optimization. + auto optimizedLoopsOp = rewriter.create(loc, rank); + std::vector optimizedLoops; + optimizedLoops.reserve(rank); + for (auto result : optimizedLoopsOp.getResults()) { + optimizedLoops.push_back(result); + } + Block &optimizationBlock = optimizedLoopsOp.region().front(); + + // Coerce the input into a 2-D tensor. `axis` will be the coercing point. + // This coercing follows the softmax definition in ONNX: + // https://github.com/onnx/onnx/blob/master/docs/Operators.md#Softmax + // Here, we create an outer loop and inner loop for handling the two + // dimensions. The outer loop is only created once `axis` is not zero. + + // Define an outer loop with respect to axis. + std::vector outerLoops, optimizedOuterLoops; + outerLoops.reserve(axis); + optimizedOuterLoops.reserve(axis); + for (int i = 0; i < axis; ++i) { + outerLoops.push_back(originalLoops[i]); + optimizedOuterLoops.push_back(optimizedLoops[i]); + } + KrnlIterateOperandPack outerPack(rewriter, outerLoops, optimizedOuterLoops); + for (int i = 0; i < axis; ++i) { + if (memRefShape[i] < 0) { + outerPack.pushConstantBound(0); + outerPack.pushOperandBound( + rewriter.create(loc, operands[0], i).getResult()); + } else { + outerPack.pushConstantBound(0); + outerPack.pushConstantBound(memRefShape[i]); + } + } + // Define an inner loop with respect to axis. + std::vector innerLoops, optimizedInnerLoops; + innerLoops.reserve(rank - axis); + optimizedInnerLoops.reserve(rank - axis); + for (int i = axis; i < rank; ++i) { + innerLoops.push_back(originalLoops[i]); + optimizedInnerLoops.push_back(optimizedLoops[i]); + } + KrnlIterateOperandPack innerPack(rewriter, innerLoops, optimizedInnerLoops); + for (int i = axis; i < rank; ++i) { + if (memRefShape[i] < 0) { + innerPack.pushConstantBound(0); + innerPack.pushOperandBound( + rewriter.create(loc, operands[0], i).getResult()); + } else { + innerPack.pushConstantBound(0); + innerPack.pushConstantBound(memRefShape[i]); + } + } + + KrnlIterateOp outerIterateOp, maxIterateOp, sumIterateOp, softmaxIterateOp; + SmallVector outerLoopIVs; + if (axis != 0) { + outerIterateOp = rewriter.create(loc, outerPack); + + // No optimization + rewriter.setInsertionPointToEnd(&optimizationBlock); + rewriter.create(loc, originalLoops); + rewriter.setInsertionPoint(optimizedLoopsOp); + + // Insert instructions inside the outer loop. + Block &outerIterationBlock = outerIterateOp.bodyRegion().front(); + rewriter.setInsertionPointToStart(&outerIterationBlock); + for (auto arg : outerIterationBlock.getArguments()) + outerLoopIVs.push_back(arg); + + // Reset accumulators. + rewriter.create(loc, zero, sumOp); + rewriter.create(loc, negInfinity, maxOp); + + // Create an inner loop to compute max. + maxIterateOp = rewriter.create(loc, innerPack); + // Create an inner loop to compute sum. + sumIterateOp = rewriter.create(loc, innerPack); + // Create an inner loop to compute softmax. + softmaxIterateOp = rewriter.create(loc, innerPack); + } else { + // Reset accumulators. + rewriter.create(loc, zero, sumOp); + rewriter.create(loc, negInfinity, maxOp); + + // Create an inner loop to compute max. + maxIterateOp = rewriter.create(loc, innerPack); + // Create an inner loop to compute sum. + sumIterateOp = rewriter.create(loc, innerPack); + // Create an inner loop to compute softmax. + softmaxIterateOp = rewriter.create(loc, innerPack); + + // No optimization + rewriter.setInsertionPointToEnd(&optimizationBlock); + rewriter.create(loc, originalLoops); + rewriter.setInsertionPoint(optimizedLoopsOp); + } + + // Insert instructions inside the max loop. + Block &maxIterationBlock = maxIterateOp.bodyRegion().front(); + rewriter.setInsertionPointToStart(&maxIterationBlock); + + // Get induction variables. + SmallVector maxLoopIVs; + for (auto arg : outerLoopIVs) + maxLoopIVs.push_back(arg); + for (auto arg : maxIterationBlock.getArguments()) + maxLoopIVs.push_back(arg); + + // Compute the max value. + Value max = rewriter.create(loc, maxOp); + Value nextMax = rewriter.create(loc, operands[0], maxLoopIVs); + auto maxCond = + rewriter.create(loc, CmpFPredicate::OGT, max, nextMax); + max = rewriter.create(loc, maxCond, max, nextMax); + rewriter.create(loc, max, maxOp); + + // Get the max. + rewriter.setInsertionPoint(sumIterateOp); + max = rewriter.create(loc, maxOp); + + // Insert instructions inside the sum loop. + Block &sumIterationBlock = sumIterateOp.bodyRegion().front(); + rewriter.setInsertionPointToStart(&sumIterationBlock); + + // Get induction variables. + SmallVector sumLoopIVs; + for (auto arg : outerLoopIVs) + sumLoopIVs.push_back(arg); + for (auto arg : sumIterationBlock.getArguments()) + sumLoopIVs.push_back(arg); + + // Sum up values. + Value sum = rewriter.create(loc, sumOp); + Value next = rewriter.create(loc, operands[0], sumLoopIVs); + Value sub = rewriter.create(loc, next, max); + Value exp = rewriter.create(loc, sub); + sum = rewriter.create(loc, sum, exp); + rewriter.create(loc, sum, sumOp); + // Store intermediate values in the result to avoid recomputation. + rewriter.create(loc, exp, alloc, sumLoopIVs); + + // Get the sum. + rewriter.setInsertionPoint(softmaxIterateOp); + sum = rewriter.create(loc, sumOp); + + // Insert instructions inside the softmax loop. + Block &softmaxIterationBlock = softmaxIterateOp.bodyRegion().front(); + rewriter.setInsertionPointToStart(&softmaxIterationBlock); + + // Get induction variables. + SmallVector softmaxLoopIVs; + for (auto arg : outerLoopIVs) + softmaxLoopIVs.push_back(arg); + for (auto arg : softmaxIterationBlock.getArguments()) + softmaxLoopIVs.push_back(arg); + + // Compute softmax. + Value expLoadedVal = rewriter.create(loc, alloc, softmaxLoopIVs); + Value result = rewriter.create(loc, expLoadedVal, sum); + rewriter.create(loc, result, alloc, softmaxLoopIVs); + + rewriter.replaceOp(op, alloc); + + return matchSuccess(); + } +}; + struct ONNXReshapeOpLowering : public ConversionPattern { ONNXReshapeOpLowering(MLIRContext *ctx) : ConversionPattern(mlir::ONNXReshapeOp::getOperationName(), 1, ctx) {} @@ -1005,7 +1224,8 @@ void FrontendToKrnlLoweringPass::runOnModule() { ONNXElementwiseVariadicOpLowering, ONNXElementwiseVariadicOpLowering, ONNXElementwiseVariadicOpLowering, - ONNXReshapeOpLowering, ONNXEntryPointLowering>(&getContext()); + ONNXReshapeOpLowering, ONNXEntryPointLowering, + ONNXSoftmaxOpLowering>(&getContext()); // With the target and rewrite patterns defined, we can now attempt the // conversion. The conversion will signal failure if any of our `illegal` diff --git a/src/pass/shape_inference_pass.cpp b/src/pass/shape_inference_pass.cpp index 5ccb9a4..3226f16 100644 --- a/src/pass/shape_inference_pass.cpp +++ b/src/pass/shape_inference_pass.cpp @@ -116,7 +116,8 @@ public: op->getName().getStringRef() != "onnx.Gemm" && op->getName().getStringRef() != "onnx.GemmNoBias" && op->getName().getStringRef() != "onnx.Reshape" && - op->getName().getStringRef() != "onnx.Transpose") + op->getName().getStringRef() != "onnx.Transpose" && + op->getName().getStringRef() != "onnx.Softmax") return false; return llvm::any_of(op->getResultTypes(), [](Type result_type) { return !result_type.isa(); diff --git a/test/backend/test.py b/test/backend/test.py index a9072db..60ca4a8 100644 --- a/test/backend/test.py +++ b/test/backend/test.py @@ -130,6 +130,14 @@ test_to_enable = [ "test_sigmoid_cpu", "test_sigmoid_example_cpu", + # Softmax Op: + "test_softmax_axis_0_cpu", + "test_softmax_axis_1_cpu", + "test_softmax_axis_2_cpu", + "test_softmax_default_axis_cpu", + "test_softmax_example_cpu", + "test_softmax_large_number_cpu", + # Sum Op: #"test_sum_example_cpu", <- error "test_sum_one_input_cpu", diff --git a/test/mlir/onnx/onnx_lowering.mlir b/test/mlir/onnx/onnx_lowering.mlir index 123e6a1..3ffce9a 100644 --- a/test/mlir/onnx/onnx_lowering.mlir +++ b/test/mlir/onnx/onnx_lowering.mlir @@ -533,3 +533,49 @@ func @test_add_with_broadcasting(%arg0 : tensor, %arg1 : tensor // CHECK: } // CHECK: return [[RES]] : memref } + +func @test_softmax(%arg0 : tensor<10x10xf32>) -> tensor<*xf32> { + %0 = "onnx.Softmax"(%arg0) {Softmax.axis=1:i32} : (tensor<10x10xf32>) -> tensor<*xf32> + "std.return"(%0) : (tensor<*xf32>) -> () + + // CHECK-LABEL: test_softmax + // CHECK: [[MAX:%.+]] = alloc() : memref + // CHECK: [[SUM:%.+]] = alloc() : memref + // CHECK: [[RES:%.+]] = alloc() : memref<10x10xf32> + // CHECK: [[CST:%.+]] = constant 0.000000e+00 : f32 + // CHECK: [[CST_0:%.+]] = constant 0xFF800000 : f32 + // CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2 + // CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops { + // CHECK: krnl.return_loops [[DEF_LOOPS]]#0, %3#1 + // CHECK: } : () -> (!krnl.loop, !krnl.loop) + // CHECK: krnl.iterate([[OPT_LOOPS]]#0) with ([[DEF_LOOPS]]#0 -> %arg1 = 0 to 10) { + // CHECK: store [[CST]], [[SUM]][] : memref + // CHECK: store [[CST_0]], [[MAX]][] : memref + // CHECK: krnl.iterate([[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#1 -> %arg2 = 0 to 10) { + // CHECK: [[LOAD1:%.+]] = load [[MAX]][] : memref + // CHECK: [[LOAD2:%.+]] = load %arg0[%arg1, %arg2] : memref<10x10xf32> + // CHECK: [[COND:%.+]] = cmpf "ogt", [[LOAD1]], [[LOAD2]] : f32 + // CHECK: [[SELECT:%.+]] = select [[COND]], [[LOAD1]], [[LOAD2]] : f32 + // CHECK: store [[SELECT]], [[MAX]][] : memref + // CHECK: } + // CHECK: %5 = load [[MAX]][] : memref + // CHECK: krnl.iterate([[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#1 -> %arg2 = 0 to 10) { + // CHECK: [[LOAD1]] = load [[SUM]][] : memref + // CHECK: [[LOAD2]] = load %arg0[%arg1, %arg2] : memref<10x10xf32> + // CHECK: [[SUB:%.+]] = subf [[LOAD2]], %5 : f32 + // CHECK: [[EXP:%.+]] = exp [[SUB]] : f32 + // CHECK: [[ADD:%.+]] = addf [[LOAD1]], [[EXP]] : f32 + // CHECK: store [[ADD]], [[SUM]][] : memref + // CHECK: store %10, [[RES]][%arg1, %arg2] : memref<10x10xf32> + // CHECK: } + // CHECK: %6 = load [[SUM]][] : memref + // CHECK: krnl.iterate([[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#1 -> %arg2 = 0 to 10) { + // CHECK: [[LOAD1]] = load [[RES]][%arg1, %arg2] : memref<10x10xf32> + // CHECK: [[DIV:%.+]] = divf [[LOAD1]], %6 : f32 + // CHECK: store [[DIV]], [[RES]][%arg1, %arg2] : memref<10x10xf32> + // CHECK: } + // CHECK: } + // CHECK: dealloc [[SUM]] : memref + // CHECK: dealloc [[MAX]] : memref + // CHECK: return [[RES]] : memref<10x10xf32> +} From 0231bb83a212258b46b0a66319bc90b984cc694c Mon Sep 17 00:00:00 2001 From: Tian Jin Date: Tue, 21 Jan 2020 11:08:16 -0500 Subject: [PATCH 7/8] Properly link with ZLIB. (#40) --- src/CMakeLists.txt | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index b537820..73c0ffe 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -69,8 +69,9 @@ add_subdirectory(runtime) add_executable(onnf main.cpp) target_link_libraries(onnf builder ${MLIRLibs} onnf_transform onnf_shape_inference onnf_lower_frontend) -set_target_properties(onnf PROPERTIES LINK_FLAGS "-lz") whole_archive_link_mlir(onnf ${MLIRWholeArchiveLibs}) +find_package(ZLIB REQUIRED) +target_link_libraries(onnf ${ZLIB_LIBRARIES}) target_include_directories(onnf PRIVATE ${CMAKE_SOURCE_DIR}) target_include_directories(onnf PRIVATE ${CMAKE_BINARY_DIR}) From 51b0f4c9dd8988d0a64498e537a9bd2dd8b1d613 Mon Sep 17 00:00:00 2001 From: Tian Jin Date: Tue, 21 Jan 2020 19:36:21 -0700 Subject: [PATCH 8/8] Chentong319 attribute with variant (#25) * change the read-in of attribute, using variant * Use backported variant. * Reduce code duplication. * 1. Make array attribute parsing more clear. 2. int -> int64_t. * 1. Fix how array attributes are imported. * 1. Fix clang-tidy warnings. * 1. Nit: fix clang-tidy warnings. * Fix MaxPool node construction. * Fix call to MaxPool. * Comment out backend tests that fail. * Add path to variant submodule to enable include file detection. * Allow unused argument to avoid special casing generator. * Address attribute related e2e test failures for Hard sigmoid,Elu,LeakyRelu,Selu,Softmax Co-authored-by: chentong319 Co-authored-by: Gheorghe-Teodor Bercea --- .gitmodules | 3 + CMakeLists.txt | 1 + src/builder/CMakeLists.txt | 3 +- src/builder/frontend_dialect_transformer.cpp | 549 ++++++------------ src/builder/op_build_table.inc | 346 ++++------- src/dialect/onnx/gen_doc.py | 87 +-- src/pass/lower_frontend_to_krnl.cpp | 14 +- test/mlir/onnx/onnx_lowering.mlir | 10 +- .../mlir/onnx/onnx_lowering_with_dealloc.mlir | 16 +- third_party/variant | 1 + 10 files changed, 375 insertions(+), 655 deletions(-) create mode 160000 third_party/variant diff --git a/.gitmodules b/.gitmodules index 285a7ac..2293919 100644 --- a/.gitmodules +++ b/.gitmodules @@ -7,3 +7,6 @@ [submodule "third_party/pybind11"] path = third_party/pybind11 url = https://github.com/pybind/pybind11.git +[submodule "third_party/variant"] + path = third_party/variant + url = git@github.com:mpark/variant.git diff --git a/CMakeLists.txt b/CMakeLists.txt index ab9e1d7..7ec7054 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -22,6 +22,7 @@ include(MLIR.cmake) add_subdirectory(third_party/onnx) add_subdirectory(third_party/benchmark) add_subdirectory(third_party/pybind11) +add_subdirectory(third_party/variant) set(CMAKE_CXX_STANDARD 14) add_subdirectory(src) diff --git a/src/builder/CMakeLists.txt b/src/builder/CMakeLists.txt index 7d96296..6033e52 100644 --- a/src/builder/CMakeLists.txt +++ b/src/builder/CMakeLists.txt @@ -7,8 +7,9 @@ add_library(builder target_include_directories(builder PRIVATE ${CMAKE_SOURCE_DIR}) target_include_directories(builder PRIVATE ${CMAKE_BINARY_DIR}) -target_link_libraries(builder compiler onnx ${MLIRLibs} curses) +target_link_libraries(builder compiler onnx ${MLIRLibs} curses mpark_variant) target_include_directories(builder PRIVATE ${CMAKE_SOURCE_DIR}/third_party/onnx + ${CMAKE_SOURCE_DIR}/third_party/variant ${CMAKE_SOURCE_DIR}) diff --git a/src/builder/frontend_dialect_transformer.cpp b/src/builder/frontend_dialect_transformer.cpp index 7dd52c0..6065cd3 100644 --- a/src/builder/frontend_dialect_transformer.cpp +++ b/src/builder/frontend_dialect_transformer.cpp @@ -1,6 +1,6 @@ //===- frontend_dialect_transformer.cpp - MLIR Operations -----------------===// // -// Copyright 2019 The IBM Research Authors. +// Copyright 2019 The IBM Research Authors. // // ============================================================================= // @@ -14,11 +14,16 @@ // //===----------------------------------------------------------------------===// +#include #include #include #include #include -#include + +// Using backported variant. +// bstd = backported standard library. +#include +namespace bstd = mpark; #include "mlir/Analysis/Verifier.h" #include "mlir/Dialect/StandardOps/Ops.h" @@ -42,15 +47,15 @@ namespace onnf { namespace { -void replaceAll( - std::string& str, const std::string& from, const std::string& to) { +void replaceAll(std::string &str, const std::string &from, + const std::string &to) { if (from.empty()) return; size_t start_pos = 0; while ((start_pos = str.find(from, start_pos)) != std::string::npos) { str.replace(start_pos, from.length(), to); - start_pos += to.length(); // In case 'to' contains 'from', like replacing - // 'x' with 'yx' + start_pos += to.length(); // In case 'to' contains 'from', like replacing + // 'x' with 'yx' } } @@ -71,10 +76,10 @@ struct OnnxOnnfSymbolMapping { * @param name onnx tensor name. * @return onnf tensor corresponding to `name`. */ - mlir::Value GetTensorByOnnxName(std::string name) { + mlir::Value GetTensorByOnnxName(const std::string &name) { assert(onnx_name2onnf_tensor.find(legalize_name(name)) != - onnx_name2onnf_tensor.end() && - "Tensor not found"); + onnx_name2onnf_tensor.end() && + "Tensor not found"); return onnx_name2onnf_tensor.at(legalize_name(name)); } @@ -83,9 +88,9 @@ struct OnnxOnnfSymbolMapping { * @param name onnx tensor name. * @param tensor MLIR Value pointer. */ - void AddMapping(std::string name, mlir::Value tensor) { + void AddMapping(const std::string &name, mlir::Value tensor) { assert(onnx_name2onnf_tensor.count(legalize_name(name)) == 0 && - "Tensor already exists."); + "Tensor already exists."); onnx_name2onnf_tensor.emplace(legalize_name(name), tensor); } @@ -124,34 +129,34 @@ private: // Convert type to MLIR type. // A complete list of types can be found in: // /third_party/onnx/onnx/onnx.pb.h - mlir::Type TypeConvert(onnx::TensorProto_DataType intype) { - switch (intype) { - case onnx::TensorProto_DataType::TensorProto_DataType_FLOAT16: - return builder_.getF16Type(); - case onnx::TensorProto_DataType::TensorProto_DataType_FLOAT: - return builder_.getF32Type(); - case onnx::TensorProto_DataType::TensorProto_DataType_DOUBLE: - return builder_.getF64Type(); - case onnx::TensorProto_DataType::TensorProto_DataType_INT8: - case onnx::TensorProto_DataType::TensorProto_DataType_UINT8: - return builder_.getIntegerType(8); - case onnx::TensorProto_DataType::TensorProto_DataType_INT16: - case onnx::TensorProto_DataType::TensorProto_DataType_UINT16: - return builder_.getIntegerType(16); - case onnx::TensorProto_DataType::TensorProto_DataType_INT32: - case onnx::TensorProto_DataType::TensorProto_DataType_UINT32: - return builder_.getIntegerType(32); - case onnx::TensorProto_DataType::TensorProto_DataType_INT64: - case onnx::TensorProto_DataType::TensorProto_DataType_UINT64: - return builder_.getIntegerType(64); - case onnx::TensorProto_DataType::TensorProto_DataType_BOOL: - return builder_.getI1Type(); - case onnx::TensorProto_DataType::TensorProto_DataType_STRING: - case onnx::TensorProto_DataType::TensorProto_DataType_COMPLEX64: - case onnx::TensorProto_DataType::TensorProto_DataType_COMPLEX128: - case onnx::TensorProto_DataType::TensorProto_DataType_UNDEFINED: - assert(false && "Unsupported data type encountered."); - return nullptr; + mlir::Type convertONNXTypeToMLIRType(onnx::TensorProto_DataType onnxType) { + switch (onnxType) { + case onnx::TensorProto_DataType::TensorProto_DataType_FLOAT16: + return builder_.getF16Type(); + case onnx::TensorProto_DataType::TensorProto_DataType_FLOAT: + return builder_.getF32Type(); + case onnx::TensorProto_DataType::TensorProto_DataType_DOUBLE: + return builder_.getF64Type(); + case onnx::TensorProto_DataType::TensorProto_DataType_INT8: + case onnx::TensorProto_DataType::TensorProto_DataType_UINT8: + return builder_.getIntegerType(8); + case onnx::TensorProto_DataType::TensorProto_DataType_INT16: + case onnx::TensorProto_DataType::TensorProto_DataType_UINT16: + return builder_.getIntegerType(16); + case onnx::TensorProto_DataType::TensorProto_DataType_INT32: + case onnx::TensorProto_DataType::TensorProto_DataType_UINT32: + return builder_.getIntegerType(32); + case onnx::TensorProto_DataType::TensorProto_DataType_INT64: + case onnx::TensorProto_DataType::TensorProto_DataType_UINT64: + return builder_.getIntegerType(64); + case onnx::TensorProto_DataType::TensorProto_DataType_BOOL: + return builder_.getI1Type(); + case onnx::TensorProto_DataType::TensorProto_DataType_STRING: + case onnx::TensorProto_DataType::TensorProto_DataType_COMPLEX64: + case onnx::TensorProto_DataType::TensorProto_DataType_COMPLEX128: + case onnx::TensorProto_DataType::TensorProto_DataType_UNDEFINED: + assert(false && "Unsupported data type encountered."); + return nullptr; } } @@ -169,8 +174,8 @@ private: for (int i = 0; i < shape_proto.dim_size(); i++) { if (shape_proto.dim()[i].dim_value()) { int dim_numeric_size = shape_proto.dim()[i].dim_value(); - assert( - dim_numeric_size != 0 && "Parsed an input tensor with a dimension size of zero"); + assert(dim_numeric_size != 0 && + "Parsed an input tensor with a dimension size of zero"); if (dim_numeric_size > 0) { dims.push_back(dim_numeric_size); } else { // If dim_value < 0, then dim is parametric. @@ -184,7 +189,7 @@ private: } mlir::Type elementType = - TypeConvert(input.type().tensor_type().elem_type()); + convertONNXTypeToMLIRType(input.type().tensor_type().elem_type()); llvm::ArrayRef tensor_dims(dims.data(), dims.size()); arg_types.emplace_back( mlir::RankedTensorType::get(tensor_dims, elementType)); @@ -200,288 +205,111 @@ private: void ImportInputTensorSymbol(const onnx::ValueInfoProto &input, mlir::Value symbol) { auto input_tensor_legalized_name = legalize_name(input.name()); - assert( - !frontend_symbols_.ContainKey(input_tensor_legalized_name) && - "Found duplicate legalized input tensor names."); + assert(!frontend_symbols_.ContainKey(input_tensor_legalized_name) && + "Found duplicate legalized input tensor names."); frontend_symbols_.AddMapping(input_tensor_legalized_name, symbol); } - template - T get_attr_generic(onnx::NodeProto &node, std::string name, - std::function attr_getter, - T default_val) { + typedef bstd::variant, float, + std::vector, std::string, + std::vector> + AttrValueType; + + struct ONNXAttrVisitor { + ONNXAttrVisitor(std::string name, mlir::OpBuilder &builder) + : _builder(builder), _name(std::move(name)) {} + + // Op builder. + mlir::OpBuilder &_builder; + + // Name of the attribute being inspected. + std::string _name; + + mlir::NamedAttribute operator()(int64_t const &r) { + auto val = _builder.getI32IntegerAttr(r); + return _builder.getNamedAttr(_name, val); + } + + mlir::NamedAttribute operator()(std::vector const &ints) { + auto val = _builder.getI64ArrayAttr(ints); + return _builder.getNamedAttr(_name, val); + } + + mlir::NamedAttribute operator()(float const &r) { + auto val = _builder.getF32FloatAttr(r); + return _builder.getNamedAttr(_name, val); + } + + mlir::NamedAttribute operator()(std::vector const &floats) { + auto val = _builder.getF32ArrayAttr(floats); + return _builder.getNamedAttr(_name, val); + } + + mlir::NamedAttribute operator()(std::string const &s) { + auto val = _builder.getStringAttr(s); + return _builder.getNamedAttr(_name, val); + } + + mlir::NamedAttribute operator()(std::vector const &r) { + assert(false && "type of attribute value is not implemented"); + auto val = _builder.getI32IntegerAttr(1); + return _builder.getNamedAttr(_name, val); + }; + }; + + mlir::NamedAttribute convertNameValuePairToNamedAttribute( + std::pair nameAndVal) { + auto visitor = ONNXAttrVisitor(nameAndVal.first, builder_); + return mpark::visit(visitor, nameAndVal.second); + } + + static std::pair + convertAttributeProtoToNameValuePair(onnx::AttributeProto &attr) { + AttrValueType val; + switch (attr.type()) { + case onnx::AttributeProto::FLOAT: + return std::make_pair(attr.name(), AttrValueType(attr.f())); + case onnx::AttributeProto::INT: + return std::make_pair(attr.name(), AttrValueType(attr.i())); + case onnx::AttributeProto::STRING: + return std::make_pair(attr.name(), AttrValueType(attr.s())); + case onnx::AttributeProto::FLOATS: + val = AttrValueType( + std::vector(attr.floats().begin(), attr.floats().end())); + return std::make_pair(attr.name(), val); + case onnx::AttributeProto::INTS: + val = AttrValueType( + std::vector(attr.ints().begin(), attr.ints().end())); + return std::make_pair(attr.name(), val); + default: + assert(false && "datatype for attribute is not implemented"); + break; + } + } + + std::vector ImportNodeAttributes( + const onnx::NodeProto &node, + std::initializer_list> + defaultAttrList) { + std::vector attributes; + std::set definedAttributeSet; for (int i = 0; i < node.attribute_size(); ++i) { auto attr = node.attribute(i); - if (attr.name() == name) { - return attr_getter(attr); - } + auto nameValPair = convertAttributeProtoToNameValuePair(attr); + attributes.push_back(convertNameValuePairToNamedAttribute(nameValPair)); + definedAttributeSet.insert(attr.name()); } - return default_val; - } - - template - T get_attr_generic(onnx::NodeProto &node, std::string name, - std::function attr_getter) { - for (int i = 0; i < node.attribute_size(); ++i) { - auto attr = node.attribute(i); - if (attr.name() == name) { - return attr_getter(attr); - } + for (const auto &defaultAttr : defaultAttrList) { + if (definedAttributeSet.find(defaultAttr.first) == + definedAttributeSet.end()) + attributes.push_back(convertNameValuePairToNamedAttribute(defaultAttr)); } - assert(false && "ONNX Node Attribute Not Found!"); + return attributes; } - auto get_attr_ints(onnx::NodeProto &node, std::string name, - std::vector default_val) { - std::function(onnx::AttributeProto &)> attr_getter = - [](onnx::AttributeProto &attr) { - std::vector ints(attr.ints_size()); - std::copy(attr.ints().begin(), attr.ints().end(), ints.begin()); - return ints; - }; - auto r = get_attr_generic(node, name, attr_getter, default_val); - auto dataType = - mlir::RankedTensorType::get(r.size(), builder_.getIntegerType(32)); - auto attr_v = mlir::DenseElementsAttr::get(dataType, llvm::makeArrayRef(r)); - auto aname = node.op_type() + "." + name; - auto attr_output = builder_.getNamedAttr(aname, attr_v); - return attr_output; - } - - auto get_attr_ints(onnx::NodeProto &node, std::string name) { - std::function(onnx::AttributeProto &)> attr_getter = - [](onnx::AttributeProto &attr) { - std::vector ints(attr.ints_size()); - std::copy(attr.ints().begin(), attr.ints().end(), ints.begin()); - return ints; - }; - auto r = get_attr_generic(node, name, attr_getter); - auto dataType = - mlir::RankedTensorType::get(r.size(), builder_.getIntegerType(32)); - auto attr_v = mlir::DenseElementsAttr::get(dataType, llvm::makeArrayRef(r)); - auto aname = node.op_type() + "." + name; - auto attr_output = builder_.getNamedAttr(aname, attr_v); - return attr_output; - } - - auto get_attr_floats(onnx::NodeProto &node, std::string name) { - std::function(onnx::AttributeProto &)> attr_getter = - [](onnx::AttributeProto &attr) { - std::vector floats(attr.floats_size()); - std::copy(attr.floats().begin(), attr.floats().end(), floats.begin()); - return floats; - }; - auto r = get_attr_generic(node, name, attr_getter); - auto dataType = - mlir::RankedTensorType::get(r.size(), builder_.getF32Type()); - auto attr_v = mlir::DenseElementsAttr::get(dataType, llvm::makeArrayRef(r)); - auto aname = node.op_type() + "." + name; - auto attr_output = builder_.getNamedAttr(aname, attr_v); - return attr_output; - } - - auto get_attr_floats(onnx::NodeProto &node, std::string name, - std::vector default_val) { - std::function(onnx::AttributeProto &)> attr_getter = - [](onnx::AttributeProto &attr) { - std::vector floats(attr.floats_size()); - std::copy(attr.floats().begin(), attr.floats().end(), floats.begin()); - return floats; - }; - auto r = get_attr_generic(node, name, attr_getter, default_val); - auto dataType = - mlir::RankedTensorType::get(r.size(), builder_.getF32Type()); - auto attr_v = mlir::DenseElementsAttr::get(dataType, llvm::makeArrayRef(r)); - auto aname = node.op_type() + "." + name; - auto attr_output = builder_.getNamedAttr(aname, attr_v); - return attr_output; - } - - auto get_attr_int(onnx::NodeProto &node, std::string name) { - std::function attr_getter = - [](onnx::AttributeProto &attr) { return attr.i(); }; - int r = get_attr_generic(node, name, attr_getter); - auto attr_v = builder_.getI32IntegerAttr(r); - auto aname = node.op_type() + "." + name; - auto attr_output = builder_.getNamedAttr(aname, attr_v); - return attr_output; - } - - auto get_attr_int(onnx::NodeProto &node, std::string name, int default_val) { - std::function attr_getter = - [](onnx::AttributeProto &attr) { return attr.i(); }; - int r = get_attr_generic(node, name, attr_getter, default_val); - auto attr_v = builder_.getI32IntegerAttr(r); - auto aname = node.op_type() + "." + name; - auto attr_output = builder_.getNamedAttr(aname, attr_v); - return attr_output; - } - - auto get_attr_float(onnx::NodeProto &node, std::string name) { - std::function attr_getter = - [](onnx::AttributeProto &attr) { return attr.f(); }; - auto r = get_attr_generic(node, name, attr_getter); - auto attr_v = builder_.getF32FloatAttr(r); - auto aname = node.op_type() + "." + name; - return builder_.getNamedAttr(aname, attr_v); - } - - auto get_attr_float(onnx::NodeProto &node, std::string name, - float default_val) { - std::function attr_getter = - [](onnx::AttributeProto &attr) { return attr.f(); }; - auto r = get_attr_generic(node, name, attr_getter, default_val); - auto attr_v = builder_.getF32FloatAttr(r); - auto aname = node.op_type() + "." + name; - return builder_.getNamedAttr(aname, attr_v); - } - - auto get_attr_string(onnx::NodeProto &node, std::string name) { - std::function attr_getter = - [](onnx::AttributeProto &attr) { return attr.s(); }; - auto r = get_attr_generic(node, name, attr_getter); - auto attr_v = builder_.getStringAttr(r); - auto aname = node.op_type() + "." + name; - return builder_.getNamedAttr(aname, attr_v); - } - - auto get_attr_string(onnx::NodeProto &node, std::string name, - std::string default_val) { - std::function attr_getter = - [](onnx::AttributeProto &attr) { return attr.s(); }; - auto r = get_attr_generic(node, name, attr_getter, default_val); - auto attr_v = builder_.getStringAttr(r); - auto aname = node.op_type() + "." + name; - return builder_.getNamedAttr(aname, attr_v); - } - - /* - auto get_attr_strings(onnx::NodeProto &node, std::string name) { - std::function(onnx::AttributeProto &)> - attr_getter = - [](onnx::AttributeProto &attr) { - std::vector strings(attr.strings_size()); - std::copy(attr.strings().begin(), attr.strings().end(), - strings.begin()); return strings; - }; - auto r = get_attr_generic(node, name, attr_getter); - return r; - return builder_.getNamedAttr(aname, attr_v); - auto dataType = - mlir::RankedTensorType::get(r.size(), builder_.get???Type()); - auto attr_v = mlir::DenseElementsAttr::get(dataType, - llvm::makeArrayRef(r)); auto aname = node.op_type() + "." + name; auto - attr_output = builder_.getNamedAttr(aname, attr_v); return attr_output; - } - */ - - auto get_default_ints(std::string default_str) { - std::vector r; - auto start = default_str.find("{"); - while (true) { - auto end = default_str.find(",", start + 1); - if (end == std::string::npos) { - end = default_str.find("}", start + 1); - if (end != std::string::npos && end > start + 1) { - r.push_back(std::stoi(default_str.substr(start + 1, end))); - } - break; - } else { - r.push_back(std::stoi(default_str.substr(start + 1, end))); - } - start = end + 1; - } - return r; - } - - auto get_default_floats(std::string default_str) { - std::vector r; - auto start = default_str.find("{"); - while (true) { - auto end = default_str.find(",", start + 1); - if (end == std::string::npos) { - end = default_str.find("}", start + 1); - if (end != std::string::npos && end > start + 1) { - r.push_back(std::stof(default_str.substr(start + 1, end))); - } - break; - } else { - r.push_back(std::stof(default_str.substr(start + 1, end))); - } - start = end + 1; - } - return r; - } - - auto get_default_strings(std::string default_str) { - std::vector r; - auto start = default_str.find("{"); - while (true) { - auto end = default_str.find(",", start + 1); - if (end == std::string::npos) { - end = default_str.find("}", start + 1); - if (end != std::string::npos && end > start + 1) { - r.push_back(default_str.substr(start + 1, end)); - } - break; - } else { - r.push_back(default_str.substr(start + 1, end)); - } - start = end + 1; - } - return r; - } - - onnx::TensorProto get_attr_tensor(onnx::NodeProto &node, std::string name) { - std::function attr_getter = - [](onnx::AttributeProto &attr) { return attr.t(); }; - return get_attr_generic(node, name, attr_getter); - } - - auto ImportNodeAttr(onnx::NodeProto node, std::string attr_name, - std::string type_name, std::string default_str) { - if (default_str == "") { - if (type_name == "int") { - return get_attr_int(node, attr_name); - } else if (type_name == "float") { - return get_attr_float(node, attr_name); - } else if (type_name == "str") { - return get_attr_string(node, attr_name); - } else if (type_name == "ints") { - return get_attr_ints(node, attr_name); - } else if (type_name == "floats") { - return get_attr_floats(node, attr_name); - } else { - assert( - false && - "Got an empty initializer or initializer for this " - "datatype is not implemented. Something is wrong."); - } - } else { - // with default value - if (type_name == "int") { - return get_attr_int(node, attr_name, std::stoi(default_str)); - } else if (type_name == "float") { - return get_attr_float(node, attr_name, std::stof(default_str)); - } else if (type_name == "str") { - return get_attr_string(node, attr_name, default_str); - } else if (type_name == "ints") { - return get_attr_ints(node, attr_name, get_default_ints(default_str)); - } else if (type_name == "floats") { - return get_attr_floats(node, attr_name, - get_default_floats(default_str)); - } else { - assert( - false && - "Got an empty initializer or initializer for this " - "datatype is not implemented. Something is wrong."); - } - } - } - - void ImportNodeGeneric(onnx::NodeProto node) { + void ImportNodeGeneric(const onnx::NodeProto &node) { std::vector inputs; - for (auto item : node.input()) { + for (const auto &item : node.input()) { if (frontend_symbols_.ContainKey(legalize_name(item))) { inputs.push_back(frontend_symbols_.GetTensorByOnnxName(item)); } @@ -511,12 +339,12 @@ private: * default} */ template - void ImportNodeOneOut( - onnx::NodeProto node, int nIn, int nOut, - std::initializer_list> - attrs) { + void + ImportNodeOneOut(const onnx::NodeProto &node, int nIn, int nOut, + std::initializer_list> + defaultAttrList) { std::vector inputs; - for (auto item : node.input()) { + for (const auto &item : node.input()) { if (frontend_symbols_.ContainKey(legalize_name(item))) { inputs.push_back(frontend_symbols_.GetTensorByOnnxName(item)); } @@ -528,22 +356,7 @@ private: mlir::UnrankedTensorType::get(builder_.getF32Type())); } - std::vector attributes; - // for (auto [attr_name, attr_type, attr_default] : attrs) { - for (auto oneAttr : attrs) { - std::string attr_name; - std::string attr_type; - std::string attr_default; - std::tie(attr_name, attr_type, attr_default) = oneAttr; - if (attr_type != "") { - auto attr = ImportNodeAttr(node, attr_name, attr_type, attr_default); - attributes.push_back(attr); - } else { - // TODO: the attributes need special handling - // std::cout << "missing " << node.op_type() << " " << attr_name << - // std::endl; - } - } + auto attributes = ImportNodeAttributes(node, defaultAttrList); llvm::StringRef OpName = node.op_type(); @@ -559,11 +372,11 @@ private: template void ImportNodeMultipleOuts( - onnx::NodeProto node, int nIn, int nOut, - std::initializer_list> - attrs) { + const onnx::NodeProto &node, int nIn, int nOut, + std::initializer_list> + defaultAttrList) { std::vector inputs; - for (auto item : node.input()) { + for (const auto &item : node.input()) { if (frontend_symbols_.ContainKey(legalize_name(item))) { inputs.push_back(frontend_symbols_.GetTensorByOnnxName(item)); } @@ -575,21 +388,7 @@ private: mlir::UnrankedTensorType::get(builder_.getF32Type())); } - std::vector attributes; - for (auto oneAttr : attrs) { - std::string attr_name; - std::string attr_type; - std::string attr_default; - std::tie(attr_name, attr_type, attr_default) = oneAttr; - if (attr_type != "") { - auto attr = ImportNodeAttr(node, attr_name, attr_type, attr_default); - attributes.push_back(attr); - } else { - // TODO: the attributes need special handling - // std::cout << "missing " << node.op_type() << " " << attr_name << - // std::endl; - } - } + auto attributes = ImportNodeAttributes(node, defaultAttrList); llvm::StringRef OpName = node.op_type(); @@ -610,10 +409,10 @@ private: * c++ does not allow template specialization inside a class scope * a specialized function is used */ - void ImportNodeConv( - onnx::NodeProto node, int nOut, - std::initializer_list> - attrs) { + void + ImportNodeConv(onnx::NodeProto node, int nIn, int nOut, + std::initializer_list> + defaultAttrList) { // Conv has attribute dilations, kernel_shape, pads, the default value of // which is determined by the shape of first argument. However, since the // shape is unknown now, these attributes can be not generated auto @@ -627,29 +426,32 @@ private: int nOps = node.input().size(); if (nOps == 2) - ImportNodeOneOut(node, nOps, nOut, attrs); + ImportNodeOneOut( + node, nOps, nOut, defaultAttrList); else - ImportNodeOneOut(node, nOps, nOut, attrs); + ImportNodeOneOut(node, nOps, nOut, defaultAttrList); } /*! * Special handle for MaxPool operations. */ void ImportNodeMaxPool( - onnx::NodeProto node, int nIn, - std::initializer_list> - attrs) { + onnx::NodeProto node, int nIn, int nOut, + std::initializer_list> + defaultAttrList) { int nOuts = node.output().size(); if (nOuts == 1) { - ImportNodeOneOut(node, nIn, nOuts, attrs); + ImportNodeOneOut( + node, nIn, nOuts, defaultAttrList); } else { - ImportNodeMultipleOuts(node, nIn, nOuts, attrs); + ImportNodeMultipleOuts( + node, nIn, nOuts, defaultAttrList); } } - void ImportNode(onnx::NodeProto node) { + void ImportNode(const onnx::NodeProto &node) { std::vector inputs; - for (auto item : node.input()) { + for (const auto &item : node.input()) { if (frontend_symbols_.ContainKey(legalize_name(item))) { inputs.push_back(frontend_symbols_.GetTensorByOnnxName(item)); } @@ -689,9 +491,8 @@ private: llvm::SmallVectorImpl &ret_types, llvm::SmallVectorImpl &ret_vals) { auto output_tensor_legalized_name = legalize_name(output.name()); - assert( - frontend_symbols_.ContainKey(output_tensor_legalized_name) && - "Output tensor not found"); + assert(frontend_symbols_.ContainKey(output_tensor_legalized_name) && + "Output tensor not found"); auto tensor_val = frontend_symbols_.GetTensorByOnnxName(output_tensor_legalized_name); @@ -750,9 +551,9 @@ private: funcType = builder_.getFunctionType(arg_types, ret_types); mainFunc.setType(funcType); } -}; // FrontendGenImpl class -} // namespace -} // namespace onnf +}; // FrontendGenImpl class +} // namespace +} // namespace onnf namespace onnf { @@ -775,4 +576,4 @@ void ImportFrontendModelFile(std::string model_fname, FrontendGenImpl myONNXGen(context); module = myONNXGen.ImportONNXModel(model); } -} // namespace onnf +} // namespace onnf diff --git a/src/builder/op_build_table.inc b/src/builder/op_build_table.inc index 29a48cc..0e7f20e 100644 --- a/src/builder/op_build_table.inc +++ b/src/builder/op_build_table.inc @@ -16,13 +16,13 @@ }); }else if (OpName == "ArgMax") { ImportNodeOneOut(node, 1, 1, { - {"axis","int","0"} - ,{"keepdims","int","1"} + {"axis", 0} + ,{"keepdims", 1} }); }else if (OpName == "ArgMin") { ImportNodeOneOut(node, 1, 1, { - {"axis","int","0"} - ,{"keepdims","int","1"} + {"axis", 0} + ,{"keepdims", 1} }); }else if (OpName == "Asin") { ImportNodeOneOut(node, 1, 1, { @@ -38,25 +38,22 @@ }); }else if (OpName == "AveragePool") { ImportNodeOneOut(node, 1, 1, { - {"auto_pad","str","NOTSET"} - ,{"ceil_mode","int","0"} - ,{"count_include_pad","int","0"} - ,{"kernel_shape","ints", ""} - ,{"pads","", ""} - ,{"strides","", ""} + {"auto_pad", "NOTSET"} + ,{"ceil_mode", 0} + ,{"count_include_pad", 0} + ,{"kernel_shape", std::vector {}} }); }else if (OpName == "BatchNormalization") { ImportNodeMultipleOuts(node, 5, 5, { - {"epsilon","float","1e-05"} - ,{"momentum","float","0.9"} + {"epsilon", (float)1e-05} + ,{"momentum", (float)0.9} }); }else if (OpName == "BitShift") { ImportNodeOneOut(node, 2, 1, { - {"direction","", ""} }); }else if (OpName == "Cast") { ImportNodeOneOut(node, 1, 1, { - {"to","int", "0"} + {"to", 0} }); }else if (OpName == "Ceil") { ImportNodeOneOut(node, 1, 1, { @@ -66,54 +63,35 @@ }); }else if (OpName == "Compress") { ImportNodeOneOut(node, 2, 1, { - {"axis","", ""} }); }else if (OpName == "Concat") { ImportNodeOneOut(node, 1, 1, { - {"axis","int", "0"} + {"axis", 0} }); }else if (OpName == "ConcatFromSequence") { ImportNodeOneOut(node, 1, 1, { - {"axis","", ""} - ,{"new_axis","int","0"} + {"new_axis", 0} }); }else if (OpName == "Constant") { ImportNodeOneOut(node, 0, 1, { - {"sparse_value","", ""} - ,{"value","", ""} }); }else if (OpName == "ConstantOfShape") { ImportNodeOneOut(node, 1, 1, { - {"value","", ""} }); }else if (OpName == "Conv") { - ImportNodeConv(node, 1, { - {"auto_pad","str","NOTSET"} - ,{"dilations","", ""} - ,{"group","int", "1"} - ,{"kernel_shape","", ""} - ,{"pads","", ""} - ,{"strides","", ""} + ImportNodeConv(node, 3, 1, { + {"auto_pad", "NOTSET"} + ,{"group", 1} }); }else if (OpName == "ConvInteger") { ImportNodeOneOut(node, 4, 1, { - {"auto_pad","str","NOTSET"} - ,{"dilations","", ""} - ,{"group","int","1"} - ,{"kernel_shape","", ""} - ,{"pads","", ""} - ,{"strides","", ""} + {"auto_pad", "NOTSET"} + ,{"group", 1} }); }else if (OpName == "ConvTranspose") { ImportNodeOneOut(node, 3, 1, { - {"auto_pad","str","NOTSET"} - ,{"dilations","", ""} - ,{"group","int","1"} - ,{"kernel_shape","", ""} - ,{"output_padding","", ""} - ,{"output_shape","", ""} - ,{"pads","", ""} - ,{"strides","", ""} + {"auto_pad", "NOTSET"} + ,{"group", 1} }); }else if (OpName == "Cos") { ImportNodeOneOut(node, 1, 1, { @@ -123,13 +101,12 @@ }); }else if (OpName == "CumSum") { ImportNodeOneOut(node, 2, 1, { - {"exclusive","int","0"} - ,{"reverse","int","0"} + {"exclusive", 0} + ,{"reverse", 0} }); }else if (OpName == "DepthToSpace") { ImportNodeOneOut(node, 1, 1, { - {"blocksize","", ""} - ,{"mode","str","DCR"} + {"mode", "DCR"} }); }else if (OpName == "DequantizeLinear") { ImportNodeOneOut(node, 3, 1, { @@ -142,14 +119,14 @@ }); }else if (OpName == "Dropout") { ImportNodeMultipleOuts(node, 1, 2, { - {"ratio","float","0.5"} + {"ratio", (float)0.5} }); }else if (OpName == "DynamicQuantizeLinear") { ImportNodeMultipleOuts(node, 1, 3, { }); }else if (OpName == "Elu") { ImportNodeOneOut(node, 1, 1, { - {"alpha","float","1.0"} + {"alpha", (float)1.0} }); }else if (OpName == "Equal") { ImportNodeOneOut(node, 2, 1, { @@ -165,50 +142,44 @@ }); }else if (OpName == "EyeLike") { ImportNodeOneOut(node, 1, 1, { - {"dtype","", ""} - ,{"k","int","0"} + {"k", 0} }); }else if (OpName == "Flatten") { ImportNodeOneOut(node, 1, 1, { - {"axis","int","1"} + {"axis", 1} }); }else if (OpName == "Floor") { ImportNodeOneOut(node, 1, 1, { }); }else if (OpName == "GRU") { ImportNodeMultipleOuts(node, 6, 2, { - {"activation_alpha","", ""} - ,{"activation_beta","", ""} - ,{"activations","", ""} - ,{"clip","", ""} - ,{"direction","str","forward"} - ,{"hidden_size","", ""} - ,{"linear_before_reset","int","0"} + {"direction", "forward"} + ,{"linear_before_reset", 0} }); }else if (OpName == "Gather") { ImportNodeOneOut(node, 2, 1, { - {"axis","int","0"} + {"axis", 0} }); }else if (OpName == "GatherElements") { ImportNodeOneOut(node, 2, 1, { - {"axis","int","0"} + {"axis", 0} }); }else if (OpName == "GatherND") { ImportNodeOneOut(node, 2, 1, { }); }else if (OpName == "Gemm") { ImportNodeOneOut(node, 3, 1, { - {"alpha","float","1.0"} - ,{"beta","float","1.0"} - ,{"transA","int","0"} - ,{"transB","int","0"} + {"alpha", (float)1.0} + ,{"beta", (float)1.0} + ,{"transA", 0} + ,{"transB", 0} }); }else if (OpName == "GlobalAveragePool") { ImportNodeOneOut(node, 1, 1, { }); }else if (OpName == "GlobalLpPool") { ImportNodeOneOut(node, 1, 1, { - {"p","int","2"} + {"p", 2} }); }else if (OpName == "GlobalMaxPool") { ImportNodeOneOut(node, 1, 1, { @@ -218,53 +189,45 @@ }); }else if (OpName == "HardSigmoid") { ImportNodeOneOut(node, 1, 1, { - {"alpha","float","0.2"} - ,{"beta","float","0.5"} + {"alpha", (float)0.2} + ,{"beta", (float)0.5} }); }else if (OpName == "Hardmax") { ImportNodeOneOut(node, 1, 1, { - {"axis","int","1"} + {"axis", 1} }); }else if (OpName == "Identity") { ImportNodeOneOut(node, 1, 1, { }); }else if (OpName == "If") { ImportNodeOneOut(node, 1, 1, { - {"else_branch","", ""} - ,{"then_branch","", ""} }); }else if (OpName == "InstanceNormalization") { ImportNodeOneOut(node, 3, 1, { - {"epsilon","float","1e-05"} + {"epsilon", (float)1e-05} }); }else if (OpName == "IsInf") { ImportNodeOneOut(node, 1, 1, { - {"detect_negative","int","1"} - ,{"detect_positive","int","1"} + {"detect_negative", 1} + ,{"detect_positive", 1} }); }else if (OpName == "IsNaN") { ImportNodeOneOut(node, 1, 1, { }); }else if (OpName == "LRN") { ImportNodeOneOut(node, 1, 1, { - {"alpha","float","0.0001"} - ,{"beta","float","0.75"} - ,{"bias","float","1.0"} - ,{"size","int", ""} + {"alpha", (float)0.0001} + ,{"beta", (float)0.75} + ,{"bias", (float)1.0} }); }else if (OpName == "LSTM") { ImportNodeMultipleOuts(node, 8, 3, { - {"activation_alpha","", ""} - ,{"activation_beta","", ""} - ,{"activations","", ""} - ,{"clip","", ""} - ,{"direction","str","forward"} - ,{"hidden_size","", ""} - ,{"input_forget","int","0"} + {"direction", "forward"} + ,{"input_forget", 0} }); }else if (OpName == "LeakyRelu") { ImportNodeOneOut(node, 1, 1, { - {"alpha","float","0.01"} + {"alpha", (float)0.01} }); }else if (OpName == "Less") { ImportNodeOneOut(node, 2, 1, { @@ -274,24 +237,20 @@ }); }else if (OpName == "LogSoftmax") { ImportNodeOneOut(node, 1, 1, { - {"axis","int","1"} + {"axis", 1} }); }else if (OpName == "Loop") { ImportNodeOneOut(node, 3, 1, { - {"body","", ""} }); }else if (OpName == "LpNormalization") { ImportNodeOneOut(node, 1, 1, { - {"axis","int","-1"} - ,{"p","int","2"} + {"axis", -1} + ,{"p", 2} }); }else if (OpName == "LpPool") { ImportNodeOneOut(node, 1, 1, { - {"auto_pad","str","NOTSET"} - ,{"kernel_shape","", ""} - ,{"p","int","2"} - ,{"pads","", ""} - ,{"strides","", ""} + {"auto_pad", "NOTSET"} + ,{"p", 2} }); }else if (OpName == "MatMul") { ImportNodeOneOut(node, 2, 1, { @@ -303,55 +262,47 @@ ImportNodeOneOut(node, 1, 1, { }); }else if (OpName == "MaxPool") { - ImportNodeMaxPool(node, 1, { - {"auto_pad","str","NOTSET"} - ,{"ceil_mode","int","0"} - ,{"dilations","", ""} - ,{"kernel_shape","ints", ""} - ,{"pads","", ""} - ,{"storage_order","int","0"} - ,{"strides","", ""} + ImportNodeMaxPool(node, 1, 2, { + {"auto_pad", "NOTSET"} + ,{"ceil_mode", 0} + ,{"kernel_shape", std::vector {}} + ,{"storage_order", 0} }); }else if (OpName == "MaxRoiPool") { ImportNodeOneOut(node, 2, 1, { - {"pooled_shape","", ""} - ,{"spatial_scale","float","1.0"} + {"spatial_scale", (float)1.0} }); }else if (OpName == "MaxUnpool") { ImportNodeOneOut(node, 3, 1, { - {"kernel_shape","", ""} - ,{"pads","", ""} - ,{"strides","", ""} }); }else if (OpName == "Mean") { ImportNodeOneOut(node, 1, 1, { }); }else if (OpName == "MeanVarianceNormalization") { ImportNodeOneOut(node, 1, 1, { - {"axes","ints","{'0', '2', '3'}"} + {"axes", std::vector{0, 2, 3}} }); }else if (OpName == "Min") { ImportNodeOneOut(node, 1, 1, { }); }else if (OpName == "Mod") { ImportNodeOneOut(node, 2, 1, { - {"fmod","int","0"} + {"fmod", 0} }); }else if (OpName == "Mul") { ImportNodeOneOut(node, 2, 1, { }); }else if (OpName == "Multinomial") { ImportNodeOneOut(node, 1, 1, { - {"dtype","int","6"} - ,{"sample_size","int","1"} - ,{"seed","", ""} + {"dtype", 6} + ,{"sample_size", 1} }); }else if (OpName == "Neg") { ImportNodeOneOut(node, 1, 1, { }); }else if (OpName == "NonMaxSuppression") { ImportNodeOneOut(node, 5, 1, { - {"center_point_box","int","0"} + {"center_point_box", 0} }); }else if (OpName == "NonZero") { ImportNodeOneOut(node, 1, 1, { @@ -361,7 +312,7 @@ }); }else if (OpName == "OneHot") { ImportNodeOneOut(node, 3, 1, { - {"axis","int","-1"} + {"axis", -1} }); }else if (OpName == "Or") { ImportNodeOneOut(node, 2, 1, { @@ -371,19 +322,15 @@ }); }else if (OpName == "Pad") { ImportNodeOneOut(node, 3, 1, { - {"mode","str","constant"} + {"mode", "constant"} }); }else if (OpName == "Pow") { ImportNodeOneOut(node, 2, 1, { }); }else if (OpName == "QLinearConv") { ImportNodeOneOut(node, 9, 1, { - {"auto_pad","str","NOTSET"} - ,{"dilations","", ""} - ,{"group","int","1"} - ,{"kernel_shape","", ""} - ,{"pads","", ""} - ,{"strides","", ""} + {"auto_pad", "NOTSET"} + ,{"group", 1} }); }else if (OpName == "QLinearMatMul") { ImportNodeOneOut(node, 8, 1, { @@ -393,42 +340,32 @@ }); }else if (OpName == "RNN") { ImportNodeMultipleOuts(node, 6, 2, { - {"activation_alpha","floats", "{}"} - ,{"activation_beta","floats", "{}"} - ,{"activations","", "{Tannh, Tanh}"} - ,{"clip","", ""} - ,{"direction","str","forward"} - ,{"hidden_size","", ""} + {"activation_alpha", std::vector {}} + ,{"activation_beta", std::vector {}} + ,{"activations", std::vector{"Tanh", "Tanh"}} + ,{"direction", "forward"} }); }else if (OpName == "RandomNormal") { ImportNodeOneOut(node, 0, 1, { - {"dtype","int","1"} - ,{"mean","float","0.0"} - ,{"scale","float","1.0"} - ,{"seed","", ""} - ,{"shape","", ""} + {"dtype", 1} + ,{"mean", (float)0.0} + ,{"scale", (float)1.0} }); }else if (OpName == "RandomNormalLike") { ImportNodeOneOut(node, 1, 1, { - {"dtype","", ""} - ,{"mean","float","0.0"} - ,{"scale","float","1.0"} - ,{"seed","", ""} + {"mean", (float)0.0} + ,{"scale", (float)1.0} }); }else if (OpName == "RandomUniform") { ImportNodeOneOut(node, 0, 1, { - {"dtype","int","1"} - ,{"high","float","1.0"} - ,{"low","float","0.0"} - ,{"seed","", ""} - ,{"shape","", ""} + {"dtype", 1} + ,{"high", (float)1.0} + ,{"low", (float)0.0} }); }else if (OpName == "RandomUniformLike") { ImportNodeOneOut(node, 1, 1, { - {"dtype","", ""} - ,{"high","float","1.0"} - ,{"low","float","0.0"} - ,{"seed","", ""} + {"high", (float)1.0} + ,{"low", (float)0.0} }); }else if (OpName == "Range") { ImportNodeOneOut(node, 3, 1, { @@ -438,53 +375,43 @@ }); }else if (OpName == "ReduceL1") { ImportNodeOneOut(node, 1, 1, { - {"axes","", ""} - ,{"keepdims","int","1"} + {"keepdims", 1} }); }else if (OpName == "ReduceL2") { ImportNodeOneOut(node, 1, 1, { - {"axes","", ""} - ,{"keepdims","int","1"} + {"keepdims", 1} }); }else if (OpName == "ReduceLogSum") { ImportNodeOneOut(node, 1, 1, { - {"axes","", ""} - ,{"keepdims","int","1"} + {"keepdims", 1} }); }else if (OpName == "ReduceLogSumExp") { ImportNodeOneOut(node, 1, 1, { - {"axes","", ""} - ,{"keepdims","int","1"} + {"keepdims", 1} }); }else if (OpName == "ReduceMax") { ImportNodeOneOut(node, 1, 1, { - {"axes","", ""} - ,{"keepdims","int","1"} + {"keepdims", 1} }); }else if (OpName == "ReduceMean") { ImportNodeOneOut(node, 1, 1, { - {"axes","", ""} - ,{"keepdims","int","1"} + {"keepdims", 1} }); }else if (OpName == "ReduceMin") { ImportNodeOneOut(node, 1, 1, { - {"axes","", ""} - ,{"keepdims","int","1"} + {"keepdims", 1} }); }else if (OpName == "ReduceProd") { ImportNodeOneOut(node, 1, 1, { - {"axes","", ""} - ,{"keepdims","int","1"} + {"keepdims", 1} }); }else if (OpName == "ReduceSum") { ImportNodeOneOut(node, 1, 1, { - {"axes","", ""} - ,{"keepdims","int","1"} + {"keepdims", 1} }); }else if (OpName == "ReduceSumSquare") { ImportNodeOneOut(node, 1, 1, { - {"axes","", ""} - ,{"keepdims","int","1"} + {"keepdims", 1} }); }else if (OpName == "Relu") { ImportNodeOneOut(node, 1, 1, { @@ -494,53 +421,47 @@ }); }else if (OpName == "Resize") { ImportNodeOneOut(node, 4, 1, { - {"coordinate_transformation_mode","str","half_pixel"} - ,{"cubic_coeff_a","float","-0.75"} - ,{"exclude_outside","int","0"} - ,{"extrapolation_value","float","0.0"} - ,{"mode","str","nearest"} - ,{"nearest_mode","str","round_prefer_floor"} + {"coordinate_transformation_mode", "half_pixel"} + ,{"cubic_coeff_a", (float)-0.75} + ,{"exclude_outside", 0} + ,{"extrapolation_value", (float)0.0} + ,{"mode", "nearest"} + ,{"nearest_mode", "round_prefer_floor"} }); }else if (OpName == "ReverseSequence") { ImportNodeOneOut(node, 2, 1, { - {"batch_axis","int","1"} - ,{"time_axis","int","0"} + {"batch_axis", 1} + ,{"time_axis", 0} }); }else if (OpName == "RoiAlign") { ImportNodeOneOut(node, 3, 1, { - {"mode","str","avg"} - ,{"output_height","int","1"} - ,{"output_width","int","1"} - ,{"sampling_ratio","int","0"} - ,{"spatial_scale","float","1.0"} + {"mode", "avg"} + ,{"output_height", 1} + ,{"output_width", 1} + ,{"sampling_ratio", 0} + ,{"spatial_scale", (float)1.0} }); }else if (OpName == "Round") { ImportNodeOneOut(node, 1, 1, { }); }else if (OpName == "Scan") { ImportNodeOneOut(node, 1, 1, { - {"body","", ""} - ,{"num_scan_inputs","", ""} - ,{"scan_input_axes","", ""} - ,{"scan_input_directions","", ""} - ,{"scan_output_axes","", ""} - ,{"scan_output_directions","", ""} }); }else if (OpName == "Scatter") { ImportNodeOneOut(node, 3, 1, { - {"axis","int","0"} + {"axis", 0} }); }else if (OpName == "ScatterElements") { ImportNodeOneOut(node, 3, 1, { - {"axis","int","0"} + {"axis", 0} }); }else if (OpName == "ScatterND") { ImportNodeOneOut(node, 3, 1, { }); }else if (OpName == "Selu") { ImportNodeOneOut(node, 1, 1, { - {"alpha","float","1.67326"} - ,{"gamma","float","1.0507"} + {"alpha", (float)1.67326} + ,{"gamma", (float)1.0507} }); }else if (OpName == "SequenceAt") { ImportNodeOneOut(node, 2, 1, { @@ -550,7 +471,6 @@ }); }else if (OpName == "SequenceEmpty") { ImportNodeOneOut(node, 0, 1, { - {"dtype","", ""} }); }else if (OpName == "SequenceErase") { ImportNodeOneOut(node, 2, 1, { @@ -566,8 +486,8 @@ }); }else if (OpName == "Shrink") { ImportNodeOneOut(node, 1, 1, { - {"bias","float","0.0"} - ,{"lambd","float","0.5"} + {"bias", (float)0.0} + ,{"lambd", (float)0.5} }); }else if (OpName == "Sigmoid") { ImportNodeOneOut(node, 1, 1, { @@ -589,7 +509,7 @@ }); }else if (OpName == "Softmax") { ImportNodeOneOut(node, 1, 1, { - {"axis","int","1"} + {"axis", 1} }); }else if (OpName == "Softplus") { ImportNodeOneOut(node, 1, 1, { @@ -599,31 +519,26 @@ }); }else if (OpName == "SpaceToDepth") { ImportNodeOneOut(node, 1, 1, { - {"blocksize","", ""} }); }else if (OpName == "Split") { ImportNodeOneOut(node, 1, 1, { - {"axis","int","0"} - ,{"split","", ""} + {"axis", 0} }); }else if (OpName == "SplitToSequence") { ImportNodeOneOut(node, 2, 1, { - {"axis","int","0"} - ,{"keepdims","int","1"} + {"axis", 0} + ,{"keepdims", 1} }); }else if (OpName == "Sqrt") { ImportNodeOneOut(node, 1, 1, { }); }else if (OpName == "Squeeze") { ImportNodeOneOut(node, 1, 1, { - {"axes","", ""} }); }else if (OpName == "StringNormalizer") { ImportNodeOneOut(node, 1, 1, { - {"case_change_action","str","NONE"} - ,{"is_case_sensitive","int","0"} - ,{"locale","", ""} - ,{"stopwords","", ""} + {"case_change_action", "NONE"} + ,{"is_case_sensitive", 0} }); }else if (OpName == "Sub") { ImportNodeOneOut(node, 2, 1, { @@ -639,45 +554,34 @@ }); }else if (OpName == "TfIdfVectorizer") { ImportNodeOneOut(node, 1, 1, { - {"max_gram_length","", ""} - ,{"max_skip_count","", ""} - ,{"min_gram_length","", ""} - ,{"mode","", ""} - ,{"ngram_counts","", ""} - ,{"ngram_indexes","", ""} - ,{"pool_int64s","", ""} - ,{"pool_strings","", ""} - ,{"weights","", ""} }); }else if (OpName == "ThresholdedRelu") { ImportNodeOneOut(node, 1, 1, { - {"alpha","float","1.0"} + {"alpha", (float)1.0} }); }else if (OpName == "Tile") { ImportNodeOneOut(node, 2, 1, { }); }else if (OpName == "TopK") { ImportNodeMultipleOuts(node, 2, 2, { - {"axis","int","-1"} - ,{"largest","int","1"} - ,{"sorted","int","1"} + {"axis", -1} + ,{"largest", 1} + ,{"sorted", 1} }); }else if (OpName == "Transpose") { ImportNodeOneOut(node, 1, 1, { - {"perm","", ""} }); }else if (OpName == "Unique") { ImportNodeMultipleOuts(node, 1, 4, { - {"axis","", ""} - ,{"sorted","int","1"} + {"sorted", 1} }); }else if (OpName == "Unsqueeze") { ImportNodeOneOut(node, 1, 1, { - {"axes","ints", ""} + {"axes", std::vector {}} }); }else if (OpName == "Upsample") { ImportNodeOneOut(node, 2, 1, { - {"mode","str","nearest"} + {"mode", "nearest"} }); }else if (OpName == "Where") { ImportNodeOneOut(node, 3, 1, { @@ -685,4 +589,4 @@ }else if (OpName == "Xor") { ImportNodeOneOut(node, 2, 1, { }); - } \ No newline at end of file + } diff --git a/src/dialect/onnx/gen_doc.py b/src/dialect/onnx/gen_doc.py index 4141556..ed99e57 100644 --- a/src/dialect/onnx/gen_doc.py +++ b/src/dialect/onnx/gen_doc.py @@ -368,17 +368,17 @@ def gen_code(schema,fefile) : ("MaxPool", "ImportNodeMaxPool"), #("Transpose", "ImportNodeTranspose") ]) - special_type = dict([ - ("AveragePool "+"kernel_shape", '"ints", ""'), - ("MaxPool "+"kernel_shape", '"ints", ""'), - ("Cast "+"to", '"int", "0"'), - ("Concat "+"axis", '"int", "0"'), - ("Conv "+"group", '"int", "1"'), - ("Unsqueeze "+"axes", '"ints", ""'), - ("RNN "+"activation_alpha", '"floats", "{}"'), - ("RNN "+"activation_beta", '"floats", "{}"'), - ("RNN "+"activations", '"", "{Tannh, Tanh}"'), - ("LRN "+"size", '"int", ""') + list_str = 'std::vector' + empty_ints = list_str+' {}' + empty_floats = list_str+' {}' + special_default = dict([ + ("AveragePool "+"kernel_shape", empty_ints), + ("MaxPool "+"kernel_shape", empty_ints), + ("Cast "+"to", '0'), + ("Concat "+"axis", '0'), + ("Unsqueeze "+"axes", empty_ints), + ("RNN "+"activation_alpha", empty_floats), + ("RNN "+"activation_beta", empty_floats) ]) line_indent = ' ' fefile.write(' '+'}else if (OpName == "'+schema.name+'") {\n') @@ -400,21 +400,9 @@ def gen_code(schema,fefile) : if schema.attributes: first_attr = True for _, attr in sorted(schema.attributes.items()): - attr_line = line_indent+line_indent+line_indent+line_indent - if not first_attr: - attr_line += ',{' - else : - attr_line += ' {' - first_attr = False - - attr_line += '"'+attr.name+'",' - - if schema.name+' '+attr.name in special_type: - attr_line += special_type[schema.name+' '+attr.name] - # option holds either required or default value - elif attr.required: - attr_line += '"", ""' - + #only generate default attr list + if schema.name+' '+attr.name in special_default: + attr_value = special_default[schema.name+' '+attr.name] elif attr.default_value.name: default_value = helper.get_attribute_value(attr.default_value) @@ -430,28 +418,35 @@ def gen_code(schema,fefile) : return str(value) if isinstance(default_value, list): + value = default_value[0] default_value = [format_value(val) for val in default_value] + attr_option_str = '{}'.format(default_value) + attr_option_str = attr_option_str.replace('[', '{', 1) + attr_option_str = attr_option_str.replace(']', '}', 1) # TODO the list type is homogenous or htergeneous? if isinstance(value, float) : - attr_type_str = '"floats"' + attr_type_str = list_str+'' + attr_option_str = attr_option_str.replace("'", '') elif isinstance(value, int) : - attr_type_str = '"ints"' + attr_type_str = list_str+'' + attr_option_str = attr_option_str.replace("'", '') elif isinstance(value, str) : - attr_type_str = '"strs"' + attr_type_str = list_str+'' + attr_option_str = attr_option_str.replace("'", '"') elif isinstance(value, (bytes, bytearray)) : - attr_type_str = '"strs"' + attr_type_str = list_str+'' + attr_option_str = attr_option_str.replace("'", '"') else : attr_type_str = '"unknowns"' - attr_option_str = '"{}"'.format(default_value) - attr_option_str = attr_option_str.replace('[', '{', 1) - attr_option_str = attr_option_str.replace(']', '}', 1) else: if isinstance(default_value, float) : - attr_type_str = '"float"' + attr_type_str = '(float)' + attr_option_str = default_value elif isinstance(default_value, int) : - attr_type_str = '"int"' + attr_option_str = default_value + attr_type_str='' elif isinstance(default_value, str) : attr_type_str = '"str"' elif isinstance(default_value, (bytes, bytearray)) : @@ -459,11 +454,25 @@ def gen_code(schema,fefile) : else : attr_type_str = '"unknown"' default_value = format_value(default_value) - attr_option_str = '"{}"'.format(default_value) - attr_line += attr_type_str+','+attr_option_str + if attr_type_str == '"str"' : + attr_option_str = '"'+default_value+'"' + attr_type_str='' + else : + attr_option_str = default_value + attr_value = attr_type_str+attr_option_str else: - #TODO why? - attr_line += '"", ""' + #no default value + continue + + attr_line = line_indent+line_indent+line_indent+line_indent + if not first_attr: + attr_line += ',{' + else : + attr_line += ' {' + first_attr = False + + attr_line += '"'+attr.name+'", ' + attr_line += attr_value attr_line += '}\n' fefile.write(attr_line) fefile.write(line_indent+line_indent+line_indent+'});\n') diff --git a/src/pass/lower_frontend_to_krnl.cpp b/src/pass/lower_frontend_to_krnl.cpp index 3d899ee..58d603a 100644 --- a/src/pass/lower_frontend_to_krnl.cpp +++ b/src/pass/lower_frontend_to_krnl.cpp @@ -420,8 +420,8 @@ Value mapToLowerScalarOp( // Constant 1) auto loc = op->getLoc(); Value operand = operands[0]; - auto alphaAttr = op->getAttrOfType("HardSigmoid.alpha"); - auto betaAttr = op->getAttrOfType("HardSigmoid.beta"); + auto alphaAttr = op->getAttrOfType("alpha"); + auto betaAttr = op->getAttrOfType("beta"); auto elementType = result_types[0]; auto zero = rewriter.create(loc, FloatAttr::get(elementType, 0)); @@ -455,7 +455,7 @@ Value mapToLowerScalarOp(Operation *op, ArrayRef result_types, Value operand = operands[0]; auto elementType = result_types[0]; - auto alphaAttr = op->getAttrOfType("Elu.alpha"); + auto alphaAttr = op->getAttrOfType("alpha"); auto zero = rewriter.create(loc, FloatAttr::get(elementType, 0)); auto one = rewriter.create(loc, FloatAttr::get(elementType, 1)); auto alpha = rewriter.create(loc, alphaAttr); @@ -508,7 +508,7 @@ Value mapToLowerScalarOp(Operation *op, Value operand = operands[0]; auto elementType = result_types[0]; - auto alphaAttr = op->getAttrOfType("LeakyRelu.alpha"); + auto alphaAttr = op->getAttrOfType("alpha"); auto zero = rewriter.create(loc, FloatAttr::get(elementType, 0)); auto alpha = rewriter.create(loc, alphaAttr); auto lessThanZero = @@ -533,8 +533,8 @@ Value mapToLowerScalarOp(Operation *op, ArrayRef result_types, // alpha))) auto loc = op->getLoc(); Value operand = operands[0]; - auto alphaAttr = op->getAttrOfType("Selu.alpha"); - auto gammaAttr = op->getAttrOfType("Selu.gamma"); + auto alphaAttr = op->getAttrOfType("alpha"); + auto gammaAttr = op->getAttrOfType("gamma"); auto elementType = result_types[0]; auto zero = rewriter.create(loc, FloatAttr::get(elementType, 0)); @@ -836,7 +836,7 @@ struct ONNXSoftmaxOpLowering : public ConversionPattern { // exp_x / sum auto tensorType = (*op->result_type_begin()).cast(); int64_t rank = tensorType.getRank(); - int64_t axis = op->getAttrOfType("Softmax.axis").getInt(); + int64_t axis = op->getAttrOfType("axis").getInt(); axis = axis >= 0 ? axis : rank + axis; assert(axis >= -rank && axis <= rank - 1); diff --git a/test/mlir/onnx/onnx_lowering.mlir b/test/mlir/onnx/onnx_lowering.mlir index 3ffce9a..c6c5927 100644 --- a/test/mlir/onnx/onnx_lowering.mlir +++ b/test/mlir/onnx/onnx_lowering.mlir @@ -385,7 +385,7 @@ func @test_min(%arg0 : tensor<10x10xf32>, %arg1 : tensor<10x10xf32>) -> tensor<* } func @test_elu(%arg0 : tensor) -> tensor<*xf32> { - %0 = "onnx.Elu"(%arg0) {Elu.alpha=2.0:f32} : (tensor) -> tensor<*xf32> + %0 = "onnx.Elu"(%arg0) {alpha=2.0:f32} : (tensor) -> tensor<*xf32> "std.return"(%0) : (tensor<*xf32>) -> () // CHECK-LABEL: test_elu @@ -411,7 +411,7 @@ func @test_elu(%arg0 : tensor) -> tensor<*xf32> { } func @test_leakyrelu(%arg0 : tensor) -> tensor<*xf32> { - %0 = "onnx.LeakyRelu"(%arg0) {LeakyRelu.alpha=1.0:f32} : (tensor) -> tensor<*xf32> + %0 = "onnx.LeakyRelu"(%arg0) {alpha=1.0:f32} : (tensor) -> tensor<*xf32> "std.return"(%0) : (tensor<*xf32>) -> () // CHECK-LABEL: test_leakyrelu @@ -434,7 +434,7 @@ func @test_leakyrelu(%arg0 : tensor) -> tensor<*xf32> { } func @test_selu(%arg0 : tensor) -> tensor<*xf32> { - %0 = "onnx.Selu"(%arg0) {Selu.alpha=1.0:f32, Selu.gamma=2.0:f32} : (tensor) -> tensor<*xf32> + %0 = "onnx.Selu"(%arg0) {alpha=1.0:f32, gamma=2.0:f32} : (tensor) -> tensor<*xf32> "std.return"(%0) : (tensor<*xf32>) -> () // CHECK-LABEL: test_selu @@ -461,7 +461,7 @@ func @test_selu(%arg0 : tensor) -> tensor<*xf32> { } func @test_hardsigmoid(%arg0 : tensor) -> tensor<*xf32> { - %0 = "onnx.HardSigmoid"(%arg0) {HardSigmoid.alpha=1.0:f32, HardSigmoid.beta=2.0:f32} : (tensor) -> tensor<*xf32> + %0 = "onnx.HardSigmoid"(%arg0) {alpha=1.0:f32, beta=2.0:f32} : (tensor) -> tensor<*xf32> "std.return"(%0) : (tensor<*xf32>) -> () // CHECK-LABEL: test_hardsigmoid @@ -535,7 +535,7 @@ func @test_add_with_broadcasting(%arg0 : tensor, %arg1 : tensor } func @test_softmax(%arg0 : tensor<10x10xf32>) -> tensor<*xf32> { - %0 = "onnx.Softmax"(%arg0) {Softmax.axis=1:i32} : (tensor<10x10xf32>) -> tensor<*xf32> + %0 = "onnx.Softmax"(%arg0) {axis=1:i32} : (tensor<10x10xf32>) -> tensor<*xf32> "std.return"(%0) : (tensor<*xf32>) -> () // CHECK-LABEL: test_softmax diff --git a/test/mlir/onnx/onnx_lowering_with_dealloc.mlir b/test/mlir/onnx/onnx_lowering_with_dealloc.mlir index 385dc3c..1286041 100644 --- a/test/mlir/onnx/onnx_lowering_with_dealloc.mlir +++ b/test/mlir/onnx/onnx_lowering_with_dealloc.mlir @@ -648,8 +648,8 @@ func @test_min_min(%arg0 : tensor<10x10xf32>, %arg1 : tensor<10x10xf32>) -> tens } func @test_elu_elu(%arg0 : tensor) -> tensor<*xf32> { - %0 = "onnx.Elu"(%arg0) {Elu.alpha=2.0:f32} : (tensor) -> tensor<*xf32> - %1 = "onnx.Elu"(%0) {Elu.alpha=2.0:f32} : (tensor<*xf32>) -> tensor<*xf32> + %0 = "onnx.Elu"(%arg0) {alpha=2.0:f32} : (tensor) -> tensor<*xf32> + %1 = "onnx.Elu"(%0) {alpha=2.0:f32} : (tensor<*xf32>) -> tensor<*xf32> "std.return"(%1) : (tensor<*xf32>) -> () // CHECK-LABEL: test_elu_elu @@ -701,8 +701,8 @@ func @test_elu_elu(%arg0 : tensor) -> tensor<*xf32> { } func @test_leakyrelu_leakyrelu(%arg0 : tensor) -> tensor<*xf32> { - %0 = "onnx.LeakyRelu"(%arg0) {LeakyRelu.alpha=1.0:f32} : (tensor) -> tensor<*xf32> - %1 = "onnx.LeakyRelu"(%0) {LeakyRelu.alpha=1.0:f32} : (tensor<*xf32>) -> tensor<*xf32> + %0 = "onnx.LeakyRelu"(%arg0) {alpha=1.0:f32} : (tensor) -> tensor<*xf32> + %1 = "onnx.LeakyRelu"(%0) {alpha=1.0:f32} : (tensor<*xf32>) -> tensor<*xf32> "std.return"(%1) : (tensor<*xf32>) -> () // CHECK-LABEL: test_leakyrelu_leakyrelu @@ -748,8 +748,8 @@ func @test_leakyrelu_leakyrelu(%arg0 : tensor) -> tensor<*xf32> { } func @test_selu_selu(%arg0 : tensor) -> tensor<*xf32> { - %0 = "onnx.Selu"(%arg0) {Selu.alpha=1.0:f32, Selu.gamma=2.0:f32} : (tensor) -> tensor<*xf32> - %1 = "onnx.Selu"(%0) {Selu.alpha=1.0:f32, Selu.gamma=2.0:f32} : (tensor<*xf32>) -> tensor<*xf32> + %0 = "onnx.Selu"(%arg0) {alpha=1.0:f32, gamma=2.0:f32} : (tensor) -> tensor<*xf32> + %1 = "onnx.Selu"(%0) {alpha=1.0:f32, gamma=2.0:f32} : (tensor<*xf32>) -> tensor<*xf32> "std.return"(%1) : (tensor<*xf32>) -> () // CHECK-LABEL: test_selu_selu @@ -803,8 +803,8 @@ func @test_selu_selu(%arg0 : tensor) -> tensor<*xf32> { } func @test_hardsigmoid_hardsigmoid(%arg0 : tensor) -> tensor<*xf32> { - %0 = "onnx.HardSigmoid"(%arg0) {HardSigmoid.alpha=1.0:f32, HardSigmoid.beta=2.0:f32} : (tensor) -> tensor<*xf32> - %1 = "onnx.HardSigmoid"(%0) {HardSigmoid.alpha=1.0:f32, HardSigmoid.beta=2.0:f32} : (tensor<*xf32>) -> tensor<*xf32> + %0 = "onnx.HardSigmoid"(%arg0) {alpha=1.0:f32, beta=2.0:f32} : (tensor) -> tensor<*xf32> + %1 = "onnx.HardSigmoid"(%0) {alpha=1.0:f32, beta=2.0:f32} : (tensor<*xf32>) -> tensor<*xf32> "std.return"(%1) : (tensor<*xf32>) -> () // CHECK-LABEL: test_hardsigmoid_hardsigmoid diff --git a/third_party/variant b/third_party/variant new file mode 160000 index 0000000..3c7fc82 --- /dev/null +++ b/third_party/variant @@ -0,0 +1 @@ +Subproject commit 3c7fc8266bb46046b42c2dc2663f9f505f0cec28