Constprop (#162)
* initial const prop attempt * added support for broadcast ops * adde all binary broadcast ops into custom builders with precise type * added test example * working * format * fixed suggestion by Tung, start woring on unary * added subtraction and neg the right way, and added elementwise mul too * formatting changes * format * format * added instructions to add new optimizations
This commit is contained in:
parent
dedd5f4a12
commit
e2af505746
|
@ -116,6 +116,7 @@ docs/_build/
|
|||
|
||||
# PyBuilder
|
||||
target/
|
||||
utils/ONNXOps.td.inc
|
||||
|
||||
# Jupyter Notebook
|
||||
.ipynb_checkpoints
|
||||
|
@ -175,3 +176,6 @@ dmypy.json
|
|||
|
||||
# pytype static type analyzer
|
||||
.pytype/
|
||||
|
||||
#editor
|
||||
*~
|
||||
|
|
|
@ -1,9 +1,9 @@
|
|||
# Import ONNX specifications into ONNX-MLIR
|
||||
|
||||
ONNX specifications are defined under `onnx/defs` directory in the ONNX project repository.
|
||||
There is a python script onnx/defs/gen_doc.py that automatically generate documents about operations in ONNX (docs/Operations.md).
|
||||
There is a python script onnx/defs/gen_onnx_mlir.py that automatically generate documents about operations in ONNX (docs/Operations.md).
|
||||
ONNX-MLIR modified this script to import ONNX specifications into ONNX-MLIR.
|
||||
There are two files generated for ONNX MLIR with the modified gen_doc.py:
|
||||
There are two files generated for ONNX MLIR with the modified gen_onnx_mlir.py:
|
||||
|
||||
1. `src/Dialect/ONNX/ONNXOps.td.inc`: Operation definition for MLIR TableGen. `src/Dialect/ONNX/ONNXOps.td` includes this file.
|
||||
2. `src/Builder/OpBuildTable.inc`: C++ code for ONNX-MLIR frontend to import operation nodes from ONNX model. `src/Builder/FrontendDialectTransformer.cpp` includes this file.
|
||||
|
@ -22,7 +22,7 @@ Even though we strive to support the latest version of ONNX specification as qui
|
|||
Due to the possibility of such a delay, operator definition within the ONNX project repository may describe features and schemas that we do not yet support.
|
||||
|
||||
## Customization
|
||||
In addition to following the ONNX specification, the script gen_onnx_mlir.py, modified gen_doc.py, provides some mechanism for you to customize the output.
|
||||
In addition to following the ONNX specification, the script gen_onnx_mlir.py, modified gen_onnx_mlir.py, provides some mechanism for you to customize the output.
|
||||
Several tables are defined at the beginning of the script:
|
||||
1. `special_attr_defaults`: gives attribute special default value.
|
||||
2. `special_op_handler`: creates special import function in frontend_dialect_transformer.cpp. Currently, a special handler is used for operations with operational arguments
|
||||
|
|
|
@ -93,17 +93,43 @@ def ONNXAddOp:ONNX_Op<"Add",
|
|||
let arguments = (ins AnyTypeOf<[TensorOf<[I32,I64,F16,F32,F64]>, MemRefOf<[I32,I64,F16,F32,F64]>]>:$A,
|
||||
AnyTypeOf<[TensorOf<[I32,I64,F16,F32,F64]>, MemRefOf<[I32,I64,F16,F32,F64]>]>:$B);
|
||||
let results = (outs AnyTypeOf<[TensorOf<[I32,I64,F16,F32,F64]>, MemRefOf<[I32,I64,F16,F32,F64]>]>:$C);
|
||||
let extraClassDeclaration = [{
|
||||
static int getNumberOfOperands() {
|
||||
return 2;
|
||||
}
|
||||
static int getNumberOfResults() {
|
||||
return 1;
|
||||
}
|
||||
static std::vector<int> getTypeMap() {
|
||||
return {20};
|
||||
}
|
||||
}];
|
||||
let builders = [
|
||||
OpBuilder<"OpBuilder &builder, OperationState &state, Value A, Value B", [{
|
||||
auto lhsTy = A.getType().cast<RankedTensorType>();
|
||||
auto rhsTy = B.getType().cast<RankedTensorType>();
|
||||
auto elementType = getBroadcastedType(lhsTy, rhsTy);
|
||||
auto shapedType = elementType.dyn_cast_or_null<ShapedType>();
|
||||
if (!shapedType || !shapedType.hasStaticShape()) {
|
||||
elementType = A.getType().cast<TensorType>().getElementType();
|
||||
elementType = UnrankedTensorType::get(elementType);
|
||||
}
|
||||
build(builder, state, elementType, A, B);
|
||||
}]>,
|
||||
OpBuilder<"OpBuilder &builder, OperationState &state, ValueRange operands, ArrayRef<NamedAttribute> attributes", [{
|
||||
auto lhsTy = operands[0].getType().cast<RankedTensorType>();
|
||||
auto rhsTy = operands[1].getType().cast<RankedTensorType>();
|
||||
auto elementType = getBroadcastedType(lhsTy, rhsTy);
|
||||
auto shapedType = elementType.dyn_cast_or_null<ShapedType>();
|
||||
if (!shapedType || !shapedType.hasStaticShape()) {
|
||||
elementType = operands[0].getType().cast<TensorType>().getElementType();
|
||||
elementType = UnrankedTensorType::get(elementType);
|
||||
}
|
||||
std::vector<mlir::Type> outputTypes;
|
||||
outputTypes.emplace_back(elementType);
|
||||
build(builder, state, outputTypes, operands, attributes);
|
||||
}]>
|
||||
];
|
||||
let extraClassDeclaration = [{
|
||||
static int getNumberOfOperands() {
|
||||
return 2;
|
||||
}
|
||||
static int getNumberOfResults() {
|
||||
return 1;
|
||||
}
|
||||
static std::vector<int> getTypeMap() {
|
||||
return {20};
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def ONNXAndOp:ONNX_Op<"And",
|
||||
|
@ -118,17 +144,43 @@ def ONNXAndOp:ONNX_Op<"And",
|
|||
let arguments = (ins AnyTypeOf<[TensorOf<[I1]>, MemRefOf<[I1]>]>:$A,
|
||||
AnyTypeOf<[TensorOf<[I1]>, MemRefOf<[I1]>]>:$B);
|
||||
let results = (outs AnyTypeOf<[TensorOf<[I1]>, MemRefOf<[I1]>]>:$C);
|
||||
let extraClassDeclaration = [{
|
||||
static int getNumberOfOperands() {
|
||||
return 2;
|
||||
}
|
||||
static int getNumberOfResults() {
|
||||
return 1;
|
||||
}
|
||||
static std::vector<int> getTypeMap() {
|
||||
return {0};
|
||||
}
|
||||
}];
|
||||
let builders = [
|
||||
OpBuilder<"OpBuilder &builder, OperationState &state, Value A, Value B", [{
|
||||
auto lhsTy = A.getType().cast<RankedTensorType>();
|
||||
auto rhsTy = B.getType().cast<RankedTensorType>();
|
||||
auto elementType = getBroadcastedType(lhsTy, rhsTy);
|
||||
auto shapedType = elementType.dyn_cast_or_null<ShapedType>();
|
||||
if (!shapedType || !shapedType.hasStaticShape()) {
|
||||
elementType = A.getType().cast<TensorType>().getElementType();
|
||||
elementType = UnrankedTensorType::get(elementType);
|
||||
}
|
||||
build(builder, state, elementType, A, B);
|
||||
}]>,
|
||||
OpBuilder<"OpBuilder &builder, OperationState &state, ValueRange operands, ArrayRef<NamedAttribute> attributes", [{
|
||||
auto lhsTy = operands[0].getType().cast<RankedTensorType>();
|
||||
auto rhsTy = operands[1].getType().cast<RankedTensorType>();
|
||||
auto elementType = getBroadcastedType(lhsTy, rhsTy);
|
||||
auto shapedType = elementType.dyn_cast_or_null<ShapedType>();
|
||||
if (!shapedType || !shapedType.hasStaticShape()) {
|
||||
elementType = operands[0].getType().cast<TensorType>().getElementType();
|
||||
elementType = UnrankedTensorType::get(elementType);
|
||||
}
|
||||
std::vector<mlir::Type> outputTypes;
|
||||
outputTypes.emplace_back(elementType);
|
||||
build(builder, state, outputTypes, operands, attributes);
|
||||
}]>
|
||||
];
|
||||
let extraClassDeclaration = [{
|
||||
static int getNumberOfOperands() {
|
||||
return 2;
|
||||
}
|
||||
static int getNumberOfResults() {
|
||||
return 1;
|
||||
}
|
||||
static std::vector<int> getTypeMap() {
|
||||
return {0};
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def ONNXArgMaxOp:ONNX_Op<"ArgMax",
|
||||
|
@ -933,17 +985,43 @@ def ONNXDivOp:ONNX_Op<"Div",
|
|||
let arguments = (ins AnyTypeOf<[TensorOf<[I32,I64,F16,F32,F64]>, MemRefOf<[I32,I64,F16,F32,F64]>]>:$A,
|
||||
AnyTypeOf<[TensorOf<[I32,I64,F16,F32,F64]>, MemRefOf<[I32,I64,F16,F32,F64]>]>:$B);
|
||||
let results = (outs AnyTypeOf<[TensorOf<[I32,I64,F16,F32,F64]>, MemRefOf<[I32,I64,F16,F32,F64]>]>:$C);
|
||||
let extraClassDeclaration = [{
|
||||
static int getNumberOfOperands() {
|
||||
return 2;
|
||||
}
|
||||
static int getNumberOfResults() {
|
||||
return 1;
|
||||
}
|
||||
static std::vector<int> getTypeMap() {
|
||||
return {20};
|
||||
}
|
||||
}];
|
||||
let builders = [
|
||||
OpBuilder<"OpBuilder &builder, OperationState &state, Value A, Value B", [{
|
||||
auto lhsTy = A.getType().cast<RankedTensorType>();
|
||||
auto rhsTy = B.getType().cast<RankedTensorType>();
|
||||
auto elementType = getBroadcastedType(lhsTy, rhsTy);
|
||||
auto shapedType = elementType.dyn_cast_or_null<ShapedType>();
|
||||
if (!shapedType || !shapedType.hasStaticShape()) {
|
||||
elementType = A.getType().cast<TensorType>().getElementType();
|
||||
elementType = UnrankedTensorType::get(elementType);
|
||||
}
|
||||
build(builder, state, elementType, A, B);
|
||||
}]>,
|
||||
OpBuilder<"OpBuilder &builder, OperationState &state, ValueRange operands, ArrayRef<NamedAttribute> attributes", [{
|
||||
auto lhsTy = operands[0].getType().cast<RankedTensorType>();
|
||||
auto rhsTy = operands[1].getType().cast<RankedTensorType>();
|
||||
auto elementType = getBroadcastedType(lhsTy, rhsTy);
|
||||
auto shapedType = elementType.dyn_cast_or_null<ShapedType>();
|
||||
if (!shapedType || !shapedType.hasStaticShape()) {
|
||||
elementType = operands[0].getType().cast<TensorType>().getElementType();
|
||||
elementType = UnrankedTensorType::get(elementType);
|
||||
}
|
||||
std::vector<mlir::Type> outputTypes;
|
||||
outputTypes.emplace_back(elementType);
|
||||
build(builder, state, outputTypes, operands, attributes);
|
||||
}]>
|
||||
];
|
||||
let extraClassDeclaration = [{
|
||||
static int getNumberOfOperands() {
|
||||
return 2;
|
||||
}
|
||||
static int getNumberOfResults() {
|
||||
return 1;
|
||||
}
|
||||
static std::vector<int> getTypeMap() {
|
||||
return {20};
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def ONNXDropoutOp:ONNX_Op<"Dropout",
|
||||
|
@ -1055,17 +1133,43 @@ def ONNXEqualOp:ONNX_Op<"Equal",
|
|||
let arguments = (ins AnyTypeOf<[TensorOf<[I1,I8,I16,I32,I64,F16,F32,F64]>, MemRefOf<[I1,I8,I16,I32,I64,F16,F32,F64]>]>:$A,
|
||||
AnyTypeOf<[TensorOf<[I1,I8,I16,I32,I64,F16,F32,F64]>, MemRefOf<[I1,I8,I16,I32,I64,F16,F32,F64]>]>:$B);
|
||||
let results = (outs AnyTypeOf<[TensorOf<[I1]>, MemRefOf<[I1]>]>:$C);
|
||||
let extraClassDeclaration = [{
|
||||
static int getNumberOfOperands() {
|
||||
return 2;
|
||||
}
|
||||
static int getNumberOfResults() {
|
||||
return 1;
|
||||
}
|
||||
static std::vector<int> getTypeMap() {
|
||||
return {0};
|
||||
}
|
||||
}];
|
||||
let builders = [
|
||||
OpBuilder<"OpBuilder &builder, OperationState &state, Value A, Value B", [{
|
||||
auto lhsTy = A.getType().cast<RankedTensorType>();
|
||||
auto rhsTy = B.getType().cast<RankedTensorType>();
|
||||
auto elementType = getBroadcastedType(lhsTy, rhsTy);
|
||||
auto shapedType = elementType.dyn_cast_or_null<ShapedType>();
|
||||
if (!shapedType || !shapedType.hasStaticShape()) {
|
||||
elementType = A.getType().cast<TensorType>().getElementType();
|
||||
elementType = UnrankedTensorType::get(elementType);
|
||||
}
|
||||
build(builder, state, elementType, A, B);
|
||||
}]>,
|
||||
OpBuilder<"OpBuilder &builder, OperationState &state, ValueRange operands, ArrayRef<NamedAttribute> attributes", [{
|
||||
auto lhsTy = operands[0].getType().cast<RankedTensorType>();
|
||||
auto rhsTy = operands[1].getType().cast<RankedTensorType>();
|
||||
auto elementType = getBroadcastedType(lhsTy, rhsTy);
|
||||
auto shapedType = elementType.dyn_cast_or_null<ShapedType>();
|
||||
if (!shapedType || !shapedType.hasStaticShape()) {
|
||||
elementType = operands[0].getType().cast<TensorType>().getElementType();
|
||||
elementType = UnrankedTensorType::get(elementType);
|
||||
}
|
||||
std::vector<mlir::Type> outputTypes;
|
||||
outputTypes.emplace_back(elementType);
|
||||
build(builder, state, outputTypes, operands, attributes);
|
||||
}]>
|
||||
];
|
||||
let extraClassDeclaration = [{
|
||||
static int getNumberOfOperands() {
|
||||
return 2;
|
||||
}
|
||||
static int getNumberOfResults() {
|
||||
return 1;
|
||||
}
|
||||
static std::vector<int> getTypeMap() {
|
||||
return {0};
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def ONNXErfOp:ONNX_Op<"Erf",
|
||||
|
@ -1697,17 +1801,43 @@ def ONNXGreaterOp:ONNX_Op<"Greater",
|
|||
let arguments = (ins AnyTypeOf<[TensorOf<[I8,I16,I32,I64,F16,F32,F64]>, MemRefOf<[I8,I16,I32,I64,F16,F32,F64]>]>:$A,
|
||||
AnyTypeOf<[TensorOf<[I8,I16,I32,I64,F16,F32,F64]>, MemRefOf<[I8,I16,I32,I64,F16,F32,F64]>]>:$B);
|
||||
let results = (outs AnyTypeOf<[TensorOf<[I1]>, MemRefOf<[I1]>]>:$C);
|
||||
let extraClassDeclaration = [{
|
||||
static int getNumberOfOperands() {
|
||||
return 2;
|
||||
}
|
||||
static int getNumberOfResults() {
|
||||
return 1;
|
||||
}
|
||||
static std::vector<int> getTypeMap() {
|
||||
return {0};
|
||||
}
|
||||
}];
|
||||
let builders = [
|
||||
OpBuilder<"OpBuilder &builder, OperationState &state, Value A, Value B", [{
|
||||
auto lhsTy = A.getType().cast<RankedTensorType>();
|
||||
auto rhsTy = B.getType().cast<RankedTensorType>();
|
||||
auto elementType = getBroadcastedType(lhsTy, rhsTy);
|
||||
auto shapedType = elementType.dyn_cast_or_null<ShapedType>();
|
||||
if (!shapedType || !shapedType.hasStaticShape()) {
|
||||
elementType = A.getType().cast<TensorType>().getElementType();
|
||||
elementType = UnrankedTensorType::get(elementType);
|
||||
}
|
||||
build(builder, state, elementType, A, B);
|
||||
}]>,
|
||||
OpBuilder<"OpBuilder &builder, OperationState &state, ValueRange operands, ArrayRef<NamedAttribute> attributes", [{
|
||||
auto lhsTy = operands[0].getType().cast<RankedTensorType>();
|
||||
auto rhsTy = operands[1].getType().cast<RankedTensorType>();
|
||||
auto elementType = getBroadcastedType(lhsTy, rhsTy);
|
||||
auto shapedType = elementType.dyn_cast_or_null<ShapedType>();
|
||||
if (!shapedType || !shapedType.hasStaticShape()) {
|
||||
elementType = operands[0].getType().cast<TensorType>().getElementType();
|
||||
elementType = UnrankedTensorType::get(elementType);
|
||||
}
|
||||
std::vector<mlir::Type> outputTypes;
|
||||
outputTypes.emplace_back(elementType);
|
||||
build(builder, state, outputTypes, operands, attributes);
|
||||
}]>
|
||||
];
|
||||
let extraClassDeclaration = [{
|
||||
static int getNumberOfOperands() {
|
||||
return 2;
|
||||
}
|
||||
static int getNumberOfResults() {
|
||||
return 1;
|
||||
}
|
||||
static std::vector<int> getTypeMap() {
|
||||
return {0};
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def ONNXHardSigmoidOp:ONNX_Op<"HardSigmoid",
|
||||
|
@ -2075,17 +2205,43 @@ def ONNXLessOp:ONNX_Op<"Less",
|
|||
let arguments = (ins AnyTypeOf<[TensorOf<[I8,I16,I32,I64,F16,F32,F64]>, MemRefOf<[I8,I16,I32,I64,F16,F32,F64]>]>:$A,
|
||||
AnyTypeOf<[TensorOf<[I8,I16,I32,I64,F16,F32,F64]>, MemRefOf<[I8,I16,I32,I64,F16,F32,F64]>]>:$B);
|
||||
let results = (outs AnyTypeOf<[TensorOf<[I1]>, MemRefOf<[I1]>]>:$C);
|
||||
let extraClassDeclaration = [{
|
||||
static int getNumberOfOperands() {
|
||||
return 2;
|
||||
}
|
||||
static int getNumberOfResults() {
|
||||
return 1;
|
||||
}
|
||||
static std::vector<int> getTypeMap() {
|
||||
return {0};
|
||||
}
|
||||
}];
|
||||
let builders = [
|
||||
OpBuilder<"OpBuilder &builder, OperationState &state, Value A, Value B", [{
|
||||
auto lhsTy = A.getType().cast<RankedTensorType>();
|
||||
auto rhsTy = B.getType().cast<RankedTensorType>();
|
||||
auto elementType = getBroadcastedType(lhsTy, rhsTy);
|
||||
auto shapedType = elementType.dyn_cast_or_null<ShapedType>();
|
||||
if (!shapedType || !shapedType.hasStaticShape()) {
|
||||
elementType = A.getType().cast<TensorType>().getElementType();
|
||||
elementType = UnrankedTensorType::get(elementType);
|
||||
}
|
||||
build(builder, state, elementType, A, B);
|
||||
}]>,
|
||||
OpBuilder<"OpBuilder &builder, OperationState &state, ValueRange operands, ArrayRef<NamedAttribute> attributes", [{
|
||||
auto lhsTy = operands[0].getType().cast<RankedTensorType>();
|
||||
auto rhsTy = operands[1].getType().cast<RankedTensorType>();
|
||||
auto elementType = getBroadcastedType(lhsTy, rhsTy);
|
||||
auto shapedType = elementType.dyn_cast_or_null<ShapedType>();
|
||||
if (!shapedType || !shapedType.hasStaticShape()) {
|
||||
elementType = operands[0].getType().cast<TensorType>().getElementType();
|
||||
elementType = UnrankedTensorType::get(elementType);
|
||||
}
|
||||
std::vector<mlir::Type> outputTypes;
|
||||
outputTypes.emplace_back(elementType);
|
||||
build(builder, state, outputTypes, operands, attributes);
|
||||
}]>
|
||||
];
|
||||
let extraClassDeclaration = [{
|
||||
static int getNumberOfOperands() {
|
||||
return 2;
|
||||
}
|
||||
static int getNumberOfResults() {
|
||||
return 1;
|
||||
}
|
||||
static std::vector<int> getTypeMap() {
|
||||
return {0};
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def ONNXLogOp:ONNX_Op<"Log",
|
||||
|
@ -2646,13 +2802,27 @@ def ONNXMulOp:ONNX_Op<"Mul",
|
|||
let results = (outs AnyTypeOf<[TensorOf<[I32,I64,F16,F32,F64]>, MemRefOf<[I32,I64,F16,F32,F64]>]>:$C);
|
||||
let builders = [
|
||||
OpBuilder<"OpBuilder &builder, OperationState &state, Value A, Value B", [{
|
||||
auto elementType = A.getType().cast<TensorType>().getElementType();
|
||||
build(builder, state, UnrankedTensorType::get(elementType), A, B);
|
||||
auto lhsTy = A.getType().cast<RankedTensorType>();
|
||||
auto rhsTy = B.getType().cast<RankedTensorType>();
|
||||
auto elementType = getBroadcastedType(lhsTy, rhsTy);
|
||||
auto shapedType = elementType.dyn_cast_or_null<ShapedType>();
|
||||
if (!shapedType || !shapedType.hasStaticShape()) {
|
||||
elementType = A.getType().cast<TensorType>().getElementType();
|
||||
elementType = UnrankedTensorType::get(elementType);
|
||||
}
|
||||
build(builder, state, elementType, A, B);
|
||||
}]>,
|
||||
OpBuilder<"OpBuilder &builder, OperationState &state, ValueRange operands, ArrayRef<NamedAttribute> attributes", [{
|
||||
auto elementType = operands[0].getType().cast<TensorType>().getElementType();
|
||||
auto lhsTy = operands[0].getType().cast<RankedTensorType>();
|
||||
auto rhsTy = operands[1].getType().cast<RankedTensorType>();
|
||||
auto elementType = getBroadcastedType(lhsTy, rhsTy);
|
||||
auto shapedType = elementType.dyn_cast_or_null<ShapedType>();
|
||||
if (!shapedType || !shapedType.hasStaticShape()) {
|
||||
elementType = operands[0].getType().cast<TensorType>().getElementType();
|
||||
elementType = UnrankedTensorType::get(elementType);
|
||||
}
|
||||
std::vector<mlir::Type> outputTypes;
|
||||
outputTypes.emplace_back(UnrankedTensorType::get(elementType));
|
||||
outputTypes.emplace_back(elementType);
|
||||
build(builder, state, outputTypes, operands, attributes);
|
||||
}]>
|
||||
];
|
||||
|
@ -2848,17 +3018,43 @@ def ONNXOrOp:ONNX_Op<"Or",
|
|||
let arguments = (ins AnyTypeOf<[TensorOf<[I1]>, MemRefOf<[I1]>]>:$A,
|
||||
AnyTypeOf<[TensorOf<[I1]>, MemRefOf<[I1]>]>:$B);
|
||||
let results = (outs AnyTypeOf<[TensorOf<[I1]>, MemRefOf<[I1]>]>:$C);
|
||||
let extraClassDeclaration = [{
|
||||
static int getNumberOfOperands() {
|
||||
return 2;
|
||||
}
|
||||
static int getNumberOfResults() {
|
||||
return 1;
|
||||
}
|
||||
static std::vector<int> getTypeMap() {
|
||||
return {0};
|
||||
}
|
||||
}];
|
||||
let builders = [
|
||||
OpBuilder<"OpBuilder &builder, OperationState &state, Value A, Value B", [{
|
||||
auto lhsTy = A.getType().cast<RankedTensorType>();
|
||||
auto rhsTy = B.getType().cast<RankedTensorType>();
|
||||
auto elementType = getBroadcastedType(lhsTy, rhsTy);
|
||||
auto shapedType = elementType.dyn_cast_or_null<ShapedType>();
|
||||
if (!shapedType || !shapedType.hasStaticShape()) {
|
||||
elementType = A.getType().cast<TensorType>().getElementType();
|
||||
elementType = UnrankedTensorType::get(elementType);
|
||||
}
|
||||
build(builder, state, elementType, A, B);
|
||||
}]>,
|
||||
OpBuilder<"OpBuilder &builder, OperationState &state, ValueRange operands, ArrayRef<NamedAttribute> attributes", [{
|
||||
auto lhsTy = operands[0].getType().cast<RankedTensorType>();
|
||||
auto rhsTy = operands[1].getType().cast<RankedTensorType>();
|
||||
auto elementType = getBroadcastedType(lhsTy, rhsTy);
|
||||
auto shapedType = elementType.dyn_cast_or_null<ShapedType>();
|
||||
if (!shapedType || !shapedType.hasStaticShape()) {
|
||||
elementType = operands[0].getType().cast<TensorType>().getElementType();
|
||||
elementType = UnrankedTensorType::get(elementType);
|
||||
}
|
||||
std::vector<mlir::Type> outputTypes;
|
||||
outputTypes.emplace_back(elementType);
|
||||
build(builder, state, outputTypes, operands, attributes);
|
||||
}]>
|
||||
];
|
||||
let extraClassDeclaration = [{
|
||||
static int getNumberOfOperands() {
|
||||
return 2;
|
||||
}
|
||||
static int getNumberOfResults() {
|
||||
return 1;
|
||||
}
|
||||
static std::vector<int> getTypeMap() {
|
||||
return {0};
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def ONNXPReluOp:ONNX_Op<"PRelu",
|
||||
|
@ -3017,17 +3213,43 @@ def ONNXPowOp:ONNX_Op<"Pow",
|
|||
let arguments = (ins AnyTypeOf<[TensorOf<[F16,F32,F64]>, MemRefOf<[F16,F32,F64]>]>:$X,
|
||||
AnyTypeOf<[TensorOf<[F16,F32,F64]>, MemRefOf<[F16,F32,F64]>]>:$Y);
|
||||
let results = (outs AnyTypeOf<[TensorOf<[F16,F32,F64]>, MemRefOf<[F16,F32,F64]>]>:$Z);
|
||||
let extraClassDeclaration = [{
|
||||
static int getNumberOfOperands() {
|
||||
return 2;
|
||||
}
|
||||
static int getNumberOfResults() {
|
||||
return 1;
|
||||
}
|
||||
static std::vector<int> getTypeMap() {
|
||||
return {20};
|
||||
}
|
||||
}];
|
||||
let builders = [
|
||||
OpBuilder<"OpBuilder &builder, OperationState &state, Value X, Value Y", [{
|
||||
auto lhsTy = X.getType().cast<RankedTensorType>();
|
||||
auto rhsTy = Y.getType().cast<RankedTensorType>();
|
||||
auto elementType = getBroadcastedType(lhsTy, rhsTy);
|
||||
auto shapedType = elementType.dyn_cast_or_null<ShapedType>();
|
||||
if (!shapedType || !shapedType.hasStaticShape()) {
|
||||
elementType = X.getType().cast<TensorType>().getElementType();
|
||||
elementType = UnrankedTensorType::get(elementType);
|
||||
}
|
||||
build(builder, state, elementType, X, Y);
|
||||
}]>,
|
||||
OpBuilder<"OpBuilder &builder, OperationState &state, ValueRange operands, ArrayRef<NamedAttribute> attributes", [{
|
||||
auto lhsTy = operands[0].getType().cast<RankedTensorType>();
|
||||
auto rhsTy = operands[1].getType().cast<RankedTensorType>();
|
||||
auto elementType = getBroadcastedType(lhsTy, rhsTy);
|
||||
auto shapedType = elementType.dyn_cast_or_null<ShapedType>();
|
||||
if (!shapedType || !shapedType.hasStaticShape()) {
|
||||
elementType = operands[0].getType().cast<TensorType>().getElementType();
|
||||
elementType = UnrankedTensorType::get(elementType);
|
||||
}
|
||||
std::vector<mlir::Type> outputTypes;
|
||||
outputTypes.emplace_back(elementType);
|
||||
build(builder, state, outputTypes, operands, attributes);
|
||||
}]>
|
||||
];
|
||||
let extraClassDeclaration = [{
|
||||
static int getNumberOfOperands() {
|
||||
return 2;
|
||||
}
|
||||
static int getNumberOfResults() {
|
||||
return 1;
|
||||
}
|
||||
static std::vector<int> getTypeMap() {
|
||||
return {20};
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def ONNXQLinearConvOp:ONNX_Op<"QLinearConv",
|
||||
|
@ -4938,17 +5160,43 @@ def ONNXSubOp:ONNX_Op<"Sub",
|
|||
let arguments = (ins AnyTypeOf<[TensorOf<[I32,I64,F16,F32,F64]>, MemRefOf<[I32,I64,F16,F32,F64]>]>:$A,
|
||||
AnyTypeOf<[TensorOf<[I32,I64,F16,F32,F64]>, MemRefOf<[I32,I64,F16,F32,F64]>]>:$B);
|
||||
let results = (outs AnyTypeOf<[TensorOf<[I32,I64,F16,F32,F64]>, MemRefOf<[I32,I64,F16,F32,F64]>]>:$C);
|
||||
let extraClassDeclaration = [{
|
||||
static int getNumberOfOperands() {
|
||||
return 2;
|
||||
}
|
||||
static int getNumberOfResults() {
|
||||
return 1;
|
||||
}
|
||||
static std::vector<int> getTypeMap() {
|
||||
return {20};
|
||||
}
|
||||
}];
|
||||
let builders = [
|
||||
OpBuilder<"OpBuilder &builder, OperationState &state, Value A, Value B", [{
|
||||
auto lhsTy = A.getType().cast<RankedTensorType>();
|
||||
auto rhsTy = B.getType().cast<RankedTensorType>();
|
||||
auto elementType = getBroadcastedType(lhsTy, rhsTy);
|
||||
auto shapedType = elementType.dyn_cast_or_null<ShapedType>();
|
||||
if (!shapedType || !shapedType.hasStaticShape()) {
|
||||
elementType = A.getType().cast<TensorType>().getElementType();
|
||||
elementType = UnrankedTensorType::get(elementType);
|
||||
}
|
||||
build(builder, state, elementType, A, B);
|
||||
}]>,
|
||||
OpBuilder<"OpBuilder &builder, OperationState &state, ValueRange operands, ArrayRef<NamedAttribute> attributes", [{
|
||||
auto lhsTy = operands[0].getType().cast<RankedTensorType>();
|
||||
auto rhsTy = operands[1].getType().cast<RankedTensorType>();
|
||||
auto elementType = getBroadcastedType(lhsTy, rhsTy);
|
||||
auto shapedType = elementType.dyn_cast_or_null<ShapedType>();
|
||||
if (!shapedType || !shapedType.hasStaticShape()) {
|
||||
elementType = operands[0].getType().cast<TensorType>().getElementType();
|
||||
elementType = UnrankedTensorType::get(elementType);
|
||||
}
|
||||
std::vector<mlir::Type> outputTypes;
|
||||
outputTypes.emplace_back(elementType);
|
||||
build(builder, state, outputTypes, operands, attributes);
|
||||
}]>
|
||||
];
|
||||
let extraClassDeclaration = [{
|
||||
static int getNumberOfOperands() {
|
||||
return 2;
|
||||
}
|
||||
static int getNumberOfResults() {
|
||||
return 1;
|
||||
}
|
||||
static std::vector<int> getTypeMap() {
|
||||
return {20};
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def ONNXSumOp:ONNX_Op<"Sum",
|
||||
|
@ -5379,16 +5627,42 @@ def ONNXXorOp:ONNX_Op<"Xor",
|
|||
let arguments = (ins AnyTypeOf<[TensorOf<[I1]>, MemRefOf<[I1]>]>:$A,
|
||||
AnyTypeOf<[TensorOf<[I1]>, MemRefOf<[I1]>]>:$B);
|
||||
let results = (outs AnyTypeOf<[TensorOf<[I1]>, MemRefOf<[I1]>]>:$C);
|
||||
let extraClassDeclaration = [{
|
||||
static int getNumberOfOperands() {
|
||||
return 2;
|
||||
}
|
||||
static int getNumberOfResults() {
|
||||
return 1;
|
||||
}
|
||||
static std::vector<int> getTypeMap() {
|
||||
return {0};
|
||||
}
|
||||
}];
|
||||
let builders = [
|
||||
OpBuilder<"OpBuilder &builder, OperationState &state, Value A, Value B", [{
|
||||
auto lhsTy = A.getType().cast<RankedTensorType>();
|
||||
auto rhsTy = B.getType().cast<RankedTensorType>();
|
||||
auto elementType = getBroadcastedType(lhsTy, rhsTy);
|
||||
auto shapedType = elementType.dyn_cast_or_null<ShapedType>();
|
||||
if (!shapedType || !shapedType.hasStaticShape()) {
|
||||
elementType = A.getType().cast<TensorType>().getElementType();
|
||||
elementType = UnrankedTensorType::get(elementType);
|
||||
}
|
||||
build(builder, state, elementType, A, B);
|
||||
}]>,
|
||||
OpBuilder<"OpBuilder &builder, OperationState &state, ValueRange operands, ArrayRef<NamedAttribute> attributes", [{
|
||||
auto lhsTy = operands[0].getType().cast<RankedTensorType>();
|
||||
auto rhsTy = operands[1].getType().cast<RankedTensorType>();
|
||||
auto elementType = getBroadcastedType(lhsTy, rhsTy);
|
||||
auto shapedType = elementType.dyn_cast_or_null<ShapedType>();
|
||||
if (!shapedType || !shapedType.hasStaticShape()) {
|
||||
elementType = operands[0].getType().cast<TensorType>().getElementType();
|
||||
elementType = UnrankedTensorType::get(elementType);
|
||||
}
|
||||
std::vector<mlir::Type> outputTypes;
|
||||
outputTypes.emplace_back(elementType);
|
||||
build(builder, state, outputTypes, operands, attributes);
|
||||
}]>
|
||||
];
|
||||
let extraClassDeclaration = [{
|
||||
static int getNumberOfOperands() {
|
||||
return 2;
|
||||
}
|
||||
static int getNumberOfResults() {
|
||||
return 1;
|
||||
}
|
||||
static std::vector<int> getTypeMap() {
|
||||
return {0};
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
|
|
|
@ -88,6 +88,7 @@ void registerDialects() {
|
|||
|
||||
void addONNXToMLIRPasses(mlir::PassManager &pm) {
|
||||
pm.addPass(mlir::createDecomposeONNXToONNXPass());
|
||||
pm.addPass(mlir::createConstPropONNXToONNXPass());
|
||||
pm.addPass(mlir::createShapeInferencePass());
|
||||
pm.addPass(mlir::createCanonicalizerPass());
|
||||
pm.addPass(mlir::createAttributePromotionPass());
|
||||
|
|
|
@ -20,6 +20,8 @@ std::unique_ptr<Pass> createDecomposeONNXToONNXPass();
|
|||
|
||||
std::unique_ptr<Pass> createShapeInferencePass();
|
||||
|
||||
std::unique_ptr<Pass> createConstPropONNXToONNXPass();
|
||||
|
||||
/// Pass for promoting constant operands to attributes.
|
||||
std::unique_ptr<Pass> createAttributePromotionPass();
|
||||
|
||||
|
|
|
@ -33,17 +33,23 @@ set(LLVM_TARGET_DEFINITIONS ONNXDecompose.td)
|
|||
onnx_mlir_tablegen(ONNXDecompose.inc -gen-rewriters)
|
||||
add_public_tablegen_target(OMONNXDecomposeIncGen)
|
||||
|
||||
set(LLVM_TARGET_DEFINITIONS ONNXConstProp.td)
|
||||
onnx_mlir_tablegen(ONNXConstProp.inc -gen-rewriters)
|
||||
add_public_tablegen_target(OMONNXConstPropIncGen)
|
||||
|
||||
add_library(OMONNXRewrite
|
||||
ONNXRewrite.cpp
|
||||
ONNXCombine.cpp
|
||||
ONNXDecompose.cpp)
|
||||
ONNXDecompose.cpp
|
||||
ONNXConstProp.cpp)
|
||||
target_include_directories(OMONNXRewrite
|
||||
PRIVATE ${ONNX_MLIR_SRC_ROOT} ${ONNX_MLIR_BIN_ROOT}
|
||||
${ONNF_MLIR_SRC_ROOT})
|
||||
add_dependencies(OMONNXRewrite
|
||||
OMONNXRewriteIncGen
|
||||
OMONNXDecomposeIncGen
|
||||
OMONNXCombineIncGen)
|
||||
OMONNXCombineIncGen
|
||||
OMONNXConstPropIncGen)
|
||||
# Linking dependencies:
|
||||
add_dependencies(OMONNXRewrite
|
||||
OMONNXOps)
|
||||
|
|
|
@ -0,0 +1,335 @@
|
|||
//===----------- ONNXConstProp.cpp - ONNX High Level Rewriting ------------===//
|
||||
//
|
||||
// Copyright 2019-2020 The IBM Research Authors.
|
||||
//
|
||||
// =============================================================================
|
||||
//
|
||||
// This file implements a set of rewriters to constprop an ONNX operation into
|
||||
// composition of other ONNX operations.
|
||||
//
|
||||
// This pass is applied before any other pass so that there is no need to
|
||||
// implement shape inference for the constpropd operation. Hence, it is expected
|
||||
// that there is no knowledge about tensor shape at this point
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/IR/Matchers.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
#include "src/Pass/Passes.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace {
|
||||
|
||||
// =============================================================================
|
||||
// Instructions to add a constant operation. There is currently support for
|
||||
// adding constant propagation for unary and binary athythmetic ops (binary ops
|
||||
// support broadcast). To add an operation, you simply have to add a templated
|
||||
// method on how to compute the result in terms of one or two inputs. Values
|
||||
// comes as Attribtues, and return is also an Attribute. In that function,
|
||||
// presumably you will need different methods to handle int / float /
|
||||
// strings... Note that these methods cannot fail. It is your responsablitity to
|
||||
// tests for which data type are supported in the rules directly. Specific type
|
||||
// restrictions can be added in the DRR files.
|
||||
|
||||
// The methods are:
|
||||
//
|
||||
// ComputeConstProppElementwiseBinary and ComputeConstProppElementwiseUnary
|
||||
// and they need to be tempalted wtih an ONNX Operation (presuably).
|
||||
//
|
||||
// Then you need to add rules on how to transform the patterns; look into
|
||||
// ONNXConstProp.td for example.
|
||||
//
|
||||
// =============================================================================
|
||||
|
||||
// =============================================================================
|
||||
// Code to perform constant propagation for binary in presence of broadcast.
|
||||
// =============================================================================
|
||||
|
||||
// Template to generate binary operation results. It takes as inupt
|
||||
// the element type as well as the two element attributes for the
|
||||
// operation, and return the result of the operation, also as an
|
||||
// attribute.
|
||||
|
||||
template <typename OP>
|
||||
Attribute ComputeConstProppElementwiseBinary(PatternRewriter &rewriter,
|
||||
Type elementType, Attribute &lhsAttr, Attribute &secondAttr) {
|
||||
llvm_unreachable("unkonwn operation");
|
||||
}
|
||||
|
||||
template <>
|
||||
Attribute ComputeConstProppElementwiseBinary<ONNXAddOp>(
|
||||
PatternRewriter &rewriter, Type elementType, Attribute &lhsAttr,
|
||||
Attribute &secondAttr) {
|
||||
if (elementType.isa<FloatType>()) {
|
||||
double lhsVal = lhsAttr.cast<FloatAttr>().getValueAsDouble();
|
||||
double rhsVal = secondAttr.cast<FloatAttr>().getValueAsDouble();
|
||||
double res = lhsVal + rhsVal;
|
||||
// printf(" %f + %f -> %f\n", lhsVal, rhsVal, res);
|
||||
// Could use the APFloat interface to emulate the results, are ok to simply
|
||||
// perform them in the highest possible precision.
|
||||
return rewriter.getFloatAttr(elementType, res);
|
||||
}
|
||||
if (elementType.isa<IntegerType>()) {
|
||||
uint64_t lhsVal = lhsAttr.cast<IntegerAttr>().getInt();
|
||||
uint64_t rhsVal = secondAttr.cast<IntegerAttr>().getInt();
|
||||
uint64_t res = lhsVal + rhsVal;
|
||||
// printf(" %llu + %llu -> %llu\n", lhsVal, rhsVal, res);
|
||||
return rewriter.getIntegerAttr(elementType, res);
|
||||
}
|
||||
llvm_unreachable("constant propagation for AddOp: unkonwn data type");
|
||||
}
|
||||
|
||||
template <>
|
||||
Attribute ComputeConstProppElementwiseBinary<ONNXSubOp>(
|
||||
PatternRewriter &rewriter, Type elementType, Attribute &lhsAttr,
|
||||
Attribute &secondAttr) {
|
||||
if (elementType.isa<FloatType>()) {
|
||||
double lhsVal = lhsAttr.cast<FloatAttr>().getValueAsDouble();
|
||||
double rhsVal = secondAttr.cast<FloatAttr>().getValueAsDouble();
|
||||
double res = lhsVal - rhsVal;
|
||||
return rewriter.getFloatAttr(elementType, res);
|
||||
}
|
||||
if (elementType.isa<IntegerType>()) {
|
||||
uint64_t lhsVal = lhsAttr.cast<IntegerAttr>().getInt();
|
||||
uint64_t rhsVal = secondAttr.cast<IntegerAttr>().getInt();
|
||||
uint64_t res = lhsVal - rhsVal;
|
||||
return rewriter.getIntegerAttr(elementType, res);
|
||||
}
|
||||
llvm_unreachable("constant propagation for SubOp: unkonwn data type");
|
||||
}
|
||||
|
||||
template <>
|
||||
Attribute ComputeConstProppElementwiseBinary<ONNXMulOp>(
|
||||
PatternRewriter &rewriter, Type elementType, Attribute &lhsAttr,
|
||||
Attribute &secondAttr) {
|
||||
if (elementType.isa<FloatType>()) {
|
||||
double lhsVal = lhsAttr.cast<FloatAttr>().getValueAsDouble();
|
||||
double rhsVal = secondAttr.cast<FloatAttr>().getValueAsDouble();
|
||||
double res = lhsVal * rhsVal;
|
||||
return rewriter.getFloatAttr(elementType, res);
|
||||
}
|
||||
if (elementType.isa<IntegerType>()) {
|
||||
uint64_t lhsVal = lhsAttr.cast<IntegerAttr>().getInt();
|
||||
uint64_t rhsVal = secondAttr.cast<IntegerAttr>().getInt();
|
||||
uint64_t res = lhsVal * rhsVal;
|
||||
return rewriter.getIntegerAttr(elementType, res);
|
||||
}
|
||||
llvm_unreachable("constant propagation for MulOp: unkonwn data type");
|
||||
}
|
||||
|
||||
// Recursively process one dimension in the rank of the two references. There
|
||||
// can be one of 3 cases.
|
||||
// 1) We have fully defined accesses for both operands, launch the computations.
|
||||
// 2) One of the two has a higher number of unprocessed ranks, which is hte case
|
||||
// when we have to broadcast the whole lower-dim reference with respect to the
|
||||
// other. Iterate over each value of the higher ranked reference, keeping the
|
||||
// reference of the lower ranked reference constant.
|
||||
// 3) Both references have the same rank, we still do broadcast if one of the
|
||||
// dimension size is equal to 1.
|
||||
|
||||
template <typename ElementwiseBinaryOp>
|
||||
void RecurseConstProppElementwiseBinary(PatternRewriter &rewriter,
|
||||
std::vector<Attribute> &resVector, DenseElementsAttr &lhsAttr,
|
||||
DenseElementsAttr &rhsAttr, SmallVector<uint64_t, 4> &lhsIndices,
|
||||
SmallVector<uint64_t, 4> &rhsIndices, int lhsFreeRank, int rhsFreeRank) {
|
||||
// printf("recurse with free %d/%d\n", lhsFreeRank, rhsFreeRank);
|
||||
if (lhsFreeRank == 0) {
|
||||
// Fully defined ranks.
|
||||
assert(
|
||||
rhsFreeRank == 0 && "expect both to recurse to zero at the same time");
|
||||
auto lhsElementAttr = lhsAttr.getValue(ArrayRef<uint64_t>(lhsIndices));
|
||||
auto rhsElementAttr = rhsAttr.getValue(ArrayRef<uint64_t>(rhsIndices));
|
||||
auto elementaryType = lhsAttr.getType().getElementType();
|
||||
auto res = ComputeConstProppElementwiseBinary<ElementwiseBinaryOp>(
|
||||
rewriter, elementaryType, lhsElementAttr, rhsElementAttr);
|
||||
resVector.emplace_back(res);
|
||||
} else if (lhsFreeRank > rhsFreeRank) {
|
||||
// Initial broadcast from lhs.
|
||||
auto lhsShape = lhsAttr.getType().getShape();
|
||||
int lhsRank = lhsShape.size();
|
||||
int lhsIndex = lhsRank - lhsFreeRank;
|
||||
int lhsSize = lhsAttr.getType().getShape()[lhsIndex];
|
||||
for (int i = 0; i < lhsSize; ++i) {
|
||||
lhsIndices[lhsIndex] = i;
|
||||
RecurseConstProppElementwiseBinary<ElementwiseBinaryOp>(rewriter,
|
||||
resVector, lhsAttr, rhsAttr, lhsIndices, rhsIndices, lhsFreeRank - 1,
|
||||
rhsFreeRank);
|
||||
}
|
||||
} else if (lhsFreeRank < rhsFreeRank) {
|
||||
// Initial broadcast from rhs.
|
||||
auto rhsShape = rhsAttr.getType().getShape();
|
||||
int rhsRank = rhsShape.size();
|
||||
int rhsIndex = rhsRank - rhsFreeRank;
|
||||
int rhsSize = rhsAttr.getType().getShape()[rhsIndex];
|
||||
for (int i = 0; i < rhsSize; ++i) {
|
||||
rhsIndices[rhsIndex] = i;
|
||||
RecurseConstProppElementwiseBinary<ElementwiseBinaryOp>(rewriter,
|
||||
resVector, lhsAttr, rhsAttr, lhsIndices, rhsIndices, lhsFreeRank,
|
||||
rhsFreeRank - 1);
|
||||
}
|
||||
} else {
|
||||
// No initial broadcast, but if one element has size 1 and the other is
|
||||
// greater than one, then we also have broadcast.
|
||||
auto lhsShape = lhsAttr.getType().getShape();
|
||||
int lhsRank = lhsShape.size();
|
||||
int lhsIndex = lhsRank - lhsFreeRank;
|
||||
int lhsSize = lhsAttr.getType().getShape()[lhsIndex];
|
||||
auto rhsShape = rhsAttr.getType().getShape();
|
||||
int rhsRank = rhsShape.size();
|
||||
int rhsIndex = rhsRank - rhsFreeRank;
|
||||
int rhsSize = rhsAttr.getType().getShape()[rhsIndex];
|
||||
assert((lhsSize == 1 || rhsSize == 1 || lhsSize == rhsSize) &&
|
||||
"incompatible sizes");
|
||||
int size = std::max(lhsSize, rhsSize);
|
||||
lhsIndices[lhsIndex] = rhsIndices[rhsIndex] = 0;
|
||||
for (int i = 0; i < size; ++i) {
|
||||
if (lhsSize > 1)
|
||||
lhsIndices[lhsIndex] = i;
|
||||
if (rhsSize > 1)
|
||||
rhsIndices[rhsIndex] = i;
|
||||
RecurseConstProppElementwiseBinary<ElementwiseBinaryOp>(rewriter,
|
||||
resVector, lhsAttr, rhsAttr, lhsIndices, rhsIndices, lhsFreeRank - 1,
|
||||
rhsFreeRank - 1);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Process the constant operands, perform the operation with broadcast, and
|
||||
// generate the new constant operation.
|
||||
template <typename ElementwiseBinaryOp>
|
||||
DenseElementsAttr ConstPropElementwiseBinary(PatternRewriter &rewriter,
|
||||
Value resOperand, Attribute &lhsAttr, Attribute &rhsAttr) {
|
||||
DenseElementsAttr lhsDenseAttr =
|
||||
lhsAttr.dyn_cast_or_null<mlir::DenseElementsAttr>();
|
||||
DenseElementsAttr rhsDenseAttr =
|
||||
rhsAttr.dyn_cast_or_null<mlir::DenseElementsAttr>();
|
||||
assert((lhsDenseAttr && lhsDenseAttr) && "expected dense attributes");
|
||||
assert(
|
||||
resOperand.getType().isa<RankedTensorType>() && "expected ranked tensor");
|
||||
ShapedType resType = resOperand.getType().cast<RankedTensorType>();
|
||||
auto lhsRank = lhsDenseAttr.getType().getShape().size();
|
||||
auto rhsRank = rhsDenseAttr.getType().getShape().size();
|
||||
SmallVector<uint64_t, 4> lhsIndices(lhsRank, 0);
|
||||
SmallVector<uint64_t, 4> rhsIndices(rhsRank, 0);
|
||||
std::vector<Attribute> resVector;
|
||||
RecurseConstProppElementwiseBinary<ElementwiseBinaryOp>(rewriter, resVector,
|
||||
lhsDenseAttr, rhsDenseAttr, lhsIndices, rhsIndices, lhsRank, rhsRank);
|
||||
ArrayRef<Attribute> resRef(resVector);
|
||||
return DenseElementsAttr::get(resType, resRef);
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// Code to perform constant propagation for unary operation.
|
||||
// =============================================================================
|
||||
|
||||
template <typename OP>
|
||||
Attribute ComputeConstProppElementwiseUnary(
|
||||
PatternRewriter &rewriter, Type elementType, Attribute &attr) {
|
||||
llvm_unreachable("unkonwn operation");
|
||||
}
|
||||
|
||||
template <>
|
||||
Attribute ComputeConstProppElementwiseUnary<ONNXNegOp>(
|
||||
PatternRewriter &rewriter, Type elementType, Attribute &attr) {
|
||||
if (elementType.isa<FloatType>()) {
|
||||
double val = attr.cast<FloatAttr>().getValueAsDouble();
|
||||
double res = -val;
|
||||
return rewriter.getFloatAttr(elementType, res);
|
||||
}
|
||||
if (elementType.isa<IntegerType>()) {
|
||||
uint64_t val = attr.cast<IntegerAttr>().getInt();
|
||||
uint64_t res = -val;
|
||||
return rewriter.getIntegerAttr(elementType, res);
|
||||
}
|
||||
llvm_unreachable("constant propagation for NegOp: unkonwn data type");
|
||||
}
|
||||
|
||||
template <typename ElementwiseUnaryOp>
|
||||
void RecurseConstProppElementwiseUnary(PatternRewriter &rewriter,
|
||||
std::vector<Attribute> &resVector, DenseElementsAttr &attr,
|
||||
SmallVector<uint64_t, 4> &indices, int freeRank) {
|
||||
// printf("recurse with free %d\n", freeRank);
|
||||
if (freeRank == 0) {
|
||||
// Fully defined ranks.
|
||||
auto elementAttr = attr.getValue(ArrayRef<uint64_t>(indices));
|
||||
auto elementaryType = attr.getType().getElementType();
|
||||
auto res = ComputeConstProppElementwiseUnary<ElementwiseUnaryOp>(
|
||||
rewriter, elementaryType, elementAttr);
|
||||
resVector.emplace_back(res);
|
||||
} else {
|
||||
// Recurse.
|
||||
auto shape = attr.getType().getShape();
|
||||
int rank = shape.size();
|
||||
int index = rank - freeRank;
|
||||
int size = attr.getType().getShape()[index];
|
||||
for (int i = 0; i < size; ++i) {
|
||||
indices[index] = i;
|
||||
RecurseConstProppElementwiseUnary<ElementwiseUnaryOp>(
|
||||
rewriter, resVector, attr, indices, freeRank - 1);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Process the constant operands, perform the operation with broadcast, and
|
||||
// generate the new constant operation.
|
||||
template <typename ElementwiseUnaryOp>
|
||||
DenseElementsAttr ConstPropElementwiseUnary(
|
||||
PatternRewriter &rewriter, Value resOperand, Attribute &attr) {
|
||||
DenseElementsAttr denseAttr =
|
||||
attr.dyn_cast_or_null<mlir::DenseElementsAttr>();
|
||||
assert(denseAttr && "expected dense attribute");
|
||||
assert(
|
||||
resOperand.getType().isa<RankedTensorType>() && "expected ranked tensor");
|
||||
ShapedType resType = resOperand.getType().cast<RankedTensorType>();
|
||||
auto rank = denseAttr.getType().getShape().size();
|
||||
SmallVector<uint64_t, 4> indices(rank, 0);
|
||||
std::vector<Attribute> resVector;
|
||||
RecurseConstProppElementwiseUnary<ElementwiseUnaryOp>(
|
||||
rewriter, resVector, denseAttr, indices, rank);
|
||||
ArrayRef<Attribute> resRef(resVector);
|
||||
return DenseElementsAttr::get(resType, resRef);
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// Pattern definition.
|
||||
// =============================================================================
|
||||
|
||||
#include "src/Transform/ONNX/ONNXConstProp.inc"
|
||||
|
||||
// =============================================================================
|
||||
// Code to manage the pass.
|
||||
// =============================================================================
|
||||
|
||||
struct ConstPropONNXToONNXPass
|
||||
: public PassWrapper<ConstPropONNXToONNXPass, FunctionPass> {
|
||||
void runOnFunction() final;
|
||||
};
|
||||
} // end anonymous namespace.
|
||||
|
||||
void ConstPropONNXToONNXPass::runOnFunction() {
|
||||
auto function = getFunction();
|
||||
MLIRContext *context = &getContext();
|
||||
|
||||
ConversionTarget target(getContext());
|
||||
target.addLegalDialect<ONNXOpsDialect>();
|
||||
|
||||
OwningRewritePatternList patterns;
|
||||
populateWithGenerated(context, &patterns);
|
||||
|
||||
applyPatternsAndFoldGreedily(function, patterns);
|
||||
} // end anonymous namespace
|
||||
|
||||
/*!
|
||||
* Create a ConstPropONNX pass.
|
||||
*/
|
||||
std::unique_ptr<mlir::Pass> mlir::createConstPropONNXToONNXPass() {
|
||||
return std::make_unique<ConstPropONNXToONNXPass>();
|
||||
}
|
||||
|
||||
static PassRegistration<ConstPropONNXToONNXPass> pass("constprop-onnx",
|
||||
"ConstProp ONNX operations into composition of other ONNX operations.");
|
|
@ -0,0 +1,161 @@
|
|||
//===- ONNXConstProp.td - Rewriting for Constant Propagation in ONNX Ops -*- tablegen -===//
|
||||
//
|
||||
// Copyright 2019-2020 The IBM Research Authors.
|
||||
//
|
||||
// =============================================================================
|
||||
//
|
||||
// Defines language-specific pattern match rewritings for ONNX using
|
||||
// Declarative Rewrite Rules (DRR) specified using TableGen records.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef ONNX_CONSTPROP
|
||||
#define ONNX_CONSTPROP
|
||||
|
||||
#ifndef OP_BASE
|
||||
include "src/Dialect/ONNX/ONNXOps.td"
|
||||
#endif // OP_BASE
|
||||
|
||||
|
||||
// =============================================================================
|
||||
// Instruction to add new constant operation rules. Minimally, you will have added
|
||||
// operation in the ONNXConstProp.cpp to perform the element-wise single value
|
||||
// handling of the new operator that you are dealing with. You will need to
|
||||
// generate a call to the method that handle the tensor constant prop. Here
|
||||
// is the call for a unary and binary operation. Adapt to your new operator:
|
||||
//
|
||||
// def CreateAddOfTwoConst :
|
||||
// NativeCodeCall<"ConstPropElementwiseBinary<mlir::ONNXAddOp>($_builder, $0, $1, $2)">;
|
||||
//
|
||||
// def CreateNegOfConst :
|
||||
// NativeCodeCall<"ConstPropElementwiseUnary<mlir::ONNXNegOp>($_builder, $0, $1)">;
|
||||
//
|
||||
// where you will have mostly to substitute your new operator as well as using
|
||||
// a new def name.
|
||||
//
|
||||
// Then you will need to add substitution rules, see examples below.
|
||||
// =============================================================================
|
||||
|
||||
|
||||
// Useful test definitions.
|
||||
|
||||
def IsNotAConstant :
|
||||
Constraint<CPred<"! dyn_cast_or_null<ONNXConstantOp>(($_self).getDefiningOp())">,
|
||||
"operation is not a constant">;
|
||||
|
||||
def AttributeIsNull :
|
||||
Constraint<CPred<"! ($_self)">,
|
||||
"Attribute is null">;
|
||||
|
||||
|
||||
// Usefult code generation invokation.
|
||||
def GetNullAttr : NativeCodeCall<"Attribute()">;
|
||||
|
||||
def CreateAddOfTwoConst :
|
||||
NativeCodeCall<"ConstPropElementwiseBinary<mlir::ONNXAddOp>($_builder, $0, $1, $2)">;
|
||||
|
||||
def CreateSubOfTwoConst :
|
||||
NativeCodeCall<"ConstPropElementwiseBinary<mlir::ONNXSubOp>($_builder, $0, $1, $2)">;
|
||||
|
||||
def CreateNegOfConst :
|
||||
NativeCodeCall<"ConstPropElementwiseUnary<mlir::ONNXNegOp>($_builder, $0, $1)">;
|
||||
|
||||
def CreateMulOfTwoConst :
|
||||
NativeCodeCall<"ConstPropElementwiseBinary<mlir::ONNXMulOp>($_builder, $0, $1, $2)">;
|
||||
|
||||
// =============================================================================
|
||||
// Patterns to enable opportunities with elementwise ADD operations.
|
||||
|
||||
// Use commutativity to normalize constants in the second position of Add.
|
||||
def AddConstCommutative1 : Pat<
|
||||
// From add(c, x).
|
||||
(ONNXAddOp (ONNXConstantOp:$c $_, $_), $x),
|
||||
// To add(x, c).
|
||||
(ONNXAddOp $x, $c),
|
||||
// To avoid infinite loop, constrain the first arguments to be anything but a constant.
|
||||
[(IsNotAConstant:$x)]>;
|
||||
|
||||
// Use associativity to add constants together.
|
||||
def AddConstAssociative1 : Pat<
|
||||
// From add(add(x, c1), c2).
|
||||
(ONNXAddOp
|
||||
(ONNXAddOp $x,(ONNXConstantOp:$c1 $_, $_)),
|
||||
(ONNXConstantOp:$c2 $_, $_)),
|
||||
// To add(x, add(c1, c2)).
|
||||
(ONNXAddOp
|
||||
$x,
|
||||
(ONNXAddOp $c1, $c2))>;
|
||||
|
||||
// Constant Propagation for Add
|
||||
def AddConstProp : Pat<
|
||||
// From add(c1, c2).
|
||||
(ONNXAddOp:$addOp (ONNXConstantOp $s1, $v1), (ONNXConstantOp $s2, $v2)),
|
||||
// To c1+c2
|
||||
(ONNXConstantOp (GetNullAttr), (CreateAddOfTwoConst $addOp, $v1, $v2)),
|
||||
// Additional constraints (no sparse)
|
||||
[(AttributeIsNull:$s1), (AttributeIsNull:$s2)]>;
|
||||
|
||||
|
||||
// =============================================================================
|
||||
// Patterns to enable opportunities with elementwise SUB / NEG operations.
|
||||
|
||||
// Constant Propagation for Sub
|
||||
def SubConstProp : Pat<
|
||||
// From sub(c1, c2).
|
||||
(ONNXSubOp:$subOp (ONNXConstantOp $s1, $v1), (ONNXConstantOp $s2, $v2)),
|
||||
// To c1-c2
|
||||
(ONNXConstantOp (GetNullAttr), (CreateSubOfTwoConst $subOp, $v1, $v2)),
|
||||
[(AttributeIsNull:$s1), (AttributeIsNull:$s2)]>;
|
||||
|
||||
// Neg of constant is simly -const
|
||||
def NegofConst : Pat<
|
||||
// From - (c)
|
||||
(ONNXNegOp (ONNXConstantOp:$constOp $s, $v)),
|
||||
// To (-c)
|
||||
(ONNXConstantOp (GetNullAttr), (CreateNegOfConst $constOp, $v)),
|
||||
[(AttributeIsNull:$s)]>;
|
||||
|
||||
// Change a subtraction of a constant c by an addition of -c. Helpfull to combine
|
||||
// with other add optimizations.
|
||||
def SubConstToNeg : Pat<
|
||||
// From x - c.
|
||||
(ONNXSubOp:$subOp $x, (ONNXConstantOp:$constOp $s, $v)),
|
||||
// To x + (-c).
|
||||
(ONNXAddOp $x, (ONNXConstantOp (GetNullAttr), (CreateNegOfConst $constOp, $v))),
|
||||
[(IsNotAConstant:$x), (AttributeIsNull:$s)]>;
|
||||
|
||||
|
||||
// =============================================================================
|
||||
// Patterns to enable opportunities with elementwise MUL operations.
|
||||
// Exactly the same pattern as for the elementwise ADD operations.
|
||||
|
||||
// Use commutativity to normalize constants in the second position of Mul.
|
||||
def MulConstCommutative1 : Pat<
|
||||
// From mul(c, x).
|
||||
(ONNXMulOp (ONNXConstantOp:$c $_, $_), $x),
|
||||
// To mul(x, c).
|
||||
(ONNXMulOp $x, $c),
|
||||
// To avoid infinite loop, constrain the first arguments to be anything but a constant.
|
||||
[(IsNotAConstant:$x)]>;
|
||||
|
||||
// Use associativity to mul constants together.
|
||||
def MulConstAssociative1 : Pat<
|
||||
// From mul(mul(x, c1), c2).
|
||||
(ONNXMulOp
|
||||
(ONNXMulOp $x,(ONNXConstantOp:$c1 $_, $_)),
|
||||
(ONNXConstantOp:$c2 $_, $_)),
|
||||
// To mul(x, mul(c1, c2)).
|
||||
(ONNXMulOp
|
||||
$x,
|
||||
(ONNXMulOp $c1, $c2))>;
|
||||
|
||||
// Constant Propagation for Mul
|
||||
def MulConstProp : Pat<
|
||||
// From mul(c1, c2).
|
||||
(ONNXMulOp:$mulOp (ONNXConstantOp $s1, $v1), (ONNXConstantOp $s2, $v2)),
|
||||
// To c1+c2
|
||||
(ONNXConstantOp (GetNullAttr), (CreateMulOfTwoConst $mulOp, $v1, $v2)),
|
||||
// Mulitional constraints (no sparse)
|
||||
[(AttributeIsNull:$s1), (AttributeIsNull:$s2)]>;
|
||||
|
||||
#endif // ONNX_CONSTPROP
|
|
@ -0,0 +1,158 @@
|
|||
// RUN: onnx-mlir-opt --constprop-onnx %s -split-input-file | FileCheck %s
|
||||
|
||||
// =============================================================================
|
||||
/// MUL tests (same as add, so have only one).
|
||||
|
||||
/// Test ConstantOp assoc for add
|
||||
|
||||
// CHECK-LABEL: @test_add_constant_1(%arg0: tensor<3xf32>) -> tensor<3xf32>
|
||||
func @test_add_constant_1(%arg0 : tensor<3xf32>) -> tensor<3xf32> {
|
||||
%0 = "onnx.Constant"() {value = dense<[0.0, 1.0, 2.0]> : tensor<3xf32>} : () -> tensor<3xf32>
|
||||
%1 = "onnx.Add"(%0, %arg0) : (tensor<3xf32> , tensor<3xf32>) -> tensor<3xf32>
|
||||
"std.return"(%1) : (tensor<3xf32>) -> ()
|
||||
// CHECK-NEXT: [[CONST:%.+]] = "onnx.Constant"() {value = dense<[0.000000e+00, 1.000000e+00, 2.000000e+00]> : tensor<3xf32>} : () -> tensor<3xf32>
|
||||
// CHECK-NEXT: [[ADD:%.+]] = "onnx.Add"(%arg0, [[CONST]]) : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xf32>
|
||||
}
|
||||
|
||||
/// Test ConstantOp assoc for add
|
||||
// CHECK-LABEL: @test_add_constant_2(%arg0: tensor<3xf32>) -> tensor<3xf32>
|
||||
func @test_add_constant_2(%arg0 : tensor<3xf32>) -> tensor<3xf32> {
|
||||
%0 = "onnx.Constant"() {value = dense<[0.0, 1.0, 2.0]> : tensor<3xf32>} : () -> tensor<3xf32>
|
||||
%1 = "onnx.Add"(%arg0, %0) : (tensor<3xf32> , tensor<3xf32>) -> tensor<3xf32>
|
||||
"std.return"(%1) : (tensor<3xf32>) -> ()
|
||||
// CHECK-NEXT: [[CONST:%.+]] = "onnx.Constant"() {value = dense<[0.000000e+00, 1.000000e+00, 2.000000e+00]> : tensor<3xf32>} : () -> tensor<3xf32>
|
||||
// CHECK-NEXT: [[ADD:%.+]] = "onnx.Add"(%arg0, [[CONST]]) : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xf32>
|
||||
}
|
||||
|
||||
/// Change (x+c1)+c2 to x+(c1+c2)
|
||||
// CHECK-LABEL: @test_add_constant_3(%arg0: tensor<3xi32>) -> tensor<3xi32>
|
||||
func @test_add_constant_3(%arg0 : tensor<3xi32>) -> tensor<3xi32> {
|
||||
%0 = "onnx.Constant"() {value = dense<[0, 1, 2]> : tensor<3xi32>} : () -> tensor<3xi32>
|
||||
%1 = "onnx.Constant"() {value = dense<[10, 11, 12]> : tensor<3xi32>} : () -> tensor<3xi32>
|
||||
%2 = "onnx.Add"(%0, %arg0) : (tensor<3xi32> , tensor<3xi32>) -> tensor<3xi32>
|
||||
%3 = "onnx.Add"(%1, %2) : (tensor<3xi32> , tensor<3xi32>) -> tensor<3xi32>
|
||||
"std.return"(%3) : (tensor<3xi32>) -> ()
|
||||
// CHECK-NEXT: [[CONST1:%.+]] = "onnx.Constant"() {value = dense<[10, 12, 14]> : tensor<3xi32>} : () -> tensor<3xi32>
|
||||
// CHECK-NEXT: [[ADD1:%.+]] = "onnx.Add"(%arg0, [[CONST1]]) : (tensor<3xi32>, tensor<3xi32>) -> tensor<3xi32>
|
||||
}
|
||||
|
||||
/// Same test as above, but with a use of an intermediary result
|
||||
/// change (x+c1)+c2 + (x+c1) to x+(c1+c2) + (x+c1)
|
||||
// CHECK-LABEL: @test_add_constant_4(%arg0: tensor<3xi32>) -> tensor<3xi32>
|
||||
func @test_add_constant_4(%arg0 : tensor<3xi32>) -> tensor<3xi32> {
|
||||
%0 = "onnx.Constant"() {value = dense<[0, 1, 2]> : tensor<3xi32>} : () -> tensor<3xi32>
|
||||
%1 = "onnx.Constant"() {value = dense<[10, 11, 12]> : tensor<3xi32>} : () -> tensor<3xi32>
|
||||
%2 = "onnx.Add"(%0, %arg0) : (tensor<3xi32> , tensor<3xi32>) -> tensor<3xi32>
|
||||
%3 = "onnx.Add"(%1, %2) : (tensor<3xi32> , tensor<3xi32>) -> tensor<3xi32>
|
||||
%4 = "onnx.Add"(%2, %3) : (tensor<3xi32> , tensor<3xi32>) -> tensor<3xi32>
|
||||
"std.return"(%4) : (tensor<3xi32>) -> ()
|
||||
// CHECK-NEXT: [[CONST1:%.+]] = "onnx.Constant"() {value = dense<[0, 1, 2]> : tensor<3xi32>} : () -> tensor<3xi32>
|
||||
// CHECK-NEXT: [[ADD1:%.+]] = "onnx.Add"(%arg0, [[CONST1]]) : (tensor<3xi32>, tensor<3xi32>) -> tensor<3xi32>
|
||||
// CHECK-NEXT: [[CONST2:%.+]] = "onnx.Constant"() {value = dense<[10, 12, 14]> : tensor<3xi32>} : () -> tensor<3xi32>
|
||||
// CHECK-NEXT: [[ADD2:%.+]] = "onnx.Add"(%arg0, [[CONST2]]) : (tensor<3xi32>, tensor<3xi32>) -> tensor<3xi32>
|
||||
// CHECK-NEXT: [[ADD3:%.+]] = "onnx.Add"([[ADD1]], [[ADD2]]) : (tensor<3xi32>, tensor<3xi32>) -> tensor<3xi32>
|
||||
}
|
||||
|
||||
/// Test broadcast 1 -> 2d
|
||||
|
||||
// CHECK-LABEL: @test_broadcast_1(%arg0: tensor<3x2xi32>) -> tensor<3x2xi32>
|
||||
func @test_broadcast_1(%arg0: tensor<3x2xi32>) -> tensor<3x2xi32> {
|
||||
%0 = "onnx.Constant"() {value = dense<[1]> : tensor<1xi32>} : () -> tensor<1xi32>
|
||||
%1 = "onnx.Constant"() {value = dense<[[2, 3], [4, 5], [6, 7]]> : tensor<3x2xi32>} : () -> tensor<3x2xi32>
|
||||
%2 = "onnx.Add"(%0, %1) : (tensor<1xi32> , tensor<3x2xi32>) -> tensor<3x2xi32>
|
||||
%3 = "onnx.Add"(%2, %arg0) : (tensor<3x2xi32> , tensor<3x2xi32>) -> tensor<3x2xi32>
|
||||
"std.return"(%3) : (tensor<3x2xi32>) -> ()
|
||||
// CHECK-NEXT: [[CONST1:%.+]] = "onnx.Constant"() {value = dense<{{.}}[3, 4], [5, 6], [7, 8]]> : tensor<3x2xi32>} : () -> tensor<3x2xi32>
|
||||
// CHECK-NEXT: [[ADD1:%.+]] = "onnx.Add"(%arg0, [[CONST1]]) : (tensor<3x2xi32>, tensor<3x2xi32>) -> tensor<3x2xi32>
|
||||
}
|
||||
|
||||
/// Test broadcast 2d (size one) -> 2d
|
||||
|
||||
// CHECK-LABEL: @test_broadcast_2(%arg0: tensor<3x2xi32>) -> tensor<3x2xi32>
|
||||
func @test_broadcast_2(%arg0: tensor<3x2xi32>) -> tensor<3x2xi32> {
|
||||
%0 = "onnx.Constant"() {value = dense<[[1]]> : tensor<1x1xi32>} : () -> tensor<1x1xi32>
|
||||
%1 = "onnx.Constant"() {value = dense<[[2, 3], [4, 5], [6, 7]]> : tensor<3x2xi32>} : () -> tensor<3x2xi32>
|
||||
%2 = "onnx.Add"(%0, %1) : (tensor<1x1xi32> , tensor<3x2xi32>) -> tensor<3x2xi32>
|
||||
%3 = "onnx.Add"(%2, %arg0) : (tensor<3x2xi32> , tensor<3x2xi32>) -> tensor<3x2xi32>
|
||||
"std.return"(%3) : (tensor<3x2xi32>) -> ()
|
||||
// CHECK-NEXT: [[CONST1:%.+]] = "onnx.Constant"() {value = dense<{{.}}[3, 4], [5, 6], [7, 8]]> : tensor<3x2xi32>} : () -> tensor<3x2xi32>
|
||||
// CHECK-NEXT: [[ADD1:%.+]] = "onnx.Add"(%arg0, [[CONST1]]) : (tensor<3x2xi32>, tensor<3x2xi32>) -> tensor<3x2xi32>
|
||||
}
|
||||
|
||||
/// check 1d -> 2d
|
||||
|
||||
// CHECK-LABEL: @test_broadcast_3(%arg0: tensor<3x2xi32>) -> tensor<3x2xi32>
|
||||
func @test_broadcast_3(%arg0 : tensor<3x2xi32>) -> tensor<3x2xi32> {
|
||||
%0 = "onnx.Constant"() {value = dense<[[1], [2], [3]]> : tensor<3x1xi32>} : () -> tensor<3x1xi32>
|
||||
%1 = "onnx.Constant"() {value = dense<[[10, 11], [21, 22], [31, 32]]> : tensor<3x2xi32>} : () -> tensor<3x2xi32>
|
||||
%2 = "onnx.Add"(%0, %1) : (tensor<3x1xi32> , tensor<3x2xi32>) -> tensor<3x2xi32>
|
||||
%3 = "onnx.Add"(%2, %arg0) : (tensor<3x2xi32> , tensor<3x2xi32>) -> tensor<3x2xi32>
|
||||
"std.return"(%3) : (tensor<3x2xi32>) -> ()
|
||||
// CHECK-NEXT: [[CONST1:%.+]] = "onnx.Constant"() {value = dense<{{.}}[11, 12], [23, 24], [34, 35]]> : tensor<3x2xi32>} : () -> tensor<3x2xi32>
|
||||
// CHECK-NEXT: [[ADD1:%.+]] = "onnx.Add"(%arg0, [[CONST1]]) : (tensor<3x2xi32>, tensor<3x2xi32>) -> tensor<3x2xi32>
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
/// MUL tests (same as add, so have only one).
|
||||
|
||||
/// Change (x*c1)*c2 to x*(c1*c2)
|
||||
// CHECK-LABEL: @test_mul_constant_3(%arg0: tensor<3xi32>) -> tensor<3xi32>
|
||||
func @test_mul_constant_3(%arg0 : tensor<3xi32>) -> tensor<3xi32> {
|
||||
%0 = "onnx.Constant"() {value = dense<[0, 1, 2]> : tensor<3xi32>} : () -> tensor<3xi32>
|
||||
%1 = "onnx.Constant"() {value = dense<[10, 11, 12]> : tensor<3xi32>} : () -> tensor<3xi32>
|
||||
%2 = "onnx.Mul"(%0, %arg0) : (tensor<3xi32> , tensor<3xi32>) -> tensor<3xi32>
|
||||
%3 = "onnx.Mul"(%1, %2) : (tensor<3xi32> , tensor<3xi32>) -> tensor<3xi32>
|
||||
"std.return"(%3) : (tensor<3xi32>) -> ()
|
||||
// CHECK-NEXT: [[CONST1:%.+]] = "onnx.Constant"() {value = dense<[0, 11, 24]> : tensor<3xi32>} : () -> tensor<3xi32>
|
||||
// CHECK-NEXT: [[MUL1:%.+]] = "onnx.Mul"(%arg0, [[CONST1]]) : (tensor<3xi32>, tensor<3xi32>) -> tensor<3xi32>
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
/// SUB and NEG tests.
|
||||
|
||||
|
||||
// check of sub two constants
|
||||
|
||||
// CHECK-LABEL: @test_sub_1(%arg0: tensor<3x2xi32>) -> tensor<3x2xi32>
|
||||
func @test_sub_1(%arg0: tensor<3x2xi32>) -> tensor<3x2xi32> {
|
||||
%0 = "onnx.Constant"() {value = dense<[[2, 3], [4, 5], [6, 7]]> : tensor<3x2xi32>} : () -> tensor<3x2xi32>
|
||||
%1 = "onnx.Constant"() {value = dense<[[2]]> : tensor<1x1xi32>} : () -> tensor<1x1xi32>
|
||||
%2 = "onnx.Sub"(%0, %1) : (tensor<3x2xi32>, tensor<1x1xi32>) -> tensor<3x2xi32>
|
||||
"std.return"(%2) : (tensor<3x2xi32>) -> ()
|
||||
// CHECK-NEXT: [[CONST1:%.+]] = "onnx.Constant"() {value = dense<{{.}}[0, 1], [2, 3], [4, 5]]> : tensor<3x2xi32>} : () -> tensor<3x2xi32>
|
||||
}
|
||||
|
||||
/// check sub to add of negative
|
||||
|
||||
// CHECK-LABEL: @test_neg_1(%arg0: tensor<3x2xi32>) -> tensor<3x2xi32>
|
||||
func @test_neg_1(%arg0: tensor<3x2xi32>) -> tensor<3x2xi32> {
|
||||
%0 = "onnx.Constant"() {value = dense<[[2, 3], [4, 5], [6, 7]]> : tensor<3x2xi32>} : () -> tensor<3x2xi32>
|
||||
%1 = "onnx.Sub"(%arg0, %0) : (tensor<3x2xi32> , tensor<3x2xi32>) -> tensor<3x2xi32>
|
||||
"std.return"(%1) : (tensor<3x2xi32>) -> ()
|
||||
// CHECK-NEXT: [[CONST1:%.+]] = "onnx.Constant"() {value = dense<{{.}}[-2, -3], [-4, -5], [-6, -7]]> : tensor<3x2xi32>} : () -> tensor<3x2xi32>
|
||||
// CHECK-NEXT: [[ADD1:%.+]] = "onnx.Add"(%arg0, [[CONST1]]) : (tensor<3x2xi32>, tensor<3x2xi32>) -> tensor<3x2xi32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @test_neg_2(%arg0: tensor<3x2xi32>) -> tensor<3x2xi32>
|
||||
func @test_neg_2(%arg0: tensor<3x2xi32>) -> tensor<3x2xi32> {
|
||||
%0 = "onnx.Constant"() {value = dense<[[2, 3], [4, 5], [6, 7]]> : tensor<3x2xi32>} : () -> tensor<3x2xi32>
|
||||
%1 = "onnx.Constant"() {value = dense<[[10]]> : tensor<1x1xi32>} : () -> tensor<1x1xi32>
|
||||
%2 = "onnx.Sub"(%arg0, %0) : (tensor<3x2xi32> , tensor<3x2xi32>) -> tensor<3x2xi32>
|
||||
%5 = "onnx.Add"(%2, %1) : (tensor<3x2xi32> , tensor<1x1xi32>) -> tensor<3x2xi32>
|
||||
"std.return"(%5) : (tensor<3x2xi32>) -> ()
|
||||
// CHECK-NEXT: [[CONST1:%.+]] = "onnx.Constant"() {value = dense<{{.}}[8, 7], [6, 5], [4, 3]]> : tensor<3x2xi32>} : () -> tensor<3x2xi32>
|
||||
// CHECK-NEXT: [[ADD1:%.+]] = "onnx.Add"(%arg0, [[CONST1]]) : (tensor<3x2xi32>, tensor<3x2xi32>) -> tensor<3x2xi32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @test_neg_3(%arg0: tensor<3x2xi32>) -> tensor<3x2xi32>
|
||||
func @test_neg_3(%arg0: tensor<3x2xi32>) -> tensor<3x2xi32> {
|
||||
%0 = "onnx.Constant"() {value = dense<[[2, 3], [4, 5], [6, 7]]> : tensor<3x2xi32>} : () -> tensor<3x2xi32>
|
||||
%1 = "onnx.Constant"() {value = dense<[[10]]> : tensor<1x1xi32>} : () -> tensor<1x1xi32>
|
||||
%2 = "onnx.Neg"(%0) : (tensor<3x2xi32>) -> tensor<3x2xi32>
|
||||
%3 = "onnx.Add"(%arg0, %2) : (tensor<3x2xi32> , tensor<3x2xi32>) -> tensor<3x2xi32>
|
||||
%4 = "onnx.Add"(%3, %1) : (tensor<3x2xi32> , tensor<1x1xi32>) -> tensor<3x2xi32>
|
||||
"std.return"(%4) : (tensor<3x2xi32>) -> ()
|
||||
// CHECK-NEXT: [[CONST1:%.+]] = "onnx.Constant"() {value = dense<{{.}}[8, 7], [6, 5], [4, 3]]> : tensor<3x2xi32>} : () -> tensor<3x2xi32>
|
||||
// CHECK-NEXT: [[ADD1:%.+]] = "onnx.Add"(%arg0, [[CONST1]]) : (tensor<3x2xi32>, tensor<3x2xi32>) -> tensor<3x2xi32>
|
||||
}
|
||||
|
|
@ -295,8 +295,15 @@ OpsWithResultTypeInference = {
|
|||
# Currenlty, there are only two build methods generated:
|
||||
# - one with operands and attributes having a separate parameter, and
|
||||
# - one with operands and attributes having aggregated parameters.
|
||||
custom_builder_ops_list = ['Abs', 'Mul', 'Exp', 'ReduceSum', 'ReduceSumSquare', 'Pad']
|
||||
|
||||
custom_builder_unranked_ops_list = ['Abs', 'Exp', 'ReduceSum', 'ReduceSumSquare', 'Pad']
|
||||
# Custom builder op list for operations with broadcast; we can deduce the right
|
||||
# output type, no need to leave it undef as in the above list.
|
||||
# Ops must have two operands, not one, not three... And there shall be two.
|
||||
# TODO: handle variadic ops omitted here: Max, Min, Min, Sum.
|
||||
custom_builder_broadcast_ops_list = ['Add', 'And', 'Div', 'Equal', 'Greater',
|
||||
'Less', 'Mul', 'Or', 'Pow', 'Sub', 'Xor']
|
||||
# union of both
|
||||
custom_builder_ops_list = custom_builder_unranked_ops_list + custom_builder_broadcast_ops_list
|
||||
|
||||
#a dictionary to add any special definition for an operation
|
||||
custom_definition_misc = dict([ ('Constant',
|
||||
|
@ -716,6 +723,8 @@ def gen_op_def(schema):
|
|||
s += indent + 'let results = (outs {});\n'.format(
|
||||
(',\n' + inc_indent(indent)).join(outs_strs))
|
||||
|
||||
# custom_builder_broadcast_ops_list
|
||||
|
||||
# add custom builders
|
||||
# use element type of the first operand to construct an UnrankedTensorType for the output.
|
||||
if schema.name in custom_builder_ops_list:
|
||||
|
@ -726,7 +735,8 @@ def gen_op_def(schema):
|
|||
else:
|
||||
s += indent + 'let builders = [\n'
|
||||
# Custom builders with operands and attributes having a seperate parameter.
|
||||
# E.g. OpBuilder<"OpBuilder &builder, OperationState &state, Value X, Value, Y, Attribute A", [{}]>
|
||||
# E.g. OpBuilder<"OpBuilder &builder, OperationState &state, Value X,
|
||||
# Value, Y, Attribute A", [{}]>
|
||||
indent = inc_indent(indent)
|
||||
s += indent + 'OpBuilder<"OpBuilder &builder, OperationState &state'
|
||||
operands_dict = get_operands_or_results(schema, is_input=True)
|
||||
|
@ -740,9 +750,26 @@ def gen_op_def(schema):
|
|||
|
||||
# Get output type from first operand's type.
|
||||
first_operand_name = list(ins.items())[0][0]
|
||||
s += indent + 'auto elementType = {}.getType().cast<TensorType>().getElementType();\n'.format(
|
||||
first_operand_name)
|
||||
s += indent + 'build(builder, state, UnrankedTensorType::get(elementType)'
|
||||
build_type_name = ''
|
||||
if schema.name in custom_builder_broadcast_ops_list:
|
||||
second_operand_name = list(ins.items())[1][0]
|
||||
s += indent + 'auto lhsTy = {}.getType().cast<RankedTensorType>();\n'. \
|
||||
format(first_operand_name)
|
||||
s += indent + 'auto rhsTy = {}.getType().cast<RankedTensorType>();\n'. \
|
||||
format(second_operand_name)
|
||||
s += indent + 'auto elementType = getBroadcastedType(lhsTy, rhsTy);\n'
|
||||
s += indent + 'auto shapedType = elementType.dyn_cast_or_null<ShapedType>();\n';
|
||||
s += indent + 'if (!shapedType || !shapedType.hasStaticShape()) {\n';
|
||||
s += indent + indent + 'elementType = {}'.format(first_operand_name) + \
|
||||
'.getType().cast<TensorType>().getElementType();\n';
|
||||
s += indent + indent + 'elementType = UnrankedTensorType::get(elementType);\n'
|
||||
s += indent + '}\n';
|
||||
build_type_name = 'elementType'
|
||||
else:
|
||||
s += indent + 'auto elementType = {}'.format(first_operand_name) + \
|
||||
'.getType().cast<TensorType>().getElementType();\n'
|
||||
build_type_name = 'UnrankedTensorType::get(elementType)'
|
||||
s += indent + 'build(builder, state, {}'.format(build_type_name)
|
||||
for name, _ in ins.items():
|
||||
s += ', ' + name
|
||||
s += ');\n'
|
||||
|
@ -750,12 +777,26 @@ def gen_op_def(schema):
|
|||
s += indent + '}]>,\n'
|
||||
|
||||
# Custom builders with all operands and attributes having aggregate parameters.
|
||||
# E.g. OpBuilder<"OpBuilder &builder, OperationState &state, ValueRange operands, ArrayRef<NamedAttribute> attributes", [{}]>'
|
||||
s += indent + 'OpBuilder<"OpBuilder &builder, OperationState &state, ValueRange operands, ArrayRef<NamedAttribute> attributes", [{\n'
|
||||
# E.g. OpBuilder<"OpBuilder &builder, OperationState &state, ValueRange operands,
|
||||
# ArrayRef<NamedAttribute> attributes", [{}]>'
|
||||
s += indent + 'OpBuilder<"OpBuilder &builder, OperationState &state, ' + \
|
||||
'ValueRange operands, ArrayRef<NamedAttribute> attributes", [{\n'
|
||||
indent = inc_indent(indent)
|
||||
s += indent + 'auto elementType = operands[0].getType().cast<TensorType>().getElementType();\n'
|
||||
if schema.name in custom_builder_broadcast_ops_list:
|
||||
s += indent + 'auto lhsTy = operands[0].getType().cast<RankedTensorType>();\n'
|
||||
s += indent + 'auto rhsTy = operands[1].getType().cast<RankedTensorType>();\n'
|
||||
s += indent + 'auto elementType = getBroadcastedType(lhsTy, rhsTy);\n'
|
||||
s += indent + 'auto shapedType = elementType.dyn_cast_or_null<ShapedType>();\n';
|
||||
s += indent + 'if (!shapedType || !shapedType.hasStaticShape()) {\n';
|
||||
s += indent + indent + 'elementType = operands[0]' + \
|
||||
'.getType().cast<TensorType>().getElementType();\n';
|
||||
s += indent + indent + 'elementType = UnrankedTensorType::get(elementType);\n'
|
||||
s += indent + '}\n';
|
||||
else:
|
||||
s += indent + 'auto elementType = operands[0].getType().' + \
|
||||
'cast<TensorType>().getElementType();\n'
|
||||
s += indent + 'std::vector<mlir::Type> outputTypes;\n'
|
||||
s += indent + 'outputTypes.emplace_back(UnrankedTensorType::get(elementType));\n'
|
||||
s += indent + 'outputTypes.emplace_back({});\n'.format(build_type_name)
|
||||
s += indent + 'build(builder, state, outputTypes, operands, attributes);\n'
|
||||
indent = dec_indent(indent)
|
||||
s += indent + '}]>'
|
||||
|
|
Loading…
Reference in New Issue