Add type inference for CastOp (#156)

* Add type inference for CastOp

* Share type translation between op builder and onnx importer

* clang-format

* Format emitted code

* Remove unnecessary dependencies
This commit is contained in:
Tung D. Le 2020-06-04 22:05:04 +09:00 committed by GitHub
parent 2a1fe9e1e7
commit e2e1fbfd3b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 75 additions and 40 deletions

View File

@ -54,42 +54,6 @@ private:
mlir::Location UnknownLoc() { return mlir::UnknownLoc::get(&context_); }
// Convert type to MLIR type.
// A complete list of types can be found in:
// <onnx-mlir-build-folder>/third_party/onnx/onnx/onnx.pb.h
mlir::Type convertONNXTypeToMLIRType(onnx::TensorProto_DataType onnxType) {
switch (onnxType) {
case onnx::TensorProto_DataType::TensorProto_DataType_FLOAT16:
return builder_.getF16Type();
case onnx::TensorProto_DataType::TensorProto_DataType_FLOAT:
return builder_.getF32Type();
case onnx::TensorProto_DataType::TensorProto_DataType_DOUBLE:
return builder_.getF64Type();
case onnx::TensorProto_DataType::TensorProto_DataType_INT8:
case onnx::TensorProto_DataType::TensorProto_DataType_UINT8:
return builder_.getIntegerType(/*width=*/8);
case onnx::TensorProto_DataType::TensorProto_DataType_INT16:
case onnx::TensorProto_DataType::TensorProto_DataType_UINT16:
return builder_.getIntegerType(/*width=*/16);
case onnx::TensorProto_DataType::TensorProto_DataType_INT32:
case onnx::TensorProto_DataType::TensorProto_DataType_UINT32:
return builder_.getIntegerType(/*width=*/32);
case onnx::TensorProto_DataType::TensorProto_DataType_INT64:
case onnx::TensorProto_DataType::TensorProto_DataType_UINT64:
return builder_.getIntegerType(/*width=*/64);
case onnx::TensorProto_DataType::TensorProto_DataType_BOOL:
return builder_.getI1Type();
case onnx::TensorProto_DataType::TensorProto_DataType_STRING:
case onnx::TensorProto_DataType::TensorProto_DataType_COMPLEX64:
case onnx::TensorProto_DataType::TensorProto_DataType_COMPLEX128:
case onnx::TensorProto_DataType::TensorProto_DataType_UNDEFINED:
default:
assert(false && "Unsupported data type encountered.");
return nullptr;
}
}
/*!
* Import an onnx input tensor type by determining and recording its type
* in a list of input tensor mlir types.
@ -119,7 +83,8 @@ private:
auto elementOnnxType =
(onnx::TensorProto_DataType)input.type().tensor_type().elem_type();
mlir::Type elementType = convertONNXTypeToMLIRType(elementOnnxType);
mlir::Type elementType =
convertONNXTypeToMLIRType(builder_, elementOnnxType);
llvm::ArrayRef<int64_t> tensor_dims(dims.data(), dims.size());
return mlir::RankedTensorType::get(tensor_dims, elementType);
}

View File

@ -39,6 +39,7 @@ endif()
# targets, or only use it for libraries that have no further dependencies
# (except system libraries such as libc).
target_link_libraries(onnx-mlir
onnx
OMBuilder
OMKrnlOps
OMONNXOps

View File

@ -21,6 +21,8 @@ add_library(OMONNXToKrnl
Tensor/Constant.cpp
Tensor/Concat.cpp
ConvertONNXToKrnl.cpp)
target_link_libraries(OMONNXToKrnl
onnx)
target_include_directories(OMONNXToKrnl
PRIVATE
${ONNX_MLIR_SRC_ROOT}

View File

@ -14,6 +14,8 @@ target_include_directories(OMONNXOps
${ONNX_MLIR_SRC_ROOT}
${ONNX_MLIR_BIN_ROOT}
${ONNX_MLIR_SRC_ROOT})
target_link_libraries(OMONNXOps
onnx)
add_dependencies(OMONNXOps OMONNXOpsIncGen)
# Linking dependencies:
add_dependencies(OMONNXOps

View File

@ -21,7 +21,6 @@
#include "llvm/ADT/SmallBitVector.h"
#include "ONNXOps.hpp"
#include "ONNXOpsHelper.hpp"
using namespace mlir;
using namespace mlir::OpTrait::util;

View File

@ -23,6 +23,8 @@
#include "src/Interface/ResultTypeInferenceOpInterface.hpp"
#include "src/Interface/ShapeInferenceInterface.hpp"
#include "ONNXOpsHelper.hpp"
namespace mlir {
class ONNXOpsDialect : public Dialect {

View File

@ -397,7 +397,7 @@ def ONNXBitShiftOp:ONNX_Op<"BitShift",
}
def ONNXCastOp:ONNX_Op<"Cast",
[NoSideEffect]> {
[NoSideEffect, OpInterface<"ResultTypeInferenceOpInterface">]> {
let summary = "ONNX Cast operation";
let description = [{
"The operator casts the elements of a given input tensor to a data type"
@ -433,6 +433,14 @@ def ONNXCastOp:ONNX_Op<"Cast",
static std::vector<int> getTypeMap() {
return {-1};
}
std::vector<mlir::Type> resultTypeInference() {
std::vector<mlir::Type> resultTypes;
auto toAttr = to().getSExtValue();
auto builder = mlir::OpBuilder(getContext());
resultTypes.push_back(mlir::UnrankedTensorType::get(
convertONNXTypeToMLIRType(builder, static_cast<onnx::TensorProto_DataType>(toAttr))));
return resultTypes;
}
}];
}

View File

@ -40,3 +40,40 @@ AffineMap getConvDimMap(Builder &builder, bool ceilMode) {
return AffineMap::get(1, 4, {dimExp});
}
// Convert type to MLIR type.
// A complete list of types can be found in:
// <onnx-mlir-build-folder>/third_party/onnx/onnx/onnx.pb.h
mlir::Type convertONNXTypeToMLIRType(
mlir::OpBuilder &builder_, onnx::TensorProto_DataType onnxType) {
switch (onnxType) {
case onnx::TensorProto_DataType::TensorProto_DataType_FLOAT16:
return builder_.getF16Type();
case onnx::TensorProto_DataType::TensorProto_DataType_FLOAT:
return builder_.getF32Type();
case onnx::TensorProto_DataType::TensorProto_DataType_DOUBLE:
return builder_.getF64Type();
case onnx::TensorProto_DataType::TensorProto_DataType_INT8:
case onnx::TensorProto_DataType::TensorProto_DataType_UINT8:
return builder_.getIntegerType(/*width=*/8);
case onnx::TensorProto_DataType::TensorProto_DataType_INT16:
case onnx::TensorProto_DataType::TensorProto_DataType_UINT16:
return builder_.getIntegerType(/*width=*/16);
case onnx::TensorProto_DataType::TensorProto_DataType_INT32:
case onnx::TensorProto_DataType::TensorProto_DataType_UINT32:
return builder_.getIntegerType(/*width=*/32);
case onnx::TensorProto_DataType::TensorProto_DataType_INT64:
case onnx::TensorProto_DataType::TensorProto_DataType_UINT64:
return builder_.getIntegerType(/*width=*/64);
case onnx::TensorProto_DataType::TensorProto_DataType_BOOL:
return builder_.getI1Type();
case onnx::TensorProto_DataType::TensorProto_DataType_STRING:
case onnx::TensorProto_DataType::TensorProto_DataType_COMPLEX64:
case onnx::TensorProto_DataType::TensorProto_DataType_COMPLEX128:
case onnx::TensorProto_DataType::TensorProto_DataType_UNDEFINED:
default:
assert(false && "Unsupported data type encountered.");
return nullptr;
}
}

