Constprop (#162)
* initial const prop attempt * added support for broadcast ops * adde all binary broadcast ops into custom builders with precise type * added test example * working * format * fixed suggestion by Tung, start woring on unary * added subtraction and neg the right way, and added elementwise mul too * formatting changes * format * format * added instructions to add new optimizations
This commit is contained in:
parent
dedd5f4a12
commit
e2af505746
|
@ -116,6 +116,7 @@ docs/_build/
|
||||||
|
|
||||||
# PyBuilder
|
# PyBuilder
|
||||||
target/
|
target/
|
||||||
|
utils/ONNXOps.td.inc
|
||||||
|
|
||||||
# Jupyter Notebook
|
# Jupyter Notebook
|
||||||
.ipynb_checkpoints
|
.ipynb_checkpoints
|
||||||
|
@ -175,3 +176,6 @@ dmypy.json
|
||||||
|
|
||||||
# pytype static type analyzer
|
# pytype static type analyzer
|
||||||
.pytype/
|
.pytype/
|
||||||
|
|
||||||
|
#editor
|
||||||
|
*~
|
||||||
|
|
|
@ -1,9 +1,9 @@
|
||||||
# Import ONNX specifications into ONNX-MLIR
|
# Import ONNX specifications into ONNX-MLIR
|
||||||
|
|
||||||
ONNX specifications are defined under `onnx/defs` directory in the ONNX project repository.
|
ONNX specifications are defined under `onnx/defs` directory in the ONNX project repository.
|
||||||
There is a python script onnx/defs/gen_doc.py that automatically generate documents about operations in ONNX (docs/Operations.md).
|
There is a python script onnx/defs/gen_onnx_mlir.py that automatically generate documents about operations in ONNX (docs/Operations.md).
|
||||||
ONNX-MLIR modified this script to import ONNX specifications into ONNX-MLIR.
|
ONNX-MLIR modified this script to import ONNX specifications into ONNX-MLIR.
|
||||||
There are two files generated for ONNX MLIR with the modified gen_doc.py:
|
There are two files generated for ONNX MLIR with the modified gen_onnx_mlir.py:
|
||||||
|
|
||||||
1. `src/Dialect/ONNX/ONNXOps.td.inc`: Operation definition for MLIR TableGen. `src/Dialect/ONNX/ONNXOps.td` includes this file.
|
1. `src/Dialect/ONNX/ONNXOps.td.inc`: Operation definition for MLIR TableGen. `src/Dialect/ONNX/ONNXOps.td` includes this file.
|
||||||
2. `src/Builder/OpBuildTable.inc`: C++ code for ONNX-MLIR frontend to import operation nodes from ONNX model. `src/Builder/FrontendDialectTransformer.cpp` includes this file.
|
2. `src/Builder/OpBuildTable.inc`: C++ code for ONNX-MLIR frontend to import operation nodes from ONNX model. `src/Builder/FrontendDialectTransformer.cpp` includes this file.
|
||||||
|
@ -22,7 +22,7 @@ Even though we strive to support the latest version of ONNX specification as qui
|
||||||
Due to the possibility of such a delay, operator definition within the ONNX project repository may describe features and schemas that we do not yet support.
|
Due to the possibility of such a delay, operator definition within the ONNX project repository may describe features and schemas that we do not yet support.
|
||||||
|
|
||||||
## Customization
|
## Customization
|
||||||
In addition to following the ONNX specification, the script gen_onnx_mlir.py, modified gen_doc.py, provides some mechanism for you to customize the output.
|
In addition to following the ONNX specification, the script gen_onnx_mlir.py, modified gen_onnx_mlir.py, provides some mechanism for you to customize the output.
|
||||||
Several tables are defined at the beginning of the script:
|
Several tables are defined at the beginning of the script:
|
||||||
1. `special_attr_defaults`: gives attribute special default value.
|
1. `special_attr_defaults`: gives attribute special default value.
|
||||||
2. `special_op_handler`: creates special import function in frontend_dialect_transformer.cpp. Currently, a special handler is used for operations with operational arguments
|
2. `special_op_handler`: creates special import function in frontend_dialect_transformer.cpp. Currently, a special handler is used for operations with operational arguments
|
||||||
|
|
|
@ -93,6 +93,32 @@ def ONNXAddOp:ONNX_Op<"Add",
|
||||||
let arguments = (ins AnyTypeOf<[TensorOf<[I32,I64,F16,F32,F64]>, MemRefOf<[I32,I64,F16,F32,F64]>]>:$A,
|
let arguments = (ins AnyTypeOf<[TensorOf<[I32,I64,F16,F32,F64]>, MemRefOf<[I32,I64,F16,F32,F64]>]>:$A,
|
||||||
AnyTypeOf<[TensorOf<[I32,I64,F16,F32,F64]>, MemRefOf<[I32,I64,F16,F32,F64]>]>:$B);
|
AnyTypeOf<[TensorOf<[I32,I64,F16,F32,F64]>, MemRefOf<[I32,I64,F16,F32,F64]>]>:$B);
|
||||||
let results = (outs AnyTypeOf<[TensorOf<[I32,I64,F16,F32,F64]>, MemRefOf<[I32,I64,F16,F32,F64]>]>:$C);
|
let results = (outs AnyTypeOf<[TensorOf<[I32,I64,F16,F32,F64]>, MemRefOf<[I32,I64,F16,F32,F64]>]>:$C);
|
||||||
|
let builders = [
|
||||||
|
OpBuilder<"OpBuilder &builder, OperationState &state, Value A, Value B", [{
|
||||||
|
auto lhsTy = A.getType().cast<RankedTensorType>();
|
||||||
|
auto rhsTy = B.getType().cast<RankedTensorType>();
|
||||||
|
auto elementType = getBroadcastedType(lhsTy, rhsTy);
|
||||||
|
auto shapedType = elementType.dyn_cast_or_null<ShapedType>();
|
||||||
|
if (!shapedType || !shapedType.hasStaticShape()) {
|
||||||
|
elementType = A.getType().cast<TensorType>().getElementType();
|
||||||
|
elementType = UnrankedTensorType::get(elementType);
|
||||||
|
}
|
||||||
|
build(builder, state, elementType, A, B);
|
||||||
|
}]>,
|
||||||
|
OpBuilder<"OpBuilder &builder, OperationState &state, ValueRange operands, ArrayRef<NamedAttribute> attributes", [{
|
||||||
|
auto lhsTy = operands[0].getType().cast<RankedTensorType>();
|
||||||
|
auto rhsTy = operands[1].getType().cast<RankedTensorType>();
|
||||||
|
auto elementType = getBroadcastedType(lhsTy, rhsTy);
|
||||||
|
auto shapedType = elementType.dyn_cast_or_null<ShapedType>();
|
||||||
|
if (!shapedType || !shapedType.hasStaticShape()) {
|
||||||
|
elementType = operands[0].getType().cast<TensorType>().getElementType();
|
||||||
|
elementType = UnrankedTensorType::get(elementType);
|
||||||
|
}
|
||||||
|
std::vector<mlir::Type> outputTypes;
|
||||||
|
outputTypes.emplace_back(elementType);
|
||||||
|
build(builder, state, outputTypes, operands, attributes);
|
||||||
|
}]>
|
||||||
|
];
|
||||||
let extraClassDeclaration = [{
|
let extraClassDeclaration = [{
|
||||||
static int getNumberOfOperands() {
|
static int getNumberOfOperands() {
|
||||||
return 2;
|
return 2;
|
||||||
|
@ -118,6 +144,32 @@ def ONNXAndOp:ONNX_Op<"And",
|
||||||
let arguments = (ins AnyTypeOf<[TensorOf<[I1]>, MemRefOf<[I1]>]>:$A,
|
let arguments = (ins AnyTypeOf<[TensorOf<[I1]>, MemRefOf<[I1]>]>:$A,
|
||||||
AnyTypeOf<[TensorOf<[I1]>, MemRefOf<[I1]>]>:$B);
|
AnyTypeOf<[TensorOf<[I1]>, MemRefOf<[I1]>]>:$B);
|
||||||
let results = (outs AnyTypeOf<[TensorOf<[I1]>, MemRefOf<[I1]>]>:$C);
|
let results = (outs AnyTypeOf<[TensorOf<[I1]>, MemRefOf<[I1]>]>:$C);
|
||||||
|
let builders = [
|
||||||
|
OpBuilder<"OpBuilder &builder, OperationState &state, Value A, Value B", [{
|
||||||
|
auto lhsTy = A.getType().cast<RankedTensorType>();
|
||||||
|
auto rhsTy = B.getType().cast<RankedTensorType>();
|
||||||
|
auto elementType = getBroadcastedType(lhsTy, rhsTy);
|
||||||
|
auto shapedType = elementType.dyn_cast_or_null<ShapedType>();
|
||||||
|
if (!shapedType || !shapedType.hasStaticShape()) {
|
||||||
|
elementType = A.getType().cast<TensorType>().getElementType();
|
||||||
|
elementType = UnrankedTensorType::get(elementType);
|
||||||
|
}
|
||||||
|
build(builder, state, elementType, A, B);
|
||||||
|
}]>,
|
||||||
|
OpBuilder<"OpBuilder &builder, OperationState &state, ValueRange operands, ArrayRef<NamedAttribute> attributes", [{
|
||||||
|
auto lhsTy = operands[0].getType().cast<RankedTensorType>();
|
||||||
|
auto rhsTy = operands[1].getType().cast<RankedTensorType>();
|
||||||
|
auto elementType = getBroadcastedType(lhsTy, rhsTy);
|
||||||
|
auto shapedType = elementType.dyn_cast_or_null<ShapedType>();
|
||||||
|
if (!shapedType || !shapedType.hasStaticShape()) {
|
||||||
|
elementType = operands[0].getType().cast<TensorType>().getElementType();
|
||||||
|
elementType = UnrankedTensorType::get(elementType);
|
||||||
|
}
|
||||||
|
std::vector<mlir::Type> outputTypes;
|
||||||
|
outputTypes.emplace_back(elementType);
|
||||||
|
build(builder, state, outputTypes, operands, attributes);
|
||||||
|
}]>
|
||||||
|
];
|
||||||
let extraClassDeclaration = [{
|
let extraClassDeclaration = [{
|
||||||
static int getNumberOfOperands() {
|
static int getNumberOfOperands() {
|
||||||
return 2;
|
return 2;
|
||||||
|
@ -933,6 +985,32 @@ def ONNXDivOp:ONNX_Op<"Div",
|
||||||
let arguments = (ins AnyTypeOf<[TensorOf<[I32,I64,F16,F32,F64]>, MemRefOf<[I32,I64,F16,F32,F64]>]>:$A,
|
let arguments = (ins AnyTypeOf<[TensorOf<[I32,I64,F16,F32,F64]>, MemRefOf<[I32,I64,F16,F32,F64]>]>:$A,
|
||||||
AnyTypeOf<[TensorOf<[I32,I64,F16,F32,F64]>, MemRefOf<[I32,I64,F16,F32,F64]>]>:$B);
|
AnyTypeOf<[TensorOf<[I32,I64,F16,F32,F64]>, MemRefOf<[I32,I64,F16,F32,F64]>]>:$B);
|
||||||
let results = (outs AnyTypeOf<[TensorOf<[I32,I64,F16,F32,F64]>, MemRefOf<[I32,I64,F16,F32,F64]>]>:$C);
|
let results = (outs AnyTypeOf<[TensorOf<[I32,I64,F16,F32,F64]>, MemRefOf<[I32,I64,F16,F32,F64]>]>:$C);
|
||||||
|
let builders = [
|
||||||
|
OpBuilder<"OpBuilder &builder, OperationState &state, Value A, Value B", [{
|
||||||
|
auto lhsTy = A.getType().cast<RankedTensorType>();
|
||||||
|
auto rhsTy = B.getType().cast<RankedTensorType>();
|
||||||
|
auto elementType = getBroadcastedType(lhsTy, rhsTy);
|
||||||
|
auto shapedType = elementType.dyn_cast_or_null<ShapedType>();
|
||||||
|
if (!shapedType || !shapedType.hasStaticShape()) {
|
||||||
|
elementType = A.getType().cast<TensorType>().getElementType();
|
||||||
|
elementType = UnrankedTensorType::get(elementType);
|
||||||
|
}
|
||||||
|
build(builder, state, elementType, A, B);
|
||||||
|
}]>,
|
||||||
|
OpBuilder<"OpBuilder &builder, OperationState &state, ValueRange operands, ArrayRef<NamedAttribute> attributes", [{
|
||||||
|
auto lhsTy = operands[0].getType().cast<RankedTensorType>();
|
||||||
|
auto rhsTy = operands[1].getType().cast<RankedTensorType>();
|
||||||
|
auto elementType = getBroadcastedType(lhsTy, rhsTy);
|
||||||
|
auto shapedType = elementType.dyn_cast_or_null<ShapedType>();
|
||||||
|
if (!shapedType || !shapedType.hasStaticShape()) {
|
||||||
|
elementType = operands[0].getType().cast<TensorType>().getElementType();
|
||||||
|
elementType = UnrankedTensorType::get(elementType);
|
||||||
|
}
|
||||||
|
std::vector<mlir::Type> outputTypes;
|
||||||
|
outputTypes.emplace_back(elementType);
|
||||||
|
build(builder, state, outputTypes, operands, attributes);
|
||||||
|
}]>
|
||||||
|
];
|
||||||
let extraClassDeclaration = [{
|
let extraClassDeclaration = [{
|
||||||
static int getNumberOfOperands() {
|
static int getNumberOfOperands() {
|
||||||
return 2;
|
return 2;
|
||||||
|
@ -1055,6 +1133,32 @@ def ONNXEqualOp:ONNX_Op<"Equal",
|
||||||
let arguments = (ins AnyTypeOf<[TensorOf<[I1,I8,I16,I32,I64,F16,F32,F64]>, MemRefOf<[I1,I8,I16,I32,I64,F16,F32,F64]>]>:$A,
|
let arguments = (ins AnyTypeOf<[TensorOf<[I1,I8,I16,I32,I64,F16,F32,F64]>, MemRefOf<[I1,I8,I16,I32,I64,F16,F32,F64]>]>:$A,
|
||||||
AnyTypeOf<[TensorOf<[I1,I8,I16,I32,I64,F16,F32,F64]>, MemRefOf<[I1,I8,I16,I32,I64,F16,F32,F64]>]>:$B);
|
AnyTypeOf<[TensorOf<[I1,I8,I16,I32,I64,F16,F32,F64]>, MemRefOf<[I1,I8,I16,I32,I64,F16,F32,F64]>]>:$B);
|
||||||
let results = (outs AnyTypeOf<[TensorOf<[I1]>, MemRefOf<[I1]>]>:$C);
|
let results = (outs AnyTypeOf<[TensorOf<[I1]>, MemRefOf<[I1]>]>:$C);
|
||||||
|
let builders = [
|
||||||
|
OpBuilder<"OpBuilder &builder, OperationState &state, Value A, Value B", [{
|
||||||
|
auto lhsTy = A.getType().cast<RankedTensorType>();
|
||||||
|
auto rhsTy = B.getType().cast<RankedTensorType>();
|
||||||
|
auto elementType = getBroadcastedType(lhsTy, rhsTy);
|
||||||
|
auto shapedType = elementType.dyn_cast_or_null<ShapedType>();
|
||||||
|
if (!shapedType || !shapedType.hasStaticShape()) {
|
||||||
|
elementType = A.getType().cast<TensorType>().getElementType();
|
||||||
|
elementType = UnrankedTensorType::get(elementType);
|
||||||
|
}
|
||||||
|
build(builder, state, elementType, A, B);
|
||||||
|
}]>,
|
||||||
|
OpBuilder<"OpBuilder &builder, OperationState &state, ValueRange operands, ArrayRef<NamedAttribute> attributes", [{
|
||||||
|
auto lhsTy = operands[0].getType().cast<RankedTensorType>();
|
||||||
|
auto rhsTy = operands[1].getType().cast<RankedTensorType>();
|
||||||
|
auto elementType = getBroadcastedType(lhsTy, rhsTy);
|
||||||
|
auto shapedType = elementType.dyn_cast_or_null<ShapedType>();
|
||||||
|
if (!shapedType || !shapedType.hasStaticShape()) {
|
||||||
|
elementType = operands[0].getType().cast<TensorType>().getElementType();
|
||||||
|
elementType = UnrankedTensorType::get(elementType);
|
||||||
|
}
|
||||||
|
std::vector<mlir::Type> outputTypes;
|
||||||
|
outputTypes.emplace_back(elementType);
|
||||||
|
build(builder, state, outputTypes, operands, attributes);
|
||||||
|
}]>
|
||||||
|
];
|
||||||
let extraClassDeclaration = [{
|
let extraClassDeclaration = [{
|
||||||
static int getNumberOfOperands() {
|
static int getNumberOfOperands() {
|
||||||
return 2;
|
return 2;
|
||||||
|
@ -1697,6 +1801,32 @@ def ONNXGreaterOp:ONNX_Op<"Greater",
|
||||||
let arguments = (ins AnyTypeOf<[TensorOf<[I8,I16,I32,I64,F16,F32,F64]>, MemRefOf<[I8,I16,I32,I64,F16,F32,F64]>]>:$A,
|
let arguments = (ins AnyTypeOf<[TensorOf<[I8,I16,I32,I64,F16,F32,F64]>, MemRefOf<[I8,I16,I32,I64,F16,F32,F64]>]>:$A,
|
||||||
AnyTypeOf<[TensorOf<[I8,I16,I32,I64,F16,F32,F64]>, MemRefOf<[I8,I16,I32,I64,F16,F32,F64]>]>:$B);
|
AnyTypeOf<[TensorOf<[I8,I16,I32,I64,F16,F32,F64]>, MemRefOf<[I8,I16,I32,I64,F16,F32,F64]>]>:$B);
|
||||||
let results = (outs AnyTypeOf<[TensorOf<[I1]>, MemRefOf<[I1]>]>:$C);
|
let results = (outs AnyTypeOf<[TensorOf<[I1]>, MemRefOf<[I1]>]>:$C);
|
||||||
|
let builders = [
|
||||||
|
OpBuilder<"OpBuilder &builder, OperationState &state, Value A, Value B", [{
|
||||||
|
auto lhsTy = A.getType().cast<RankedTensorType>();
|
||||||
|
auto rhsTy = B.getType().cast<RankedTensorType>();
|
||||||
|
auto elementType = getBroadcastedType(lhsTy, rhsTy);
|
||||||
|
auto shapedType = elementType.dyn_cast_or_null<ShapedType>();
|
||||||
|
if (!shapedType || !shapedType.hasStaticShape()) {
|
||||||
|
elementType = A.getType().cast<TensorType>().getElementType();
|
||||||
|
elementType = UnrankedTensorType::get(elementType);
|
||||||
|
}
|
||||||
|
build(builder, state, elementType, A, B);
|
||||||
|
}]>,
|
||||||
|
OpBuilder<"OpBuilder &builder, OperationState &state, ValueRange operands, ArrayRef<NamedAttribute> attributes", [{
|
||||||
|
auto lhsTy = operands[0].getType().cast<RankedTensorType>();
|
||||||
|
auto rhsTy = operands[1].getType().cast<RankedTensorType>();
|
||||||
|
auto elementType = getBroadcastedType(lhsTy, rhsTy);
|
||||||
|
auto shapedType = elementType.dyn_cast_or_null<ShapedType>();
|
||||||
|
if (!shapedType || !shapedType.hasStaticShape()) {
|
||||||
|
elementType = operands[0].getType().cast<TensorType>().getElementType();
|
||||||
|
elementType = UnrankedTensorType::get(elementType);
|
||||||
|
}
|
||||||
|
std::vector<mlir::Type> outputTypes;
|
||||||
|
outputTypes.emplace_back(elementType);
|
||||||
|
build(builder, state, outputTypes, operands, attributes);
|
||||||
|
}]>
|
||||||
|
];
|
||||||
let extraClassDeclaration = [{
|
let extraClassDeclaration = [{
|
||||||
static int getNumberOfOperands() {
|
static int getNumberOfOperands() {
|
||||||
return 2;
|
return 2;
|
||||||
|
@ -2075,6 +2205,32 @@ def ONNXLessOp:ONNX_Op<"Less",
|
||||||
let arguments = (ins AnyTypeOf<[TensorOf<[I8,I16,I32,I64,F16,F32,F64]>, MemRefOf<[I8,I16,I32,I64,F16,F32,F64]>]>:$A,
|
let arguments = (ins AnyTypeOf<[TensorOf<[I8,I16,I32,I64,F16,F32,F64]>, MemRefOf<[I8,I16,I32,I64,F16,F32,F64]>]>:$A,
|
||||||
AnyTypeOf<[TensorOf<[I8,I16,I32,I64,F16,F32,F64]>, MemRefOf<[I8,I16,I32,I64,F16,F32,F64]>]>:$B);
|
AnyTypeOf<[TensorOf<[I8,I16,I32,I64,F16,F32,F64]>, MemRefOf<[I8,I16,I32,I64,F16,F32,F64]>]>:$B);
|
||||||
let results = (outs AnyTypeOf<[TensorOf<[I1]>, MemRefOf<[I1]>]>:$C);
|
let results = (outs AnyTypeOf<[TensorOf<[I1]>, MemRefOf<[I1]>]>:$C);
|
||||||
|
let builders = [
|
||||||
|
OpBuilder<"OpBuilder &builder, OperationState &state, Value A, Value B", [{
|
||||||
|
auto lhsTy = A.getType().cast<RankedTensorType>();
|
||||||
|
auto rhsTy = B.getType().cast<RankedTensorType>();
|
||||||
|
auto elementType = getBroadcastedType(lhsTy, rhsTy);
|
||||||
|
auto shapedType = elementType.dyn_cast_or_null<ShapedType>();
|
||||||
|
if (!shapedType || !shapedType.hasStaticShape()) {
|
||||||
|
elementType = A.getType().cast<TensorType>().getElementType();
|
||||||
|
elementType = UnrankedTensorType::get(elementType);
|
||||||
|
}
|
||||||
|
build(builder, state, elementType, A, B);
|
||||||
|
}]>,
|
||||||
|
OpBuilder<"OpBuilder &builder, OperationState &state, ValueRange operands, ArrayRef<NamedAttribute> attributes", [{
|
||||||
|
auto lhsTy = operands[0].getType().cast<RankedTensorType>();
|
||||||
|
auto rhsTy = operands[1].getType().cast<RankedTensorType>();
|
||||||
|
auto elementType = getBroadcastedType(lhsTy, rhsTy);
|
||||||
|
auto shapedType = elementType.dyn_cast_or_null<ShapedType>();
|
||||||
|
if (!shapedType || !shapedType.hasStaticShape()) {
|
||||||
|
elementType = operands[0].getType().cast<TensorType>().getElementType();
|
||||||
|
elementType = UnrankedTensorType::get(elementType);
|
||||||
|
}
|
||||||
|
std::vector<mlir::Type> outputTypes;
|
||||||
|
outputTypes.emplace_back(elementType);
|
||||||
|
build(builder, state, outputTypes, operands, attributes);
|
||||||
|
}]>
|
||||||
|
];
|
||||||
let extraClassDeclaration = [{
|
let extraClassDeclaration = [{
|
||||||
static int getNumberOfOperands() {
|
static int getNumberOfOperands() {
|
||||||
return 2;
|
return 2;
|
||||||
|
@ -2646,13 +2802,27 @@ def ONNXMulOp:ONNX_Op<"Mul",
|
||||||
let results = (outs AnyTypeOf<[TensorOf<[I32,I64,F16,F32,F64]>, MemRefOf<[I32,I64,F16,F32,F64]>]>:$C);
|
let results = (outs AnyTypeOf<[TensorOf<[I32,I64,F16,F32,F64]>, MemRefOf<[I32,I64,F16,F32,F64]>]>:$C);
|
||||||
let builders = [
|
let builders = [
|
||||||
OpBuilder<"OpBuilder &builder, OperationState &state, Value A, Value B", [{
|
OpBuilder<"OpBuilder &builder, OperationState &state, Value A, Value B", [{
|
||||||
auto elementType = A.getType().cast<TensorType>().getElementType();
|
auto lhsTy = A.getType().cast<RankedTensorType>();
|
||||||
build(builder, state, UnrankedTensorType::get(elementType), A, B);
|
auto rhsTy = B.getType().cast<RankedTensorType>();
|
||||||
|
auto elementType = getBroadcastedType(lhsTy, rhsTy);
|
||||||
|
auto shapedType = elementType.dyn_cast_or_null<ShapedType>();
|
||||||
|
if (!shapedType || !shapedType.hasStaticShape()) {
|
||||||
|
elementType = A.getType().cast<TensorType>().getElementType();
|
||||||
|
elementType = UnrankedTensorType::get(elementType);
|
||||||
|
}
|
||||||
|
build(builder, state, elementType, A, B);
|
||||||
}]>,
|
}]>,
|
||||||
OpBuilder<"OpBuilder &builder, OperationState &state, ValueRange operands, ArrayRef<NamedAttribute> attributes", [{
|
OpBuilder<"OpBuilder &builder, OperationState &state, ValueRange operands, ArrayRef<NamedAttribute> attributes", [{
|
||||||
auto elementType = operands[0].getType().cast<TensorType>().getElementType();
|
auto lhsTy = operands[0].getType().cast<RankedTensorType>();
|
||||||
|
auto rhsTy = operands[1].getType().cast<RankedTensorType>();
|
||||||
|
auto elementType = getBroadcastedType(lhsTy, rhsTy);
|
||||||
|
auto shapedType = elementType.dyn_cast_or_null<ShapedType>();
|
||||||
|
if (!shapedType || !shapedType.hasStaticShape()) {
|
||||||
|
elementType = operands[0].getType().cast<TensorType>().getElementType();
|
||||||
|
elementType = UnrankedTensorType::get(elementType);
|
||||||
|
}
|
||||||
std::vector<mlir::Type> outputTypes;
|
std::vector<mlir::Type> outputTypes;
|
||||||
outputTypes.emplace_back(UnrankedTensorType::get(elementType));
|
outputTypes.emplace_back(elementType);
|
||||||
build(builder, state, outputTypes, operands, attributes);
|
build(builder, state, outputTypes, operands, attributes);
|
||||||
}]>
|
}]>
|
||||||
];
|
];
|
||||||
|
@ -2848,6 +3018,32 @@ def ONNXOrOp:ONNX_Op<"Or",
|
||||||
let arguments = (ins AnyTypeOf<[TensorOf<[I1]>, MemRefOf<[I1]>]>:$A,
|
let arguments = (ins AnyTypeOf<[TensorOf<[I1]>, MemRefOf<[I1]>]>:$A,
|
||||||
AnyTypeOf<[TensorOf<[I1]>, MemRefOf<[I1]>]>:$B);
|
AnyTypeOf<[TensorOf<[I1]>, MemRefOf<[I1]>]>:$B);
|
||||||
let results = (outs AnyTypeOf<[TensorOf<[I1]>, MemRefOf<[I1]>]>:$C);
|
let results = (outs AnyTypeOf<[TensorOf<[I1]>, MemRefOf<[I1]>]>:$C);
|
||||||
|
let builders = [
|
||||||
|
OpBuilder<"OpBuilder &builder, OperationState &state, Value A, Value B", [{
|
||||||
|
auto lhsTy = A.getType().cast<RankedTensorType>();
|
||||||
|
auto rhsTy = B.getType().cast<RankedTensorType>();
|
||||||
|
auto elementType = getBroadcastedType(lhsTy, rhsTy);
|
||||||
|
auto shapedType = elementType.dyn_cast_or_null<ShapedType>();
|
||||||
|
if (!shapedType || !shapedType.hasStaticShape()) {
|
||||||
|
elementType = A.getType().cast<TensorType>().getElementType();
|
||||||
|
elementType = UnrankedTensorType::get(elementType);
|
||||||
|
}
|
||||||
|
build(builder, state, elementType, A, B);
|
||||||
|
}]>,
|
||||||
|
OpBuilder<"OpBuilder &builder, OperationState &state, ValueRange operands, ArrayRef<NamedAttribute> attributes", [{
|
||||||
|
auto lhsTy = operands[0].getType().cast<RankedTensorType>();
|
||||||
|
auto rhsTy = operands[1].getType().cast<RankedTensorType>();
|
||||||
|
auto elementType = getBroadcastedType(lhsTy, rhsTy);
|
||||||
|
auto shapedType = elementType.dyn_cast_or_null<ShapedType>();
|
||||||
|
if (!shapedType || !shapedType.hasStaticShape()) {
|
||||||
|
elementType = operands[0].getType().cast<TensorType>().getElementType();
|
||||||
|
elementType = UnrankedTensorType::get(elementType);
|
||||||
|
}
|
||||||
|
std::vector<mlir::Type> outputTypes;
|
||||||
|
outputTypes.emplace_back(elementType);
|
||||||
|
build(builder, state, outputTypes, operands, attributes);
|
||||||
|
}]>
|
||||||
|
];
|
||||||
let extraClassDeclaration = [{
|
let extraClassDeclaration = [{
|
||||||
static int getNumberOfOperands() {
|
static int getNumberOfOperands() {
|
||||||
return 2;
|
return 2;
|
||||||
|
@ -3017,6 +3213,32 @@ def ONNXPowOp:ONNX_Op<"Pow",
|
||||||
let arguments = (ins AnyTypeOf<[TensorOf<[F16,F32,F64]>, MemRefOf<[F16,F32,F64]>]>:$X,
|
let arguments = (ins AnyTypeOf<[TensorOf<[F16,F32,F64]>, MemRefOf<[F16,F32,F64]>]>:$X,
|
||||||
AnyTypeOf<[TensorOf<[F16,F32,F64]>, MemRefOf<[F16,F32,F64]>]>:$Y);
|
AnyTypeOf<[TensorOf<[F16,F32,F64]>, MemRefOf<[F16,F32,F64]>]>:$Y);
|
||||||
let results = (outs AnyTypeOf<[TensorOf<[F16,F32,F64]>, MemRefOf<[F16,F32,F64]>]>:$Z);
|
let results = (outs AnyTypeOf<[TensorOf<[F16,F32,F64]>, MemRefOf<[F16,F32,F64]>]>:$Z);
|
||||||
|
let builders = [
|
||||||
|
OpBuilder<"OpBuilder &builder, OperationState &state, Value X, Value Y", [{
|
||||||
|
auto lhsTy = X.getType().cast<RankedTensorType>();
|
||||||
|
auto rhsTy = Y.getType().cast<RankedTensorType>();
|
||||||
|
auto elementType = getBroadcastedType(lhsTy, rhsTy);
|
||||||
|
auto shapedType = elementType.dyn_cast_or_null<ShapedType>();
|
||||||
|
if (!shapedType || !shapedType.hasStaticShape()) {
|
||||||
|
elementType = X.getType().cast<TensorType>().getElementType();
|
||||||
|
elementType = UnrankedTensorType::get(elementType);
|
||||||
|
}
|
||||||
|
build(builder, state, elementType, X, Y);
|
||||||
|
}]>,
|
||||||
|
OpBuilder<"OpBuilder &builder, OperationState &state, ValueRange operands, ArrayRef<NamedAttribute> attributes", [{
|
||||||
|
auto lhsTy = operands[0].getType().cast<RankedTensorType>();
|
||||||
|
auto rhsTy = operands[1].getType().cast<RankedTensorType>();
|
||||||
|
auto elementType = getBroadcastedType(lhsTy, rhsTy);
|
||||||
|
auto shapedType = elementType.dyn_cast_or_null<ShapedType>();
|
||||||
|
if (!shapedType || !shapedType.hasStaticShape()) {
|
||||||
|
elementType = operands[0].getType().cast<TensorType>().getElementType();
|
||||||
|
elementType = UnrankedTensorType::get(elementType);
|
||||||
|
}
|
||||||
|
std::vector<mlir::Type> outputTypes;
|
||||||
|
outputTypes.emplace_back(elementType);
|
||||||
|
build(builder, state, outputTypes, operands, attributes);
|
||||||
|
}]>
|
||||||
|
];
|
||||||
let extraClassDeclaration = [{
|
let extraClassDeclaration = [{
|
||||||
static int getNumberOfOperands() {
|
static int getNumberOfOperands() {
|
||||||
return 2;
|
return 2;
|
||||||
|
@ -4938,6 +5160,32 @@ def ONNXSubOp:ONNX_Op<"Sub",
|
||||||
let arguments = (ins AnyTypeOf<[TensorOf<[I32,I64,F16,F32,F64]>, MemRefOf<[I32,I64,F16,F32,F64]>]>:$A,
|
let arguments = (ins AnyTypeOf<[TensorOf<[I32,I64,F16,F32,F64]>, MemRefOf<[I32,I64,F16,F32,F64]>]>:$A,
|
||||||
AnyTypeOf<[TensorOf<[I32,I64,F16,F32,F64]>, MemRefOf<[I32,I64,F16,F32,F64]>]>:$B);
|
AnyTypeOf<[TensorOf<[I32,I64,F16,F32,F64]>, MemRefOf<[I32,I64,F16,F32,F64]>]>:$B);
|
||||||
let results = (outs AnyTypeOf<[TensorOf<[I32,I64,F16,F32,F64]>, MemRefOf<[I32,I64,F16,F32,F64]>]>:$C);
|
let results = (outs AnyTypeOf<[TensorOf<[I32,I64,F16,F32,F64]>, MemRefOf<[I32,I64,F16,F32,F64]>]>:$C);
|
||||||
|
let builders = [
|
||||||
|
OpBuilder<"OpBuilder &builder, OperationState &state, Value A, Value B", [{
|
||||||
|
auto lhsTy = A.getType().cast<RankedTensorType>();
|
||||||
|
auto rhsTy = B.getType().cast<RankedTensorType>();
|
||||||
|
auto elementType = getBroadcastedType(lhsTy, rhsTy);
|
||||||
|
auto shapedType = elementType.dyn_cast_or_null<ShapedType>();
|
||||||
|
if (!shapedType || !shapedType.hasStaticShape()) {
|
||||||
|
elementType = A.getType().cast<TensorType>().getElementType();
|
||||||
|
elementType = UnrankedTensorType::get(elementType);
|
||||||
|
}
|
||||||
|
build(builder, state, elementType, A, B);
|
||||||
|
}]>,
|
||||||
|
OpBuilder<"OpBuilder &builder, OperationState &state, ValueRange operands, ArrayRef<NamedAttribute> attributes", [{
|
||||||
|
auto lhsTy = operands[0].getType().cast<RankedTensorType>();
|
||||||
|
auto rhsTy = operands[1].getType().cast<RankedTensorType>();
|
||||||
|
auto elementType = getBroadcastedType(lhsTy, rhsTy);
|
||||||
|
auto shapedType = elementType.dyn_cast_or_null<ShapedType>();
|
||||||
|
if (!shapedType || !shapedType.hasStaticShape()) {
|
||||||
|
elementType = operands[0].getType().cast<TensorType>().getElementType();
|
||||||
|
elementType = UnrankedTensorType::get(elementType);
|
||||||
|
}
|
||||||
|
std::vector<mlir::Type> outputTypes;
|
||||||
|
outputTypes.emplace_back(elementType);
|
||||||
|
build(builder, state, outputTypes, operands, attributes);
|
||||||
|
}]>
|
||||||
|
];
|
||||||
let extraClassDeclaration = [{
|
let extraClassDeclaration = [{
|
||||||
static int getNumberOfOperands() {
|
static int getNumberOfOperands() {
|
||||||
return 2;
|
return 2;
|
||||||
|
@ -5379,6 +5627,32 @@ def ONNXXorOp:ONNX_Op<"Xor",
|
||||||
let arguments = (ins AnyTypeOf<[TensorOf<[I1]>, MemRefOf<[I1]>]>:$A,
|
let arguments = (ins AnyTypeOf<[TensorOf<[I1]>, MemRefOf<[I1]>]>:$A,
|
||||||
AnyTypeOf<[TensorOf<[I1]>, MemRefOf<[I1]>]>:$B);
|
AnyTypeOf<[TensorOf<[I1]>, MemRefOf<[I1]>]>:$B);
|
||||||
let results = (outs AnyTypeOf<[TensorOf<[I1]>, MemRefOf<[I1]>]>:$C);
|
let results = (outs AnyTypeOf<[TensorOf<[I1]>, MemRefOf<[I1]>]>:$C);
|
||||||
|
let builders = [
|
||||||
|
OpBuilder<"OpBuilder &builder, OperationState &state, Value A, Value B", [{
|
||||||
|
auto lhsTy = A.getType().cast<RankedTensorType>();
|
||||||
|
auto rhsTy = B.getType().cast<RankedTensorType>();
|
||||||
|
auto elementType = getBroadcastedType(lhsTy, rhsTy);
|
||||||
|
auto shapedType = elementType.dyn_cast_or_null<ShapedType>();
|
||||||
|
if (!shapedType || !shapedType.hasStaticShape()) {
|
||||||
|
elementType = A.getType().cast<TensorType>().getElementType();
|
||||||
|
elementType = UnrankedTensorType::get(elementType);
|
||||||
|
}
|
||||||
|
build(builder, state, elementType, A, B);
|
||||||
|
}]>,
|
||||||
|
OpBuilder<"OpBuilder &builder, OperationState &state, ValueRange operands, ArrayRef<NamedAttribute> attributes", [{
|
||||||
|
auto lhsTy = operands[0].getType().cast<RankedTensorType>();
|
||||||
|
auto rhsTy = operands[1].getType().cast<RankedTensorType>();
|
||||||
|
auto elementType = getBroadcastedType(lhsTy, rhsTy);
|
||||||
|
auto shapedType = elementType.dyn_cast_or_null<ShapedType>();
|
||||||
|
if (!shapedType || !shapedType.hasStaticShape()) {
|
||||||
|
elementType = operands[0].getType().cast<TensorType>().getElementType();
|
||||||
|
elementType = UnrankedTensorType::get(elementType);
|
||||||
|
}
|
||||||
|
std::vector<mlir::Type> outputTypes;
|
||||||
|
outputTypes.emplace_back(elementType);
|
||||||
|
build(builder, state, outputTypes, operands, attributes);
|
||||||
|
}]>
|
||||||
|
];
|
||||||
let extraClassDeclaration = [{
|
let extraClassDeclaration = [{
|
||||||
static int getNumberOfOperands() {
|
static int getNumberOfOperands() {
|
||||||
return 2;
|
return 2;
|
||||||
|
|
|
@ -88,6 +88,7 @@ void registerDialects() {
|
||||||
|
|
||||||
void addONNXToMLIRPasses(mlir::PassManager &pm) {
|
void addONNXToMLIRPasses(mlir::PassManager &pm) {
|
||||||
pm.addPass(mlir::createDecomposeONNXToONNXPass());
|
pm.addPass(mlir::createDecomposeONNXToONNXPass());
|
||||||
|
pm.addPass(mlir::createConstPropONNXToONNXPass());
|
||||||
pm.addPass(mlir::createShapeInferencePass());
|
pm.addPass(mlir::createShapeInferencePass());
|
||||||
pm.addPass(mlir::createCanonicalizerPass());
|
pm.addPass(mlir::createCanonicalizerPass());
|
||||||
pm.addPass(mlir::createAttributePromotionPass());
|
pm.addPass(mlir::createAttributePromotionPass());
|
||||||
|
|
|
@ -20,6 +20,8 @@ std::unique_ptr<Pass> createDecomposeONNXToONNXPass();
|
||||||
|
|
||||||
std::unique_ptr<Pass> createShapeInferencePass();
|
std::unique_ptr<Pass> createShapeInferencePass();
|
||||||
|
|
||||||
|
std::unique_ptr<Pass> createConstPropONNXToONNXPass();
|
||||||
|
|
||||||
/// Pass for promoting constant operands to attributes.
|
/// Pass for promoting constant operands to attributes.
|
||||||
std::unique_ptr<Pass> createAttributePromotionPass();
|
std::unique_ptr<Pass> createAttributePromotionPass();
|
||||||
|
|
||||||
|
|
|
@ -33,17 +33,23 @@ set(LLVM_TARGET_DEFINITIONS ONNXDecompose.td)
|
||||||
onnx_mlir_tablegen(ONNXDecompose.inc -gen-rewriters)
|
onnx_mlir_tablegen(ONNXDecompose.inc -gen-rewriters)
|
||||||
add_public_tablegen_target(OMONNXDecomposeIncGen)
|
add_public_tablegen_target(OMONNXDecomposeIncGen)
|
||||||
|
|
||||||
|
set(LLVM_TARGET_DEFINITIONS ONNXConstProp.td)
|
||||||
|
onnx_mlir_tablegen(ONNXConstProp.inc -gen-rewriters)
|
||||||
|
add_public_tablegen_target(OMONNXConstPropIncGen)
|
||||||
|
|
||||||
add_library(OMONNXRewrite
|
add_library(OMONNXRewrite
|
||||||
ONNXRewrite.cpp
|
ONNXRewrite.cpp
|
||||||
ONNXCombine.cpp
|
ONNXCombine.cpp
|
||||||
ONNXDecompose.cpp)
|
ONNXDecompose.cpp
|
||||||
|
ONNXConstProp.cpp)
|
||||||
target_include_directories(OMONNXRewrite
|
target_include_directories(OMONNXRewrite
|
||||||
PRIVATE ${ONNX_MLIR_SRC_ROOT} ${ONNX_MLIR_BIN_ROOT}
|
PRIVATE ${ONNX_MLIR_SRC_ROOT} ${ONNX_MLIR_BIN_ROOT}
|
||||||
${ONNF_MLIR_SRC_ROOT})
|
${ONNF_MLIR_SRC_ROOT})
|
||||||
add_dependencies(OMONNXRewrite
|
add_dependencies(OMONNXRewrite
|
||||||
OMONNXRewriteIncGen
|
OMONNXRewriteIncGen
|
||||||
OMONNXDecomposeIncGen
|
OMONNXDecomposeIncGen
|
||||||
OMONNXCombineIncGen)
|
OMONNXCombineIncGen
|
||||||
|
OMONNXConstPropIncGen)
|
||||||
# Linking dependencies:
|
# Linking dependencies:
|
||||||
add_dependencies(OMONNXRewrite
|
add_dependencies(OMONNXRewrite
|
||||||
OMONNXOps)
|
OMONNXOps)
|
||||||
|
|
|
@ -0,0 +1,335 @@
|
||||||
|
//===----------- ONNXConstProp.cpp - ONNX High Level Rewriting ------------===//
|
||||||
|
//
|
||||||
|
// Copyright 2019-2020 The IBM Research Authors.
|
||||||
|
//
|
||||||
|
// =============================================================================
|
||||||
|
//
|
||||||
|
// This file implements a set of rewriters to constprop an ONNX operation into
|
||||||
|
// composition of other ONNX operations.
|
||||||
|
//
|
||||||
|
// This pass is applied before any other pass so that there is no need to
|
||||||
|
// implement shape inference for the constpropd operation. Hence, it is expected
|
||||||
|
// that there is no knowledge about tensor shape at this point
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#include "mlir/IR/Matchers.h"
|
||||||
|
#include "mlir/IR/PatternMatch.h"
|
||||||
|
#include "mlir/Pass/Pass.h"
|
||||||
|
#include "mlir/Transforms/DialectConversion.h"
|
||||||
|
|
||||||
|
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||||
|
#include "src/Pass/Passes.hpp"
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
// =============================================================================
|
||||||
|
// Instructions to add a constant operation. There is currently support for
|
||||||
|
// adding constant propagation for unary and binary athythmetic ops (binary ops
|
||||||
|
// support broadcast). To add an operation, you simply have to add a templated
|
||||||
|
// method on how to compute the result in terms of one or two inputs. Values
|
||||||
|
// comes as Attribtues, and return is also an Attribute. In that function,
|
||||||
|
// presumably you will need different methods to handle int / float /
|
||||||
|
// strings... Note that these methods cannot fail. It is your responsablitity to
|
||||||
|
// tests for which data type are supported in the rules directly. Specific type
|
||||||
|
// restrictions can be added in the DRR files.
|
||||||
|
|
||||||
|
// The methods are:
|
||||||
|
//
|
||||||
|
// ComputeConstProppElementwiseBinary and ComputeConstProppElementwiseUnary
|
||||||
|
// and they need to be tempalted wtih an ONNX Operation (presuably).
|
||||||
|
//
|
||||||
|
// Then you need to add rules on how to transform the patterns; look into
|
||||||
|
// ONNXConstProp.td for example.
|
||||||
|
//
|
||||||
|
// =============================================================================
|
||||||
|
|
||||||
|
// =============================================================================
|
||||||
|
// Code to perform constant propagation for binary in presence of broadcast.
|
||||||
|
// =============================================================================
|
||||||
|
|
||||||
|
// Template to generate binary operation results. It takes as inupt
|
||||||
|
// the element type as well as the two element attributes for the
|
||||||
|
// operation, and return the result of the operation, also as an
|
||||||
|
// attribute.
|
||||||
|
|
||||||
|
template <typename OP>
|
||||||
|
Attribute ComputeConstProppElementwiseBinary(PatternRewriter &rewriter,
|
||||||
|
Type elementType, Attribute &lhsAttr, Attribute &secondAttr) {
|
||||||
|
llvm_unreachable("unkonwn operation");
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
Attribute ComputeConstProppElementwiseBinary<ONNXAddOp>(
|
||||||
|
PatternRewriter &rewriter, Type elementType, Attribute &lhsAttr,
|
||||||
|
Attribute &secondAttr) {
|
||||||
|
if (elementType.isa<FloatType>()) {
|
||||||
|
double lhsVal = lhsAttr.cast<FloatAttr>().getValueAsDouble();
|
||||||
|
double rhsVal = secondAttr.cast<FloatAttr>().getValueAsDouble();
|
||||||
|
double res = lhsVal + rhsVal;
|
||||||
|
// printf(" %f + %f -> %f\n", lhsVal, rhsVal, res);
|
||||||
|
// Could use the APFloat interface to emulate the results, are ok to simply
|
||||||
|
// perform them in the highest possible precision.
|
||||||
|
return rewriter.getFloatAttr(elementType, res);
|
||||||
|
}
|
||||||
|
if (elementType.isa<IntegerType>()) {
|
||||||
|
uint64_t lhsVal = lhsAttr.cast<IntegerAttr>().getInt();
|
||||||
|
uint64_t rhsVal = secondAttr.cast<IntegerAttr>().getInt();
|
||||||
|
uint64_t res = lhsVal + rhsVal;
|
||||||
|
// printf(" %llu + %llu -> %llu\n", lhsVal, rhsVal, res);
|
||||||
|
return rewriter.getIntegerAttr(elementType, res);
|
||||||
|
}
|
||||||
|
llvm_unreachable("constant propagation for AddOp: unkonwn data type");
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
Attribute ComputeConstProppElementwiseBinary<ONNXSubOp>(
|
||||||
|
PatternRewriter &rewriter, Type elementType, Attribute &lhsAttr,
|
||||||
|
Attribute &secondAttr) {
|
||||||
|
if (elementType.isa<FloatType>()) {
|
||||||
|
double lhsVal = lhsAttr.cast<FloatAttr>().getValueAsDouble();
|
||||||
|
double rhsVal = secondAttr.cast<FloatAttr>().getValueAsDouble();
|
||||||
|
double res = lhsVal - rhsVal;
|
||||||
|
return rewriter.getFloatAttr(elementType, res);
|
||||||
|
}
|
||||||
|
if (elementType.isa<IntegerType>()) {
|
||||||
|
uint64_t lhsVal = lhsAttr.cast<IntegerAttr>().getInt();
|
||||||
|
uint64_t rhsVal = secondAttr.cast<IntegerAttr>().getInt();
|
||||||
|
uint64_t res = lhsVal - rhsVal;
|
||||||
|
return rewriter.getIntegerAttr(elementType, res);
|
||||||
|
}
|
||||||
|
llvm_unreachable("constant propagation for SubOp: unkonwn data type");
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
Attribute ComputeConstProppElementwiseBinary<ONNXMulOp>(
|
||||||
|
PatternRewriter &rewriter, Type elementType, Attribute &lhsAttr,
|
||||||
|
Attribute &secondAttr) {
|
||||||
|
if (elementType.isa<FloatType>()) {
|
||||||
|
double lhsVal = lhsAttr.cast<FloatAttr>().getValueAsDouble();
|
||||||
|
double rhsVal = secondAttr.cast<FloatAttr>().getValueAsDouble();
|
||||||
|
double res = lhsVal * rhsVal;
|
||||||
|
return rewriter.getFloatAttr(elementType, res);
|
||||||
|
}
|
||||||
|
if (elementType.isa<IntegerType>()) {
|
||||||
|
uint64_t lhsVal = lhsAttr.cast<IntegerAttr>().getInt();
|
||||||
|
uint64_t rhsVal = secondAttr.cast<IntegerAttr>().getInt();
|
||||||
|
uint64_t res = lhsVal * rhsVal;
|
||||||
|
return rewriter.getIntegerAttr(elementType, res);
|
||||||
|
}
|
||||||
|
llvm_unreachable("constant propagation for MulOp: unkonwn data type");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Recursively process one dimension in the rank of the two references. There
|
||||||
|
// can be one of 3 cases.
|
||||||
|
// 1) We have fully defined accesses for both operands, launch the computations.
|
||||||
|
// 2) One of the two has a higher number of unprocessed ranks, which is hte case
|
||||||
|
// when we have to broadcast the whole lower-dim reference with respect to the
|
||||||
|
// other. Iterate over each value of the higher ranked reference, keeping the
|
||||||
|
// reference of the lower ranked reference constant.
|
||||||
|
// 3) Both references have the same rank, we still do broadcast if one of the
|
||||||
|
// dimension size is equal to 1.
|
||||||
|
|
||||||
|
template <typename ElementwiseBinaryOp>
|
||||||
|
void RecurseConstProppElementwiseBinary(PatternRewriter &rewriter,
|
||||||
|
std::vector<Attribute> &resVector, DenseElementsAttr &lhsAttr,
|
||||||
|
DenseElementsAttr &rhsAttr, SmallVector<uint64_t, 4> &lhsIndices,
|
||||||
|
SmallVector<uint64_t, 4> &rhsIndices, int lhsFreeRank, int rhsFreeRank) {
|
||||||
|
// printf("recurse with free %d/%d\n", lhsFreeRank, rhsFreeRank);
|
||||||
|
if (lhsFreeRank == 0) {
|
||||||
|
// Fully defined ranks.
|
||||||
|
assert(
|
||||||
|
rhsFreeRank == 0 && "expect both to recurse to zero at the same time");
|
||||||
|
auto lhsElementAttr = lhsAttr.getValue(ArrayRef<uint64_t>(lhsIndices));
|
||||||
|
auto rhsElementAttr = rhsAttr.getValue(ArrayRef<uint64_t>(rhsIndices));
|
||||||
|
auto elementaryType = lhsAttr.getType().getElementType();
|
||||||
|
auto res = ComputeConstProppElementwiseBinary<ElementwiseBinaryOp>(
|
||||||
|
rewriter, elementaryType, lhsElementAttr, rhsElementAttr);
|
||||||
|
resVector.emplace_back(res);
|
||||||
|
} else if (lhsFreeRank > rhsFreeRank) {
|
||||||
|
// Initial broadcast from lhs.
|
||||||
|
auto lhsShape = lhsAttr.getType().getShape();
|
||||||
|
int lhsRank = lhsShape.size();
|
||||||
|
int lhsIndex = lhsRank - lhsFreeRank;
|
||||||
|
int lhsSize = lhsAttr.getType().getShape()[lhsIndex];
|
||||||
|
for (int i = 0; i < lhsSize; ++i) {
|
||||||
|
lhsIndices[lhsIndex] = i;
|
||||||
|
RecurseConstProppElementwiseBinary<ElementwiseBinaryOp>(rewriter,
|
||||||
|
resVector, lhsAttr, rhsAttr, lhsIndices, rhsIndices, lhsFreeRank - 1,
|
||||||
|
rhsFreeRank);
|
||||||
|
}
|
||||||
|
} else if (lhsFreeRank < rhsFreeRank) {
|
||||||
|
// Initial broadcast from rhs.
|
||||||
|
auto rhsShape = rhsAttr.getType().getShape();
|
||||||
|
int rhsRank = rhsShape.size();
|
||||||
|
int rhsIndex = rhsRank - rhsFreeRank;
|
||||||
|
int rhsSize = rhsAttr.getType().getShape()[rhsIndex];
|
||||||
|
for (int i = 0; i < rhsSize; ++i) {
|
||||||
|
rhsIndices[rhsIndex] = i;
|
||||||
|
RecurseConstProppElementwiseBinary<ElementwiseBinaryOp>(rewriter,
|
||||||
|
resVector, lhsAttr, rhsAttr, lhsIndices, rhsIndices, lhsFreeRank,
|
||||||
|
rhsFreeRank - 1);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// No initial broadcast, but if one element has size 1 and the other is
|
||||||
|
// greater than one, then we also have broadcast.
|
||||||
|
auto lhsShape = lhsAttr.getType().getShape();
|
||||||
|
int lhsRank = lhsShape.size();
|
||||||
|
int lhsIndex = lhsRank - lhsFreeRank;
|
||||||
|
int lhsSize = lhsAttr.getType().getShape()[lhsIndex];
|
||||||
|
auto rhsShape = rhsAttr.getType().getShape();
|
||||||
|
int rhsRank = rhsShape.size();
|
||||||
|
int rhsIndex = rhsRank - rhsFreeRank;
|
||||||
|
int rhsSize = rhsAttr.getType().getShape()[rhsIndex];
|
||||||
|
assert((lhsSize == 1 || rhsSize == 1 || lhsSize == rhsSize) &&
|
||||||
|
"incompatible sizes");
|
||||||
|
int size = std::max(lhsSize, rhsSize);
|
||||||
|
lhsIndices[lhsIndex] = rhsIndices[rhsIndex] = 0;
|
||||||
|
for (int i = 0; i < size; ++i) {
|
||||||
|
if (lhsSize > 1)
|
||||||
|
lhsIndices[lhsIndex] = i;
|
||||||
|
if (rhsSize > 1)
|
||||||
|
rhsIndices[rhsIndex] = i;
|
||||||
|
RecurseConstProppElementwiseBinary<ElementwiseBinaryOp>(rewriter,
|
||||||
|
resVector, lhsAttr, rhsAttr, lhsIndices, rhsIndices, lhsFreeRank - 1,
|
||||||
|
rhsFreeRank - 1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Process the constant operands, perform the operation with broadcast, and
|
||||||
|
// generate the new constant operation.
|
||||||
|
template <typename ElementwiseBinaryOp>
|
||||||
|
DenseElementsAttr ConstPropElementwiseBinary(PatternRewriter &rewriter,
|
||||||
|
Value resOperand, Attribute &lhsAttr, Attribute &rhsAttr) {
|
||||||
|
DenseElementsAttr lhsDenseAttr =
|
||||||
|
lhsAttr.dyn_cast_or_null<mlir::DenseElementsAttr>();
|
||||||
|
DenseElementsAttr rhsDenseAttr =
|
||||||
|
rhsAttr.dyn_cast_or_null<mlir::DenseElementsAttr>();
|
||||||
|
assert((lhsDenseAttr && lhsDenseAttr) && "expected dense attributes");
|
||||||
|
assert(
|
||||||
|
resOperand.getType().isa<RankedTensorType>() && "expected ranked tensor");
|
||||||
|
ShapedType resType = resOperand.getType().cast<RankedTensorType>();
|
||||||
|
auto lhsRank = lhsDenseAttr.getType().getShape().size();
|
||||||
|
auto rhsRank = rhsDenseAttr.getType().getShape().size();
|
||||||
|
SmallVector<uint64_t, 4> lhsIndices(lhsRank, 0);
|
||||||
|
SmallVector<uint64_t, 4> rhsIndices(rhsRank, 0);
|
||||||
|
std::vector<Attribute> resVector;
|
||||||
|
RecurseConstProppElementwiseBinary<ElementwiseBinaryOp>(rewriter, resVector,
|
||||||
|
lhsDenseAttr, rhsDenseAttr, lhsIndices, rhsIndices, lhsRank, rhsRank);
|
||||||
|
ArrayRef<Attribute> resRef(resVector);
|
||||||
|
return DenseElementsAttr::get(resType, resRef);
|
||||||
|
}
|
||||||
|
|
||||||
|
// =============================================================================
|
||||||
|
// Code to perform constant propagation for unary operation.
|
||||||
|
// =============================================================================
|
||||||
|
|
||||||
|
template <typename OP>
|
||||||
|
Attribute ComputeConstProppElementwiseUnary(
|
||||||
|
PatternRewriter &rewriter, Type elementType, Attribute &attr) {
|
||||||
|
llvm_unreachable("unkonwn operation");
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
Attribute ComputeConstProppElementwiseUnary<ONNXNegOp>(
|
||||||
|
PatternRewriter &rewriter, Type elementType, Attribute &attr) {
|
||||||
|
if (elementType.isa<FloatType>()) {
|
||||||
|
double val = attr.cast<FloatAttr>().getValueAsDouble();
|
||||||
|
double res = -val;
|
||||||
|
return rewriter.getFloatAttr(elementType, res);
|
||||||
|
}
|
||||||
|
if (elementType.isa<IntegerType>()) {
|
||||||
|
uint64_t val = attr.cast<IntegerAttr>().getInt();
|
||||||
|
uint64_t res = -val;
|
||||||
|
return rewriter.getIntegerAttr(elementType, res);
|
||||||
|
}
|
||||||
|
llvm_unreachable("constant propagation for NegOp: unkonwn data type");
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename ElementwiseUnaryOp>
|
||||||
|
void RecurseConstProppElementwiseUnary(PatternRewriter &rewriter,
|
||||||
|
std::vector<Attribute> &resVector, DenseElementsAttr &attr,
|
||||||
|
SmallVector<uint64_t, 4> &indices, int freeRank) {
|
||||||
|
// printf("recurse with free %d\n", freeRank);
|
||||||
|
if (freeRank == 0) {
|
||||||
|
// Fully defined ranks.
|
||||||
|
auto elementAttr = attr.getValue(ArrayRef<uint64_t>(indices));
|
||||||
|
auto elementaryType = attr.getType().getElementType();
|
||||||
|
auto res = ComputeConstProppElementwiseUnary<ElementwiseUnaryOp>(
|
||||||
|
rewriter, elementaryType, elementAttr);
|
||||||
|
resVector.emplace_back(res);
|
||||||
|
} else {
|
||||||
|
// Recurse.
|
||||||
|
auto shape = attr.getType().getShape();
|
||||||
|
int rank = shape.size();
|
||||||
|
int index = rank - freeRank;
|
||||||
|
int size = attr.getType().getShape()[index];
|
||||||
|
for (int i = 0; i < size; ++i) {
|
||||||
|
indices[index] = i;
|
||||||
|
RecurseConstProppElementwiseUnary<ElementwiseUnaryOp>(
|
||||||
|
rewriter, resVector, attr, indices, freeRank - 1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Process the constant operands, perform the operation with broadcast, and
|
||||||
|
// generate the new constant operation.
|
||||||
|
template <typename ElementwiseUnaryOp>
|
||||||
|
DenseElementsAttr ConstPropElementwiseUnary(
|
||||||
|
PatternRewriter &rewriter, Value resOperand, Attribute &attr) {
|
||||||
|
DenseElementsAttr denseAttr =
|
||||||
|
attr.dyn_cast_or_null<mlir::DenseElementsAttr>();
|
||||||
|
assert(denseAttr && "expected dense attribute");
|
||||||
|
assert(
|
||||||
|
resOperand.getType().isa<RankedTensorType>() && "expected ranked tensor");
|
||||||
|
ShapedType resType = resOperand.getType().cast<RankedTensorType>();
|
||||||
|
auto rank = denseAttr.getType().getShape().size();
|
||||||
|
SmallVector<uint64_t, 4> indices(rank, 0);
|
||||||
|
std::vector<Attribute> resVector;
|
||||||
|
RecurseConstProppElementwiseUnary<ElementwiseUnaryOp>(
|
||||||
|
rewriter, resVector, denseAttr, indices, rank);
|
||||||
|
ArrayRef<Attribute> resRef(resVector);
|
||||||
|
return DenseElementsAttr::get(resType, resRef);
|
||||||
|
}
|
||||||
|
|
||||||
|
// =============================================================================
|
||||||
|
// Pattern definition.
|
||||||
|
// =============================================================================
|
||||||
|
|
||||||
|
#include "src/Transform/ONNX/ONNXConstProp.inc"
|
||||||
|
|
||||||
|
// =============================================================================
|
||||||
|
// Code to manage the pass.
|
||||||
|
// =============================================================================
|
||||||
|
|
||||||
|
struct ConstPropONNXToONNXPass
|
||||||
|
: public PassWrapper<ConstPropONNXToONNXPass, FunctionPass> {
|
||||||
|
void runOnFunction() final;
|
||||||
|
};
|
||||||
|
} // end anonymous namespace.
|
||||||
|
|
||||||
|
void ConstPropONNXToONNXPass::runOnFunction() {
|
||||||
|
auto function = getFunction();
|
||||||
|
MLIRContext *context = &getContext();
|
||||||
|
|
||||||
|
ConversionTarget target(getContext());
|
||||||
|
target.addLegalDialect<ONNXOpsDialect>();
|
||||||
|
|
||||||
|
OwningRewritePatternList patterns;
|
||||||
|
populateWithGenerated(context, &patterns);
|
||||||
|
|
||||||
|
applyPatternsAndFoldGreedily(function, patterns);
|
||||||
|
} // end anonymous namespace
|
||||||
|
|
||||||
|
/*!
|
||||||
|
* Create a ConstPropONNX pass.
|
||||||
|
*/
|
||||||
|
std::unique_ptr<mlir::Pass> mlir::createConstPropONNXToONNXPass() {
|
||||||
|
return std::make_unique<ConstPropONNXToONNXPass>();
|
||||||
|
}
|
||||||
|
|
||||||
|
static PassRegistration<ConstPropONNXToONNXPass> pass("constprop-onnx",
|
||||||
|
"ConstProp ONNX operations into composition of other ONNX operations.");
|
|
@ -0,0 +1,161 @@
|
||||||
|
//===- ONNXConstProp.td - Rewriting for Constant Propagation in ONNX Ops -*- tablegen -===//
|
||||||
|
//
|
||||||
|
// Copyright 2019-2020 The IBM Research Authors.
|
||||||
|
//
|
||||||
|
// =============================================================================
|
||||||
|
//
|
||||||
|
// Defines language-specific pattern match rewritings for ONNX using
|
||||||
|
// Declarative Rewrite Rules (DRR) specified using TableGen records.
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#ifndef ONNX_CONSTPROP
|
||||||
|
#define ONNX_CONSTPROP
|
||||||
|
|
||||||
|
#ifndef OP_BASE
|
||||||
|
include "src/Dialect/ONNX/ONNXOps.td"
|
||||||
|
#endif // OP_BASE
|
||||||
|
|
||||||
|
|
||||||
|
// =============================================================================
|
||||||
|
// Instruction to add new constant operation rules. Minimally, you will have added
|
||||||
|
// operation in the ONNXConstProp.cpp to perform the element-wise single value
|
||||||
|
// handling of the new operator that you are dealing with. You will need to
|
||||||
|
// generate a call to the method that handle the tensor constant prop. Here
|
||||||
|
// is the call for a unary and binary operation. Adapt to your new operator:
|
||||||
|
//
|
||||||
|
// def CreateAddOfTwoConst :
|
||||||
|
// NativeCodeCall<"ConstPropElementwiseBinary<mlir::ONNXAddOp>($_builder, $0, $1, $2)">;
|
||||||
|
//
|
||||||
|
// def CreateNegOfConst :
|
||||||
|
// NativeCodeCall<"ConstPropElementwiseUnary<mlir::ONNXNegOp>($_builder, $0, $1)">;
|
||||||
|
//
|
||||||
|
// where you will have mostly to substitute your new operator as well as using
|
||||||
|
// a new def name.
|
||||||
|
//
|
||||||
|
// Then you will need to add substitution rules, see examples below.
|
||||||
|
// =============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
// Useful test definitions.
|
||||||
|
|
||||||
|
def IsNotAConstant :
|
||||||
|
Constraint<CPred<"! dyn_cast_or_null<ONNXConstantOp>(($_self).getDefiningOp())">,
|
||||||
|
"operation is not a constant">;
|
||||||
|
|
||||||
|
def AttributeIsNull :
|
||||||
|
Constraint<CPred<"! ($_self)">,
|
||||||
|
"Attribute is null">;
|
||||||
|
|
||||||
|
|
||||||
|
// Usefult code generation invokation.
|
||||||
|
def GetNullAttr : NativeCodeCall<"Attribute()">;
|
||||||
|
|
||||||
|
def CreateAddOfTwoConst :
|
||||||
|
NativeCodeCall<"ConstPropElementwiseBinary<mlir::ONNXAddOp>($_builder, $0, $1, $2)">;
|
||||||
|
|
||||||
|
def CreateSubOfTwoConst :
|
||||||
|
NativeCodeCall<"ConstPropElementwiseBinary<mlir::ONNXSubOp>($_builder, $0, $1, $2)">;
|
||||||
|
|
||||||
|
def CreateNegOfConst :
|
||||||
|
NativeCodeCall<"ConstPropElementwiseUnary<mlir::ONNXNegOp>($_builder, $0, $1)">;
|
||||||
|
|
||||||
|
def CreateMulOfTwoConst :
|
||||||
|
NativeCodeCall<"ConstPropElementwiseBinary<mlir::ONNXMulOp>($_builder, $0, $1, $2)">;
|
||||||
|
|
||||||
|
// =============================================================================
|
||||||
|
// Patterns to enable opportunities with elementwise ADD operations.
|
||||||
|
|
||||||
|
// Use commutativity to normalize constants in the second position of Add.
|
||||||
|
def AddConstCommutative1 : Pat<
|
||||||
|
// From add(c, x).
|
||||||
|
(ONNXAddOp (ONNXConstantOp:$c $_, $_), $x),
|
||||||
|
// To add(x, c).
|
||||||
|
(ONNXAddOp $x, $c),
|
||||||
|
// To avoid infinite loop, constrain the first arguments to be anything but a constant.
|
||||||
|
[(IsNotAConstant:$x)]>;
|
||||||
|
|
||||||
|
// Use associativity to add constants together.
|
||||||
|
def AddConstAssociative1 : Pat<
|
||||||
|
// From add(add(x, c1), c2).
|
||||||
|
(ONNXAddOp
|
||||||
|
(ONNXAddOp $x,(ONNXConstantOp:$c1 $_, $_)),
|
||||||
|
(ONNXConstantOp:$c2 $_, $_)),
|
||||||
|
// To add(x, add(c1, c2)).
|
||||||
|
(ONNXAddOp
|
||||||
|
$x,
|
||||||
|
(ONNXAddOp $c1, $c2))>;
|
||||||
|
|
||||||
|
// Constant Propagation for Add
|
||||||
|
def AddConstProp : Pat<
|
||||||
|
// From add(c1, c2).
|
||||||
|
(ONNXAddOp:$addOp (ONNXConstantOp $s1, $v1), (ONNXConstantOp $s2, $v2)),
|
||||||
|
// To c1+c2
|
||||||
|
(ONNXConstantOp (GetNullAttr), (CreateAddOfTwoConst $addOp, $v1, $v2)),
|
||||||
|
// Additional constraints (no sparse)
|
||||||
|
[(AttributeIsNull:$s1), (AttributeIsNull:$s2)]>;
|
||||||
|
|
||||||
|
|
||||||
|
// =============================================================================
|
||||||
|
// Patterns to enable opportunities with elementwise SUB / NEG operations.
|
||||||
|
|
||||||
|
// Constant Propagation for Sub
|
||||||
|
def SubConstProp : Pat<
|
||||||
|
// From sub(c1, c2).
|
||||||
|
(ONNXSubOp:$subOp (ONNXConstantOp $s1, $v1), (ONNXConstantOp $s2, $v2)),
|
||||||
|
// To c1-c2
|
||||||
|
(ONNXConstantOp (GetNullAttr), (CreateSubOfTwoConst $subOp, $v1, $v2)),
|
||||||
|
[(AttributeIsNull:$s1), (AttributeIsNull:$s2)]>;
|
||||||
|
|
||||||
|
// Neg of constant is simly -const
|
||||||
|
def NegofConst : Pat<
|
||||||
|
// From - (c)
|
||||||
|
(ONNXNegOp (ONNXConstantOp:$constOp $s, $v)),
|
||||||
|
// To (-c)
|
||||||
|
(ONNXConstantOp (GetNullAttr), (CreateNegOfConst $constOp, $v)),
|
||||||
|
[(AttributeIsNull:$s)]>;
|
||||||
|
|
||||||
|
// Change a subtraction of a constant c by an addition of -c. Helpfull to combine
|
||||||
|
// with other add optimizations.
|
||||||
|
def SubConstToNeg : Pat<
|
||||||
|
// From x - c.
|
||||||
|
(ONNXSubOp:$subOp $x, (ONNXConstantOp:$constOp $s, $v)),
|
||||||
|
// To x + (-c).
|
||||||
|
(ONNXAddOp $x, (ONNXConstantOp (GetNullAttr), (CreateNegOfConst $constOp, $v))),
|
||||||
|
[(IsNotAConstant:$x), (AttributeIsNull:$s)]>;
|
||||||
|
|
||||||
|
|
||||||
|
// =============================================================================
|
||||||
|
// Patterns to enable opportunities with elementwise MUL operations.
|
||||||
|
// Exactly the same pattern as for the elementwise ADD operations.
|
||||||
|
|
||||||
|
// Use commutativity to normalize constants in the second position of Mul.
|
||||||
|
def MulConstCommutative1 : Pat<
|
||||||
|
// From mul(c, x).
|
||||||
|
(ONNXMulOp (ONNXConstantOp:$c $_, $_), $x),
|
||||||
|
// To mul(x, c).
|
||||||
|
(ONNXMulOp $x, $c),
|
||||||
|
// To avoid infinite loop, constrain the first arguments to be anything but a constant.
|
||||||
|
[(IsNotAConstant:$x)]>;
|
||||||
|
|
||||||
|
// Use associativity to mul constants together.
|
||||||
|
def MulConstAssociative1 : Pat<
|
||||||
|
// From mul(mul(x, c1), c2).
|
||||||
|
(ONNXMulOp
|
||||||
|
(ONNXMulOp $x,(ONNXConstantOp:$c1 $_, $_)),
|
||||||
|
(ONNXConstantOp:$c2 $_, $_)),
|
||||||
|
// To mul(x, mul(c1, c2)).
|
||||||
|
(ONNXMulOp
|
||||||
|
$x,
|
||||||
|
(ONNXMulOp $c1, $c2))>;
|
||||||
|
|
||||||
|
// Constant Propagation for Mul
|
||||||
|
def MulConstProp : Pat<
|
||||||
|
// From mul(c1, c2).
|
||||||
|
(ONNXMulOp:$mulOp (ONNXConstantOp $s1, $v1), (ONNXConstantOp $s2, $v2)),
|
||||||
|
// To c1+c2
|
||||||
|
(ONNXConstantOp (GetNullAttr), (CreateMulOfTwoConst $mulOp, $v1, $v2)),
|
||||||
|
// Mulitional constraints (no sparse)
|
||||||
|
[(AttributeIsNull:$s1), (AttributeIsNull:$s2)]>;
|
||||||
|
|
||||||
|
#endif // ONNX_CONSTPROP
|
|
@ -0,0 +1,158 @@
|
||||||
|
// RUN: onnx-mlir-opt --constprop-onnx %s -split-input-file | FileCheck %s
|
||||||
|
|
||||||
|
// =============================================================================
|
||||||
|
/// MUL tests (same as add, so have only one).
|
||||||
|
|
||||||
|
/// Test ConstantOp assoc for add
|
||||||
|
|
||||||
|
// CHECK-LABEL: @test_add_constant_1(%arg0: tensor<3xf32>) -> tensor<3xf32>
|
||||||
|
func @test_add_constant_1(%arg0 : tensor<3xf32>) -> tensor<3xf32> {
|
||||||
|
%0 = "onnx.Constant"() {value = dense<[0.0, 1.0, 2.0]> : tensor<3xf32>} : () -> tensor<3xf32>
|
||||||
|
%1 = "onnx.Add"(%0, %arg0) : (tensor<3xf32> , tensor<3xf32>) -> tensor<3xf32>
|
||||||
|
"std.return"(%1) : (tensor<3xf32>) -> ()
|
||||||
|
// CHECK-NEXT: [[CONST:%.+]] = "onnx.Constant"() {value = dense<[0.000000e+00, 1.000000e+00, 2.000000e+00]> : tensor<3xf32>} : () -> tensor<3xf32>
|
||||||
|
// CHECK-NEXT: [[ADD:%.+]] = "onnx.Add"(%arg0, [[CONST]]) : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Test ConstantOp assoc for add
|
||||||
|
// CHECK-LABEL: @test_add_constant_2(%arg0: tensor<3xf32>) -> tensor<3xf32>
|
||||||
|
func @test_add_constant_2(%arg0 : tensor<3xf32>) -> tensor<3xf32> {
|
||||||
|
%0 = "onnx.Constant"() {value = dense<[0.0, 1.0, 2.0]> : tensor<3xf32>} : () -> tensor<3xf32>
|
||||||
|
%1 = "onnx.Add"(%arg0, %0) : (tensor<3xf32> , tensor<3xf32>) -> tensor<3xf32>
|
||||||
|
"std.return"(%1) : (tensor<3xf32>) -> ()
|
||||||
|
// CHECK-NEXT: [[CONST:%.+]] = "onnx.Constant"() {value = dense<[0.000000e+00, 1.000000e+00, 2.000000e+00]> : tensor<3xf32>} : () -> tensor<3xf32>
|
||||||
|
// CHECK-NEXT: [[ADD:%.+]] = "onnx.Add"(%arg0, [[CONST]]) : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Change (x+c1)+c2 to x+(c1+c2)
|
||||||
|
// CHECK-LABEL: @test_add_constant_3(%arg0: tensor<3xi32>) -> tensor<3xi32>
|
||||||
|
func @test_add_constant_3(%arg0 : tensor<3xi32>) -> tensor<3xi32> {
|
||||||
|
%0 = "onnx.Constant"() {value = dense<[0, 1, 2]> : tensor<3xi32>} : () -> tensor<3xi32>
|
||||||
|
%1 = "onnx.Constant"() {value = dense<[10, 11, 12]> : tensor<3xi32>} : () -> tensor<3xi32>
|
||||||
|
%2 = "onnx.Add"(%0, %arg0) : (tensor<3xi32> , tensor<3xi32>) -> tensor<3xi32>
|
||||||
|
%3 = "onnx.Add"(%1, %2) : (tensor<3xi32> , tensor<3xi32>) -> tensor<3xi32>
|
||||||
|
"std.return"(%3) : (tensor<3xi32>) -> ()
|
||||||
|
// CHECK-NEXT: [[CONST1:%.+]] = "onnx.Constant"() {value = dense<[10, 12, 14]> : tensor<3xi32>} : () -> tensor<3xi32>
|
||||||
|
// CHECK-NEXT: [[ADD1:%.+]] = "onnx.Add"(%arg0, [[CONST1]]) : (tensor<3xi32>, tensor<3xi32>) -> tensor<3xi32>
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Same test as above, but with a use of an intermediary result
|
||||||
|
/// change (x+c1)+c2 + (x+c1) to x+(c1+c2) + (x+c1)
|
||||||
|
// CHECK-LABEL: @test_add_constant_4(%arg0: tensor<3xi32>) -> tensor<3xi32>
|
||||||
|
func @test_add_constant_4(%arg0 : tensor<3xi32>) -> tensor<3xi32> {
|
||||||
|
%0 = "onnx.Constant"() {value = dense<[0, 1, 2]> : tensor<3xi32>} : () -> tensor<3xi32>
|
||||||
|
%1 = "onnx.Constant"() {value = dense<[10, 11, 12]> : tensor<3xi32>} : () -> tensor<3xi32>
|
||||||
|
%2 = "onnx.Add"(%0, %arg0) : (tensor<3xi32> , tensor<3xi32>) -> tensor<3xi32>
|
||||||
|
%3 = "onnx.Add"(%1, %2) : (tensor<3xi32> , tensor<3xi32>) -> tensor<3xi32>
|
||||||
|
%4 = "onnx.Add"(%2, %3) : (tensor<3xi32> , tensor<3xi32>) -> tensor<3xi32>
|
||||||
|
"std.return"(%4) : (tensor<3xi32>) -> ()
|
||||||
|
// CHECK-NEXT: [[CONST1:%.+]] = "onnx.Constant"() {value = dense<[0, 1, 2]> : tensor<3xi32>} : () -> tensor<3xi32>
|
||||||
|
// CHECK-NEXT: [[ADD1:%.+]] = "onnx.Add"(%arg0, [[CONST1]]) : (tensor<3xi32>, tensor<3xi32>) -> tensor<3xi32>
|
||||||
|
// CHECK-NEXT: [[CONST2:%.+]] = "onnx.Constant"() {value = dense<[10, 12, 14]> : tensor<3xi32>} : () -> tensor<3xi32>
|
||||||
|
// CHECK-NEXT: [[ADD2:%.+]] = "onnx.Add"(%arg0, [[CONST2]]) : (tensor<3xi32>, tensor<3xi32>) -> tensor<3xi32>
|
||||||
|
// CHECK-NEXT: [[ADD3:%.+]] = "onnx.Add"([[ADD1]], [[ADD2]]) : (tensor<3xi32>, tensor<3xi32>) -> tensor<3xi32>
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Test broadcast 1 -> 2d
|
||||||
|
|
||||||
|
// CHECK-LABEL: @test_broadcast_1(%arg0: tensor<3x2xi32>) -> tensor<3x2xi32>
|
||||||
|
func @test_broadcast_1(%arg0: tensor<3x2xi32>) -> tensor<3x2xi32> {
|
||||||
|
%0 = "onnx.Constant"() {value = dense<[1]> : tensor<1xi32>} : () -> tensor<1xi32>
|
||||||
|
%1 = "onnx.Constant"() {value = dense<[[2, 3], [4, 5], [6, 7]]> : tensor<3x2xi32>} : () -> tensor<3x2xi32>
|
||||||
|
%2 = "onnx.Add"(%0, %1) : (tensor<1xi32> , tensor<3x2xi32>) -> tensor<3x2xi32>
|
||||||
|
%3 = "onnx.Add"(%2, %arg0) : (tensor<3x2xi32> , tensor<3x2xi32>) -> tensor<3x2xi32>
|
||||||
|
"std.return"(%3) : (tensor<3x2xi32>) -> ()
|
||||||
|
// CHECK-NEXT: [[CONST1:%.+]] = "onnx.Constant"() {value = dense<{{.}}[3, 4], [5, 6], [7, 8]]> : tensor<3x2xi32>} : () -> tensor<3x2xi32>
|
||||||
|
// CHECK-NEXT: [[ADD1:%.+]] = "onnx.Add"(%arg0, [[CONST1]]) : (tensor<3x2xi32>, tensor<3x2xi32>) -> tensor<3x2xi32>
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Test broadcast 2d (size one) -> 2d
|
||||||
|
|
||||||
|
// CHECK-LABEL: @test_broadcast_2(%arg0: tensor<3x2xi32>) -> tensor<3x2xi32>
|
||||||
|
func @test_broadcast_2(%arg0: tensor<3x2xi32>) -> tensor<3x2xi32> {
|
||||||
|
%0 = "onnx.Constant"() {value = dense<[[1]]> : tensor<1x1xi32>} : () -> tensor<1x1xi32>
|
||||||
|
%1 = "onnx.Constant"() {value = dense<[[2, 3], [4, 5], [6, 7]]> : tensor<3x2xi32>} : () -> tensor<3x2xi32>
|
||||||
|
%2 = "onnx.Add"(%0, %1) : (tensor<1x1xi32> , tensor<3x2xi32>) -> tensor<3x2xi32>
|
||||||
|
%3 = "onnx.Add"(%2, %arg0) : (tensor<3x2xi32> , tensor<3x2xi32>) -> tensor<3x2xi32>
|
||||||
|
"std.return"(%3) : (tensor<3x2xi32>) -> ()
|
||||||
|
// CHECK-NEXT: [[CONST1:%.+]] = "onnx.Constant"() {value = dense<{{.}}[3, 4], [5, 6], [7, 8]]> : tensor<3x2xi32>} : () -> tensor<3x2xi32>
|
||||||
|
// CHECK-NEXT: [[ADD1:%.+]] = "onnx.Add"(%arg0, [[CONST1]]) : (tensor<3x2xi32>, tensor<3x2xi32>) -> tensor<3x2xi32>
|
||||||
|
}
|
||||||
|
|
||||||
|
/// check 1d -> 2d
|
||||||
|
|
||||||
|
// CHECK-LABEL: @test_broadcast_3(%arg0: tensor<3x2xi32>) -> tensor<3x2xi32>
|
||||||
|
func @test_broadcast_3(%arg0 : tensor<3x2xi32>) -> tensor<3x2xi32> {
|
||||||
|
%0 = "onnx.Constant"() {value = dense<[[1], [2], [3]]> : tensor<3x1xi32>} : () -> tensor<3x1xi32>
|
||||||
|
%1 = "onnx.Constant"() {value = dense<[[10, 11], [21, 22], [31, 32]]> : tensor<3x2xi32>} : () -> tensor<3x2xi32>
|
||||||
|
%2 = "onnx.Add"(%0, %1) : (tensor<3x1xi32> , tensor<3x2xi32>) -> tensor<3x2xi32>
|
||||||
|
%3 = "onnx.Add"(%2, %arg0) : (tensor<3x2xi32> , tensor<3x2xi32>) -> tensor<3x2xi32>
|
||||||
|
"std.return"(%3) : (tensor<3x2xi32>) -> ()
|
||||||
|
// CHECK-NEXT: [[CONST1:%.+]] = "onnx.Constant"() {value = dense<{{.}}[11, 12], [23, 24], [34, 35]]> : tensor<3x2xi32>} : () -> tensor<3x2xi32>
|
||||||
|
// CHECK-NEXT: [[ADD1:%.+]] = "onnx.Add"(%arg0, [[CONST1]]) : (tensor<3x2xi32>, tensor<3x2xi32>) -> tensor<3x2xi32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// =============================================================================
|
||||||
|
/// MUL tests (same as add, so have only one).
|
||||||
|
|
||||||
|
/// Change (x*c1)*c2 to x*(c1*c2)
|
||||||
|
// CHECK-LABEL: @test_mul_constant_3(%arg0: tensor<3xi32>) -> tensor<3xi32>
|
||||||
|
func @test_mul_constant_3(%arg0 : tensor<3xi32>) -> tensor<3xi32> {
|
||||||
|
%0 = "onnx.Constant"() {value = dense<[0, 1, 2]> : tensor<3xi32>} : () -> tensor<3xi32>
|
||||||
|
%1 = "onnx.Constant"() {value = dense<[10, 11, 12]> : tensor<3xi32>} : () -> tensor<3xi32>
|
||||||
|
%2 = "onnx.Mul"(%0, %arg0) : (tensor<3xi32> , tensor<3xi32>) -> tensor<3xi32>
|
||||||
|
%3 = "onnx.Mul"(%1, %2) : (tensor<3xi32> , tensor<3xi32>) -> tensor<3xi32>
|
||||||
|
"std.return"(%3) : (tensor<3xi32>) -> ()
|
||||||
|
// CHECK-NEXT: [[CONST1:%.+]] = "onnx.Constant"() {value = dense<[0, 11, 24]> : tensor<3xi32>} : () -> tensor<3xi32>
|
||||||
|
// CHECK-NEXT: [[MUL1:%.+]] = "onnx.Mul"(%arg0, [[CONST1]]) : (tensor<3xi32>, tensor<3xi32>) -> tensor<3xi32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// =============================================================================
|
||||||
|
/// SUB and NEG tests.
|
||||||
|
|
||||||
|
|
||||||
|
// check of sub two constants
|
||||||
|
|
||||||
|
// CHECK-LABEL: @test_sub_1(%arg0: tensor<3x2xi32>) -> tensor<3x2xi32>
|
||||||
|
func @test_sub_1(%arg0: tensor<3x2xi32>) -> tensor<3x2xi32> {
|
||||||
|
%0 = "onnx.Constant"() {value = dense<[[2, 3], [4, 5], [6, 7]]> : tensor<3x2xi32>} : () -> tensor<3x2xi32>
|
||||||
|
%1 = "onnx.Constant"() {value = dense<[[2]]> : tensor<1x1xi32>} : () -> tensor<1x1xi32>
|
||||||
|
%2 = "onnx.Sub"(%0, %1) : (tensor<3x2xi32>, tensor<1x1xi32>) -> tensor<3x2xi32>
|
||||||
|
"std.return"(%2) : (tensor<3x2xi32>) -> ()
|
||||||
|
// CHECK-NEXT: [[CONST1:%.+]] = "onnx.Constant"() {value = dense<{{.}}[0, 1], [2, 3], [4, 5]]> : tensor<3x2xi32>} : () -> tensor<3x2xi32>
|
||||||
|
}
|
||||||
|
|
||||||
|
/// check sub to add of negative
|
||||||
|
|
||||||
|
// CHECK-LABEL: @test_neg_1(%arg0: tensor<3x2xi32>) -> tensor<3x2xi32>
|
||||||
|
func @test_neg_1(%arg0: tensor<3x2xi32>) -> tensor<3x2xi32> {
|
||||||
|
%0 = "onnx.Constant"() {value = dense<[[2, 3], [4, 5], [6, 7]]> : tensor<3x2xi32>} : () -> tensor<3x2xi32>
|
||||||
|
%1 = "onnx.Sub"(%arg0, %0) : (tensor<3x2xi32> , tensor<3x2xi32>) -> tensor<3x2xi32>
|
||||||
|
"std.return"(%1) : (tensor<3x2xi32>) -> ()
|
||||||
|
// CHECK-NEXT: [[CONST1:%.+]] = "onnx.Constant"() {value = dense<{{.}}[-2, -3], [-4, -5], [-6, -7]]> : tensor<3x2xi32>} : () -> tensor<3x2xi32>
|
||||||
|
// CHECK-NEXT: [[ADD1:%.+]] = "onnx.Add"(%arg0, [[CONST1]]) : (tensor<3x2xi32>, tensor<3x2xi32>) -> tensor<3x2xi32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: @test_neg_2(%arg0: tensor<3x2xi32>) -> tensor<3x2xi32>
|
||||||
|
func @test_neg_2(%arg0: tensor<3x2xi32>) -> tensor<3x2xi32> {
|
||||||
|
%0 = "onnx.Constant"() {value = dense<[[2, 3], [4, 5], [6, 7]]> : tensor<3x2xi32>} : () -> tensor<3x2xi32>
|
||||||
|
%1 = "onnx.Constant"() {value = dense<[[10]]> : tensor<1x1xi32>} : () -> tensor<1x1xi32>
|
||||||
|
%2 = "onnx.Sub"(%arg0, %0) : (tensor<3x2xi32> , tensor<3x2xi32>) -> tensor<3x2xi32>
|
||||||
|
%5 = "onnx.Add"(%2, %1) : (tensor<3x2xi32> , tensor<1x1xi32>) -> tensor<3x2xi32>
|
||||||
|
"std.return"(%5) : (tensor<3x2xi32>) -> ()
|
||||||
|
// CHECK-NEXT: [[CONST1:%.+]] = "onnx.Constant"() {value = dense<{{.}}[8, 7], [6, 5], [4, 3]]> : tensor<3x2xi32>} : () -> tensor<3x2xi32>
|
||||||
|
// CHECK-NEXT: [[ADD1:%.+]] = "onnx.Add"(%arg0, [[CONST1]]) : (tensor<3x2xi32>, tensor<3x2xi32>) -> tensor<3x2xi32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: @test_neg_3(%arg0: tensor<3x2xi32>) -> tensor<3x2xi32>
|
||||||
|
func @test_neg_3(%arg0: tensor<3x2xi32>) -> tensor<3x2xi32> {
|
||||||
|
%0 = "onnx.Constant"() {value = dense<[[2, 3], [4, 5], [6, 7]]> : tensor<3x2xi32>} : () -> tensor<3x2xi32>
|
||||||
|
%1 = "onnx.Constant"() {value = dense<[[10]]> : tensor<1x1xi32>} : () -> tensor<1x1xi32>
|
||||||
|
%2 = "onnx.Neg"(%0) : (tensor<3x2xi32>) -> tensor<3x2xi32>
|
||||||
|
%3 = "onnx.Add"(%arg0, %2) : (tensor<3x2xi32> , tensor<3x2xi32>) -> tensor<3x2xi32>
|
||||||
|
%4 = "onnx.Add"(%3, %1) : (tensor<3x2xi32> , tensor<1x1xi32>) -> tensor<3x2xi32>
|
||||||
|
"std.return"(%4) : (tensor<3x2xi32>) -> ()
|
||||||
|
// CHECK-NEXT: [[CONST1:%.+]] = "onnx.Constant"() {value = dense<{{.}}[8, 7], [6, 5], [4, 3]]> : tensor<3x2xi32>} : () -> tensor<3x2xi32>
|
||||||
|
// CHECK-NEXT: [[ADD1:%.+]] = "onnx.Add"(%arg0, [[CONST1]]) : (tensor<3x2xi32>, tensor<3x2xi32>) -> tensor<3x2xi32>
|
||||||
|
}
|
||||||
|
|
|
@ -295,8 +295,15 @@ OpsWithResultTypeInference = {
|
||||||
# Currenlty, there are only two build methods generated:
|
# Currenlty, there are only two build methods generated:
|
||||||
# - one with operands and attributes having a separate parameter, and
|
# - one with operands and attributes having a separate parameter, and
|
||||||
# - one with operands and attributes having aggregated parameters.
|
# - one with operands and attributes having aggregated parameters.
|
||||||
custom_builder_ops_list = ['Abs', 'Mul', 'Exp', 'ReduceSum', 'ReduceSumSquare', 'Pad']
|
custom_builder_unranked_ops_list = ['Abs', 'Exp', 'ReduceSum', 'ReduceSumSquare', 'Pad']
|
||||||
|
# Custom builder op list for operations with broadcast; we can deduce the right
|
||||||
|
# output type, no need to leave it undef as in the above list.
|
||||||
|
# Ops must have two operands, not one, not three... And there shall be two.
|
||||||
|
# TODO: handle variadic ops omitted here: Max, Min, Min, Sum.
|
||||||
|
custom_builder_broadcast_ops_list = ['Add', 'And', 'Div', 'Equal', 'Greater',
|
||||||
|
'Less', 'Mul', 'Or', 'Pow', 'Sub', 'Xor']
|
||||||
|
# union of both
|
||||||
|
custom_builder_ops_list = custom_builder_unranked_ops_list + custom_builder_broadcast_ops_list
|
||||||
|
|
||||||
#a dictionary to add any special definition for an operation
|
#a dictionary to add any special definition for an operation
|
||||||
custom_definition_misc = dict([ ('Constant',
|
custom_definition_misc = dict([ ('Constant',
|
||||||
|
@ -716,6 +723,8 @@ def gen_op_def(schema):
|
||||||
s += indent + 'let results = (outs {});\n'.format(
|
s += indent + 'let results = (outs {});\n'.format(
|
||||||
(',\n' + inc_indent(indent)).join(outs_strs))
|
(',\n' + inc_indent(indent)).join(outs_strs))
|
||||||
|
|
||||||
|
# custom_builder_broadcast_ops_list
|
||||||
|
|
||||||
# add custom builders
|
# add custom builders
|
||||||
# use element type of the first operand to construct an UnrankedTensorType for the output.
|
# use element type of the first operand to construct an UnrankedTensorType for the output.
|
||||||
if schema.name in custom_builder_ops_list:
|
if schema.name in custom_builder_ops_list:
|
||||||
|
@ -726,7 +735,8 @@ def gen_op_def(schema):
|
||||||
else:
|
else:
|
||||||
s += indent + 'let builders = [\n'
|
s += indent + 'let builders = [\n'
|
||||||
# Custom builders with operands and attributes having a seperate parameter.
|
# Custom builders with operands and attributes having a seperate parameter.
|
||||||
# E.g. OpBuilder<"OpBuilder &builder, OperationState &state, Value X, Value, Y, Attribute A", [{}]>
|
# E.g. OpBuilder<"OpBuilder &builder, OperationState &state, Value X,
|
||||||
|
# Value, Y, Attribute A", [{}]>
|
||||||
indent = inc_indent(indent)
|
indent = inc_indent(indent)
|
||||||
s += indent + 'OpBuilder<"OpBuilder &builder, OperationState &state'
|
s += indent + 'OpBuilder<"OpBuilder &builder, OperationState &state'
|
||||||
operands_dict = get_operands_or_results(schema, is_input=True)
|
operands_dict = get_operands_or_results(schema, is_input=True)
|
||||||
|
@ -740,9 +750,26 @@ def gen_op_def(schema):
|
||||||
|
|
||||||
# Get output type from first operand's type.
|
# Get output type from first operand's type.
|
||||||
first_operand_name = list(ins.items())[0][0]
|
first_operand_name = list(ins.items())[0][0]
|
||||||
s += indent + 'auto elementType = {}.getType().cast<TensorType>().getElementType();\n'.format(
|
build_type_name = ''
|
||||||
first_operand_name)
|
if schema.name in custom_builder_broadcast_ops_list:
|
||||||
s += indent + 'build(builder, state, UnrankedTensorType::get(elementType)'
|
second_operand_name = list(ins.items())[1][0]
|
||||||
|
s += indent + 'auto lhsTy = {}.getType().cast<RankedTensorType>();\n'. \
|
||||||
|
format(first_operand_name)
|
||||||
|
s += indent + 'auto rhsTy = {}.getType().cast<RankedTensorType>();\n'. \
|
||||||
|
format(second_operand_name)
|
||||||
|
s += indent + 'auto elementType = getBroadcastedType(lhsTy, rhsTy);\n'
|
||||||
|
s += indent + 'auto shapedType = elementType.dyn_cast_or_null<ShapedType>();\n';
|
||||||
|
s += indent + 'if (!shapedType || !shapedType.hasStaticShape()) {\n';
|
||||||
|
s += indent + indent + 'elementType = {}'.format(first_operand_name) + \
|
||||||
|
'.getType().cast<TensorType>().getElementType();\n';
|
||||||
|
s += indent + indent + 'elementType = UnrankedTensorType::get(elementType);\n'
|
||||||
|
s += indent + '}\n';
|
||||||
|
build_type_name = 'elementType'
|
||||||
|
else:
|
||||||
|
s += indent + 'auto elementType = {}'.format(first_operand_name) + \
|
||||||
|
'.getType().cast<TensorType>().getElementType();\n'
|
||||||
|
build_type_name = 'UnrankedTensorType::get(elementType)'
|
||||||
|
s += indent + 'build(builder, state, {}'.format(build_type_name)
|
||||||
for name, _ in ins.items():
|
for name, _ in ins.items():
|
||||||
s += ', ' + name
|
s += ', ' + name
|
||||||
s += ');\n'
|
s += ');\n'
|
||||||
|
@ -750,12 +777,26 @@ def gen_op_def(schema):
|
||||||
s += indent + '}]>,\n'
|
s += indent + '}]>,\n'
|
||||||
|
|
||||||
# Custom builders with all operands and attributes having aggregate parameters.
|
# Custom builders with all operands and attributes having aggregate parameters.
|
||||||
# E.g. OpBuilder<"OpBuilder &builder, OperationState &state, ValueRange operands, ArrayRef<NamedAttribute> attributes", [{}]>'
|
# E.g. OpBuilder<"OpBuilder &builder, OperationState &state, ValueRange operands,
|
||||||
s += indent + 'OpBuilder<"OpBuilder &builder, OperationState &state, ValueRange operands, ArrayRef<NamedAttribute> attributes", [{\n'
|
# ArrayRef<NamedAttribute> attributes", [{}]>'
|
||||||
|
s += indent + 'OpBuilder<"OpBuilder &builder, OperationState &state, ' + \
|
||||||
|
'ValueRange operands, ArrayRef<NamedAttribute> attributes", [{\n'
|
||||||
indent = inc_indent(indent)
|
indent = inc_indent(indent)
|
||||||
s += indent + 'auto elementType = operands[0].getType().cast<TensorType>().getElementType();\n'
|
if schema.name in custom_builder_broadcast_ops_list:
|
||||||
|
s += indent + 'auto lhsTy = operands[0].getType().cast<RankedTensorType>();\n'
|
||||||
|
s += indent + 'auto rhsTy = operands[1].getType().cast<RankedTensorType>();\n'
|
||||||
|
s += indent + 'auto elementType = getBroadcastedType(lhsTy, rhsTy);\n'
|
||||||
|
s += indent + 'auto shapedType = elementType.dyn_cast_or_null<ShapedType>();\n';
|
||||||
|
s += indent + 'if (!shapedType || !shapedType.hasStaticShape()) {\n';
|
||||||
|
s += indent + indent + 'elementType = operands[0]' + \
|
||||||
|
'.getType().cast<TensorType>().getElementType();\n';
|
||||||
|
s += indent + indent + 'elementType = UnrankedTensorType::get(elementType);\n'
|
||||||
|
s += indent + '}\n';
|
||||||
|
else:
|
||||||
|
s += indent + 'auto elementType = operands[0].getType().' + \
|
||||||
|
'cast<TensorType>().getElementType();\n'
|
||||||
s += indent + 'std::vector<mlir::Type> outputTypes;\n'
|
s += indent + 'std::vector<mlir::Type> outputTypes;\n'
|
||||||
s += indent + 'outputTypes.emplace_back(UnrankedTensorType::get(elementType));\n'
|
s += indent + 'outputTypes.emplace_back({});\n'.format(build_type_name)
|
||||||
s += indent + 'build(builder, state, outputTypes, operands, attributes);\n'
|
s += indent + 'build(builder, state, outputTypes, operands, attributes);\n'
|
||||||
indent = dec_indent(indent)
|
indent = dec_indent(indent)
|
||||||
s += indent + '}]>'
|
s += indent + '}]>'
|
||||||
|
|
Loading…
Reference in New Issue