[NFC] Fix comment style, shorten file names. (#169)

* Fix comment style, shorten file names.

* Update CMakeLists.txt

* Rename ONNXRewrite.td to Rewrite.td
This commit is contained in:
Tian Jin 2020-06-15 11:49:09 +08:00 committed by GitHub
parent 60c648ae39
commit a7781791e9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 113 additions and 65 deletions

View File

@ -122,7 +122,7 @@ RankedTensorType getReductionOutputType(
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Support function that computes default values for dilations. // Support function that computes default values for dilations.
// //===----------------------------------------------------------------------===//
template <class T> template <class T>
static LogicalResult processConvDilationParam( static LogicalResult processConvDilationParam(
T *op, Optional<ArrayAttr> kernelShape) { T *op, Optional<ArrayAttr> kernelShape) {
@ -153,7 +153,7 @@ static LogicalResult processConvDilationParam(
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Support function that computes default values for strides. // Support function that computes default values for strides.
// //===----------------------------------------------------------------------===//
template <class T> template <class T>
static LogicalResult processConvStrideParam( static LogicalResult processConvStrideParam(
T *op, Optional<ArrayAttr> kernelShape) { T *op, Optional<ArrayAttr> kernelShape) {
@ -181,7 +181,7 @@ static LogicalResult processConvStrideParam(
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Support function that computes default values for pads. // Support function that computes default values for pads.
// //===----------------------------------------------------------------------===//
template <class T> template <class T>
static LogicalResult processConvPadParam(T *op, ArrayRef<int64_t> inputShape, static LogicalResult processConvPadParam(T *op, ArrayRef<int64_t> inputShape,
Optional<ArrayAttr> kernelShape, Optional<ArrayAttr> stridesOpt, Optional<ArrayAttr> kernelShape, Optional<ArrayAttr> stridesOpt,
@ -273,8 +273,8 @@ static LogicalResult processConvPadParam(T *op, ArrayRef<int64_t> inputShape,
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Support function that computes default values for dilations, strides, and // Support function computing default values for dilations, strides, and pads.
// pads. //===----------------------------------------------------------------------===//
template <class T> template <class T>
static LogicalResult processConvTypeParams(T *op, Value inputOperand) { static LogicalResult processConvTypeParams(T *op, Value inputOperand) {
auto builder = mlir::Builder(op->getContext()); auto builder = mlir::Builder(op->getContext());
@ -305,7 +305,7 @@ static LogicalResult processConvTypeParams(T *op, Value inputOperand) {
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Compute spatial dimensions given dilations, strides, pads, and ceil mode. // Compute spatial dimensions given dilations, strides, pads, and ceil mode.
// //===----------------------------------------------------------------------===//
static void insertConvSpatialDim(SmallVector<int64_t, 4> *outputDims, static void insertConvSpatialDim(SmallVector<int64_t, 4> *outputDims,
Builder &builder, ArrayRef<int64_t> xShape, Optional<ArrayAttr> kernelShape, Builder &builder, ArrayRef<int64_t> xShape, Optional<ArrayAttr> kernelShape,
Optional<ArrayAttr> padsOpt, Optional<ArrayAttr> stridesOpt, Optional<ArrayAttr> padsOpt, Optional<ArrayAttr> stridesOpt,
@ -335,6 +335,7 @@ static void insertConvSpatialDim(SmallVector<int64_t, 4> *outputDims,
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Support function that infers shape for RNN operations. // Support function that infers shape for RNN operations.
//===----------------------------------------------------------------------===//
template <typename T> template <typename T>
static LogicalResult RNNShapeInference(T *op) { static LogicalResult RNNShapeInference(T *op) {
Value X = op->X(); Value X = op->X();
@ -516,6 +517,7 @@ LogicalResult ONNXExpOp::inferShapes() {
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Atan // Atan
//===----------------------------------------------------------------------===//
/// Infer the output shape of the ONNXAtanOp. This method is required by the /// Infer the output shape of the ONNXAtanOp. This method is required by the
/// shape inference interface. /// shape inference interface.
LogicalResult ONNXAtanOp::inferShapes() { LogicalResult ONNXAtanOp::inferShapes() {
@ -525,6 +527,7 @@ LogicalResult ONNXAtanOp::inferShapes() {
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Tan // Tan
//===----------------------------------------------------------------------===//
/// Infer the output shape of the ONNXTanOp. This method is required by the /// Infer the output shape of the ONNXTanOp. This method is required by the
/// shape inference interface. /// shape inference interface.
LogicalResult ONNXTanOp::inferShapes() { LogicalResult ONNXTanOp::inferShapes() {
@ -543,6 +546,7 @@ LogicalResult ONNXTanhOp::inferShapes() {
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Sin // Sin
//===----------------------------------------------------------------------===//
/// Infer the output shape of the ONNXSinOp. This method is required by the /// Infer the output shape of the ONNXSinOp. This method is required by the
/// shape inference interface. /// shape inference interface.
LogicalResult ONNXSinOp::inferShapes() { LogicalResult ONNXSinOp::inferShapes() {
@ -552,6 +556,7 @@ LogicalResult ONNXSinOp::inferShapes() {
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Sinh // Sinh
//===----------------------------------------------------------------------===//
/// Infer the output shape of the ONNXSinhOp. This method is required by the /// Infer the output shape of the ONNXSinhOp. This method is required by the
/// shape inference interface. /// shape inference interface.
LogicalResult ONNXSinhOp::inferShapes() { LogicalResult ONNXSinhOp::inferShapes() {
@ -561,6 +566,7 @@ LogicalResult ONNXSinhOp::inferShapes() {
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Cosh // Cosh
//===----------------------------------------------------------------------===//
/// Infer the output shape of the ONNXCoshOp. This method is required by the /// Infer the output shape of the ONNXCoshOp. This method is required by the
/// shape inference interface. /// shape inference interface.
LogicalResult ONNXCoshOp::inferShapes() { LogicalResult ONNXCoshOp::inferShapes() {
@ -570,6 +576,7 @@ LogicalResult ONNXCoshOp::inferShapes() {
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Cos // Cos
//===----------------------------------------------------------------------===//
/// Infer the output shape of the ONNXCosOp. This method is required by the /// Infer the output shape of the ONNXCosOp. This method is required by the
/// shape inference interface. /// shape inference interface.
LogicalResult ONNXCosOp::inferShapes() { LogicalResult ONNXCosOp::inferShapes() {
@ -579,6 +586,7 @@ LogicalResult ONNXCosOp::inferShapes() {
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Log // Log
//===----------------------------------------------------------------------===//
/// Infer the output shape of the ONNXLogOp. This method is required by the /// Infer the output shape of the ONNXLogOp. This method is required by the
/// shape inference interface. /// shape inference interface.
LogicalResult ONNXLogOp::inferShapes() { LogicalResult ONNXLogOp::inferShapes() {
@ -588,6 +596,7 @@ LogicalResult ONNXLogOp::inferShapes() {
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// HardSigmoid // HardSigmoid
//===----------------------------------------------------------------------===//
/// Infer the output shape of the ONNXHardSigmoidOp. This method is required by /// Infer the output shape of the ONNXHardSigmoidOp. This method is required by
/// the shape inference interface. /// the shape inference interface.
LogicalResult ONNXHardSigmoidOp::inferShapes() { LogicalResult ONNXHardSigmoidOp::inferShapes() {
@ -597,6 +606,7 @@ LogicalResult ONNXHardSigmoidOp::inferShapes() {
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Sigmoid // Sigmoid
//===----------------------------------------------------------------------===//
/// Infer the output shape of the ONNXSigmoidOp. This method is required by the /// Infer the output shape of the ONNXSigmoidOp. This method is required by the
/// shape inference interface. /// shape inference interface.
LogicalResult ONNXSigmoidOp::inferShapes() { LogicalResult ONNXSigmoidOp::inferShapes() {
@ -606,6 +616,7 @@ LogicalResult ONNXSigmoidOp::inferShapes() {
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Elu // Elu
//===----------------------------------------------------------------------===//
/// Infer the output shape of the ONNXEluOp. This method is required by the /// Infer the output shape of the ONNXEluOp. This method is required by the
/// shape inference interface. /// shape inference interface.
LogicalResult ONNXEluOp::inferShapes() { LogicalResult ONNXEluOp::inferShapes() {
@ -615,6 +626,7 @@ LogicalResult ONNXEluOp::inferShapes() {
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Relu // Relu
//===----------------------------------------------------------------------===//
/// Infer the output shape of the ONNXReluOp. This method is required by the /// Infer the output shape of the ONNXReluOp. This method is required by the
/// shape inference interface. /// shape inference interface.
LogicalResult ONNXReluOp::inferShapes() { LogicalResult ONNXReluOp::inferShapes() {
@ -624,6 +636,7 @@ LogicalResult ONNXReluOp::inferShapes() {
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// LeakyRelu // LeakyRelu
//===----------------------------------------------------------------------===//
/// Infer the output shape of the ONNXLeakyReluOp. This method is required by /// Infer the output shape of the ONNXLeakyReluOp. This method is required by
/// the shape inference interface. /// the shape inference interface.
LogicalResult ONNXLeakyReluOp::inferShapes() { LogicalResult ONNXLeakyReluOp::inferShapes() {
@ -633,6 +646,7 @@ LogicalResult ONNXLeakyReluOp::inferShapes() {
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Selu // Selu
//===----------------------------------------------------------------------===//
/// Infer the output shape of the ONNXSeluOp. This method is required by /// Infer the output shape of the ONNXSeluOp. This method is required by
/// the shape inference interface. /// the shape inference interface.
LogicalResult ONNXSeluOp::inferShapes() { LogicalResult ONNXSeluOp::inferShapes() {
@ -642,6 +656,7 @@ LogicalResult ONNXSeluOp::inferShapes() {
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Reciprocal // Reciprocal
//===----------------------------------------------------------------------===//
/// Infer the output shape of the ONNXReciprocalOp. This method is required by /// Infer the output shape of the ONNXReciprocalOp. This method is required by
/// the shape inference interface. /// the shape inference interface.
LogicalResult ONNXReciprocalOp::inferShapes() { LogicalResult ONNXReciprocalOp::inferShapes() {
@ -651,6 +666,7 @@ LogicalResult ONNXReciprocalOp::inferShapes() {
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Softmax // Softmax
//===----------------------------------------------------------------------===//
/// Infer the output shape of the ONNXSoftmaxOp. This method is required by /// Infer the output shape of the ONNXSoftmaxOp. This method is required by
/// the shape inference interface. /// the shape inference interface.
LogicalResult ONNXSoftmaxOp::inferShapes() { LogicalResult ONNXSoftmaxOp::inferShapes() {
@ -660,6 +676,7 @@ LogicalResult ONNXSoftmaxOp::inferShapes() {
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Softplus // Softplus
//===----------------------------------------------------------------------===//
/// Infer the output shape of the ONNXSoftplusOp. This method is required by /// Infer the output shape of the ONNXSoftplusOp. This method is required by
/// the shape inference interface. /// the shape inference interface.
LogicalResult ONNXSoftplusOp::inferShapes() { LogicalResult ONNXSoftplusOp::inferShapes() {
@ -669,6 +686,7 @@ LogicalResult ONNXSoftplusOp::inferShapes() {
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Softsign // Softsign
//===----------------------------------------------------------------------===//
/// Infer the output shape of the ONNXSoftsignOp. This method is required by /// Infer the output shape of the ONNXSoftsignOp. This method is required by
/// the shape inference interface. /// the shape inference interface.
LogicalResult ONNXSoftsignOp::inferShapes() { LogicalResult ONNXSoftsignOp::inferShapes() {
@ -678,6 +696,7 @@ LogicalResult ONNXSoftsignOp::inferShapes() {
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Sqrt // Sqrt
//===----------------------------------------------------------------------===//
/// Infer the output shape of the ONNXSqrtOp. This method is required by /// Infer the output shape of the ONNXSqrtOp. This method is required by
/// the shape inference interface. /// the shape inference interface.
LogicalResult ONNXSqrtOp::inferShapes() { LogicalResult ONNXSqrtOp::inferShapes() {
@ -687,6 +706,7 @@ LogicalResult ONNXSqrtOp::inferShapes() {
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Sign // Sign
//===----------------------------------------------------------------------===//
/// Infer the output shape of the ONNXSignOp. This method is required by /// Infer the output shape of the ONNXSignOp. This method is required by
/// the shape inference interface. /// the shape inference interface.
LogicalResult ONNXSignOp::inferShapes() { LogicalResult ONNXSignOp::inferShapes() {
@ -696,6 +716,7 @@ LogicalResult ONNXSignOp::inferShapes() {
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Abs // Abs
//===----------------------------------------------------------------------===//
/// Infer the output shape of the ONNXAbsOp. This method is required by the /// Infer the output shape of the ONNXAbsOp. This method is required by the
/// shape inference interface. /// shape inference interface.
LogicalResult ONNXAbsOp::inferShapes() { LogicalResult ONNXAbsOp::inferShapes() {
@ -705,6 +726,7 @@ LogicalResult ONNXAbsOp::inferShapes() {
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Add // Add
//===----------------------------------------------------------------------===//
/// Infer the output shape of the ONNXAddOp. This method is required by the /// Infer the output shape of the ONNXAddOp. This method is required by the
/// shape inference interface. /// shape inference interface.
LogicalResult ONNXAddOp::inferShapes() { LogicalResult ONNXAddOp::inferShapes() {
@ -719,6 +741,7 @@ LogicalResult ONNXAddOp::inferShapes() {
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Mul // Mul
//===----------------------------------------------------------------------===//
/// Infer the output shape of the ONNXMulOp. This method is required by the /// Infer the output shape of the ONNXMulOp. This method is required by the
/// shape inference interface. /// shape inference interface.
LogicalResult ONNXMulOp::inferShapes() { LogicalResult ONNXMulOp::inferShapes() {
@ -733,6 +756,7 @@ LogicalResult ONNXMulOp::inferShapes() {
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Div // Div
//===----------------------------------------------------------------------===//
/// Infer the output shape of the ONNXDivOp. This method is required by the /// Infer the output shape of the ONNXDivOp. This method is required by the
/// shape inference interface. /// shape inference interface.
LogicalResult ONNXDivOp::inferShapes() { LogicalResult ONNXDivOp::inferShapes() {
@ -747,6 +771,7 @@ LogicalResult ONNXDivOp::inferShapes() {
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Sub // Sub
//===----------------------------------------------------------------------===//
/// Infer the output shape of the ONNXSubOp. This method is required by the /// Infer the output shape of the ONNXSubOp. This method is required by the
/// shape inference interface. /// shape inference interface.
LogicalResult ONNXSubOp::inferShapes() { LogicalResult ONNXSubOp::inferShapes() {
@ -761,6 +786,7 @@ LogicalResult ONNXSubOp::inferShapes() {
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// And // And
//===----------------------------------------------------------------------===//
/// Infer the output shape of the ONNXAndOp. This method is required by the /// Infer the output shape of the ONNXAndOp. This method is required by the
/// shape inference interface. /// shape inference interface.
LogicalResult ONNXAndOp::inferShapes() { LogicalResult ONNXAndOp::inferShapes() {
@ -775,6 +801,7 @@ LogicalResult ONNXAndOp::inferShapes() {
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Or // Or
//===----------------------------------------------------------------------===//
/// Infer the output shape of the ONNXOrOp. This method is required by the /// Infer the output shape of the ONNXOrOp. This method is required by the
/// shape inference interface. /// shape inference interface.
LogicalResult ONNXOrOp::inferShapes() { LogicalResult ONNXOrOp::inferShapes() {
@ -789,6 +816,7 @@ LogicalResult ONNXOrOp::inferShapes() {
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Xor // Xor
//===----------------------------------------------------------------------===//
/// Infer the output shape of the ONNXXorOp. This method is required by the /// Infer the output shape of the ONNXXorOp. This method is required by the
/// shape inference interface. /// shape inference interface.
LogicalResult ONNXXorOp::inferShapes() { LogicalResult ONNXXorOp::inferShapes() {
@ -801,10 +829,9 @@ LogicalResult ONNXXorOp::inferShapes() {
return success(); return success();
} }
//===----------------------------------------------------------------------===//
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Sum // Sum
//===----------------------------------------------------------------------===//
/// Infer the output shape of the ONNXSumOp. This method is required by the /// Infer the output shape of the ONNXSumOp. This method is required by the
/// shape inference interface. /// shape inference interface.
LogicalResult ONNXSumOp::inferShapes() { LogicalResult ONNXSumOp::inferShapes() {
@ -823,6 +850,7 @@ LogicalResult ONNXSumOp::inferShapes() {
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Max // Max
//===----------------------------------------------------------------------===//
/// Infer the output shape of the ONNXMaxOp. This method is required by the /// Infer the output shape of the ONNXMaxOp. This method is required by the
/// shape inference interface. /// shape inference interface.
LogicalResult ONNXMaxOp::inferShapes() { LogicalResult ONNXMaxOp::inferShapes() {
@ -841,6 +869,7 @@ LogicalResult ONNXMaxOp::inferShapes() {
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Min // Min
//===----------------------------------------------------------------------===//
/// Infer the output shape of the ONNXMinOp. This method is required by the /// Infer the output shape of the ONNXMinOp. This method is required by the
/// shape inference interface. /// shape inference interface.
LogicalResult ONNXMinOp::inferShapes() { LogicalResult ONNXMinOp::inferShapes() {
@ -859,6 +888,7 @@ LogicalResult ONNXMinOp::inferShapes() {
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Neg // Neg
//===----------------------------------------------------------------------===//
/// Infer the output shape of the ONNXNegOp. This method is required by the /// Infer the output shape of the ONNXNegOp. This method is required by the
/// shape inference interface. /// shape inference interface.
LogicalResult ONNXNegOp::inferShapes() { LogicalResult ONNXNegOp::inferShapes() {
@ -868,6 +898,7 @@ LogicalResult ONNXNegOp::inferShapes() {
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Identity // Identity
//===----------------------------------------------------------------------===//
/// Infer the output shape of the ONNXIdentityOp. This method is required by the /// Infer the output shape of the ONNXIdentityOp. This method is required by the
/// shape inference interface. /// shape inference interface.
LogicalResult ONNXIdentityOp::inferShapes() { LogicalResult ONNXIdentityOp::inferShapes() {
@ -876,8 +907,8 @@ LogicalResult ONNXIdentityOp::inferShapes() {
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// MatMul // MatMul
//===----------------------------------------------------------------------===//
LogicalResult ONNXMatMulOp::inferShapes() { LogicalResult ONNXMatMulOp::inferShapes() {
// Cannot infer shape if no shape exists. // Cannot infer shape if no shape exists.
@ -1006,10 +1037,7 @@ LogicalResult ONNXMatMulOp::inferShapes() {
return success(); return success();
} }
//===----------------------------------------------------------------------===//
// Gemm // Gemm
LogicalResult ONNXGemmOp::inferShapes() { LogicalResult ONNXGemmOp::inferShapes() {
bool hasBias = !C().getType().isa<NoneType>(); bool hasBias = !C().getType().isa<NoneType>();
// Cannot infer shape if no shape exists. // Cannot infer shape if no shape exists.
@ -1104,8 +1132,8 @@ LogicalResult ONNXBatchNormalizationTestModeOp::inferShapes() {
// Take into account the dimensionality of the matrix. // Take into account the dimensionality of the matrix.
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Reshape // Reshape
//===----------------------------------------------------------------------===//
LogicalResult ONNXReshapeOp::inferShapes() { LogicalResult ONNXReshapeOp::inferShapes() {
// Cannot infer shape if no shape tensor is specified. // Cannot infer shape if no shape tensor is specified.
@ -1177,8 +1205,6 @@ LogicalResult ONNXReshapeOp::inferShapes() {
return success(); return success();
} }
//===----------------------------------------------------------------------===//
// Transpose // Transpose
LogicalResult ONNXTransposeOp::inferShapes() { LogicalResult ONNXTransposeOp::inferShapes() {
@ -1206,8 +1232,8 @@ LogicalResult ONNXTransposeOp::inferShapes() {
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// ReduceMax // ReduceMax
//===----------------------------------------------------------------------===//
LogicalResult ONNXReduceMaxOp::inferShapes() { LogicalResult ONNXReduceMaxOp::inferShapes() {
if (!getOperand().getType().isa<RankedTensorType>()) if (!getOperand().getType().isa<RankedTensorType>())
@ -1219,8 +1245,8 @@ LogicalResult ONNXReduceMaxOp::inferShapes() {
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// ReduceMin // ReduceMin
//===----------------------------------------------------------------------===//
LogicalResult ONNXReduceMinOp::inferShapes() { LogicalResult ONNXReduceMinOp::inferShapes() {
if (!getOperand().getType().isa<RankedTensorType>()) if (!getOperand().getType().isa<RankedTensorType>())
@ -1232,8 +1258,8 @@ LogicalResult ONNXReduceMinOp::inferShapes() {
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// ReduceProd // ReduceProd
//===----------------------------------------------------------------------===//
LogicalResult ONNXReduceProdOp::inferShapes() { LogicalResult ONNXReduceProdOp::inferShapes() {
if (!getOperand().getType().isa<RankedTensorType>()) if (!getOperand().getType().isa<RankedTensorType>())
@ -1245,8 +1271,8 @@ LogicalResult ONNXReduceProdOp::inferShapes() {
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// ReduceSum // ReduceSum
//===----------------------------------------------------------------------===//
LogicalResult ONNXReduceSumOp::inferShapes() { LogicalResult ONNXReduceSumOp::inferShapes() {
if (!getOperand().getType().isa<RankedTensorType>()) if (!getOperand().getType().isa<RankedTensorType>())
@ -1258,8 +1284,8 @@ LogicalResult ONNXReduceSumOp::inferShapes() {
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Conv // Conv
//===----------------------------------------------------------------------===//
// For this operation, we define the attributes once in the original Conv // For this operation, we define the attributes once in the original Conv
// operation class. There is no need to redefine the attribute names for the // operation class. There is no need to redefine the attribute names for the
@ -1373,8 +1399,8 @@ LogicalResult ONNXConvOp::inferShapes() {
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// ConvTranspose // ConvTranspose
//===----------------------------------------------------------------------===//
// For this operation, we define the attributes once in the original Conv // For this operation, we define the attributes once in the original Conv
// operation class. There is no need to redefine the attribute names for the // operation class. There is no need to redefine the attribute names for the
@ -1505,8 +1531,9 @@ LogicalResult ONNXConvTransposeOp::inferShapes() {
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// AveragePool // AveragePool
//===----------------------------------------------------------------------===//
// Infer shape attributes output: // Infer shape attributes output:
// - auto_pad set to NOTSET; // - auto_pad set to NOTSET;
// - strides: set to 1 if not defined by user; // - strides: set to 1 if not defined by user;
@ -1557,8 +1584,9 @@ LogicalResult ONNXAveragePoolOp::inferShapes() {
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// MaxPoolSingleOut // MaxPoolSingleOut
//===----------------------------------------------------------------------===//
// Infer shape attributes output: // Infer shape attributes output:
// - auto_pad set to NOTSET; // - auto_pad set to NOTSET;
// - dilations, strides: set to 1 if not defined by user; // - dilations, strides: set to 1 if not defined by user;
@ -1607,6 +1635,8 @@ LogicalResult ONNXMaxPoolSingleOutOp::inferShapes() {
return success(); return success();
} }
//===----------------------------------------------------------------------===//
// Pad
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
LogicalResult ONNXPadOp::inferShapes() { LogicalResult ONNXPadOp::inferShapes() {
@ -1678,7 +1708,9 @@ static Type padShapeInferenceHelper(Value data, ArrayAttr padsOpt) {
} }
} }
//===----------------------------------------------------------------------===//
// PadConstantPad // PadConstantPad
//===----------------------------------------------------------------------===//
LogicalResult ONNXPadConstantPadOp::inferShapes() { LogicalResult ONNXPadConstantPadOp::inferShapes() {
auto outputType = padShapeInferenceHelper(data(), pads()); auto outputType = padShapeInferenceHelper(data(), pads());
@ -1689,8 +1721,8 @@ LogicalResult ONNXPadConstantPadOp::inferShapes() {
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// PadConstantValuePad // PadConstantValuePad
//===----------------------------------------------------------------------===//
LogicalResult ONNXPadConstantValuePadOp::inferShapes() { LogicalResult ONNXPadConstantValuePadOp::inferShapes() {
auto outputType = padShapeInferenceHelper(data(), pads()); auto outputType = padShapeInferenceHelper(data(), pads());
@ -1711,8 +1743,8 @@ void ONNXPadConstantValuePadOp::build(OpBuilder &builder, OperationState &state,
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Unsqueeze // Unsqueeze
//===----------------------------------------------------------------------===//
LogicalResult ONNXUnsqueezeOp::inferShapes() { LogicalResult ONNXUnsqueezeOp::inferShapes() {
if (!data().getType().isa<RankedTensorType>()) if (!data().getType().isa<RankedTensorType>())
@ -1753,6 +1785,7 @@ LogicalResult ONNXUnsqueezeOp::inferShapes() {
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Cast // Cast
//===----------------------------------------------------------------------===//
LogicalResult ONNXCastOp::inferShapes() { LogicalResult ONNXCastOp::inferShapes() {
ShapedType inputType = input().getType().dyn_cast<ShapedType>(); ShapedType inputType = input().getType().dyn_cast<ShapedType>();
@ -1781,6 +1814,7 @@ LogicalResult ONNXCastOp::inferShapes() {
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Constant // Constant
//===----------------------------------------------------------------------===//
LogicalResult ONNXConstantOp::inferShapes() { LogicalResult ONNXConstantOp::inferShapes() {
if ((sparse_value().hasValue() && value().hasValue()) || if ((sparse_value().hasValue() && value().hasValue()) ||
@ -1796,7 +1830,9 @@ LogicalResult ONNXConstantOp::inferShapes() {
return success(); return success();
} }
//===----------------------------------------------------------------------===//
// Concat // Concat
//===----------------------------------------------------------------------===//
LogicalResult ONNXConcatOp::inferShapes() { LogicalResult ONNXConcatOp::inferShapes() {
int inputNum = getNumOperands(); int inputNum = getNumOperands();
@ -1854,21 +1890,25 @@ LogicalResult ONNXConcatOp::inferShapes() {
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// RNN // RNN
//===----------------------------------------------------------------------===//
LogicalResult ONNXRNNOp::inferShapes() { return RNNShapeInference<>(this); } LogicalResult ONNXRNNOp::inferShapes() { return RNNShapeInference<>(this); }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// LSTM // LSTM
//===----------------------------------------------------------------------===//
LogicalResult ONNXLSTMOp::inferShapes() { return RNNShapeInference<>(this); } LogicalResult ONNXLSTMOp::inferShapes() { return RNNShapeInference<>(this); }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// GRU // GRU
//===----------------------------------------------------------------------===//
LogicalResult ONNXGRUOp::inferShapes() { return RNNShapeInference<>(this); } LogicalResult ONNXGRUOp::inferShapes() { return RNNShapeInference<>(this); }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Split // Split
//===----------------------------------------------------------------------===//
LogicalResult ONNXSplitOp::inferShapes() { LogicalResult ONNXSplitOp::inferShapes() {
if (!getOperand().getType().cast<RankedTensorType>()) if (!getOperand().getType().cast<RankedTensorType>())
@ -1933,6 +1973,7 @@ LogicalResult ONNXSplitOp::inferShapes() {
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Flatten // Flatten
//===----------------------------------------------------------------------===//
LogicalResult ONNXFlattenOp::inferShapes() { LogicalResult ONNXFlattenOp::inferShapes() {
assert(axis() == 1 && "ONNXFlattenOp can only handle axis=1 for now"); assert(axis() == 1 && "ONNXFlattenOp can only handle axis=1 for now");
@ -1967,6 +2008,7 @@ LogicalResult ONNXFlattenOp::inferShapes() {
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// DynamicQuantizeLinear // DynamicQuantizeLinear
//===----------------------------------------------------------------------===//
LogicalResult ONNXDynamicQuantizeLinearOp::inferShapes() { LogicalResult ONNXDynamicQuantizeLinearOp::inferShapes() {
auto inTy = x().getType().dyn_cast<RankedTensorType>(); auto inTy = x().getType().dyn_cast<RankedTensorType>();
@ -2000,6 +2042,7 @@ LogicalResult ONNXDynamicQuantizeLinearOp::inferShapes() {
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// QuantizeLinear // QuantizeLinear
//===----------------------------------------------------------------------===//
LogicalResult ONNXQuantizeLinearOp::inferShapes() { LogicalResult ONNXQuantizeLinearOp::inferShapes() {
auto inTy = x().getType().dyn_cast<RankedTensorType>(); auto inTy = x().getType().dyn_cast<RankedTensorType>();
@ -2022,6 +2065,7 @@ LogicalResult ONNXQuantizeLinearOp::inferShapes() {
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// DequantizeLinear // DequantizeLinear
//===----------------------------------------------------------------------===//
LogicalResult ONNXDequantizeLinearOp::inferShapes() { LogicalResult ONNXDequantizeLinearOp::inferShapes() {
auto inTy = x().getType().dyn_cast<RankedTensorType>(); auto inTy = x().getType().dyn_cast<RankedTensorType>();
@ -2042,6 +2086,7 @@ LogicalResult ONNXDequantizeLinearOp::inferShapes() {
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// ConvInteger - copied almost exactly from Conv (X -> x, W -> w, no bias) // ConvInteger - copied almost exactly from Conv (X -> x, W -> w, no bias)
//===----------------------------------------------------------------------===//
LogicalResult ONNXConvIntegerOp::inferShapes() { LogicalResult ONNXConvIntegerOp::inferShapes() {
// Generic shape for data input X, weight tensor W // Generic shape for data input X, weight tensor W

View File

@ -26,27 +26,27 @@ add_dependencies(OMElideConstants OMONNXOps)
target_link_libraries(OMElideConstants target_link_libraries(OMElideConstants
onnx) onnx)
set(LLVM_TARGET_DEFINITIONS ONNXRewrite.td) set(LLVM_TARGET_DEFINITIONS Rewrite.td)
onnx_mlir_tablegen(ONNXRewrite.inc -gen-rewriters) onnx_mlir_tablegen(ONNXRewrite.inc -gen-rewriters)
add_public_tablegen_target(OMONNXRewriteIncGen) add_public_tablegen_target(OMONNXRewriteIncGen)
set(LLVM_TARGET_DEFINITIONS ONNXCombine.td) set(LLVM_TARGET_DEFINITIONS Combine.td)
onnx_mlir_tablegen(ONNXCombine.inc -gen-rewriters) onnx_mlir_tablegen(ONNXCombine.inc -gen-rewriters)
add_public_tablegen_target(OMONNXCombineIncGen) add_public_tablegen_target(OMONNXCombineIncGen)
set(LLVM_TARGET_DEFINITIONS ONNXDecompose.td) set(LLVM_TARGET_DEFINITIONS Decompose.td)
onnx_mlir_tablegen(ONNXDecompose.inc -gen-rewriters) onnx_mlir_tablegen(ONNXDecompose.inc -gen-rewriters)
add_public_tablegen_target(OMONNXDecomposeIncGen) add_public_tablegen_target(OMONNXDecomposeIncGen)
set(LLVM_TARGET_DEFINITIONS ONNXConstProp.td) set(LLVM_TARGET_DEFINITIONS ConstProp.td)
onnx_mlir_tablegen(ONNXConstProp.inc -gen-rewriters) onnx_mlir_tablegen(ONNXConstProp.inc -gen-rewriters)
add_public_tablegen_target(OMONNXConstPropIncGen) add_public_tablegen_target(OMONNXConstPropIncGen)
add_library(OMONNXRewrite add_library(OMONNXRewrite
ONNXRewrite.cpp Rewrite.cpp
ONNXCombine.cpp Combine.cpp
ONNXDecompose.cpp Decompose.cpp
ONNXConstProp.cpp) ConstProp.cpp)
target_include_directories(OMONNXRewrite target_include_directories(OMONNXRewrite
PRIVATE ${ONNX_MLIR_SRC_ROOT} ${ONNX_MLIR_BIN_ROOT} PRIVATE ${ONNX_MLIR_SRC_ROOT} ${ONNX_MLIR_BIN_ROOT}
${ONNF_MLIR_SRC_ROOT}) ${ONNF_MLIR_SRC_ROOT})

View File

@ -25,16 +25,17 @@ using namespace mlir;
namespace { namespace {
// ============================================================================= //===----------------------------------------------------------------------===//
// Instructions to add a constant operation. There is currently support for // Instructions to add a constant operation.
// adding constant propagation for unary and binary athythmetic ops (binary ops //===----------------------------------------------------------------------===//
// support broadcast). To add an operation, you simply have to add a templated // There is currently support for adding constant propagation for unary and
// method on how to compute the result in terms of one or two inputs. Values // binary athythmetic ops (binary ops support broadcast). To add an operation,
// comes as Attribtues, and return is also an Attribute. In that function, // you simply have to add a templated method on how to compute the result in
// presumably you will need different methods to handle int / float / // terms of one or two inputs. Values comes as Attribtues, and return is also an
// strings... Note that these methods cannot fail. It is your responsablitity to // Attribute. In that function, presumably you will need different methods to
// tests for which data type are supported in the rules directly. Specific type // handle int / float / strings... Note that these methods cannot fail. It is
// restrictions can be added in the DRR files. // your responsablitity to tests for which data type are supported in the rules
// directly. Specific type restrictions can be added in the DRR files.
// The methods are: // The methods are:
// //
@ -42,13 +43,12 @@ namespace {
// and they need to be tempalted wtih an ONNX Operation (presuably). // and they need to be tempalted wtih an ONNX Operation (presuably).
// //
// Then you need to add rules on how to transform the patterns; look into // Then you need to add rules on how to transform the patterns; look into
// ONNXConstProp.td for example. // ConstProp.td for example.
// //
// =============================================================================
// ============================================================================= //===----------------------------------------------------------------------===//
// Code to perform constant propagation for binary in presence of broadcast. // Code to perform constant propagation for binary in presence of broadcast.
// ============================================================================= //===----------------------------------------------------------------------===//
// Template to generate binary operation results. It takes as inupt // Template to generate binary operation results. It takes as inupt
// the element type as well as the two element attributes for the // the element type as well as the two element attributes for the
@ -223,9 +223,9 @@ DenseElementsAttr ConstPropElementwiseBinary(PatternRewriter &rewriter,
return DenseElementsAttr::get(resType, resRef); return DenseElementsAttr::get(resType, resRef);
} }
// ============================================================================= //===----------------------------------------------------------------------===//
// Code to perform constant propagation for unary operation. // Code to perform constant propagation for unary operation.
// ============================================================================= //===----------------------------------------------------------------------===//
template <typename OP> template <typename OP>
Attribute ComputeConstProppElementwiseUnary( Attribute ComputeConstProppElementwiseUnary(
@ -295,15 +295,15 @@ DenseElementsAttr ConstPropElementwiseUnary(
return DenseElementsAttr::get(resType, resRef); return DenseElementsAttr::get(resType, resRef);
} }
// ============================================================================= //===----------------------------------------------------------------------===//
// Pattern definition. // Pattern definition.
// ============================================================================= //===----------------------------------------------------------------------===//
#include "src/Transform/ONNX/ONNXConstProp.inc" #include "src/Transform/ONNX/ONNXConstProp.inc"
// ============================================================================= //===----------------------------------------------------------------------===//
// Code to manage the pass. // Code to manage the pass.
// ============================================================================= //===----------------------------------------------------------------------===//
struct ConstPropONNXToONNXPass struct ConstPropONNXToONNXPass
: public PassWrapper<ConstPropONNXToONNXPass, FunctionPass> { : public PassWrapper<ConstPropONNXToONNXPass, FunctionPass> {

View File

@ -16,13 +16,14 @@
include "src/Dialect/ONNX/ONNXOps.td" include "src/Dialect/ONNX/ONNXOps.td"
#endif // OP_BASE #endif // OP_BASE
//===----------------------------------------------------------------------===//
// ============================================================================= // Instruction to add new constant operation rules.
// Instruction to add new constant operation rules. Minimally, you will have added //===----------------------------------------------------------------------===//
// operation in the ONNXConstProp.cpp to perform the element-wise single value // Minimally, you will have added operation in the ONNXConstProp.cpp to perform
// handling of the new operator that you are dealing with. You will need to // the element-wise single value handling of the new operator that you are dealing
// generate a call to the method that handle the tensor constant prop. Here // with. You will need to generate a call to the method that handle the tensor
// is the call for a unary and binary operation. Adapt to your new operator: // constant prop. Here is the call for a unary and binary operation. Adapt to your
// new operator:
// //
// def CreateAddOfTwoConst : // def CreateAddOfTwoConst :
// NativeCodeCall<"ConstPropElementwiseBinary<mlir::ONNXAddOp>($_builder, $0, $1, $2)">; // NativeCodeCall<"ConstPropElementwiseBinary<mlir::ONNXAddOp>($_builder, $0, $1, $2)">;
@ -34,7 +35,6 @@ include "src/Dialect/ONNX/ONNXOps.td"
// a new def name. // a new def name.
// //
// Then you will need to add substitution rules, see examples below. // Then you will need to add substitution rules, see examples below.
// =============================================================================
// Useful test definitions. // Useful test definitions.
@ -63,8 +63,9 @@ def CreateNegOfConst :
def CreateMulOfTwoConst : def CreateMulOfTwoConst :
NativeCodeCall<"ConstPropElementwiseBinary<mlir::ONNXMulOp>($_builder, $0, $1, $2)">; NativeCodeCall<"ConstPropElementwiseBinary<mlir::ONNXMulOp>($_builder, $0, $1, $2)">;
// ============================================================================= //===----------------------------------------------------------------------===//
// Patterns to enable opportunities with elementwise ADD operations. // Patterns to enable opportunities with elementwise ADD operations.
//===----------------------------------------------------------------------===//
// Use commutativity to normalize constants in the second position of Add. // Use commutativity to normalize constants in the second position of Add.
def AddConstCommutative1 : Pat< def AddConstCommutative1 : Pat<
@ -96,8 +97,9 @@ def AddConstProp : Pat<
[(AttributeIsNull:$s1), (AttributeIsNull:$s2)]>; [(AttributeIsNull:$s1), (AttributeIsNull:$s2)]>;
// ============================================================================= //===----------------------------------------------------------------------===//
// Patterns to enable opportunities with elementwise SUB / NEG operations. // Patterns to enable opportunities with elementwise SUB / NEG operations.
//===----------------------------------------------------------------------===//
// Constant Propagation for Sub // Constant Propagation for Sub
def SubConstProp : Pat< def SubConstProp : Pat<
@ -125,9 +127,10 @@ def SubConstToNeg : Pat<
[(IsNotAConstant:$x), (AttributeIsNull:$s)]>; [(IsNotAConstant:$x), (AttributeIsNull:$s)]>;
// ============================================================================= //===----------------------------------------------------------------------===//
// Patterns to enable opportunities with elementwise MUL operations. // Patterns to enable opportunities with elementwise MUL operations.
// Exactly the same pattern as for the elementwise ADD operations. // Exactly the same pattern as for the elementwise ADD operations.
//===----------------------------------------------------------------------===//
// Use commutativity to normalize constants in the second position of Mul. // Use commutativity to normalize constants in the second position of Mul.
def MulConstCommutative1 : Pat< def MulConstCommutative1 : Pat<