Constprop (#162)

* initial const prop attempt

* added support for broadcast ops

* adde all binary broadcast ops into custom builders with precise type

* added test example

* working

* format

* fixed suggestion by Tung, start woring on unary

* added subtraction and neg the right way, and added elementwise mul too

* formatting changes

* format

* format

* added instructions to add new optimizations
This commit is contained in:
Alexandre Eichenberger 2020-06-08 15:45:32 -04:00 committed by GitHub
parent dedd5f4a12
commit e2af505746
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 1111 additions and 129 deletions

4
.gitignore vendored
View File

@ -116,6 +116,7 @@ docs/_build/
# PyBuilder # PyBuilder
target/ target/
utils/ONNXOps.td.inc
# Jupyter Notebook # Jupyter Notebook
.ipynb_checkpoints .ipynb_checkpoints
@ -175,3 +176,6 @@ dmypy.json
# pytype static type analyzer # pytype static type analyzer
.pytype/ .pytype/
#editor
*~

View File

@ -1,9 +1,9 @@
# Import ONNX specifications into ONNX-MLIR # Import ONNX specifications into ONNX-MLIR
ONNX specifications are defined under `onnx/defs` directory in the ONNX project repository. ONNX specifications are defined under `onnx/defs` directory in the ONNX project repository.
There is a python script onnx/defs/gen_doc.py that automatically generate documents about operations in ONNX (docs/Operations.md). There is a python script onnx/defs/gen_onnx_mlir.py that automatically generate documents about operations in ONNX (docs/Operations.md).
ONNX-MLIR modified this script to import ONNX specifications into ONNX-MLIR. ONNX-MLIR modified this script to import ONNX specifications into ONNX-MLIR.
There are two files generated for ONNX MLIR with the modified gen_doc.py: There are two files generated for ONNX MLIR with the modified gen_onnx_mlir.py:
1. `src/Dialect/ONNX/ONNXOps.td.inc`: Operation definition for MLIR TableGen. `src/Dialect/ONNX/ONNXOps.td` includes this file. 1. `src/Dialect/ONNX/ONNXOps.td.inc`: Operation definition for MLIR TableGen. `src/Dialect/ONNX/ONNXOps.td` includes this file.
2. `src/Builder/OpBuildTable.inc`: C++ code for ONNX-MLIR frontend to import operation nodes from ONNX model. `src/Builder/FrontendDialectTransformer.cpp` includes this file. 2. `src/Builder/OpBuildTable.inc`: C++ code for ONNX-MLIR frontend to import operation nodes from ONNX model. `src/Builder/FrontendDialectTransformer.cpp` includes this file.
@ -22,7 +22,7 @@ Even though we strive to support the latest version of ONNX specification as qui
Due to the possibility of such a delay, operator definition within the ONNX project repository may describe features and schemas that we do not yet support. Due to the possibility of such a delay, operator definition within the ONNX project repository may describe features and schemas that we do not yet support.
## Customization ## Customization
In addition to following the ONNX specification, the script gen_onnx_mlir.py, modified gen_doc.py, provides some mechanism for you to customize the output. In addition to following the ONNX specification, the script gen_onnx_mlir.py, modified gen_onnx_mlir.py, provides some mechanism for you to customize the output.
Several tables are defined at the beginning of the script: Several tables are defined at the beginning of the script:
1. `special_attr_defaults`: gives attribute special default value. 1. `special_attr_defaults`: gives attribute special default value.
2. `special_op_handler`: creates special import function in frontend_dialect_transformer.cpp. Currently, a special handler is used for operations with operational arguments 2. `special_op_handler`: creates special import function in frontend_dialect_transformer.cpp. Currently, a special handler is used for operations with operational arguments

View File

@ -93,6 +93,32 @@ def ONNXAddOp:ONNX_Op<"Add",
let arguments = (ins AnyTypeOf<[TensorOf<[I32,I64,F16,F32,F64]>, MemRefOf<[I32,I64,F16,F32,F64]>]>:$A, let arguments = (ins AnyTypeOf<[TensorOf<[I32,I64,F16,F32,F64]>, MemRefOf<[I32,I64,F16,F32,F64]>]>:$A,
AnyTypeOf<[TensorOf<[I32,I64,F16,F32,F64]>, MemRefOf<[I32,I64,F16,F32,F64]>]>:$B); AnyTypeOf<[TensorOf<[I32,I64,F16,F32,F64]>, MemRefOf<[I32,I64,F16,F32,F64]>]>:$B);
let results = (outs AnyTypeOf<[TensorOf<[I32,I64,F16,F32,F64]>, MemRefOf<[I32,I64,F16,F32,F64]>]>:$C); let results = (outs AnyTypeOf<[TensorOf<[I32,I64,F16,F32,F64]>, MemRefOf<[I32,I64,F16,F32,F64]>]>:$C);
let builders = [
OpBuilder<"OpBuilder &builder, OperationState &state, Value A, Value B", [{
auto lhsTy = A.getType().cast<RankedTensorType>();
auto rhsTy = B.getType().cast<RankedTensorType>();
auto elementType = getBroadcastedType(lhsTy, rhsTy);
auto shapedType = elementType.dyn_cast_or_null<ShapedType>();
if (!shapedType || !shapedType.hasStaticShape()) {
elementType = A.getType().cast<TensorType>().getElementType();
elementType = UnrankedTensorType::get(elementType);
}
build(builder, state, elementType, A, B);
}]>,
OpBuilder<"OpBuilder &builder, OperationState &state, ValueRange operands, ArrayRef<NamedAttribute> attributes", [{
auto lhsTy = operands[0].getType().cast<RankedTensorType>();
auto rhsTy = operands[1].getType().cast<RankedTensorType>();
auto elementType = getBroadcastedType(lhsTy, rhsTy);
auto shapedType = elementType.dyn_cast_or_null<ShapedType>();
if (!shapedType || !shapedType.hasStaticShape()) {
elementType = operands[0].getType().cast<TensorType>().getElementType();
elementType = UnrankedTensorType::get(elementType);
}
std::vector<mlir::Type> outputTypes;
outputTypes.emplace_back(elementType);
build(builder, state, outputTypes, operands, attributes);
}]>
];
let extraClassDeclaration = [{ let extraClassDeclaration = [{
static int getNumberOfOperands() { static int getNumberOfOperands() {
return 2; return 2;
@ -118,6 +144,32 @@ def ONNXAndOp:ONNX_Op<"And",
let arguments = (ins AnyTypeOf<[TensorOf<[I1]>, MemRefOf<[I1]>]>:$A, let arguments = (ins AnyTypeOf<[TensorOf<[I1]>, MemRefOf<[I1]>]>:$A,
AnyTypeOf<[TensorOf<[I1]>, MemRefOf<[I1]>]>:$B); AnyTypeOf<[TensorOf<[I1]>, MemRefOf<[I1]>]>:$B);
let results = (outs AnyTypeOf<[TensorOf<[I1]>, MemRefOf<[I1]>]>:$C); let results = (outs AnyTypeOf<[TensorOf<[I1]>, MemRefOf<[I1]>]>:$C);
let builders = [
OpBuilder<"OpBuilder &builder, OperationState &state, Value A, Value B", [{
auto lhsTy = A.getType().cast<RankedTensorType>();
auto rhsTy = B.getType().cast<RankedTensorType>();
auto elementType = getBroadcastedType(lhsTy, rhsTy);
auto shapedType = elementType.dyn_cast_or_null<ShapedType>();
if (!shapedType || !shapedType.hasStaticShape()) {
elementType = A.getType().cast<TensorType>().getElementType();
elementType = UnrankedTensorType::get(elementType);
}
build(builder, state, elementType, A, B);
}]>,
OpBuilder<"OpBuilder &builder, OperationState &state, ValueRange operands, ArrayRef<NamedAttribute> attributes", [{
auto lhsTy = operands[0].getType().cast<RankedTensorType>();
auto rhsTy = operands[1].getType().cast<RankedTensorType>();
auto elementType = getBroadcastedType(lhsTy, rhsTy);
auto shapedType = elementType.dyn_cast_or_null<ShapedType>();
if (!shapedType || !shapedType.hasStaticShape()) {
elementType = operands[0].getType().cast<TensorType>().getElementType();
elementType = UnrankedTensorType::get(elementType);
}
std::vector<mlir::Type> outputTypes;
outputTypes.emplace_back(elementType);
build(builder, state, outputTypes, operands, attributes);
}]>
];
let extraClassDeclaration = [{ let extraClassDeclaration = [{
static int getNumberOfOperands() { static int getNumberOfOperands() {
return 2; return 2;
@ -933,6 +985,32 @@ def ONNXDivOp:ONNX_Op<"Div",
let arguments = (ins AnyTypeOf<[TensorOf<[I32,I64,F16,F32,F64]>, MemRefOf<[I32,I64,F16,F32,F64]>]>:$A, let arguments = (ins AnyTypeOf<[TensorOf<[I32,I64,F16,F32,F64]>, MemRefOf<[I32,I64,F16,F32,F64]>]>:$A,
AnyTypeOf<[TensorOf<[I32,I64,F16,F32,F64]>, MemRefOf<[I32,I64,F16,F32,F64]>]>:$B); AnyTypeOf<[TensorOf<[I32,I64,F16,F32,F64]>, MemRefOf<[I32,I64,F16,F32,F64]>]>:$B);
let results = (outs AnyTypeOf<[TensorOf<[I32,I64,F16,F32,F64]>, MemRefOf<[I32,I64,F16,F32,F64]>]>:$C); let results = (outs AnyTypeOf<[TensorOf<[I32,I64,F16,F32,F64]>, MemRefOf<[I32,I64,F16,F32,F64]>]>:$C);
let builders = [
OpBuilder<"OpBuilder &builder, OperationState &state, Value A, Value B", [{
auto lhsTy = A.getType().cast<RankedTensorType>();
auto rhsTy = B.getType().cast<RankedTensorType>();
auto elementType = getBroadcastedType(lhsTy, rhsTy);
auto shapedType = elementType.dyn_cast_or_null<ShapedType>();
if (!shapedType || !shapedType.hasStaticShape()) {
elementType = A.getType().cast<TensorType>().getElementType();
elementType = UnrankedTensorType::get(elementType);
}
build(builder, state, elementType, A, B);
}]>,
OpBuilder<"OpBuilder &builder, OperationState &state, ValueRange operands, ArrayRef<NamedAttribute> attributes", [{
auto lhsTy = operands[0].getType().cast<RankedTensorType>();
auto rhsTy = operands[1].getType().cast<RankedTensorType>();
auto elementType = getBroadcastedType(lhsTy, rhsTy);
auto shapedType = elementType.dyn_cast_or_null<ShapedType>();
if (!shapedType || !shapedType.hasStaticShape()) {
elementType = operands[0].getType().cast<TensorType>().getElementType();
elementType = UnrankedTensorType::get(elementType);
}
std::vector<mlir::Type> outputTypes;
outputTypes.emplace_back(elementType);
build(builder, state, outputTypes, operands, attributes);
}]>
];
let extraClassDeclaration = [{ let extraClassDeclaration = [{
static int getNumberOfOperands() { static int getNumberOfOperands() {
return 2; return 2;
@ -1055,6 +1133,32 @@ def ONNXEqualOp:ONNX_Op<"Equal",
let arguments = (ins AnyTypeOf<[TensorOf<[I1,I8,I16,I32,I64,F16,F32,F64]>, MemRefOf<[I1,I8,I16,I32,I64,F16,F32,F64]>]>:$A, let arguments = (ins AnyTypeOf<[TensorOf<[I1,I8,I16,I32,I64,F16,F32,F64]>, MemRefOf<[I1,I8,I16,I32,I64,F16,F32,F64]>]>:$A,
AnyTypeOf<[TensorOf<[I1,I8,I16,I32,I64,F16,F32,F64]>, MemRefOf<[I1,I8,I16,I32,I64,F16,F32,F64]>]>:$B); AnyTypeOf<[TensorOf<[I1,I8,I16,I32,I64,F16,F32,F64]>, MemRefOf<[I1,I8,I16,I32,I64,F16,F32,F64]>]>:$B);
let results = (outs AnyTypeOf<[TensorOf<[I1]>, MemRefOf<[I1]>]>:$C); let results = (outs AnyTypeOf<[TensorOf<[I1]>, MemRefOf<[I1]>]>:$C);
let builders = [
OpBuilder<"OpBuilder &builder, OperationState &state, Value A, Value B", [{
auto lhsTy = A.getType().cast<RankedTensorType>();
auto rhsTy = B.getType().cast<RankedTensorType>();
auto elementType = getBroadcastedType(lhsTy, rhsTy);
auto shapedType = elementType.dyn_cast_or_null<ShapedType>();
if (!shapedType || !shapedType.hasStaticShape()) {
elementType = A.getType().cast<TensorType>().getElementType();
elementType = UnrankedTensorType::get(elementType);
}
build(builder, state, elementType, A, B);
}]>,
OpBuilder<"OpBuilder &builder, OperationState &state, ValueRange operands, ArrayRef<NamedAttribute> attributes", [{
auto lhsTy = operands[0].getType().cast<RankedTensorType>();
auto rhsTy = operands[1].getType().cast<RankedTensorType>();
auto elementType = getBroadcastedType(lhsTy, rhsTy);
auto shapedType = elementType.dyn_cast_or_null<ShapedType>();
if (!shapedType || !shapedType.hasStaticShape()) {
elementType = operands[0].getType().cast<TensorType>().getElementType();
elementType = UnrankedTensorType::get(elementType);
}
std::vector<mlir::Type> outputTypes;
outputTypes.emplace_back(elementType);
build(builder, state, outputTypes, operands, attributes);
}]>
];
let extraClassDeclaration = [{ let extraClassDeclaration = [{
static int getNumberOfOperands() { static int getNumberOfOperands() {
return 2; return 2;
@ -1697,6 +1801,32 @@ def ONNXGreaterOp:ONNX_Op<"Greater",
let arguments = (ins AnyTypeOf<[TensorOf<[I8,I16,I32,I64,F16,F32,F64]>, MemRefOf<[I8,I16,I32,I64,F16,F32,F64]>]>:$A, let arguments = (ins AnyTypeOf<[TensorOf<[I8,I16,I32,I64,F16,F32,F64]>, MemRefOf<[I8,I16,I32,I64,F16,F32,F64]>]>:$A,
AnyTypeOf<[TensorOf<[I8,I16,I32,I64,F16,F32,F64]>, MemRefOf<[I8,I16,I32,I64,F16,F32,F64]>]>:$B); AnyTypeOf<[TensorOf<[I8,I16,I32,I64,F16,F32,F64]>, MemRefOf<[I8,I16,I32,I64,F16,F32,F64]>]>:$B);
let results = (outs AnyTypeOf<[TensorOf<[I1]>, MemRefOf<[I1]>]>:$C); let results = (outs AnyTypeOf<[TensorOf<[I1]>, MemRefOf<[I1]>]>:$C);
let builders = [
OpBuilder<"OpBuilder &builder, OperationState &state, Value A, Value B", [{
auto lhsTy = A.getType().cast<RankedTensorType>();
auto rhsTy = B.getType().cast<RankedTensorType>();
auto elementType = getBroadcastedType(lhsTy, rhsTy);
auto shapedType = elementType.dyn_cast_or_null<ShapedType>();
if (!shapedType || !shapedType.hasStaticShape()) {
elementType = A.getType().cast<TensorType>().getElementType();
elementType = UnrankedTensorType::get(elementType);
}
build(builder, state, elementType, A, B);
}]>,
OpBuilder<"OpBuilder &builder, OperationState &state, ValueRange operands, ArrayRef<NamedAttribute> attributes", [{
auto lhsTy = operands[0].getType().cast<RankedTensorType>();
auto rhsTy = operands[1].getType().cast<RankedTensorType>();
auto elementType = getBroadcastedType(lhsTy, rhsTy);
auto shapedType = elementType.dyn_cast_or_null<ShapedType>();
if (!shapedType || !shapedType.hasStaticShape()) {
elementType = operands[0].getType().cast<TensorType>().getElementType();
elementType = UnrankedTensorType::get(elementType);
}
std::vector<mlir::Type> outputTypes;
outputTypes.emplace_back(elementType);
build(builder, state, outputTypes, operands, attributes);
}]>
];
let extraClassDeclaration = [{ let extraClassDeclaration = [{
static int getNumberOfOperands() { static int getNumberOfOperands() {
return 2; return 2;
@ -2075,6 +2205,32 @@ def ONNXLessOp:ONNX_Op<"Less",
let arguments = (ins AnyTypeOf<[TensorOf<[I8,I16,I32,I64,F16,F32,F64]>, MemRefOf<[I8,I16,I32,I64,F16,F32,F64]>]>:$A, let arguments = (ins AnyTypeOf<[TensorOf<[I8,I16,I32,I64,F16,F32,F64]>, MemRefOf<[I8,I16,I32,I64,F16,F32,F64]>]>:$A,
AnyTypeOf<[TensorOf<[I8,I16,I32,I64,F16,F32,F64]>, MemRefOf<[I8,I16,I32,I64,F16,F32,F64]>]>:$B); AnyTypeOf<[TensorOf<[I8,I16,I32,I64,F16,F32,F64]>, MemRefOf<[I8,I16,I32,I64,F16,F32,F64]>]>:$B);
let results = (outs AnyTypeOf<[TensorOf<[I1]>, MemRefOf<[I1]>]>:$C); let results = (outs AnyTypeOf<[TensorOf<[I1]>, MemRefOf<[I1]>]>:$C);
let builders = [
OpBuilder<"OpBuilder &builder, OperationState &state, Value A, Value B", [{
auto lhsTy = A.getType().cast<RankedTensorType>();
auto rhsTy = B.getType().cast<RankedTensorType>();
auto elementType = getBroadcastedType(lhsTy, rhsTy);
auto shapedType = elementType.dyn_cast_or_null<ShapedType>();
if (!shapedType || !shapedType.hasStaticShape()) {
elementType = A.getType().cast<TensorType>().getElementType();
elementType = UnrankedTensorType::get(elementType);
}
build(builder, state, elementType, A, B);
}]>,
OpBuilder<"OpBuilder &builder, OperationState &state, ValueRange operands, ArrayRef<NamedAttribute> attributes", [{
auto lhsTy = operands[0].getType().cast<RankedTensorType>();
auto rhsTy = operands[1].getType().cast<RankedTensorType>();
auto elementType = getBroadcastedType(lhsTy, rhsTy);
auto shapedType = elementType.dyn_cast_or_null<ShapedType>();
if (!shapedType || !shapedType.hasStaticShape()) {
elementType = operands[0].getType().cast<TensorType>().getElementType();
elementType = UnrankedTensorType::get(elementType);
}
std::vector<mlir::Type> outputTypes;
outputTypes.emplace_back(elementType);
build(builder, state, outputTypes, operands, attributes);
}]>
];
let extraClassDeclaration = [{ let extraClassDeclaration = [{
static int getNumberOfOperands() { static int getNumberOfOperands() {
return 2; return 2;
@ -2646,13 +2802,27 @@ def ONNXMulOp:ONNX_Op<"Mul",
let results = (outs AnyTypeOf<[TensorOf<[I32,I64,F16,F32,F64]>, MemRefOf<[I32,I64,F16,F32,F64]>]>:$C); let results = (outs AnyTypeOf<[TensorOf<[I32,I64,F16,F32,F64]>, MemRefOf<[I32,I64,F16,F32,F64]>]>:$C);
let builders = [ let builders = [
OpBuilder<"OpBuilder &builder, OperationState &state, Value A, Value B", [{ OpBuilder<"OpBuilder &builder, OperationState &state, Value A, Value B", [{
auto elementType = A.getType().cast<TensorType>().getElementType(); auto lhsTy = A.getType().cast<RankedTensorType>();
build(builder, state, UnrankedTensorType::get(elementType), A, B); auto rhsTy = B.getType().cast<RankedTensorType>();
auto elementType = getBroadcastedType(lhsTy, rhsTy);
auto shapedType = elementType.dyn_cast_or_null<ShapedType>();
if (!shapedType || !shapedType.hasStaticShape()) {
elementType = A.getType().cast<TensorType>().getElementType();
elementType = UnrankedTensorType::get(elementType);
}
build(builder, state, elementType, A, B);
}]>, }]>,
OpBuilder<"OpBuilder &builder, OperationState &state, ValueRange operands, ArrayRef<NamedAttribute> attributes", [{ OpBuilder<"OpBuilder &builder, OperationState &state, ValueRange operands, ArrayRef<NamedAttribute> attributes", [{
auto elementType = operands[0].getType().cast<TensorType>().getElementType(); auto lhsTy = operands[0].getType().cast<RankedTensorType>();
auto rhsTy = operands[1].getType().cast<RankedTensorType>();
auto elementType = getBroadcastedType(lhsTy, rhsTy);
auto shapedType = elementType.dyn_cast_or_null<ShapedType>();
if (!shapedType || !shapedType.hasStaticShape()) {
elementType = operands[0].getType().cast<TensorType>().getElementType();
elementType = UnrankedTensorType::get(elementType);
}
std::vector<mlir::Type> outputTypes; std::vector<mlir::Type> outputTypes;
outputTypes.emplace_back(UnrankedTensorType::get(elementType)); outputTypes.emplace_back(elementType);
build(builder, state, outputTypes, operands, attributes); build(builder, state, outputTypes, operands, attributes);
}]> }]>
]; ];
@ -2848,6 +3018,32 @@ def ONNXOrOp:ONNX_Op<"Or",
let arguments = (ins AnyTypeOf<[TensorOf<[I1]>, MemRefOf<[I1]>]>:$A, let arguments = (ins AnyTypeOf<[TensorOf<[I1]>, MemRefOf<[I1]>]>:$A,
AnyTypeOf<[TensorOf<[I1]>, MemRefOf<[I1]>]>:$B); AnyTypeOf<[TensorOf<[I1]>, MemRefOf<[I1]>]>:$B);
let results = (outs AnyTypeOf<[TensorOf<[I1]>, MemRefOf<[I1]>]>:$C); let results = (outs AnyTypeOf<[TensorOf<[I1]>, MemRefOf<[I1]>]>:$C);
let builders = [
OpBuilder<"OpBuilder &builder, OperationState &state, Value A, Value B", [{
auto lhsTy = A.getType().cast<RankedTensorType>();
auto rhsTy = B.getType().cast<RankedTensorType>();
auto elementType = getBroadcastedType(lhsTy, rhsTy);
auto shapedType = elementType.dyn_cast_or_null<ShapedType>();
if (!shapedType || !shapedType.hasStaticShape()) {
elementType = A.getType().cast<TensorType>().getElementType();
elementType = UnrankedTensorType::get(elementType);
}
build(builder, state, elementType, A, B);
}]>,
OpBuilder<"OpBuilder &builder, OperationState &state, ValueRange operands, ArrayRef<NamedAttribute> attributes", [{
auto lhsTy = operands[0].getType().cast<RankedTensorType>();
auto rhsTy = operands[1].getType().cast<RankedTensorType>();
auto elementType = getBroadcastedType(lhsTy, rhsTy);
auto shapedType = elementType.dyn_cast_or_null<ShapedType>();
if (!shapedType || !shapedType.hasStaticShape()) {
elementType = operands[0].getType().cast<TensorType>().getElementType();
elementType = UnrankedTensorType::get(elementType);
}
std::vector<mlir::Type> outputTypes;
outputTypes.emplace_back(elementType);
build(builder, state, outputTypes, operands, attributes);
}]>
];
let extraClassDeclaration = [{ let extraClassDeclaration = [{
static int getNumberOfOperands() { static int getNumberOfOperands() {
return 2; return 2;
@ -3017,6 +3213,32 @@ def ONNXPowOp:ONNX_Op<"Pow",
let arguments = (ins AnyTypeOf<[TensorOf<[F16,F32,F64]>, MemRefOf<[F16,F32,F64]>]>:$X, let arguments = (ins AnyTypeOf<[TensorOf<[F16,F32,F64]>, MemRefOf<[F16,F32,F64]>]>:$X,
AnyTypeOf<[TensorOf<[F16,F32,F64]>, MemRefOf<[F16,F32,F64]>]>:$Y); AnyTypeOf<[TensorOf<[F16,F32,F64]>, MemRefOf<[F16,F32,F64]>]>:$Y);
let results = (outs AnyTypeOf<[TensorOf<[F16,F32,F64]>, MemRefOf<[F16,F32,F64]>]>:$Z); let results = (outs AnyTypeOf<[TensorOf<[F16,F32,F64]>, MemRefOf<[F16,F32,F64]>]>:$Z);
let builders = [
OpBuilder<"OpBuilder &builder, OperationState &state, Value X, Value Y", [{
auto lhsTy = X.getType().cast<RankedTensorType>();
auto rhsTy = Y.getType().cast<RankedTensorType>();
auto elementType = getBroadcastedType(lhsTy, rhsTy);
auto shapedType = elementType.dyn_cast_or_null<ShapedType>();
if (!shapedType || !shapedType.hasStaticShape()) {
elementType = X.getType().cast<TensorType>().getElementType();
elementType = UnrankedTensorType::get(elementType);
}
build(builder, state, elementType, X, Y);
}]>,
OpBuilder<"OpBuilder &builder, OperationState &state, ValueRange operands, ArrayRef<NamedAttribute> attributes", [{
auto lhsTy = operands[0].getType().cast<RankedTensorType>();
auto rhsTy = operands[1].getType().cast<RankedTensorType>();
auto elementType = getBroadcastedType(lhsTy, rhsTy);
auto shapedType = elementType.dyn_cast_or_null<ShapedType>();
if (!shapedType || !shapedType.hasStaticShape()) {
elementType = operands[0].getType().cast<TensorType>().getElementType();
elementType = UnrankedTensorType::get(elementType);
}
std::vector<mlir::Type> outputTypes;
outputTypes.emplace_back(elementType);
build(builder, state, outputTypes, operands, attributes);
}]>
];
let extraClassDeclaration = [{ let extraClassDeclaration = [{
static int getNumberOfOperands() { static int getNumberOfOperands() {
return 2; return 2;
@ -4938,6 +5160,32 @@ def ONNXSubOp:ONNX_Op<"Sub",
let arguments = (ins AnyTypeOf<[TensorOf<[I32,I64,F16,F32,F64]>, MemRefOf<[I32,I64,F16,F32,F64]>]>:$A, let arguments = (ins AnyTypeOf<[TensorOf<[I32,I64,F16,F32,F64]>, MemRefOf<[I32,I64,F16,F32,F64]>]>:$A,
AnyTypeOf<[TensorOf<[I32,I64,F16,F32,F64]>, MemRefOf<[I32,I64,F16,F32,F64]>]>:$B); AnyTypeOf<[TensorOf<[I32,I64,F16,F32,F64]>, MemRefOf<[I32,I64,F16,F32,F64]>]>:$B);
let results = (outs AnyTypeOf<[TensorOf<[I32,I64,F16,F32,F64]>, MemRefOf<[I32,I64,F16,F32,F64]>]>:$C); let results = (outs AnyTypeOf<[TensorOf<[I32,I64,F16,F32,F64]>, MemRefOf<[I32,I64,F16,F32,F64]>]>:$C);
let builders = [
OpBuilder<"OpBuilder &builder, OperationState &state, Value A, Value B", [{
auto lhsTy = A.getType().cast<RankedTensorType>();
auto rhsTy = B.getType().cast<RankedTensorType>();
auto elementType = getBroadcastedType(lhsTy, rhsTy);
auto shapedType = elementType.dyn_cast_or_null<ShapedType>();
if (!shapedType || !shapedType.hasStaticShape()) {
elementType = A.getType().cast<TensorType>().getElementType();
elementType = UnrankedTensorType::get(elementType);
}
build(builder, state, elementType, A, B);
}]>,
OpBuilder<"OpBuilder &builder, OperationState &state, ValueRange operands, ArrayRef<NamedAttribute> attributes", [{
auto lhsTy = operands[0].getType().cast<RankedTensorType>();
auto rhsTy = operands[1].getType().cast<RankedTensorType>();
auto elementType = getBroadcastedType(lhsTy, rhsTy);
auto shapedType = elementType.dyn_cast_or_null<ShapedType>();
if (!shapedType || !shapedType.hasStaticShape()) {
elementType = operands[0].getType().cast<TensorType>().getElementType();
elementType = UnrankedTensorType::get(elementType);
}
std::vector<mlir::Type> outputTypes;
outputTypes.emplace_back(elementType);
build(builder, state, outputTypes, operands, attributes);
}]>
];
let extraClassDeclaration = [{ let extraClassDeclaration = [{
static int getNumberOfOperands() { static int getNumberOfOperands() {
return 2; return 2;
@ -5379,6 +5627,32 @@ def ONNXXorOp:ONNX_Op<"Xor",
let arguments = (ins AnyTypeOf<[TensorOf<[I1]>, MemRefOf<[I1]>]>:$A, let arguments = (ins AnyTypeOf<[TensorOf<[I1]>, MemRefOf<[I1]>]>:$A,
AnyTypeOf<[TensorOf<[I1]>, MemRefOf<[I1]>]>:$B); AnyTypeOf<[TensorOf<[I1]>, MemRefOf<[I1]>]>:$B);
let results = (outs AnyTypeOf<[TensorOf<[I1]>, MemRefOf<[I1]>]>:$C); let results = (outs AnyTypeOf<[TensorOf<[I1]>, MemRefOf<[I1]>]>:$C);
let builders = [
OpBuilder<"OpBuilder &builder, OperationState &state, Value A, Value B", [{
auto lhsTy = A.getType().cast<RankedTensorType>();
auto rhsTy = B.getType().cast<RankedTensorType>();
auto elementType = getBroadcastedType(lhsTy, rhsTy);
auto shapedType = elementType.dyn_cast_or_null<ShapedType>();
if (!shapedType || !shapedType.hasStaticShape()) {
elementType = A.getType().cast<TensorType>().getElementType();
elementType = UnrankedTensorType::get(elementType);
}
build(builder, state, elementType, A, B);
}]>,
OpBuilder<"OpBuilder &builder, OperationState &state, ValueRange operands, ArrayRef<NamedAttribute> attributes", [{
auto lhsTy = operands[0].getType().cast<RankedTensorType>();
auto rhsTy = operands[1].getType().cast<RankedTensorType>();
auto elementType = getBroadcastedType(lhsTy, rhsTy);
auto shapedType = elementType.dyn_cast_or_null<ShapedType>();
if (!shapedType || !shapedType.hasStaticShape()) {
elementType = operands[0].getType().cast<TensorType>().getElementType();
elementType = UnrankedTensorType::get(elementType);
}
std::vector<mlir::Type> outputTypes;
outputTypes.emplace_back(elementType);
build(builder, state, outputTypes, operands, attributes);
}]>
];
let extraClassDeclaration = [{ let extraClassDeclaration = [{
static int getNumberOfOperands() { static int getNumberOfOperands() {
return 2; return 2;

View File

@ -88,6 +88,7 @@ void registerDialects() {
void addONNXToMLIRPasses(mlir::PassManager &pm) { void addONNXToMLIRPasses(mlir::PassManager &pm) {
pm.addPass(mlir::createDecomposeONNXToONNXPass()); pm.addPass(mlir::createDecomposeONNXToONNXPass());
pm.addPass(mlir::createConstPropONNXToONNXPass());
pm.addPass(mlir::createShapeInferencePass()); pm.addPass(mlir::createShapeInferencePass());
pm.addPass(mlir::createCanonicalizerPass()); pm.addPass(mlir::createCanonicalizerPass());
pm.addPass(mlir::createAttributePromotionPass()); pm.addPass(mlir::createAttributePromotionPass());

View File

@ -20,6 +20,8 @@ std::unique_ptr<Pass> createDecomposeONNXToONNXPass();
std::unique_ptr<Pass> createShapeInferencePass(); std::unique_ptr<Pass> createShapeInferencePass();
std::unique_ptr<Pass> createConstPropONNXToONNXPass();
/// Pass for promoting constant operands to attributes. /// Pass for promoting constant operands to attributes.
std::unique_ptr<Pass> createAttributePromotionPass(); std::unique_ptr<Pass> createAttributePromotionPass();

View File

@ -33,17 +33,23 @@ set(LLVM_TARGET_DEFINITIONS ONNXDecompose.td)
onnx_mlir_tablegen(ONNXDecompose.inc -gen-rewriters) onnx_mlir_tablegen(ONNXDecompose.inc -gen-rewriters)
add_public_tablegen_target(OMONNXDecomposeIncGen) add_public_tablegen_target(OMONNXDecomposeIncGen)
set(LLVM_TARGET_DEFINITIONS ONNXConstProp.td)
onnx_mlir_tablegen(ONNXConstProp.inc -gen-rewriters)
add_public_tablegen_target(OMONNXConstPropIncGen)
add_library(OMONNXRewrite add_library(OMONNXRewrite
ONNXRewrite.cpp ONNXRewrite.cpp
ONNXCombine.cpp ONNXCombine.cpp
ONNXDecompose.cpp) ONNXDecompose.cpp
ONNXConstProp.cpp)
target_include_directories(OMONNXRewrite target_include_directories(OMONNXRewrite
PRIVATE ${ONNX_MLIR_SRC_ROOT} ${ONNX_MLIR_BIN_ROOT} PRIVATE ${ONNX_MLIR_SRC_ROOT} ${ONNX_MLIR_BIN_ROOT}
${ONNF_MLIR_SRC_ROOT}) ${ONNF_MLIR_SRC_ROOT})
add_dependencies(OMONNXRewrite add_dependencies(OMONNXRewrite
OMONNXRewriteIncGen OMONNXRewriteIncGen
OMONNXDecomposeIncGen OMONNXDecomposeIncGen
OMONNXCombineIncGen) OMONNXCombineIncGen
OMONNXConstPropIncGen)
# Linking dependencies: # Linking dependencies:
add_dependencies(OMONNXRewrite add_dependencies(OMONNXRewrite
OMONNXOps) OMONNXOps)

View File

@ -0,0 +1,335 @@
//===----------- ONNXConstProp.cpp - ONNX High Level Rewriting ------------===//
//
// Copyright 2019-2020 The IBM Research Authors.
//
// =============================================================================
//
// This file implements a set of rewriters to constprop an ONNX operation into
// composition of other ONNX operations.
//
// This pass is applied before any other pass so that there is no need to
// implement shape inference for the constpropd operation. Hence, it is expected
// that there is no knowledge about tensor shape at this point
//
//===----------------------------------------------------------------------===//
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
#include "src/Dialect/ONNX/ONNXOps.hpp"
#include "src/Pass/Passes.hpp"
using namespace mlir;
namespace {
// =============================================================================
// Instructions to add a constant operation. There is currently support for
// adding constant propagation for unary and binary athythmetic ops (binary ops
// support broadcast). To add an operation, you simply have to add a templated
// method on how to compute the result in terms of one or two inputs. Values
// comes as Attribtues, and return is also an Attribute. In that function,
// presumably you will need different methods to handle int / float /
// strings... Note that these methods cannot fail. It is your responsablitity to
// tests for which data type are supported in the rules directly. Specific type
// restrictions can be added in the DRR files.
// The methods are:
//
// ComputeConstProppElementwiseBinary and ComputeConstProppElementwiseUnary
// and they need to be tempalted wtih an ONNX Operation (presuably).
//
// Then you need to add rules on how to transform the patterns; look into
// ONNXConstProp.td for example.
//
// =============================================================================
// =============================================================================
// Code to perform constant propagation for binary in presence of broadcast.
// =============================================================================
// Template to generate binary operation results. It takes as inupt
// the element type as well as the two element attributes for the
// operation, and return the result of the operation, also as an
// attribute.
template <typename OP>
Attribute ComputeConstProppElementwiseBinary(PatternRewriter &rewriter,
Type elementType, Attribute &lhsAttr, Attribute &secondAttr) {
llvm_unreachable("unkonwn operation");
}
template <>
Attribute ComputeConstProppElementwiseBinary<ONNXAddOp>(
PatternRewriter &rewriter, Type elementType, Attribute &lhsAttr,
Attribute &secondAttr) {
if (elementType.isa<FloatType>()) {
double lhsVal = lhsAttr.cast<FloatAttr>().getValueAsDouble();
double rhsVal = secondAttr.cast<FloatAttr>().getValueAsDouble();
double res = lhsVal + rhsVal;
// printf(" %f + %f -> %f\n", lhsVal, rhsVal, res);
// Could use the APFloat interface to emulate the results, are ok to simply
// perform them in the highest possible precision.
return rewriter.getFloatAttr(elementType, res);
}
if (elementType.isa<IntegerType>()) {
uint64_t lhsVal = lhsAttr.cast<IntegerAttr>().getInt();
uint64_t rhsVal = secondAttr.cast<IntegerAttr>().getInt();
uint64_t res = lhsVal + rhsVal;
// printf(" %llu + %llu -> %llu\n", lhsVal, rhsVal, res);
return rewriter.getIntegerAttr(elementType, res);
}
llvm_unreachable("constant propagation for AddOp: unkonwn data type");
}
template <>
Attribute ComputeConstProppElementwiseBinary<ONNXSubOp>(
PatternRewriter &rewriter, Type elementType, Attribute &lhsAttr,
Attribute &secondAttr) {
if (elementType.isa<FloatType>()) {
double lhsVal = lhsAttr.cast<FloatAttr>().getValueAsDouble();
double rhsVal = secondAttr.cast<FloatAttr>().getValueAsDouble();
double res = lhsVal - rhsVal;
return rewriter.getFloatAttr(elementType, res);
}
if (elementType.isa<IntegerType>()) {
uint64_t lhsVal = lhsAttr.cast<IntegerAttr>().getInt();
uint64_t rhsVal = secondAttr.cast<IntegerAttr>().getInt();
uint64_t res = lhsVal - rhsVal;
return rewriter.getIntegerAttr(elementType, res);
}
llvm_unreachable("constant propagation for SubOp: unkonwn data type");
}
template <>
Attribute ComputeConstProppElementwiseBinary<ONNXMulOp>(
PatternRewriter &rewriter, Type elementType, Attribute &lhsAttr,
Attribute &secondAttr) {
if (elementType.isa<FloatType>()) {
double lhsVal = lhsAttr.cast<FloatAttr>().getValueAsDouble();
double rhsVal = secondAttr.cast<FloatAttr>().getValueAsDouble();
double res = lhsVal * rhsVal;
return rewriter.getFloatAttr(elementType, res);
}
if (elementType.isa<IntegerType>()) {
uint64_t lhsVal = lhsAttr.cast<IntegerAttr>().getInt();
uint64_t rhsVal = secondAttr.cast<IntegerAttr>().getInt();
uint64_t res = lhsVal * rhsVal;
return rewriter.getIntegerAttr(elementType, res);
}
llvm_unreachable("constant propagation for MulOp: unkonwn data type");
}
// Recursively process one dimension in the rank of the two references. There
// can be one of 3 cases.
// 1) We have fully defined accesses for both operands, launch the computations.
// 2) One of the two has a higher number of unprocessed ranks, which is hte case
// when we have to broadcast the whole lower-dim reference with respect to the
// other. Iterate over each value of the higher ranked reference, keeping the
// reference of the lower ranked reference constant.
// 3) Both references have the same rank, we still do broadcast if one of the
// dimension size is equal to 1.
template <typename ElementwiseBinaryOp>
void RecurseConstProppElementwiseBinary(PatternRewriter &rewriter,
std::vector<Attribute> &resVector, DenseElementsAttr &lhsAttr,
DenseElementsAttr &rhsAttr, SmallVector<uint64_t, 4> &lhsIndices,
SmallVector<uint64_t, 4> &rhsIndices, int lhsFreeRank, int rhsFreeRank) {
// printf("recurse with free %d/%d\n", lhsFreeRank, rhsFreeRank);
if (lhsFreeRank == 0) {
// Fully defined ranks.
assert(
rhsFreeRank == 0 && "expect both to recurse to zero at the same time");
auto lhsElementAttr = lhsAttr.getValue(ArrayRef<uint64_t>(lhsIndices));
auto rhsElementAttr = rhsAttr.getValue(ArrayRef<uint64_t>(rhsIndices));
auto elementaryType = lhsAttr.getType().getElementType();
auto res = ComputeConstProppElementwiseBinary<ElementwiseBinaryOp>(
rewriter, elementaryType, lhsElementAttr, rhsElementAttr);
resVector.emplace_back(res);
} else if (lhsFreeRank > rhsFreeRank) {
// Initial broadcast from lhs.
auto lhsShape = lhsAttr.getType().getShape();
int lhsRank = lhsShape.size();
int lhsIndex = lhsRank - lhsFreeRank;
int lhsSize = lhsAttr.getType().getShape()[lhsIndex];
for (int i = 0; i < lhsSize; ++i) {
lhsIndices[lhsIndex] = i;
RecurseConstProppElementwiseBinary<ElementwiseBinaryOp>(rewriter,
resVector, lhsAttr, rhsAttr, lhsIndices, rhsIndices, lhsFreeRank - 1,
rhsFreeRank);
}
} else if (lhsFreeRank < rhsFreeRank) {
// Initial broadcast from rhs.
auto rhsShape = rhsAttr.getType().getShape();
int rhsRank = rhsShape.size();
int rhsIndex = rhsRank - rhsFreeRank;
int rhsSize = rhsAttr.getType().getShape()[rhsIndex];
for (int i = 0; i < rhsSize; ++i) {
rhsIndices[rhsIndex] = i;
RecurseConstProppElementwiseBinary<ElementwiseBinaryOp>(rewriter,
resVector, lhsAttr, rhsAttr, lhsIndices, rhsIndices, lhsFreeRank,
rhsFreeRank - 1);
}
} else {
// No initial broadcast, but if one element has size 1 and the other is
// greater than one, then we also have broadcast.
auto lhsShape = lhsAttr.getType().getShape();
int lhsRank = lhsShape.size();
int lhsIndex = lhsRank - lhsFreeRank;
int lhsSize = lhsAttr.getType().getShape()[lhsIndex];
auto rhsShape = rhsAttr.getType().getShape();
int rhsRank = rhsShape.size();
int rhsIndex = rhsRank - rhsFreeRank;
int rhsSize = rhsAttr.getType().getShape()[rhsIndex];
assert((lhsSize == 1 || rhsSize == 1 || lhsSize == rhsSize) &&
"incompatible sizes");
int size = std::max(lhsSize, rhsSize);
lhsIndices[lhsIndex] = rhsIndices[rhsIndex] = 0;
for (int i = 0; i < size; ++i) {
if (lhsSize > 1)
lhsIndices[lhsIndex] = i;
if (rhsSize > 1)
rhsIndices[rhsIndex] = i;
RecurseConstProppElementwiseBinary<ElementwiseBinaryOp>(rewriter,
resVector, lhsAttr, rhsAttr, lhsIndices, rhsIndices, lhsFreeRank - 1,
rhsFreeRank - 1);
}
}
}
// Process the constant operands, perform the operation with broadcast, and
// generate the new constant operation.
template <typename ElementwiseBinaryOp>
DenseElementsAttr ConstPropElementwiseBinary(PatternRewriter &rewriter,
Value resOperand, Attribute &lhsAttr, Attribute &rhsAttr) {
DenseElementsAttr lhsDenseAttr =
lhsAttr.dyn_cast_or_null<mlir::DenseElementsAttr>();
DenseElementsAttr rhsDenseAttr =
rhsAttr.dyn_cast_or_null<mlir::DenseElementsAttr>();
assert((lhsDenseAttr && lhsDenseAttr) && "expected dense attributes");
assert(
resOperand.getType().isa<RankedTensorType>() && "expected ranked tensor");
ShapedType resType = resOperand.getType().cast<RankedTensorType>();
auto lhsRank = lhsDenseAttr.getType().getShape().size();
auto rhsRank = rhsDenseAttr.getType().getShape().size();
SmallVector<uint64_t, 4> lhsIndices(lhsRank, 0);
SmallVector<uint64_t, 4> rhsIndices(rhsRank, 0);
std::vector<Attribute> resVector;
RecurseConstProppElementwiseBinary<ElementwiseBinaryOp>(rewriter, resVector,
lhsDenseAttr, rhsDenseAttr, lhsIndices, rhsIndices, lhsRank, rhsRank);
ArrayRef<Attribute> resRef(resVector);
return DenseElementsAttr::get(resType, resRef);
}
// =============================================================================
// Code to perform constant propagation for unary operation.
// =============================================================================
template <typename OP>
Attribute ComputeConstProppElementwiseUnary(
PatternRewriter &rewriter, Type elementType, Attribute &attr) {
llvm_unreachable("unkonwn operation");
}
template <>
Attribute ComputeConstProppElementwiseUnary<ONNXNegOp>(
PatternRewriter &rewriter, Type elementType, Attribute &attr) {
if (elementType.isa<FloatType>()) {
double val = attr.cast<FloatAttr>().getValueAsDouble();
double res = -val;
return rewriter.getFloatAttr(elementType, res);
}
if (elementType.isa<IntegerType>()) {
uint64_t val = attr.cast<IntegerAttr>().getInt();
uint64_t res = -val;
return rewriter.getIntegerAttr(elementType, res);
}
llvm_unreachable("constant propagation for NegOp: unkonwn data type");
}
template <typename ElementwiseUnaryOp>
void RecurseConstProppElementwiseUnary(PatternRewriter &rewriter,
std::vector<Attribute> &resVector, DenseElementsAttr &attr,
SmallVector<uint64_t, 4> &indices, int freeRank) {
// printf("recurse with free %d\n", freeRank);
if (freeRank == 0) {
// Fully defined ranks.
auto elementAttr = attr.getValue(ArrayRef<uint64_t>(indices));
auto elementaryType = attr.getType().getElementType();
auto res = ComputeConstProppElementwiseUnary<ElementwiseUnaryOp>(
rewriter, elementaryType, elementAttr);
resVector.emplace_back(res);
} else {
// Recurse.
auto shape = attr.getType().getShape();
int rank = shape.size();
int index = rank - freeRank;
int size = attr.getType().getShape()[index];
for (int i = 0; i < size; ++i) {
indices[index] = i;
RecurseConstProppElementwiseUnary<ElementwiseUnaryOp>(
rewriter, resVector, attr, indices, freeRank - 1);
}
}
}
// Process the constant operands, perform the operation with broadcast, and
// generate the new constant operation.
template <typename ElementwiseUnaryOp>
DenseElementsAttr ConstPropElementwiseUnary(
PatternRewriter &rewriter, Value resOperand, Attribute &attr) {
DenseElementsAttr denseAttr =
attr.dyn_cast_or_null<mlir::DenseElementsAttr>();
assert(denseAttr && "expected dense attribute");
assert(
resOperand.getType().isa<RankedTensorType>() && "expected ranked tensor");
ShapedType resType = resOperand.getType().cast<RankedTensorType>();
auto rank = denseAttr.getType().getShape().size();
SmallVector<uint64_t, 4> indices(rank, 0);
std::vector<Attribute> resVector;
RecurseConstProppElementwiseUnary<ElementwiseUnaryOp>(
rewriter, resVector, denseAttr, indices, rank);
ArrayRef<Attribute> resRef(resVector);
return DenseElementsAttr::get(resType, resRef);
}
// =============================================================================
// Pattern definition.
// =============================================================================
#include "src/Transform/ONNX/ONNXConstProp.inc"
// =============================================================================
// Code to manage the pass.
// =============================================================================
struct ConstPropONNXToONNXPass
: public PassWrapper<ConstPropONNXToONNXPass, FunctionPass> {
void runOnFunction() final;
};
} // end anonymous namespace.
void ConstPropONNXToONNXPass::runOnFunction() {
auto function = getFunction();
MLIRContext *context = &getContext();
ConversionTarget target(getContext());
target.addLegalDialect<ONNXOpsDialect>();
OwningRewritePatternList patterns;
populateWithGenerated(context, &patterns);
applyPatternsAndFoldGreedily(function, patterns);
} // end anonymous namespace
/*!
* Create a ConstPropONNX pass.
*/
std::unique_ptr<mlir::Pass> mlir::createConstPropONNXToONNXPass() {
return std::make_unique<ConstPropONNXToONNXPass>();
}
static PassRegistration<ConstPropONNXToONNXPass> pass("constprop-onnx",
"ConstProp ONNX operations into composition of other ONNX operations.");

View File

@ -0,0 +1,161 @@
//===- ONNXConstProp.td - Rewriting for Constant Propagation in ONNX Ops -*- tablegen -===//
//
// Copyright 2019-2020 The IBM Research Authors.
//
// =============================================================================
//
// Defines language-specific pattern match rewritings for ONNX using
// Declarative Rewrite Rules (DRR) specified using TableGen records.
//
//===----------------------------------------------------------------------===//
#ifndef ONNX_CONSTPROP
#define ONNX_CONSTPROP
#ifndef OP_BASE
include "src/Dialect/ONNX/ONNXOps.td"
#endif // OP_BASE
// =============================================================================
// Instruction to add new constant operation rules. Minimally, you will have added
// operation in the ONNXConstProp.cpp to perform the element-wise single value
// handling of the new operator that you are dealing with. You will need to
// generate a call to the method that handle the tensor constant prop. Here
// is the call for a unary and binary operation. Adapt to your new operator:
//
// def CreateAddOfTwoConst :
// NativeCodeCall<"ConstPropElementwiseBinary<mlir::ONNXAddOp>($_builder, $0, $1, $2)">;
//
// def CreateNegOfConst :
// NativeCodeCall<"ConstPropElementwiseUnary<mlir::ONNXNegOp>($_builder, $0, $1)">;
//
// where you will have mostly to substitute your new operator as well as using
// a new def name.
//
// Then you will need to add substitution rules, see examples below.
// =============================================================================
// Useful test definitions.
def IsNotAConstant :
Constraint<CPred<"! dyn_cast_or_null<ONNXConstantOp>(($_self).getDefiningOp())">,
"operation is not a constant">;
def AttributeIsNull :
Constraint<CPred<"! ($_self)">,
"Attribute is null">;
// Usefult code generation invokation.
def GetNullAttr : NativeCodeCall<"Attribute()">;
def CreateAddOfTwoConst :
NativeCodeCall<"ConstPropElementwiseBinary<mlir::ONNXAddOp>($_builder, $0, $1, $2)">;
def CreateSubOfTwoConst :
NativeCodeCall<"ConstPropElementwiseBinary<mlir::ONNXSubOp>($_builder, $0, $1, $2)">;
def CreateNegOfConst :
NativeCodeCall<"ConstPropElementwiseUnary<mlir::ONNXNegOp>($_builder, $0, $1)">;
def CreateMulOfTwoConst :
NativeCodeCall<"ConstPropElementwiseBinary<mlir::ONNXMulOp>($_builder, $0, $1, $2)">;
// =============================================================================
// Patterns to enable opportunities with elementwise ADD operations.
// Use commutativity to normalize constants in the second position of Add.
def AddConstCommutative1 : Pat<
// From add(c, x).
(ONNXAddOp (ONNXConstantOp:$c $_, $_), $x),
// To add(x, c).
(ONNXAddOp $x, $c),
// To avoid infinite loop, constrain the first arguments to be anything but a constant.
[(IsNotAConstant:$x)]>;
// Use associativity to add constants together.
def AddConstAssociative1 : Pat<
// From add(add(x, c1), c2).
(ONNXAddOp
(ONNXAddOp $x,(ONNXConstantOp:$c1 $_, $_)),
(ONNXConstantOp:$c2 $_, $_)),
// To add(x, add(c1, c2)).
(ONNXAddOp
$x,
(ONNXAddOp $c1, $c2))>;
// Constant Propagation for Add
def AddConstProp : Pat<
// From add(c1, c2).
(ONNXAddOp:$addOp (ONNXConstantOp $s1, $v1), (ONNXConstantOp $s2, $v2)),
// To c1+c2
(ONNXConstantOp (GetNullAttr), (CreateAddOfTwoConst $addOp, $v1, $v2)),
// Additional constraints (no sparse)
[(AttributeIsNull:$s1), (AttributeIsNull:$s2)]>;
// =============================================================================
// Patterns to enable opportunities with elementwise SUB / NEG operations.
// Constant Propagation for Sub
def SubConstProp : Pat<
// From sub(c1, c2).
(ONNXSubOp:$subOp (ONNXConstantOp $s1, $v1), (ONNXConstantOp $s2, $v2)),
// To c1-c2
(ONNXConstantOp (GetNullAttr), (CreateSubOfTwoConst $subOp, $v1, $v2)),
[(AttributeIsNull:$s1), (AttributeIsNull:$s2)]>;
// Neg of constant is simly -const
def NegofConst : Pat<
// From - (c)
(ONNXNegOp (ONNXConstantOp:$constOp $s, $v)),
// To (-c)
(ONNXConstantOp (GetNullAttr), (CreateNegOfConst $constOp, $v)),
[(AttributeIsNull:$s)]>;
// Change a subtraction of a constant c by an addition of -c. Helpfull to combine
// with other add optimizations.
def SubConstToNeg : Pat<
// From x - c.
(ONNXSubOp:$subOp $x, (ONNXConstantOp:$constOp $s, $v)),
// To x + (-c).
(ONNXAddOp $x, (ONNXConstantOp (GetNullAttr), (CreateNegOfConst $constOp, $v))),
[(IsNotAConstant:$x), (AttributeIsNull:$s)]>;
// =============================================================================
// Patterns to enable opportunities with elementwise MUL operations.
// Exactly the same pattern as for the elementwise ADD operations.
// Use commutativity to normalize constants in the second position of Mul.
def MulConstCommutative1 : Pat<
// From mul(c, x).
(ONNXMulOp (ONNXConstantOp:$c $_, $_), $x),
// To mul(x, c).
(ONNXMulOp $x, $c),
// To avoid infinite loop, constrain the first arguments to be anything but a constant.
[(IsNotAConstant:$x)]>;
// Use associativity to mul constants together.
def MulConstAssociative1 : Pat<
// From mul(mul(x, c1), c2).
(ONNXMulOp
(ONNXMulOp $x,(ONNXConstantOp:$c1 $_, $_)),
(ONNXConstantOp:$c2 $_, $_)),
// To mul(x, mul(c1, c2)).
(ONNXMulOp
$x,
(ONNXMulOp $c1, $c2))>;
// Constant Propagation for Mul
def MulConstProp : Pat<
// From mul(c1, c2).
(ONNXMulOp:$mulOp (ONNXConstantOp $s1, $v1), (ONNXConstantOp $s2, $v2)),
// To c1+c2
(ONNXConstantOp (GetNullAttr), (CreateMulOfTwoConst $mulOp, $v1, $v2)),
// Mulitional constraints (no sparse)
[(AttributeIsNull:$s1), (AttributeIsNull:$s2)]>;
#endif // ONNX_CONSTPROP

View File

@ -0,0 +1,158 @@
// RUN: onnx-mlir-opt --constprop-onnx %s -split-input-file | FileCheck %s
// =============================================================================
/// MUL tests (same as add, so have only one).
/// Test ConstantOp assoc for add
// CHECK-LABEL: @test_add_constant_1(%arg0: tensor<3xf32>) -> tensor<3xf32>
func @test_add_constant_1(%arg0 : tensor<3xf32>) -> tensor<3xf32> {
%0 = "onnx.Constant"() {value = dense<[0.0, 1.0, 2.0]> : tensor<3xf32>} : () -> tensor<3xf32>
%1 = "onnx.Add"(%0, %arg0) : (tensor<3xf32> , tensor<3xf32>) -> tensor<3xf32>
"std.return"(%1) : (tensor<3xf32>) -> ()
// CHECK-NEXT: [[CONST:%.+]] = "onnx.Constant"() {value = dense<[0.000000e+00, 1.000000e+00, 2.000000e+00]> : tensor<3xf32>} : () -> tensor<3xf32>
// CHECK-NEXT: [[ADD:%.+]] = "onnx.Add"(%arg0, [[CONST]]) : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xf32>
}
/// Test ConstantOp assoc for add
// CHECK-LABEL: @test_add_constant_2(%arg0: tensor<3xf32>) -> tensor<3xf32>
func @test_add_constant_2(%arg0 : tensor<3xf32>) -> tensor<3xf32> {
%0 = "onnx.Constant"() {value = dense<[0.0, 1.0, 2.0]> : tensor<3xf32>} : () -> tensor<3xf32>
%1 = "onnx.Add"(%arg0, %0) : (tensor<3xf32> , tensor<3xf32>) -> tensor<3xf32>
"std.return"(%1) : (tensor<3xf32>) -> ()
// CHECK-NEXT: [[CONST:%.+]] = "onnx.Constant"() {value = dense<[0.000000e+00, 1.000000e+00, 2.000000e+00]> : tensor<3xf32>} : () -> tensor<3xf32>
// CHECK-NEXT: [[ADD:%.+]] = "onnx.Add"(%arg0, [[CONST]]) : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xf32>
}
/// Change (x+c1)+c2 to x+(c1+c2)
// CHECK-LABEL: @test_add_constant_3(%arg0: tensor<3xi32>) -> tensor<3xi32>
func @test_add_constant_3(%arg0 : tensor<3xi32>) -> tensor<3xi32> {
%0 = "onnx.Constant"() {value = dense<[0, 1, 2]> : tensor<3xi32>} : () -> tensor<3xi32>
%1 = "onnx.Constant"() {value = dense<[10, 11, 12]> : tensor<3xi32>} : () -> tensor<3xi32>
%2 = "onnx.Add"(%0, %arg0) : (tensor<3xi32> , tensor<3xi32>) -> tensor<3xi32>
%3 = "onnx.Add"(%1, %2) : (tensor<3xi32> , tensor<3xi32>) -> tensor<3xi32>
"std.return"(%3) : (tensor<3xi32>) -> ()
// CHECK-NEXT: [[CONST1:%.+]] = "onnx.Constant"() {value = dense<[10, 12, 14]> : tensor<3xi32>} : () -> tensor<3xi32>
// CHECK-NEXT: [[ADD1:%.+]] = "onnx.Add"(%arg0, [[CONST1]]) : (tensor<3xi32>, tensor<3xi32>) -> tensor<3xi32>
}
/// Same test as above, but with a use of an intermediary result
/// change (x+c1)+c2 + (x+c1) to x+(c1+c2) + (x+c1)
// CHECK-LABEL: @test_add_constant_4(%arg0: tensor<3xi32>) -> tensor<3xi32>
func @test_add_constant_4(%arg0 : tensor<3xi32>) -> tensor<3xi32> {
%0 = "onnx.Constant"() {value = dense<[0, 1, 2]> : tensor<3xi32>} : () -> tensor<3xi32>
%1 = "onnx.Constant"() {value = dense<[10, 11, 12]> : tensor<3xi32>} : () -> tensor<3xi32>
%2 = "onnx.Add"(%0, %arg0) : (tensor<3xi32> , tensor<3xi32>) -> tensor<3xi32>
%3 = "onnx.Add"(%1, %2) : (tensor<3xi32> , tensor<3xi32>) -> tensor<3xi32>
%4 = "onnx.Add"(%2, %3) : (tensor<3xi32> , tensor<3xi32>) -> tensor<3xi32>
"std.return"(%4) : (tensor<3xi32>) -> ()
// CHECK-NEXT: [[CONST1:%.+]] = "onnx.Constant"() {value = dense<[0, 1, 2]> : tensor<3xi32>} : () -> tensor<3xi32>
// CHECK-NEXT: [[ADD1:%.+]] = "onnx.Add"(%arg0, [[CONST1]]) : (tensor<3xi32>, tensor<3xi32>) -> tensor<3xi32>
// CHECK-NEXT: [[CONST2:%.+]] = "onnx.Constant"() {value = dense<[10, 12, 14]> : tensor<3xi32>} : () -> tensor<3xi32>
// CHECK-NEXT: [[ADD2:%.+]] = "onnx.Add"(%arg0, [[CONST2]]) : (tensor<3xi32>, tensor<3xi32>) -> tensor<3xi32>
// CHECK-NEXT: [[ADD3:%.+]] = "onnx.Add"([[ADD1]], [[ADD2]]) : (tensor<3xi32>, tensor<3xi32>) -> tensor<3xi32>
}
/// Test broadcast 1 -> 2d
// CHECK-LABEL: @test_broadcast_1(%arg0: tensor<3x2xi32>) -> tensor<3x2xi32>
func @test_broadcast_1(%arg0: tensor<3x2xi32>) -> tensor<3x2xi32> {
%0 = "onnx.Constant"() {value = dense<[1]> : tensor<1xi32>} : () -> tensor<1xi32>
%1 = "onnx.Constant"() {value = dense<[[2, 3], [4, 5], [6, 7]]> : tensor<3x2xi32>} : () -> tensor<3x2xi32>
%2 = "onnx.Add"(%0, %1) : (tensor<1xi32> , tensor<3x2xi32>) -> tensor<3x2xi32>
%3 = "onnx.Add"(%2, %arg0) : (tensor<3x2xi32> , tensor<3x2xi32>) -> tensor<3x2xi32>
"std.return"(%3) : (tensor<3x2xi32>) -> ()
// CHECK-NEXT: [[CONST1:%.+]] = "onnx.Constant"() {value = dense<{{.}}[3, 4], [5, 6], [7, 8]]> : tensor<3x2xi32>} : () -> tensor<3x2xi32>
// CHECK-NEXT: [[ADD1:%.+]] = "onnx.Add"(%arg0, [[CONST1]]) : (tensor<3x2xi32>, tensor<3x2xi32>) -> tensor<3x2xi32>
}
/// Test broadcast 2d (size one) -> 2d
// CHECK-LABEL: @test_broadcast_2(%arg0: tensor<3x2xi32>) -> tensor<3x2xi32>
func @test_broadcast_2(%arg0: tensor<3x2xi32>) -> tensor<3x2xi32> {
%0 = "onnx.Constant"() {value = dense<[[1]]> : tensor<1x1xi32>} : () -> tensor<1x1xi32>
%1 = "onnx.Constant"() {value = dense<[[2, 3], [4, 5], [6, 7]]> : tensor<3x2xi32>} : () -> tensor<3x2xi32>
%2 = "onnx.Add"(%0, %1) : (tensor<1x1xi32> , tensor<3x2xi32>) -> tensor<3x2xi32>
%3 = "onnx.Add"(%2, %arg0) : (tensor<3x2xi32> , tensor<3x2xi32>) -> tensor<3x2xi32>
"std.return"(%3) : (tensor<3x2xi32>) -> ()
// CHECK-NEXT: [[CONST1:%.+]] = "onnx.Constant"() {value = dense<{{.}}[3, 4], [5, 6], [7, 8]]> : tensor<3x2xi32>} : () -> tensor<3x2xi32>
// CHECK-NEXT: [[ADD1:%.+]] = "onnx.Add"(%arg0, [[CONST1]]) : (tensor<3x2xi32>, tensor<3x2xi32>) -> tensor<3x2xi32>
}
/// check 1d -> 2d
// CHECK-LABEL: @test_broadcast_3(%arg0: tensor<3x2xi32>) -> tensor<3x2xi32>
func @test_broadcast_3(%arg0 : tensor<3x2xi32>) -> tensor<3x2xi32> {
%0 = "onnx.Constant"() {value = dense<[[1], [2], [3]]> : tensor<3x1xi32>} : () -> tensor<3x1xi32>
%1 = "onnx.Constant"() {value = dense<[[10, 11], [21, 22], [31, 32]]> : tensor<3x2xi32>} : () -> tensor<3x2xi32>
%2 = "onnx.Add"(%0, %1) : (tensor<3x1xi32> , tensor<3x2xi32>) -> tensor<3x2xi32>
%3 = "onnx.Add"(%2, %arg0) : (tensor<3x2xi32> , tensor<3x2xi32>) -> tensor<3x2xi32>
"std.return"(%3) : (tensor<3x2xi32>) -> ()
// CHECK-NEXT: [[CONST1:%.+]] = "onnx.Constant"() {value = dense<{{.}}[11, 12], [23, 24], [34, 35]]> : tensor<3x2xi32>} : () -> tensor<3x2xi32>
// CHECK-NEXT: [[ADD1:%.+]] = "onnx.Add"(%arg0, [[CONST1]]) : (tensor<3x2xi32>, tensor<3x2xi32>) -> tensor<3x2xi32>
}
// =============================================================================
/// MUL tests (same as add, so have only one).
/// Change (x*c1)*c2 to x*(c1*c2)
// CHECK-LABEL: @test_mul_constant_3(%arg0: tensor<3xi32>) -> tensor<3xi32>
func @test_mul_constant_3(%arg0 : tensor<3xi32>) -> tensor<3xi32> {
%0 = "onnx.Constant"() {value = dense<[0, 1, 2]> : tensor<3xi32>} : () -> tensor<3xi32>
%1 = "onnx.Constant"() {value = dense<[10, 11, 12]> : tensor<3xi32>} : () -> tensor<3xi32>
%2 = "onnx.Mul"(%0, %arg0) : (tensor<3xi32> , tensor<3xi32>) -> tensor<3xi32>
%3 = "onnx.Mul"(%1, %2) : (tensor<3xi32> , tensor<3xi32>) -> tensor<3xi32>
"std.return"(%3) : (tensor<3xi32>) -> ()
// CHECK-NEXT: [[CONST1:%.+]] = "onnx.Constant"() {value = dense<[0, 11, 24]> : tensor<3xi32>} : () -> tensor<3xi32>
// CHECK-NEXT: [[MUL1:%.+]] = "onnx.Mul"(%arg0, [[CONST1]]) : (tensor<3xi32>, tensor<3xi32>) -> tensor<3xi32>
}
// =============================================================================
/// SUB and NEG tests.
// check of sub two constants
// CHECK-LABEL: @test_sub_1(%arg0: tensor<3x2xi32>) -> tensor<3x2xi32>
func @test_sub_1(%arg0: tensor<3x2xi32>) -> tensor<3x2xi32> {
%0 = "onnx.Constant"() {value = dense<[[2, 3], [4, 5], [6, 7]]> : tensor<3x2xi32>} : () -> tensor<3x2xi32>
%1 = "onnx.Constant"() {value = dense<[[2]]> : tensor<1x1xi32>} : () -> tensor<1x1xi32>
%2 = "onnx.Sub"(%0, %1) : (tensor<3x2xi32>, tensor<1x1xi32>) -> tensor<3x2xi32>
"std.return"(%2) : (tensor<3x2xi32>) -> ()
// CHECK-NEXT: [[CONST1:%.+]] = "onnx.Constant"() {value = dense<{{.}}[0, 1], [2, 3], [4, 5]]> : tensor<3x2xi32>} : () -> tensor<3x2xi32>
}
/// check sub to add of negative
// CHECK-LABEL: @test_neg_1(%arg0: tensor<3x2xi32>) -> tensor<3x2xi32>
func @test_neg_1(%arg0: tensor<3x2xi32>) -> tensor<3x2xi32> {
%0 = "onnx.Constant"() {value = dense<[[2, 3], [4, 5], [6, 7]]> : tensor<3x2xi32>} : () -> tensor<3x2xi32>
%1 = "onnx.Sub"(%arg0, %0) : (tensor<3x2xi32> , tensor<3x2xi32>) -> tensor<3x2xi32>
"std.return"(%1) : (tensor<3x2xi32>) -> ()
// CHECK-NEXT: [[CONST1:%.+]] = "onnx.Constant"() {value = dense<{{.}}[-2, -3], [-4, -5], [-6, -7]]> : tensor<3x2xi32>} : () -> tensor<3x2xi32>
// CHECK-NEXT: [[ADD1:%.+]] = "onnx.Add"(%arg0, [[CONST1]]) : (tensor<3x2xi32>, tensor<3x2xi32>) -> tensor<3x2xi32>
}
// CHECK-LABEL: @test_neg_2(%arg0: tensor<3x2xi32>) -> tensor<3x2xi32>
func @test_neg_2(%arg0: tensor<3x2xi32>) -> tensor<3x2xi32> {
%0 = "onnx.Constant"() {value = dense<[[2, 3], [4, 5], [6, 7]]> : tensor<3x2xi32>} : () -> tensor<3x2xi32>
%1 = "onnx.Constant"() {value = dense<[[10]]> : tensor<1x1xi32>} : () -> tensor<1x1xi32>
%2 = "onnx.Sub"(%arg0, %0) : (tensor<3x2xi32> , tensor<3x2xi32>) -> tensor<3x2xi32>
%5 = "onnx.Add"(%2, %1) : (tensor<3x2xi32> , tensor<1x1xi32>) -> tensor<3x2xi32>
"std.return"(%5) : (tensor<3x2xi32>) -> ()
// CHECK-NEXT: [[CONST1:%.+]] = "onnx.Constant"() {value = dense<{{.}}[8, 7], [6, 5], [4, 3]]> : tensor<3x2xi32>} : () -> tensor<3x2xi32>
// CHECK-NEXT: [[ADD1:%.+]] = "onnx.Add"(%arg0, [[CONST1]]) : (tensor<3x2xi32>, tensor<3x2xi32>) -> tensor<3x2xi32>
}
// CHECK-LABEL: @test_neg_3(%arg0: tensor<3x2xi32>) -> tensor<3x2xi32>
func @test_neg_3(%arg0: tensor<3x2xi32>) -> tensor<3x2xi32> {
%0 = "onnx.Constant"() {value = dense<[[2, 3], [4, 5], [6, 7]]> : tensor<3x2xi32>} : () -> tensor<3x2xi32>
%1 = "onnx.Constant"() {value = dense<[[10]]> : tensor<1x1xi32>} : () -> tensor<1x1xi32>
%2 = "onnx.Neg"(%0) : (tensor<3x2xi32>) -> tensor<3x2xi32>
%3 = "onnx.Add"(%arg0, %2) : (tensor<3x2xi32> , tensor<3x2xi32>) -> tensor<3x2xi32>
%4 = "onnx.Add"(%3, %1) : (tensor<3x2xi32> , tensor<1x1xi32>) -> tensor<3x2xi32>
"std.return"(%4) : (tensor<3x2xi32>) -> ()
// CHECK-NEXT: [[CONST1:%.+]] = "onnx.Constant"() {value = dense<{{.}}[8, 7], [6, 5], [4, 3]]> : tensor<3x2xi32>} : () -> tensor<3x2xi32>
// CHECK-NEXT: [[ADD1:%.+]] = "onnx.Add"(%arg0, [[CONST1]]) : (tensor<3x2xi32>, tensor<3x2xi32>) -> tensor<3x2xi32>
}

View File

@ -295,8 +295,15 @@ OpsWithResultTypeInference = {
# Currenlty, there are only two build methods generated: # Currenlty, there are only two build methods generated:
# - one with operands and attributes having a separate parameter, and # - one with operands and attributes having a separate parameter, and
# - one with operands and attributes having aggregated parameters. # - one with operands and attributes having aggregated parameters.
custom_builder_ops_list = ['Abs', 'Mul', 'Exp', 'ReduceSum', 'ReduceSumSquare', 'Pad'] custom_builder_unranked_ops_list = ['Abs', 'Exp', 'ReduceSum', 'ReduceSumSquare', 'Pad']
# Custom builder op list for operations with broadcast; we can deduce the right
# output type, no need to leave it undef as in the above list.
# Ops must have two operands, not one, not three... And there shall be two.
# TODO: handle variadic ops omitted here: Max, Min, Min, Sum.
custom_builder_broadcast_ops_list = ['Add', 'And', 'Div', 'Equal', 'Greater',
'Less', 'Mul', 'Or', 'Pow', 'Sub', 'Xor']
# union of both
custom_builder_ops_list = custom_builder_unranked_ops_list + custom_builder_broadcast_ops_list
#a dictionary to add any special definition for an operation #a dictionary to add any special definition for an operation
custom_definition_misc = dict([ ('Constant', custom_definition_misc = dict([ ('Constant',
@ -716,6 +723,8 @@ def gen_op_def(schema):
s += indent + 'let results = (outs {});\n'.format( s += indent + 'let results = (outs {});\n'.format(
(',\n' + inc_indent(indent)).join(outs_strs)) (',\n' + inc_indent(indent)).join(outs_strs))
# custom_builder_broadcast_ops_list
# add custom builders # add custom builders
# use element type of the first operand to construct an UnrankedTensorType for the output. # use element type of the first operand to construct an UnrankedTensorType for the output.
if schema.name in custom_builder_ops_list: if schema.name in custom_builder_ops_list:
@ -726,7 +735,8 @@ def gen_op_def(schema):
else: else:
s += indent + 'let builders = [\n' s += indent + 'let builders = [\n'
# Custom builders with operands and attributes having a seperate parameter. # Custom builders with operands and attributes having a seperate parameter.
# E.g. OpBuilder<"OpBuilder &builder, OperationState &state, Value X, Value, Y, Attribute A", [{}]> # E.g. OpBuilder<"OpBuilder &builder, OperationState &state, Value X,
# Value, Y, Attribute A", [{}]>
indent = inc_indent(indent) indent = inc_indent(indent)
s += indent + 'OpBuilder<"OpBuilder &builder, OperationState &state' s += indent + 'OpBuilder<"OpBuilder &builder, OperationState &state'
operands_dict = get_operands_or_results(schema, is_input=True) operands_dict = get_operands_or_results(schema, is_input=True)
@ -740,9 +750,26 @@ def gen_op_def(schema):
# Get output type from first operand's type. # Get output type from first operand's type.
first_operand_name = list(ins.items())[0][0] first_operand_name = list(ins.items())[0][0]
s += indent + 'auto elementType = {}.getType().cast<TensorType>().getElementType();\n'.format( build_type_name = ''
first_operand_name) if schema.name in custom_builder_broadcast_ops_list:
s += indent + 'build(builder, state, UnrankedTensorType::get(elementType)' second_operand_name = list(ins.items())[1][0]
s += indent + 'auto lhsTy = {}.getType().cast<RankedTensorType>();\n'. \
format(first_operand_name)
s += indent + 'auto rhsTy = {}.getType().cast<RankedTensorType>();\n'. \
format(second_operand_name)
s += indent + 'auto elementType = getBroadcastedType(lhsTy, rhsTy);\n'
s += indent + 'auto shapedType = elementType.dyn_cast_or_null<ShapedType>();\n';
s += indent + 'if (!shapedType || !shapedType.hasStaticShape()) {\n';
s += indent + indent + 'elementType = {}'.format(first_operand_name) + \
'.getType().cast<TensorType>().getElementType();\n';
s += indent + indent + 'elementType = UnrankedTensorType::get(elementType);\n'
s += indent + '}\n';
build_type_name = 'elementType'
else:
s += indent + 'auto elementType = {}'.format(first_operand_name) + \
'.getType().cast<TensorType>().getElementType();\n'
build_type_name = 'UnrankedTensorType::get(elementType)'
s += indent + 'build(builder, state, {}'.format(build_type_name)
for name, _ in ins.items(): for name, _ in ins.items():
s += ', ' + name s += ', ' + name
s += ');\n' s += ');\n'
@ -750,12 +777,26 @@ def gen_op_def(schema):
s += indent + '}]>,\n' s += indent + '}]>,\n'
# Custom builders with all operands and attributes having aggregate parameters. # Custom builders with all operands and attributes having aggregate parameters.
# E.g. OpBuilder<"OpBuilder &builder, OperationState &state, ValueRange operands, ArrayRef<NamedAttribute> attributes", [{}]>' # E.g. OpBuilder<"OpBuilder &builder, OperationState &state, ValueRange operands,
s += indent + 'OpBuilder<"OpBuilder &builder, OperationState &state, ValueRange operands, ArrayRef<NamedAttribute> attributes", [{\n' # ArrayRef<NamedAttribute> attributes", [{}]>'
s += indent + 'OpBuilder<"OpBuilder &builder, OperationState &state, ' + \
'ValueRange operands, ArrayRef<NamedAttribute> attributes", [{\n'
indent = inc_indent(indent) indent = inc_indent(indent)
s += indent + 'auto elementType = operands[0].getType().cast<TensorType>().getElementType();\n' if schema.name in custom_builder_broadcast_ops_list:
s += indent + 'auto lhsTy = operands[0].getType().cast<RankedTensorType>();\n'
s += indent + 'auto rhsTy = operands[1].getType().cast<RankedTensorType>();\n'
s += indent + 'auto elementType = getBroadcastedType(lhsTy, rhsTy);\n'
s += indent + 'auto shapedType = elementType.dyn_cast_or_null<ShapedType>();\n';
s += indent + 'if (!shapedType || !shapedType.hasStaticShape()) {\n';
s += indent + indent + 'elementType = operands[0]' + \
'.getType().cast<TensorType>().getElementType();\n';
s += indent + indent + 'elementType = UnrankedTensorType::get(elementType);\n'
s += indent + '}\n';
else:
s += indent + 'auto elementType = operands[0].getType().' + \
'cast<TensorType>().getElementType();\n'
s += indent + 'std::vector<mlir::Type> outputTypes;\n' s += indent + 'std::vector<mlir::Type> outputTypes;\n'
s += indent + 'outputTypes.emplace_back(UnrankedTensorType::get(elementType));\n' s += indent + 'outputTypes.emplace_back({});\n'.format(build_type_name)
s += indent + 'build(builder, state, outputTypes, operands, attributes);\n' s += indent + 'build(builder, state, outputTypes, operands, attributes);\n'
indent = dec_indent(indent) indent = dec_indent(indent)
s += indent + '}]>' s += indent + '}]>'