[NFC] Fix comment style, shorten file names. (#169)

* Fix comment style, shorten file names.

* Update CMakeLists.txt

* Rename ONNXRewrite.td to Rewrite.td
This commit is contained in:
Tian Jin 2020-06-15 11:49:09 +08:00 committed by GitHub
parent 60c648ae39
commit a7781791e9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 113 additions and 65 deletions

View File

@ -122,7 +122,7 @@ RankedTensorType getReductionOutputType(
//===----------------------------------------------------------------------===//
// Support function that computes default values for dilations.
//
//===----------------------------------------------------------------------===//
template <class T>
static LogicalResult processConvDilationParam(
T *op, Optional<ArrayAttr> kernelShape) {
@ -153,7 +153,7 @@ static LogicalResult processConvDilationParam(
//===----------------------------------------------------------------------===//
// Support function that computes default values for strides.
//
//===----------------------------------------------------------------------===//
template <class T>
static LogicalResult processConvStrideParam(
T *op, Optional<ArrayAttr> kernelShape) {
@ -181,7 +181,7 @@ static LogicalResult processConvStrideParam(
//===----------------------------------------------------------------------===//
// Support function that computes default values for pads.
//
//===----------------------------------------------------------------------===//
template <class T>
static LogicalResult processConvPadParam(T *op, ArrayRef<int64_t> inputShape,
Optional<ArrayAttr> kernelShape, Optional<ArrayAttr> stridesOpt,
@ -273,8 +273,8 @@ static LogicalResult processConvPadParam(T *op, ArrayRef<int64_t> inputShape,
}
//===----------------------------------------------------------------------===//
// Support function that computes default values for dilations, strides, and
// pads.
// Support function computing default values for dilations, strides, and pads.
//===----------------------------------------------------------------------===//
template <class T>
static LogicalResult processConvTypeParams(T *op, Value inputOperand) {
auto builder = mlir::Builder(op->getContext());
@ -305,7 +305,7 @@ static LogicalResult processConvTypeParams(T *op, Value inputOperand) {
//===----------------------------------------------------------------------===//
// Compute spatial dimensions given dilations, strides, pads, and ceil mode.
//
//===----------------------------------------------------------------------===//
static void insertConvSpatialDim(SmallVector<int64_t, 4> *outputDims,
Builder &builder, ArrayRef<int64_t> xShape, Optional<ArrayAttr> kernelShape,
Optional<ArrayAttr> padsOpt, Optional<ArrayAttr> stridesOpt,
@ -335,6 +335,7 @@ static void insertConvSpatialDim(SmallVector<int64_t, 4> *outputDims,
//===----------------------------------------------------------------------===//
// Support function that infers shape for RNN operations.
//===----------------------------------------------------------------------===//
template <typename T>
static LogicalResult RNNShapeInference(T *op) {
Value X = op->X();
@ -516,6 +517,7 @@ LogicalResult ONNXExpOp::inferShapes() {
//===----------------------------------------------------------------------===//
// Atan
//===----------------------------------------------------------------------===//
/// Infer the output shape of the ONNXAtanOp. This method is required by the
/// shape inference interface.
LogicalResult ONNXAtanOp::inferShapes() {
@ -525,6 +527,7 @@ LogicalResult ONNXAtanOp::inferShapes() {
//===----------------------------------------------------------------------===//
// Tan
//===----------------------------------------------------------------------===//
/// Infer the output shape of the ONNXTanOp. This method is required by the
/// shape inference interface.
LogicalResult ONNXTanOp::inferShapes() {
@ -543,6 +546,7 @@ LogicalResult ONNXTanhOp::inferShapes() {
//===----------------------------------------------------------------------===//
// Sin
//===----------------------------------------------------------------------===//
/// Infer the output shape of the ONNXSinOp. This method is required by the
/// shape inference interface.
LogicalResult ONNXSinOp::inferShapes() {
@ -552,6 +556,7 @@ LogicalResult ONNXSinOp::inferShapes() {
//===----------------------------------------------------------------------===//
// Sinh
//===----------------------------------------------------------------------===//
/// Infer the output shape of the ONNXSinhOp. This method is required by the
/// shape inference interface.
LogicalResult ONNXSinhOp::inferShapes() {
@ -561,6 +566,7 @@ LogicalResult ONNXSinhOp::inferShapes() {
//===----------------------------------------------------------------------===//
// Cosh
//===----------------------------------------------------------------------===//
/// Infer the output shape of the ONNXCoshOp. This method is required by the
/// shape inference interface.
LogicalResult ONNXCoshOp::inferShapes() {
@ -570,6 +576,7 @@ LogicalResult ONNXCoshOp::inferShapes() {
//===----------------------------------------------------------------------===//
// Cos
//===----------------------------------------------------------------------===//
/// Infer the output shape of the ONNXCosOp. This method is required by the
/// shape inference interface.
LogicalResult ONNXCosOp::inferShapes() {
@ -579,6 +586,7 @@ LogicalResult ONNXCosOp::inferShapes() {
//===----------------------------------------------------------------------===//
// Log
//===----------------------------------------------------------------------===//
/// Infer the output shape of the ONNXLogOp. This method is required by the
/// shape inference interface.
LogicalResult ONNXLogOp::inferShapes() {
@ -588,6 +596,7 @@ LogicalResult ONNXLogOp::inferShapes() {
//===----------------------------------------------------------------------===//
// HardSigmoid
//===----------------------------------------------------------------------===//
/// Infer the output shape of the ONNXHardSigmoidOp. This method is required by
/// the shape inference interface.
LogicalResult ONNXHardSigmoidOp::inferShapes() {
@ -597,6 +606,7 @@ LogicalResult ONNXHardSigmoidOp::inferShapes() {
//===----------------------------------------------------------------------===//
// Sigmoid
//===----------------------------------------------------------------------===//
/// Infer the output shape of the ONNXSigmoidOp. This method is required by the
/// shape inference interface.
LogicalResult ONNXSigmoidOp::inferShapes() {
@ -606,6 +616,7 @@ LogicalResult ONNXSigmoidOp::inferShapes() {
//===----------------------------------------------------------------------===//
// Elu
//===----------------------------------------------------------------------===//
/// Infer the output shape of the ONNXEluOp. This method is required by the
/// shape inference interface.
LogicalResult ONNXEluOp::inferShapes() {
@ -615,6 +626,7 @@ LogicalResult ONNXEluOp::inferShapes() {
//===----------------------------------------------------------------------===//
// Relu
//===----------------------------------------------------------------------===//
/// Infer the output shape of the ONNXReluOp. This method is required by the
/// shape inference interface.
LogicalResult ONNXReluOp::inferShapes() {
@ -624,6 +636,7 @@ LogicalResult ONNXReluOp::inferShapes() {
//===----------------------------------------------------------------------===//
// LeakyRelu
//===----------------------------------------------------------------------===//
/// Infer the output shape of the ONNXLeakyReluOp. This method is required by
/// the shape inference interface.
LogicalResult ONNXLeakyReluOp::inferShapes() {
@ -633,6 +646,7 @@ LogicalResult ONNXLeakyReluOp::inferShapes() {
//===----------------------------------------------------------------------===//
// Selu
//===----------------------------------------------------------------------===//
/// Infer the output shape of the ONNXSeluOp. This method is required by
/// the shape inference interface.
LogicalResult ONNXSeluOp::inferShapes() {
@ -642,6 +656,7 @@ LogicalResult ONNXSeluOp::inferShapes() {
//===----------------------------------------------------------------------===//
// Reciprocal
//===----------------------------------------------------------------------===//
/// Infer the output shape of the ONNXReciprocalOp. This method is required by
/// the shape inference interface.
LogicalResult ONNXReciprocalOp::inferShapes() {
@ -651,6 +666,7 @@ LogicalResult ONNXReciprocalOp::inferShapes() {
//===----------------------------------------------------------------------===//
// Softmax
//===----------------------------------------------------------------------===//
/// Infer the output shape of the ONNXSoftmaxOp. This method is required by
/// the shape inference interface.
LogicalResult ONNXSoftmaxOp::inferShapes() {
@ -660,6 +676,7 @@ LogicalResult ONNXSoftmaxOp::inferShapes() {
//===----------------------------------------------------------------------===//
// Softplus
//===----------------------------------------------------------------------===//
/// Infer the output shape of the ONNXSoftplusOp. This method is required by
/// the shape inference interface.
LogicalResult ONNXSoftplusOp::inferShapes() {
@ -669,6 +686,7 @@ LogicalResult ONNXSoftplusOp::inferShapes() {
//===----------------------------------------------------------------------===//
// Softsign
//===----------------------------------------------------------------------===//
/// Infer the output shape of the ONNXSoftsignOp. This method is required by
/// the shape inference interface.
LogicalResult ONNXSoftsignOp::inferShapes() {
@ -678,6 +696,7 @@ LogicalResult ONNXSoftsignOp::inferShapes() {
//===----------------------------------------------------------------------===//
// Sqrt
//===----------------------------------------------------------------------===//
/// Infer the output shape of the ONNXSqrtOp. This method is required by
/// the shape inference interface.
LogicalResult ONNXSqrtOp::inferShapes() {
@ -687,6 +706,7 @@ LogicalResult ONNXSqrtOp::inferShapes() {
//===----------------------------------------------------------------------===//
// Sign
//===----------------------------------------------------------------------===//
/// Infer the output shape of the ONNXSignOp. This method is required by
/// the shape inference interface.
LogicalResult ONNXSignOp::inferShapes() {
@ -696,6 +716,7 @@ LogicalResult ONNXSignOp::inferShapes() {
//===----------------------------------------------------------------------===//
// Abs
//===----------------------------------------------------------------------===//
/// Infer the output shape of the ONNXAbsOp. This method is required by the
/// shape inference interface.
LogicalResult ONNXAbsOp::inferShapes() {
@ -705,6 +726,7 @@ LogicalResult ONNXAbsOp::inferShapes() {
//===----------------------------------------------------------------------===//
// Add
//===----------------------------------------------------------------------===//
/// Infer the output shape of the ONNXAddOp. This method is required by the
/// shape inference interface.
LogicalResult ONNXAddOp::inferShapes() {
@ -719,6 +741,7 @@ LogicalResult ONNXAddOp::inferShapes() {
//===----------------------------------------------------------------------===//
// Mul
//===----------------------------------------------------------------------===//
/// Infer the output shape of the ONNXMulOp. This method is required by the
/// shape inference interface.
LogicalResult ONNXMulOp::inferShapes() {
@ -733,6 +756,7 @@ LogicalResult ONNXMulOp::inferShapes() {
//===----------------------------------------------------------------------===//
// Div
//===----------------------------------------------------------------------===//
/// Infer the output shape of the ONNXDivOp. This method is required by the
/// shape inference interface.
LogicalResult ONNXDivOp::inferShapes() {
@ -747,6 +771,7 @@ LogicalResult ONNXDivOp::inferShapes() {
//===----------------------------------------------------------------------===//
// Sub
//===----------------------------------------------------------------------===//
/// Infer the output shape of the ONNXSubOp. This method is required by the
/// shape inference interface.
LogicalResult ONNXSubOp::inferShapes() {
@ -761,6 +786,7 @@ LogicalResult ONNXSubOp::inferShapes() {
//===----------------------------------------------------------------------===//
// And
//===----------------------------------------------------------------------===//
/// Infer the output shape of the ONNXAndOp. This method is required by the
/// shape inference interface.
LogicalResult ONNXAndOp::inferShapes() {
@ -775,6 +801,7 @@ LogicalResult ONNXAndOp::inferShapes() {
//===----------------------------------------------------------------------===//
// Or
//===----------------------------------------------------------------------===//
/// Infer the output shape of the ONNXOrOp. This method is required by the
/// shape inference interface.
LogicalResult ONNXOrOp::inferShapes() {
@ -789,6 +816,7 @@ LogicalResult ONNXOrOp::inferShapes() {
//===----------------------------------------------------------------------===//
// Xor
//===----------------------------------------------------------------------===//
/// Infer the output shape of the ONNXXorOp. This method is required by the
/// shape inference interface.
LogicalResult ONNXXorOp::inferShapes() {
@ -801,10 +829,9 @@ LogicalResult ONNXXorOp::inferShapes() {
return success();
}
//===----------------------------------------------------------------------===//
//===----------------------------------------------------------------------===//
// Sum
//===----------------------------------------------------------------------===//
/// Infer the output shape of the ONNXSumOp. This method is required by the
/// shape inference interface.
LogicalResult ONNXSumOp::inferShapes() {
@ -823,6 +850,7 @@ LogicalResult ONNXSumOp::inferShapes() {
//===----------------------------------------------------------------------===//
// Max
//===----------------------------------------------------------------------===//
/// Infer the output shape of the ONNXMaxOp. This method is required by the
/// shape inference interface.
LogicalResult ONNXMaxOp::inferShapes() {
@ -841,6 +869,7 @@ LogicalResult ONNXMaxOp::inferShapes() {
//===----------------------------------------------------------------------===//
// Min
//===----------------------------------------------------------------------===//
/// Infer the output shape of the ONNXMinOp. This method is required by the
/// shape inference interface.
LogicalResult ONNXMinOp::inferShapes() {
@ -859,6 +888,7 @@ LogicalResult ONNXMinOp::inferShapes() {
//===----------------------------------------------------------------------===//
// Neg
//===----------------------------------------------------------------------===//
/// Infer the output shape of the ONNXNegOp. This method is required by the
/// shape inference interface.
LogicalResult ONNXNegOp::inferShapes() {
@ -868,6 +898,7 @@ LogicalResult ONNXNegOp::inferShapes() {
//===----------------------------------------------------------------------===//
// Identity
//===----------------------------------------------------------------------===//
/// Infer the output shape of the ONNXIdentityOp. This method is required by the
/// shape inference interface.
LogicalResult ONNXIdentityOp::inferShapes() {
@ -876,8 +907,8 @@ LogicalResult ONNXIdentityOp::inferShapes() {
}
//===----------------------------------------------------------------------===//
// MatMul
//===----------------------------------------------------------------------===//
LogicalResult ONNXMatMulOp::inferShapes() {
// Cannot infer shape if no shape exists.
@ -1006,10 +1037,7 @@ LogicalResult ONNXMatMulOp::inferShapes() {
return success();
}
//===----------------------------------------------------------------------===//
// Gemm
LogicalResult ONNXGemmOp::inferShapes() {
bool hasBias = !C().getType().isa<NoneType>();
// Cannot infer shape if no shape exists.
@ -1104,8 +1132,8 @@ LogicalResult ONNXBatchNormalizationTestModeOp::inferShapes() {
// Take into account the dimensionality of the matrix.
//===----------------------------------------------------------------------===//
// Reshape
//===----------------------------------------------------------------------===//
LogicalResult ONNXReshapeOp::inferShapes() {
// Cannot infer shape if no shape tensor is specified.
@ -1177,8 +1205,6 @@ LogicalResult ONNXReshapeOp::inferShapes() {
return success();
}
//===----------------------------------------------------------------------===//
// Transpose
LogicalResult ONNXTransposeOp::inferShapes() {
@ -1206,8 +1232,8 @@ LogicalResult ONNXTransposeOp::inferShapes() {
}
//===----------------------------------------------------------------------===//
// ReduceMax
//===----------------------------------------------------------------------===//
LogicalResult ONNXReduceMaxOp::inferShapes() {
if (!getOperand().getType().isa<RankedTensorType>())
@ -1219,8 +1245,8 @@ LogicalResult ONNXReduceMaxOp::inferShapes() {
}
//===----------------------------------------------------------------------===//
// ReduceMin
//===----------------------------------------------------------------------===//
LogicalResult ONNXReduceMinOp::inferShapes() {
if (!getOperand().getType().isa<RankedTensorType>())
@ -1232,8 +1258,8 @@ LogicalResult ONNXReduceMinOp::inferShapes() {
}
//===----------------------------------------------------------------------===//
// ReduceProd
//===----------------------------------------------------------------------===//
LogicalResult ONNXReduceProdOp::inferShapes() {
if (!getOperand().getType().isa<RankedTensorType>())
@ -1245,8 +1271,8 @@ LogicalResult ONNXReduceProdOp::inferShapes() {
}
//===----------------------------------------------------------------------===//
// ReduceSum
//===----------------------------------------------------------------------===//
LogicalResult ONNXReduceSumOp::inferShapes() {
if (!getOperand().getType().isa<RankedTensorType>())
@ -1258,8 +1284,8 @@ LogicalResult ONNXReduceSumOp::inferShapes() {
}
//===----------------------------------------------------------------------===//
// Conv
//===----------------------------------------------------------------------===//
// For this operation, we define the attributes once in the original Conv
// operation class. There is no need to redefine the attribute names for the
@ -1373,8 +1399,8 @@ LogicalResult ONNXConvOp::inferShapes() {
}
//===----------------------------------------------------------------------===//
// ConvTranspose
//===----------------------------------------------------------------------===//
// For this operation, we define the attributes once in the original Conv
// operation class. There is no need to redefine the attribute names for the
@ -1505,8 +1531,9 @@ LogicalResult ONNXConvTransposeOp::inferShapes() {
}
//===----------------------------------------------------------------------===//
// AveragePool
//===----------------------------------------------------------------------===//
// Infer shape attributes output:
// - auto_pad set to NOTSET;
// - strides: set to 1 if not defined by user;
@ -1557,8 +1584,9 @@ LogicalResult ONNXAveragePoolOp::inferShapes() {
}
//===----------------------------------------------------------------------===//
// MaxPoolSingleOut
//===----------------------------------------------------------------------===//
// Infer shape attributes output:
// - auto_pad set to NOTSET;
// - dilations, strides: set to 1 if not defined by user;
@ -1607,6 +1635,8 @@ LogicalResult ONNXMaxPoolSingleOutOp::inferShapes() {
return success();
}
//===----------------------------------------------------------------------===//
// Pad
//===----------------------------------------------------------------------===//
LogicalResult ONNXPadOp::inferShapes() {
@ -1678,7 +1708,9 @@ static Type padShapeInferenceHelper(Value data, ArrayAttr padsOpt) {
}
}
//===----------------------------------------------------------------------===//
// PadConstantPad
//===----------------------------------------------------------------------===//
LogicalResult ONNXPadConstantPadOp::inferShapes() {
auto outputType = padShapeInferenceHelper(data(), pads());
@ -1689,8 +1721,8 @@ LogicalResult ONNXPadConstantPadOp::inferShapes() {
}
//===----------------------------------------------------------------------===//
// PadConstantValuePad
//===----------------------------------------------------------------------===//
LogicalResult ONNXPadConstantValuePadOp::inferShapes() {
auto outputType = padShapeInferenceHelper(data(), pads());
@ -1711,8 +1743,8 @@ void ONNXPadConstantValuePadOp::build(OpBuilder &builder, OperationState &state,
}
//===----------------------------------------------------------------------===//
// Unsqueeze
//===----------------------------------------------------------------------===//
LogicalResult ONNXUnsqueezeOp::inferShapes() {
if (!data().getType().isa<RankedTensorType>())
@ -1753,6 +1785,7 @@ LogicalResult ONNXUnsqueezeOp::inferShapes() {
//===----------------------------------------------------------------------===//
// Cast
//===----------------------------------------------------------------------===//
LogicalResult ONNXCastOp::inferShapes() {
ShapedType inputType = input().getType().dyn_cast<ShapedType>();
@ -1781,6 +1814,7 @@ LogicalResult ONNXCastOp::inferShapes() {
//===----------------------------------------------------------------------===//
// Constant
//===----------------------------------------------------------------------===//
LogicalResult ONNXConstantOp::inferShapes() {
if ((sparse_value().hasValue() && value().hasValue()) ||
@ -1796,7 +1830,9 @@ LogicalResult ONNXConstantOp::inferShapes() {
return success();
}
//===----------------------------------------------------------------------===//
// Concat
//===----------------------------------------------------------------------===//
LogicalResult ONNXConcatOp::inferShapes() {
int inputNum = getNumOperands();
@ -1854,21 +1890,25 @@ LogicalResult ONNXConcatOp::inferShapes() {
//===----------------------------------------------------------------------===//
// RNN
//===----------------------------------------------------------------------===//
LogicalResult ONNXRNNOp::inferShapes() { return RNNShapeInference<>(this); }
//===----------------------------------------------------------------------===//
// LSTM
//===----------------------------------------------------------------------===//
LogicalResult ONNXLSTMOp::inferShapes() { return RNNShapeInference<>(this); }
//===----------------------------------------------------------------------===//
// GRU
//===----------------------------------------------------------------------===//
LogicalResult ONNXGRUOp::inferShapes() { return RNNShapeInference<>(this); }
//===----------------------------------------------------------------------===//
// Split
//===----------------------------------------------------------------------===//
LogicalResult ONNXSplitOp::inferShapes() {
if (!getOperand().getType().cast<RankedTensorType>())
@ -1933,6 +1973,7 @@ LogicalResult ONNXSplitOp::inferShapes() {
//===----------------------------------------------------------------------===//
// Flatten
//===----------------------------------------------------------------------===//
LogicalResult ONNXFlattenOp::inferShapes() {
assert(axis() == 1 && "ONNXFlattenOp can only handle axis=1 for now");
@ -1967,6 +2008,7 @@ LogicalResult ONNXFlattenOp::inferShapes() {
//===----------------------------------------------------------------------===//
// DynamicQuantizeLinear
//===----------------------------------------------------------------------===//
LogicalResult ONNXDynamicQuantizeLinearOp::inferShapes() {
auto inTy = x().getType().dyn_cast<RankedTensorType>();
@ -2000,6 +2042,7 @@ LogicalResult ONNXDynamicQuantizeLinearOp::inferShapes() {
//===----------------------------------------------------------------------===//
// QuantizeLinear
//===----------------------------------------------------------------------===//
LogicalResult ONNXQuantizeLinearOp::inferShapes() {
auto inTy = x().getType().dyn_cast<RankedTensorType>();
@ -2022,6 +2065,7 @@ LogicalResult ONNXQuantizeLinearOp::inferShapes() {
//===----------------------------------------------------------------------===//
// DequantizeLinear
//===----------------------------------------------------------------------===//
LogicalResult ONNXDequantizeLinearOp::inferShapes() {
auto inTy = x().getType().dyn_cast<RankedTensorType>();
@ -2042,6 +2086,7 @@ LogicalResult ONNXDequantizeLinearOp::inferShapes() {
//===----------------------------------------------------------------------===//
// ConvInteger - copied almost exactly from Conv (X -> x, W -> w, no bias)
//===----------------------------------------------------------------------===//
LogicalResult ONNXConvIntegerOp::inferShapes() {
// Generic shape for data input X, weight tensor W

View File

@ -26,27 +26,27 @@ add_dependencies(OMElideConstants OMONNXOps)
target_link_libraries(OMElideConstants
onnx)
set(LLVM_TARGET_DEFINITIONS ONNXRewrite.td)
set(LLVM_TARGET_DEFINITIONS Rewrite.td)
onnx_mlir_tablegen(ONNXRewrite.inc -gen-rewriters)
add_public_tablegen_target(OMONNXRewriteIncGen)
set(LLVM_TARGET_DEFINITIONS ONNXCombine.td)
set(LLVM_TARGET_DEFINITIONS Combine.td)
onnx_mlir_tablegen(ONNXCombine.inc -gen-rewriters)
add_public_tablegen_target(OMONNXCombineIncGen)
set(LLVM_TARGET_DEFINITIONS ONNXDecompose.td)
set(LLVM_TARGET_DEFINITIONS Decompose.td)
onnx_mlir_tablegen(ONNXDecompose.inc -gen-rewriters)
add_public_tablegen_target(OMONNXDecomposeIncGen)
set(LLVM_TARGET_DEFINITIONS ONNXConstProp.td)
set(LLVM_TARGET_DEFINITIONS ConstProp.td)
onnx_mlir_tablegen(ONNXConstProp.inc -gen-rewriters)
add_public_tablegen_target(OMONNXConstPropIncGen)
add_library(OMONNXRewrite
ONNXRewrite.cpp
ONNXCombine.cpp
ONNXDecompose.cpp
ONNXConstProp.cpp)
Rewrite.cpp
Combine.cpp
Decompose.cpp
ConstProp.cpp)
target_include_directories(OMONNXRewrite
PRIVATE ${ONNX_MLIR_SRC_ROOT} ${ONNX_MLIR_BIN_ROOT}
${ONNF_MLIR_SRC_ROOT})

View File

@ -25,16 +25,17 @@ using namespace mlir;
namespace {
// =============================================================================
// Instructions to add a constant operation. There is currently support for
// adding constant propagation for unary and binary athythmetic ops (binary ops
// support broadcast). To add an operation, you simply have to add a templated
// method on how to compute the result in terms of one or two inputs. Values
// comes as Attribtues, and return is also an Attribute. In that function,
// presumably you will need different methods to handle int / float /
// strings... Note that these methods cannot fail. It is your responsablitity to
// tests for which data type are supported in the rules directly. Specific type
// restrictions can be added in the DRR files.
//===----------------------------------------------------------------------===//
// Instructions to add a constant operation.
//===----------------------------------------------------------------------===//
// There is currently support for adding constant propagation for unary and
// binary athythmetic ops (binary ops support broadcast). To add an operation,
// you simply have to add a templated method on how to compute the result in
// terms of one or two inputs. Values comes as Attribtues, and return is also an
// Attribute. In that function, presumably you will need different methods to
// handle int / float / strings... Note that these methods cannot fail. It is
// your responsablitity to tests for which data type are supported in the rules
// directly. Specific type restrictions can be added in the DRR files.
// The methods are:
//
@ -42,13 +43,12 @@ namespace {
// and they need to be tempalted wtih an ONNX Operation (presuably).
//
// Then you need to add rules on how to transform the patterns; look into
// ONNXConstProp.td for example.
// ConstProp.td for example.
//
// =============================================================================
// =============================================================================
//===----------------------------------------------------------------------===//
// Code to perform constant propagation for binary in presence of broadcast.
// =============================================================================
//===----------------------------------------------------------------------===//
// Template to generate binary operation results. It takes as inupt
// the element type as well as the two element attributes for the
@ -223,9 +223,9 @@ DenseElementsAttr ConstPropElementwiseBinary(PatternRewriter &rewriter,
return DenseElementsAttr::get(resType, resRef);
}
// =============================================================================
//===----------------------------------------------------------------------===//
// Code to perform constant propagation for unary operation.
// =============================================================================
//===----------------------------------------------------------------------===//
template <typename OP>
Attribute ComputeConstProppElementwiseUnary(
@ -295,15 +295,15 @@ DenseElementsAttr ConstPropElementwiseUnary(
return DenseElementsAttr::get(resType, resRef);
}
// =============================================================================
//===----------------------------------------------------------------------===//
// Pattern definition.
// =============================================================================
//===----------------------------------------------------------------------===//
#include "src/Transform/ONNX/ONNXConstProp.inc"
// =============================================================================
//===----------------------------------------------------------------------===//
// Code to manage the pass.
// =============================================================================
//===----------------------------------------------------------------------===//
struct ConstPropONNXToONNXPass
: public PassWrapper<ConstPropONNXToONNXPass, FunctionPass> {

View File

@ -16,13 +16,14 @@
include "src/Dialect/ONNX/ONNXOps.td"
#endif // OP_BASE
// =============================================================================
// Instruction to add new constant operation rules. Minimally, you will have added
// operation in the ONNXConstProp.cpp to perform the element-wise single value
// handling of the new operator that you are dealing with. You will need to
// generate a call to the method that handle the tensor constant prop. Here
// is the call for a unary and binary operation. Adapt to your new operator:
//===----------------------------------------------------------------------===//
// Instruction to add new constant operation rules.
//===----------------------------------------------------------------------===//
// Minimally, you will have added operation in the ONNXConstProp.cpp to perform
// the element-wise single value handling of the new operator that you are dealing
// with. You will need to generate a call to the method that handle the tensor
// constant prop. Here is the call for a unary and binary operation. Adapt to your
// new operator:
//
// def CreateAddOfTwoConst :
// NativeCodeCall<"ConstPropElementwiseBinary<mlir::ONNXAddOp>($_builder, $0, $1, $2)">;
@ -34,7 +35,6 @@ include "src/Dialect/ONNX/ONNXOps.td"
// a new def name.
//
// Then you will need to add substitution rules, see examples below.
// =============================================================================
// Useful test definitions.
@ -63,8 +63,9 @@ def CreateNegOfConst :
def CreateMulOfTwoConst :
NativeCodeCall<"ConstPropElementwiseBinary<mlir::ONNXMulOp>($_builder, $0, $1, $2)">;
// =============================================================================
//===----------------------------------------------------------------------===//
// Patterns to enable opportunities with elementwise ADD operations.
//===----------------------------------------------------------------------===//
// Use commutativity to normalize constants in the second position of Add.
def AddConstCommutative1 : Pat<
@ -96,8 +97,9 @@ def AddConstProp : Pat<
[(AttributeIsNull:$s1), (AttributeIsNull:$s2)]>;
// =============================================================================
//===----------------------------------------------------------------------===//
// Patterns to enable opportunities with elementwise SUB / NEG operations.
//===----------------------------------------------------------------------===//
// Constant Propagation for Sub
def SubConstProp : Pat<
@ -125,9 +127,10 @@ def SubConstToNeg : Pat<
[(IsNotAConstant:$x), (AttributeIsNull:$s)]>;
// =============================================================================
//===----------------------------------------------------------------------===//
// Patterns to enable opportunities with elementwise MUL operations.
// Exactly the same pattern as for the elementwise ADD operations.
//===----------------------------------------------------------------------===//
// Use commutativity to normalize constants in the second position of Mul.
def MulConstCommutative1 : Pat<