diff --git a/doc/Dialects/onnx.md b/doc/Dialects/onnx.md index e4ca150..b7cb820 100644 --- a/doc/Dialects/onnx.md +++ b/doc/Dialects/onnx.md @@ -3749,7 +3749,7 @@ ONNX Reshape operation #### Operands: 1. `data`: memref of any type values or tensor of any type values -1. `shape`: memref of any type values or tensor of any type values +1. `shape`: memref of any type values or tensor of any type values or none type #### Attributes: diff --git a/doc/gen_doc.py b/doc/gen_doc.py index 0ad5db7..4f921a9 100644 --- a/doc/gen_doc.py +++ b/doc/gen_doc.py @@ -47,13 +47,20 @@ OpsWithShapeInference = [ 'LeakyRelu', 'Elu', 'Selu', 'HardSigmoid', 'Reshape', 'Reciprocal', 'Identity', 'Cos', 'Log', 'Transpose', 'Softmax', 'ReduceMax', 'ReduceMin', 'ReduceProd', 'ReduceSum', 'Softplus', 'Softsign', 'Sqrt', 'Unsqueeze', - 'Sign', 'Constant', 'ONNXAveragePoolOp', 'Abs' + 'Sign', 'Constant', 'AveragePool', 'Abs' ] # Operations supporting canonicalization. -OpsWithCanonicalizer = [ - 'Add', 'Identity', 'Gemm' -] +OpsWithCanonicalizer = ['Add', 'Identity', 'Gemm'] + +# Operations who have operands that, if produced by constant operations, should +# be promoted to become an attribute (via attribute promotion). +# +# For each operation, a key/value pair is used to specify how attribute promotion +# should proceed. The key is the operation's name and the value is a list of +# tuples, whose first item is the attribute/operand name, and the second item is +# the index at which such operand occurs in the list of the operation's inputs. +OpsWithPromotableConstOperands = {"Reshape": [("shape", 1)]} # Add an Op in this list if the Op needs result type deduction which is required # when writing declarative rewriting rules. Deduced type is always @@ -224,7 +231,7 @@ def get_operands_or_results(schema, is_input): return "AnyTypeOf<[{}]>".format(", ".join(types)) name_to_types = OrderedDict() - for value in value_list: + for i, value in enumerate(value_list): elem_types = get_allowed_elem_types(schema, value) if elem_types is None: @@ -233,6 +240,13 @@ def get_operands_or_results(schema, is_input): types = ["TensorOf<[{}]>", "MemRefOf<[{}]>"] types = list(map(lambda x: x.format(elem_types), types)) + # If operand is promotable to an attribute, then it must be + # nullable in case it migrates to be an attribute. + if schema.name in OpsWithPromotableConstOperands: + idxs = dict(OpsWithPromotableConstOperands[schema.name]).values() + if i in idxs: + types.append("NoneType") + if OpSchema.FormalParameterOption.Optional == value.option: types.append("NoneType") elif OpSchema.FormalParameterOption.Variadic == value.option: @@ -313,6 +327,25 @@ def get_attrs(schema): return name_to_type +def get_promotable_const_operands_func(s, indent, const_operands_name_to_idx): + cpp_name_to_idx_literal = "{" + ", ".join([ + "{{\"{}\", {}}}".format(*name_to_idx) + for name_to_idx in const_operands_name_to_idx + ]) + "}" + + s += indent + "let extraClassDeclaration = [{\n" + indent = inc_indent(indent) + s += indent + "std::map promotableConstOperands() {\n" + indent = inc_indent(indent) + s += indent + "return {};\n".format(cpp_name_to_idx_literal) + indent = dec_indent(indent) + s += indent + "}\n" + indent = dec_indent(indent) + s += indent + "}];\n" + + return s + + def gen_op_def(schema): indent = inc_indent() s = 'def ONNX{0}Op:ONNX_Op<"{0}",\n'.format(schema.name) @@ -321,6 +354,8 @@ def gen_op_def(schema): traits = ["NoSideEffect"] if schema.name in OpsWithShapeInference: traits.append("DeclareOpInterfaceMethods") + if schema.name in OpsWithPromotableConstOperands.keys(): + traits.append("OpInterface<\"PromotableConstOperandsOpInterface\">") s += inc_indent(indent) + '[{}]> {{\n'.format(join_args(traits)) # Generate decl for canonicalizer. @@ -400,6 +435,9 @@ def gen_op_def(schema): s += '\n' + indent + '];\n' + if schema.name in OpsWithPromotableConstOperands: + s = get_promotable_const_operands_func( + s, indent, OpsWithPromotableConstOperands[schema.name]) s += '}\n\n' return s @@ -506,7 +544,7 @@ if __name__ == '__main__': curr_dir = os.path.dirname(os.path.realpath(__file__)) class Args(object): - op_def_file = os.path.join(curr_dir, 'onnxop.inc') + op_def_file = os.path.join(curr_dir, 'onnx_ops.td.inc') op_importer_file = os.path.join(curr_dir, 'op_build_table.inc') main(Args) diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 8196b2d..d6b8ac3 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -8,7 +8,6 @@ add_library(compiler dialect/krnl/krnl_helper.cpp dialect/krnl/krnl_helper.hpp pass/shape_inference_interface.hpp - dialect/onnx/onnxop.inc pass/onnx_combine.cpp pass/onnx_rewrite.cpp pass/onnx_decompose.cpp @@ -47,12 +46,22 @@ onnx_mlir_tablegen(onnx_rewrite.inc -gen-rewriters) add_public_tablegen_target(gen_onnx_rewrite) add_dependencies(compiler gen_onnx_rewrite) +add_subdirectory(interface) + set(LLVM_TARGET_DEFINITIONS dialect/onnx/onnx.td) onnx_mlir_tablegen(onnx.hpp.inc -gen-op-decls "-I${CMAKE_SOURCE_DIR}/compiler/pass") onnx_mlir_tablegen(onnx.cpp.inc -gen-op-defs "-I${CMAKE_SOURCE_DIR}/compiler/pass") set(GEN_DOC_FILE ${CMAKE_BINARY_DIR}/docs/Dialects/onnx.md) add_public_tablegen_target(gen_onnx) + +add_dependencies(gen_onnx gen_shape_inference) add_dependencies(compiler gen_onnx) + +# TODO: onnx_mlir_gen_promotable_const_operands_op_interface is really a +# dependency of the onnx dialect library, which is currently part of `compiler`. +add_dependencies(compiler onnx_mlir_gen_promotable_const_operands_op_interface) + + add_onnx_mlir_dialect_doc(onnx dialect/onnx/onnx.td) set(LLVM_TARGET_DEFINITIONS dialect/krnl/krnl_ops.td) @@ -66,14 +75,14 @@ target_include_directories(onnx_mlir_onnx_decompose PRIVATE ${ONNX_MLIR_SRC_ROOT} ${ONNX_MLIR_BIN_ROOT} ${ONNX_MLIR_SRC_ROOT}) target_link_libraries(onnx_mlir_onnx_decompose ${MLIRLibs}) -add_dependencies(onnx_mlir_onnx_decompose gen_krnl_ops) +add_dependencies(onnx_mlir_onnx_decompose gen_onnx) add_library(onnx_mlir_shape_inference pass/shape_inference_pass.cpp) target_include_directories(onnx_mlir_shape_inference PRIVATE ${ONNX_MLIR_SRC_ROOT} ${ONNX_MLIR_BIN_ROOT} ${ONNX_MLIR_SRC_ROOT}) target_link_libraries(onnx_mlir_shape_inference ${MLIRLibs}) -add_dependencies(onnx_mlir_shape_inference gen_krnl_ops) +add_dependencies(onnx_mlir_shape_inference gen_onnx) add_library(onnx_mlir_lower_frontend conversion/onnx_to_krnl/onnx_to_krnl_common.cpp @@ -106,8 +115,17 @@ add_subdirectory(runtime) add_executable(onnx-mlir main.cpp) -target_link_libraries(onnx-mlir builder ${MLIRLibs} onnx_mlir_transform onnx_mlir_onnx_decompose onnx_mlir_shape_inference onnx_mlir_lower_frontend) -whole_archive_link_mlir(onnx-mlir ${MLIRWholeArchiveLibs}) +target_link_libraries(onnx-mlir + builder + ${MLIRLibs} + onnx_mlir_transform + onnx_mlir_onnx_decompose + onnx_mlir_shape_inference + onnx_mlir_lower_frontend + onnx_mlir_attribute_promotion) +whole_archive_link_mlir(onnx-mlir + ${MLIRWholeArchiveLibs}) + find_package(ZLIB REQUIRED) target_link_libraries(onnx-mlir ${ZLIB_LIBRARIES}) diff --git a/src/dialect/onnx/CMakeLists.txt b/src/dialect/onnx/CMakeLists.txt new file mode 100644 index 0000000..e69de29 diff --git a/src/dialect/onnx/onnx.td b/src/dialect/onnx/onnx.td index 218c573..c632dc3 100644 --- a/src/dialect/onnx/onnx.td +++ b/src/dialect/onnx/onnx.td @@ -22,6 +22,11 @@ include "mlir/IR/OpBase.td" include "pass/shape_inference_interface.td" #endif // SHAPE_INFERENCE_INTERFACE +#ifdef PROMOTABLE_CONST_OPERANDS_OP_INTERFACE +#else +include "interface/promotable_const_operands_op_interface.td" +#endif // PROMOTABLE_CONST_OPERANDS_OP_INTERFACE + def ONNX_Dialect : Dialect { let name = "onnx"; let cppNamespace = ""; @@ -48,7 +53,7 @@ class ONNX_Op traits = []> : // CC=gcc CXX=g++ pip install -e . //run the script // python onnx/defs/gen_doc.py -//result is in docs/onnxop.inc +//result is in docs/onnx_ops.td.inc //current limitations: // 1. Attributes are not processed // 2. output type inference not implemented except Add @@ -56,7 +61,7 @@ class ONNX_Op traits = []> : // 4. type of string, complex64 and complex128 for input/output are ignored // 5. unsigned int are treated as signed one -include "dialect/onnx/onnxop.inc" +include "dialect/onnx/onnx_ops.td.inc" // Indicate entry point functions of ONNX graph. def ONNXEntryPointOp: ONNX_Op<"EntryPoint"> { diff --git a/src/dialect/onnx/onnx_ops.hpp b/src/dialect/onnx/onnx_ops.hpp index b981aed..0aa107d 100644 --- a/src/dialect/onnx/onnx_ops.hpp +++ b/src/dialect/onnx/onnx_ops.hpp @@ -10,6 +10,9 @@ #pragma once +#include +#include + #include "mlir/Dialect/StandardOps/Ops.h" #include "mlir/IR/Builders.h" #include "mlir/IR/Dialect.h" @@ -17,6 +20,7 @@ #include "mlir/IR/StandardTypes.h" #include "src/pass/shape_inference_interface.hpp" +#include "src/interface/promotable_const_operands_op_interface.hpp" namespace mlir { diff --git a/src/dialect/onnx/onnxop.inc b/src/dialect/onnx/onnx_ops.td.inc similarity index 99% rename from src/dialect/onnx/onnxop.inc rename to src/dialect/onnx/onnx_ops.td.inc index d70e1f9..99b1f5f 100644 --- a/src/dialect/onnx/onnxop.inc +++ b/src/dialect/onnx/onnx_ops.td.inc @@ -1,5 +1,5 @@ //******************************************************** -// This file is generated on UTC-02/24/2020, 06:44:13. +// This file is generated on UTC-03/18/2020, 03:36:58. // Do not modify this file directly. // This file is automatically generated via script. // Details can be found in doc/readonnxdefs.md . @@ -2501,7 +2501,7 @@ def ONNXReluOp:ONNX_Op<"Relu", } def ONNXReshapeOp:ONNX_Op<"Reshape", - [NoSideEffect, DeclareOpInterfaceMethods]> { + [NoSideEffect, DeclareOpInterfaceMethods, OpInterface<"PromotableConstOperandsOpInterface">]> { let summary = "ONNX Reshape operation"; let description = [{ "Reshape the input tensor similar to numpy.reshape." @@ -2512,8 +2512,13 @@ def ONNXReshapeOp:ONNX_Op<"Reshape", "from the input tensor)." }]; let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$data, - AnyTypeOf<[AnyMemRef, AnyTensor]>:$shape); + AnyTypeOf<[AnyMemRef, AnyTensor, NoneType]>:$shape); let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$reshaped); + let extraClassDeclaration = [{ + std::map promotableConstOperands() { + return {{"shape", 1}}; + } + }]; } def ONNXResizeOp:ONNX_Op<"Resize", diff --git a/src/interface/CMakeLists.txt b/src/interface/CMakeLists.txt new file mode 100644 index 0000000..d9780c7 --- /dev/null +++ b/src/interface/CMakeLists.txt @@ -0,0 +1,14 @@ +set(LLVM_TARGET_DEFINITIONS promotable_const_operands_op_interface.td) +onnx_mlir_tablegen(promotable_const_operands_op_interface.hpp.inc -gen-op-interface-decls) +onnx_mlir_tablegen(promotable_const_operands_op_interface.cpp.inc -gen-op-interface-defs) +add_public_tablegen_target(onnx_mlir_gen_promotable_const_operands_op_interface) + +add_library(onnx_mlir_promotable_const_operands_op_interface + promotable_const_operands_op_interface.hpp + promotable_const_operands_op_interface.cpp) +target_include_directories(onnx_mlir_promotable_const_operands_op_interface + PRIVATE ${ONNX_MLIR_SRC_ROOT} ${ONNX_MLIR_BIN_ROOT} + ${ONNX_MLIR_SRC_ROOT}) + +add_dependencies(onnx_mlir_promotable_const_operands_op_interface + onnx_mlir_gen_promotable_const_operands_op_interface) \ No newline at end of file diff --git a/src/interface/promotable_const_operands_op_interface.cpp b/src/interface/promotable_const_operands_op_interface.cpp new file mode 100644 index 0000000..8dd1402 --- /dev/null +++ b/src/interface/promotable_const_operands_op_interface.cpp @@ -0,0 +1,23 @@ +//===------------ promotable_const_operands_op_interface.cpp --------------===// +//===-------- Promotable Const Operands Op Interface Definition -----------===// +// +// Copyright 2020 The IBM Research Authors. +// +// ============================================================================= +// +// This file contains the definition of the promotable const operands op +// interface. +// +//===----------------------------------------------------------------------===// + + +#include "src/interface/promotable_const_operands_op_interface.hpp" + +using namespace mlir; + +//===----------------------------------------------------------------------===// +// Promotable Const Operands Op Interface +//===----------------------------------------------------------------------===// + +#include "src/interface/promotable_const_operands_op_interface.cpp.inc" + diff --git a/src/interface/promotable_const_operands_op_interface.hpp b/src/interface/promotable_const_operands_op_interface.hpp new file mode 100644 index 0000000..13a864e --- /dev/null +++ b/src/interface/promotable_const_operands_op_interface.hpp @@ -0,0 +1,25 @@ +//===------------ promotable_const_operands_op_interface.cpp --------------===// +//===-------- Promotable Const Operands Op Interface Definition -----------===// +// +// Copyright 2020 The IBM Research Authors. +// +// ============================================================================= +// +// This file contains the declaration of the promotable const operands op +// interface. +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include +#include + +#include "mlir/IR/OpDefinition.h" + +namespace mlir { + +/// Include the auto-generated declarations. +#include "src/interface/promotable_const_operands_op_interface.hpp.inc" + +} // end namespace mlir \ No newline at end of file diff --git a/src/interface/promotable_const_operands_op_interface.td b/src/interface/promotable_const_operands_op_interface.td new file mode 100644 index 0000000..ccecb09 --- /dev/null +++ b/src/interface/promotable_const_operands_op_interface.td @@ -0,0 +1,34 @@ +//===------------- promotable_const_operands_op_interface.td --------------===// +//===---- Promotable Const Operands Op Interface TableGen Definition ------===// +// +// Copyright 2020 The IBM Research Authors. +// +// ============================================================================= +// +// This file contains the TableGen definition of the promotable const operands +// op interface. +// +//===----------------------------------------------------------------------===// + +#ifdef PROMOTABLE_CONST_OPERANDS_OP_INTERFACE +#else +#define PROMOTABLE_CONST_OPERANDS_OP_INTERFACE + +#ifdef OP_BASE +#else +include "mlir/IR/OpBase.td" +#endif // OP_BASE + +def PromotableConstOperandsOpInterface : OpInterface<"PromotableConstOperandsOpInterface"> { + let description = [{ + Interface to access a registered method to infer the return types for an + operation that can be used during type inference. + }]; + + let methods = [ + InterfaceMethod<"Infer and set the output shape for the current operation.", + "std::map", "promotableConstOperands"> + ]; +} + +#endif // PROMOTABLE_CONST_OPERANDS_OP_INTERFACE diff --git a/src/main.cpp b/src/main.cpp index 0926f99..6bc102a 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -126,6 +126,7 @@ int main(int argc, char *argv[]) { pm.addPass(mlir::createShapeInferencePass()); pm.addPass(mlir::createCanonicalizerPass()); pm.addPass(mlir::createShapeInferencePass()); + pm.addPass(mlir::createAttributePromotionPass()); if (emissionTarget >= EmitMLIR) { pm.addPass(mlir::createLowerToKrnlPass()); diff --git a/src/pass/passes.hpp b/src/pass/passes.hpp index de8fe1d..07a2a95 100644 --- a/src/pass/passes.hpp +++ b/src/pass/passes.hpp @@ -20,6 +20,9 @@ std::unique_ptr createDecomposeONNXToONNXPass(); std::unique_ptr createShapeInferencePass(); +/// Pass for promoting constant operands to attributes. +std::unique_ptr createAttributePromotionPass(); + /// Add pass for lowering to Krnl IR. std::unique_ptr createLowerToKrnlPass(); diff --git a/src/pass/shape_inference_pass.cpp b/src/pass/shape_inference_pass.cpp index 899bfb4..967b85e 100644 --- a/src/pass/shape_inference_pass.cpp +++ b/src/pass/shape_inference_pass.cpp @@ -12,9 +12,9 @@ #include "mlir/Pass/Pass.h" #include "llvm/ADT/SmallPtrSet.h" #include "llvm/Support/raw_ostream.h" +#include "mlir/IR/StandardTypes.h" #include "shape_inference_interface.hpp" -#include "src/dialect/onnx/onnx_ops.hpp" #include "passes.hpp" diff --git a/src/tool/onnx_mlir_opt/CMakeLists.txt b/src/tool/onnx_mlir_opt/CMakeLists.txt index 3c97cc1..9ed3774 100644 --- a/src/tool/onnx_mlir_opt/CMakeLists.txt +++ b/src/tool/onnx_mlir_opt/CMakeLists.txt @@ -4,6 +4,19 @@ add_dependencies(onnx-mlir-opt gen_krnl_ops) target_include_directories(onnx-mlir-opt PRIVATE ${ONNX_MLIR_SRC_ROOT}) target_include_directories(onnx-mlir-opt PRIVATE ${ONNX_MLIR_BIN_ROOT}) -target_link_libraries(onnx-mlir-opt builder ${MLIRLibs} onnx_mlir_transform onnx_mlir_shape_inference onnx_mlir_lower_frontend curses) -whole_archive_link_mlir(onnx-mlir-opt ${MLIRWholeArchiveLibs}) -whole_archive_link_onnx_mlir(onnx-mlir-opt compiler onnx_mlir_transform onnx_mlir_lower_frontend onnx_mlir_shape_inference) +target_link_libraries(onnx-mlir-opt + builder + ${MLIRLibs} + onnx_mlir_transform + onnx_mlir_shape_inference + onnx_mlir_lower_frontend + onnx_mlir_promotable_const_operands_op_interface + curses) +whole_archive_link_mlir(onnx-mlir-opt + ${MLIRWholeArchiveLibs}) +whole_archive_link_onnx_mlir(onnx-mlir-opt + compiler + onnx_mlir_transform + onnx_mlir_lower_frontend + onnx_mlir_shape_inference + onnx_mlir_attribute_promotion) \ No newline at end of file diff --git a/src/transform/CMakeLists.txt b/src/transform/CMakeLists.txt index 1e89333..26c23fa 100644 --- a/src/transform/CMakeLists.txt +++ b/src/transform/CMakeLists.txt @@ -7,3 +7,5 @@ target_include_directories(onnx_mlir_transform ${ONNX_MLIR_SRC_ROOT}) target_link_libraries(onnx_mlir_transform ${MLIRLibs}) add_dependencies(onnx_mlir_transform gen_krnl_ops) + +add_subdirectory(onnx) \ No newline at end of file diff --git a/src/transform/onnx/CMakeLists.txt b/src/transform/onnx/CMakeLists.txt new file mode 100644 index 0000000..3874c63 --- /dev/null +++ b/src/transform/onnx/CMakeLists.txt @@ -0,0 +1,7 @@ +add_library(onnx_mlir_attribute_promotion + attribute_promotion.cpp) +target_include_directories(onnx_mlir_attribute_promotion + PRIVATE ${ONNX_MLIR_SRC_ROOT} ${ONNX_MLIR_BIN_ROOT} + ${ONNF_MLIR_SRC_ROOT}) +target_link_libraries(onnx_mlir_attribute_promotion + onnx_mlir_promotable_const_operands_op_interface) \ No newline at end of file diff --git a/src/transform/onnx/attribute_promotion.cpp b/src/transform/onnx/attribute_promotion.cpp new file mode 100644 index 0000000..da8dc69 --- /dev/null +++ b/src/transform/onnx/attribute_promotion.cpp @@ -0,0 +1,92 @@ +//===----- attribute_promotion.cpp - Attribute Promotion +//-------------------===// +// +// Copyright 2020 The IBM Research Authors. +// +// ============================================================================= +// +// This file implements a function level pass to move an operand to become +// an attribute if desirable and legal. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/StandardOps/Ops.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" + +#include "src/interface/promotable_const_operands_op_interface.hpp" +#include "src/pass/passes.hpp" + +using namespace mlir; + +namespace { + +/*! + * Helper function to create a NoneTyped constant value if `none` is empty. + */ +void getOrCreateNoneValue(llvm::Optional &none, FuncOp f) { + if (none.hasValue()) + return; + + OpBuilder builder(f.getContext()); + builder.setInsertionPointToStart(&f.front()); + none = builder.create(f.getLoc(), builder.getUnitAttr()); +} + +/*! + * FunctionPass that performs attribute promotion by iterating over a list of + * candidate operations and moves constant operands to attributes whenever + * desirable (as instructed by the PromotableConstOperandsOpInterface). + */ +class AttributePromotionPass + : public mlir::FunctionPass { +public: + void runOnFunction() override { + auto f = getFunction(); + + // A function-scope shared none value used to indicate an missing operand. + llvm::Optional none; + + // Iterate on the operations that may need attribute promotion. + f.walk([&](mlir::Operation *op) { + if (PromotableConstOperandsOpInterface opWithConstOperand = + dyn_cast(op)) { + auto promotableOperands = opWithConstOperand.promotableConstOperands(); + for (const auto &operandNameToIdx : promotableOperands) { + auto name = operandNameToIdx.first; + auto i = operandNameToIdx.second; + + // If the i-th operand is defined by an constant operation, then + // move it to an attribute, and use None to indicate the absence + // of the original operand value. + auto operandToPromote = op->getOperand(i); + if (auto constantOp = dyn_cast_or_null( + operandToPromote.getDefiningOp())) { + op->setAttr(name, constantOp.value()); + getOrCreateNoneValue(none, f); + op->setOperand(i, *none); + } + } + } + }); + + // Dispatch canonicalization pattern rewriters to eliminate redundant + // constant operaions. + OwningRewritePatternList patterns; + auto *context = &getContext(); + ConstantOp::getCanonicalizationPatterns(patterns, context); + applyPatternsGreedily(f, patterns); + } +}; +} // end anonymous namespace + +/*! + * Create a Attribute Promotion pass. + */ +std::unique_ptr mlir::createAttributePromotionPass() { + return std::make_unique(); +} + +static PassRegistration pass( + "attribute-promotion", "Promote constant operands to attributes."); diff --git a/test/mlir/transform/attribute_promotion.mlir b/test/mlir/transform/attribute_promotion.mlir new file mode 100644 index 0000000..a7555fa --- /dev/null +++ b/test/mlir/transform/attribute_promotion.mlir @@ -0,0 +1,32 @@ +// RUN: onnx-mlir-opt --attribute-promotion %s -split-input-file | FileCheck %s + +func @test_should_promote_to_attribute(%arg0 : tensor) -> tensor<*xf32> { + %shape = constant dense<[6, 7, 42]> : tensor<3xi32> + %0 = "onnx.Reshape"(%arg0, %shape) : (tensor, tensor<3xi32>) -> tensor<*xf32> + return %0 : tensor<*xf32> + // CHECK-LABEL: test_should_promote_to_attribute + // CHECK-NEXT: [[NONE:%.+]] = constant unit + // CHECK-NEXT: [[RESHAPE:%.+]] = "onnx.Reshape"(%{{.*}}, [[NONE]]) {shape = dense<[6, 7, 42]> : tensor<3xi32>} : (tensor, none) -> tensor<*xf32> + // CHECK-NEXT: return [[RESHAPE]] : tensor<*xf32> +} + +func @test_should_not_promote_to_attribute(%arg0 : tensor, %arg1 : tensor<*xi64>) -> tensor<*xf32> { + %0 = "onnx.Reshape"(%arg0, %arg1) : (tensor, tensor<*xi64>) -> tensor<*xf32> + return %0 : tensor<*xf32> + // CHECK-LABEL: test_should_not_promote_to_attribute + // CHECK-NEXT: [[RESHAPE:%.+]] = "onnx.Reshape"(%{{.*}}, %{{.*}}) : (tensor, tensor<*xi64>) -> tensor<*xf32> + // CHECK-NEXT: return [[RESHAPE]] : tensor<*xf32> +} + +func @test_promote_to_attribute_without_removing_const_op(%arg0 : tensor) -> (tensor<*xf32>, tensor<*xf32>) { + %shape = constant dense<[6, 7, 42]> : tensor<3xi32> + %0 = "onnx.Reshape"(%arg0, %shape) : (tensor, tensor<3xi32>) -> tensor<*xf32> + %1 = "onnx.Identity"(%shape) : (tensor<3xi32>) -> tensor<*xf32> + "std.return"(%0, %1) : (tensor<*xf32>, tensor<*xf32>) -> () + // CHECK-LABEL: test_promote_to_attribute_without_removing_const_op + // CHECK-NEXT: [[NONE:%.+]] = constant unit + // CHECK-NEXT: [[SHAPE:%.+]] = constant dense<[6, 7, 42]> : tensor<3xi32> + // CHECK-NEXT: [[RESHAPE:%.+]] = "onnx.Reshape"(%{{.*}}, [[NONE]]) {shape = dense<[6, 7, 42]> : tensor<3xi32>} : (tensor, none) -> tensor<*xf32> + // CHECK-NEXT: [[IDENTITY:%.+]] = "onnx.Identity"([[SHAPE]]) : (tensor<3xi32>) -> tensor<*xf32> + // CHECK-NEXT: return [[RESHAPE]], [[IDENTITY]] : tensor<*xf32>, tensor<*xf32> +} \ No newline at end of file