[NFC] Fix comment style, shorten file names. (#169)
* Fix comment style, shorten file names. * Update CMakeLists.txt * Rename ONNXRewrite.td to Rewrite.td
This commit is contained in:
parent
60c648ae39
commit
a7781791e9
|
@ -122,7 +122,7 @@ RankedTensorType getReductionOutputType(
|
|||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Support function that computes default values for dilations.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
template <class T>
|
||||
static LogicalResult processConvDilationParam(
|
||||
T *op, Optional<ArrayAttr> kernelShape) {
|
||||
|
@ -153,7 +153,7 @@ static LogicalResult processConvDilationParam(
|
|||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Support function that computes default values for strides.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
template <class T>
|
||||
static LogicalResult processConvStrideParam(
|
||||
T *op, Optional<ArrayAttr> kernelShape) {
|
||||
|
@ -181,7 +181,7 @@ static LogicalResult processConvStrideParam(
|
|||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Support function that computes default values for pads.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
template <class T>
|
||||
static LogicalResult processConvPadParam(T *op, ArrayRef<int64_t> inputShape,
|
||||
Optional<ArrayAttr> kernelShape, Optional<ArrayAttr> stridesOpt,
|
||||
|
@ -273,8 +273,8 @@ static LogicalResult processConvPadParam(T *op, ArrayRef<int64_t> inputShape,
|
|||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Support function that computes default values for dilations, strides, and
|
||||
// pads.
|
||||
// Support function computing default values for dilations, strides, and pads.
|
||||
//===----------------------------------------------------------------------===//
|
||||
template <class T>
|
||||
static LogicalResult processConvTypeParams(T *op, Value inputOperand) {
|
||||
auto builder = mlir::Builder(op->getContext());
|
||||
|
@ -305,7 +305,7 @@ static LogicalResult processConvTypeParams(T *op, Value inputOperand) {
|
|||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Compute spatial dimensions given dilations, strides, pads, and ceil mode.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
static void insertConvSpatialDim(SmallVector<int64_t, 4> *outputDims,
|
||||
Builder &builder, ArrayRef<int64_t> xShape, Optional<ArrayAttr> kernelShape,
|
||||
Optional<ArrayAttr> padsOpt, Optional<ArrayAttr> stridesOpt,
|
||||
|
@ -335,6 +335,7 @@ static void insertConvSpatialDim(SmallVector<int64_t, 4> *outputDims,
|
|||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Support function that infers shape for RNN operations.
|
||||
//===----------------------------------------------------------------------===//
|
||||
template <typename T>
|
||||
static LogicalResult RNNShapeInference(T *op) {
|
||||
Value X = op->X();
|
||||
|
@ -516,6 +517,7 @@ LogicalResult ONNXExpOp::inferShapes() {
|
|||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Atan
|
||||
//===----------------------------------------------------------------------===//
|
||||
/// Infer the output shape of the ONNXAtanOp. This method is required by the
|
||||
/// shape inference interface.
|
||||
LogicalResult ONNXAtanOp::inferShapes() {
|
||||
|
@ -525,6 +527,7 @@ LogicalResult ONNXAtanOp::inferShapes() {
|
|||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Tan
|
||||
//===----------------------------------------------------------------------===//
|
||||
/// Infer the output shape of the ONNXTanOp. This method is required by the
|
||||
/// shape inference interface.
|
||||
LogicalResult ONNXTanOp::inferShapes() {
|
||||
|
@ -543,6 +546,7 @@ LogicalResult ONNXTanhOp::inferShapes() {
|
|||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Sin
|
||||
//===----------------------------------------------------------------------===//
|
||||
/// Infer the output shape of the ONNXSinOp. This method is required by the
|
||||
/// shape inference interface.
|
||||
LogicalResult ONNXSinOp::inferShapes() {
|
||||
|
@ -552,6 +556,7 @@ LogicalResult ONNXSinOp::inferShapes() {
|
|||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Sinh
|
||||
//===----------------------------------------------------------------------===//
|
||||
/// Infer the output shape of the ONNXSinhOp. This method is required by the
|
||||
/// shape inference interface.
|
||||
LogicalResult ONNXSinhOp::inferShapes() {
|
||||
|
@ -561,6 +566,7 @@ LogicalResult ONNXSinhOp::inferShapes() {
|
|||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Cosh
|
||||
//===----------------------------------------------------------------------===//
|
||||
/// Infer the output shape of the ONNXCoshOp. This method is required by the
|
||||
/// shape inference interface.
|
||||
LogicalResult ONNXCoshOp::inferShapes() {
|
||||
|
@ -570,6 +576,7 @@ LogicalResult ONNXCoshOp::inferShapes() {
|
|||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Cos
|
||||
//===----------------------------------------------------------------------===//
|
||||
/// Infer the output shape of the ONNXCosOp. This method is required by the
|
||||
/// shape inference interface.
|
||||
LogicalResult ONNXCosOp::inferShapes() {
|
||||
|
@ -579,6 +586,7 @@ LogicalResult ONNXCosOp::inferShapes() {
|
|||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Log
|
||||
//===----------------------------------------------------------------------===//
|
||||
/// Infer the output shape of the ONNXLogOp. This method is required by the
|
||||
/// shape inference interface.
|
||||
LogicalResult ONNXLogOp::inferShapes() {
|
||||
|
@ -588,6 +596,7 @@ LogicalResult ONNXLogOp::inferShapes() {
|
|||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// HardSigmoid
|
||||
//===----------------------------------------------------------------------===//
|
||||
/// Infer the output shape of the ONNXHardSigmoidOp. This method is required by
|
||||
/// the shape inference interface.
|
||||
LogicalResult ONNXHardSigmoidOp::inferShapes() {
|
||||
|
@ -597,6 +606,7 @@ LogicalResult ONNXHardSigmoidOp::inferShapes() {
|
|||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Sigmoid
|
||||
//===----------------------------------------------------------------------===//
|
||||
/// Infer the output shape of the ONNXSigmoidOp. This method is required by the
|
||||
/// shape inference interface.
|
||||
LogicalResult ONNXSigmoidOp::inferShapes() {
|
||||
|
@ -606,6 +616,7 @@ LogicalResult ONNXSigmoidOp::inferShapes() {
|
|||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Elu
|
||||
//===----------------------------------------------------------------------===//
|
||||
/// Infer the output shape of the ONNXEluOp. This method is required by the
|
||||
/// shape inference interface.
|
||||
LogicalResult ONNXEluOp::inferShapes() {
|
||||
|
@ -615,6 +626,7 @@ LogicalResult ONNXEluOp::inferShapes() {
|
|||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Relu
|
||||
//===----------------------------------------------------------------------===//
|
||||
/// Infer the output shape of the ONNXReluOp. This method is required by the
|
||||
/// shape inference interface.
|
||||
LogicalResult ONNXReluOp::inferShapes() {
|
||||
|
@ -624,6 +636,7 @@ LogicalResult ONNXReluOp::inferShapes() {
|
|||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// LeakyRelu
|
||||
//===----------------------------------------------------------------------===//
|
||||
/// Infer the output shape of the ONNXLeakyReluOp. This method is required by
|
||||
/// the shape inference interface.
|
||||
LogicalResult ONNXLeakyReluOp::inferShapes() {
|
||||
|
@ -633,6 +646,7 @@ LogicalResult ONNXLeakyReluOp::inferShapes() {
|
|||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Selu
|
||||
//===----------------------------------------------------------------------===//
|
||||
/// Infer the output shape of the ONNXSeluOp. This method is required by
|
||||
/// the shape inference interface.
|
||||
LogicalResult ONNXSeluOp::inferShapes() {
|
||||
|
@ -642,6 +656,7 @@ LogicalResult ONNXSeluOp::inferShapes() {
|
|||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Reciprocal
|
||||
//===----------------------------------------------------------------------===//
|
||||
/// Infer the output shape of the ONNXReciprocalOp. This method is required by
|
||||
/// the shape inference interface.
|
||||
LogicalResult ONNXReciprocalOp::inferShapes() {
|
||||
|
@ -651,6 +666,7 @@ LogicalResult ONNXReciprocalOp::inferShapes() {
|
|||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Softmax
|
||||
//===----------------------------------------------------------------------===//
|
||||
/// Infer the output shape of the ONNXSoftmaxOp. This method is required by
|
||||
/// the shape inference interface.
|
||||
LogicalResult ONNXSoftmaxOp::inferShapes() {
|
||||
|
@ -660,6 +676,7 @@ LogicalResult ONNXSoftmaxOp::inferShapes() {
|
|||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Softplus
|
||||
//===----------------------------------------------------------------------===//
|
||||
/// Infer the output shape of the ONNXSoftplusOp. This method is required by
|
||||
/// the shape inference interface.
|
||||
LogicalResult ONNXSoftplusOp::inferShapes() {
|
||||
|
@ -669,6 +686,7 @@ LogicalResult ONNXSoftplusOp::inferShapes() {
|
|||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Softsign
|
||||
//===----------------------------------------------------------------------===//
|
||||
/// Infer the output shape of the ONNXSoftsignOp. This method is required by
|
||||
/// the shape inference interface.
|
||||
LogicalResult ONNXSoftsignOp::inferShapes() {
|
||||
|
@ -678,6 +696,7 @@ LogicalResult ONNXSoftsignOp::inferShapes() {
|
|||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Sqrt
|
||||
//===----------------------------------------------------------------------===//
|
||||
/// Infer the output shape of the ONNXSqrtOp. This method is required by
|
||||
/// the shape inference interface.
|
||||
LogicalResult ONNXSqrtOp::inferShapes() {
|
||||
|
@ -687,6 +706,7 @@ LogicalResult ONNXSqrtOp::inferShapes() {
|
|||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Sign
|
||||
//===----------------------------------------------------------------------===//
|
||||
/// Infer the output shape of the ONNXSignOp. This method is required by
|
||||
/// the shape inference interface.
|
||||
LogicalResult ONNXSignOp::inferShapes() {
|
||||
|
@ -696,6 +716,7 @@ LogicalResult ONNXSignOp::inferShapes() {
|
|||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Abs
|
||||
//===----------------------------------------------------------------------===//
|
||||
/// Infer the output shape of the ONNXAbsOp. This method is required by the
|
||||
/// shape inference interface.
|
||||
LogicalResult ONNXAbsOp::inferShapes() {
|
||||
|
@ -705,6 +726,7 @@ LogicalResult ONNXAbsOp::inferShapes() {
|
|||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Add
|
||||
//===----------------------------------------------------------------------===//
|
||||
/// Infer the output shape of the ONNXAddOp. This method is required by the
|
||||
/// shape inference interface.
|
||||
LogicalResult ONNXAddOp::inferShapes() {
|
||||
|
@ -719,6 +741,7 @@ LogicalResult ONNXAddOp::inferShapes() {
|
|||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Mul
|
||||
//===----------------------------------------------------------------------===//
|
||||
/// Infer the output shape of the ONNXMulOp. This method is required by the
|
||||
/// shape inference interface.
|
||||
LogicalResult ONNXMulOp::inferShapes() {
|
||||
|
@ -733,6 +756,7 @@ LogicalResult ONNXMulOp::inferShapes() {
|
|||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Div
|
||||
//===----------------------------------------------------------------------===//
|
||||
/// Infer the output shape of the ONNXDivOp. This method is required by the
|
||||
/// shape inference interface.
|
||||
LogicalResult ONNXDivOp::inferShapes() {
|
||||
|
@ -747,6 +771,7 @@ LogicalResult ONNXDivOp::inferShapes() {
|
|||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Sub
|
||||
//===----------------------------------------------------------------------===//
|
||||
/// Infer the output shape of the ONNXSubOp. This method is required by the
|
||||
/// shape inference interface.
|
||||
LogicalResult ONNXSubOp::inferShapes() {
|
||||
|
@ -761,6 +786,7 @@ LogicalResult ONNXSubOp::inferShapes() {
|
|||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// And
|
||||
//===----------------------------------------------------------------------===//
|
||||
/// Infer the output shape of the ONNXAndOp. This method is required by the
|
||||
/// shape inference interface.
|
||||
LogicalResult ONNXAndOp::inferShapes() {
|
||||
|
@ -775,6 +801,7 @@ LogicalResult ONNXAndOp::inferShapes() {
|
|||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Or
|
||||
//===----------------------------------------------------------------------===//
|
||||
/// Infer the output shape of the ONNXOrOp. This method is required by the
|
||||
/// shape inference interface.
|
||||
LogicalResult ONNXOrOp::inferShapes() {
|
||||
|
@ -789,6 +816,7 @@ LogicalResult ONNXOrOp::inferShapes() {
|
|||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Xor
|
||||
//===----------------------------------------------------------------------===//
|
||||
/// Infer the output shape of the ONNXXorOp. This method is required by the
|
||||
/// shape inference interface.
|
||||
LogicalResult ONNXXorOp::inferShapes() {
|
||||
|
@ -801,10 +829,9 @@ LogicalResult ONNXXorOp::inferShapes() {
|
|||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Sum
|
||||
//===----------------------------------------------------------------------===//
|
||||
/// Infer the output shape of the ONNXSumOp. This method is required by the
|
||||
/// shape inference interface.
|
||||
LogicalResult ONNXSumOp::inferShapes() {
|
||||
|
@ -823,6 +850,7 @@ LogicalResult ONNXSumOp::inferShapes() {
|
|||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Max
|
||||
//===----------------------------------------------------------------------===//
|
||||
/// Infer the output shape of the ONNXMaxOp. This method is required by the
|
||||
/// shape inference interface.
|
||||
LogicalResult ONNXMaxOp::inferShapes() {
|
||||
|
@ -841,6 +869,7 @@ LogicalResult ONNXMaxOp::inferShapes() {
|
|||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Min
|
||||
//===----------------------------------------------------------------------===//
|
||||
/// Infer the output shape of the ONNXMinOp. This method is required by the
|
||||
/// shape inference interface.
|
||||
LogicalResult ONNXMinOp::inferShapes() {
|
||||
|
@ -859,6 +888,7 @@ LogicalResult ONNXMinOp::inferShapes() {
|
|||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Neg
|
||||
//===----------------------------------------------------------------------===//
|
||||
/// Infer the output shape of the ONNXNegOp. This method is required by the
|
||||
/// shape inference interface.
|
||||
LogicalResult ONNXNegOp::inferShapes() {
|
||||
|
@ -868,6 +898,7 @@ LogicalResult ONNXNegOp::inferShapes() {
|
|||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Identity
|
||||
//===----------------------------------------------------------------------===//
|
||||
/// Infer the output shape of the ONNXIdentityOp. This method is required by the
|
||||
/// shape inference interface.
|
||||
LogicalResult ONNXIdentityOp::inferShapes() {
|
||||
|
@ -876,8 +907,8 @@ LogicalResult ONNXIdentityOp::inferShapes() {
|
|||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// MatMul
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
LogicalResult ONNXMatMulOp::inferShapes() {
|
||||
// Cannot infer shape if no shape exists.
|
||||
|
@ -1006,10 +1037,7 @@ LogicalResult ONNXMatMulOp::inferShapes() {
|
|||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// Gemm
|
||||
|
||||
LogicalResult ONNXGemmOp::inferShapes() {
|
||||
bool hasBias = !C().getType().isa<NoneType>();
|
||||
// Cannot infer shape if no shape exists.
|
||||
|
@ -1104,8 +1132,8 @@ LogicalResult ONNXBatchNormalizationTestModeOp::inferShapes() {
|
|||
// Take into account the dimensionality of the matrix.
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// Reshape
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
LogicalResult ONNXReshapeOp::inferShapes() {
|
||||
// Cannot infer shape if no shape tensor is specified.
|
||||
|
@ -1177,8 +1205,6 @@ LogicalResult ONNXReshapeOp::inferShapes() {
|
|||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// Transpose
|
||||
|
||||
LogicalResult ONNXTransposeOp::inferShapes() {
|
||||
|
@ -1206,8 +1232,8 @@ LogicalResult ONNXTransposeOp::inferShapes() {
|
|||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// ReduceMax
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
LogicalResult ONNXReduceMaxOp::inferShapes() {
|
||||
if (!getOperand().getType().isa<RankedTensorType>())
|
||||
|
@ -1219,8 +1245,8 @@ LogicalResult ONNXReduceMaxOp::inferShapes() {
|
|||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// ReduceMin
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
LogicalResult ONNXReduceMinOp::inferShapes() {
|
||||
if (!getOperand().getType().isa<RankedTensorType>())
|
||||
|
@ -1232,8 +1258,8 @@ LogicalResult ONNXReduceMinOp::inferShapes() {
|
|||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// ReduceProd
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
LogicalResult ONNXReduceProdOp::inferShapes() {
|
||||
if (!getOperand().getType().isa<RankedTensorType>())
|
||||
|
@ -1245,8 +1271,8 @@ LogicalResult ONNXReduceProdOp::inferShapes() {
|
|||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// ReduceSum
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
LogicalResult ONNXReduceSumOp::inferShapes() {
|
||||
if (!getOperand().getType().isa<RankedTensorType>())
|
||||
|
@ -1258,8 +1284,8 @@ LogicalResult ONNXReduceSumOp::inferShapes() {
|
|||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// Conv
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// For this operation, we define the attributes once in the original Conv
|
||||
// operation class. There is no need to redefine the attribute names for the
|
||||
|
@ -1373,8 +1399,8 @@ LogicalResult ONNXConvOp::inferShapes() {
|
|||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// ConvTranspose
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// For this operation, we define the attributes once in the original Conv
|
||||
// operation class. There is no need to redefine the attribute names for the
|
||||
|
@ -1505,8 +1531,9 @@ LogicalResult ONNXConvTransposeOp::inferShapes() {
|
|||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// AveragePool
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// Infer shape attributes output:
|
||||
// - auto_pad set to NOTSET;
|
||||
// - strides: set to 1 if not defined by user;
|
||||
|
@ -1557,8 +1584,9 @@ LogicalResult ONNXAveragePoolOp::inferShapes() {
|
|||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// MaxPoolSingleOut
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// Infer shape attributes output:
|
||||
// - auto_pad set to NOTSET;
|
||||
// - dilations, strides: set to 1 if not defined by user;
|
||||
|
@ -1607,6 +1635,8 @@ LogicalResult ONNXMaxPoolSingleOutOp::inferShapes() {
|
|||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Pad
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
LogicalResult ONNXPadOp::inferShapes() {
|
||||
|
@ -1678,7 +1708,9 @@ static Type padShapeInferenceHelper(Value data, ArrayAttr padsOpt) {
|
|||
}
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// PadConstantPad
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
LogicalResult ONNXPadConstantPadOp::inferShapes() {
|
||||
auto outputType = padShapeInferenceHelper(data(), pads());
|
||||
|
@ -1689,8 +1721,8 @@ LogicalResult ONNXPadConstantPadOp::inferShapes() {
|
|||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// PadConstantValuePad
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
LogicalResult ONNXPadConstantValuePadOp::inferShapes() {
|
||||
auto outputType = padShapeInferenceHelper(data(), pads());
|
||||
|
@ -1711,8 +1743,8 @@ void ONNXPadConstantValuePadOp::build(OpBuilder &builder, OperationState &state,
|
|||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// Unsqueeze
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
LogicalResult ONNXUnsqueezeOp::inferShapes() {
|
||||
if (!data().getType().isa<RankedTensorType>())
|
||||
|
@ -1753,6 +1785,7 @@ LogicalResult ONNXUnsqueezeOp::inferShapes() {
|
|||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Cast
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
LogicalResult ONNXCastOp::inferShapes() {
|
||||
ShapedType inputType = input().getType().dyn_cast<ShapedType>();
|
||||
|
@ -1781,6 +1814,7 @@ LogicalResult ONNXCastOp::inferShapes() {
|
|||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Constant
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
LogicalResult ONNXConstantOp::inferShapes() {
|
||||
if ((sparse_value().hasValue() && value().hasValue()) ||
|
||||
|
@ -1796,7 +1830,9 @@ LogicalResult ONNXConstantOp::inferShapes() {
|
|||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Concat
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
LogicalResult ONNXConcatOp::inferShapes() {
|
||||
int inputNum = getNumOperands();
|
||||
|
@ -1854,21 +1890,25 @@ LogicalResult ONNXConcatOp::inferShapes() {
|
|||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// RNN
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
LogicalResult ONNXRNNOp::inferShapes() { return RNNShapeInference<>(this); }
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// LSTM
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
LogicalResult ONNXLSTMOp::inferShapes() { return RNNShapeInference<>(this); }
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// GRU
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
LogicalResult ONNXGRUOp::inferShapes() { return RNNShapeInference<>(this); }
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Split
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
LogicalResult ONNXSplitOp::inferShapes() {
|
||||
if (!getOperand().getType().cast<RankedTensorType>())
|
||||
|
@ -1933,6 +1973,7 @@ LogicalResult ONNXSplitOp::inferShapes() {
|
|||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Flatten
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
LogicalResult ONNXFlattenOp::inferShapes() {
|
||||
assert(axis() == 1 && "ONNXFlattenOp can only handle axis=1 for now");
|
||||
|
@ -1967,6 +2008,7 @@ LogicalResult ONNXFlattenOp::inferShapes() {
|
|||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// DynamicQuantizeLinear
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
LogicalResult ONNXDynamicQuantizeLinearOp::inferShapes() {
|
||||
auto inTy = x().getType().dyn_cast<RankedTensorType>();
|
||||
|
@ -2000,6 +2042,7 @@ LogicalResult ONNXDynamicQuantizeLinearOp::inferShapes() {
|
|||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// QuantizeLinear
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
LogicalResult ONNXQuantizeLinearOp::inferShapes() {
|
||||
auto inTy = x().getType().dyn_cast<RankedTensorType>();
|
||||
|
@ -2022,6 +2065,7 @@ LogicalResult ONNXQuantizeLinearOp::inferShapes() {
|
|||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// DequantizeLinear
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
LogicalResult ONNXDequantizeLinearOp::inferShapes() {
|
||||
auto inTy = x().getType().dyn_cast<RankedTensorType>();
|
||||
|
@ -2042,6 +2086,7 @@ LogicalResult ONNXDequantizeLinearOp::inferShapes() {
|
|||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ConvInteger - copied almost exactly from Conv (X -> x, W -> w, no bias)
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
LogicalResult ONNXConvIntegerOp::inferShapes() {
|
||||
// Generic shape for data input X, weight tensor W
|
||||
|
|
|
@ -26,27 +26,27 @@ add_dependencies(OMElideConstants OMONNXOps)
|
|||
target_link_libraries(OMElideConstants
|
||||
onnx)
|
||||
|
||||
set(LLVM_TARGET_DEFINITIONS ONNXRewrite.td)
|
||||
set(LLVM_TARGET_DEFINITIONS Rewrite.td)
|
||||
onnx_mlir_tablegen(ONNXRewrite.inc -gen-rewriters)
|
||||
add_public_tablegen_target(OMONNXRewriteIncGen)
|
||||
|
||||
set(LLVM_TARGET_DEFINITIONS ONNXCombine.td)
|
||||
set(LLVM_TARGET_DEFINITIONS Combine.td)
|
||||
onnx_mlir_tablegen(ONNXCombine.inc -gen-rewriters)
|
||||
add_public_tablegen_target(OMONNXCombineIncGen)
|
||||
|
||||
set(LLVM_TARGET_DEFINITIONS ONNXDecompose.td)
|
||||
set(LLVM_TARGET_DEFINITIONS Decompose.td)
|
||||
onnx_mlir_tablegen(ONNXDecompose.inc -gen-rewriters)
|
||||
add_public_tablegen_target(OMONNXDecomposeIncGen)
|
||||
|
||||
set(LLVM_TARGET_DEFINITIONS ONNXConstProp.td)
|
||||
set(LLVM_TARGET_DEFINITIONS ConstProp.td)
|
||||
onnx_mlir_tablegen(ONNXConstProp.inc -gen-rewriters)
|
||||
add_public_tablegen_target(OMONNXConstPropIncGen)
|
||||
|
||||
add_library(OMONNXRewrite
|
||||
ONNXRewrite.cpp
|
||||
ONNXCombine.cpp
|
||||
ONNXDecompose.cpp
|
||||
ONNXConstProp.cpp)
|
||||
Rewrite.cpp
|
||||
Combine.cpp
|
||||
Decompose.cpp
|
||||
ConstProp.cpp)
|
||||
target_include_directories(OMONNXRewrite
|
||||
PRIVATE ${ONNX_MLIR_SRC_ROOT} ${ONNX_MLIR_BIN_ROOT}
|
||||
${ONNF_MLIR_SRC_ROOT})
|
||||
|
|
|
@ -25,16 +25,17 @@ using namespace mlir;
|
|||
|
||||
namespace {
|
||||
|
||||
// =============================================================================
|
||||
// Instructions to add a constant operation. There is currently support for
|
||||
// adding constant propagation for unary and binary athythmetic ops (binary ops
|
||||
// support broadcast). To add an operation, you simply have to add a templated
|
||||
// method on how to compute the result in terms of one or two inputs. Values
|
||||
// comes as Attribtues, and return is also an Attribute. In that function,
|
||||
// presumably you will need different methods to handle int / float /
|
||||
// strings... Note that these methods cannot fail. It is your responsablitity to
|
||||
// tests for which data type are supported in the rules directly. Specific type
|
||||
// restrictions can be added in the DRR files.
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Instructions to add a constant operation.
|
||||
//===----------------------------------------------------------------------===//
|
||||
// There is currently support for adding constant propagation for unary and
|
||||
// binary athythmetic ops (binary ops support broadcast). To add an operation,
|
||||
// you simply have to add a templated method on how to compute the result in
|
||||
// terms of one or two inputs. Values comes as Attribtues, and return is also an
|
||||
// Attribute. In that function, presumably you will need different methods to
|
||||
// handle int / float / strings... Note that these methods cannot fail. It is
|
||||
// your responsablitity to tests for which data type are supported in the rules
|
||||
// directly. Specific type restrictions can be added in the DRR files.
|
||||
|
||||
// The methods are:
|
||||
//
|
||||
|
@ -42,13 +43,12 @@ namespace {
|
|||
// and they need to be tempalted wtih an ONNX Operation (presuably).
|
||||
//
|
||||
// Then you need to add rules on how to transform the patterns; look into
|
||||
// ONNXConstProp.td for example.
|
||||
// ConstProp.td for example.
|
||||
//
|
||||
// =============================================================================
|
||||
|
||||
// =============================================================================
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Code to perform constant propagation for binary in presence of broadcast.
|
||||
// =============================================================================
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// Template to generate binary operation results. It takes as inupt
|
||||
// the element type as well as the two element attributes for the
|
||||
|
@ -223,9 +223,9 @@ DenseElementsAttr ConstPropElementwiseBinary(PatternRewriter &rewriter,
|
|||
return DenseElementsAttr::get(resType, resRef);
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Code to perform constant propagation for unary operation.
|
||||
// =============================================================================
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
template <typename OP>
|
||||
Attribute ComputeConstProppElementwiseUnary(
|
||||
|
@ -295,15 +295,15 @@ DenseElementsAttr ConstPropElementwiseUnary(
|
|||
return DenseElementsAttr::get(resType, resRef);
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Pattern definition.
|
||||
// =============================================================================
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "src/Transform/ONNX/ONNXConstProp.inc"
|
||||
|
||||
// =============================================================================
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Code to manage the pass.
|
||||
// =============================================================================
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
struct ConstPropONNXToONNXPass
|
||||
: public PassWrapper<ConstPropONNXToONNXPass, FunctionPass> {
|
|
@ -16,13 +16,14 @@
|
|||
include "src/Dialect/ONNX/ONNXOps.td"
|
||||
#endif // OP_BASE
|
||||
|
||||
|
||||
// =============================================================================
|
||||
// Instruction to add new constant operation rules. Minimally, you will have added
|
||||
// operation in the ONNXConstProp.cpp to perform the element-wise single value
|
||||
// handling of the new operator that you are dealing with. You will need to
|
||||
// generate a call to the method that handle the tensor constant prop. Here
|
||||
// is the call for a unary and binary operation. Adapt to your new operator:
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Instruction to add new constant operation rules.
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Minimally, you will have added operation in the ONNXConstProp.cpp to perform
|
||||
// the element-wise single value handling of the new operator that you are dealing
|
||||
// with. You will need to generate a call to the method that handle the tensor
|
||||
// constant prop. Here is the call for a unary and binary operation. Adapt to your
|
||||
// new operator:
|
||||
//
|
||||
// def CreateAddOfTwoConst :
|
||||
// NativeCodeCall<"ConstPropElementwiseBinary<mlir::ONNXAddOp>($_builder, $0, $1, $2)">;
|
||||
|
@ -34,7 +35,6 @@ include "src/Dialect/ONNX/ONNXOps.td"
|
|||
// a new def name.
|
||||
//
|
||||
// Then you will need to add substitution rules, see examples below.
|
||||
// =============================================================================
|
||||
|
||||
|
||||
// Useful test definitions.
|
||||
|
@ -63,8 +63,9 @@ def CreateNegOfConst :
|
|||
def CreateMulOfTwoConst :
|
||||
NativeCodeCall<"ConstPropElementwiseBinary<mlir::ONNXMulOp>($_builder, $0, $1, $2)">;
|
||||
|
||||
// =============================================================================
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Patterns to enable opportunities with elementwise ADD operations.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// Use commutativity to normalize constants in the second position of Add.
|
||||
def AddConstCommutative1 : Pat<
|
||||
|
@ -96,8 +97,9 @@ def AddConstProp : Pat<
|
|||
[(AttributeIsNull:$s1), (AttributeIsNull:$s2)]>;
|
||||
|
||||
|
||||
// =============================================================================
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Patterns to enable opportunities with elementwise SUB / NEG operations.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// Constant Propagation for Sub
|
||||
def SubConstProp : Pat<
|
||||
|
@ -125,9 +127,10 @@ def SubConstToNeg : Pat<
|
|||
[(IsNotAConstant:$x), (AttributeIsNull:$s)]>;
|
||||
|
||||
|
||||
// =============================================================================
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Patterns to enable opportunities with elementwise MUL operations.
|
||||
// Exactly the same pattern as for the elementwise ADD operations.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// Use commutativity to normalize constants in the second position of Mul.
|
||||
def MulConstCommutative1 : Pat<
|
Loading…
Reference in New Issue