From fe3279e721c245be2f46211161aec20d4e7934c7 Mon Sep 17 00:00:00 2001 From: Gheorghe-Teodor Bercea Date: Tue, 10 Mar 2020 14:46:35 -0400 Subject: [PATCH] Initialize operation arguments with ONNX model constants (#8) * Save current state. * Include constant arguments in source. * Emit constants for Reshape second argument. * Clean-up code. * Add changes to gen_doc.py file. * Propagate constant tensor to Reshape second arg to infer shape. * Update documentation. * Eliminate constant tensor operations when lowering to KRNL dialect. * Replace ConstantTensorOp with ConstantOp. * Add comment to remove temporary Constant lowering code. * Remove unused shape inference for Constant. * Remove comment. * Remove explicit constant elimination. * Refactor code. --- doc/gen_doc.py | 1 + src/builder/CMakeLists.txt | 2 + src/builder/frontend_dialect_helper.cpp | 185 ++++++++++++++++++ src/builder/frontend_dialect_helper.hpp | 101 ++++++++++ src/builder/frontend_dialect_transformer.cpp | 177 +++++++---------- src/builder/frontend_dialect_transformer.hpp | 9 +- src/builder/op_build_table.inc | 2 +- .../onnx_to_krnl/tensor/reshape.cpp | 3 +- src/dialect/onnx/onnx.td | 1 - src/dialect/onnx/onnx_ops.cpp | 24 ++- src/pass/onnx_rewrite.cpp | 4 +- 11 files changed, 383 insertions(+), 126 deletions(-) create mode 100644 src/builder/frontend_dialect_helper.cpp create mode 100644 src/builder/frontend_dialect_helper.hpp diff --git a/doc/gen_doc.py b/doc/gen_doc.py index c654a69..5ccfafa 100644 --- a/doc/gen_doc.py +++ b/doc/gen_doc.py @@ -36,6 +36,7 @@ special_op_handler = dict([ ("MaxPool", "ImportNodeMaxPool"), ("BatchNormalization", "ImportNodeBatchNormalization"), ("Pad", "ImportNodePad"), + ("Reshape", "ImportNodeReshape"), #("Transpose", "ImportNodeTranspose") ]) diff --git a/src/builder/CMakeLists.txt b/src/builder/CMakeLists.txt index 6033e52..1d1a117 100644 --- a/src/builder/CMakeLists.txt +++ b/src/builder/CMakeLists.txt @@ -1,4 +1,6 @@ add_library(builder + frontend_dialect_helper.cpp + frontend_dialect_helper.hpp frontend_dialect_transformer.cpp frontend_dialect_transformer.hpp op_build_table.inc diff --git a/src/builder/frontend_dialect_helper.cpp b/src/builder/frontend_dialect_helper.cpp new file mode 100644 index 0000000..1dbadf4 --- /dev/null +++ b/src/builder/frontend_dialect_helper.cpp @@ -0,0 +1,185 @@ +//===------------------- frontend_dialect_helper.cpp ----------------------===// +// +// Copyright 2019 The IBM Research Authors. +// +// ============================================================================= +// +// Helper methods for handling input ONNX models. +// +//===----------------------------------------------------------------------===// + +#include "src/builder/frontend_dialect_helper.hpp" + +namespace onnf { + +void replaceAll(std::string &str, const std::string &from, + const std::string &to) { + if (from.empty()) + return; + size_t start_pos = 0; + while ((start_pos = str.find(from, start_pos)) != std::string::npos) { + str.replace(start_pos, from.length(), to); + start_pos += to.length(); // In case 'to' contains 'from', like replacing + // 'x' with 'yx' + } +} + +std::string legalize_name(std::string name) { + std::replace(name.begin(), name.end(), '/', '_'); + std::replace(name.begin(), name.end(), '-', '_'); + replaceAll(name, ":", "_colon_"); + // If tensor name starts with a number, prepend n to make it a legal c++ + // identifier. + if (name.size() > 0 && isdigit(name.at(0))) + name.insert(0, 1, 'n'); + return name; +} + +mlir::Value OnnxOnnfSymbolMapping::GetTensorByOnnxName( + const std::string &name) { + assert(onnx_name2onnf_tensor.find(legalize_name(name)) != + onnx_name2onnf_tensor.end() && + "Tensor not found"); + return onnx_name2onnf_tensor.at(legalize_name(name)); +} + +void OnnxOnnfSymbolMapping::AddMapping( + const std::string &name, mlir::Value tensor) { + assert(onnx_name2onnf_tensor.count(legalize_name(name)) == 0 && + "Tensor already exists."); + onnx_name2onnf_tensor.emplace(legalize_name(name), tensor); +} + +bool OnnxOnnfSymbolMapping::ContainKey(std::string name) { + return onnx_name2onnf_tensor.count(name) != 0; +} + +template +struct TransformValueToONNXData { + static const google::protobuf::RepeatedField data( + onnx::TensorProto initializer) { + return google::protobuf::RepeatedField(); + } +}; + +template <> +struct TransformValueToONNXData { + static const google::protobuf::RepeatedField data( + onnx::TensorProto initializer) { + return initializer.double_data(); + } +}; + +template <> +struct TransformValueToONNXData { + static const google::protobuf::RepeatedField data( + onnx::TensorProto initializer) { + return initializer.float_data(); + } +}; + +template <> +struct TransformValueToONNXData { + static const google::protobuf::RepeatedField data( + onnx::TensorProto initializer) { + return initializer.int32_data(); + } +}; + +template <> +struct TransformValueToONNXData { + static const google::protobuf::RepeatedField data( + onnx::TensorProto initializer) { + return initializer.int64_data(); + } +}; + +// Helper method for constructing an array attribute from a model input. +template +static T* CreateArrayAttribute(onnx::TensorProto initializer, int *size) { + if (initializer.raw_data().size()) { + // copy & take care of endianness + std::vector byteInitializer; + std::copy(initializer.raw_data().begin(), initializer.raw_data().end(), + back_inserter(byteInitializer)); + *size = initializer.raw_data().size() / sizeof(T); + return reinterpret_cast(&byteInitializer[0]); + } + + // copy, no need to take care of endianness + auto data = TransformValueToONNXData::data(initializer); + *size = data.size(); + return &data[0]; +} + +void InitializedTensorMapping::AddMapping( + std::string name, onnx::TensorProto tensor) { + assert(nameToInitializedTensor.count(name) == 0 && + "Tensor initializer already mapped."); + nameToInitializedTensor.emplace(name, tensor); +} + + +bool InitializedTensorMapping::ContainKey(std::string name) { + return nameToInitializedTensor.count(name) != 0; +} + +mlir::Value InitializedTensorMapping::EmitInitializerForInputTensor( + mlir::Location loc, mlir::OpBuilder &builder, std::string name) { + // Initializer for input. + onnx::TensorProto initializer = GetInitializedTensor(name); + + // Emit ConstantOp and record the mapping between the input and + // the constant value. + mlir::ArrayAttr constantArrayAttribute; + mlir::Type elementType; + int length; + switch (initializer.data_type()) { + case (onnx::TensorProto::FLOAT): { + float *typeArray = + CreateArrayAttribute(initializer, &length); + std::vector arrayAttrInitializer( + typeArray, typeArray + length); + llvm::ArrayRef array(typeArray, length); + constantArrayAttribute = builder.getF32ArrayAttr(array); + elementType = builder.getF32Type(); + break; + } + case (onnx::TensorProto::INT32): { + int32_t *typeArray = + CreateArrayAttribute(initializer, &length); + std::vector arrayAttrInitializer( + typeArray, typeArray + length); + llvm::ArrayRef array(typeArray, length); + constantArrayAttribute = builder.getI32ArrayAttr(array); + elementType = builder.getIntegerType(32); + break; + } + case (onnx::TensorProto::INT64): { + int64_t *typeArray = + CreateArrayAttribute(initializer, &length); + std::vector arrayAttrInitializer( + typeArray, typeArray + length); + llvm::ArrayRef array(typeArray, length); + constantArrayAttribute = builder.getI64ArrayAttr(array); + elementType = builder.getIntegerType(64); + break; + } + } + + // Create empty sparse_value attribute. + llvm::ArrayRef array; + auto sparseValueAttribute = builder.getI64ArrayAttr(array); + + // Create value attribute. + llvm::ArrayRef tensorDims(initializer.dims().data(), + initializer.dims().size()); + mlir::Type tensorType = + mlir::RankedTensorType::get(tensorDims, elementType); + + return builder.create( + loc, tensorType, sparseValueAttribute, + constantArrayAttribute); +} + +} // namespace onnf \ No newline at end of file diff --git a/src/builder/frontend_dialect_helper.hpp b/src/builder/frontend_dialect_helper.hpp new file mode 100644 index 0000000..f47c685 --- /dev/null +++ b/src/builder/frontend_dialect_helper.hpp @@ -0,0 +1,101 @@ +//===------------------- frontend_dialect_helper.hpp ----------------------===// +// +// Copyright 2019 The IBM Research Authors. +// +// ============================================================================= +// +// Helper methods for handling input ONNX models. +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include +#include +#include + +#include "mlir/Analysis/Verifier.h" +#include "mlir/Dialect/StandardOps/Ops.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/Function.h" +#include "mlir/IR/Location.h" +#include "mlir/IR/Matchers.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/Module.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/StandardTypes.h" +#include "mlir/IR/Types.h" + +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/ScopedHashTable.h" +#include "llvm/Support/raw_ostream.h" + +#include "src/dialect/onnx/onnx_ops.hpp" +#include "onnx/onnx_pb.h" + +namespace onnf { + +void replaceAll(std::string &str, const std::string &from, + const std::string &to); + +std::string legalize_name(std::string name); + +struct OnnxOnnfSymbolMapping { + /*! + * Get MLIR tensor by onnx tensor name. + * @param name onnx tensor name. + * @return onnf tensor corresponding to `name`. + */ + mlir::Value GetTensorByOnnxName(const std::string &name); + + /*! + * Add a new mapping from onnx tensor name to MLIR symbol. + * @param name onnx tensor name. + * @param tensor MLIR Value pointer. + */ + void AddMapping(const std::string &name, mlir::Value tensor); + + bool ContainKey(std::string name); + +private: + /*! + * mapping from onnx tensor names to MLIR tensor. + */ + std::map onnx_name2onnf_tensor; +}; + +struct InitializedTensorMapping { + // Add new entry. + void AddMapping(std::string name, onnx::TensorProto tensor); + + // Check if input is initialized. Not all inputs are, some of the inputs + // require input from the user and are not stored inside the ONNX model + // itself. + bool ContainKey(std::string name); + + // Emit constant argument (initialized arguments) as a ConstantOp. + // This method will allow operations to use the constant data contained + // in an ONNX model as they are being compiled. + // This method enables the emission of such constant operation on demand. + // + // This will allow the propagation of shape information passed in as an + // argument to operations such as Reshape and will enable other + // optimizations such as constant folding. + mlir::Value EmitInitializerForInputTensor(mlir::Location loc, + mlir::OpBuilder &builder, std::string name); + + // Get initialized tensor. + onnx::TensorProto& GetInitializedTensor(std::string name) { + assert(nameToInitializedTensor.find(name) != + nameToInitializedTensor.end() && + "Tensor initializer not found"); + return nameToInitializedTensor.at(name); + } + +private: + // Mapping from ONNX tensor name to InitializedTensor. + std::map nameToInitializedTensor; +}; + +} // namespace onnf \ No newline at end of file diff --git a/src/builder/frontend_dialect_transformer.cpp b/src/builder/frontend_dialect_transformer.cpp index 0efca22..a7a6d09 100644 --- a/src/builder/frontend_dialect_transformer.cpp +++ b/src/builder/frontend_dialect_transformer.cpp @@ -14,96 +14,20 @@ // //===----------------------------------------------------------------------===// -#include -#include -#include -#include -#include - // Using backported variant. // bstd = backported standard library. #include namespace bstd = mpark; -#include "mlir/Analysis/Verifier.h" -#include "mlir/Dialect/StandardOps/Ops.h" -#include "mlir/IR/Attributes.h" -#include "mlir/IR/Builders.h" -#include "mlir/IR/Function.h" -#include "mlir/IR/Location.h" -#include "mlir/IR/MLIRContext.h" -#include "mlir/IR/Module.h" -#include "mlir/IR/StandardTypes.h" -#include "mlir/IR/Types.h" - -#include "llvm/ADT/STLExtras.h" -#include "llvm/ADT/ScopedHashTable.h" -#include "llvm/Support/raw_ostream.h" - -#include "src/dialect/onnx/onnx_ops.hpp" - #include "frontend_dialect_transformer.hpp" namespace onnf { namespace { -void replaceAll(std::string &str, const std::string &from, - const std::string &to) { - if (from.empty()) - return; - size_t start_pos = 0; - while ((start_pos = str.find(from, start_pos)) != std::string::npos) { - str.replace(start_pos, from.length(), to); - start_pos += to.length(); // In case 'to' contains 'from', like replacing - // 'x' with 'yx' - } -} - -std::string legalize_name(std::string name) { - std::replace(name.begin(), name.end(), '/', '_'); - std::replace(name.begin(), name.end(), '-', '_'); - replaceAll(name, ":", "_colon_"); - // If tensor name starts with a number, prepend n to make it a legal c++ - // identifier. - if (name.size() > 0 && isdigit(name.at(0))) - name.insert(0, 1, 'n'); - return name; -} - -struct OnnxOnnfSymbolMapping { - /*! - * Get MLIR tensor by onnx tensor name. - * @param name onnx tensor name. - * @return onnf tensor corresponding to `name`. - */ - mlir::Value GetTensorByOnnxName(const std::string &name) { - assert(onnx_name2onnf_tensor.find(legalize_name(name)) != - onnx_name2onnf_tensor.end() && - "Tensor not found"); - return onnx_name2onnf_tensor.at(legalize_name(name)); - } - - /*! - * Add a new mapping from onnx tensor name to MLIR symbol. - * @param name onnx tensor name. - * @param tensor MLIR Value pointer. - */ - void AddMapping(const std::string &name, mlir::Value tensor) { - assert(onnx_name2onnf_tensor.count(legalize_name(name)) == 0 && - "Tensor already exists."); - onnx_name2onnf_tensor.emplace(legalize_name(name), tensor); - } - - bool ContainKey(std::string name) { - return onnx_name2onnf_tensor.count(name) != 0; - } - -private: - /*! - * mapping from onnx tensor names to MLIR tensor. - */ - std::map onnx_name2onnf_tensor; -}; +/*! + * The list of tensors initialized by the ONNX model. + */ +InitializedTensorMapping initializedTensors; class FrontendGenImpl { public: @@ -167,8 +91,7 @@ private: * @param input onnx input tensor ValueInfoProto. * @param arg_types list of mlir types representing types of graph input. */ - void ImportInputTensorType(const onnx::ValueInfoProto &input, - llvm::SmallVector &arg_types) { + mlir::Type ImportInputTensorType(const onnx::ValueInfoProto &input) { std::vector dims; auto shape_proto = input.type().tensor_type().shape(); auto input_tensor_legalized_name = legalize_name(input.name()); @@ -193,8 +116,7 @@ private: (onnx::TensorProto_DataType)input.type().tensor_type().elem_type(); mlir::Type elementType = convertONNXTypeToMLIRType(elementOnnxType); llvm::ArrayRef tensor_dims(dims.data(), dims.size()); - arg_types.emplace_back( - mlir::RankedTensorType::get(tensor_dims, elementType)); + return mlir::RankedTensorType::get(tensor_dims, elementType); } /*! @@ -320,16 +242,11 @@ private: } template - void buildOperation(const onnx::NodeProto &node, int expectedNumOperands = -1, - int expectedNumResults = -1) { + void buildOutputAndOperation(const onnx::NodeProto &node, + std::vector inputs, int expectedNumOperands, + int expectedNumResults) { bool variadicIn = expectedNumOperands == -1; bool variadicOut = expectedNumResults == -1; - std::vector inputs; - for (const auto &item : node.input()) { - if (frontend_symbols_.ContainKey(legalize_name(item))) { - inputs.push_back(frontend_symbols_.GetTensorByOnnxName(item)); - } - } if (!variadicIn) for (auto i = inputs.size(); i < expectedNumOperands; i++) @@ -351,6 +268,37 @@ private: } } + template + void buildOperation(const onnx::NodeProto &node, + int expectedNumOperands = -1, + int expectedNumResults = -1) { + std::vector inputs; + for (const auto &item : node.input()) + if (frontend_symbols_.ContainKey(legalize_name(item))) + inputs.push_back(frontend_symbols_.GetTensorByOnnxName(item)); + + buildOutputAndOperation(node, inputs, expectedNumOperands, + expectedNumResults); + } + + void ImportNodeReshape(onnx::NodeProto node, int nIn, int nOut) { + std::vector inputs; + std::string item; + for (int i = 0; i < node.input().size(); ++i) { + item = node.input()[i]; + // For the second argument, check if there exists an initializer. + if (i == 1 && initializedTensors.ContainKey(legalize_name(item))) { + inputs.push_back( + initializedTensors.EmitInitializerForInputTensor( + UnknownLoc(), builder_, legalize_name(item))); + } else if (frontend_symbols_.ContainKey(legalize_name(item))) { + inputs.push_back(frontend_symbols_.GetTensorByOnnxName(item)); + } + } + + buildOutputAndOperation(node, inputs, nIn, nOut); + } + /*! * Special handle for Conv operations. * c++ does not allow template specialization inside a class scope @@ -452,38 +400,52 @@ private: void ImportGraph(const onnx::GraphProto &graph, const std::string &name = "main_graph") { + // Maintain a mapping between the parameter and its initializer. + for (auto initializer : graph.initializer()) { + auto name = initializer.name(); + initializedTensors.AddMapping(legalize_name(name), initializer); + } + // create a function for the graph // TODO: // * get name and type for the function. // * maintain a list of the defined graph llvm::SmallVector arg_types; - // Import the input tensor types. - for (const auto &input : graph.input()) { - ImportInputTensorType(input, arg_types); - } + // Import the input tensor types that are not constant. + for (const auto &input : graph.input()) + arg_types.emplace_back(ImportInputTensorType(input)); - // TODO: import the initializer + // Create the main function. auto funcType = builder_.getFunctionType(arg_types, {}); auto mainFunc = mlir::FuncOp::create(UnknownLoc(), name, funcType, /* attrs = */ {}); + + // Emit the entry point operation which specifies the number of user + // inputs and outputs. auto entryPoint = mlir::ONNXEntryPointOp::create( - UnknownLoc(), mainFunc, /*numInputs=*/graph.input().size(), + UnknownLoc(), mainFunc, + /*numInputs=*/graph.input().size() - graph.initializer().size(), /*numOutputs=*/graph.output().size()); + // Get the entru block inside the main function and set the insertion point + // to it. auto &entryBlock = *mainFunc.addEntryBlock(); builder_.setInsertionPointToStart(&entryBlock); module_.push_back(mainFunc); module_.push_back(entryPoint); - for (auto it : llvm::zip(graph.input(), entryBlock.getArguments())) { - ImportInputTensorSymbol(std::get<0>(it), std::get<1>(it)); - } + // Map graph inputs to entry block arguments. + for (int i = 0; i < graph.input().size(); ++i) + ImportInputTensorSymbol( + graph.input()[i], entryBlock.getArguments()[i]); + + // Create a NoneTyped constant to be used for optional operation inputs + // which are not used. + none_ = builder_.create(UnknownLoc(), + builder_.getUnitAttr()); - // Create a NoneTyped constant. - none_ = - builder_.create(UnknownLoc(), builder_.getUnitAttr()); // Import nodes in the graph. for (const auto &item : graph.node()) { ImportNode(item); @@ -509,13 +471,6 @@ private: namespace onnf { -mlir::OwningModuleRef ImportFrontendModel(onnx::ModelProto model) { - mlir::MLIRContext context; - FrontendGenImpl myONNXGen(context); - auto module = myONNXGen.ImportONNXModel(model); - return module; -} - void ImportFrontendModelFile(std::string model_fname, mlir::MLIRContext &context, mlir::OwningModuleRef &module) { diff --git a/src/builder/frontend_dialect_transformer.hpp b/src/builder/frontend_dialect_transformer.hpp index 8544087..234415a 100644 --- a/src/builder/frontend_dialect_transformer.hpp +++ b/src/builder/frontend_dialect_transformer.hpp @@ -18,6 +18,8 @@ #include "onnx/onnx_pb.h" +#include "src/builder/frontend_dialect_helper.hpp" + namespace mlir { class MLIRContext; class OwningModuleRef; @@ -28,13 +30,6 @@ class OwningModuleRef; //===----------------------------------------------------------------------===// namespace onnf { -/*! - * Import an ONNX model into ONNF's ONNX Dialect. - * @param model onnx model. - * @return MLIR::module generated for the ONNX model. - */ -mlir::OwningModuleRef ImportFrontendModel(onnx::ModelProto model); - /*! * Import an ONNX model file into ONNF's ONNX Dialect. * @param model_fname file name pointing to the onnx model protobuf. diff --git a/src/builder/op_build_table.inc b/src/builder/op_build_table.inc index 41a910f..32328e8 100644 --- a/src/builder/op_build_table.inc +++ b/src/builder/op_build_table.inc @@ -224,7 +224,7 @@ if (opName == "ReduceSumSquare") if (opName == "Relu") return buildOperation(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); if (opName == "Reshape") - return buildOperation(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1); + return ImportNodeReshape(node, /* expected_num_operands = */ 2, /* expected_num_results = */ 1); if (opName == "Resize") return buildOperation(node, /* expected_num_operands = */ 4, /* expected_num_results = */ 1); if (opName == "ReverseSequence") diff --git a/src/conversion/onnx_to_krnl/tensor/reshape.cpp b/src/conversion/onnx_to_krnl/tensor/reshape.cpp index 9e99f2d..5c2ea11 100644 --- a/src/conversion/onnx_to_krnl/tensor/reshape.cpp +++ b/src/conversion/onnx_to_krnl/tensor/reshape.cpp @@ -61,7 +61,8 @@ struct ONNXReshapeOpLowering : public ConversionPattern { rewriter.getIntegerType(64), getMemRefEltSizeInBytes(memRefType)); SmallVector DimInfo; for (int i = 0; i < memRefShape.size(); ++i) { - Value index = emitConstantOp(rewriter, loc, rewriter.getIndexType(), i); + Value index = + emitConstantOp(rewriter, loc, rewriter.getIndexType(), i); // Load index from array of indices. Value loadedVal = rewriter.create(loc, operands[1], index); // If a dimension is zero, the actual dimension value is taken from the diff --git a/src/dialect/onnx/onnx.td b/src/dialect/onnx/onnx.td index fae2806..68dbbf6 100644 --- a/src/dialect/onnx/onnx.td +++ b/src/dialect/onnx/onnx.td @@ -202,5 +202,4 @@ def ONNXPadConstantValuePadOp : ONNX_Op<"PadConstantValuePad", "FloatAttr constant_value, StringAttr mode">]; } - #endif // ONNX_OPS diff --git a/src/dialect/onnx/onnx_ops.cpp b/src/dialect/onnx/onnx_ops.cpp index d6b86b0..18768b7 100644 --- a/src/dialect/onnx/onnx_ops.cpp +++ b/src/dialect/onnx/onnx_ops.cpp @@ -656,9 +656,27 @@ void ONNXReshapeOp::inferShapes() { if (outputRank < 0) emitError("Shape tensor must have constant shape"); - SmallVector dims; - for (int i = 0; i < outputRank; ++i) - dims.emplace_back(-1); + // Check if second argument of ReshapeOp is a constant. + // Get operation that defines the second argument. If this operation is a + // `ConstantTensor` operation, the shape of this `Reshape` operation + // resides in the `value` attribute of the `ConstantTensor` operation. + auto *secondArgDefiningOp = (*getODSOperands(1).begin()).getDefiningOp(); + auto constantOp = + dyn_cast_or_null(secondArgDefiningOp); + + SmallVector dims(outputRank, -1); + if (constantOp) { + ArrayAttr valueAttribute = constantOp.valueAttr().dyn_cast(); + + if (!valueAttribute) + emitError("ArrayAttr expected"); + + if (valueAttribute.getValue().size() != outputRank) + emitError("Constant value must have same rank as output"); + + for (int i=0; i().getInt(); + } getResult().setType( RankedTensorType::get(dims, inputTensorTy.getElementType())); diff --git a/src/pass/onnx_rewrite.cpp b/src/pass/onnx_rewrite.cpp index 79bf712..8cac4e4 100644 --- a/src/pass/onnx_rewrite.cpp +++ b/src/pass/onnx_rewrite.cpp @@ -91,7 +91,7 @@ struct SplitConvOpPattern : public RewritePattern { 1, context) {} PatternMatchResult matchAndRewrite(Operation *op, - PatternRewriter &rewriter) const override { + PatternRewriter &rewriter) const override { auto loc = op->getLoc(); // If convolution does not use padding then no rewrite is required. @@ -173,7 +173,7 @@ void ONNXMaxPoolSingleOutOp::getCanonicalizationPatterns( OwningRewritePatternList &results, MLIRContext *context) { results.insert(context); } -/// on the ONNXReduceSumSquareOp. +/// on the ONNXConvNoBiasOp. void ONNXConvNoBiasOp::getCanonicalizationPatterns( OwningRewritePatternList &results, MLIRContext *context) { results.insert(context);