From e2af50574600fc0a2fb2050c2685965a7b65a841 Mon Sep 17 00:00:00 2001 From: Alexandre Eichenberger Date: Mon, 8 Jun 2020 15:45:32 -0400 Subject: [PATCH] Constprop (#162) * initial const prop attempt * added support for broadcast ops * adde all binary broadcast ops into custom builders with precise type * added test example * working * format * fixed suggestion by Tung, start woring on unary * added subtraction and neg the right way, and added elementwise mul too * formatting changes * format * format * added instructions to add new optimizations --- .gitignore | 4 + docs/ImportONNXDefs.md | 6 +- src/Dialect/ONNX/ONNXOps.td.inc | 502 +++++++++++++++++++++------ src/MainUtils.cpp | 1 + src/Pass/Passes.hpp | 2 + src/Transform/ONNX/CMakeLists.txt | 10 +- src/Transform/ONNX/ONNXConstProp.cpp | 335 ++++++++++++++++++ src/Transform/ONNX/ONNXConstProp.td | 161 +++++++++ test/mlir/onnx/onnx_constprop.mlir | 158 +++++++++ utils/gen_onnx_mlir.py | 61 +++- 10 files changed, 1111 insertions(+), 129 deletions(-) create mode 100644 src/Transform/ONNX/ONNXConstProp.cpp create mode 100644 src/Transform/ONNX/ONNXConstProp.td create mode 100644 test/mlir/onnx/onnx_constprop.mlir diff --git a/.gitignore b/.gitignore index 6c866a7..dec7560 100644 --- a/.gitignore +++ b/.gitignore @@ -116,6 +116,7 @@ docs/_build/ # PyBuilder target/ +utils/ONNXOps.td.inc # Jupyter Notebook .ipynb_checkpoints @@ -175,3 +176,6 @@ dmypy.json # pytype static type analyzer .pytype/ + +#editor +*~ diff --git a/docs/ImportONNXDefs.md b/docs/ImportONNXDefs.md index 9432dc0..e8e62cd 100644 --- a/docs/ImportONNXDefs.md +++ b/docs/ImportONNXDefs.md @@ -1,9 +1,9 @@ # Import ONNX specifications into ONNX-MLIR ONNX specifications are defined under `onnx/defs` directory in the ONNX project repository. -There is a python script onnx/defs/gen_doc.py that automatically generate documents about operations in ONNX (docs/Operations.md). +There is a python script onnx/defs/gen_onnx_mlir.py that automatically generate documents about operations in ONNX (docs/Operations.md). ONNX-MLIR modified this script to import ONNX specifications into ONNX-MLIR. -There are two files generated for ONNX MLIR with the modified gen_doc.py: +There are two files generated for ONNX MLIR with the modified gen_onnx_mlir.py: 1. `src/Dialect/ONNX/ONNXOps.td.inc`: Operation definition for MLIR TableGen. `src/Dialect/ONNX/ONNXOps.td` includes this file. 2. `src/Builder/OpBuildTable.inc`: C++ code for ONNX-MLIR frontend to import operation nodes from ONNX model. `src/Builder/FrontendDialectTransformer.cpp` includes this file. @@ -22,7 +22,7 @@ Even though we strive to support the latest version of ONNX specification as qui Due to the possibility of such a delay, operator definition within the ONNX project repository may describe features and schemas that we do not yet support. ## Customization -In addition to following the ONNX specification, the script gen_onnx_mlir.py, modified gen_doc.py, provides some mechanism for you to customize the output. +In addition to following the ONNX specification, the script gen_onnx_mlir.py, modified gen_onnx_mlir.py, provides some mechanism for you to customize the output. Several tables are defined at the beginning of the script: 1. `special_attr_defaults`: gives attribute special default value. 2. `special_op_handler`: creates special import function in frontend_dialect_transformer.cpp. Currently, a special handler is used for operations with operational arguments diff --git a/src/Dialect/ONNX/ONNXOps.td.inc b/src/Dialect/ONNX/ONNXOps.td.inc index 03823de..5ac237b 100644 --- a/src/Dialect/ONNX/ONNXOps.td.inc +++ b/src/Dialect/ONNX/ONNXOps.td.inc @@ -93,17 +93,43 @@ def ONNXAddOp:ONNX_Op<"Add", let arguments = (ins AnyTypeOf<[TensorOf<[I32,I64,F16,F32,F64]>, MemRefOf<[I32,I64,F16,F32,F64]>]>:$A, AnyTypeOf<[TensorOf<[I32,I64,F16,F32,F64]>, MemRefOf<[I32,I64,F16,F32,F64]>]>:$B); let results = (outs AnyTypeOf<[TensorOf<[I32,I64,F16,F32,F64]>, MemRefOf<[I32,I64,F16,F32,F64]>]>:$C); - let extraClassDeclaration = [{ - static int getNumberOfOperands() { - return 2; - } - static int getNumberOfResults() { - return 1; - } - static std::vector getTypeMap() { - return {20}; - } - }]; + let builders = [ + OpBuilder<"OpBuilder &builder, OperationState &state, Value A, Value B", [{ + auto lhsTy = A.getType().cast(); + auto rhsTy = B.getType().cast(); + auto elementType = getBroadcastedType(lhsTy, rhsTy); + auto shapedType = elementType.dyn_cast_or_null(); + if (!shapedType || !shapedType.hasStaticShape()) { + elementType = A.getType().cast().getElementType(); + elementType = UnrankedTensorType::get(elementType); + } + build(builder, state, elementType, A, B); + }]>, + OpBuilder<"OpBuilder &builder, OperationState &state, ValueRange operands, ArrayRef attributes", [{ + auto lhsTy = operands[0].getType().cast(); + auto rhsTy = operands[1].getType().cast(); + auto elementType = getBroadcastedType(lhsTy, rhsTy); + auto shapedType = elementType.dyn_cast_or_null(); + if (!shapedType || !shapedType.hasStaticShape()) { + elementType = operands[0].getType().cast().getElementType(); + elementType = UnrankedTensorType::get(elementType); + } + std::vector outputTypes; + outputTypes.emplace_back(elementType); + build(builder, state, outputTypes, operands, attributes); + }]> + ]; + let extraClassDeclaration = [{ + static int getNumberOfOperands() { + return 2; + } + static int getNumberOfResults() { + return 1; + } + static std::vector getTypeMap() { + return {20}; + } + }]; } def ONNXAndOp:ONNX_Op<"And", @@ -118,17 +144,43 @@ def ONNXAndOp:ONNX_Op<"And", let arguments = (ins AnyTypeOf<[TensorOf<[I1]>, MemRefOf<[I1]>]>:$A, AnyTypeOf<[TensorOf<[I1]>, MemRefOf<[I1]>]>:$B); let results = (outs AnyTypeOf<[TensorOf<[I1]>, MemRefOf<[I1]>]>:$C); - let extraClassDeclaration = [{ - static int getNumberOfOperands() { - return 2; - } - static int getNumberOfResults() { - return 1; - } - static std::vector getTypeMap() { - return {0}; - } - }]; + let builders = [ + OpBuilder<"OpBuilder &builder, OperationState &state, Value A, Value B", [{ + auto lhsTy = A.getType().cast(); + auto rhsTy = B.getType().cast(); + auto elementType = getBroadcastedType(lhsTy, rhsTy); + auto shapedType = elementType.dyn_cast_or_null(); + if (!shapedType || !shapedType.hasStaticShape()) { + elementType = A.getType().cast().getElementType(); + elementType = UnrankedTensorType::get(elementType); + } + build(builder, state, elementType, A, B); + }]>, + OpBuilder<"OpBuilder &builder, OperationState &state, ValueRange operands, ArrayRef attributes", [{ + auto lhsTy = operands[0].getType().cast(); + auto rhsTy = operands[1].getType().cast(); + auto elementType = getBroadcastedType(lhsTy, rhsTy); + auto shapedType = elementType.dyn_cast_or_null(); + if (!shapedType || !shapedType.hasStaticShape()) { + elementType = operands[0].getType().cast().getElementType(); + elementType = UnrankedTensorType::get(elementType); + } + std::vector outputTypes; + outputTypes.emplace_back(elementType); + build(builder, state, outputTypes, operands, attributes); + }]> + ]; + let extraClassDeclaration = [{ + static int getNumberOfOperands() { + return 2; + } + static int getNumberOfResults() { + return 1; + } + static std::vector getTypeMap() { + return {0}; + } + }]; } def ONNXArgMaxOp:ONNX_Op<"ArgMax", @@ -933,17 +985,43 @@ def ONNXDivOp:ONNX_Op<"Div", let arguments = (ins AnyTypeOf<[TensorOf<[I32,I64,F16,F32,F64]>, MemRefOf<[I32,I64,F16,F32,F64]>]>:$A, AnyTypeOf<[TensorOf<[I32,I64,F16,F32,F64]>, MemRefOf<[I32,I64,F16,F32,F64]>]>:$B); let results = (outs AnyTypeOf<[TensorOf<[I32,I64,F16,F32,F64]>, MemRefOf<[I32,I64,F16,F32,F64]>]>:$C); - let extraClassDeclaration = [{ - static int getNumberOfOperands() { - return 2; - } - static int getNumberOfResults() { - return 1; - } - static std::vector getTypeMap() { - return {20}; - } - }]; + let builders = [ + OpBuilder<"OpBuilder &builder, OperationState &state, Value A, Value B", [{ + auto lhsTy = A.getType().cast(); + auto rhsTy = B.getType().cast(); + auto elementType = getBroadcastedType(lhsTy, rhsTy); + auto shapedType = elementType.dyn_cast_or_null(); + if (!shapedType || !shapedType.hasStaticShape()) { + elementType = A.getType().cast().getElementType(); + elementType = UnrankedTensorType::get(elementType); + } + build(builder, state, elementType, A, B); + }]>, + OpBuilder<"OpBuilder &builder, OperationState &state, ValueRange operands, ArrayRef attributes", [{ + auto lhsTy = operands[0].getType().cast(); + auto rhsTy = operands[1].getType().cast(); + auto elementType = getBroadcastedType(lhsTy, rhsTy); + auto shapedType = elementType.dyn_cast_or_null(); + if (!shapedType || !shapedType.hasStaticShape()) { + elementType = operands[0].getType().cast().getElementType(); + elementType = UnrankedTensorType::get(elementType); + } + std::vector outputTypes; + outputTypes.emplace_back(elementType); + build(builder, state, outputTypes, operands, attributes); + }]> + ]; + let extraClassDeclaration = [{ + static int getNumberOfOperands() { + return 2; + } + static int getNumberOfResults() { + return 1; + } + static std::vector getTypeMap() { + return {20}; + } + }]; } def ONNXDropoutOp:ONNX_Op<"Dropout", @@ -1055,17 +1133,43 @@ def ONNXEqualOp:ONNX_Op<"Equal", let arguments = (ins AnyTypeOf<[TensorOf<[I1,I8,I16,I32,I64,F16,F32,F64]>, MemRefOf<[I1,I8,I16,I32,I64,F16,F32,F64]>]>:$A, AnyTypeOf<[TensorOf<[I1,I8,I16,I32,I64,F16,F32,F64]>, MemRefOf<[I1,I8,I16,I32,I64,F16,F32,F64]>]>:$B); let results = (outs AnyTypeOf<[TensorOf<[I1]>, MemRefOf<[I1]>]>:$C); - let extraClassDeclaration = [{ - static int getNumberOfOperands() { - return 2; - } - static int getNumberOfResults() { - return 1; - } - static std::vector getTypeMap() { - return {0}; - } - }]; + let builders = [ + OpBuilder<"OpBuilder &builder, OperationState &state, Value A, Value B", [{ + auto lhsTy = A.getType().cast(); + auto rhsTy = B.getType().cast(); + auto elementType = getBroadcastedType(lhsTy, rhsTy); + auto shapedType = elementType.dyn_cast_or_null(); + if (!shapedType || !shapedType.hasStaticShape()) { + elementType = A.getType().cast().getElementType(); + elementType = UnrankedTensorType::get(elementType); + } + build(builder, state, elementType, A, B); + }]>, + OpBuilder<"OpBuilder &builder, OperationState &state, ValueRange operands, ArrayRef attributes", [{ + auto lhsTy = operands[0].getType().cast(); + auto rhsTy = operands[1].getType().cast(); + auto elementType = getBroadcastedType(lhsTy, rhsTy); + auto shapedType = elementType.dyn_cast_or_null(); + if (!shapedType || !shapedType.hasStaticShape()) { + elementType = operands[0].getType().cast().getElementType(); + elementType = UnrankedTensorType::get(elementType); + } + std::vector outputTypes; + outputTypes.emplace_back(elementType); + build(builder, state, outputTypes, operands, attributes); + }]> + ]; + let extraClassDeclaration = [{ + static int getNumberOfOperands() { + return 2; + } + static int getNumberOfResults() { + return 1; + } + static std::vector getTypeMap() { + return {0}; + } + }]; } def ONNXErfOp:ONNX_Op<"Erf", @@ -1697,17 +1801,43 @@ def ONNXGreaterOp:ONNX_Op<"Greater", let arguments = (ins AnyTypeOf<[TensorOf<[I8,I16,I32,I64,F16,F32,F64]>, MemRefOf<[I8,I16,I32,I64,F16,F32,F64]>]>:$A, AnyTypeOf<[TensorOf<[I8,I16,I32,I64,F16,F32,F64]>, MemRefOf<[I8,I16,I32,I64,F16,F32,F64]>]>:$B); let results = (outs AnyTypeOf<[TensorOf<[I1]>, MemRefOf<[I1]>]>:$C); - let extraClassDeclaration = [{ - static int getNumberOfOperands() { - return 2; - } - static int getNumberOfResults() { - return 1; - } - static std::vector getTypeMap() { - return {0}; - } - }]; + let builders = [ + OpBuilder<"OpBuilder &builder, OperationState &state, Value A, Value B", [{ + auto lhsTy = A.getType().cast(); + auto rhsTy = B.getType().cast(); + auto elementType = getBroadcastedType(lhsTy, rhsTy); + auto shapedType = elementType.dyn_cast_or_null(); + if (!shapedType || !shapedType.hasStaticShape()) { + elementType = A.getType().cast().getElementType(); + elementType = UnrankedTensorType::get(elementType); + } + build(builder, state, elementType, A, B); + }]>, + OpBuilder<"OpBuilder &builder, OperationState &state, ValueRange operands, ArrayRef attributes", [{ + auto lhsTy = operands[0].getType().cast(); + auto rhsTy = operands[1].getType().cast(); + auto elementType = getBroadcastedType(lhsTy, rhsTy); + auto shapedType = elementType.dyn_cast_or_null(); + if (!shapedType || !shapedType.hasStaticShape()) { + elementType = operands[0].getType().cast().getElementType(); + elementType = UnrankedTensorType::get(elementType); + } + std::vector outputTypes; + outputTypes.emplace_back(elementType); + build(builder, state, outputTypes, operands, attributes); + }]> + ]; + let extraClassDeclaration = [{ + static int getNumberOfOperands() { + return 2; + } + static int getNumberOfResults() { + return 1; + } + static std::vector getTypeMap() { + return {0}; + } + }]; } def ONNXHardSigmoidOp:ONNX_Op<"HardSigmoid", @@ -2075,17 +2205,43 @@ def ONNXLessOp:ONNX_Op<"Less", let arguments = (ins AnyTypeOf<[TensorOf<[I8,I16,I32,I64,F16,F32,F64]>, MemRefOf<[I8,I16,I32,I64,F16,F32,F64]>]>:$A, AnyTypeOf<[TensorOf<[I8,I16,I32,I64,F16,F32,F64]>, MemRefOf<[I8,I16,I32,I64,F16,F32,F64]>]>:$B); let results = (outs AnyTypeOf<[TensorOf<[I1]>, MemRefOf<[I1]>]>:$C); - let extraClassDeclaration = [{ - static int getNumberOfOperands() { - return 2; - } - static int getNumberOfResults() { - return 1; - } - static std::vector getTypeMap() { - return {0}; - } - }]; + let builders = [ + OpBuilder<"OpBuilder &builder, OperationState &state, Value A, Value B", [{ + auto lhsTy = A.getType().cast(); + auto rhsTy = B.getType().cast(); + auto elementType = getBroadcastedType(lhsTy, rhsTy); + auto shapedType = elementType.dyn_cast_or_null(); + if (!shapedType || !shapedType.hasStaticShape()) { + elementType = A.getType().cast().getElementType(); + elementType = UnrankedTensorType::get(elementType); + } + build(builder, state, elementType, A, B); + }]>, + OpBuilder<"OpBuilder &builder, OperationState &state, ValueRange operands, ArrayRef attributes", [{ + auto lhsTy = operands[0].getType().cast(); + auto rhsTy = operands[1].getType().cast(); + auto elementType = getBroadcastedType(lhsTy, rhsTy); + auto shapedType = elementType.dyn_cast_or_null(); + if (!shapedType || !shapedType.hasStaticShape()) { + elementType = operands[0].getType().cast().getElementType(); + elementType = UnrankedTensorType::get(elementType); + } + std::vector outputTypes; + outputTypes.emplace_back(elementType); + build(builder, state, outputTypes, operands, attributes); + }]> + ]; + let extraClassDeclaration = [{ + static int getNumberOfOperands() { + return 2; + } + static int getNumberOfResults() { + return 1; + } + static std::vector getTypeMap() { + return {0}; + } + }]; } def ONNXLogOp:ONNX_Op<"Log", @@ -2646,13 +2802,27 @@ def ONNXMulOp:ONNX_Op<"Mul", let results = (outs AnyTypeOf<[TensorOf<[I32,I64,F16,F32,F64]>, MemRefOf<[I32,I64,F16,F32,F64]>]>:$C); let builders = [ OpBuilder<"OpBuilder &builder, OperationState &state, Value A, Value B", [{ - auto elementType = A.getType().cast().getElementType(); - build(builder, state, UnrankedTensorType::get(elementType), A, B); + auto lhsTy = A.getType().cast(); + auto rhsTy = B.getType().cast(); + auto elementType = getBroadcastedType(lhsTy, rhsTy); + auto shapedType = elementType.dyn_cast_or_null(); + if (!shapedType || !shapedType.hasStaticShape()) { + elementType = A.getType().cast().getElementType(); + elementType = UnrankedTensorType::get(elementType); + } + build(builder, state, elementType, A, B); }]>, OpBuilder<"OpBuilder &builder, OperationState &state, ValueRange operands, ArrayRef attributes", [{ - auto elementType = operands[0].getType().cast().getElementType(); + auto lhsTy = operands[0].getType().cast(); + auto rhsTy = operands[1].getType().cast(); + auto elementType = getBroadcastedType(lhsTy, rhsTy); + auto shapedType = elementType.dyn_cast_or_null(); + if (!shapedType || !shapedType.hasStaticShape()) { + elementType = operands[0].getType().cast().getElementType(); + elementType = UnrankedTensorType::get(elementType); + } std::vector outputTypes; - outputTypes.emplace_back(UnrankedTensorType::get(elementType)); + outputTypes.emplace_back(elementType); build(builder, state, outputTypes, operands, attributes); }]> ]; @@ -2848,17 +3018,43 @@ def ONNXOrOp:ONNX_Op<"Or", let arguments = (ins AnyTypeOf<[TensorOf<[I1]>, MemRefOf<[I1]>]>:$A, AnyTypeOf<[TensorOf<[I1]>, MemRefOf<[I1]>]>:$B); let results = (outs AnyTypeOf<[TensorOf<[I1]>, MemRefOf<[I1]>]>:$C); - let extraClassDeclaration = [{ - static int getNumberOfOperands() { - return 2; - } - static int getNumberOfResults() { - return 1; - } - static std::vector getTypeMap() { - return {0}; - } - }]; + let builders = [ + OpBuilder<"OpBuilder &builder, OperationState &state, Value A, Value B", [{ + auto lhsTy = A.getType().cast(); + auto rhsTy = B.getType().cast(); + auto elementType = getBroadcastedType(lhsTy, rhsTy); + auto shapedType = elementType.dyn_cast_or_null(); + if (!shapedType || !shapedType.hasStaticShape()) { + elementType = A.getType().cast().getElementType(); + elementType = UnrankedTensorType::get(elementType); + } + build(builder, state, elementType, A, B); + }]>, + OpBuilder<"OpBuilder &builder, OperationState &state, ValueRange operands, ArrayRef attributes", [{ + auto lhsTy = operands[0].getType().cast(); + auto rhsTy = operands[1].getType().cast(); + auto elementType = getBroadcastedType(lhsTy, rhsTy); + auto shapedType = elementType.dyn_cast_or_null(); + if (!shapedType || !shapedType.hasStaticShape()) { + elementType = operands[0].getType().cast().getElementType(); + elementType = UnrankedTensorType::get(elementType); + } + std::vector outputTypes; + outputTypes.emplace_back(elementType); + build(builder, state, outputTypes, operands, attributes); + }]> + ]; + let extraClassDeclaration = [{ + static int getNumberOfOperands() { + return 2; + } + static int getNumberOfResults() { + return 1; + } + static std::vector getTypeMap() { + return {0}; + } + }]; } def ONNXPReluOp:ONNX_Op<"PRelu", @@ -3017,17 +3213,43 @@ def ONNXPowOp:ONNX_Op<"Pow", let arguments = (ins AnyTypeOf<[TensorOf<[F16,F32,F64]>, MemRefOf<[F16,F32,F64]>]>:$X, AnyTypeOf<[TensorOf<[F16,F32,F64]>, MemRefOf<[F16,F32,F64]>]>:$Y); let results = (outs AnyTypeOf<[TensorOf<[F16,F32,F64]>, MemRefOf<[F16,F32,F64]>]>:$Z); - let extraClassDeclaration = [{ - static int getNumberOfOperands() { - return 2; - } - static int getNumberOfResults() { - return 1; - } - static std::vector getTypeMap() { - return {20}; - } - }]; + let builders = [ + OpBuilder<"OpBuilder &builder, OperationState &state, Value X, Value Y", [{ + auto lhsTy = X.getType().cast(); + auto rhsTy = Y.getType().cast(); + auto elementType = getBroadcastedType(lhsTy, rhsTy); + auto shapedType = elementType.dyn_cast_or_null(); + if (!shapedType || !shapedType.hasStaticShape()) { + elementType = X.getType().cast().getElementType(); + elementType = UnrankedTensorType::get(elementType); + } + build(builder, state, elementType, X, Y); + }]>, + OpBuilder<"OpBuilder &builder, OperationState &state, ValueRange operands, ArrayRef attributes", [{ + auto lhsTy = operands[0].getType().cast(); + auto rhsTy = operands[1].getType().cast(); + auto elementType = getBroadcastedType(lhsTy, rhsTy); + auto shapedType = elementType.dyn_cast_or_null(); + if (!shapedType || !shapedType.hasStaticShape()) { + elementType = operands[0].getType().cast().getElementType(); + elementType = UnrankedTensorType::get(elementType); + } + std::vector outputTypes; + outputTypes.emplace_back(elementType); + build(builder, state, outputTypes, operands, attributes); + }]> + ]; + let extraClassDeclaration = [{ + static int getNumberOfOperands() { + return 2; + } + static int getNumberOfResults() { + return 1; + } + static std::vector getTypeMap() { + return {20}; + } + }]; } def ONNXQLinearConvOp:ONNX_Op<"QLinearConv", @@ -4938,17 +5160,43 @@ def ONNXSubOp:ONNX_Op<"Sub", let arguments = (ins AnyTypeOf<[TensorOf<[I32,I64,F16,F32,F64]>, MemRefOf<[I32,I64,F16,F32,F64]>]>:$A, AnyTypeOf<[TensorOf<[I32,I64,F16,F32,F64]>, MemRefOf<[I32,I64,F16,F32,F64]>]>:$B); let results = (outs AnyTypeOf<[TensorOf<[I32,I64,F16,F32,F64]>, MemRefOf<[I32,I64,F16,F32,F64]>]>:$C); - let extraClassDeclaration = [{ - static int getNumberOfOperands() { - return 2; - } - static int getNumberOfResults() { - return 1; - } - static std::vector getTypeMap() { - return {20}; - } - }]; + let builders = [ + OpBuilder<"OpBuilder &builder, OperationState &state, Value A, Value B", [{ + auto lhsTy = A.getType().cast(); + auto rhsTy = B.getType().cast(); + auto elementType = getBroadcastedType(lhsTy, rhsTy); + auto shapedType = elementType.dyn_cast_or_null(); + if (!shapedType || !shapedType.hasStaticShape()) { + elementType = A.getType().cast().getElementType(); + elementType = UnrankedTensorType::get(elementType); + } + build(builder, state, elementType, A, B); + }]>, + OpBuilder<"OpBuilder &builder, OperationState &state, ValueRange operands, ArrayRef attributes", [{ + auto lhsTy = operands[0].getType().cast(); + auto rhsTy = operands[1].getType().cast(); + auto elementType = getBroadcastedType(lhsTy, rhsTy); + auto shapedType = elementType.dyn_cast_or_null(); + if (!shapedType || !shapedType.hasStaticShape()) { + elementType = operands[0].getType().cast().getElementType(); + elementType = UnrankedTensorType::get(elementType); + } + std::vector outputTypes; + outputTypes.emplace_back(elementType); + build(builder, state, outputTypes, operands, attributes); + }]> + ]; + let extraClassDeclaration = [{ + static int getNumberOfOperands() { + return 2; + } + static int getNumberOfResults() { + return 1; + } + static std::vector getTypeMap() { + return {20}; + } + }]; } def ONNXSumOp:ONNX_Op<"Sum", @@ -5379,16 +5627,42 @@ def ONNXXorOp:ONNX_Op<"Xor", let arguments = (ins AnyTypeOf<[TensorOf<[I1]>, MemRefOf<[I1]>]>:$A, AnyTypeOf<[TensorOf<[I1]>, MemRefOf<[I1]>]>:$B); let results = (outs AnyTypeOf<[TensorOf<[I1]>, MemRefOf<[I1]>]>:$C); - let extraClassDeclaration = [{ - static int getNumberOfOperands() { - return 2; - } - static int getNumberOfResults() { - return 1; - } - static std::vector getTypeMap() { - return {0}; - } - }]; + let builders = [ + OpBuilder<"OpBuilder &builder, OperationState &state, Value A, Value B", [{ + auto lhsTy = A.getType().cast(); + auto rhsTy = B.getType().cast(); + auto elementType = getBroadcastedType(lhsTy, rhsTy); + auto shapedType = elementType.dyn_cast_or_null(); + if (!shapedType || !shapedType.hasStaticShape()) { + elementType = A.getType().cast().getElementType(); + elementType = UnrankedTensorType::get(elementType); + } + build(builder, state, elementType, A, B); + }]>, + OpBuilder<"OpBuilder &builder, OperationState &state, ValueRange operands, ArrayRef attributes", [{ + auto lhsTy = operands[0].getType().cast(); + auto rhsTy = operands[1].getType().cast(); + auto elementType = getBroadcastedType(lhsTy, rhsTy); + auto shapedType = elementType.dyn_cast_or_null(); + if (!shapedType || !shapedType.hasStaticShape()) { + elementType = operands[0].getType().cast().getElementType(); + elementType = UnrankedTensorType::get(elementType); + } + std::vector outputTypes; + outputTypes.emplace_back(elementType); + build(builder, state, outputTypes, operands, attributes); + }]> + ]; + let extraClassDeclaration = [{ + static int getNumberOfOperands() { + return 2; + } + static int getNumberOfResults() { + return 1; + } + static std::vector getTypeMap() { + return {0}; + } + }]; } diff --git a/src/MainUtils.cpp b/src/MainUtils.cpp index 47ce0d0..175ff1f 100644 --- a/src/MainUtils.cpp +++ b/src/MainUtils.cpp @@ -88,6 +88,7 @@ void registerDialects() { void addONNXToMLIRPasses(mlir::PassManager &pm) { pm.addPass(mlir::createDecomposeONNXToONNXPass()); + pm.addPass(mlir::createConstPropONNXToONNXPass()); pm.addPass(mlir::createShapeInferencePass()); pm.addPass(mlir::createCanonicalizerPass()); pm.addPass(mlir::createAttributePromotionPass()); diff --git a/src/Pass/Passes.hpp b/src/Pass/Passes.hpp index 8ccb537..7a80240 100644 --- a/src/Pass/Passes.hpp +++ b/src/Pass/Passes.hpp @@ -20,6 +20,8 @@ std::unique_ptr createDecomposeONNXToONNXPass(); std::unique_ptr createShapeInferencePass(); +std::unique_ptr createConstPropONNXToONNXPass(); + /// Pass for promoting constant operands to attributes. std::unique_ptr createAttributePromotionPass(); diff --git a/src/Transform/ONNX/CMakeLists.txt b/src/Transform/ONNX/CMakeLists.txt index e2243ff..1f7f896 100644 --- a/src/Transform/ONNX/CMakeLists.txt +++ b/src/Transform/ONNX/CMakeLists.txt @@ -33,17 +33,23 @@ set(LLVM_TARGET_DEFINITIONS ONNXDecompose.td) onnx_mlir_tablegen(ONNXDecompose.inc -gen-rewriters) add_public_tablegen_target(OMONNXDecomposeIncGen) +set(LLVM_TARGET_DEFINITIONS ONNXConstProp.td) +onnx_mlir_tablegen(ONNXConstProp.inc -gen-rewriters) +add_public_tablegen_target(OMONNXConstPropIncGen) + add_library(OMONNXRewrite ONNXRewrite.cpp ONNXCombine.cpp - ONNXDecompose.cpp) + ONNXDecompose.cpp + ONNXConstProp.cpp) target_include_directories(OMONNXRewrite PRIVATE ${ONNX_MLIR_SRC_ROOT} ${ONNX_MLIR_BIN_ROOT} ${ONNF_MLIR_SRC_ROOT}) add_dependencies(OMONNXRewrite OMONNXRewriteIncGen OMONNXDecomposeIncGen - OMONNXCombineIncGen) + OMONNXCombineIncGen + OMONNXConstPropIncGen) # Linking dependencies: add_dependencies(OMONNXRewrite OMONNXOps) diff --git a/src/Transform/ONNX/ONNXConstProp.cpp b/src/Transform/ONNX/ONNXConstProp.cpp new file mode 100644 index 0000000..2ff12bc --- /dev/null +++ b/src/Transform/ONNX/ONNXConstProp.cpp @@ -0,0 +1,335 @@ +//===----------- ONNXConstProp.cpp - ONNX High Level Rewriting ------------===// +// +// Copyright 2019-2020 The IBM Research Authors. +// +// ============================================================================= +// +// This file implements a set of rewriters to constprop an ONNX operation into +// composition of other ONNX operations. +// +// This pass is applied before any other pass so that there is no need to +// implement shape inference for the constpropd operation. Hence, it is expected +// that there is no knowledge about tensor shape at this point +// +//===----------------------------------------------------------------------===// + +#include "mlir/IR/Matchers.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" + +#include "src/Dialect/ONNX/ONNXOps.hpp" +#include "src/Pass/Passes.hpp" + +using namespace mlir; + +namespace { + +// ============================================================================= +// Instructions to add a constant operation. There is currently support for +// adding constant propagation for unary and binary athythmetic ops (binary ops +// support broadcast). To add an operation, you simply have to add a templated +// method on how to compute the result in terms of one or two inputs. Values +// comes as Attribtues, and return is also an Attribute. In that function, +// presumably you will need different methods to handle int / float / +// strings... Note that these methods cannot fail. It is your responsablitity to +// tests for which data type are supported in the rules directly. Specific type +// restrictions can be added in the DRR files. + +// The methods are: +// +// ComputeConstProppElementwiseBinary and ComputeConstProppElementwiseUnary +// and they need to be tempalted wtih an ONNX Operation (presuably). +// +// Then you need to add rules on how to transform the patterns; look into +// ONNXConstProp.td for example. +// +// ============================================================================= + +// ============================================================================= +// Code to perform constant propagation for binary in presence of broadcast. +// ============================================================================= + +// Template to generate binary operation results. It takes as inupt +// the element type as well as the two element attributes for the +// operation, and return the result of the operation, also as an +// attribute. + +template +Attribute ComputeConstProppElementwiseBinary(PatternRewriter &rewriter, + Type elementType, Attribute &lhsAttr, Attribute &secondAttr) { + llvm_unreachable("unkonwn operation"); +} + +template <> +Attribute ComputeConstProppElementwiseBinary( + PatternRewriter &rewriter, Type elementType, Attribute &lhsAttr, + Attribute &secondAttr) { + if (elementType.isa()) { + double lhsVal = lhsAttr.cast().getValueAsDouble(); + double rhsVal = secondAttr.cast().getValueAsDouble(); + double res = lhsVal + rhsVal; + // printf(" %f + %f -> %f\n", lhsVal, rhsVal, res); + // Could use the APFloat interface to emulate the results, are ok to simply + // perform them in the highest possible precision. + return rewriter.getFloatAttr(elementType, res); + } + if (elementType.isa()) { + uint64_t lhsVal = lhsAttr.cast().getInt(); + uint64_t rhsVal = secondAttr.cast().getInt(); + uint64_t res = lhsVal + rhsVal; + // printf(" %llu + %llu -> %llu\n", lhsVal, rhsVal, res); + return rewriter.getIntegerAttr(elementType, res); + } + llvm_unreachable("constant propagation for AddOp: unkonwn data type"); +} + +template <> +Attribute ComputeConstProppElementwiseBinary( + PatternRewriter &rewriter, Type elementType, Attribute &lhsAttr, + Attribute &secondAttr) { + if (elementType.isa()) { + double lhsVal = lhsAttr.cast().getValueAsDouble(); + double rhsVal = secondAttr.cast().getValueAsDouble(); + double res = lhsVal - rhsVal; + return rewriter.getFloatAttr(elementType, res); + } + if (elementType.isa()) { + uint64_t lhsVal = lhsAttr.cast().getInt(); + uint64_t rhsVal = secondAttr.cast().getInt(); + uint64_t res = lhsVal - rhsVal; + return rewriter.getIntegerAttr(elementType, res); + } + llvm_unreachable("constant propagation for SubOp: unkonwn data type"); +} + +template <> +Attribute ComputeConstProppElementwiseBinary( + PatternRewriter &rewriter, Type elementType, Attribute &lhsAttr, + Attribute &secondAttr) { + if (elementType.isa()) { + double lhsVal = lhsAttr.cast().getValueAsDouble(); + double rhsVal = secondAttr.cast().getValueAsDouble(); + double res = lhsVal * rhsVal; + return rewriter.getFloatAttr(elementType, res); + } + if (elementType.isa()) { + uint64_t lhsVal = lhsAttr.cast().getInt(); + uint64_t rhsVal = secondAttr.cast().getInt(); + uint64_t res = lhsVal * rhsVal; + return rewriter.getIntegerAttr(elementType, res); + } + llvm_unreachable("constant propagation for MulOp: unkonwn data type"); +} + +// Recursively process one dimension in the rank of the two references. There +// can be one of 3 cases. +// 1) We have fully defined accesses for both operands, launch the computations. +// 2) One of the two has a higher number of unprocessed ranks, which is hte case +// when we have to broadcast the whole lower-dim reference with respect to the +// other. Iterate over each value of the higher ranked reference, keeping the +// reference of the lower ranked reference constant. +// 3) Both references have the same rank, we still do broadcast if one of the +// dimension size is equal to 1. + +template +void RecurseConstProppElementwiseBinary(PatternRewriter &rewriter, + std::vector &resVector, DenseElementsAttr &lhsAttr, + DenseElementsAttr &rhsAttr, SmallVector &lhsIndices, + SmallVector &rhsIndices, int lhsFreeRank, int rhsFreeRank) { + // printf("recurse with free %d/%d\n", lhsFreeRank, rhsFreeRank); + if (lhsFreeRank == 0) { + // Fully defined ranks. + assert( + rhsFreeRank == 0 && "expect both to recurse to zero at the same time"); + auto lhsElementAttr = lhsAttr.getValue(ArrayRef(lhsIndices)); + auto rhsElementAttr = rhsAttr.getValue(ArrayRef(rhsIndices)); + auto elementaryType = lhsAttr.getType().getElementType(); + auto res = ComputeConstProppElementwiseBinary( + rewriter, elementaryType, lhsElementAttr, rhsElementAttr); + resVector.emplace_back(res); + } else if (lhsFreeRank > rhsFreeRank) { + // Initial broadcast from lhs. + auto lhsShape = lhsAttr.getType().getShape(); + int lhsRank = lhsShape.size(); + int lhsIndex = lhsRank - lhsFreeRank; + int lhsSize = lhsAttr.getType().getShape()[lhsIndex]; + for (int i = 0; i < lhsSize; ++i) { + lhsIndices[lhsIndex] = i; + RecurseConstProppElementwiseBinary(rewriter, + resVector, lhsAttr, rhsAttr, lhsIndices, rhsIndices, lhsFreeRank - 1, + rhsFreeRank); + } + } else if (lhsFreeRank < rhsFreeRank) { + // Initial broadcast from rhs. + auto rhsShape = rhsAttr.getType().getShape(); + int rhsRank = rhsShape.size(); + int rhsIndex = rhsRank - rhsFreeRank; + int rhsSize = rhsAttr.getType().getShape()[rhsIndex]; + for (int i = 0; i < rhsSize; ++i) { + rhsIndices[rhsIndex] = i; + RecurseConstProppElementwiseBinary(rewriter, + resVector, lhsAttr, rhsAttr, lhsIndices, rhsIndices, lhsFreeRank, + rhsFreeRank - 1); + } + } else { + // No initial broadcast, but if one element has size 1 and the other is + // greater than one, then we also have broadcast. + auto lhsShape = lhsAttr.getType().getShape(); + int lhsRank = lhsShape.size(); + int lhsIndex = lhsRank - lhsFreeRank; + int lhsSize = lhsAttr.getType().getShape()[lhsIndex]; + auto rhsShape = rhsAttr.getType().getShape(); + int rhsRank = rhsShape.size(); + int rhsIndex = rhsRank - rhsFreeRank; + int rhsSize = rhsAttr.getType().getShape()[rhsIndex]; + assert((lhsSize == 1 || rhsSize == 1 || lhsSize == rhsSize) && + "incompatible sizes"); + int size = std::max(lhsSize, rhsSize); + lhsIndices[lhsIndex] = rhsIndices[rhsIndex] = 0; + for (int i = 0; i < size; ++i) { + if (lhsSize > 1) + lhsIndices[lhsIndex] = i; + if (rhsSize > 1) + rhsIndices[rhsIndex] = i; + RecurseConstProppElementwiseBinary(rewriter, + resVector, lhsAttr, rhsAttr, lhsIndices, rhsIndices, lhsFreeRank - 1, + rhsFreeRank - 1); + } + } +} + +// Process the constant operands, perform the operation with broadcast, and +// generate the new constant operation. +template +DenseElementsAttr ConstPropElementwiseBinary(PatternRewriter &rewriter, + Value resOperand, Attribute &lhsAttr, Attribute &rhsAttr) { + DenseElementsAttr lhsDenseAttr = + lhsAttr.dyn_cast_or_null(); + DenseElementsAttr rhsDenseAttr = + rhsAttr.dyn_cast_or_null(); + assert((lhsDenseAttr && lhsDenseAttr) && "expected dense attributes"); + assert( + resOperand.getType().isa() && "expected ranked tensor"); + ShapedType resType = resOperand.getType().cast(); + auto lhsRank = lhsDenseAttr.getType().getShape().size(); + auto rhsRank = rhsDenseAttr.getType().getShape().size(); + SmallVector lhsIndices(lhsRank, 0); + SmallVector rhsIndices(rhsRank, 0); + std::vector resVector; + RecurseConstProppElementwiseBinary(rewriter, resVector, + lhsDenseAttr, rhsDenseAttr, lhsIndices, rhsIndices, lhsRank, rhsRank); + ArrayRef resRef(resVector); + return DenseElementsAttr::get(resType, resRef); +} + +// ============================================================================= +// Code to perform constant propagation for unary operation. +// ============================================================================= + +template +Attribute ComputeConstProppElementwiseUnary( + PatternRewriter &rewriter, Type elementType, Attribute &attr) { + llvm_unreachable("unkonwn operation"); +} + +template <> +Attribute ComputeConstProppElementwiseUnary( + PatternRewriter &rewriter, Type elementType, Attribute &attr) { + if (elementType.isa()) { + double val = attr.cast().getValueAsDouble(); + double res = -val; + return rewriter.getFloatAttr(elementType, res); + } + if (elementType.isa()) { + uint64_t val = attr.cast().getInt(); + uint64_t res = -val; + return rewriter.getIntegerAttr(elementType, res); + } + llvm_unreachable("constant propagation for NegOp: unkonwn data type"); +} + +template +void RecurseConstProppElementwiseUnary(PatternRewriter &rewriter, + std::vector &resVector, DenseElementsAttr &attr, + SmallVector &indices, int freeRank) { + // printf("recurse with free %d\n", freeRank); + if (freeRank == 0) { + // Fully defined ranks. + auto elementAttr = attr.getValue(ArrayRef(indices)); + auto elementaryType = attr.getType().getElementType(); + auto res = ComputeConstProppElementwiseUnary( + rewriter, elementaryType, elementAttr); + resVector.emplace_back(res); + } else { + // Recurse. + auto shape = attr.getType().getShape(); + int rank = shape.size(); + int index = rank - freeRank; + int size = attr.getType().getShape()[index]; + for (int i = 0; i < size; ++i) { + indices[index] = i; + RecurseConstProppElementwiseUnary( + rewriter, resVector, attr, indices, freeRank - 1); + } + } +} + +// Process the constant operands, perform the operation with broadcast, and +// generate the new constant operation. +template +DenseElementsAttr ConstPropElementwiseUnary( + PatternRewriter &rewriter, Value resOperand, Attribute &attr) { + DenseElementsAttr denseAttr = + attr.dyn_cast_or_null(); + assert(denseAttr && "expected dense attribute"); + assert( + resOperand.getType().isa() && "expected ranked tensor"); + ShapedType resType = resOperand.getType().cast(); + auto rank = denseAttr.getType().getShape().size(); + SmallVector indices(rank, 0); + std::vector resVector; + RecurseConstProppElementwiseUnary( + rewriter, resVector, denseAttr, indices, rank); + ArrayRef resRef(resVector); + return DenseElementsAttr::get(resType, resRef); +} + +// ============================================================================= +// Pattern definition. +// ============================================================================= + +#include "src/Transform/ONNX/ONNXConstProp.inc" + +// ============================================================================= +// Code to manage the pass. +// ============================================================================= + +struct ConstPropONNXToONNXPass + : public PassWrapper { + void runOnFunction() final; +}; +} // end anonymous namespace. + +void ConstPropONNXToONNXPass::runOnFunction() { + auto function = getFunction(); + MLIRContext *context = &getContext(); + + ConversionTarget target(getContext()); + target.addLegalDialect(); + + OwningRewritePatternList patterns; + populateWithGenerated(context, &patterns); + + applyPatternsAndFoldGreedily(function, patterns); +} // end anonymous namespace + +/*! + * Create a ConstPropONNX pass. + */ +std::unique_ptr mlir::createConstPropONNXToONNXPass() { + return std::make_unique(); +} + +static PassRegistration pass("constprop-onnx", + "ConstProp ONNX operations into composition of other ONNX operations."); diff --git a/src/Transform/ONNX/ONNXConstProp.td b/src/Transform/ONNX/ONNXConstProp.td new file mode 100644 index 0000000..a628d32 --- /dev/null +++ b/src/Transform/ONNX/ONNXConstProp.td @@ -0,0 +1,161 @@ +//===- ONNXConstProp.td - Rewriting for Constant Propagation in ONNX Ops -*- tablegen -===// +// +// Copyright 2019-2020 The IBM Research Authors. +// +// ============================================================================= +// +// Defines language-specific pattern match rewritings for ONNX using +// Declarative Rewrite Rules (DRR) specified using TableGen records. +// +//===----------------------------------------------------------------------===// + +#ifndef ONNX_CONSTPROP +#define ONNX_CONSTPROP + +#ifndef OP_BASE +include "src/Dialect/ONNX/ONNXOps.td" +#endif // OP_BASE + + +// ============================================================================= +// Instruction to add new constant operation rules. Minimally, you will have added +// operation in the ONNXConstProp.cpp to perform the element-wise single value +// handling of the new operator that you are dealing with. You will need to +// generate a call to the method that handle the tensor constant prop. Here +// is the call for a unary and binary operation. Adapt to your new operator: +// +// def CreateAddOfTwoConst : +// NativeCodeCall<"ConstPropElementwiseBinary($_builder, $0, $1, $2)">; +// +// def CreateNegOfConst : +// NativeCodeCall<"ConstPropElementwiseUnary($_builder, $0, $1)">; +// +// where you will have mostly to substitute your new operator as well as using +// a new def name. +// +// Then you will need to add substitution rules, see examples below. +// ============================================================================= + + +// Useful test definitions. + +def IsNotAConstant : + Constraint(($_self).getDefiningOp())">, + "operation is not a constant">; + +def AttributeIsNull : + Constraint, + "Attribute is null">; + + +// Usefult code generation invokation. +def GetNullAttr : NativeCodeCall<"Attribute()">; + +def CreateAddOfTwoConst : + NativeCodeCall<"ConstPropElementwiseBinary($_builder, $0, $1, $2)">; + +def CreateSubOfTwoConst : + NativeCodeCall<"ConstPropElementwiseBinary($_builder, $0, $1, $2)">; + +def CreateNegOfConst : + NativeCodeCall<"ConstPropElementwiseUnary($_builder, $0, $1)">; + + def CreateMulOfTwoConst : + NativeCodeCall<"ConstPropElementwiseBinary($_builder, $0, $1, $2)">; + +// ============================================================================= +// Patterns to enable opportunities with elementwise ADD operations. + +// Use commutativity to normalize constants in the second position of Add. +def AddConstCommutative1 : Pat< + // From add(c, x). + (ONNXAddOp (ONNXConstantOp:$c $_, $_), $x), + // To add(x, c). + (ONNXAddOp $x, $c), + // To avoid infinite loop, constrain the first arguments to be anything but a constant. + [(IsNotAConstant:$x)]>; + +// Use associativity to add constants together. +def AddConstAssociative1 : Pat< + // From add(add(x, c1), c2). + (ONNXAddOp + (ONNXAddOp $x,(ONNXConstantOp:$c1 $_, $_)), + (ONNXConstantOp:$c2 $_, $_)), + // To add(x, add(c1, c2)). + (ONNXAddOp + $x, + (ONNXAddOp $c1, $c2))>; + +// Constant Propagation for Add +def AddConstProp : Pat< + // From add(c1, c2). + (ONNXAddOp:$addOp (ONNXConstantOp $s1, $v1), (ONNXConstantOp $s2, $v2)), + // To c1+c2 + (ONNXConstantOp (GetNullAttr), (CreateAddOfTwoConst $addOp, $v1, $v2)), + // Additional constraints (no sparse) + [(AttributeIsNull:$s1), (AttributeIsNull:$s2)]>; + + +// ============================================================================= +// Patterns to enable opportunities with elementwise SUB / NEG operations. + +// Constant Propagation for Sub +def SubConstProp : Pat< + // From sub(c1, c2). + (ONNXSubOp:$subOp (ONNXConstantOp $s1, $v1), (ONNXConstantOp $s2, $v2)), + // To c1-c2 + (ONNXConstantOp (GetNullAttr), (CreateSubOfTwoConst $subOp, $v1, $v2)), + [(AttributeIsNull:$s1), (AttributeIsNull:$s2)]>; + +// Neg of constant is simly -const +def NegofConst : Pat< + // From - (c) + (ONNXNegOp (ONNXConstantOp:$constOp $s, $v)), + // To (-c) + (ONNXConstantOp (GetNullAttr), (CreateNegOfConst $constOp, $v)), + [(AttributeIsNull:$s)]>; + +// Change a subtraction of a constant c by an addition of -c. Helpfull to combine +// with other add optimizations. +def SubConstToNeg : Pat< + // From x - c. + (ONNXSubOp:$subOp $x, (ONNXConstantOp:$constOp $s, $v)), + // To x + (-c). + (ONNXAddOp $x, (ONNXConstantOp (GetNullAttr), (CreateNegOfConst $constOp, $v))), + [(IsNotAConstant:$x), (AttributeIsNull:$s)]>; + + +// ============================================================================= +// Patterns to enable opportunities with elementwise MUL operations. +// Exactly the same pattern as for the elementwise ADD operations. + +// Use commutativity to normalize constants in the second position of Mul. +def MulConstCommutative1 : Pat< + // From mul(c, x). + (ONNXMulOp (ONNXConstantOp:$c $_, $_), $x), + // To mul(x, c). + (ONNXMulOp $x, $c), + // To avoid infinite loop, constrain the first arguments to be anything but a constant. + [(IsNotAConstant:$x)]>; + +// Use associativity to mul constants together. +def MulConstAssociative1 : Pat< + // From mul(mul(x, c1), c2). + (ONNXMulOp + (ONNXMulOp $x,(ONNXConstantOp:$c1 $_, $_)), + (ONNXConstantOp:$c2 $_, $_)), + // To mul(x, mul(c1, c2)). + (ONNXMulOp + $x, + (ONNXMulOp $c1, $c2))>; + +// Constant Propagation for Mul +def MulConstProp : Pat< + // From mul(c1, c2). + (ONNXMulOp:$mulOp (ONNXConstantOp $s1, $v1), (ONNXConstantOp $s2, $v2)), + // To c1+c2 + (ONNXConstantOp (GetNullAttr), (CreateMulOfTwoConst $mulOp, $v1, $v2)), + // Mulitional constraints (no sparse) + [(AttributeIsNull:$s1), (AttributeIsNull:$s2)]>; + +#endif // ONNX_CONSTPROP diff --git a/test/mlir/onnx/onnx_constprop.mlir b/test/mlir/onnx/onnx_constprop.mlir new file mode 100644 index 0000000..6bbfdb8 --- /dev/null +++ b/test/mlir/onnx/onnx_constprop.mlir @@ -0,0 +1,158 @@ +// RUN: onnx-mlir-opt --constprop-onnx %s -split-input-file | FileCheck %s + +// ============================================================================= +/// MUL tests (same as add, so have only one). + +/// Test ConstantOp assoc for add + +// CHECK-LABEL: @test_add_constant_1(%arg0: tensor<3xf32>) -> tensor<3xf32> +func @test_add_constant_1(%arg0 : tensor<3xf32>) -> tensor<3xf32> { + %0 = "onnx.Constant"() {value = dense<[0.0, 1.0, 2.0]> : tensor<3xf32>} : () -> tensor<3xf32> + %1 = "onnx.Add"(%0, %arg0) : (tensor<3xf32> , tensor<3xf32>) -> tensor<3xf32> + "std.return"(%1) : (tensor<3xf32>) -> () + // CHECK-NEXT: [[CONST:%.+]] = "onnx.Constant"() {value = dense<[0.000000e+00, 1.000000e+00, 2.000000e+00]> : tensor<3xf32>} : () -> tensor<3xf32> + // CHECK-NEXT: [[ADD:%.+]] = "onnx.Add"(%arg0, [[CONST]]) : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xf32> +} + +/// Test ConstantOp assoc for add +// CHECK-LABEL: @test_add_constant_2(%arg0: tensor<3xf32>) -> tensor<3xf32> +func @test_add_constant_2(%arg0 : tensor<3xf32>) -> tensor<3xf32> { + %0 = "onnx.Constant"() {value = dense<[0.0, 1.0, 2.0]> : tensor<3xf32>} : () -> tensor<3xf32> + %1 = "onnx.Add"(%arg0, %0) : (tensor<3xf32> , tensor<3xf32>) -> tensor<3xf32> + "std.return"(%1) : (tensor<3xf32>) -> () + // CHECK-NEXT: [[CONST:%.+]] = "onnx.Constant"() {value = dense<[0.000000e+00, 1.000000e+00, 2.000000e+00]> : tensor<3xf32>} : () -> tensor<3xf32> + // CHECK-NEXT: [[ADD:%.+]] = "onnx.Add"(%arg0, [[CONST]]) : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xf32> +} + +/// Change (x+c1)+c2 to x+(c1+c2) +// CHECK-LABEL: @test_add_constant_3(%arg0: tensor<3xi32>) -> tensor<3xi32> +func @test_add_constant_3(%arg0 : tensor<3xi32>) -> tensor<3xi32> { + %0 = "onnx.Constant"() {value = dense<[0, 1, 2]> : tensor<3xi32>} : () -> tensor<3xi32> + %1 = "onnx.Constant"() {value = dense<[10, 11, 12]> : tensor<3xi32>} : () -> tensor<3xi32> + %2 = "onnx.Add"(%0, %arg0) : (tensor<3xi32> , tensor<3xi32>) -> tensor<3xi32> + %3 = "onnx.Add"(%1, %2) : (tensor<3xi32> , tensor<3xi32>) -> tensor<3xi32> + "std.return"(%3) : (tensor<3xi32>) -> () + // CHECK-NEXT: [[CONST1:%.+]] = "onnx.Constant"() {value = dense<[10, 12, 14]> : tensor<3xi32>} : () -> tensor<3xi32> + // CHECK-NEXT: [[ADD1:%.+]] = "onnx.Add"(%arg0, [[CONST1]]) : (tensor<3xi32>, tensor<3xi32>) -> tensor<3xi32> +} + +/// Same test as above, but with a use of an intermediary result +/// change (x+c1)+c2 + (x+c1) to x+(c1+c2) + (x+c1) +// CHECK-LABEL: @test_add_constant_4(%arg0: tensor<3xi32>) -> tensor<3xi32> +func @test_add_constant_4(%arg0 : tensor<3xi32>) -> tensor<3xi32> { + %0 = "onnx.Constant"() {value = dense<[0, 1, 2]> : tensor<3xi32>} : () -> tensor<3xi32> + %1 = "onnx.Constant"() {value = dense<[10, 11, 12]> : tensor<3xi32>} : () -> tensor<3xi32> + %2 = "onnx.Add"(%0, %arg0) : (tensor<3xi32> , tensor<3xi32>) -> tensor<3xi32> + %3 = "onnx.Add"(%1, %2) : (tensor<3xi32> , tensor<3xi32>) -> tensor<3xi32> + %4 = "onnx.Add"(%2, %3) : (tensor<3xi32> , tensor<3xi32>) -> tensor<3xi32> + "std.return"(%4) : (tensor<3xi32>) -> () + // CHECK-NEXT: [[CONST1:%.+]] = "onnx.Constant"() {value = dense<[0, 1, 2]> : tensor<3xi32>} : () -> tensor<3xi32> + // CHECK-NEXT: [[ADD1:%.+]] = "onnx.Add"(%arg0, [[CONST1]]) : (tensor<3xi32>, tensor<3xi32>) -> tensor<3xi32> + // CHECK-NEXT: [[CONST2:%.+]] = "onnx.Constant"() {value = dense<[10, 12, 14]> : tensor<3xi32>} : () -> tensor<3xi32> + // CHECK-NEXT: [[ADD2:%.+]] = "onnx.Add"(%arg0, [[CONST2]]) : (tensor<3xi32>, tensor<3xi32>) -> tensor<3xi32> + // CHECK-NEXT: [[ADD3:%.+]] = "onnx.Add"([[ADD1]], [[ADD2]]) : (tensor<3xi32>, tensor<3xi32>) -> tensor<3xi32> +} + +/// Test broadcast 1 -> 2d + +// CHECK-LABEL: @test_broadcast_1(%arg0: tensor<3x2xi32>) -> tensor<3x2xi32> +func @test_broadcast_1(%arg0: tensor<3x2xi32>) -> tensor<3x2xi32> { + %0 = "onnx.Constant"() {value = dense<[1]> : tensor<1xi32>} : () -> tensor<1xi32> + %1 = "onnx.Constant"() {value = dense<[[2, 3], [4, 5], [6, 7]]> : tensor<3x2xi32>} : () -> tensor<3x2xi32> + %2 = "onnx.Add"(%0, %1) : (tensor<1xi32> , tensor<3x2xi32>) -> tensor<3x2xi32> + %3 = "onnx.Add"(%2, %arg0) : (tensor<3x2xi32> , tensor<3x2xi32>) -> tensor<3x2xi32> + "std.return"(%3) : (tensor<3x2xi32>) -> () + // CHECK-NEXT: [[CONST1:%.+]] = "onnx.Constant"() {value = dense<{{.}}[3, 4], [5, 6], [7, 8]]> : tensor<3x2xi32>} : () -> tensor<3x2xi32> + // CHECK-NEXT: [[ADD1:%.+]] = "onnx.Add"(%arg0, [[CONST1]]) : (tensor<3x2xi32>, tensor<3x2xi32>) -> tensor<3x2xi32> +} + +/// Test broadcast 2d (size one) -> 2d + +// CHECK-LABEL: @test_broadcast_2(%arg0: tensor<3x2xi32>) -> tensor<3x2xi32> +func @test_broadcast_2(%arg0: tensor<3x2xi32>) -> tensor<3x2xi32> { + %0 = "onnx.Constant"() {value = dense<[[1]]> : tensor<1x1xi32>} : () -> tensor<1x1xi32> + %1 = "onnx.Constant"() {value = dense<[[2, 3], [4, 5], [6, 7]]> : tensor<3x2xi32>} : () -> tensor<3x2xi32> + %2 = "onnx.Add"(%0, %1) : (tensor<1x1xi32> , tensor<3x2xi32>) -> tensor<3x2xi32> + %3 = "onnx.Add"(%2, %arg0) : (tensor<3x2xi32> , tensor<3x2xi32>) -> tensor<3x2xi32> + "std.return"(%3) : (tensor<3x2xi32>) -> () + // CHECK-NEXT: [[CONST1:%.+]] = "onnx.Constant"() {value = dense<{{.}}[3, 4], [5, 6], [7, 8]]> : tensor<3x2xi32>} : () -> tensor<3x2xi32> + // CHECK-NEXT: [[ADD1:%.+]] = "onnx.Add"(%arg0, [[CONST1]]) : (tensor<3x2xi32>, tensor<3x2xi32>) -> tensor<3x2xi32> +} + +/// check 1d -> 2d + + // CHECK-LABEL: @test_broadcast_3(%arg0: tensor<3x2xi32>) -> tensor<3x2xi32> +func @test_broadcast_3(%arg0 : tensor<3x2xi32>) -> tensor<3x2xi32> { + %0 = "onnx.Constant"() {value = dense<[[1], [2], [3]]> : tensor<3x1xi32>} : () -> tensor<3x1xi32> + %1 = "onnx.Constant"() {value = dense<[[10, 11], [21, 22], [31, 32]]> : tensor<3x2xi32>} : () -> tensor<3x2xi32> + %2 = "onnx.Add"(%0, %1) : (tensor<3x1xi32> , tensor<3x2xi32>) -> tensor<3x2xi32> + %3 = "onnx.Add"(%2, %arg0) : (tensor<3x2xi32> , tensor<3x2xi32>) -> tensor<3x2xi32> + "std.return"(%3) : (tensor<3x2xi32>) -> () + // CHECK-NEXT: [[CONST1:%.+]] = "onnx.Constant"() {value = dense<{{.}}[11, 12], [23, 24], [34, 35]]> : tensor<3x2xi32>} : () -> tensor<3x2xi32> + // CHECK-NEXT: [[ADD1:%.+]] = "onnx.Add"(%arg0, [[CONST1]]) : (tensor<3x2xi32>, tensor<3x2xi32>) -> tensor<3x2xi32> +} + +// ============================================================================= +/// MUL tests (same as add, so have only one). + +/// Change (x*c1)*c2 to x*(c1*c2) +// CHECK-LABEL: @test_mul_constant_3(%arg0: tensor<3xi32>) -> tensor<3xi32> +func @test_mul_constant_3(%arg0 : tensor<3xi32>) -> tensor<3xi32> { + %0 = "onnx.Constant"() {value = dense<[0, 1, 2]> : tensor<3xi32>} : () -> tensor<3xi32> + %1 = "onnx.Constant"() {value = dense<[10, 11, 12]> : tensor<3xi32>} : () -> tensor<3xi32> + %2 = "onnx.Mul"(%0, %arg0) : (tensor<3xi32> , tensor<3xi32>) -> tensor<3xi32> + %3 = "onnx.Mul"(%1, %2) : (tensor<3xi32> , tensor<3xi32>) -> tensor<3xi32> + "std.return"(%3) : (tensor<3xi32>) -> () + // CHECK-NEXT: [[CONST1:%.+]] = "onnx.Constant"() {value = dense<[0, 11, 24]> : tensor<3xi32>} : () -> tensor<3xi32> + // CHECK-NEXT: [[MUL1:%.+]] = "onnx.Mul"(%arg0, [[CONST1]]) : (tensor<3xi32>, tensor<3xi32>) -> tensor<3xi32> +} + +// ============================================================================= +/// SUB and NEG tests. + + +// check of sub two constants + +// CHECK-LABEL: @test_sub_1(%arg0: tensor<3x2xi32>) -> tensor<3x2xi32> +func @test_sub_1(%arg0: tensor<3x2xi32>) -> tensor<3x2xi32> { + %0 = "onnx.Constant"() {value = dense<[[2, 3], [4, 5], [6, 7]]> : tensor<3x2xi32>} : () -> tensor<3x2xi32> + %1 = "onnx.Constant"() {value = dense<[[2]]> : tensor<1x1xi32>} : () -> tensor<1x1xi32> + %2 = "onnx.Sub"(%0, %1) : (tensor<3x2xi32>, tensor<1x1xi32>) -> tensor<3x2xi32> + "std.return"(%2) : (tensor<3x2xi32>) -> () + // CHECK-NEXT: [[CONST1:%.+]] = "onnx.Constant"() {value = dense<{{.}}[0, 1], [2, 3], [4, 5]]> : tensor<3x2xi32>} : () -> tensor<3x2xi32> +} + +/// check sub to add of negative + +// CHECK-LABEL: @test_neg_1(%arg0: tensor<3x2xi32>) -> tensor<3x2xi32> +func @test_neg_1(%arg0: tensor<3x2xi32>) -> tensor<3x2xi32> { + %0 = "onnx.Constant"() {value = dense<[[2, 3], [4, 5], [6, 7]]> : tensor<3x2xi32>} : () -> tensor<3x2xi32> + %1 = "onnx.Sub"(%arg0, %0) : (tensor<3x2xi32> , tensor<3x2xi32>) -> tensor<3x2xi32> + "std.return"(%1) : (tensor<3x2xi32>) -> () + // CHECK-NEXT: [[CONST1:%.+]] = "onnx.Constant"() {value = dense<{{.}}[-2, -3], [-4, -5], [-6, -7]]> : tensor<3x2xi32>} : () -> tensor<3x2xi32> + // CHECK-NEXT: [[ADD1:%.+]] = "onnx.Add"(%arg0, [[CONST1]]) : (tensor<3x2xi32>, tensor<3x2xi32>) -> tensor<3x2xi32> +} + +// CHECK-LABEL: @test_neg_2(%arg0: tensor<3x2xi32>) -> tensor<3x2xi32> +func @test_neg_2(%arg0: tensor<3x2xi32>) -> tensor<3x2xi32> { + %0 = "onnx.Constant"() {value = dense<[[2, 3], [4, 5], [6, 7]]> : tensor<3x2xi32>} : () -> tensor<3x2xi32> + %1 = "onnx.Constant"() {value = dense<[[10]]> : tensor<1x1xi32>} : () -> tensor<1x1xi32> + %2 = "onnx.Sub"(%arg0, %0) : (tensor<3x2xi32> , tensor<3x2xi32>) -> tensor<3x2xi32> + %5 = "onnx.Add"(%2, %1) : (tensor<3x2xi32> , tensor<1x1xi32>) -> tensor<3x2xi32> + "std.return"(%5) : (tensor<3x2xi32>) -> () + // CHECK-NEXT: [[CONST1:%.+]] = "onnx.Constant"() {value = dense<{{.}}[8, 7], [6, 5], [4, 3]]> : tensor<3x2xi32>} : () -> tensor<3x2xi32> + // CHECK-NEXT: [[ADD1:%.+]] = "onnx.Add"(%arg0, [[CONST1]]) : (tensor<3x2xi32>, tensor<3x2xi32>) -> tensor<3x2xi32> +} + +// CHECK-LABEL: @test_neg_3(%arg0: tensor<3x2xi32>) -> tensor<3x2xi32> +func @test_neg_3(%arg0: tensor<3x2xi32>) -> tensor<3x2xi32> { + %0 = "onnx.Constant"() {value = dense<[[2, 3], [4, 5], [6, 7]]> : tensor<3x2xi32>} : () -> tensor<3x2xi32> + %1 = "onnx.Constant"() {value = dense<[[10]]> : tensor<1x1xi32>} : () -> tensor<1x1xi32> + %2 = "onnx.Neg"(%0) : (tensor<3x2xi32>) -> tensor<3x2xi32> + %3 = "onnx.Add"(%arg0, %2) : (tensor<3x2xi32> , tensor<3x2xi32>) -> tensor<3x2xi32> + %4 = "onnx.Add"(%3, %1) : (tensor<3x2xi32> , tensor<1x1xi32>) -> tensor<3x2xi32> + "std.return"(%4) : (tensor<3x2xi32>) -> () + // CHECK-NEXT: [[CONST1:%.+]] = "onnx.Constant"() {value = dense<{{.}}[8, 7], [6, 5], [4, 3]]> : tensor<3x2xi32>} : () -> tensor<3x2xi32> + // CHECK-NEXT: [[ADD1:%.+]] = "onnx.Add"(%arg0, [[CONST1]]) : (tensor<3x2xi32>, tensor<3x2xi32>) -> tensor<3x2xi32> +} + diff --git a/utils/gen_onnx_mlir.py b/utils/gen_onnx_mlir.py index d393cd0..6c3b810 100644 --- a/utils/gen_onnx_mlir.py +++ b/utils/gen_onnx_mlir.py @@ -295,8 +295,15 @@ OpsWithResultTypeInference = { # Currenlty, there are only two build methods generated: # - one with operands and attributes having a separate parameter, and # - one with operands and attributes having aggregated parameters. -custom_builder_ops_list = ['Abs', 'Mul', 'Exp', 'ReduceSum', 'ReduceSumSquare', 'Pad'] - +custom_builder_unranked_ops_list = ['Abs', 'Exp', 'ReduceSum', 'ReduceSumSquare', 'Pad'] +# Custom builder op list for operations with broadcast; we can deduce the right +# output type, no need to leave it undef as in the above list. +# Ops must have two operands, not one, not three... And there shall be two. +# TODO: handle variadic ops omitted here: Max, Min, Min, Sum. +custom_builder_broadcast_ops_list = ['Add', 'And', 'Div', 'Equal', 'Greater', + 'Less', 'Mul', 'Or', 'Pow', 'Sub', 'Xor'] +# union of both +custom_builder_ops_list = custom_builder_unranked_ops_list + custom_builder_broadcast_ops_list #a dictionary to add any special definition for an operation custom_definition_misc = dict([ ('Constant', @@ -716,6 +723,8 @@ def gen_op_def(schema): s += indent + 'let results = (outs {});\n'.format( (',\n' + inc_indent(indent)).join(outs_strs)) + # custom_builder_broadcast_ops_list + # add custom builders # use element type of the first operand to construct an UnrankedTensorType for the output. if schema.name in custom_builder_ops_list: @@ -726,7 +735,8 @@ def gen_op_def(schema): else: s += indent + 'let builders = [\n' # Custom builders with operands and attributes having a seperate parameter. - # E.g. OpBuilder<"OpBuilder &builder, OperationState &state, Value X, Value, Y, Attribute A", [{}]> + # E.g. OpBuilder<"OpBuilder &builder, OperationState &state, Value X, + # Value, Y, Attribute A", [{}]> indent = inc_indent(indent) s += indent + 'OpBuilder<"OpBuilder &builder, OperationState &state' operands_dict = get_operands_or_results(schema, is_input=True) @@ -740,9 +750,26 @@ def gen_op_def(schema): # Get output type from first operand's type. first_operand_name = list(ins.items())[0][0] - s += indent + 'auto elementType = {}.getType().cast().getElementType();\n'.format( - first_operand_name) - s += indent + 'build(builder, state, UnrankedTensorType::get(elementType)' + build_type_name = '' + if schema.name in custom_builder_broadcast_ops_list: + second_operand_name = list(ins.items())[1][0] + s += indent + 'auto lhsTy = {}.getType().cast();\n'. \ + format(first_operand_name) + s += indent + 'auto rhsTy = {}.getType().cast();\n'. \ + format(second_operand_name) + s += indent + 'auto elementType = getBroadcastedType(lhsTy, rhsTy);\n' + s += indent + 'auto shapedType = elementType.dyn_cast_or_null();\n'; + s += indent + 'if (!shapedType || !shapedType.hasStaticShape()) {\n'; + s += indent + indent + 'elementType = {}'.format(first_operand_name) + \ + '.getType().cast().getElementType();\n'; + s += indent + indent + 'elementType = UnrankedTensorType::get(elementType);\n' + s += indent + '}\n'; + build_type_name = 'elementType' + else: + s += indent + 'auto elementType = {}'.format(first_operand_name) + \ + '.getType().cast().getElementType();\n' + build_type_name = 'UnrankedTensorType::get(elementType)' + s += indent + 'build(builder, state, {}'.format(build_type_name) for name, _ in ins.items(): s += ', ' + name s += ');\n' @@ -750,12 +777,26 @@ def gen_op_def(schema): s += indent + '}]>,\n' # Custom builders with all operands and attributes having aggregate parameters. - # E.g. OpBuilder<"OpBuilder &builder, OperationState &state, ValueRange operands, ArrayRef attributes", [{}]>' - s += indent + 'OpBuilder<"OpBuilder &builder, OperationState &state, ValueRange operands, ArrayRef attributes", [{\n' + # E.g. OpBuilder<"OpBuilder &builder, OperationState &state, ValueRange operands, + # ArrayRef attributes", [{}]>' + s += indent + 'OpBuilder<"OpBuilder &builder, OperationState &state, ' + \ + 'ValueRange operands, ArrayRef attributes", [{\n' indent = inc_indent(indent) - s += indent + 'auto elementType = operands[0].getType().cast().getElementType();\n' + if schema.name in custom_builder_broadcast_ops_list: + s += indent + 'auto lhsTy = operands[0].getType().cast();\n' + s += indent + 'auto rhsTy = operands[1].getType().cast();\n' + s += indent + 'auto elementType = getBroadcastedType(lhsTy, rhsTy);\n' + s += indent + 'auto shapedType = elementType.dyn_cast_or_null();\n'; + s += indent + 'if (!shapedType || !shapedType.hasStaticShape()) {\n'; + s += indent + indent + 'elementType = operands[0]' + \ + '.getType().cast().getElementType();\n'; + s += indent + indent + 'elementType = UnrankedTensorType::get(elementType);\n' + s += indent + '}\n'; + else: + s += indent + 'auto elementType = operands[0].getType().' + \ + 'cast().getElementType();\n' s += indent + 'std::vector outputTypes;\n' - s += indent + 'outputTypes.emplace_back(UnrankedTensorType::get(elementType));\n' + s += indent + 'outputTypes.emplace_back({});\n'.format(build_type_name) s += indent + 'build(builder, state, outputTypes, operands, attributes);\n' indent = dec_indent(indent) s += indent + '}]>'