From b02652dd767d2e3a7a191b6fd957857619a2531f Mon Sep 17 00:00:00 2001 From: GHEORGHE-TEOD BERCEA Date: Tue, 26 Nov 2019 13:55:44 -0500 Subject: [PATCH] [MLIR] Lowering of frontend dialect to KRNL dialect (#382) * Partial support for lowering operations to KRNL dialect. * Attempt to lower to KRNL IR. * Update file. * Add lowering. * Address comments. Fix alloc dynamic dimensions. Correctly link StandardOps. * Temporarily remove deallocation of locally allocated tensors. --- CMakeLists.txt | 2 + MLIR.cmake | 169 ++--- src/CMakeLists.txt | 8 +- src/builder/CMakeLists.txt | 3 +- src/compiler/CMakeLists.txt | 9 +- src/compiler/dialect/krnl/krnl_ops.cpp | 10 +- src/compiler/dialect/onnx/onnxop.inc | 616 +++++++++---------- src/compiler/pass/lower_frontend_to_krnl.cpp | 282 +++++++++ src/compiler/pass/passes.hpp | 3 +- src/compiler/pass/shape_inference_pass.cpp | 6 + src/compiler/tool/onnf_opt/CMakeLists.txt | 10 +- src/main.cpp | 1 + test/CMakeLists.txt | 3 - 13 files changed, 709 insertions(+), 413 deletions(-) create mode 100644 src/compiler/pass/lower_frontend_to_krnl.cpp diff --git a/CMakeLists.txt b/CMakeLists.txt index c92e97a..38f64ae 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -46,3 +46,5 @@ add_subdirectory(src/builder) add_subdirectory(src/compiler) add_subdirectory(src) +add_subdirectory(test) + diff --git a/MLIR.cmake b/MLIR.cmake index 74e3372..a01cbb0 100644 --- a/MLIR.cmake +++ b/MLIR.cmake @@ -44,89 +44,95 @@ set( ) include_directories(${MLIR_INCLUDE_PATHS}) -find_library(MLIR_LIB_ANALYSIS - NAMES MLIRAnalysis - PATHS ${LLVM_PROJECT_LIB} - NO_DEFAULT_PATH) - -find_library(MLIR_LIB_IR NAMES MLIRIR PATHS ${LLVM_PROJECT_LIB} NO_DEFAULT_PATH) - -find_library(MLIR_LIB_PARSER - NAMES MLIRParser - PATHS ${LLVM_PROJECT_LIB} - NO_DEFAULT_PATH) - -find_library(MLIR_LIB_PASS - NAMES MLIRPass - PATHS ${LLVM_PROJECT_LIB} - NO_DEFAULT_PATH) - -find_library(MLIR_LIB_TRANSFORMS - NAMES MLIRTransforms - PATHS ${LLVM_PROJECT_LIB} - NO_DEFAULT_PATH) - -find_library(MLIR_LIB_VECTOR_OPS - NAMES MLIRVectorOps - PATHS ${LLVM_PROJECT_LIB} - NO_DEFAULT_PATH) - -find_library(MLIR_LIB_SUPPORT - NAMES MLIRSupport - PATHS ${LLVM_PROJECT_LIB} - NO_DEFAULT_PATH) - -find_library(MLIR_LIB_STANDARD_OPS - NAMES MLIRStandardOps - PATHS ${LLVM_PROJECT_LIB} - NO_DEFAULT_PATH) - -find_library(MLIR_LIB_OPT_MAIN - NAMES MLIROptMain - PATHS ${LLVM_PROJECT_LIB} - NO_DEFAULT_PATH) - -find_library(MLIR_LLVM_IR - NAMES MLIRLLVMIR - PATHS ${LLVM_PROJECT_LIB} - NO_DEFAULT_PATH) - -find_library(MLIR_LIB_TRANSFORM_UTILS - NAMES MLIRTransformUtils - PATHS ${LLVM_PROJECT_LIB} - NO_DEFAULT_PATH) - -find_library(LLVM_LIB_SUPPORT - NAMES LLVMSupport - PATHS ${LLVM_PROJECT_LIB} - NO_DEFAULT_PATH) - # Threading libraries required due to parallel pass execution. find_package(Threads REQUIRED) -set(MLIRLIBS - ${MLIR_LIB_ANALYSIS} - ${MLIR_LIB_IR} - ${MLIR_LIB_PARSER} - ${MLIR_LIB_PASS} - ${MLIR_LIB_TRANSFORMS} - ${MLIR_LIB_VECTOR_OPS} - ${MLIR_LIB_STANDARD_OPS} - ${MLIR_LIB_OPT_MAIN} - ${MLIR_LIB_SUPPORT} - ${MLIR_LIB_TRANSFORM_UTILS} - ${MLIR_LIB_ANALYSIS} - ${MLIR_LIB_IR} - ${MLIR_LIB_PARSER} - ${MLIR_LIB_PASS} - ${MLIR_LIB_TRANSFORMS} - ${MLIR_LIB_VECTOR_OPS} - ${MLIR_LIB_STANDARD_OPS} - ${MLIR_LIB_OPT_MAIN} - ${MLIR_LIB_SUPPORT} - ${MLIR_LIB_TRANSFORM_UTILS} - ${LLVM_LIB_SUPPORT} - Threads::Threads) +function(find_mlir_lib lib) + find_library(${lib} + NAMES ${lib} + PATHS ${LLVM_PROJECT_LIB} + NO_DEFAULT_PATH) +endfunction(find_mlir_lib) + +find_mlir_lib(MLIRAffineOps) +find_mlir_lib(MLIRAffineToStandard) +find_mlir_lib(MLIRAnalysis) +find_mlir_lib(MLIRExecutionEngine) +find_mlir_lib(MLIRIR) +find_mlir_lib(MLIRLLVMIR) +find_mlir_lib(MLIRLoopToStandard) +find_mlir_lib(MLIRParser) +find_mlir_lib(MLIRPass) +find_mlir_lib(MLIRStandardOps) +find_mlir_lib(MLIRStandardToLLVM) +find_mlir_lib(MLIRTargetLLVMIR) +find_mlir_lib(MLIRTransforms) +find_mlir_lib(MLIRTransforms) +find_mlir_lib(MLIRTransformUtils) +find_mlir_lib(MLIRSupport) +find_mlir_lib(MLIROptMain) + +find_mlir_lib(LLVMCore) +find_mlir_lib(LLVMSupport) +find_mlir_lib(LLVMAsmParser) +find_mlir_lib(LLVMBinaryFormat) +find_mlir_lib(LLVMRemarks) +find_mlir_lib(LLVMIRReader) +find_mlir_lib(LLVMTransformUtils) +find_mlir_lib(LLVMBitstreamReader) + +set(MLIRLibsOnce + MLIRAffineOps + MLIRAffineToStandard + MLIRAnalysis + MLIRExecutionEngine + MLIRIR + MLIRLLVMIR + MLIRLoopToStandard + MLIRParser + MLIRPass + MLIRStandardOps + MLIRStandardToLLVM + MLIRTargetLLVMIR + MLIRTransforms + MLIRAffineOps + MLIRAffineToStandard + MLIRAnalysis + MLIRExecutionEngine + MLIRIR + MLIRLLVMIR + MLIRLoopToStandard + MLIRParser + MLIRPass + MLIRStandardOps + MLIRStandardToLLVM + MLIRTargetLLVMIR + MLIRTransforms + MLIRTransformUtils + MLIRLoopOps + MLIRSupport + MLIROptMain + LLVMCore + LLVMSupport + LLVMAsmParser + LLVMIRReader + LLVMTransformUtils + LLVMBinaryFormat + LLVMRemarks + LLVMBitstreamReader) + +set(MLIRLibs + ${MLIRLibsOnce} + ${MLIRLibsOnce} + Threads::Threads) + +set(MLIRWholeArchiveLibs + MLIRAffineToStandard + MLIRAffineOps + MLIRLLVMIR + MLIRStandardOps + MLIRStandardToLLVM + MLIRLoopToStandard) function(whole_archive_link target lib_dir) get_property(link_flags TARGET ${target} PROPERTY LINK_FLAGS) @@ -155,6 +161,9 @@ function(whole_archive_link_mlir target) endfunction(whole_archive_link_mlir) function(whole_archive_link_onnf target) + foreach(LIB ${ARGN}) + add_dependencies(${target} ${LIB}) + endforeach(LIB) whole_archive_link(${target} ${CMAKE_BINARY_DIR}/lib ${ARGN}) endfunction(whole_archive_link_onnf) diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 5f24f72..c36eb0b 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -1,7 +1,9 @@ - add_executable(onnf main.cpp) + +target_link_libraries(onnf builder compiler ${MLIRLibs} ${Boost_LIBRARIES}) +whole_archive_link_mlir(onnf ${MLIRWholeArchiveLibs}) + target_include_directories(onnf PRIVATE ${CMAKE_SOURCE_DIR}) target_include_directories(onnf PRIVATE ${CMAKE_BINARY_DIR}) -target_link_libraries(onnf builder compiler ${Boost_LIBRARIES}) -install(TARGETS onnf DESTINATION bin) +install(TARGETS onnf DESTINATION bin) \ No newline at end of file diff --git a/src/builder/CMakeLists.txt b/src/builder/CMakeLists.txt index 3d8e603..1cc773b 100644 --- a/src/builder/CMakeLists.txt +++ b/src/builder/CMakeLists.txt @@ -5,7 +5,8 @@ add_library(builder target_include_directories(builder PRIVATE ${CMAKE_SOURCE_DIR}) target_include_directories(builder PRIVATE ${CMAKE_BINARY_DIR}) -target_link_libraries(builder compiler onnx ${MLIRLIBS} curses) + +target_link_libraries(builder compiler onnx ${MLIRLibs} curses) target_include_directories(builder PRIVATE ${CMAKE_SOURCE_DIR}/third_party/onnx diff --git a/src/compiler/CMakeLists.txt b/src/compiler/CMakeLists.txt index dbc9f3b..91dc4e0 100644 --- a/src/compiler/CMakeLists.txt +++ b/src/compiler/CMakeLists.txt @@ -10,9 +10,10 @@ add_library( dialect/krnl/parser_helper.hpp pass/shape_inference_pass.cpp pass/shape_inference_interface.hpp - pass/passes.hpp dialect/onnx/onnxop.inc - pass/onnx_combine.cpp) + pass/onnx_combine.cpp + pass/lower_frontend_to_krnl.cpp + pass/passes.hpp) # Include root src directory. target_include_directories(compiler PRIVATE ${ONNF_SRC_ROOT}) @@ -41,7 +42,7 @@ target_link_libraries(compiler ${Boost_LIBRARIES} ${CMAKE_THREAD_LIBS_INIT} ${CMAKE_DL_LIBS} - ${MLIRLIBS} + ${MLIRLibs} curses) add_subdirectory(tool) @@ -68,4 +69,4 @@ onnf_tablegen(krnl.hpp.inc -gen-op-decls) onnf_tablegen(krnl.cpp.inc -gen-op-defs) add_public_tablegen_target(gen_krnl_ops) add_dependencies(compiler gen_krnl_ops) -add_dependencies(onnf-opt gen_krnl_ops) \ No newline at end of file +add_dependencies(onnf-opt gen_krnl_ops) diff --git a/src/compiler/dialect/krnl/krnl_ops.cpp b/src/compiler/dialect/krnl/krnl_ops.cpp index 3ba483d..5dbd16d 100644 --- a/src/compiler/dialect/krnl/krnl_ops.cpp +++ b/src/compiler/dialect/krnl/krnl_ops.cpp @@ -176,7 +176,7 @@ void KrnlIterateOp::build(Builder* builder, OperationState& result, result.addAttribute(KrnlIterateOp::getBoundTypesAttrName(), builder->getI32ArrayAttr(bound_types)); - // Create a region and a block for the body. The arguments of the region is + // Create a region and a block for the body. The arguments of the region are // the loop induction variables; there can be multiple induction variables // associated with the same krnl.iterate operation. Region* bodyRegion = result.addRegion(); @@ -207,16 +207,16 @@ void print(OpAsmPrinter& p, KrnlIterateOp& op) { auto print_bound = [&](ArrayRef bound_types, size_t idx) { IntegerAttr type = bound_types[idx].dyn_cast(); if (type.getValue().getSExtValue() == 0) { - // Bound is an operand. - p.printOperand(*next_operand_bound); - next_operand_bound = std::next(next_operand_bound); - } else { // Bound is an integer attribute. auto bound_idx = idx / 2; auto is_ub = idx % 2; IntegerAttr bound = op.getAttrOfType( KrnlIterateOp::getBoundAttrName(bound_idx, is_ub)); p << bound.getValue().getSExtValue(); + } else { + // Bound is an operand. + p.printOperand(*next_operand_bound); + next_operand_bound = std::next(next_operand_bound); } }; diff --git a/src/compiler/dialect/onnx/onnxop.inc b/src/compiler/dialect/onnx/onnxop.inc index 6cfafdb..404c0d5 100644 --- a/src/compiler/dialect/onnx/onnxop.inc +++ b/src/compiler/dialect/onnx/onnxop.inc @@ -6,8 +6,8 @@ def ONNXAbsOp:ONNX_Op<"Abs", "(Tensor) where the absolute is, y = abs(x), is applied to" "the tensor elementwise." }]; - let arguments = (ins AnyTensor:$X); - let results = (outs AnyTensor); + let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$X); + let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>); } def ONNXAcosOp:ONNX_Op<"Acos", @@ -16,8 +16,8 @@ def ONNXAcosOp:ONNX_Op<"Acos", let description = [{ "Calculates the arccosine (inverse of cosine) of the given input tensor, element-wise." }]; - let arguments = (ins AnyTensor:$input); - let results = (outs AnyTensor); + let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$input); + let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>); } def ONNXAcoshOp:ONNX_Op<"Acosh", @@ -26,8 +26,8 @@ def ONNXAcoshOp:ONNX_Op<"Acosh", let description = [{ "Calculates the hyperbolic arccosine of the given input tensor element-wise." }]; - let arguments = (ins AnyTensor:$input); - let results = (outs AnyTensor); + let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$input); + let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>); } def ONNXAddOp:ONNX_Op<"Add", @@ -39,8 +39,8 @@ def ONNXAddOp:ONNX_Op<"Add", "" "This operator supports **multidirectional (i.e., Numpy-style) broadcasting**; for more details please check [the doc](Broadcasting.md)." }]; - let arguments = (ins AnyTensor:$A, AnyTensor:$B); - let results = (outs AnyTensor); + let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$A, AnyTypeOf<[AnyMemRef, AnyTensor]>:$B); + let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>); } def ONNXAndOp:ONNX_Op<"And", @@ -52,8 +52,8 @@ def ONNXAndOp:ONNX_Op<"And", "" "This operator supports **multidirectional (i.e., Numpy-style) broadcasting**; for more details please check [the doc](Broadcasting.md)." }]; - let arguments = (ins AnyTensor:$A, AnyTensor:$B); - let results = (outs AnyTensor); + let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$A, AnyTypeOf<[AnyMemRef, AnyTensor]>:$B); + let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>); } def ONNXArgMaxOp:ONNX_Op<"ArgMax", @@ -65,8 +65,8 @@ def ONNXArgMaxOp:ONNX_Op<"ArgMax", "If keepdims equal 0, then the resulted tensor have the reduced dimension pruned. " "The type of the output tensor is integer." }]; - let arguments = (ins AnyTensor:$data); - let results = (outs AnyTensor); + let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$data); + let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>); } def ONNXArgMinOp:ONNX_Op<"ArgMin", @@ -78,8 +78,8 @@ def ONNXArgMinOp:ONNX_Op<"ArgMin", "If keepdims equal 0, then the resulted tensor have the reduced dimension pruned. " "The type of the output tensor is integer." }]; - let arguments = (ins AnyTensor:$data); - let results = (outs AnyTensor); + let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$data); + let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>); } def ONNXAsinOp:ONNX_Op<"Asin", @@ -88,8 +88,8 @@ def ONNXAsinOp:ONNX_Op<"Asin", let description = [{ "Calculates the arcsine (inverse of sine) of the given input tensor, element-wise." }]; - let arguments = (ins AnyTensor:$input); - let results = (outs AnyTensor); + let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$input); + let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>); } def ONNXAsinhOp:ONNX_Op<"Asinh", @@ -98,8 +98,8 @@ def ONNXAsinhOp:ONNX_Op<"Asinh", let description = [{ "Calculates the hyperbolic arcsine of the given input tensor element-wise." }]; - let arguments = (ins AnyTensor:$input); - let results = (outs AnyTensor); + let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$input); + let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>); } def ONNXAtanOp:ONNX_Op<"Atan", @@ -108,8 +108,8 @@ def ONNXAtanOp:ONNX_Op<"Atan", let description = [{ "Calculates the arctangent (inverse of tangent) of the given input tensor, element-wise." }]; - let arguments = (ins AnyTensor:$input); - let results = (outs AnyTensor); + let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$input); + let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>); } def ONNXAtanhOp:ONNX_Op<"Atanh", @@ -118,8 +118,8 @@ def ONNXAtanhOp:ONNX_Op<"Atanh", let description = [{ "Calculates the hyperbolic arctangent of the given input tensor element-wise." }]; - let arguments = (ins AnyTensor:$input); - let results = (outs AnyTensor); + let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$input); + let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>); } def ONNXAveragePoolOp:ONNX_Op<"AveragePool", @@ -156,8 +156,8 @@ def ONNXAveragePoolOp:ONNX_Op<"AveragePool", " The output of each pooling window is divided by the number of elements (exclude pad when attribute count_include_pad is zero)." " " }]; - let arguments = (ins AnyTensor:$X); - let results = (outs AnyTensor); + let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$X); + let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>); } def ONNXBatchNormalizationOp:ONNX_Op<"BatchNormalization", @@ -175,8 +175,8 @@ def ONNXBatchNormalizationOp:ONNX_Op<"BatchNormalization", "to flatten the input shape to (N x C*D1*D2 ..*Dn) before a BatchNormalization Op." "This operator has **optional** inputs/outputs. See [the doc](IR.md) for more details about the representation of optional arguments. An empty string may be used in the place of an actual argument's name to indicate a missing argument. Trailing optional arguments (those not followed by an argument that is present) may also be simply omitted." }]; - let arguments = (ins AnyTensor:$X, AnyTensor:$scale, AnyTensor:$B, AnyTensor:$mean, AnyTensor:$var); - let results = (outs AnyTensor, AnyTensor, AnyTensor, AnyTensor, AnyTensor); + let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$X, AnyTypeOf<[AnyMemRef, AnyTensor]>:$scale, AnyTypeOf<[AnyMemRef, AnyTensor]>:$B, AnyTypeOf<[AnyMemRef, AnyTensor]>:$mean, AnyTypeOf<[AnyMemRef, AnyTensor]>:$var); + let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>, AnyTypeOf<[AnyMemRef, AnyTensor]>, AnyTypeOf<[AnyMemRef, AnyTensor]>, AnyTypeOf<[AnyMemRef, AnyTensor]>, AnyTypeOf<[AnyMemRef, AnyTensor]>); } def ONNXBitShiftOp:ONNX_Op<"BitShift", @@ -196,8 +196,8 @@ def ONNXBitShiftOp:ONNX_Op<"BitShift", " not necessarily identical." "This operator supports **multidirectional (i.e., Numpy-style) broadcasting**; for more details please check [the doc](Broadcasting.md)." }]; - let arguments = (ins AnyTensor:$X, AnyTensor:$Y); - let results = (outs AnyTensor); + let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$X, AnyTypeOf<[AnyMemRef, AnyTensor]>:$Y); + let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>); } def ONNXCastOp:ONNX_Op<"Cast", @@ -224,8 +224,8 @@ def ONNXCastOp:ONNX_Op<"Cast", "For example, a 64-bit float 3.1415926459 may be round to a 32-bit float 3.141592. Similarly, converting" "an integer 36 to Boolean may produce 1 because we truncate bits which can't be stored in the targeted type." }]; - let arguments = (ins AnyTensor:$input); - let results = (outs AnyTensor); + let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$input); + let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>); } def ONNXCeilOp:ONNX_Op<"Ceil", @@ -236,8 +236,8 @@ def ONNXCeilOp:ONNX_Op<"Ceil", "(Tensor) where the ceil is, y = ceil(x), is applied to" "the tensor elementwise." }]; - let arguments = (ins AnyTensor:$X); - let results = (outs AnyTensor); + let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$X); + let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>); } def ONNXClipOp:ONNX_Op<"Clip", @@ -248,8 +248,8 @@ def ONNXClipOp:ONNX_Op<"Clip", "specified by the inputs 'min' and 'max'. They default to" "numeric_limits::lowest() and numeric_limits::max(), respectively." }]; - let arguments = (ins AnyTensor:$input, AnyTensor:$min, AnyTensor:$max); - let results = (outs AnyTensor); + let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$input, AnyTypeOf<[AnyMemRef, AnyTensor]>:$min, AnyTypeOf<[AnyMemRef, AnyTensor]>:$max); + let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>); } def ONNXCompressOp:ONNX_Op<"Compress", @@ -261,8 +261,8 @@ def ONNXCompressOp:ONNX_Op<"Compress", " Compress behaves like numpy.compress: https://docs.scipy.org/doc/numpy/reference/generated/numpy.compress.html" " " }]; - let arguments = (ins AnyTensor:$input, AnyTensor:$condition); - let results = (outs AnyTensor); + let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$input, AnyTypeOf<[AnyMemRef, AnyTensor]>:$condition); + let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>); } def ONNXConcatOp:ONNX_Op<"Concat", @@ -271,8 +271,8 @@ def ONNXConcatOp:ONNX_Op<"Concat", let description = [{ "Concatenate a list of tensors into a single tensor. All input tensors must have the same shape, except for the dimension size of the axis to concatenate on." }]; - let arguments = (ins Variadic:$inputs); - let results = (outs AnyTensor); + let arguments = (ins Variadic>:$inputs); + let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>); } def ONNXConcatFromSequenceOp:ONNX_Op<"ConcatFromSequence", @@ -284,8 +284,8 @@ def ONNXConcatFromSequenceOp:ONNX_Op<"ConcatFromSequence", "By default 'new_axis' is 0, the behavior is similar to numpy.concatenate." "When 'new_axis' is 1, the behavior is similar to numpy.stack." }]; - let arguments = (ins AnyTensor:$input_sequence); - let results = (outs AnyTensor); + let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$input_sequence); + let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>); } def ONNXConstantOp:ONNX_Op<"Constant", @@ -296,7 +296,7 @@ def ONNXConstantOp:ONNX_Op<"Constant", "must be specified." }]; let arguments = (ins ); - let results = (outs AnyTensor); + let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>); } def ONNXConstantOfShapeOp:ONNX_Op<"ConstantOfShape", @@ -305,8 +305,8 @@ def ONNXConstantOfShapeOp:ONNX_Op<"ConstantOfShape", let description = [{ "Generate a tensor with given value and shape." }]; - let arguments = (ins AnyTensor:$input); - let results = (outs AnyTensor); + let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$input); + let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>); } def ONNXConvOp:ONNX_Op<"Conv", @@ -316,8 +316,8 @@ def ONNXConvOp:ONNX_Op<"Conv", "The convolution operator consumes an input tensor and a filter, and" "computes the output." }]; - let arguments = (ins AnyTensor:$X, AnyTensor:$W, AnyTensor:$B); - let results = (outs AnyTensor); + let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$X, AnyTypeOf<[AnyMemRef, AnyTensor]>:$W, AnyTypeOf<[AnyMemRef, AnyTensor]>:$B); + let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>); } def ONNXConvIntegerOp:ONNX_Op<"ConvInteger", @@ -327,8 +327,8 @@ def ONNXConvIntegerOp:ONNX_Op<"ConvInteger", "The integer convolution operator consumes an input tensor, its zero-point, a filter, and its zero-point," "and computes the output. The production MUST never overflow. The accumulation may overflow if and only if in 32 bits." }]; - let arguments = (ins AnyTensor:$x, AnyTensor:$w, AnyTensor:$x_zero_point, AnyTensor:$w_zero_point); - let results = (outs AnyTensor); + let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$x, AnyTypeOf<[AnyMemRef, AnyTensor]>:$w, AnyTypeOf<[AnyMemRef, AnyTensor]>:$x_zero_point, AnyTypeOf<[AnyMemRef, AnyTensor]>:$w_zero_point); + let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>); } def ONNXConvTransposeOp:ONNX_Op<"ConvTranspose", @@ -350,8 +350,8 @@ def ONNXConvTransposeOp:ONNX_Op<"ConvTranspose", "" " " }]; - let arguments = (ins AnyTensor:$X, AnyTensor:$W, AnyTensor:$B); - let results = (outs AnyTensor); + let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$X, AnyTypeOf<[AnyMemRef, AnyTensor]>:$W, AnyTypeOf<[AnyMemRef, AnyTensor]>:$B); + let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>); } def ONNXCosOp:ONNX_Op<"Cos", @@ -360,8 +360,8 @@ def ONNXCosOp:ONNX_Op<"Cos", let description = [{ "Calculates the cosine of the given input tensor, element-wise." }]; - let arguments = (ins AnyTensor:$input); - let results = (outs AnyTensor); + let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$input); + let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>); } def ONNXCoshOp:ONNX_Op<"Cosh", @@ -370,8 +370,8 @@ def ONNXCoshOp:ONNX_Op<"Cosh", let description = [{ "Calculates the hyperbolic cosine of the given input tensor element-wise." }]; - let arguments = (ins AnyTensor:$input); - let results = (outs AnyTensor); + let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$input); + let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>); } def ONNXCumSumOp:ONNX_Op<"CumSum", @@ -399,8 +399,8 @@ def ONNXCumSumOp:ONNX_Op<"CumSum", "```" " " }]; - let arguments = (ins AnyTensor:$x, AnyTensor:$axis); - let results = (outs AnyTensor); + let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$x, AnyTypeOf<[AnyMemRef, AnyTensor]>:$axis); + let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>); } def ONNXDepthToSpaceOp:ONNX_Op<"DepthToSpace", @@ -435,8 +435,8 @@ def ONNXDepthToSpaceOp:ONNX_Op<"DepthToSpace", "y = np.reshape(tmp, [b, c // (blocksize ** 2), h * blocksize, w * blocksize])" "" }]; - let arguments = (ins AnyTensor:$input); - let results = (outs AnyTensor); + let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$input); + let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>); } def ONNXDequantizeLinearOp:ONNX_Op<"DequantizeLinear", @@ -448,8 +448,8 @@ def ONNXDequantizeLinearOp:ONNX_Op<"DequantizeLinear", "'x_zero_point' and 'x' must have same type. 'x' and 'y' must have same shape. In the case of dequantizing int32," "there's no zero point (zero point is supposed to be 0)." }]; - let arguments = (ins AnyTensor:$x, AnyTensor:$x_scale, AnyTensor:$x_zero_point); - let results = (outs AnyTensor); + let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$x, AnyTypeOf<[AnyMemRef, AnyTensor]>:$x_scale, AnyTypeOf<[AnyMemRef, AnyTensor]>:$x_zero_point); + let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>); } def ONNXDetOp:ONNX_Op<"Det", @@ -462,8 +462,8 @@ def ONNXDetOp:ONNX_Op<"Det", "The output is a tensor of shape `[*]`, containing the determinants of all input submatrices." "e.g., When the input is 2-D, the output is a scalar(shape is empty: `[]`)." }]; - let arguments = (ins AnyTensor:$X); - let results = (outs AnyTensor); + let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$X); + let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>); } def ONNXDivOp:ONNX_Op<"Div", @@ -474,8 +474,8 @@ def ONNXDivOp:ONNX_Op<"Div", "" "This operator supports **multidirectional (i.e., Numpy-style) broadcasting**; for more details please check [the doc](Broadcasting.md)." }]; - let arguments = (ins AnyTensor:$A, AnyTensor:$B); - let results = (outs AnyTensor); + let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$A, AnyTypeOf<[AnyMemRef, AnyTensor]>:$B); + let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>); } def ONNXDropoutOp:ONNX_Op<"Dropout", @@ -489,8 +489,8 @@ def ONNXDropoutOp:ONNX_Op<"Dropout", "the training phase, so during testing nothing needs to be done." "This operator has **optional** inputs/outputs. See [the doc](IR.md) for more details about the representation of optional arguments. An empty string may be used in the place of an actual argument's name to indicate a missing argument. Trailing optional arguments (those not followed by an argument that is present) may also be simply omitted." }]; - let arguments = (ins AnyTensor:$data); - let results = (outs AnyTensor, AnyTensor); + let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$data); + let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>, AnyTypeOf<[AnyMemRef, AnyTensor]>); } def ONNXDynamicQuantizeLinearOp:ONNX_Op<"DynamicQuantizeLinear", @@ -520,8 +520,8 @@ def ONNXDynamicQuantizeLinearOp:ONNX_Op<"DynamicQuantizeLinear", "* rounding to nearest ties to even." "```" }]; - let arguments = (ins AnyTensor:$x); - let results = (outs AnyTensor, AnyTensor, AnyTensor); + let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$x); + let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>, AnyTypeOf<[AnyMemRef, AnyTensor]>, AnyTypeOf<[AnyMemRef, AnyTensor]>); } def ONNXEluOp:ONNX_Op<"Elu", @@ -533,8 +533,8 @@ def ONNXEluOp:ONNX_Op<"Elu", "0`, `f(x) = x for x >= 0`., is applied to the tensor elementwise." "" }]; - let arguments = (ins AnyTensor:$X); - let results = (outs AnyTensor); + let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$X); + let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>); } def ONNXEqualOp:ONNX_Op<"Equal", @@ -546,8 +546,8 @@ def ONNXEqualOp:ONNX_Op<"Equal", "" "This operator supports **multidirectional (i.e., Numpy-style) broadcasting**; for more details please check [the doc](Broadcasting.md)." }]; - let arguments = (ins AnyTensor:$A, AnyTensor:$B); - let results = (outs AnyTensor); + let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$A, AnyTypeOf<[AnyMemRef, AnyTensor]>:$B); + let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>); } def ONNXErfOp:ONNX_Op<"Erf", @@ -556,8 +556,8 @@ def ONNXErfOp:ONNX_Op<"Erf", let description = [{ "Computes the error function of the given input tensor element-wise." }]; - let arguments = (ins AnyTensor:$input); - let results = (outs AnyTensor); + let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$input); + let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>); } def ONNXExpOp:ONNX_Op<"Exp", @@ -566,8 +566,8 @@ def ONNXExpOp:ONNX_Op<"Exp", let description = [{ "Calculates the exponential of the given input tensor, element-wise." }]; - let arguments = (ins AnyTensor:$input); - let results = (outs AnyTensor); + let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$input); + let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>); } def ONNXExpandOp:ONNX_Op<"Expand", @@ -583,8 +583,8 @@ def ONNXExpandOp:ONNX_Op<"Expand", "It is possible that the output.shape is not equal to shape, when some dimensions in shape is equal to 1," "or the shape.ndim < input.shape.ndim." }]; - let arguments = (ins AnyTensor:$input, AnyTensor:$shape); - let results = (outs AnyTensor); + let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$input, AnyTypeOf<[AnyMemRef, AnyTensor]>:$shape); + let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>); } def ONNXEyeLikeOp:ONNX_Op<"EyeLike", @@ -599,8 +599,8 @@ def ONNXEyeLikeOp:ONNX_Op<"EyeLike", "The 'dtype' argument must be one of the data types specified in the 'DataType' enum field in the" "TensorProto message and be valid as an output type." }]; - let arguments = (ins AnyTensor:$input); - let results = (outs AnyTensor); + let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$input); + let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>); } def ONNXFlattenOp:ONNX_Op<"Flatten", @@ -611,8 +611,8 @@ def ONNXFlattenOp:ONNX_Op<"Flatten", "(d_0, d_1, ... d_n) then the output will have shape" "(d_0 X d_1 ... d_(axis-1), d_axis X d_(axis+1) ... X dn)." }]; - let arguments = (ins AnyTensor:$input); - let results = (outs AnyTensor); + let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$input); + let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>); } def ONNXFloorOp:ONNX_Op<"Floor", @@ -623,8 +623,8 @@ def ONNXFloorOp:ONNX_Op<"Floor", "(Tensor) where the floor is, y = floor(x), is applied to" "the tensor elementwise." }]; - let arguments = (ins AnyTensor:$X); - let results = (outs AnyTensor); + let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$X); + let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>); } def ONNXGRUOp:ONNX_Op<"GRU", @@ -705,8 +705,8 @@ def ONNXGRUOp:ONNX_Op<"GRU", " - Ht = (1 - zt) (.) ht + zt (.) Ht-1" "This operator has **optional** inputs/outputs. See [the doc](IR.md) for more details about the representation of optional arguments. An empty string may be used in the place of an actual argument's name to indicate a missing argument. Trailing optional arguments (those not followed by an argument that is present) may also be simply omitted." }]; - let arguments = (ins AnyTensor:$X, AnyTensor:$W, AnyTensor:$R, AnyTensor:$B, AnyTensor:$sequence_lens, AnyTensor:$initial_h); - let results = (outs AnyTensor, AnyTensor); + let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$X, AnyTypeOf<[AnyMemRef, AnyTensor]>:$W, AnyTypeOf<[AnyMemRef, AnyTensor]>:$R, AnyTypeOf<[AnyMemRef, AnyTensor]>:$B, AnyTypeOf<[AnyMemRef, AnyTensor]>:$sequence_lens, AnyTypeOf<[AnyMemRef, AnyTensor]>:$initial_h); + let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>, AnyTypeOf<[AnyMemRef, AnyTensor]>); } def ONNXGatherOp:ONNX_Op<"Gather", @@ -771,8 +771,8 @@ def ONNXGatherOp:ONNX_Op<"Gather", " ]" "```" }]; - let arguments = (ins AnyTensor:$data, AnyTensor:$indices); - let results = (outs AnyTensor); + let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$data, AnyTypeOf<[AnyMemRef, AnyTensor]>:$indices); + let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>); } def ONNXGatherElementsOp:ONNX_Op<"GatherElements", @@ -835,8 +835,8 @@ def ONNXGatherElementsOp:ONNX_Op<"GatherElements", " ]" "```" }]; - let arguments = (ins AnyTensor:$data, AnyTensor:$indices); - let results = (outs AnyTensor); + let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$data, AnyTypeOf<[AnyMemRef, AnyTensor]>:$indices); + let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>); } def ONNXGatherNDOp:ONNX_Op<"GatherND", @@ -909,8 +909,8 @@ def ONNXGatherNDOp:ONNX_Op<"GatherND", " output = [[[2,3]],[[4,5]]] # output_shape = [2, 1, 2] " "" }]; - let arguments = (ins AnyTensor:$data, AnyTensor:$indices); - let results = (outs AnyTensor); + let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$data, AnyTypeOf<[AnyMemRef, AnyTensor]>:$indices); + let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>); } def ONNXGemmOp:ONNX_Op<"Gemm", @@ -931,8 +931,8 @@ def ONNXGemmOp:ONNX_Op<"Gemm", "This operator supports **unidirectional broadcasting** (tensor C should be unidirectional broadcastable to tensor A * B); for more details please check [the doc](Broadcasting.md)." "This operator has **optional** inputs/outputs. See [the doc](IR.md) for more details about the representation of optional arguments. An empty string may be used in the place of an actual argument's name to indicate a missing argument. Trailing optional arguments (those not followed by an argument that is present) may also be simply omitted." }]; - let arguments = (ins AnyTensor:$A, AnyTensor:$B, AnyTensor:$C); - let results = (outs AnyTensor); + let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$A, AnyTypeOf<[AnyMemRef, AnyTensor]>:$B, AnyTypeOf<[AnyMemRef, AnyTensor]>:$C); + let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>); } def ONNXGlobalAveragePoolOp:ONNX_Op<"GlobalAveragePool", @@ -943,8 +943,8 @@ def ONNXGlobalAveragePoolOp:ONNX_Op<"GlobalAveragePool", " the values in the same channel. This is equivalent to AveragePool with kernel size" " equal to the spatial dimension of input tensor." }]; - let arguments = (ins AnyTensor:$X); - let results = (outs AnyTensor); + let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$X); + let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>); } def ONNXGlobalLpPoolOp:ONNX_Op<"GlobalLpPool", @@ -955,8 +955,8 @@ def ONNXGlobalLpPoolOp:ONNX_Op<"GlobalLpPool", " the values in the same channel. This is equivalent to LpPool with kernel size" " equal to the spatial dimension of input tensor." }]; - let arguments = (ins AnyTensor:$X); - let results = (outs AnyTensor); + let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$X); + let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>); } def ONNXGlobalMaxPoolOp:ONNX_Op<"GlobalMaxPool", @@ -967,8 +967,8 @@ def ONNXGlobalMaxPoolOp:ONNX_Op<"GlobalMaxPool", " the values in the same channel. This is equivalent to MaxPool with kernel size" " equal to the spatial dimension of input tensor." }]; - let arguments = (ins AnyTensor:$X); - let results = (outs AnyTensor); + let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$X); + let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>); } def ONNXGreaterOp:ONNX_Op<"Greater", @@ -980,8 +980,8 @@ def ONNXGreaterOp:ONNX_Op<"Greater", "" "This operator supports **multidirectional (i.e., Numpy-style) broadcasting**; for more details please check [the doc](Broadcasting.md)." }]; - let arguments = (ins AnyTensor:$A, AnyTensor:$B); - let results = (outs AnyTensor); + let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$A, AnyTypeOf<[AnyMemRef, AnyTensor]>:$B); + let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>); } def ONNXHardSigmoidOp:ONNX_Op<"HardSigmoid", @@ -992,8 +992,8 @@ def ONNXHardSigmoidOp:ONNX_Op<"HardSigmoid", "(Tensor) where the HardSigmoid function, y = max(0, min(1, alpha * x + beta))," "is applied to the tensor elementwise." }]; - let arguments = (ins AnyTensor:$X); - let results = (outs AnyTensor); + let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$X); + let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>); } def ONNXHardmaxOp:ONNX_Op<"Hardmax", @@ -1015,8 +1015,8 @@ def ONNXHardmaxOp:ONNX_Op<"Hardmax", "will throw errors. The output tensor has the same shape" "and contains the hardmax values of the corresponding input." }]; - let arguments = (ins AnyTensor:$input); - let results = (outs AnyTensor); + let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$input); + let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>); } def ONNXIdentityOp:ONNX_Op<"Identity", @@ -1026,8 +1026,8 @@ def ONNXIdentityOp:ONNX_Op<"Identity", let description = [{ "Identity operator" }]; - let arguments = (ins AnyTensor:$input); - let results = (outs AnyTensor); + let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$input); + let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>); } def ONNXIfOp:ONNX_Op<"If", @@ -1036,8 +1036,8 @@ def ONNXIfOp:ONNX_Op<"If", let description = [{ "If conditional" }]; - let arguments = (ins AnyTensor:$cond); - let results = (outs AnyTensor); + let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$cond); + let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>); } def ONNXInstanceNormalizationOp:ONNX_Op<"InstanceNormalization", @@ -1051,8 +1051,8 @@ def ONNXInstanceNormalizationOp:ONNX_Op<"InstanceNormalization", "where mean and variance are computed per instance per channel." "" }]; - let arguments = (ins AnyTensor:$input, AnyTensor:$scale, AnyTensor:$B); - let results = (outs AnyTensor); + let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$input, AnyTypeOf<[AnyMemRef, AnyTensor]>:$scale, AnyTypeOf<[AnyMemRef, AnyTensor]>:$B); + let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>); } def ONNXIsInfOp:ONNX_Op<"IsInf", @@ -1061,8 +1061,8 @@ def ONNXIsInfOp:ONNX_Op<"IsInf", let description = [{ "Map infinity to true and other values to false." }]; - let arguments = (ins AnyTensor:$X); - let results = (outs AnyTensor); + let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$X); + let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>); } def ONNXIsNaNOp:ONNX_Op<"IsNaN", @@ -1071,8 +1071,8 @@ def ONNXIsNaNOp:ONNX_Op<"IsNaN", let description = [{ "Returns which elements of the input are NaN." }]; - let arguments = (ins AnyTensor:$X); - let results = (outs AnyTensor); + let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$X); + let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>); } def ONNXLRNOp:ONNX_Op<"LRN", @@ -1090,8 +1090,8 @@ def ONNXLRNOp:ONNX_Op<"LRN", "" "Y[n, c, d1, ..., dk] = X[n, c, d1, ..., dk] / (bias + alpha / size * square_sum[n, c, d1, ..., dk] ) ^ beta" }]; - let arguments = (ins AnyTensor:$X); - let results = (outs AnyTensor); + let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$X); + let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>); } def ONNXLSTMOp:ONNX_Op<"LSTM", @@ -1180,8 +1180,8 @@ def ONNXLSTMOp:ONNX_Op<"LSTM", " - Ht = ot (.) h(Ct)" "This operator has **optional** inputs/outputs. See [the doc](IR.md) for more details about the representation of optional arguments. An empty string may be used in the place of an actual argument's name to indicate a missing argument. Trailing optional arguments (those not followed by an argument that is present) may also be simply omitted." }]; - let arguments = (ins AnyTensor:$X, AnyTensor:$W, AnyTensor:$R, AnyTensor:$B, AnyTensor:$sequence_lens, AnyTensor:$initial_h, AnyTensor:$initial_c, AnyTensor:$P); - let results = (outs AnyTensor, AnyTensor, AnyTensor); + let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$X, AnyTypeOf<[AnyMemRef, AnyTensor]>:$W, AnyTypeOf<[AnyMemRef, AnyTensor]>:$R, AnyTypeOf<[AnyMemRef, AnyTensor]>:$B, AnyTypeOf<[AnyMemRef, AnyTensor]>:$sequence_lens, AnyTypeOf<[AnyMemRef, AnyTensor]>:$initial_h, AnyTypeOf<[AnyMemRef, AnyTensor]>:$initial_c, AnyTypeOf<[AnyMemRef, AnyTensor]>:$P); + let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>, AnyTypeOf<[AnyMemRef, AnyTensor]>, AnyTypeOf<[AnyMemRef, AnyTensor]>); } def ONNXLeakyReluOp:ONNX_Op<"LeakyRelu", @@ -1192,8 +1192,8 @@ def ONNXLeakyReluOp:ONNX_Op<"LeakyRelu", "output data (Tensor) where the function `f(x) = alpha * x for x < 0`," "`f(x) = x for x >= 0`, is applied to the data tensor elementwise." }]; - let arguments = (ins AnyTensor:$X); - let results = (outs AnyTensor); + let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$X); + let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>); } def ONNXLessOp:ONNX_Op<"Less", @@ -1205,8 +1205,8 @@ def ONNXLessOp:ONNX_Op<"Less", "" "This operator supports **multidirectional (i.e., Numpy-style) broadcasting**; for more details please check [the doc](Broadcasting.md)." }]; - let arguments = (ins AnyTensor:$A, AnyTensor:$B); - let results = (outs AnyTensor); + let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$A, AnyTypeOf<[AnyMemRef, AnyTensor]>:$B); + let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>); } def ONNXLogOp:ONNX_Op<"Log", @@ -1215,8 +1215,8 @@ def ONNXLogOp:ONNX_Op<"Log", let description = [{ "Calculates the natural log of the given input tensor, element-wise." }]; - let arguments = (ins AnyTensor:$input); - let results = (outs AnyTensor); + let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$input); + let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>); } def ONNXLogSoftmaxOp:ONNX_Op<"LogSoftmax", @@ -1238,8 +1238,8 @@ def ONNXLogSoftmaxOp:ONNX_Op<"LogSoftmax", "will throw errors. The output tensor has the same shape" "and contains the logsoftmax values of the corresponding input." }]; - let arguments = (ins AnyTensor:$input); - let results = (outs AnyTensor); + let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$input); + let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>); } def ONNXLoopOp:ONNX_Op<"Loop", @@ -1380,8 +1380,8 @@ def ONNXLoopOp:ONNX_Op<"Loop", "the scan_outputs from the previous layer, possibly going through several" "point-wise operators (e.g. dropout, residual connections, linear layer)." }]; - let arguments = (ins AnyTensor:$M, AnyTensor:$cond, AnyTensor:$v_initial); - let results = (outs AnyTensor); + let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$M, AnyTypeOf<[AnyMemRef, AnyTensor]>:$cond, AnyTypeOf<[AnyMemRef, AnyTensor]>:$v_initial); + let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>); } def ONNXLpNormalizationOp:ONNX_Op<"LpNormalization", @@ -1390,8 +1390,8 @@ def ONNXLpNormalizationOp:ONNX_Op<"LpNormalization", let description = [{ "Given a matrix, apply Lp-normalization along the provided axis." }]; - let arguments = (ins AnyTensor:$input); - let results = (outs AnyTensor); + let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$input); + let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>); } def ONNXLpPoolOp:ONNX_Op<"LpPool", @@ -1404,8 +1404,8 @@ def ONNXLpPoolOp:ONNX_Op<"LpPool", " of the input tensor according to the kernel size and downsampling the" " data into the output tensor Y for further processing." }]; - let arguments = (ins AnyTensor:$X); - let results = (outs AnyTensor); + let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$X); + let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>); } def ONNXMatMulOp:ONNX_Op<"MatMul", @@ -1414,8 +1414,8 @@ def ONNXMatMulOp:ONNX_Op<"MatMul", let description = [{ "Matrix product that behaves like numpy.matmul: https://docs.scipy.org/doc/numpy-1.13.0/reference/generated/numpy.matmul.html" }]; - let arguments = (ins AnyTensor:$A, AnyTensor:$B); - let results = (outs AnyTensor); + let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$A, AnyTypeOf<[AnyMemRef, AnyTensor]>:$B); + let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>); } def ONNXMatMulIntegerOp:ONNX_Op<"MatMulInteger", @@ -1425,8 +1425,8 @@ def ONNXMatMulIntegerOp:ONNX_Op<"MatMulInteger", "Matrix product that behaves like numpy.matmul: https://docs.scipy.org/doc/numpy-1.13.0/reference/generated/numpy.matmul.html." "The production MUST never overflow. The accumulation may overflow if and only if in 32 bits." }]; - let arguments = (ins AnyTensor:$A, AnyTensor:$B, AnyTensor:$a_zero_point, AnyTensor:$b_zero_point); - let results = (outs AnyTensor); + let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$A, AnyTypeOf<[AnyMemRef, AnyTensor]>:$B, AnyTypeOf<[AnyMemRef, AnyTensor]>:$a_zero_point, AnyTypeOf<[AnyMemRef, AnyTensor]>:$b_zero_point); + let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>); } def ONNXMaxOp:ONNX_Op<"Max", @@ -1437,8 +1437,8 @@ def ONNXMaxOp:ONNX_Op<"Max", "All inputs and outputs must have the same data type." "This operator supports **multidirectional (i.e., Numpy-style) broadcasting**; for more details please check [the doc](Broadcasting.md)." }]; - let arguments = (ins Variadic:$data_0); - let results = (outs AnyTensor); + let arguments = (ins Variadic>:$data_0); + let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>); } def ONNXMaxPoolOp:ONNX_Op<"MaxPool", @@ -1475,8 +1475,8 @@ def ONNXMaxPoolOp:ONNX_Op<"MaxPool", " The output of each pooling window is maximum number of elements exclude pad." " " }]; - let arguments = (ins AnyTensor:$X); - let results = (outs AnyTensor, AnyTensor); + let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$X); + let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>, AnyTypeOf<[AnyMemRef, AnyTensor]>); } def ONNXMaxRoiPoolOp:ONNX_Op<"MaxRoiPool", @@ -1487,8 +1487,8 @@ def ONNXMaxRoiPoolOp:ONNX_Op<"MaxRoiPool", " apply max pooling across each RoI, to produce output 4-D tensor of shape" " (num_rois, channels, pooled_shape[0], pooled_shape[1])." }]; - let arguments = (ins AnyTensor:$X, AnyTensor:$rois); - let results = (outs AnyTensor); + let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$X, AnyTypeOf<[AnyMemRef, AnyTensor]>:$rois); + let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>); } def ONNXMaxUnpoolOp:ONNX_Op<"MaxUnpool", @@ -1514,8 +1514,8 @@ def ONNXMaxUnpoolOp:ONNX_Op<"MaxUnpool", " which define the exact unpooling op. The attributes typically have the same values as the corrsponding" " pooling op that the unpooling op is trying to invert." }]; - let arguments = (ins AnyTensor:$X, AnyTensor:$I, AnyTensor:$output_shape); - let results = (outs AnyTensor); + let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$X, AnyTypeOf<[AnyMemRef, AnyTensor]>:$I, AnyTypeOf<[AnyMemRef, AnyTensor]>:$output_shape); + let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>); } def ONNXMeanOp:ONNX_Op<"Mean", @@ -1526,8 +1526,8 @@ def ONNXMeanOp:ONNX_Op<"Mean", "All inputs and outputs must have the same data type." "This operator supports **multidirectional (i.e., Numpy-style) broadcasting**; for more details please check [the doc](Broadcasting.md)." }]; - let arguments = (ins Variadic:$data_0); - let results = (outs AnyTensor); + let arguments = (ins Variadic>:$data_0); + let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>); } def ONNXMeanVarianceNormalizationOp:ONNX_Op<"MeanVarianceNormalization", @@ -1537,8 +1537,8 @@ def ONNXMeanVarianceNormalizationOp:ONNX_Op<"MeanVarianceNormalization", "A MeanVarianceNormalization Function: Perform mean variance normalization" " on the input tensor X using formula:
``` (X-EX)/sqrt(E(X-EX)^2) ```" }]; - let arguments = (ins AnyTensor:$X); - let results = (outs AnyTensor); + let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$X); + let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>); } def ONNXMinOp:ONNX_Op<"Min", @@ -1549,8 +1549,8 @@ def ONNXMinOp:ONNX_Op<"Min", "All inputs and outputs must have the same data type." "This operator supports **multidirectional (i.e., Numpy-style) broadcasting**; for more details please check [the doc](Broadcasting.md)." }]; - let arguments = (ins Variadic:$data_0); - let results = (outs AnyTensor); + let arguments = (ins Variadic>:$data_0); + let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>); } def ONNXModOp:ONNX_Op<"Mod", @@ -1571,8 +1571,8 @@ def ONNXModOp:ONNX_Op<"Mod", "" " This operator supports **multidirectional (i.e., Numpy-style) broadcasting**; for more details please check [the doc](Broadcasting.md)." }]; - let arguments = (ins AnyTensor:$A, AnyTensor:$B); - let results = (outs AnyTensor); + let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$A, AnyTypeOf<[AnyMemRef, AnyTensor]>:$B); + let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>); } def ONNXMulOp:ONNX_Op<"Mul", @@ -1583,8 +1583,8 @@ def ONNXMulOp:ONNX_Op<"Mul", "" "This operator supports **multidirectional (i.e., Numpy-style) broadcasting**; for more details please check [the doc](Broadcasting.md)." }]; - let arguments = (ins AnyTensor:$A, AnyTensor:$B); - let results = (outs AnyTensor); + let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$A, AnyTypeOf<[AnyMemRef, AnyTensor]>:$B); + let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>); } def ONNXMultinomialOp:ONNX_Op<"Multinomial", @@ -1594,8 +1594,8 @@ def ONNXMultinomialOp:ONNX_Op<"Multinomial", "Generate a tensor of samples from a multinomial distribution according to the probabilities" "of each of the possible outcomes." }]; - let arguments = (ins AnyTensor:$input); - let results = (outs AnyTensor); + let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$input); + let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>); } def ONNXNegOp:ONNX_Op<"Neg", @@ -1606,8 +1606,8 @@ def ONNXNegOp:ONNX_Op<"Neg", "(Tensor) where each element flipped sign, y = -x, is applied to" "the tensor elementwise." }]; - let arguments = (ins AnyTensor:$X); - let results = (outs AnyTensor); + let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$X); + let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>); } def ONNXNonMaxSuppressionOp:ONNX_Op<"NonMaxSuppression", @@ -1622,8 +1622,8 @@ def ONNXNonMaxSuppressionOp:ONNX_Op<"NonMaxSuppression", "The selected_indices output is a set of integers indexing into the input collection of bounding boxes representing the selected boxes." "The bounding box coordinates corresponding to the selected indices can then be obtained using the Gather or GatherND operation." }]; - let arguments = (ins AnyTensor:$boxes, AnyTensor:$scores, AnyTensor:$max_output_boxes_per_class, AnyTensor:$iou_threshold, AnyTensor:$score_threshold); - let results = (outs AnyTensor); + let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$boxes, AnyTypeOf<[AnyMemRef, AnyTensor]>:$scores, AnyTypeOf<[AnyMemRef, AnyTensor]>:$max_output_boxes_per_class, AnyTypeOf<[AnyMemRef, AnyTensor]>:$iou_threshold, AnyTypeOf<[AnyMemRef, AnyTensor]>:$score_threshold); + let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>); } def ONNXNonZeroOp:ONNX_Op<"NonZero", @@ -1635,8 +1635,8 @@ def ONNXNonZeroOp:ONNX_Op<"NonZero", " NonZero behaves similar to numpy.nonzero:" " https://docs.scipy.org/doc/numpy/reference/generated/numpy.nonzero.html" }]; - let arguments = (ins AnyTensor:$X); - let results = (outs AnyTensor); + let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$X); + let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>); } def ONNXNotOp:ONNX_Op<"Not", @@ -1645,8 +1645,8 @@ def ONNXNotOp:ONNX_Op<"Not", let description = [{ "Returns the negation of the input tensor element-wise." }]; - let arguments = (ins AnyTensor:$X); - let results = (outs AnyTensor); + let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$X); + let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>); } def ONNXOneHotOp:ONNX_Op<"OneHot", @@ -1673,8 +1673,8 @@ def ONNXOneHotOp:ONNX_Op<"OneHot", " output[i, j, k, input[i, j, k]] = 1 for all i, j, k and 0 otherwise." "" }]; - let arguments = (ins AnyTensor:$indices, AnyTensor:$depth, AnyTensor:$values); - let results = (outs AnyTensor); + let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$indices, AnyTypeOf<[AnyMemRef, AnyTensor]>:$depth, AnyTypeOf<[AnyMemRef, AnyTensor]>:$values); + let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>); } def ONNXOrOp:ONNX_Op<"Or", @@ -1686,8 +1686,8 @@ def ONNXOrOp:ONNX_Op<"Or", "" "This operator supports **multidirectional (i.e., Numpy-style) broadcasting**; for more details please check [the doc](Broadcasting.md)." }]; - let arguments = (ins AnyTensor:$A, AnyTensor:$B); - let results = (outs AnyTensor); + let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$A, AnyTypeOf<[AnyMemRef, AnyTensor]>:$B); + let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>); } def ONNXPReluOp:ONNX_Op<"PRelu", @@ -1699,8 +1699,8 @@ def ONNXPReluOp:ONNX_Op<"PRelu", "`f(x) = x for x >= 0`., is applied to the data tensor elementwise." "This operator supports **unidirectional broadcasting** (tensor slope should be unidirectional broadcastable to input tensor X); for more details please check [the doc](Broadcasting.md)." }]; - let arguments = (ins AnyTensor:$X, AnyTensor:$slope); - let results = (outs AnyTensor); + let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$X, AnyTypeOf<[AnyMemRef, AnyTensor]>:$slope); + let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>); } def ONNXPadOp:ONNX_Op<"Pad", @@ -1789,8 +1789,8 @@ def ONNXPadOp:ONNX_Op<"Pad", " ]" "" }]; - let arguments = (ins AnyTensor:$data, AnyTensor:$pads, AnyTensor:$constant_value); - let results = (outs AnyTensor); + let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$data, AnyTypeOf<[AnyMemRef, AnyTensor]>:$pads, AnyTypeOf<[AnyMemRef, AnyTensor]>:$constant_value); + let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>); } def ONNXPowOp:ONNX_Op<"Pow", @@ -1802,8 +1802,8 @@ def ONNXPowOp:ONNX_Op<"Pow", "is applied to the data tensor elementwise." "This operator supports **multidirectional (i.e., Numpy-style) broadcasting**; for more details please check [the doc](Broadcasting.md)." }]; - let arguments = (ins AnyTensor:$X, AnyTensor:$Y); - let results = (outs AnyTensor); + let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$X, AnyTypeOf<[AnyMemRef, AnyTensor]>:$Y); + let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>); } def ONNXQLinearConvOp:ONNX_Op<"QLinearConv", @@ -1816,8 +1816,8 @@ def ONNXQLinearConvOp:ONNX_Op<"QLinearConv", "It means they must be either scalars (per tensor) or 1-D tensors (per output channel)." "Each input or output and its related zero point must have same type." }]; - let arguments = (ins AnyTensor:$x, AnyTensor:$x_scale, AnyTensor:$x_zero_point, AnyTensor:$w, AnyTensor:$w_scale, AnyTensor:$w_zero_point, AnyTensor:$y_scale, AnyTensor:$y_zero_point, AnyTensor:$B); - let results = (outs AnyTensor); + let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$x, AnyTypeOf<[AnyMemRef, AnyTensor]>:$x_scale, AnyTypeOf<[AnyMemRef, AnyTensor]>:$x_zero_point, AnyTypeOf<[AnyMemRef, AnyTensor]>:$w, AnyTypeOf<[AnyMemRef, AnyTensor]>:$w_scale, AnyTypeOf<[AnyMemRef, AnyTensor]>:$w_zero_point, AnyTypeOf<[AnyMemRef, AnyTensor]>:$y_scale, AnyTypeOf<[AnyMemRef, AnyTensor]>:$y_zero_point, AnyTypeOf<[AnyMemRef, AnyTensor]>:$B); + let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>); } def ONNXQLinearMatMulOp:ONNX_Op<"QLinearMatMul", @@ -1833,8 +1833,8 @@ def ONNXQLinearMatMulOp:ONNX_Op<"QLinearMatMul", "and the number of elements of scale and zero point tensor of input 'b' should be equal to the number of columns of input 'b'." "Production must never overflow, and accumulation may overflow if and only if in 32 bits." }]; - let arguments = (ins AnyTensor:$a, AnyTensor:$a_scale, AnyTensor:$a_zero_point, AnyTensor:$b, AnyTensor:$b_scale, AnyTensor:$b_zero_point, AnyTensor:$y_scale, AnyTensor:$y_zero_point); - let results = (outs AnyTensor); + let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$a, AnyTypeOf<[AnyMemRef, AnyTensor]>:$a_scale, AnyTypeOf<[AnyMemRef, AnyTensor]>:$a_zero_point, AnyTypeOf<[AnyMemRef, AnyTensor]>:$b, AnyTypeOf<[AnyMemRef, AnyTensor]>:$b_scale, AnyTypeOf<[AnyMemRef, AnyTensor]>:$b_zero_point, AnyTypeOf<[AnyMemRef, AnyTensor]>:$y_scale, AnyTypeOf<[AnyMemRef, AnyTensor]>:$y_zero_point); + let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>); } def ONNXQuantizeLinearOp:ONNX_Op<"QuantizeLinear", @@ -1845,8 +1845,8 @@ def ONNXQuantizeLinearOp:ONNX_Op<"QuantizeLinear", "The quantization formula is y = saturate ((x / y_scale) + y_zero_point). For saturation, it saturates to [0, 255] if it's uint8, or [-128, 127] if it's int8." "For (x / y_scale), it's rounding to nearest ties to even. Refer to https://en.wikipedia.org/wiki/Rounding for details. 'y_zero_point' and 'y' must have same type." }]; - let arguments = (ins AnyTensor:$x, AnyTensor:$y_scale, AnyTensor:$y_zero_point); - let results = (outs AnyTensor); + let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$x, AnyTypeOf<[AnyMemRef, AnyTensor]>:$y_scale, AnyTypeOf<[AnyMemRef, AnyTensor]>:$y_zero_point); + let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>); } def ONNXRNNOp:ONNX_Op<"RNN", @@ -1915,8 +1915,8 @@ def ONNXRNNOp:ONNX_Op<"RNN", " - Ht = f(Xt*(Wi^T) + Ht-1*(Ri^T) + Wbi + Rbi)" "This operator has **optional** inputs/outputs. See [the doc](IR.md) for more details about the representation of optional arguments. An empty string may be used in the place of an actual argument's name to indicate a missing argument. Trailing optional arguments (those not followed by an argument that is present) may also be simply omitted." }]; - let arguments = (ins AnyTensor:$X, AnyTensor:$W, AnyTensor:$R, AnyTensor:$B, AnyTensor:$sequence_lens, AnyTensor:$initial_h); - let results = (outs AnyTensor, AnyTensor); + let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$X, AnyTypeOf<[AnyMemRef, AnyTensor]>:$W, AnyTypeOf<[AnyMemRef, AnyTensor]>:$R, AnyTypeOf<[AnyMemRef, AnyTensor]>:$B, AnyTypeOf<[AnyMemRef, AnyTensor]>:$sequence_lens, AnyTypeOf<[AnyMemRef, AnyTensor]>:$initial_h); + let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>, AnyTypeOf<[AnyMemRef, AnyTensor]>); } def ONNXRandomNormalOp:ONNX_Op<"RandomNormal", @@ -1932,7 +1932,7 @@ def ONNXRandomNormalOp:ONNX_Op<"RandomNormal", "TensorProto message." }]; let arguments = (ins ); - let results = (outs AnyTensor); + let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>); } def ONNXRandomNormalLikeOp:ONNX_Op<"RandomNormalLike", @@ -1947,8 +1947,8 @@ def ONNXRandomNormalLikeOp:ONNX_Op<"RandomNormalLike", "The 'dtype' argument must be one of the data types specified in the 'DataType' enum field in the" "TensorProto message, and be valid as an output type." }]; - let arguments = (ins AnyTensor:$input); - let results = (outs AnyTensor); + let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$input); + let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>); } def ONNXRandomUniformOp:ONNX_Op<"RandomUniform", @@ -1963,7 +1963,7 @@ def ONNXRandomUniformOp:ONNX_Op<"RandomUniform", "TensorProto message." }]; let arguments = (ins ); - let results = (outs AnyTensor); + let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>); } def ONNXRandomUniformLikeOp:ONNX_Op<"RandomUniformLike", @@ -1978,8 +1978,8 @@ def ONNXRandomUniformLikeOp:ONNX_Op<"RandomUniformLike", "The 'dtype' argument must be one of the data types specified in the 'DataType' enum field in the" "TensorProto message and be valid as an output type." }]; - let arguments = (ins AnyTensor:$input); - let results = (outs AnyTensor); + let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$input); + let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>); } def ONNXRangeOp:ONNX_Op<"Range", @@ -2012,8 +2012,8 @@ def ONNXRangeOp:ONNX_Op<"Range", "Output: [10, 8, 6]" "" }]; - let arguments = (ins AnyTensor:$start, AnyTensor:$limit, AnyTensor:$delta); - let results = (outs AnyTensor); + let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$start, AnyTypeOf<[AnyMemRef, AnyTensor]>:$limit, AnyTypeOf<[AnyMemRef, AnyTensor]>:$delta); + let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>); } def ONNXReciprocalOp:ONNX_Op<"Reciprocal", @@ -2024,8 +2024,8 @@ def ONNXReciprocalOp:ONNX_Op<"Reciprocal", "(Tensor) where the reciprocal is, y = 1/x, is applied to" "the tensor elementwise." }]; - let arguments = (ins AnyTensor:$X); - let results = (outs AnyTensor); + let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$X); + let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>); } def ONNXReduceL1Op:ONNX_Op<"ReduceL1", @@ -2039,8 +2039,8 @@ def ONNXReduceL1Op:ONNX_Op<"ReduceL1", "The above behavior is similar to numpy, with the exception that numpy default keepdims to" "False instead of True." }]; - let arguments = (ins AnyTensor:$data); - let results = (outs AnyTensor); + let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$data); + let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>); } def ONNXReduceL2Op:ONNX_Op<"ReduceL2", @@ -2054,8 +2054,8 @@ def ONNXReduceL2Op:ONNX_Op<"ReduceL2", "The above behavior is similar to numpy, with the exception that numpy default keepdims to" "False instead of True." }]; - let arguments = (ins AnyTensor:$data); - let results = (outs AnyTensor); + let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$data); + let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>); } def ONNXReduceLogSumOp:ONNX_Op<"ReduceLogSum", @@ -2069,8 +2069,8 @@ def ONNXReduceLogSumOp:ONNX_Op<"ReduceLogSum", "The above behavior is similar to numpy, with the exception that numpy default keepdims to" "False instead of True." }]; - let arguments = (ins AnyTensor:$data); - let results = (outs AnyTensor); + let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$data); + let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>); } def ONNXReduceLogSumExpOp:ONNX_Op<"ReduceLogSumExp", @@ -2084,8 +2084,8 @@ def ONNXReduceLogSumExpOp:ONNX_Op<"ReduceLogSumExp", "The above behavior is similar to numpy, with the exception that numpy default keepdims to" "False instead of True." }]; - let arguments = (ins AnyTensor:$data); - let results = (outs AnyTensor); + let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$data); + let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>); } def ONNXReduceMaxOp:ONNX_Op<"ReduceMax", @@ -2099,8 +2099,8 @@ def ONNXReduceMaxOp:ONNX_Op<"ReduceMax", "The above behavior is similar to numpy, with the exception that numpy default keepdims to" "False instead of True." }]; - let arguments = (ins AnyTensor:$data); - let results = (outs AnyTensor); + let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$data); + let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>); } def ONNXReduceMeanOp:ONNX_Op<"ReduceMean", @@ -2114,8 +2114,8 @@ def ONNXReduceMeanOp:ONNX_Op<"ReduceMean", "The above behavior is similar to numpy, with the exception that numpy default keepdims to" "False instead of True." }]; - let arguments = (ins AnyTensor:$data); - let results = (outs AnyTensor); + let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$data); + let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>); } def ONNXReduceMinOp:ONNX_Op<"ReduceMin", @@ -2129,8 +2129,8 @@ def ONNXReduceMinOp:ONNX_Op<"ReduceMin", "The above behavior is similar to numpy, with the exception that numpy default keepdims to" "False instead of True." }]; - let arguments = (ins AnyTensor:$data); - let results = (outs AnyTensor); + let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$data); + let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>); } def ONNXReduceProdOp:ONNX_Op<"ReduceProd", @@ -2144,8 +2144,8 @@ def ONNXReduceProdOp:ONNX_Op<"ReduceProd", "The above behavior is similar to numpy, with the exception that numpy default keepdims to" "False instead of True." }]; - let arguments = (ins AnyTensor:$data); - let results = (outs AnyTensor); + let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$data); + let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>); } def ONNXReduceSumOp:ONNX_Op<"ReduceSum", @@ -2159,8 +2159,8 @@ def ONNXReduceSumOp:ONNX_Op<"ReduceSum", "The above behavior is similar to numpy, with the exception that numpy default keepdims to" "False instead of True." }]; - let arguments = (ins AnyTensor:$data); - let results = (outs AnyTensor); + let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$data); + let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>); } def ONNXReduceSumSquareOp:ONNX_Op<"ReduceSumSquare", @@ -2174,8 +2174,8 @@ def ONNXReduceSumSquareOp:ONNX_Op<"ReduceSumSquare", "The above behavior is similar to numpy, with the exception that numpy default keepdims to" "False instead of True." }]; - let arguments = (ins AnyTensor:$data); - let results = (outs AnyTensor); + let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$data); + let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>); } def ONNXReluOp:ONNX_Op<"Relu", @@ -2186,8 +2186,8 @@ def ONNXReluOp:ONNX_Op<"Relu", "(Tensor) where the rectified linear function, y = max(0, x), is applied to" "the tensor elementwise." }]; - let arguments = (ins AnyTensor:$X); - let results = (outs AnyTensor); + let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$X); + let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>); } def ONNXReshapeOp:ONNX_Op<"Reshape", @@ -2201,8 +2201,8 @@ def ONNXReshapeOp:ONNX_Op<"Reshape", "could also be 0, in which case the actual dimension value is unchanged (i.e. taken" "from the input tensor)." }]; - let arguments = (ins AnyTensor:$data, AnyTensor:$shape); - let results = (outs AnyTensor); + let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$data, AnyTypeOf<[AnyMemRef, AnyTensor]>:$shape); + let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>); } def ONNXResizeOp:ONNX_Op<"Resize", @@ -2213,8 +2213,8 @@ def ONNXResizeOp:ONNX_Op<"Resize", "Each dimension value of the output tensor is:" " output_dimension = floor(input_dimension * (roi_end - roi_start) * scale) if input \"sizes\" is not specified." }]; - let arguments = (ins AnyTensor:$X, AnyTensor:$roi, AnyTensor:$scales, AnyTensor:$sizes); - let results = (outs AnyTensor); + let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$X, AnyTypeOf<[AnyMemRef, AnyTensor]>:$roi, AnyTypeOf<[AnyMemRef, AnyTensor]>:$scales, AnyTypeOf<[AnyMemRef, AnyTensor]>:$sizes); + let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>); } def ONNXReverseSequenceOp:ONNX_Op<"ReverseSequence", @@ -2255,8 +2255,8 @@ def ONNXReverseSequenceOp:ONNX_Op<"ReverseSequence", " [10.0, 9.0, 8.0, 11.0]," " [15.0, 14.0, 13.0, 12.0]]" }]; - let arguments = (ins AnyTensor:$input, AnyTensor:$sequence_lens); - let results = (outs AnyTensor); + let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$input, AnyTypeOf<[AnyMemRef, AnyTensor]>:$sequence_lens); + let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>); } def ONNXRoiAlignOp:ONNX_Op<"RoiAlign", @@ -2275,8 +2275,8 @@ def ONNXRoiAlignOp:ONNX_Op<"RoiAlign", "the value of the sampled locations are computed directly" "through bilinear interpolation." }]; - let arguments = (ins AnyTensor:$X, AnyTensor:$rois, AnyTensor:$batch_indices); - let results = (outs AnyTensor); + let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$X, AnyTypeOf<[AnyMemRef, AnyTensor]>:$rois, AnyTypeOf<[AnyMemRef, AnyTensor]>:$batch_indices); + let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>); } def ONNXRoundOp:ONNX_Op<"Round", @@ -2297,8 +2297,8 @@ def ONNXRoundOp:ONNX_Op<"Round", "round([-4.5]) = [-4.0]" "```" }]; - let arguments = (ins AnyTensor:$X); - let results = (outs AnyTensor); + let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$X); + let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>); } def ONNXScanOp:ONNX_Op<"Scan", @@ -2427,8 +2427,8 @@ def ONNXScanOp:ONNX_Op<"Scan", " }" "" }]; - let arguments = (ins AnyTensor:$initial_state_and_scan_inputs); - let results = (outs AnyTensor); + let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$initial_state_and_scan_inputs); + let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>); } def ONNXScatterOp:ONNX_Op<"Scatter", @@ -2489,8 +2489,8 @@ def ONNXScatterOp:ONNX_Op<"Scatter", " output = [[1.0, 1.1, 3.0, 2.1, 5.0]]" "```" }]; - let arguments = (ins AnyTensor:$data, AnyTensor:$indices, AnyTensor:$updates); - let results = (outs AnyTensor); + let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$data, AnyTypeOf<[AnyMemRef, AnyTensor]>:$indices, AnyTypeOf<[AnyMemRef, AnyTensor]>:$updates); + let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>); } def ONNXScatterElementsOp:ONNX_Op<"ScatterElements", @@ -2549,8 +2549,8 @@ def ONNXScatterElementsOp:ONNX_Op<"ScatterElements", " output = [[1.0, 1.1, 3.0, 2.1, 5.0]]" "```" }]; - let arguments = (ins AnyTensor:$data, AnyTensor:$indices, AnyTensor:$updates); - let results = (outs AnyTensor); + let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$data, AnyTypeOf<[AnyMemRef, AnyTensor]>:$indices, AnyTypeOf<[AnyMemRef, AnyTensor]>:$updates); + let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>); } def ONNXScatterNDOp:ONNX_Op<"ScatterND", @@ -2614,8 +2614,8 @@ def ONNXScatterNDOp:ONNX_Op<"ScatterND", " [[8, 7, 6, 5], [4, 3, 2, 1], [1, 2, 3, 4], [5, 6, 7, 8]]]" "```" }]; - let arguments = (ins AnyTensor:$data, AnyTensor:$indices, AnyTensor:$updates); - let results = (outs AnyTensor); + let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$data, AnyTypeOf<[AnyMemRef, AnyTensor]>:$indices, AnyTypeOf<[AnyMemRef, AnyTensor]>:$updates); + let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>); } def ONNXSeluOp:ONNX_Op<"Selu", @@ -2627,8 +2627,8 @@ def ONNXSeluOp:ONNX_Op<"Selu", "`y = gamma * (alpha * e^x - alpha) for x <= 0`, `y = gamma * x for x > 0`," "is applied to the tensor elementwise." }]; - let arguments = (ins AnyTensor:$X); - let results = (outs AnyTensor); + let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$X); + let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>); } def ONNXSequenceAtOp:ONNX_Op<"SequenceAt", @@ -2639,8 +2639,8 @@ def ONNXSequenceAtOp:ONNX_Op<"SequenceAt", "Accepted range for 'position' is in `[-n, n - 1]`, where `n` is the number of tensors in 'input_sequence'." "Negative value means counting positions from the back." }]; - let arguments = (ins AnyTensor:$input_sequence, AnyTensor:$position); - let results = (outs AnyTensor); + let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$input_sequence, AnyTypeOf<[AnyMemRef, AnyTensor]>:$position); + let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>); } def ONNXSequenceConstructOp:ONNX_Op<"SequenceConstruct", @@ -2650,8 +2650,8 @@ def ONNXSequenceConstructOp:ONNX_Op<"SequenceConstruct", "Construct a tensor sequence containing 'inputs' tensors." "All tensors in 'inputs' must have the same data type." }]; - let arguments = (ins Variadic:$inputs); - let results = (outs AnyTensor); + let arguments = (ins Variadic>:$inputs); + let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>); } def ONNXSequenceEmptyOp:ONNX_Op<"SequenceEmpty", @@ -2661,7 +2661,7 @@ def ONNXSequenceEmptyOp:ONNX_Op<"SequenceEmpty", "Construct an empty tensor sequence, with given data type." }]; let arguments = (ins ); - let results = (outs AnyTensor); + let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>); } def ONNXSequenceEraseOp:ONNX_Op<"SequenceErase", @@ -2673,8 +2673,8 @@ def ONNXSequenceEraseOp:ONNX_Op<"SequenceErase", "Negative value means counting positions from the back." "'position' is optional, by default it erases the last tensor from 'input_sequence'." }]; - let arguments = (ins AnyTensor:$input_sequence, AnyTensor:$position); - let results = (outs AnyTensor); + let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$input_sequence, AnyTypeOf<[AnyMemRef, AnyTensor]>:$position); + let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>); } def ONNXSequenceInsertOp:ONNX_Op<"SequenceInsert", @@ -2687,8 +2687,8 @@ def ONNXSequenceInsertOp:ONNX_Op<"SequenceInsert", "Negative value means counting positions from the back." "'position' is optional, by default it inserts 'tensor' to the back of 'input_sequence'." }]; - let arguments = (ins AnyTensor:$input_sequence, AnyTensor:$tensor, AnyTensor:$position); - let results = (outs AnyTensor); + let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$input_sequence, AnyTypeOf<[AnyMemRef, AnyTensor]>:$tensor, AnyTypeOf<[AnyMemRef, AnyTensor]>:$position); + let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>); } def ONNXSequenceLengthOp:ONNX_Op<"SequenceLength", @@ -2697,8 +2697,8 @@ def ONNXSequenceLengthOp:ONNX_Op<"SequenceLength", let description = [{ "Produces a scalar(tensor of empty shape) containing the number of tensors in 'input_sequence'." }]; - let arguments = (ins AnyTensor:$input_sequence); - let results = (outs AnyTensor); + let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$input_sequence); + let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>); } def ONNXShapeOp:ONNX_Op<"Shape", @@ -2707,8 +2707,8 @@ def ONNXShapeOp:ONNX_Op<"Shape", let description = [{ "Takes a tensor as input and outputs an 1D int64 tensor containing the shape of the input tensor." }]; - let arguments = (ins AnyTensor:$data); - let results = (outs AnyTensor); + let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$data); + let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>); } def ONNXShrinkOp:ONNX_Op<"Shrink", @@ -2720,8 +2720,8 @@ def ONNXShrinkOp:ONNX_Op<"Shrink", "bias. The formula of this operator is: If x < -lambd, y = x + bias;" "If x > lambd, y = x - bias; Otherwise, y = 0." }]; - let arguments = (ins AnyTensor:$input); - let results = (outs AnyTensor); + let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$input); + let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>); } def ONNXSigmoidOp:ONNX_Op<"Sigmoid", @@ -2732,8 +2732,8 @@ def ONNXSigmoidOp:ONNX_Op<"Sigmoid", "(Tensor) where the sigmoid function, y = 1 / (1 + exp(-x)), is applied to the" "tensor elementwise." }]; - let arguments = (ins AnyTensor:$X); - let results = (outs AnyTensor); + let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$X); + let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>); } def ONNXSignOp:ONNX_Op<"Sign", @@ -2743,8 +2743,8 @@ def ONNXSignOp:ONNX_Op<"Sign", "Calculate the sign of the given input tensor element-wise." "If input > 0, output 1. if input < 0, output -1. if input == 0, output 0." }]; - let arguments = (ins AnyTensor:$input); - let results = (outs AnyTensor); + let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$input); + let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>); } def ONNXSinOp:ONNX_Op<"Sin", @@ -2753,8 +2753,8 @@ def ONNXSinOp:ONNX_Op<"Sin", let description = [{ "Calculates the sine of the given input tensor, element-wise." }]; - let arguments = (ins AnyTensor:$input); - let results = (outs AnyTensor); + let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$input); + let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>); } def ONNXSinhOp:ONNX_Op<"Sinh", @@ -2763,8 +2763,8 @@ def ONNXSinhOp:ONNX_Op<"Sinh", let description = [{ "Calculates the hyperbolic sine of the given input tensor element-wise." }]; - let arguments = (ins AnyTensor:$input); - let results = (outs AnyTensor); + let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$input); + let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>); } def ONNXSizeOp:ONNX_Op<"Size", @@ -2773,8 +2773,8 @@ def ONNXSizeOp:ONNX_Op<"Size", let description = [{ "Takes a tensor as input and outputs a int64 scalar that equals to the total number of elements of the input tensor." }]; - let arguments = (ins AnyTensor:$data); - let results = (outs AnyTensor); + let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$data); + let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>); } def ONNXSliceOp:ONNX_Op<"Slice", @@ -2818,8 +2818,8 @@ def ONNXSliceOp:ONNX_Op<"Slice", " [2, 3, 4]," " ]" }]; - let arguments = (ins AnyTensor:$data, AnyTensor:$starts, AnyTensor:$ends, AnyTensor:$axes, AnyTensor:$steps); - let results = (outs AnyTensor); + let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$data, AnyTypeOf<[AnyMemRef, AnyTensor]>:$starts, AnyTypeOf<[AnyMemRef, AnyTensor]>:$ends, AnyTypeOf<[AnyMemRef, AnyTensor]>:$axes, AnyTypeOf<[AnyMemRef, AnyTensor]>:$steps); + let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>); } def ONNXSoftmaxOp:ONNX_Op<"Softmax", @@ -2841,8 +2841,8 @@ def ONNXSoftmaxOp:ONNX_Op<"Softmax", "will throw errors. The output tensor has the same shape" "and contains the softmax values of the corresponding input." }]; - let arguments = (ins AnyTensor:$input); - let results = (outs AnyTensor); + let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$input); + let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>); } def ONNXSoftplusOp:ONNX_Op<"Softplus", @@ -2853,8 +2853,8 @@ def ONNXSoftplusOp:ONNX_Op<"Softplus", "(Tensor) where the softplus function, y = ln(exp(x) + 1), is applied to" "the tensor elementwise." }]; - let arguments = (ins AnyTensor:$X); - let results = (outs AnyTensor); + let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$X); + let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>); } def ONNXSoftsignOp:ONNX_Op<"Softsign", @@ -2863,8 +2863,8 @@ def ONNXSoftsignOp:ONNX_Op<"Softsign", let description = [{ "Calculates the softsign (x/(1+|x|)) of the given input tensor element-wise." }]; - let arguments = (ins AnyTensor:$input); - let results = (outs AnyTensor); + let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$input); + let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>); } def ONNXSpaceToDepthOp:ONNX_Op<"SpaceToDepth", @@ -2875,8 +2875,8 @@ def ONNXSpaceToDepthOp:ONNX_Op<"SpaceToDepth", "this op outputs a copy of the input tensor where values from the height and width dimensions" "are moved to the depth dimension." }]; - let arguments = (ins AnyTensor:$input); - let results = (outs AnyTensor); + let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$input); + let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>); } def ONNXSplitOp:ONNX_Op<"Split", @@ -2887,8 +2887,8 @@ def ONNXSplitOp:ONNX_Op<"Split", "'axis'. Lengths of the parts can be specified using argument 'split'." "Otherwise, the tensor is split to equal sized parts." }]; - let arguments = (ins AnyTensor:$input); - let results = (outs AnyTensor); + let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$input); + let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>); } def ONNXSplitToSequenceOp:ONNX_Op<"SplitToSequence", @@ -2906,8 +2906,8 @@ def ONNXSplitToSequenceOp:ONNX_Op<"SplitToSequence", "specified in 'split'. In this scenario, the sum of entries in 'split' must be equal to the" "dimension size of input tensor on 'axis'." }]; - let arguments = (ins AnyTensor:$input, AnyTensor:$split); - let results = (outs AnyTensor); + let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$input, AnyTypeOf<[AnyMemRef, AnyTensor]>:$split); + let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>); } def ONNXSqrtOp:ONNX_Op<"Sqrt", @@ -2918,8 +2918,8 @@ def ONNXSqrtOp:ONNX_Op<"Sqrt", "(Tensor) where the square root is, y = x^0.5, is applied to" "the tensor elementwise. If x is negative, then it will return NaN." }]; - let arguments = (ins AnyTensor:$X); - let results = (outs AnyTensor); + let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$X); + let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>); } def ONNXSqueezeOp:ONNX_Op<"Squeeze", @@ -2931,8 +2931,8 @@ def ONNXSqueezeOp:ONNX_Op<"Squeeze", "If `axes` is not provided, all the single dimensions will be removed from" "the shape. If an axis is selected with shape entry not equal to one, an error is raised." }]; - let arguments = (ins AnyTensor:$data); - let results = (outs AnyTensor); + let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$data); + let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>); } def ONNXStringNormalizerOp:ONNX_Op<"StringNormalizer", @@ -2949,8 +2949,8 @@ def ONNXStringNormalizerOp:ONNX_Op<"StringNormalizer", "If all elements in X are dropped, the output will be the empty value of string tensor with shape [1]" "if input shape is [C] and shape [1, 1] if input shape is [1, C]." }]; - let arguments = (ins AnyTensor:$X); - let results = (outs AnyTensor); + let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$X); + let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>); } def ONNXSubOp:ONNX_Op<"Sub", @@ -2961,8 +2961,8 @@ def ONNXSubOp:ONNX_Op<"Sub", "" "This operator supports **multidirectional (i.e., Numpy-style) broadcasting**; for more details please check [the doc](Broadcasting.md)." }]; - let arguments = (ins AnyTensor:$A, AnyTensor:$B); - let results = (outs AnyTensor); + let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$A, AnyTypeOf<[AnyMemRef, AnyTensor]>:$B); + let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>); } def ONNXSumOp:ONNX_Op<"Sum", @@ -2973,8 +2973,8 @@ def ONNXSumOp:ONNX_Op<"Sum", "All inputs and outputs must have the same data type." "This operator supports **multidirectional (i.e., Numpy-style) broadcasting**; for more details please check [the doc](Broadcasting.md)." }]; - let arguments = (ins Variadic:$data_0); - let results = (outs AnyTensor); + let arguments = (ins Variadic>:$data_0); + let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>); } def ONNXTanOp:ONNX_Op<"Tan", @@ -2983,8 +2983,8 @@ def ONNXTanOp:ONNX_Op<"Tan", let description = [{ "Calculates the tangent of the given input tensor, element-wise." }]; - let arguments = (ins AnyTensor:$input); - let results = (outs AnyTensor); + let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$input); + let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>); } def ONNXTanhOp:ONNX_Op<"Tanh", @@ -2993,8 +2993,8 @@ def ONNXTanhOp:ONNX_Op<"Tanh", let description = [{ "Calculates the hyperbolic tangent of the given input tensor element-wise." }]; - let arguments = (ins AnyTensor:$input); - let results = (outs AnyTensor); + let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$input); + let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>); } def ONNXTfIdfVectorizerOp:ONNX_Op<"TfIdfVectorizer", @@ -3029,8 +3029,8 @@ def ONNXTfIdfVectorizerOp:ONNX_Op<"TfIdfVectorizer", "Only one of pool_strings and pool_int64s can be set. If pool_int64s is set, the input should be an integer tensor." "If pool_strings is set, the input must be a string tensor." }]; - let arguments = (ins AnyTensor:$X); - let results = (outs AnyTensor); + let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$X); + let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>); } def ONNXThresholdedReluOp:ONNX_Op<"ThresholdedRelu", @@ -3041,8 +3041,8 @@ def ONNXThresholdedReluOp:ONNX_Op<"ThresholdedRelu", "(Tensor) where the rectified linear function, y = x for x > alpha, y = 0 otherwise," "is applied to the tensor elementwise." }]; - let arguments = (ins AnyTensor:$X); - let results = (outs AnyTensor); + let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$X); + let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>); } def ONNXTileOp:ONNX_Op<"Tile", @@ -3053,8 +3053,8 @@ def ONNXTileOp:ONNX_Op<"Tile", "This is the same as function `tile` in Numpy, but no broadcast." "For example A = [[1, 2], [3, 4]], B = [1, 2], tile(A, B) = [[1, 2, 1, 2], [3, 4, 3, 4]]" }]; - let arguments = (ins AnyTensor:$input, AnyTensor:$repeats); - let results = (outs AnyTensor); + let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$input, AnyTypeOf<[AnyMemRef, AnyTensor]>:$repeats); + let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>); } def ONNXTopKOp:ONNX_Op<"TopK", @@ -3076,8 +3076,8 @@ def ONNXTopKOp:ONNX_Op<"TopK", "Given two equivalent values, this operator uses the indices along the axis as" " a tiebreaker. That is, the element with the lower index will appear first." }]; - let arguments = (ins AnyTensor:$X, AnyTensor:$K); - let results = (outs AnyTensor, AnyTensor); + let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$X, AnyTypeOf<[AnyMemRef, AnyTensor]>:$K); + let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>, AnyTypeOf<[AnyMemRef, AnyTensor]>); } def ONNXTransposeOp:ONNX_Op<"Transpose", @@ -3088,8 +3088,8 @@ def ONNXTransposeOp:ONNX_Op<"Transpose", "perm=(1, 0, 2), given an input tensor of shape (1, 2, 3), the output shape" "will be (2, 1, 3)." }]; - let arguments = (ins AnyTensor:$data); - let results = (outs AnyTensor); + let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$data); + let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>); } def ONNXUniqueOp:ONNX_Op<"Unique", @@ -3172,8 +3172,8 @@ def ONNXUniqueOp:ONNX_Op<"Unique", "" " output_counts = [2 1 1]" }]; - let arguments = (ins AnyTensor:$X); - let results = (outs AnyTensor, AnyTensor, AnyTensor, AnyTensor); + let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$X); + let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>, AnyTypeOf<[AnyMemRef, AnyTensor]>, AnyTypeOf<[AnyMemRef, AnyTensor]>, AnyTypeOf<[AnyMemRef, AnyTensor]>); } def ONNXUnsqueezeOp:ONNX_Op<"Unsqueeze", @@ -3193,8 +3193,8 @@ def ONNXUnsqueezeOp:ONNX_Op<"Unsqueeze", "The order of values in `axes` does not matter and can come in any order. " "" }]; - let arguments = (ins AnyTensor:$data); - let results = (outs AnyTensor); + let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$data); + let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>); } def ONNXUpsampleOp:ONNX_Op<"Upsample", @@ -3205,8 +3205,8 @@ def ONNXUpsampleOp:ONNX_Op<"Upsample", "Each dimension value of the output tensor is:" " output_dimension = floor(input_dimension * scale)." }]; - let arguments = (ins AnyTensor:$X, AnyTensor:$scales); - let results = (outs AnyTensor); + let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$X, AnyTypeOf<[AnyMemRef, AnyTensor]>:$scales); + let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>); } def ONNXWhereOp:ONNX_Op<"Where", @@ -3218,8 +3218,8 @@ def ONNXWhereOp:ONNX_Op<"Where", " Where behaves like numpy.where with three parameters:" " https://docs.scipy.org/doc/numpy/reference/generated/numpy.where.html" }]; - let arguments = (ins AnyTensor:$condition, AnyTensor:$X, AnyTensor:$Y); - let results = (outs AnyTensor); + let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$condition, AnyTypeOf<[AnyMemRef, AnyTensor]>:$X, AnyTypeOf<[AnyMemRef, AnyTensor]>:$Y); + let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>); } def ONNXXorOp:ONNX_Op<"Xor", @@ -3231,7 +3231,7 @@ def ONNXXorOp:ONNX_Op<"Xor", "" "This operator supports **multidirectional (i.e., Numpy-style) broadcasting**; for more details please check [the doc](Broadcasting.md)." }]; - let arguments = (ins AnyTensor:$A, AnyTensor:$B); - let results = (outs AnyTensor); + let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$A, AnyTypeOf<[AnyMemRef, AnyTensor]>:$B); + let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>); } diff --git a/src/compiler/pass/lower_frontend_to_krnl.cpp b/src/compiler/pass/lower_frontend_to_krnl.cpp new file mode 100644 index 0000000..e92eabb --- /dev/null +++ b/src/compiler/pass/lower_frontend_to_krnl.cpp @@ -0,0 +1,282 @@ +//====- lower_frontend_to_krnl.cpp - Frontend dialects to Krnl lowering ---===// +// +// Copyright 2019 The DLC Authors. +// +// ============================================================================= +// +// This file implements the lowering of frontend operations to a combination of +// Krnl IR and standard operations. +// +//===----------------------------------------------------------------------===// + +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/Sequence.h" +#include "mlir/Dialect/AffineOps/AffineOps.h" +#include "mlir/Dialect/StandardOps/Ops.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" + +#include "src/compiler/dialect/krnl/krnl_ops.hpp" +#include "src/compiler/dialect/onnx/onnx_ops.hpp" + +#include "src/compiler/pass/passes.hpp" + +using namespace mlir; + +//===----------------------------------------------------------------------===// +// FrontendToAffine RewritePatterns +//===----------------------------------------------------------------------===// + +/// Check is all dimensions are known at compile time. +static bool hasAllConstantDimensions(MemRefType type) { + auto memRefShape = type.getShape(); + for (int i = 0; i < memRefShape.size(); ++i) + if (memRefShape[i] < 0) + return false; + return true; +} + +/// Convert the given TensorType into the corresponding MemRefType. +static MemRefType convertTensorToMemRef(TensorType type) { + assert(type.hasRank() && "expected only ranked shapes"); + return MemRefType::get(type.getShape(), type.getElementType()); +} + +/// Insert an allocation and deallocation for the given MemRefType. +static Value* insertAllocAndDealloc( + MemRefType type, Location loc, PatternRewriter& rewriter, + Value *oldMemRef = nullptr) { + // Put together alloc operands for any dynamic dimensions of the memref. + AllocOp alloc; + if (oldMemRef) { + SmallVector allocOperands; + auto memRefShape = type.getShape(); + for (int i = 0; i < memRefShape.size(); ++i) + if (memRefShape[i] < 0) + allocOperands.push_back(rewriter.create(loc, oldMemRef, i)); + + alloc = rewriter.create(loc, type, allocOperands); + } else { + alloc = rewriter.create(loc, type); + } + + // Make sure to allocate at the beginning of the block if + // all dimensions are known. + auto* parentBlock = alloc.getOperation()->getBlock(); + if (hasAllConstantDimensions(type)) + alloc.getOperation()->moveBefore(&parentBlock->front()); + + return alloc; +} + +namespace { + +//===----------------------------------------------------------------------===// +// AddOp lowering to Krnl dialect. +//===----------------------------------------------------------------------===// +struct ONNXAddOpLowering : public ConversionPattern { + ONNXAddOpLowering(MLIRContext* ctx) + : ConversionPattern(mlir::ONNXAddOp::getOperationName(), 1, ctx) {} + + PatternMatchResult matchAndRewrite(Operation* op, ArrayRef operands, + ConversionPatternRewriter& rewriter) const final { + // TODO: Check that the types are valid. + // Add is an operation that must have all operands and the result of + // the same type. This should have been verified by the verifier. + auto tensorType = (*op->result_type_begin()).cast(); + auto loc = op->getLoc(); + + // Insert an allocation and deallocation for the result of this operation. + auto memRefType = convertTensorToMemRef(tensorType); + + // If the output has a dynamic dimension, pass the operands required for + // each dynamic dimension to the AllocOp. The first operand of the Add + // operation is used. The operands of the Add need to match in terms of + // dimensions with the result at this pre-optimization phase. + // TODO: verify that dimensions match. + // TODO: can the dimension of the result differ after optimizations? + Value *alloc; + if (hasAllConstantDimensions(memRefType)) + alloc = insertAllocAndDealloc(memRefType, loc, rewriter); + else + alloc = insertAllocAndDealloc(memRefType, loc, rewriter, operands[0]); + + // Number of loops + auto memRefShape = memRefType.getShape(); + int64_t rank = memRefShape.size(); + + // Define loops. + auto loopsOp = rewriter.create(loc, rank); + std::vector originalLoops; + originalLoops.reserve(rank); + for (auto result : loopsOp.getResults()) { + originalLoops.push_back(result); + } + + // Define loop optimization. + auto optimizedLoopsOp = rewriter.create(loc, rank); + std::vector optimizedLoops; + optimizedLoops.reserve(rank); + for (auto result : optimizedLoopsOp.getResults()) { + optimizedLoops.push_back(result); + } + Block& optimizationBlock = optimizedLoopsOp.region().front(); + + // Iterate over the loop nest. + // TODO (Tian): move this logic inside KrnlIterateOp. Pass MemRefShape + // to KrnlIterateOp instead. + SmallVector operandBounds; + SmallVector constBounds; + SmallVector boundTypes; + for (int i = 0; i < rank; ++i) { + if (memRefShape[i] < 0) { + // This is a dynamic value, hence use operands. + // Lower bound + constBounds.push_back(0); + boundTypes.push_back(0); + // Upper bound + operandBounds.push_back( + rewriter.create(loc, operands[0], i).getResult()); + boundTypes.push_back(1); + } else { + // Lower bound + constBounds.push_back(0); + boundTypes.push_back(0); + // Upper bound + constBounds.push_back(memRefShape[i]); + boundTypes.push_back(0); + } + } + auto iterateOp = rewriter.create(loc, originalLoops, + optimizedLoops, operandBounds, constBounds, boundTypes); + Block& iterationBlock = iterateOp.bodyRegion().front(); + + // Now perform the insertions into the body of the + // just generated instructions: + + // 1. Insert any optimizations in the KrnlOptimizeLoopsOp body. + rewriter.setInsertionPointToEnd(&optimizationBlock); + // Return from KrnlOptimizeLoopsOp body. + // When no optimizations are present we just return the loops + // unchaged. + rewriter.create(loc, originalLoops); + rewriter.setInsertionPoint(optimizedLoopsOp); + + // 2. Insert instructions inside the KernelIterateOp body. + rewriter.setInsertionPointToStart(&iterationBlock); + + // Handle AddOp: + SmallVector loopIVs; + for (auto arg : iterationBlock.getArguments()) + loopIVs.push_back(arg); + auto loadedFirstVal = + rewriter.create(loc, operands[0], loopIVs); + auto loadedSecondVal = + rewriter.create(loc, operands[1], loopIVs); + + // TODO: Choose type of the Add for now use the Float Add. + auto addOpResult = rewriter.create( + loc, loadedFirstVal, loadedSecondVal); + + // Store result in the resulting array. + rewriter.create(loc, addOpResult, alloc, loopIVs); + + rewriter.replaceOp(op, alloc); + + return matchSuccess(); + } +}; + +//===----------------------------------------------------------------------===// +// Conversion from Tensor type to the Standard dialect MemRef type. +//===----------------------------------------------------------------------===// + +struct TensorTypeConverter : public TypeConverter { + using TypeConverter::TypeConverter; + + LogicalResult convertType(Type t, SmallVectorImpl& results) override { + if (auto tensor_type = t.dyn_cast()) { + results.push_back(convertTensorToMemRef(tensor_type)); + return success(); + } + + results.push_back(t); + return success(); + } + + /// Return true if the inputs and outputs of the given function type are + /// legal. [Taken from MLIR and adapted to only check the legality of the + /// inputs. Once unranked results can be handled gracefully this + /// override needs to be removed in favour of the original MLIR one.] + bool isSignatureLegal(FunctionType funcType) { + return llvm::all_of(funcType.getInputs(), + [this](Type type) { return isLegal(type); }); + } +}; + +} // end anonymous namespace. + +//===----------------------------------------------------------------------===// +// Frontend to Krnl Dialect lowering pass +//===----------------------------------------------------------------------===// + +/// This is a partial lowering to Krnl loops of the ONNX operations. +namespace { +struct FrontendToKrnlLoweringPass + : public ModulePass { + void runOnModule() final; +}; +} // end anonymous namespace. + +void FrontendToKrnlLoweringPass::runOnModule() { + auto module = getModule(); + + // The first thing to define is the conversion target. This will define the + // final target for this lowering. + ConversionTarget target(getContext()); + + // We define the specific operations, or dialects, that are legal targets for + // this lowering. + target + .addLegalDialect(); + + // TODO: enable this once more ops are supported. + // We also define the ONNX dialect as Illegal so that the conversion will fail + // if any of these operations are *not* converted. + // target.addIllegalDialect(); + + // TODO: add any other ops which are considered legal. + // Some operations can be marked as being still legal. + // Example: target.addLegalOp(); + + // Now that the conversion target has been defined, we just need to provide + // the set of patterns that will lower the frontend operations. + OwningRewritePatternList patterns; + + // Convert TensorType to MemRef + TensorTypeConverter tensor_to_memref_converter; + target.addDynamicallyLegalOp([&](FuncOp op) { + // FuncOp is legal only if types have been converted to Std types. + return tensor_to_memref_converter.isSignatureLegal(op.getType()); + }); + + // Type conversion for function signatures. + // Call MLIR FuncOp signature conversion when result type is + // a ranked tensor. + populateFuncOpTypeConversionPattern( + patterns, &getContext(), tensor_to_memref_converter); + + // Frontent operation lowering. + patterns.insert(&getContext()); + + // With the target and rewrite patterns defined, we can now attempt the + // conversion. The conversion will signal failure if any of our `illegal` + // operations were not converted successfully. + if (failed(applyPartialConversion( + module, target, patterns))) + signalPassFailure(); +} + +std::unique_ptr mlir::createLowerToKrnlPass() { + return std::make_unique(); +} diff --git a/src/compiler/pass/passes.hpp b/src/compiler/pass/passes.hpp index 995da61..ee10dd0 100644 --- a/src/compiler/pass/passes.hpp +++ b/src/compiler/pass/passes.hpp @@ -17,7 +17,8 @@ class Pass; std::unique_ptr createShapeInferencePass(); -// TODO: Add pass for lowering to kernel IR. +/// Add pass for lowering to Krnl IR. +std::unique_ptr createLowerToKrnlPass(); // TODO: Add pass for lowering to LLVM IR. diff --git a/src/compiler/pass/shape_inference_pass.cpp b/src/compiler/pass/shape_inference_pass.cpp index 1548369..94e5389 100644 --- a/src/compiler/pass/shape_inference_pass.cpp +++ b/src/compiler/pass/shape_inference_pass.cpp @@ -71,6 +71,12 @@ class ShapeInferencePass : public mlir::FunctionPass { << op_worklist.size() << " operations couldn't be inferred\n"; signalPassFailure(); } + + if (auto terminator_op = f.getBody().back().getTerminator()) { + auto results = terminator_op->getOperandTypes(); + f.setType(FunctionType::get(f.getType().getInputs(), + std::vector(results.begin(), results.end()), f.getContext())); + } } /*! diff --git a/src/compiler/tool/onnf_opt/CMakeLists.txt b/src/compiler/tool/onnf_opt/CMakeLists.txt index 6aba684..3ec11da 100644 --- a/src/compiler/tool/onnf_opt/CMakeLists.txt +++ b/src/compiler/tool/onnf_opt/CMakeLists.txt @@ -3,14 +3,8 @@ add_executable(onnf-opt onnf_opt.cpp) target_include_directories(onnf-opt PRIVATE ${ONNF_SRC_ROOT}) target_include_directories(onnf-opt PRIVATE ${ONNF_BIN_ROOT}) -set(LIB_LIST - MLIRStandardOps - MLIRAffineOps - MLIRLoopOps - MLIRTransformUtils - MLIREDSC - MLIRTransforms) -whole_archive_link_mlir(onnf-opt ${LIB_LIST}) +target_link_libraries(onnf-opt compiler ${MLIRLibs}) +whole_archive_link_mlir(onnf-opt ${MLIRWholeArchiveLibs}) # TODO: need to investigate how to whole-archive link compiler pass to onnf-opt. target_link_libraries(onnf-opt compiler) diff --git a/src/main.cpp b/src/main.cpp index abd3718..61fd846 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -124,6 +124,7 @@ int main(int ac, char* av[]) { mlir::PassManager pm(&context); pm.addPass(mlir::createShapeInferencePass()); pm.addPass(mlir::createCanonicalizerPass()); + pm.addPass(mlir::createLowerToKrnlPass()); pm.run(*module); return 0; diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 4cb7b50..967b5b0 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -1,4 +1 @@ -add_subdirectory(models) -add_subdirectory(nodes) - add_subdirectory(mlir) \ No newline at end of file