[NFC] Fix comment style, shorten file names. (#169)
* Fix comment style, shorten file names. * Update CMakeLists.txt * Rename ONNXRewrite.td to Rewrite.td
This commit is contained in:
parent
60c648ae39
commit
a7781791e9
|
@ -122,7 +122,7 @@ RankedTensorType getReductionOutputType(
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Support function that computes default values for dilations.
|
// Support function that computes default values for dilations.
|
||||||
//
|
//===----------------------------------------------------------------------===//
|
||||||
template <class T>
|
template <class T>
|
||||||
static LogicalResult processConvDilationParam(
|
static LogicalResult processConvDilationParam(
|
||||||
T *op, Optional<ArrayAttr> kernelShape) {
|
T *op, Optional<ArrayAttr> kernelShape) {
|
||||||
|
@ -153,7 +153,7 @@ static LogicalResult processConvDilationParam(
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Support function that computes default values for strides.
|
// Support function that computes default values for strides.
|
||||||
//
|
//===----------------------------------------------------------------------===//
|
||||||
template <class T>
|
template <class T>
|
||||||
static LogicalResult processConvStrideParam(
|
static LogicalResult processConvStrideParam(
|
||||||
T *op, Optional<ArrayAttr> kernelShape) {
|
T *op, Optional<ArrayAttr> kernelShape) {
|
||||||
|
@ -181,7 +181,7 @@ static LogicalResult processConvStrideParam(
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Support function that computes default values for pads.
|
// Support function that computes default values for pads.
|
||||||
//
|
//===----------------------------------------------------------------------===//
|
||||||
template <class T>
|
template <class T>
|
||||||
static LogicalResult processConvPadParam(T *op, ArrayRef<int64_t> inputShape,
|
static LogicalResult processConvPadParam(T *op, ArrayRef<int64_t> inputShape,
|
||||||
Optional<ArrayAttr> kernelShape, Optional<ArrayAttr> stridesOpt,
|
Optional<ArrayAttr> kernelShape, Optional<ArrayAttr> stridesOpt,
|
||||||
|
@ -273,8 +273,8 @@ static LogicalResult processConvPadParam(T *op, ArrayRef<int64_t> inputShape,
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Support function that computes default values for dilations, strides, and
|
// Support function computing default values for dilations, strides, and pads.
|
||||||
// pads.
|
//===----------------------------------------------------------------------===//
|
||||||
template <class T>
|
template <class T>
|
||||||
static LogicalResult processConvTypeParams(T *op, Value inputOperand) {
|
static LogicalResult processConvTypeParams(T *op, Value inputOperand) {
|
||||||
auto builder = mlir::Builder(op->getContext());
|
auto builder = mlir::Builder(op->getContext());
|
||||||
|
@ -305,7 +305,7 @@ static LogicalResult processConvTypeParams(T *op, Value inputOperand) {
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Compute spatial dimensions given dilations, strides, pads, and ceil mode.
|
// Compute spatial dimensions given dilations, strides, pads, and ceil mode.
|
||||||
//
|
//===----------------------------------------------------------------------===//
|
||||||
static void insertConvSpatialDim(SmallVector<int64_t, 4> *outputDims,
|
static void insertConvSpatialDim(SmallVector<int64_t, 4> *outputDims,
|
||||||
Builder &builder, ArrayRef<int64_t> xShape, Optional<ArrayAttr> kernelShape,
|
Builder &builder, ArrayRef<int64_t> xShape, Optional<ArrayAttr> kernelShape,
|
||||||
Optional<ArrayAttr> padsOpt, Optional<ArrayAttr> stridesOpt,
|
Optional<ArrayAttr> padsOpt, Optional<ArrayAttr> stridesOpt,
|
||||||
|
@ -335,6 +335,7 @@ static void insertConvSpatialDim(SmallVector<int64_t, 4> *outputDims,
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Support function that infers shape for RNN operations.
|
// Support function that infers shape for RNN operations.
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
template <typename T>
|
template <typename T>
|
||||||
static LogicalResult RNNShapeInference(T *op) {
|
static LogicalResult RNNShapeInference(T *op) {
|
||||||
Value X = op->X();
|
Value X = op->X();
|
||||||
|
@ -516,6 +517,7 @@ LogicalResult ONNXExpOp::inferShapes() {
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Atan
|
// Atan
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
/// Infer the output shape of the ONNXAtanOp. This method is required by the
|
/// Infer the output shape of the ONNXAtanOp. This method is required by the
|
||||||
/// shape inference interface.
|
/// shape inference interface.
|
||||||
LogicalResult ONNXAtanOp::inferShapes() {
|
LogicalResult ONNXAtanOp::inferShapes() {
|
||||||
|
@ -525,6 +527,7 @@ LogicalResult ONNXAtanOp::inferShapes() {
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Tan
|
// Tan
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
/// Infer the output shape of the ONNXTanOp. This method is required by the
|
/// Infer the output shape of the ONNXTanOp. This method is required by the
|
||||||
/// shape inference interface.
|
/// shape inference interface.
|
||||||
LogicalResult ONNXTanOp::inferShapes() {
|
LogicalResult ONNXTanOp::inferShapes() {
|
||||||
|
@ -543,6 +546,7 @@ LogicalResult ONNXTanhOp::inferShapes() {
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Sin
|
// Sin
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
/// Infer the output shape of the ONNXSinOp. This method is required by the
|
/// Infer the output shape of the ONNXSinOp. This method is required by the
|
||||||
/// shape inference interface.
|
/// shape inference interface.
|
||||||
LogicalResult ONNXSinOp::inferShapes() {
|
LogicalResult ONNXSinOp::inferShapes() {
|
||||||
|
@ -552,6 +556,7 @@ LogicalResult ONNXSinOp::inferShapes() {
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Sinh
|
// Sinh
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
/// Infer the output shape of the ONNXSinhOp. This method is required by the
|
/// Infer the output shape of the ONNXSinhOp. This method is required by the
|
||||||
/// shape inference interface.
|
/// shape inference interface.
|
||||||
LogicalResult ONNXSinhOp::inferShapes() {
|
LogicalResult ONNXSinhOp::inferShapes() {
|
||||||
|
@ -561,6 +566,7 @@ LogicalResult ONNXSinhOp::inferShapes() {
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Cosh
|
// Cosh
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
/// Infer the output shape of the ONNXCoshOp. This method is required by the
|
/// Infer the output shape of the ONNXCoshOp. This method is required by the
|
||||||
/// shape inference interface.
|
/// shape inference interface.
|
||||||
LogicalResult ONNXCoshOp::inferShapes() {
|
LogicalResult ONNXCoshOp::inferShapes() {
|
||||||
|
@ -570,6 +576,7 @@ LogicalResult ONNXCoshOp::inferShapes() {
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Cos
|
// Cos
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
/// Infer the output shape of the ONNXCosOp. This method is required by the
|
/// Infer the output shape of the ONNXCosOp. This method is required by the
|
||||||
/// shape inference interface.
|
/// shape inference interface.
|
||||||
LogicalResult ONNXCosOp::inferShapes() {
|
LogicalResult ONNXCosOp::inferShapes() {
|
||||||
|
@ -579,6 +586,7 @@ LogicalResult ONNXCosOp::inferShapes() {
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Log
|
// Log
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
/// Infer the output shape of the ONNXLogOp. This method is required by the
|
/// Infer the output shape of the ONNXLogOp. This method is required by the
|
||||||
/// shape inference interface.
|
/// shape inference interface.
|
||||||
LogicalResult ONNXLogOp::inferShapes() {
|
LogicalResult ONNXLogOp::inferShapes() {
|
||||||
|
@ -588,6 +596,7 @@ LogicalResult ONNXLogOp::inferShapes() {
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// HardSigmoid
|
// HardSigmoid
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
/// Infer the output shape of the ONNXHardSigmoidOp. This method is required by
|
/// Infer the output shape of the ONNXHardSigmoidOp. This method is required by
|
||||||
/// the shape inference interface.
|
/// the shape inference interface.
|
||||||
LogicalResult ONNXHardSigmoidOp::inferShapes() {
|
LogicalResult ONNXHardSigmoidOp::inferShapes() {
|
||||||
|
@ -597,6 +606,7 @@ LogicalResult ONNXHardSigmoidOp::inferShapes() {
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Sigmoid
|
// Sigmoid
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
/// Infer the output shape of the ONNXSigmoidOp. This method is required by the
|
/// Infer the output shape of the ONNXSigmoidOp. This method is required by the
|
||||||
/// shape inference interface.
|
/// shape inference interface.
|
||||||
LogicalResult ONNXSigmoidOp::inferShapes() {
|
LogicalResult ONNXSigmoidOp::inferShapes() {
|
||||||
|
@ -606,6 +616,7 @@ LogicalResult ONNXSigmoidOp::inferShapes() {
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Elu
|
// Elu
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
/// Infer the output shape of the ONNXEluOp. This method is required by the
|
/// Infer the output shape of the ONNXEluOp. This method is required by the
|
||||||
/// shape inference interface.
|
/// shape inference interface.
|
||||||
LogicalResult ONNXEluOp::inferShapes() {
|
LogicalResult ONNXEluOp::inferShapes() {
|
||||||
|
@ -615,6 +626,7 @@ LogicalResult ONNXEluOp::inferShapes() {
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Relu
|
// Relu
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
/// Infer the output shape of the ONNXReluOp. This method is required by the
|
/// Infer the output shape of the ONNXReluOp. This method is required by the
|
||||||
/// shape inference interface.
|
/// shape inference interface.
|
||||||
LogicalResult ONNXReluOp::inferShapes() {
|
LogicalResult ONNXReluOp::inferShapes() {
|
||||||
|
@ -624,6 +636,7 @@ LogicalResult ONNXReluOp::inferShapes() {
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// LeakyRelu
|
// LeakyRelu
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
/// Infer the output shape of the ONNXLeakyReluOp. This method is required by
|
/// Infer the output shape of the ONNXLeakyReluOp. This method is required by
|
||||||
/// the shape inference interface.
|
/// the shape inference interface.
|
||||||
LogicalResult ONNXLeakyReluOp::inferShapes() {
|
LogicalResult ONNXLeakyReluOp::inferShapes() {
|
||||||
|
@ -633,6 +646,7 @@ LogicalResult ONNXLeakyReluOp::inferShapes() {
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Selu
|
// Selu
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
/// Infer the output shape of the ONNXSeluOp. This method is required by
|
/// Infer the output shape of the ONNXSeluOp. This method is required by
|
||||||
/// the shape inference interface.
|
/// the shape inference interface.
|
||||||
LogicalResult ONNXSeluOp::inferShapes() {
|
LogicalResult ONNXSeluOp::inferShapes() {
|
||||||
|
@ -642,6 +656,7 @@ LogicalResult ONNXSeluOp::inferShapes() {
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Reciprocal
|
// Reciprocal
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
/// Infer the output shape of the ONNXReciprocalOp. This method is required by
|
/// Infer the output shape of the ONNXReciprocalOp. This method is required by
|
||||||
/// the shape inference interface.
|
/// the shape inference interface.
|
||||||
LogicalResult ONNXReciprocalOp::inferShapes() {
|
LogicalResult ONNXReciprocalOp::inferShapes() {
|
||||||
|
@ -651,6 +666,7 @@ LogicalResult ONNXReciprocalOp::inferShapes() {
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Softmax
|
// Softmax
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
/// Infer the output shape of the ONNXSoftmaxOp. This method is required by
|
/// Infer the output shape of the ONNXSoftmaxOp. This method is required by
|
||||||
/// the shape inference interface.
|
/// the shape inference interface.
|
||||||
LogicalResult ONNXSoftmaxOp::inferShapes() {
|
LogicalResult ONNXSoftmaxOp::inferShapes() {
|
||||||
|
@ -660,6 +676,7 @@ LogicalResult ONNXSoftmaxOp::inferShapes() {
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Softplus
|
// Softplus
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
/// Infer the output shape of the ONNXSoftplusOp. This method is required by
|
/// Infer the output shape of the ONNXSoftplusOp. This method is required by
|
||||||
/// the shape inference interface.
|
/// the shape inference interface.
|
||||||
LogicalResult ONNXSoftplusOp::inferShapes() {
|
LogicalResult ONNXSoftplusOp::inferShapes() {
|
||||||
|
@ -669,6 +686,7 @@ LogicalResult ONNXSoftplusOp::inferShapes() {
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Softsign
|
// Softsign
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
/// Infer the output shape of the ONNXSoftsignOp. This method is required by
|
/// Infer the output shape of the ONNXSoftsignOp. This method is required by
|
||||||
/// the shape inference interface.
|
/// the shape inference interface.
|
||||||
LogicalResult ONNXSoftsignOp::inferShapes() {
|
LogicalResult ONNXSoftsignOp::inferShapes() {
|
||||||
|
@ -678,6 +696,7 @@ LogicalResult ONNXSoftsignOp::inferShapes() {
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Sqrt
|
// Sqrt
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
/// Infer the output shape of the ONNXSqrtOp. This method is required by
|
/// Infer the output shape of the ONNXSqrtOp. This method is required by
|
||||||
/// the shape inference interface.
|
/// the shape inference interface.
|
||||||
LogicalResult ONNXSqrtOp::inferShapes() {
|
LogicalResult ONNXSqrtOp::inferShapes() {
|
||||||
|
@ -687,6 +706,7 @@ LogicalResult ONNXSqrtOp::inferShapes() {
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Sign
|
// Sign
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
/// Infer the output shape of the ONNXSignOp. This method is required by
|
/// Infer the output shape of the ONNXSignOp. This method is required by
|
||||||
/// the shape inference interface.
|
/// the shape inference interface.
|
||||||
LogicalResult ONNXSignOp::inferShapes() {
|
LogicalResult ONNXSignOp::inferShapes() {
|
||||||
|
@ -696,6 +716,7 @@ LogicalResult ONNXSignOp::inferShapes() {
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Abs
|
// Abs
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
/// Infer the output shape of the ONNXAbsOp. This method is required by the
|
/// Infer the output shape of the ONNXAbsOp. This method is required by the
|
||||||
/// shape inference interface.
|
/// shape inference interface.
|
||||||
LogicalResult ONNXAbsOp::inferShapes() {
|
LogicalResult ONNXAbsOp::inferShapes() {
|
||||||
|
@ -705,6 +726,7 @@ LogicalResult ONNXAbsOp::inferShapes() {
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Add
|
// Add
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
/// Infer the output shape of the ONNXAddOp. This method is required by the
|
/// Infer the output shape of the ONNXAddOp. This method is required by the
|
||||||
/// shape inference interface.
|
/// shape inference interface.
|
||||||
LogicalResult ONNXAddOp::inferShapes() {
|
LogicalResult ONNXAddOp::inferShapes() {
|
||||||
|
@ -719,6 +741,7 @@ LogicalResult ONNXAddOp::inferShapes() {
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Mul
|
// Mul
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
/// Infer the output shape of the ONNXMulOp. This method is required by the
|
/// Infer the output shape of the ONNXMulOp. This method is required by the
|
||||||
/// shape inference interface.
|
/// shape inference interface.
|
||||||
LogicalResult ONNXMulOp::inferShapes() {
|
LogicalResult ONNXMulOp::inferShapes() {
|
||||||
|
@ -733,6 +756,7 @@ LogicalResult ONNXMulOp::inferShapes() {
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Div
|
// Div
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
/// Infer the output shape of the ONNXDivOp. This method is required by the
|
/// Infer the output shape of the ONNXDivOp. This method is required by the
|
||||||
/// shape inference interface.
|
/// shape inference interface.
|
||||||
LogicalResult ONNXDivOp::inferShapes() {
|
LogicalResult ONNXDivOp::inferShapes() {
|
||||||
|
@ -747,6 +771,7 @@ LogicalResult ONNXDivOp::inferShapes() {
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Sub
|
// Sub
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
/// Infer the output shape of the ONNXSubOp. This method is required by the
|
/// Infer the output shape of the ONNXSubOp. This method is required by the
|
||||||
/// shape inference interface.
|
/// shape inference interface.
|
||||||
LogicalResult ONNXSubOp::inferShapes() {
|
LogicalResult ONNXSubOp::inferShapes() {
|
||||||
|
@ -761,6 +786,7 @@ LogicalResult ONNXSubOp::inferShapes() {
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// And
|
// And
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
/// Infer the output shape of the ONNXAndOp. This method is required by the
|
/// Infer the output shape of the ONNXAndOp. This method is required by the
|
||||||
/// shape inference interface.
|
/// shape inference interface.
|
||||||
LogicalResult ONNXAndOp::inferShapes() {
|
LogicalResult ONNXAndOp::inferShapes() {
|
||||||
|
@ -775,6 +801,7 @@ LogicalResult ONNXAndOp::inferShapes() {
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Or
|
// Or
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
/// Infer the output shape of the ONNXOrOp. This method is required by the
|
/// Infer the output shape of the ONNXOrOp. This method is required by the
|
||||||
/// shape inference interface.
|
/// shape inference interface.
|
||||||
LogicalResult ONNXOrOp::inferShapes() {
|
LogicalResult ONNXOrOp::inferShapes() {
|
||||||
|
@ -789,6 +816,7 @@ LogicalResult ONNXOrOp::inferShapes() {
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Xor
|
// Xor
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
/// Infer the output shape of the ONNXXorOp. This method is required by the
|
/// Infer the output shape of the ONNXXorOp. This method is required by the
|
||||||
/// shape inference interface.
|
/// shape inference interface.
|
||||||
LogicalResult ONNXXorOp::inferShapes() {
|
LogicalResult ONNXXorOp::inferShapes() {
|
||||||
|
@ -801,10 +829,9 @@ LogicalResult ONNXXorOp::inferShapes() {
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Sum
|
// Sum
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
/// Infer the output shape of the ONNXSumOp. This method is required by the
|
/// Infer the output shape of the ONNXSumOp. This method is required by the
|
||||||
/// shape inference interface.
|
/// shape inference interface.
|
||||||
LogicalResult ONNXSumOp::inferShapes() {
|
LogicalResult ONNXSumOp::inferShapes() {
|
||||||
|
@ -823,6 +850,7 @@ LogicalResult ONNXSumOp::inferShapes() {
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Max
|
// Max
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
/// Infer the output shape of the ONNXMaxOp. This method is required by the
|
/// Infer the output shape of the ONNXMaxOp. This method is required by the
|
||||||
/// shape inference interface.
|
/// shape inference interface.
|
||||||
LogicalResult ONNXMaxOp::inferShapes() {
|
LogicalResult ONNXMaxOp::inferShapes() {
|
||||||
|
@ -841,6 +869,7 @@ LogicalResult ONNXMaxOp::inferShapes() {
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Min
|
// Min
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
/// Infer the output shape of the ONNXMinOp. This method is required by the
|
/// Infer the output shape of the ONNXMinOp. This method is required by the
|
||||||
/// shape inference interface.
|
/// shape inference interface.
|
||||||
LogicalResult ONNXMinOp::inferShapes() {
|
LogicalResult ONNXMinOp::inferShapes() {
|
||||||
|
@ -859,6 +888,7 @@ LogicalResult ONNXMinOp::inferShapes() {
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Neg
|
// Neg
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
/// Infer the output shape of the ONNXNegOp. This method is required by the
|
/// Infer the output shape of the ONNXNegOp. This method is required by the
|
||||||
/// shape inference interface.
|
/// shape inference interface.
|
||||||
LogicalResult ONNXNegOp::inferShapes() {
|
LogicalResult ONNXNegOp::inferShapes() {
|
||||||
|
@ -868,6 +898,7 @@ LogicalResult ONNXNegOp::inferShapes() {
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Identity
|
// Identity
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
/// Infer the output shape of the ONNXIdentityOp. This method is required by the
|
/// Infer the output shape of the ONNXIdentityOp. This method is required by the
|
||||||
/// shape inference interface.
|
/// shape inference interface.
|
||||||
LogicalResult ONNXIdentityOp::inferShapes() {
|
LogicalResult ONNXIdentityOp::inferShapes() {
|
||||||
|
@ -876,8 +907,8 @@ LogicalResult ONNXIdentityOp::inferShapes() {
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
// MatMul
|
// MatMul
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
LogicalResult ONNXMatMulOp::inferShapes() {
|
LogicalResult ONNXMatMulOp::inferShapes() {
|
||||||
// Cannot infer shape if no shape exists.
|
// Cannot infer shape if no shape exists.
|
||||||
|
@ -1006,10 +1037,7 @@ LogicalResult ONNXMatMulOp::inferShapes() {
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
|
|
||||||
// Gemm
|
// Gemm
|
||||||
|
|
||||||
LogicalResult ONNXGemmOp::inferShapes() {
|
LogicalResult ONNXGemmOp::inferShapes() {
|
||||||
bool hasBias = !C().getType().isa<NoneType>();
|
bool hasBias = !C().getType().isa<NoneType>();
|
||||||
// Cannot infer shape if no shape exists.
|
// Cannot infer shape if no shape exists.
|
||||||
|
@ -1104,8 +1132,8 @@ LogicalResult ONNXBatchNormalizationTestModeOp::inferShapes() {
|
||||||
// Take into account the dimensionality of the matrix.
|
// Take into account the dimensionality of the matrix.
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
// Reshape
|
// Reshape
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
LogicalResult ONNXReshapeOp::inferShapes() {
|
LogicalResult ONNXReshapeOp::inferShapes() {
|
||||||
// Cannot infer shape if no shape tensor is specified.
|
// Cannot infer shape if no shape tensor is specified.
|
||||||
|
@ -1177,8 +1205,6 @@ LogicalResult ONNXReshapeOp::inferShapes() {
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
|
|
||||||
// Transpose
|
// Transpose
|
||||||
|
|
||||||
LogicalResult ONNXTransposeOp::inferShapes() {
|
LogicalResult ONNXTransposeOp::inferShapes() {
|
||||||
|
@ -1206,8 +1232,8 @@ LogicalResult ONNXTransposeOp::inferShapes() {
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
// ReduceMax
|
// ReduceMax
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
LogicalResult ONNXReduceMaxOp::inferShapes() {
|
LogicalResult ONNXReduceMaxOp::inferShapes() {
|
||||||
if (!getOperand().getType().isa<RankedTensorType>())
|
if (!getOperand().getType().isa<RankedTensorType>())
|
||||||
|
@ -1219,8 +1245,8 @@ LogicalResult ONNXReduceMaxOp::inferShapes() {
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
// ReduceMin
|
// ReduceMin
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
LogicalResult ONNXReduceMinOp::inferShapes() {
|
LogicalResult ONNXReduceMinOp::inferShapes() {
|
||||||
if (!getOperand().getType().isa<RankedTensorType>())
|
if (!getOperand().getType().isa<RankedTensorType>())
|
||||||
|
@ -1232,8 +1258,8 @@ LogicalResult ONNXReduceMinOp::inferShapes() {
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
// ReduceProd
|
// ReduceProd
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
LogicalResult ONNXReduceProdOp::inferShapes() {
|
LogicalResult ONNXReduceProdOp::inferShapes() {
|
||||||
if (!getOperand().getType().isa<RankedTensorType>())
|
if (!getOperand().getType().isa<RankedTensorType>())
|
||||||
|
@ -1245,8 +1271,8 @@ LogicalResult ONNXReduceProdOp::inferShapes() {
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
// ReduceSum
|
// ReduceSum
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
LogicalResult ONNXReduceSumOp::inferShapes() {
|
LogicalResult ONNXReduceSumOp::inferShapes() {
|
||||||
if (!getOperand().getType().isa<RankedTensorType>())
|
if (!getOperand().getType().isa<RankedTensorType>())
|
||||||
|
@ -1258,8 +1284,8 @@ LogicalResult ONNXReduceSumOp::inferShapes() {
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
// Conv
|
// Conv
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
// For this operation, we define the attributes once in the original Conv
|
// For this operation, we define the attributes once in the original Conv
|
||||||
// operation class. There is no need to redefine the attribute names for the
|
// operation class. There is no need to redefine the attribute names for the
|
||||||
|
@ -1373,8 +1399,8 @@ LogicalResult ONNXConvOp::inferShapes() {
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
// ConvTranspose
|
// ConvTranspose
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
// For this operation, we define the attributes once in the original Conv
|
// For this operation, we define the attributes once in the original Conv
|
||||||
// operation class. There is no need to redefine the attribute names for the
|
// operation class. There is no need to redefine the attribute names for the
|
||||||
|
@ -1505,8 +1531,9 @@ LogicalResult ONNXConvTransposeOp::inferShapes() {
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
// AveragePool
|
// AveragePool
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
// Infer shape attributes output:
|
// Infer shape attributes output:
|
||||||
// - auto_pad set to NOTSET;
|
// - auto_pad set to NOTSET;
|
||||||
// - strides: set to 1 if not defined by user;
|
// - strides: set to 1 if not defined by user;
|
||||||
|
@ -1557,8 +1584,9 @@ LogicalResult ONNXAveragePoolOp::inferShapes() {
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
// MaxPoolSingleOut
|
// MaxPoolSingleOut
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
// Infer shape attributes output:
|
// Infer shape attributes output:
|
||||||
// - auto_pad set to NOTSET;
|
// - auto_pad set to NOTSET;
|
||||||
// - dilations, strides: set to 1 if not defined by user;
|
// - dilations, strides: set to 1 if not defined by user;
|
||||||
|
@ -1607,6 +1635,8 @@ LogicalResult ONNXMaxPoolSingleOutOp::inferShapes() {
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// Pad
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
LogicalResult ONNXPadOp::inferShapes() {
|
LogicalResult ONNXPadOp::inferShapes() {
|
||||||
|
@ -1678,7 +1708,9 @@ static Type padShapeInferenceHelper(Value data, ArrayAttr padsOpt) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
// PadConstantPad
|
// PadConstantPad
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
LogicalResult ONNXPadConstantPadOp::inferShapes() {
|
LogicalResult ONNXPadConstantPadOp::inferShapes() {
|
||||||
auto outputType = padShapeInferenceHelper(data(), pads());
|
auto outputType = padShapeInferenceHelper(data(), pads());
|
||||||
|
@ -1689,8 +1721,8 @@ LogicalResult ONNXPadConstantPadOp::inferShapes() {
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
// PadConstantValuePad
|
// PadConstantValuePad
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
LogicalResult ONNXPadConstantValuePadOp::inferShapes() {
|
LogicalResult ONNXPadConstantValuePadOp::inferShapes() {
|
||||||
auto outputType = padShapeInferenceHelper(data(), pads());
|
auto outputType = padShapeInferenceHelper(data(), pads());
|
||||||
|
@ -1711,8 +1743,8 @@ void ONNXPadConstantValuePadOp::build(OpBuilder &builder, OperationState &state,
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
// Unsqueeze
|
// Unsqueeze
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
LogicalResult ONNXUnsqueezeOp::inferShapes() {
|
LogicalResult ONNXUnsqueezeOp::inferShapes() {
|
||||||
if (!data().getType().isa<RankedTensorType>())
|
if (!data().getType().isa<RankedTensorType>())
|
||||||
|
@ -1753,6 +1785,7 @@ LogicalResult ONNXUnsqueezeOp::inferShapes() {
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Cast
|
// Cast
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
LogicalResult ONNXCastOp::inferShapes() {
|
LogicalResult ONNXCastOp::inferShapes() {
|
||||||
ShapedType inputType = input().getType().dyn_cast<ShapedType>();
|
ShapedType inputType = input().getType().dyn_cast<ShapedType>();
|
||||||
|
@ -1781,6 +1814,7 @@ LogicalResult ONNXCastOp::inferShapes() {
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Constant
|
// Constant
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
LogicalResult ONNXConstantOp::inferShapes() {
|
LogicalResult ONNXConstantOp::inferShapes() {
|
||||||
if ((sparse_value().hasValue() && value().hasValue()) ||
|
if ((sparse_value().hasValue() && value().hasValue()) ||
|
||||||
|
@ -1796,7 +1830,9 @@ LogicalResult ONNXConstantOp::inferShapes() {
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
// Concat
|
// Concat
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
LogicalResult ONNXConcatOp::inferShapes() {
|
LogicalResult ONNXConcatOp::inferShapes() {
|
||||||
int inputNum = getNumOperands();
|
int inputNum = getNumOperands();
|
||||||
|
@ -1854,21 +1890,25 @@ LogicalResult ONNXConcatOp::inferShapes() {
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// RNN
|
// RNN
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
LogicalResult ONNXRNNOp::inferShapes() { return RNNShapeInference<>(this); }
|
LogicalResult ONNXRNNOp::inferShapes() { return RNNShapeInference<>(this); }
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// LSTM
|
// LSTM
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
LogicalResult ONNXLSTMOp::inferShapes() { return RNNShapeInference<>(this); }
|
LogicalResult ONNXLSTMOp::inferShapes() { return RNNShapeInference<>(this); }
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// GRU
|
// GRU
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
LogicalResult ONNXGRUOp::inferShapes() { return RNNShapeInference<>(this); }
|
LogicalResult ONNXGRUOp::inferShapes() { return RNNShapeInference<>(this); }
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Split
|
// Split
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
LogicalResult ONNXSplitOp::inferShapes() {
|
LogicalResult ONNXSplitOp::inferShapes() {
|
||||||
if (!getOperand().getType().cast<RankedTensorType>())
|
if (!getOperand().getType().cast<RankedTensorType>())
|
||||||
|
@ -1933,6 +1973,7 @@ LogicalResult ONNXSplitOp::inferShapes() {
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Flatten
|
// Flatten
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
LogicalResult ONNXFlattenOp::inferShapes() {
|
LogicalResult ONNXFlattenOp::inferShapes() {
|
||||||
assert(axis() == 1 && "ONNXFlattenOp can only handle axis=1 for now");
|
assert(axis() == 1 && "ONNXFlattenOp can only handle axis=1 for now");
|
||||||
|
@ -1967,6 +2008,7 @@ LogicalResult ONNXFlattenOp::inferShapes() {
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// DynamicQuantizeLinear
|
// DynamicQuantizeLinear
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
LogicalResult ONNXDynamicQuantizeLinearOp::inferShapes() {
|
LogicalResult ONNXDynamicQuantizeLinearOp::inferShapes() {
|
||||||
auto inTy = x().getType().dyn_cast<RankedTensorType>();
|
auto inTy = x().getType().dyn_cast<RankedTensorType>();
|
||||||
|
@ -2000,6 +2042,7 @@ LogicalResult ONNXDynamicQuantizeLinearOp::inferShapes() {
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// QuantizeLinear
|
// QuantizeLinear
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
LogicalResult ONNXQuantizeLinearOp::inferShapes() {
|
LogicalResult ONNXQuantizeLinearOp::inferShapes() {
|
||||||
auto inTy = x().getType().dyn_cast<RankedTensorType>();
|
auto inTy = x().getType().dyn_cast<RankedTensorType>();
|
||||||
|
@ -2022,6 +2065,7 @@ LogicalResult ONNXQuantizeLinearOp::inferShapes() {
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// DequantizeLinear
|
// DequantizeLinear
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
LogicalResult ONNXDequantizeLinearOp::inferShapes() {
|
LogicalResult ONNXDequantizeLinearOp::inferShapes() {
|
||||||
auto inTy = x().getType().dyn_cast<RankedTensorType>();
|
auto inTy = x().getType().dyn_cast<RankedTensorType>();
|
||||||
|
@ -2042,6 +2086,7 @@ LogicalResult ONNXDequantizeLinearOp::inferShapes() {
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// ConvInteger - copied almost exactly from Conv (X -> x, W -> w, no bias)
|
// ConvInteger - copied almost exactly from Conv (X -> x, W -> w, no bias)
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
LogicalResult ONNXConvIntegerOp::inferShapes() {
|
LogicalResult ONNXConvIntegerOp::inferShapes() {
|
||||||
// Generic shape for data input X, weight tensor W
|
// Generic shape for data input X, weight tensor W
|
||||||
|
|
|
@ -26,27 +26,27 @@ add_dependencies(OMElideConstants OMONNXOps)
|
||||||
target_link_libraries(OMElideConstants
|
target_link_libraries(OMElideConstants
|
||||||
onnx)
|
onnx)
|
||||||
|
|
||||||
set(LLVM_TARGET_DEFINITIONS ONNXRewrite.td)
|
set(LLVM_TARGET_DEFINITIONS Rewrite.td)
|
||||||
onnx_mlir_tablegen(ONNXRewrite.inc -gen-rewriters)
|
onnx_mlir_tablegen(ONNXRewrite.inc -gen-rewriters)
|
||||||
add_public_tablegen_target(OMONNXRewriteIncGen)
|
add_public_tablegen_target(OMONNXRewriteIncGen)
|
||||||
|
|
||||||
set(LLVM_TARGET_DEFINITIONS ONNXCombine.td)
|
set(LLVM_TARGET_DEFINITIONS Combine.td)
|
||||||
onnx_mlir_tablegen(ONNXCombine.inc -gen-rewriters)
|
onnx_mlir_tablegen(ONNXCombine.inc -gen-rewriters)
|
||||||
add_public_tablegen_target(OMONNXCombineIncGen)
|
add_public_tablegen_target(OMONNXCombineIncGen)
|
||||||
|
|
||||||
set(LLVM_TARGET_DEFINITIONS ONNXDecompose.td)
|
set(LLVM_TARGET_DEFINITIONS Decompose.td)
|
||||||
onnx_mlir_tablegen(ONNXDecompose.inc -gen-rewriters)
|
onnx_mlir_tablegen(ONNXDecompose.inc -gen-rewriters)
|
||||||
add_public_tablegen_target(OMONNXDecomposeIncGen)
|
add_public_tablegen_target(OMONNXDecomposeIncGen)
|
||||||
|
|
||||||
set(LLVM_TARGET_DEFINITIONS ONNXConstProp.td)
|
set(LLVM_TARGET_DEFINITIONS ConstProp.td)
|
||||||
onnx_mlir_tablegen(ONNXConstProp.inc -gen-rewriters)
|
onnx_mlir_tablegen(ONNXConstProp.inc -gen-rewriters)
|
||||||
add_public_tablegen_target(OMONNXConstPropIncGen)
|
add_public_tablegen_target(OMONNXConstPropIncGen)
|
||||||
|
|
||||||
add_library(OMONNXRewrite
|
add_library(OMONNXRewrite
|
||||||
ONNXRewrite.cpp
|
Rewrite.cpp
|
||||||
ONNXCombine.cpp
|
Combine.cpp
|
||||||
ONNXDecompose.cpp
|
Decompose.cpp
|
||||||
ONNXConstProp.cpp)
|
ConstProp.cpp)
|
||||||
target_include_directories(OMONNXRewrite
|
target_include_directories(OMONNXRewrite
|
||||||
PRIVATE ${ONNX_MLIR_SRC_ROOT} ${ONNX_MLIR_BIN_ROOT}
|
PRIVATE ${ONNX_MLIR_SRC_ROOT} ${ONNX_MLIR_BIN_ROOT}
|
||||||
${ONNF_MLIR_SRC_ROOT})
|
${ONNF_MLIR_SRC_ROOT})
|
||||||
|
|
|
@ -25,16 +25,17 @@ using namespace mlir;
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
// =============================================================================
|
//===----------------------------------------------------------------------===//
|
||||||
// Instructions to add a constant operation. There is currently support for
|
// Instructions to add a constant operation.
|
||||||
// adding constant propagation for unary and binary athythmetic ops (binary ops
|
//===----------------------------------------------------------------------===//
|
||||||
// support broadcast). To add an operation, you simply have to add a templated
|
// There is currently support for adding constant propagation for unary and
|
||||||
// method on how to compute the result in terms of one or two inputs. Values
|
// binary athythmetic ops (binary ops support broadcast). To add an operation,
|
||||||
// comes as Attribtues, and return is also an Attribute. In that function,
|
// you simply have to add a templated method on how to compute the result in
|
||||||
// presumably you will need different methods to handle int / float /
|
// terms of one or two inputs. Values comes as Attribtues, and return is also an
|
||||||
// strings... Note that these methods cannot fail. It is your responsablitity to
|
// Attribute. In that function, presumably you will need different methods to
|
||||||
// tests for which data type are supported in the rules directly. Specific type
|
// handle int / float / strings... Note that these methods cannot fail. It is
|
||||||
// restrictions can be added in the DRR files.
|
// your responsablitity to tests for which data type are supported in the rules
|
||||||
|
// directly. Specific type restrictions can be added in the DRR files.
|
||||||
|
|
||||||
// The methods are:
|
// The methods are:
|
||||||
//
|
//
|
||||||
|
@ -42,13 +43,12 @@ namespace {
|
||||||
// and they need to be tempalted wtih an ONNX Operation (presuably).
|
// and they need to be tempalted wtih an ONNX Operation (presuably).
|
||||||
//
|
//
|
||||||
// Then you need to add rules on how to transform the patterns; look into
|
// Then you need to add rules on how to transform the patterns; look into
|
||||||
// ONNXConstProp.td for example.
|
// ConstProp.td for example.
|
||||||
//
|
//
|
||||||
// =============================================================================
|
|
||||||
|
|
||||||
// =============================================================================
|
//===----------------------------------------------------------------------===//
|
||||||
// Code to perform constant propagation for binary in presence of broadcast.
|
// Code to perform constant propagation for binary in presence of broadcast.
|
||||||
// =============================================================================
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
// Template to generate binary operation results. It takes as inupt
|
// Template to generate binary operation results. It takes as inupt
|
||||||
// the element type as well as the two element attributes for the
|
// the element type as well as the two element attributes for the
|
||||||
|
@ -223,9 +223,9 @@ DenseElementsAttr ConstPropElementwiseBinary(PatternRewriter &rewriter,
|
||||||
return DenseElementsAttr::get(resType, resRef);
|
return DenseElementsAttr::get(resType, resRef);
|
||||||
}
|
}
|
||||||
|
|
||||||
// =============================================================================
|
//===----------------------------------------------------------------------===//
|
||||||
// Code to perform constant propagation for unary operation.
|
// Code to perform constant propagation for unary operation.
|
||||||
// =============================================================================
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
template <typename OP>
|
template <typename OP>
|
||||||
Attribute ComputeConstProppElementwiseUnary(
|
Attribute ComputeConstProppElementwiseUnary(
|
||||||
|
@ -295,15 +295,15 @@ DenseElementsAttr ConstPropElementwiseUnary(
|
||||||
return DenseElementsAttr::get(resType, resRef);
|
return DenseElementsAttr::get(resType, resRef);
|
||||||
}
|
}
|
||||||
|
|
||||||
// =============================================================================
|
//===----------------------------------------------------------------------===//
|
||||||
// Pattern definition.
|
// Pattern definition.
|
||||||
// =============================================================================
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
#include "src/Transform/ONNX/ONNXConstProp.inc"
|
#include "src/Transform/ONNX/ONNXConstProp.inc"
|
||||||
|
|
||||||
// =============================================================================
|
//===----------------------------------------------------------------------===//
|
||||||
// Code to manage the pass.
|
// Code to manage the pass.
|
||||||
// =============================================================================
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
struct ConstPropONNXToONNXPass
|
struct ConstPropONNXToONNXPass
|
||||||
: public PassWrapper<ConstPropONNXToONNXPass, FunctionPass> {
|
: public PassWrapper<ConstPropONNXToONNXPass, FunctionPass> {
|
|
@ -16,13 +16,14 @@
|
||||||
include "src/Dialect/ONNX/ONNXOps.td"
|
include "src/Dialect/ONNX/ONNXOps.td"
|
||||||
#endif // OP_BASE
|
#endif // OP_BASE
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
// =============================================================================
|
// Instruction to add new constant operation rules.
|
||||||
// Instruction to add new constant operation rules. Minimally, you will have added
|
//===----------------------------------------------------------------------===//
|
||||||
// operation in the ONNXConstProp.cpp to perform the element-wise single value
|
// Minimally, you will have added operation in the ONNXConstProp.cpp to perform
|
||||||
// handling of the new operator that you are dealing with. You will need to
|
// the element-wise single value handling of the new operator that you are dealing
|
||||||
// generate a call to the method that handle the tensor constant prop. Here
|
// with. You will need to generate a call to the method that handle the tensor
|
||||||
// is the call for a unary and binary operation. Adapt to your new operator:
|
// constant prop. Here is the call for a unary and binary operation. Adapt to your
|
||||||
|
// new operator:
|
||||||
//
|
//
|
||||||
// def CreateAddOfTwoConst :
|
// def CreateAddOfTwoConst :
|
||||||
// NativeCodeCall<"ConstPropElementwiseBinary<mlir::ONNXAddOp>($_builder, $0, $1, $2)">;
|
// NativeCodeCall<"ConstPropElementwiseBinary<mlir::ONNXAddOp>($_builder, $0, $1, $2)">;
|
||||||
|
@ -34,7 +35,6 @@ include "src/Dialect/ONNX/ONNXOps.td"
|
||||||
// a new def name.
|
// a new def name.
|
||||||
//
|
//
|
||||||
// Then you will need to add substitution rules, see examples below.
|
// Then you will need to add substitution rules, see examples below.
|
||||||
// =============================================================================
|
|
||||||
|
|
||||||
|
|
||||||
// Useful test definitions.
|
// Useful test definitions.
|
||||||
|
@ -63,8 +63,9 @@ def CreateNegOfConst :
|
||||||
def CreateMulOfTwoConst :
|
def CreateMulOfTwoConst :
|
||||||
NativeCodeCall<"ConstPropElementwiseBinary<mlir::ONNXMulOp>($_builder, $0, $1, $2)">;
|
NativeCodeCall<"ConstPropElementwiseBinary<mlir::ONNXMulOp>($_builder, $0, $1, $2)">;
|
||||||
|
|
||||||
// =============================================================================
|
//===----------------------------------------------------------------------===//
|
||||||
// Patterns to enable opportunities with elementwise ADD operations.
|
// Patterns to enable opportunities with elementwise ADD operations.
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
// Use commutativity to normalize constants in the second position of Add.
|
// Use commutativity to normalize constants in the second position of Add.
|
||||||
def AddConstCommutative1 : Pat<
|
def AddConstCommutative1 : Pat<
|
||||||
|
@ -96,8 +97,9 @@ def AddConstProp : Pat<
|
||||||
[(AttributeIsNull:$s1), (AttributeIsNull:$s2)]>;
|
[(AttributeIsNull:$s1), (AttributeIsNull:$s2)]>;
|
||||||
|
|
||||||
|
|
||||||
// =============================================================================
|
//===----------------------------------------------------------------------===//
|
||||||
// Patterns to enable opportunities with elementwise SUB / NEG operations.
|
// Patterns to enable opportunities with elementwise SUB / NEG operations.
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
// Constant Propagation for Sub
|
// Constant Propagation for Sub
|
||||||
def SubConstProp : Pat<
|
def SubConstProp : Pat<
|
||||||
|
@ -125,9 +127,10 @@ def SubConstToNeg : Pat<
|
||||||
[(IsNotAConstant:$x), (AttributeIsNull:$s)]>;
|
[(IsNotAConstant:$x), (AttributeIsNull:$s)]>;
|
||||||
|
|
||||||
|
|
||||||
// =============================================================================
|
//===----------------------------------------------------------------------===//
|
||||||
// Patterns to enable opportunities with elementwise MUL operations.
|
// Patterns to enable opportunities with elementwise MUL operations.
|
||||||
// Exactly the same pattern as for the elementwise ADD operations.
|
// Exactly the same pattern as for the elementwise ADD operations.
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
// Use commutativity to normalize constants in the second position of Mul.
|
// Use commutativity to normalize constants in the second position of Mul.
|
||||||
def MulConstCommutative1 : Pat<
|
def MulConstCommutative1 : Pat<
|
Loading…
Reference in New Issue