diff --git a/src/builder/sgir.cpp b/src/builder/sgir.cpp index c44250d..0ebdf85 100644 --- a/src/builder/sgir.cpp +++ b/src/builder/sgir.cpp @@ -9,6 +9,7 @@ #include #include #include +#include #include "mlir/Analysis/Verifier.h" #include "mlir/Dialect/StandardOps/Ops.h" @@ -30,6 +31,29 @@ namespace onnf { namespace { +void replaceAll( + std::string& str, const std::string& from, const std::string& to) { + if (from.empty()) + return; + size_t start_pos = 0; + while ((start_pos = str.find(from, start_pos)) != std::string::npos) { + str.replace(start_pos, from.length(), to); + start_pos += to.length(); // In case 'to' contains 'from', like replacing + // 'x' with 'yx' + } +} + +std::string legalize_name(std::string name) { + std::replace(name.begin(), name.end(), '/', '_'); + std::replace(name.begin(), name.end(), '-', '_'); + replaceAll(name, ":", "_colon_"); + // If tensor name starts with a number, prepend n to make it a legal c++ + // identifier. + if (name.size() > 0 && isdigit(name.at(0))) + name.insert(0, 1, 'n'); + return name; +} + struct OnnxOnnfSymbolMapping { /*! * Get MLIR tensor by onnx tensor name.