diff --git a/src/Dialect/ONNX/ONNXOps.cpp b/src/Dialect/ONNX/ONNXOps.cpp index 99a6c78..4e22262 100644 --- a/src/Dialect/ONNX/ONNXOps.cpp +++ b/src/Dialect/ONNX/ONNXOps.cpp @@ -122,7 +122,7 @@ RankedTensorType getReductionOutputType( //===----------------------------------------------------------------------===// // Support function that computes default values for dilations. -// +//===----------------------------------------------------------------------===// template static LogicalResult processConvDilationParam( T *op, Optional kernelShape) { @@ -153,7 +153,7 @@ static LogicalResult processConvDilationParam( //===----------------------------------------------------------------------===// // Support function that computes default values for strides. -// +//===----------------------------------------------------------------------===// template static LogicalResult processConvStrideParam( T *op, Optional kernelShape) { @@ -181,7 +181,7 @@ static LogicalResult processConvStrideParam( //===----------------------------------------------------------------------===// // Support function that computes default values for pads. -// +//===----------------------------------------------------------------------===// template static LogicalResult processConvPadParam(T *op, ArrayRef inputShape, Optional kernelShape, Optional stridesOpt, @@ -273,8 +273,8 @@ static LogicalResult processConvPadParam(T *op, ArrayRef inputShape, } //===----------------------------------------------------------------------===// -// Support function that computes default values for dilations, strides, and -// pads. +// Support function computing default values for dilations, strides, and pads. +//===----------------------------------------------------------------------===// template static LogicalResult processConvTypeParams(T *op, Value inputOperand) { auto builder = mlir::Builder(op->getContext()); @@ -305,7 +305,7 @@ static LogicalResult processConvTypeParams(T *op, Value inputOperand) { //===----------------------------------------------------------------------===// // Compute spatial dimensions given dilations, strides, pads, and ceil mode. -// +//===----------------------------------------------------------------------===// static void insertConvSpatialDim(SmallVector *outputDims, Builder &builder, ArrayRef xShape, Optional kernelShape, Optional padsOpt, Optional stridesOpt, @@ -335,6 +335,7 @@ static void insertConvSpatialDim(SmallVector *outputDims, //===----------------------------------------------------------------------===// // Support function that infers shape for RNN operations. +//===----------------------------------------------------------------------===// template static LogicalResult RNNShapeInference(T *op) { Value X = op->X(); @@ -516,6 +517,7 @@ LogicalResult ONNXExpOp::inferShapes() { //===----------------------------------------------------------------------===// // Atan +//===----------------------------------------------------------------------===// /// Infer the output shape of the ONNXAtanOp. This method is required by the /// shape inference interface. LogicalResult ONNXAtanOp::inferShapes() { @@ -525,6 +527,7 @@ LogicalResult ONNXAtanOp::inferShapes() { //===----------------------------------------------------------------------===// // Tan +//===----------------------------------------------------------------------===// /// Infer the output shape of the ONNXTanOp. This method is required by the /// shape inference interface. LogicalResult ONNXTanOp::inferShapes() { @@ -543,6 +546,7 @@ LogicalResult ONNXTanhOp::inferShapes() { //===----------------------------------------------------------------------===// // Sin +//===----------------------------------------------------------------------===// /// Infer the output shape of the ONNXSinOp. This method is required by the /// shape inference interface. LogicalResult ONNXSinOp::inferShapes() { @@ -552,6 +556,7 @@ LogicalResult ONNXSinOp::inferShapes() { //===----------------------------------------------------------------------===// // Sinh +//===----------------------------------------------------------------------===// /// Infer the output shape of the ONNXSinhOp. This method is required by the /// shape inference interface. LogicalResult ONNXSinhOp::inferShapes() { @@ -561,6 +566,7 @@ LogicalResult ONNXSinhOp::inferShapes() { //===----------------------------------------------------------------------===// // Cosh +//===----------------------------------------------------------------------===// /// Infer the output shape of the ONNXCoshOp. This method is required by the /// shape inference interface. LogicalResult ONNXCoshOp::inferShapes() { @@ -570,6 +576,7 @@ LogicalResult ONNXCoshOp::inferShapes() { //===----------------------------------------------------------------------===// // Cos +//===----------------------------------------------------------------------===// /// Infer the output shape of the ONNXCosOp. This method is required by the /// shape inference interface. LogicalResult ONNXCosOp::inferShapes() { @@ -579,6 +586,7 @@ LogicalResult ONNXCosOp::inferShapes() { //===----------------------------------------------------------------------===// // Log +//===----------------------------------------------------------------------===// /// Infer the output shape of the ONNXLogOp. This method is required by the /// shape inference interface. LogicalResult ONNXLogOp::inferShapes() { @@ -588,6 +596,7 @@ LogicalResult ONNXLogOp::inferShapes() { //===----------------------------------------------------------------------===// // HardSigmoid +//===----------------------------------------------------------------------===// /// Infer the output shape of the ONNXHardSigmoidOp. This method is required by /// the shape inference interface. LogicalResult ONNXHardSigmoidOp::inferShapes() { @@ -597,6 +606,7 @@ LogicalResult ONNXHardSigmoidOp::inferShapes() { //===----------------------------------------------------------------------===// // Sigmoid +//===----------------------------------------------------------------------===// /// Infer the output shape of the ONNXSigmoidOp. This method is required by the /// shape inference interface. LogicalResult ONNXSigmoidOp::inferShapes() { @@ -606,6 +616,7 @@ LogicalResult ONNXSigmoidOp::inferShapes() { //===----------------------------------------------------------------------===// // Elu +//===----------------------------------------------------------------------===// /// Infer the output shape of the ONNXEluOp. This method is required by the /// shape inference interface. LogicalResult ONNXEluOp::inferShapes() { @@ -615,6 +626,7 @@ LogicalResult ONNXEluOp::inferShapes() { //===----------------------------------------------------------------------===// // Relu +//===----------------------------------------------------------------------===// /// Infer the output shape of the ONNXReluOp. This method is required by the /// shape inference interface. LogicalResult ONNXReluOp::inferShapes() { @@ -624,6 +636,7 @@ LogicalResult ONNXReluOp::inferShapes() { //===----------------------------------------------------------------------===// // LeakyRelu +//===----------------------------------------------------------------------===// /// Infer the output shape of the ONNXLeakyReluOp. This method is required by /// the shape inference interface. LogicalResult ONNXLeakyReluOp::inferShapes() { @@ -633,6 +646,7 @@ LogicalResult ONNXLeakyReluOp::inferShapes() { //===----------------------------------------------------------------------===// // Selu +//===----------------------------------------------------------------------===// /// Infer the output shape of the ONNXSeluOp. This method is required by /// the shape inference interface. LogicalResult ONNXSeluOp::inferShapes() { @@ -642,6 +656,7 @@ LogicalResult ONNXSeluOp::inferShapes() { //===----------------------------------------------------------------------===// // Reciprocal +//===----------------------------------------------------------------------===// /// Infer the output shape of the ONNXReciprocalOp. This method is required by /// the shape inference interface. LogicalResult ONNXReciprocalOp::inferShapes() { @@ -651,6 +666,7 @@ LogicalResult ONNXReciprocalOp::inferShapes() { //===----------------------------------------------------------------------===// // Softmax +//===----------------------------------------------------------------------===// /// Infer the output shape of the ONNXSoftmaxOp. This method is required by /// the shape inference interface. LogicalResult ONNXSoftmaxOp::inferShapes() { @@ -660,6 +676,7 @@ LogicalResult ONNXSoftmaxOp::inferShapes() { //===----------------------------------------------------------------------===// // Softplus +//===----------------------------------------------------------------------===// /// Infer the output shape of the ONNXSoftplusOp. This method is required by /// the shape inference interface. LogicalResult ONNXSoftplusOp::inferShapes() { @@ -669,6 +686,7 @@ LogicalResult ONNXSoftplusOp::inferShapes() { //===----------------------------------------------------------------------===// // Softsign +//===----------------------------------------------------------------------===// /// Infer the output shape of the ONNXSoftsignOp. This method is required by /// the shape inference interface. LogicalResult ONNXSoftsignOp::inferShapes() { @@ -678,6 +696,7 @@ LogicalResult ONNXSoftsignOp::inferShapes() { //===----------------------------------------------------------------------===// // Sqrt +//===----------------------------------------------------------------------===// /// Infer the output shape of the ONNXSqrtOp. This method is required by /// the shape inference interface. LogicalResult ONNXSqrtOp::inferShapes() { @@ -687,6 +706,7 @@ LogicalResult ONNXSqrtOp::inferShapes() { //===----------------------------------------------------------------------===// // Sign +//===----------------------------------------------------------------------===// /// Infer the output shape of the ONNXSignOp. This method is required by /// the shape inference interface. LogicalResult ONNXSignOp::inferShapes() { @@ -696,6 +716,7 @@ LogicalResult ONNXSignOp::inferShapes() { //===----------------------------------------------------------------------===// // Abs +//===----------------------------------------------------------------------===// /// Infer the output shape of the ONNXAbsOp. This method is required by the /// shape inference interface. LogicalResult ONNXAbsOp::inferShapes() { @@ -705,6 +726,7 @@ LogicalResult ONNXAbsOp::inferShapes() { //===----------------------------------------------------------------------===// // Add +//===----------------------------------------------------------------------===// /// Infer the output shape of the ONNXAddOp. This method is required by the /// shape inference interface. LogicalResult ONNXAddOp::inferShapes() { @@ -719,6 +741,7 @@ LogicalResult ONNXAddOp::inferShapes() { //===----------------------------------------------------------------------===// // Mul +//===----------------------------------------------------------------------===// /// Infer the output shape of the ONNXMulOp. This method is required by the /// shape inference interface. LogicalResult ONNXMulOp::inferShapes() { @@ -733,6 +756,7 @@ LogicalResult ONNXMulOp::inferShapes() { //===----------------------------------------------------------------------===// // Div +//===----------------------------------------------------------------------===// /// Infer the output shape of the ONNXDivOp. This method is required by the /// shape inference interface. LogicalResult ONNXDivOp::inferShapes() { @@ -747,6 +771,7 @@ LogicalResult ONNXDivOp::inferShapes() { //===----------------------------------------------------------------------===// // Sub +//===----------------------------------------------------------------------===// /// Infer the output shape of the ONNXSubOp. This method is required by the /// shape inference interface. LogicalResult ONNXSubOp::inferShapes() { @@ -761,6 +786,7 @@ LogicalResult ONNXSubOp::inferShapes() { //===----------------------------------------------------------------------===// // And +//===----------------------------------------------------------------------===// /// Infer the output shape of the ONNXAndOp. This method is required by the /// shape inference interface. LogicalResult ONNXAndOp::inferShapes() { @@ -775,6 +801,7 @@ LogicalResult ONNXAndOp::inferShapes() { //===----------------------------------------------------------------------===// // Or +//===----------------------------------------------------------------------===// /// Infer the output shape of the ONNXOrOp. This method is required by the /// shape inference interface. LogicalResult ONNXOrOp::inferShapes() { @@ -789,6 +816,7 @@ LogicalResult ONNXOrOp::inferShapes() { //===----------------------------------------------------------------------===// // Xor +//===----------------------------------------------------------------------===// /// Infer the output shape of the ONNXXorOp. This method is required by the /// shape inference interface. LogicalResult ONNXXorOp::inferShapes() { @@ -801,10 +829,9 @@ LogicalResult ONNXXorOp::inferShapes() { return success(); } -//===----------------------------------------------------------------------===// - //===----------------------------------------------------------------------===// // Sum +//===----------------------------------------------------------------------===// /// Infer the output shape of the ONNXSumOp. This method is required by the /// shape inference interface. LogicalResult ONNXSumOp::inferShapes() { @@ -823,6 +850,7 @@ LogicalResult ONNXSumOp::inferShapes() { //===----------------------------------------------------------------------===// // Max +//===----------------------------------------------------------------------===// /// Infer the output shape of the ONNXMaxOp. This method is required by the /// shape inference interface. LogicalResult ONNXMaxOp::inferShapes() { @@ -841,6 +869,7 @@ LogicalResult ONNXMaxOp::inferShapes() { //===----------------------------------------------------------------------===// // Min +//===----------------------------------------------------------------------===// /// Infer the output shape of the ONNXMinOp. This method is required by the /// shape inference interface. LogicalResult ONNXMinOp::inferShapes() { @@ -859,6 +888,7 @@ LogicalResult ONNXMinOp::inferShapes() { //===----------------------------------------------------------------------===// // Neg +//===----------------------------------------------------------------------===// /// Infer the output shape of the ONNXNegOp. This method is required by the /// shape inference interface. LogicalResult ONNXNegOp::inferShapes() { @@ -868,6 +898,7 @@ LogicalResult ONNXNegOp::inferShapes() { //===----------------------------------------------------------------------===// // Identity +//===----------------------------------------------------------------------===// /// Infer the output shape of the ONNXIdentityOp. This method is required by the /// shape inference interface. LogicalResult ONNXIdentityOp::inferShapes() { @@ -876,8 +907,8 @@ LogicalResult ONNXIdentityOp::inferShapes() { } //===----------------------------------------------------------------------===// - // MatMul +//===----------------------------------------------------------------------===// LogicalResult ONNXMatMulOp::inferShapes() { // Cannot infer shape if no shape exists. @@ -1006,10 +1037,7 @@ LogicalResult ONNXMatMulOp::inferShapes() { return success(); } -//===----------------------------------------------------------------------===// - // Gemm - LogicalResult ONNXGemmOp::inferShapes() { bool hasBias = !C().getType().isa(); // Cannot infer shape if no shape exists. @@ -1104,8 +1132,8 @@ LogicalResult ONNXBatchNormalizationTestModeOp::inferShapes() { // Take into account the dimensionality of the matrix. //===----------------------------------------------------------------------===// - // Reshape +//===----------------------------------------------------------------------===// LogicalResult ONNXReshapeOp::inferShapes() { // Cannot infer shape if no shape tensor is specified. @@ -1177,8 +1205,6 @@ LogicalResult ONNXReshapeOp::inferShapes() { return success(); } -//===----------------------------------------------------------------------===// - // Transpose LogicalResult ONNXTransposeOp::inferShapes() { @@ -1206,8 +1232,8 @@ LogicalResult ONNXTransposeOp::inferShapes() { } //===----------------------------------------------------------------------===// - // ReduceMax +//===----------------------------------------------------------------------===// LogicalResult ONNXReduceMaxOp::inferShapes() { if (!getOperand().getType().isa()) @@ -1219,8 +1245,8 @@ LogicalResult ONNXReduceMaxOp::inferShapes() { } //===----------------------------------------------------------------------===// - // ReduceMin +//===----------------------------------------------------------------------===// LogicalResult ONNXReduceMinOp::inferShapes() { if (!getOperand().getType().isa()) @@ -1232,8 +1258,8 @@ LogicalResult ONNXReduceMinOp::inferShapes() { } //===----------------------------------------------------------------------===// - // ReduceProd +//===----------------------------------------------------------------------===// LogicalResult ONNXReduceProdOp::inferShapes() { if (!getOperand().getType().isa()) @@ -1245,8 +1271,8 @@ LogicalResult ONNXReduceProdOp::inferShapes() { } //===----------------------------------------------------------------------===// - // ReduceSum +//===----------------------------------------------------------------------===// LogicalResult ONNXReduceSumOp::inferShapes() { if (!getOperand().getType().isa()) @@ -1258,8 +1284,8 @@ LogicalResult ONNXReduceSumOp::inferShapes() { } //===----------------------------------------------------------------------===// - // Conv +//===----------------------------------------------------------------------===// // For this operation, we define the attributes once in the original Conv // operation class. There is no need to redefine the attribute names for the @@ -1373,8 +1399,8 @@ LogicalResult ONNXConvOp::inferShapes() { } //===----------------------------------------------------------------------===// - // ConvTranspose +//===----------------------------------------------------------------------===// // For this operation, we define the attributes once in the original Conv // operation class. There is no need to redefine the attribute names for the @@ -1505,8 +1531,9 @@ LogicalResult ONNXConvTransposeOp::inferShapes() { } //===----------------------------------------------------------------------===// - // AveragePool +//===----------------------------------------------------------------------===// + // Infer shape attributes output: // - auto_pad set to NOTSET; // - strides: set to 1 if not defined by user; @@ -1557,8 +1584,9 @@ LogicalResult ONNXAveragePoolOp::inferShapes() { } //===----------------------------------------------------------------------===// - // MaxPoolSingleOut +//===----------------------------------------------------------------------===// + // Infer shape attributes output: // - auto_pad set to NOTSET; // - dilations, strides: set to 1 if not defined by user; @@ -1607,6 +1635,8 @@ LogicalResult ONNXMaxPoolSingleOutOp::inferShapes() { return success(); } +//===----------------------------------------------------------------------===// +// Pad //===----------------------------------------------------------------------===// LogicalResult ONNXPadOp::inferShapes() { @@ -1678,7 +1708,9 @@ static Type padShapeInferenceHelper(Value data, ArrayAttr padsOpt) { } } +//===----------------------------------------------------------------------===// // PadConstantPad +//===----------------------------------------------------------------------===// LogicalResult ONNXPadConstantPadOp::inferShapes() { auto outputType = padShapeInferenceHelper(data(), pads()); @@ -1689,8 +1721,8 @@ LogicalResult ONNXPadConstantPadOp::inferShapes() { } //===----------------------------------------------------------------------===// - // PadConstantValuePad +//===----------------------------------------------------------------------===// LogicalResult ONNXPadConstantValuePadOp::inferShapes() { auto outputType = padShapeInferenceHelper(data(), pads()); @@ -1711,8 +1743,8 @@ void ONNXPadConstantValuePadOp::build(OpBuilder &builder, OperationState &state, } //===----------------------------------------------------------------------===// - // Unsqueeze +//===----------------------------------------------------------------------===// LogicalResult ONNXUnsqueezeOp::inferShapes() { if (!data().getType().isa()) @@ -1753,6 +1785,7 @@ LogicalResult ONNXUnsqueezeOp::inferShapes() { //===----------------------------------------------------------------------===// // Cast +//===----------------------------------------------------------------------===// LogicalResult ONNXCastOp::inferShapes() { ShapedType inputType = input().getType().dyn_cast(); @@ -1781,6 +1814,7 @@ LogicalResult ONNXCastOp::inferShapes() { //===----------------------------------------------------------------------===// // Constant +//===----------------------------------------------------------------------===// LogicalResult ONNXConstantOp::inferShapes() { if ((sparse_value().hasValue() && value().hasValue()) || @@ -1796,7 +1830,9 @@ LogicalResult ONNXConstantOp::inferShapes() { return success(); } +//===----------------------------------------------------------------------===// // Concat +//===----------------------------------------------------------------------===// LogicalResult ONNXConcatOp::inferShapes() { int inputNum = getNumOperands(); @@ -1854,21 +1890,25 @@ LogicalResult ONNXConcatOp::inferShapes() { //===----------------------------------------------------------------------===// // RNN +//===----------------------------------------------------------------------===// LogicalResult ONNXRNNOp::inferShapes() { return RNNShapeInference<>(this); } //===----------------------------------------------------------------------===// // LSTM +//===----------------------------------------------------------------------===// LogicalResult ONNXLSTMOp::inferShapes() { return RNNShapeInference<>(this); } //===----------------------------------------------------------------------===// // GRU +//===----------------------------------------------------------------------===// LogicalResult ONNXGRUOp::inferShapes() { return RNNShapeInference<>(this); } //===----------------------------------------------------------------------===// // Split +//===----------------------------------------------------------------------===// LogicalResult ONNXSplitOp::inferShapes() { if (!getOperand().getType().cast()) @@ -1933,6 +1973,7 @@ LogicalResult ONNXSplitOp::inferShapes() { //===----------------------------------------------------------------------===// // Flatten +//===----------------------------------------------------------------------===// LogicalResult ONNXFlattenOp::inferShapes() { assert(axis() == 1 && "ONNXFlattenOp can only handle axis=1 for now"); @@ -1967,6 +2008,7 @@ LogicalResult ONNXFlattenOp::inferShapes() { //===----------------------------------------------------------------------===// // DynamicQuantizeLinear +//===----------------------------------------------------------------------===// LogicalResult ONNXDynamicQuantizeLinearOp::inferShapes() { auto inTy = x().getType().dyn_cast(); @@ -2000,6 +2042,7 @@ LogicalResult ONNXDynamicQuantizeLinearOp::inferShapes() { //===----------------------------------------------------------------------===// // QuantizeLinear +//===----------------------------------------------------------------------===// LogicalResult ONNXQuantizeLinearOp::inferShapes() { auto inTy = x().getType().dyn_cast(); @@ -2022,6 +2065,7 @@ LogicalResult ONNXQuantizeLinearOp::inferShapes() { //===----------------------------------------------------------------------===// // DequantizeLinear +//===----------------------------------------------------------------------===// LogicalResult ONNXDequantizeLinearOp::inferShapes() { auto inTy = x().getType().dyn_cast(); @@ -2042,6 +2086,7 @@ LogicalResult ONNXDequantizeLinearOp::inferShapes() { //===----------------------------------------------------------------------===// // ConvInteger - copied almost exactly from Conv (X -> x, W -> w, no bias) +//===----------------------------------------------------------------------===// LogicalResult ONNXConvIntegerOp::inferShapes() { // Generic shape for data input X, weight tensor W diff --git a/src/Transform/ONNX/CMakeLists.txt b/src/Transform/ONNX/CMakeLists.txt index b243b17..8f1b045 100644 --- a/src/Transform/ONNX/CMakeLists.txt +++ b/src/Transform/ONNX/CMakeLists.txt @@ -26,27 +26,27 @@ add_dependencies(OMElideConstants OMONNXOps) target_link_libraries(OMElideConstants onnx) -set(LLVM_TARGET_DEFINITIONS ONNXRewrite.td) +set(LLVM_TARGET_DEFINITIONS Rewrite.td) onnx_mlir_tablegen(ONNXRewrite.inc -gen-rewriters) add_public_tablegen_target(OMONNXRewriteIncGen) -set(LLVM_TARGET_DEFINITIONS ONNXCombine.td) +set(LLVM_TARGET_DEFINITIONS Combine.td) onnx_mlir_tablegen(ONNXCombine.inc -gen-rewriters) add_public_tablegen_target(OMONNXCombineIncGen) -set(LLVM_TARGET_DEFINITIONS ONNXDecompose.td) +set(LLVM_TARGET_DEFINITIONS Decompose.td) onnx_mlir_tablegen(ONNXDecompose.inc -gen-rewriters) add_public_tablegen_target(OMONNXDecomposeIncGen) -set(LLVM_TARGET_DEFINITIONS ONNXConstProp.td) +set(LLVM_TARGET_DEFINITIONS ConstProp.td) onnx_mlir_tablegen(ONNXConstProp.inc -gen-rewriters) add_public_tablegen_target(OMONNXConstPropIncGen) add_library(OMONNXRewrite - ONNXRewrite.cpp - ONNXCombine.cpp - ONNXDecompose.cpp - ONNXConstProp.cpp) + Rewrite.cpp + Combine.cpp + Decompose.cpp + ConstProp.cpp) target_include_directories(OMONNXRewrite PRIVATE ${ONNX_MLIR_SRC_ROOT} ${ONNX_MLIR_BIN_ROOT} ${ONNF_MLIR_SRC_ROOT}) diff --git a/src/Transform/ONNX/ONNXCombine.cpp b/src/Transform/ONNX/Combine.cpp similarity index 100% rename from src/Transform/ONNX/ONNXCombine.cpp rename to src/Transform/ONNX/Combine.cpp diff --git a/src/Transform/ONNX/ONNXCombine.td b/src/Transform/ONNX/Combine.td similarity index 100% rename from src/Transform/ONNX/ONNXCombine.td rename to src/Transform/ONNX/Combine.td diff --git a/src/Transform/ONNX/ONNXConstProp.cpp b/src/Transform/ONNX/ConstProp.cpp similarity index 89% rename from src/Transform/ONNX/ONNXConstProp.cpp rename to src/Transform/ONNX/ConstProp.cpp index 2ff12bc..262cbcb 100644 --- a/src/Transform/ONNX/ONNXConstProp.cpp +++ b/src/Transform/ONNX/ConstProp.cpp @@ -25,16 +25,17 @@ using namespace mlir; namespace { -// ============================================================================= -// Instructions to add a constant operation. There is currently support for -// adding constant propagation for unary and binary athythmetic ops (binary ops -// support broadcast). To add an operation, you simply have to add a templated -// method on how to compute the result in terms of one or two inputs. Values -// comes as Attribtues, and return is also an Attribute. In that function, -// presumably you will need different methods to handle int / float / -// strings... Note that these methods cannot fail. It is your responsablitity to -// tests for which data type are supported in the rules directly. Specific type -// restrictions can be added in the DRR files. +//===----------------------------------------------------------------------===// +// Instructions to add a constant operation. +//===----------------------------------------------------------------------===// +// There is currently support for adding constant propagation for unary and +// binary athythmetic ops (binary ops support broadcast). To add an operation, +// you simply have to add a templated method on how to compute the result in +// terms of one or two inputs. Values comes as Attribtues, and return is also an +// Attribute. In that function, presumably you will need different methods to +// handle int / float / strings... Note that these methods cannot fail. It is +// your responsablitity to tests for which data type are supported in the rules +// directly. Specific type restrictions can be added in the DRR files. // The methods are: // @@ -42,13 +43,12 @@ namespace { // and they need to be tempalted wtih an ONNX Operation (presuably). // // Then you need to add rules on how to transform the patterns; look into -// ONNXConstProp.td for example. +// ConstProp.td for example. // -// ============================================================================= -// ============================================================================= +//===----------------------------------------------------------------------===// // Code to perform constant propagation for binary in presence of broadcast. -// ============================================================================= +//===----------------------------------------------------------------------===// // Template to generate binary operation results. It takes as inupt // the element type as well as the two element attributes for the @@ -223,9 +223,9 @@ DenseElementsAttr ConstPropElementwiseBinary(PatternRewriter &rewriter, return DenseElementsAttr::get(resType, resRef); } -// ============================================================================= +//===----------------------------------------------------------------------===// // Code to perform constant propagation for unary operation. -// ============================================================================= +//===----------------------------------------------------------------------===// template Attribute ComputeConstProppElementwiseUnary( @@ -295,15 +295,15 @@ DenseElementsAttr ConstPropElementwiseUnary( return DenseElementsAttr::get(resType, resRef); } -// ============================================================================= +//===----------------------------------------------------------------------===// // Pattern definition. -// ============================================================================= +//===----------------------------------------------------------------------===// #include "src/Transform/ONNX/ONNXConstProp.inc" -// ============================================================================= +//===----------------------------------------------------------------------===// // Code to manage the pass. -// ============================================================================= +//===----------------------------------------------------------------------===// struct ConstPropONNXToONNXPass : public PassWrapper { diff --git a/src/Transform/ONNX/ONNXConstProp.td b/src/Transform/ONNX/ConstProp.td similarity index 82% rename from src/Transform/ONNX/ONNXConstProp.td rename to src/Transform/ONNX/ConstProp.td index a628d32..649bc57 100644 --- a/src/Transform/ONNX/ONNXConstProp.td +++ b/src/Transform/ONNX/ConstProp.td @@ -16,13 +16,14 @@ include "src/Dialect/ONNX/ONNXOps.td" #endif // OP_BASE - -// ============================================================================= -// Instruction to add new constant operation rules. Minimally, you will have added -// operation in the ONNXConstProp.cpp to perform the element-wise single value -// handling of the new operator that you are dealing with. You will need to -// generate a call to the method that handle the tensor constant prop. Here -// is the call for a unary and binary operation. Adapt to your new operator: +//===----------------------------------------------------------------------===// +// Instruction to add new constant operation rules. +//===----------------------------------------------------------------------===// +// Minimally, you will have added operation in the ONNXConstProp.cpp to perform +// the element-wise single value handling of the new operator that you are dealing +// with. You will need to generate a call to the method that handle the tensor +// constant prop. Here is the call for a unary and binary operation. Adapt to your +// new operator: // // def CreateAddOfTwoConst : // NativeCodeCall<"ConstPropElementwiseBinary($_builder, $0, $1, $2)">; @@ -34,7 +35,6 @@ include "src/Dialect/ONNX/ONNXOps.td" // a new def name. // // Then you will need to add substitution rules, see examples below. -// ============================================================================= // Useful test definitions. @@ -63,8 +63,9 @@ def CreateNegOfConst : def CreateMulOfTwoConst : NativeCodeCall<"ConstPropElementwiseBinary($_builder, $0, $1, $2)">; -// ============================================================================= +//===----------------------------------------------------------------------===// // Patterns to enable opportunities with elementwise ADD operations. +//===----------------------------------------------------------------------===// // Use commutativity to normalize constants in the second position of Add. def AddConstCommutative1 : Pat< @@ -96,8 +97,9 @@ def AddConstProp : Pat< [(AttributeIsNull:$s1), (AttributeIsNull:$s2)]>; -// ============================================================================= +//===----------------------------------------------------------------------===// // Patterns to enable opportunities with elementwise SUB / NEG operations. +//===----------------------------------------------------------------------===// // Constant Propagation for Sub def SubConstProp : Pat< @@ -125,10 +127,11 @@ def SubConstToNeg : Pat< [(IsNotAConstant:$x), (AttributeIsNull:$s)]>; -// ============================================================================= +//===----------------------------------------------------------------------===// // Patterns to enable opportunities with elementwise MUL operations. // Exactly the same pattern as for the elementwise ADD operations. - +//===----------------------------------------------------------------------===// + // Use commutativity to normalize constants in the second position of Mul. def MulConstCommutative1 : Pat< // From mul(c, x). diff --git a/src/Transform/ONNX/ONNXDecompose.cpp b/src/Transform/ONNX/Decompose.cpp similarity index 100% rename from src/Transform/ONNX/ONNXDecompose.cpp rename to src/Transform/ONNX/Decompose.cpp diff --git a/src/Transform/ONNX/ONNXDecompose.td b/src/Transform/ONNX/Decompose.td similarity index 100% rename from src/Transform/ONNX/ONNXDecompose.td rename to src/Transform/ONNX/Decompose.td diff --git a/src/Transform/ONNX/ONNXRewrite.cpp b/src/Transform/ONNX/Rewrite.cpp similarity index 100% rename from src/Transform/ONNX/ONNXRewrite.cpp rename to src/Transform/ONNX/Rewrite.cpp diff --git a/src/Transform/ONNX/ONNXRewrite.td b/src/Transform/ONNX/Rewrite.td similarity index 100% rename from src/Transform/ONNX/ONNXRewrite.td rename to src/Transform/ONNX/Rewrite.td