From e2e1fbfd3b9d27a1d1469a60bb27c0713e05f593 Mon Sep 17 00:00:00 2001 From: "Tung D. Le" Date: Thu, 4 Jun 2020 22:05:04 +0900 Subject: [PATCH] Add type inference for CastOp (#156) * Add type inference for CastOp * Share type translation between op builder and onnx importer * clang-format * Format emitted code * Remove unnecessary dependencies --- src/Builder/FrontendDialectTransformer.cpp | 39 ++-------------------- src/CMakeLists.txt | 1 + src/Conversion/ONNXToKrnl/CMakeLists.txt | 2 ++ src/Dialect/ONNX/CMakeLists.txt | 2 ++ src/Dialect/ONNX/ONNXOps.cpp | 1 - src/Dialect/ONNX/ONNXOps.hpp | 2 ++ src/Dialect/ONNX/ONNXOps.td.inc | 10 +++++- src/Dialect/ONNX/ONNXOpsHelper.cpp | 37 ++++++++++++++++++++ src/Dialect/ONNX/ONNXOpsHelper.hpp | 6 ++++ src/Transform/CMakeLists.txt | 2 ++ src/Transform/ONNX/CMakeLists.txt | 6 ++++ utils/gen_onnx_mlir.py | 7 +++- 12 files changed, 75 insertions(+), 40 deletions(-) diff --git a/src/Builder/FrontendDialectTransformer.cpp b/src/Builder/FrontendDialectTransformer.cpp index 2ab6353..78a12a0 100644 --- a/src/Builder/FrontendDialectTransformer.cpp +++ b/src/Builder/FrontendDialectTransformer.cpp @@ -54,42 +54,6 @@ private: mlir::Location UnknownLoc() { return mlir::UnknownLoc::get(&context_); } - // Convert type to MLIR type. - // A complete list of types can be found in: - // /third_party/onnx/onnx/onnx.pb.h - mlir::Type convertONNXTypeToMLIRType(onnx::TensorProto_DataType onnxType) { - switch (onnxType) { - case onnx::TensorProto_DataType::TensorProto_DataType_FLOAT16: - return builder_.getF16Type(); - case onnx::TensorProto_DataType::TensorProto_DataType_FLOAT: - return builder_.getF32Type(); - case onnx::TensorProto_DataType::TensorProto_DataType_DOUBLE: - return builder_.getF64Type(); - case onnx::TensorProto_DataType::TensorProto_DataType_INT8: - case onnx::TensorProto_DataType::TensorProto_DataType_UINT8: - return builder_.getIntegerType(/*width=*/8); - case onnx::TensorProto_DataType::TensorProto_DataType_INT16: - case onnx::TensorProto_DataType::TensorProto_DataType_UINT16: - return builder_.getIntegerType(/*width=*/16); - case onnx::TensorProto_DataType::TensorProto_DataType_INT32: - case onnx::TensorProto_DataType::TensorProto_DataType_UINT32: - return builder_.getIntegerType(/*width=*/32); - case onnx::TensorProto_DataType::TensorProto_DataType_INT64: - case onnx::TensorProto_DataType::TensorProto_DataType_UINT64: - return builder_.getIntegerType(/*width=*/64); - case onnx::TensorProto_DataType::TensorProto_DataType_BOOL: - return builder_.getI1Type(); - - case onnx::TensorProto_DataType::TensorProto_DataType_STRING: - case onnx::TensorProto_DataType::TensorProto_DataType_COMPLEX64: - case onnx::TensorProto_DataType::TensorProto_DataType_COMPLEX128: - case onnx::TensorProto_DataType::TensorProto_DataType_UNDEFINED: - default: - assert(false && "Unsupported data type encountered."); - return nullptr; - } - } - /*! * Import an onnx input tensor type by determining and recording its type * in a list of input tensor mlir types. @@ -119,7 +83,8 @@ private: auto elementOnnxType = (onnx::TensorProto_DataType)input.type().tensor_type().elem_type(); - mlir::Type elementType = convertONNXTypeToMLIRType(elementOnnxType); + mlir::Type elementType = + convertONNXTypeToMLIRType(builder_, elementOnnxType); llvm::ArrayRef tensor_dims(dims.data(), dims.size()); return mlir::RankedTensorType::get(tensor_dims, elementType); } diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index febd746..ca16cdd 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -39,6 +39,7 @@ endif() # targets, or only use it for libraries that have no further dependencies # (except system libraries such as libc). target_link_libraries(onnx-mlir + onnx OMBuilder OMKrnlOps OMONNXOps diff --git a/src/Conversion/ONNXToKrnl/CMakeLists.txt b/src/Conversion/ONNXToKrnl/CMakeLists.txt index a972031..aa2859c 100644 --- a/src/Conversion/ONNXToKrnl/CMakeLists.txt +++ b/src/Conversion/ONNXToKrnl/CMakeLists.txt @@ -21,6 +21,8 @@ add_library(OMONNXToKrnl Tensor/Constant.cpp Tensor/Concat.cpp ConvertONNXToKrnl.cpp) +target_link_libraries(OMONNXToKrnl + onnx) target_include_directories(OMONNXToKrnl PRIVATE ${ONNX_MLIR_SRC_ROOT} diff --git a/src/Dialect/ONNX/CMakeLists.txt b/src/Dialect/ONNX/CMakeLists.txt index 8295a58..cf9acfa 100644 --- a/src/Dialect/ONNX/CMakeLists.txt +++ b/src/Dialect/ONNX/CMakeLists.txt @@ -14,6 +14,8 @@ target_include_directories(OMONNXOps ${ONNX_MLIR_SRC_ROOT} ${ONNX_MLIR_BIN_ROOT} ${ONNX_MLIR_SRC_ROOT}) +target_link_libraries(OMONNXOps + onnx) add_dependencies(OMONNXOps OMONNXOpsIncGen) # Linking dependencies: add_dependencies(OMONNXOps diff --git a/src/Dialect/ONNX/ONNXOps.cpp b/src/Dialect/ONNX/ONNXOps.cpp index c6efad0..d2ea7c6 100644 --- a/src/Dialect/ONNX/ONNXOps.cpp +++ b/src/Dialect/ONNX/ONNXOps.cpp @@ -21,7 +21,6 @@ #include "llvm/ADT/SmallBitVector.h" #include "ONNXOps.hpp" -#include "ONNXOpsHelper.hpp" using namespace mlir; using namespace mlir::OpTrait::util; diff --git a/src/Dialect/ONNX/ONNXOps.hpp b/src/Dialect/ONNX/ONNXOps.hpp index 2d9e871..5c4200d 100644 --- a/src/Dialect/ONNX/ONNXOps.hpp +++ b/src/Dialect/ONNX/ONNXOps.hpp @@ -23,6 +23,8 @@ #include "src/Interface/ResultTypeInferenceOpInterface.hpp" #include "src/Interface/ShapeInferenceInterface.hpp" +#include "ONNXOpsHelper.hpp" + namespace mlir { class ONNXOpsDialect : public Dialect { diff --git a/src/Dialect/ONNX/ONNXOps.td.inc b/src/Dialect/ONNX/ONNXOps.td.inc index e6a361e..03823de 100644 --- a/src/Dialect/ONNX/ONNXOps.td.inc +++ b/src/Dialect/ONNX/ONNXOps.td.inc @@ -397,7 +397,7 @@ def ONNXBitShiftOp:ONNX_Op<"BitShift", } def ONNXCastOp:ONNX_Op<"Cast", - [NoSideEffect]> { + [NoSideEffect, OpInterface<"ResultTypeInferenceOpInterface">]> { let summary = "ONNX Cast operation"; let description = [{ "The operator casts the elements of a given input tensor to a data type" @@ -433,6 +433,14 @@ def ONNXCastOp:ONNX_Op<"Cast", static std::vector getTypeMap() { return {-1}; } + std::vector resultTypeInference() { + std::vector resultTypes; + auto toAttr = to().getSExtValue(); + auto builder = mlir::OpBuilder(getContext()); + resultTypes.push_back(mlir::UnrankedTensorType::get( + convertONNXTypeToMLIRType(builder, static_cast(toAttr)))); + return resultTypes; + } }]; } diff --git a/src/Dialect/ONNX/ONNXOpsHelper.cpp b/src/Dialect/ONNX/ONNXOpsHelper.cpp index 347336a..f7f8b9c 100644 --- a/src/Dialect/ONNX/ONNXOpsHelper.cpp +++ b/src/Dialect/ONNX/ONNXOpsHelper.cpp @@ -40,3 +40,40 @@ AffineMap getConvDimMap(Builder &builder, bool ceilMode) { return AffineMap::get(1, 4, {dimExp}); } + +// Convert type to MLIR type. +// A complete list of types can be found in: +// /third_party/onnx/onnx/onnx.pb.h +mlir::Type convertONNXTypeToMLIRType( + mlir::OpBuilder &builder_, onnx::TensorProto_DataType onnxType) { + switch (onnxType) { + case onnx::TensorProto_DataType::TensorProto_DataType_FLOAT16: + return builder_.getF16Type(); + case onnx::TensorProto_DataType::TensorProto_DataType_FLOAT: + return builder_.getF32Type(); + case onnx::TensorProto_DataType::TensorProto_DataType_DOUBLE: + return builder_.getF64Type(); + case onnx::TensorProto_DataType::TensorProto_DataType_INT8: + case onnx::TensorProto_DataType::TensorProto_DataType_UINT8: + return builder_.getIntegerType(/*width=*/8); + case onnx::TensorProto_DataType::TensorProto_DataType_INT16: + case onnx::TensorProto_DataType::TensorProto_DataType_UINT16: + return builder_.getIntegerType(/*width=*/16); + case onnx::TensorProto_DataType::TensorProto_DataType_INT32: + case onnx::TensorProto_DataType::TensorProto_DataType_UINT32: + return builder_.getIntegerType(/*width=*/32); + case onnx::TensorProto_DataType::TensorProto_DataType_INT64: + case onnx::TensorProto_DataType::TensorProto_DataType_UINT64: + return builder_.getIntegerType(/*width=*/64); + case onnx::TensorProto_DataType::TensorProto_DataType_BOOL: + return builder_.getI1Type(); + + case onnx::TensorProto_DataType::TensorProto_DataType_STRING: + case onnx::TensorProto_DataType::TensorProto_DataType_COMPLEX64: + case onnx::TensorProto_DataType::TensorProto_DataType_COMPLEX128: + case onnx::TensorProto_DataType::TensorProto_DataType_UNDEFINED: + default: + assert(false && "Unsupported data type encountered."); + return nullptr; + } +} diff --git a/src/Dialect/ONNX/ONNXOpsHelper.hpp b/src/Dialect/ONNX/ONNXOpsHelper.hpp index 15d4f0e..2dff474 100644 --- a/src/Dialect/ONNX/ONNXOpsHelper.hpp +++ b/src/Dialect/ONNX/ONNXOpsHelper.hpp @@ -11,6 +11,9 @@ #include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineMap.h" #include "mlir/IR/Builders.h" +#include "mlir/IR/StandardTypes.h" + +#include "onnx/onnx_pb.h" using namespace mlir; @@ -31,3 +34,6 @@ AffineMap getIdentityDimMap(Builder &builder); // - s2: stride // - s3: dilation AffineMap getConvDimMap(Builder &builder, bool ceilMode); + +mlir::Type convertONNXTypeToMLIRType( + mlir::OpBuilder &builder_, onnx::TensorProto_DataType onnxType); diff --git a/src/Transform/CMakeLists.txt b/src/Transform/CMakeLists.txt index ddee739..f3e2105 100644 --- a/src/Transform/CMakeLists.txt +++ b/src/Transform/CMakeLists.txt @@ -18,6 +18,8 @@ target_include_directories(OMKrnlToLLVM ${ONNX_MLIR_SRC_ROOT} ${ONNX_MLIR_BIN_ROOT} ${ONNX_MLIR_SRC_ROOT}) +target_link_libraries(OMKrnlToLLVM + onnx) #Linking dependencies: add_dependencies(OMKrnlToLLVM diff --git a/src/Transform/ONNX/CMakeLists.txt b/src/Transform/ONNX/CMakeLists.txt index 9068910..e2243ff 100644 --- a/src/Transform/ONNX/CMakeLists.txt +++ b/src/Transform/ONNX/CMakeLists.txt @@ -7,6 +7,8 @@ target_include_directories(OMAttributePromotion # Linking dependencies: add_dependencies(OMAttributePromotion OMPromotableConstOperandsOpInterface) +target_link_libraries(OMAttributePromotion + onnx) add_library(OMElideConstants ElideConstants.cpp) @@ -16,6 +18,8 @@ target_include_directories(OMElideConstants add_dependencies(OMElideConstants OMONNXOps) +target_link_libraries(OMElideConstants + onnx) set(LLVM_TARGET_DEFINITIONS ONNXRewrite.td) onnx_mlir_tablegen(ONNXRewrite.inc -gen-rewriters) @@ -43,6 +47,8 @@ add_dependencies(OMONNXRewrite # Linking dependencies: add_dependencies(OMONNXRewrite OMONNXOps) +target_link_libraries(OMONNXRewrite + onnx) add_library(OMShapeInference ShapeInferencePass.cpp) target_include_directories(OMShapeInference diff --git a/utils/gen_onnx_mlir.py b/utils/gen_onnx_mlir.py index cce1734..d393cd0 100644 --- a/utils/gen_onnx_mlir.py +++ b/utils/gen_onnx_mlir.py @@ -279,7 +279,12 @@ OpsWithResultTypeInference = { resultTypes.push_back(attr.getType()); } else if (auto attr = sparse_valueAttr()) { resultTypes.push_back(attr.getType()); - }''' + }''', + "Cast": + '''auto toAttr = to().getSExtValue(); + auto builder = mlir::OpBuilder(getContext()); + resultTypes.push_back(mlir::UnrankedTensorType::get( + convertONNXTypeToMLIRType(builder, static_cast(toAttr))));''' } # Add an Op in this list if the Op needs result type deduction which is required