View File

@ -11,6 +11,9 @@
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/StandardTypes.h"
#include "onnx/onnx_pb.h"
using namespace mlir;
@ -31,3 +34,6 @@ AffineMap getIdentityDimMap(Builder &builder);
// - s2: stride
// - s3: dilation
AffineMap getConvDimMap(Builder &builder, bool ceilMode);
mlir::Type convertONNXTypeToMLIRType(
mlir::OpBuilder &builder_, onnx::TensorProto_DataType onnxType);

View File

@ -18,6 +18,8 @@ target_include_directories(OMKrnlToLLVM
${ONNX_MLIR_SRC_ROOT}
${ONNX_MLIR_BIN_ROOT}
${ONNX_MLIR_SRC_ROOT})
target_link_libraries(OMKrnlToLLVM
onnx)
#Linking dependencies:
add_dependencies(OMKrnlToLLVM

View File

@ -7,6 +7,8 @@ target_include_directories(OMAttributePromotion
# Linking dependencies:
add_dependencies(OMAttributePromotion
OMPromotableConstOperandsOpInterface)
target_link_libraries(OMAttributePromotion
onnx)
add_library(OMElideConstants
ElideConstants.cpp)
@ -16,6 +18,8 @@ target_include_directories(OMElideConstants
add_dependencies(OMElideConstants
OMONNXOps)
target_link_libraries(OMElideConstants
onnx)
set(LLVM_TARGET_DEFINITIONS ONNXRewrite.td)
onnx_mlir_tablegen(ONNXRewrite.inc -gen-rewriters)
@ -43,6 +47,8 @@ add_dependencies(OMONNXRewrite
# Linking dependencies:
add_dependencies(OMONNXRewrite
OMONNXOps)
target_link_libraries(OMONNXRewrite
onnx)
add_library(OMShapeInference ShapeInferencePass.cpp)
target_include_directories(OMShapeInference

View File

@ -279,7 +279,12 @@ OpsWithResultTypeInference = {
resultTypes.push_back(attr.getType());
} else if (auto attr = sparse_valueAttr()) {
resultTypes.push_back(attr.getType());
}'''
}''',
"Cast":
'''auto toAttr = to().getSExtValue();
auto builder = mlir::OpBuilder(getContext());
resultTypes.push_back(mlir::UnrankedTensorType::get(
convertONNXTypeToMLIRType(builder, static_cast<onnx::TensorProto_DataType>(toAttr))));'''
}
# Add an Op in this list if the Op needs result type deduction which is required