Add type inference for CastOp (#156)
* Add type inference for CastOp * Share type translation between op builder and onnx importer * clang-format * Format emitted code * Remove unnecessary dependencies
This commit is contained in:
parent
2a1fe9e1e7
commit
e2e1fbfd3b
|
@ -54,42 +54,6 @@ private:
|
|||
|
||||
mlir::Location UnknownLoc() { return mlir::UnknownLoc::get(&context_); }
|
||||
|
||||
// Convert type to MLIR type.
|
||||
// A complete list of types can be found in:
|
||||
// <onnx-mlir-build-folder>/third_party/onnx/onnx/onnx.pb.h
|
||||
mlir::Type convertONNXTypeToMLIRType(onnx::TensorProto_DataType onnxType) {
|
||||
switch (onnxType) {
|
||||
case onnx::TensorProto_DataType::TensorProto_DataType_FLOAT16:
|
||||
return builder_.getF16Type();
|
||||
case onnx::TensorProto_DataType::TensorProto_DataType_FLOAT:
|
||||
return builder_.getF32Type();
|
||||
case onnx::TensorProto_DataType::TensorProto_DataType_DOUBLE:
|
||||
return builder_.getF64Type();
|
||||
case onnx::TensorProto_DataType::TensorProto_DataType_INT8:
|
||||
case onnx::TensorProto_DataType::TensorProto_DataType_UINT8:
|
||||
return builder_.getIntegerType(/*width=*/8);
|
||||
case onnx::TensorProto_DataType::TensorProto_DataType_INT16:
|
||||
case onnx::TensorProto_DataType::TensorProto_DataType_UINT16:
|
||||
return builder_.getIntegerType(/*width=*/16);
|
||||
case onnx::TensorProto_DataType::TensorProto_DataType_INT32:
|
||||
case onnx::TensorProto_DataType::TensorProto_DataType_UINT32:
|
||||
return builder_.getIntegerType(/*width=*/32);
|
||||
case onnx::TensorProto_DataType::TensorProto_DataType_INT64:
|
||||
case onnx::TensorProto_DataType::TensorProto_DataType_UINT64:
|
||||
return builder_.getIntegerType(/*width=*/64);
|
||||
case onnx::TensorProto_DataType::TensorProto_DataType_BOOL:
|
||||
return builder_.getI1Type();
|
||||
|
||||
case onnx::TensorProto_DataType::TensorProto_DataType_STRING:
|
||||
case onnx::TensorProto_DataType::TensorProto_DataType_COMPLEX64:
|
||||
case onnx::TensorProto_DataType::TensorProto_DataType_COMPLEX128:
|
||||
case onnx::TensorProto_DataType::TensorProto_DataType_UNDEFINED:
|
||||
default:
|
||||
assert(false && "Unsupported data type encountered.");
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
/*!
|
||||
* Import an onnx input tensor type by determining and recording its type
|
||||
* in a list of input tensor mlir types.
|
||||
|
@ -119,7 +83,8 @@ private:
|
|||
|
||||
auto elementOnnxType =
|
||||
(onnx::TensorProto_DataType)input.type().tensor_type().elem_type();
|
||||
mlir::Type elementType = convertONNXTypeToMLIRType(elementOnnxType);
|
||||
mlir::Type elementType =
|
||||
convertONNXTypeToMLIRType(builder_, elementOnnxType);
|
||||
llvm::ArrayRef<int64_t> tensor_dims(dims.data(), dims.size());
|
||||
return mlir::RankedTensorType::get(tensor_dims, elementType);
|
||||
}
|
||||
|
|
|
@ -39,6 +39,7 @@ endif()
|
|||
# targets, or only use it for libraries that have no further dependencies
|
||||
# (except system libraries such as libc).
|
||||
target_link_libraries(onnx-mlir
|
||||
onnx
|
||||
OMBuilder
|
||||
OMKrnlOps
|
||||
OMONNXOps
|
||||
|
|
|
@ -21,6 +21,8 @@ add_library(OMONNXToKrnl
|
|||
Tensor/Constant.cpp
|
||||
Tensor/Concat.cpp
|
||||
ConvertONNXToKrnl.cpp)
|
||||
target_link_libraries(OMONNXToKrnl
|
||||
onnx)
|
||||
target_include_directories(OMONNXToKrnl
|
||||
PRIVATE
|
||||
${ONNX_MLIR_SRC_ROOT}
|
||||
|
|
|
@ -14,6 +14,8 @@ target_include_directories(OMONNXOps
|
|||
${ONNX_MLIR_SRC_ROOT}
|
||||
${ONNX_MLIR_BIN_ROOT}
|
||||
${ONNX_MLIR_SRC_ROOT})
|
||||
target_link_libraries(OMONNXOps
|
||||
onnx)
|
||||
add_dependencies(OMONNXOps OMONNXOpsIncGen)
|
||||
# Linking dependencies:
|
||||
add_dependencies(OMONNXOps
|
||||
|
|
|
@ -21,7 +21,6 @@
|
|||
#include "llvm/ADT/SmallBitVector.h"
|
||||
|
||||
#include "ONNXOps.hpp"
|
||||
#include "ONNXOpsHelper.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::OpTrait::util;
|
||||
|
|
|
@ -23,6 +23,8 @@
|
|||
#include "src/Interface/ResultTypeInferenceOpInterface.hpp"
|
||||
#include "src/Interface/ShapeInferenceInterface.hpp"
|
||||
|
||||
#include "ONNXOpsHelper.hpp"
|
||||
|
||||
namespace mlir {
|
||||
|
||||
class ONNXOpsDialect : public Dialect {
|
||||
|
|
|
@ -397,7 +397,7 @@ def ONNXBitShiftOp:ONNX_Op<"BitShift",
|
|||
}
|
||||
|
||||
def ONNXCastOp:ONNX_Op<"Cast",
|
||||
[NoSideEffect]> {
|
||||
[NoSideEffect, OpInterface<"ResultTypeInferenceOpInterface">]> {
|
||||
let summary = "ONNX Cast operation";
|
||||
let description = [{
|
||||
"The operator casts the elements of a given input tensor to a data type"
|
||||
|
@ -433,6 +433,14 @@ def ONNXCastOp:ONNX_Op<"Cast",
|
|||
static std::vector<int> getTypeMap() {
|
||||
return {-1};
|
||||
}
|
||||
std::vector<mlir::Type> resultTypeInference() {
|
||||
std::vector<mlir::Type> resultTypes;
|
||||
auto toAttr = to().getSExtValue();
|
||||
auto builder = mlir::OpBuilder(getContext());
|
||||
resultTypes.push_back(mlir::UnrankedTensorType::get(
|
||||
convertONNXTypeToMLIRType(builder, static_cast<onnx::TensorProto_DataType>(toAttr))));
|
||||
return resultTypes;
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
|
|
|
@ -40,3 +40,40 @@ AffineMap getConvDimMap(Builder &builder, bool ceilMode) {
|
|||
|
||||
return AffineMap::get(1, 4, {dimExp});
|
||||
}
|
||||
|
||||
// Convert type to MLIR type.
|
||||
// A complete list of types can be found in:
|
||||
// <onnx-mlir-build-folder>/third_party/onnx/onnx/onnx.pb.h
|
||||
mlir::Type convertONNXTypeToMLIRType(
|
||||
mlir::OpBuilder &builder_, onnx::TensorProto_DataType onnxType) {
|
||||
switch (onnxType) {
|
||||
case onnx::TensorProto_DataType::TensorProto_DataType_FLOAT16:
|
||||
return builder_.getF16Type();
|
||||
case onnx::TensorProto_DataType::TensorProto_DataType_FLOAT:
|
||||
return builder_.getF32Type();
|
||||
case onnx::TensorProto_DataType::TensorProto_DataType_DOUBLE:
|
||||
return builder_.getF64Type();
|
||||
case onnx::TensorProto_DataType::TensorProto_DataType_INT8:
|
||||
case onnx::TensorProto_DataType::TensorProto_DataType_UINT8:
|
||||
return builder_.getIntegerType(/*width=*/8);
|
||||
case onnx::TensorProto_DataType::TensorProto_DataType_INT16:
|
||||
case onnx::TensorProto_DataType::TensorProto_DataType_UINT16:
|
||||
return builder_.getIntegerType(/*width=*/16);
|
||||
case onnx::TensorProto_DataType::TensorProto_DataType_INT32:
|
||||
case onnx::TensorProto_DataType::TensorProto_DataType_UINT32:
|
||||
return builder_.getIntegerType(/*width=*/32);
|
||||
case onnx::TensorProto_DataType::TensorProto_DataType_INT64:
|
||||
case onnx::TensorProto_DataType::TensorProto_DataType_UINT64:
|
||||
return builder_.getIntegerType(/*width=*/64);
|
||||
case onnx::TensorProto_DataType::TensorProto_DataType_BOOL:
|
||||
return builder_.getI1Type();
|
||||
|
||||
case onnx::TensorProto_DataType::TensorProto_DataType_STRING:
|
||||
case onnx::TensorProto_DataType::TensorProto_DataType_COMPLEX64:
|
||||
case onnx::TensorProto_DataType::TensorProto_DataType_COMPLEX128:
|
||||
case onnx::TensorProto_DataType::TensorProto_DataType_UNDEFINED:
|
||||
default:
|
||||
assert(false && "Unsupported data type encountered.");
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -11,6 +11,9 @@
|
|||
#include "mlir/IR/AffineExpr.h"
|
||||
#include "mlir/IR/AffineMap.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/StandardTypes.h"
|
||||
|
||||
#include "onnx/onnx_pb.h"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
|
@ -31,3 +34,6 @@ AffineMap getIdentityDimMap(Builder &builder);
|
|||
// - s2: stride
|
||||
// - s3: dilation
|
||||
AffineMap getConvDimMap(Builder &builder, bool ceilMode);
|
||||
|
||||
mlir::Type convertONNXTypeToMLIRType(
|
||||
mlir::OpBuilder &builder_, onnx::TensorProto_DataType onnxType);
|
||||
|
|
|
@ -18,6 +18,8 @@ target_include_directories(OMKrnlToLLVM
|
|||
${ONNX_MLIR_SRC_ROOT}
|
||||
${ONNX_MLIR_BIN_ROOT}
|
||||
${ONNX_MLIR_SRC_ROOT})
|
||||
target_link_libraries(OMKrnlToLLVM
|
||||
onnx)
|
||||
|
||||
#Linking dependencies:
|
||||
add_dependencies(OMKrnlToLLVM
|
||||
|
|
|
@ -7,6 +7,8 @@ target_include_directories(OMAttributePromotion
|
|||
# Linking dependencies:
|
||||
add_dependencies(OMAttributePromotion
|
||||
OMPromotableConstOperandsOpInterface)
|
||||
target_link_libraries(OMAttributePromotion
|
||||
onnx)
|
||||
|
||||
add_library(OMElideConstants
|
||||
ElideConstants.cpp)
|
||||
|
@ -16,6 +18,8 @@ target_include_directories(OMElideConstants
|
|||
|
||||
add_dependencies(OMElideConstants
|
||||
OMONNXOps)
|
||||
target_link_libraries(OMElideConstants
|
||||
onnx)
|
||||
|
||||
set(LLVM_TARGET_DEFINITIONS ONNXRewrite.td)
|
||||
onnx_mlir_tablegen(ONNXRewrite.inc -gen-rewriters)
|
||||
|
@ -43,6 +47,8 @@ add_dependencies(OMONNXRewrite
|
|||
# Linking dependencies:
|
||||
add_dependencies(OMONNXRewrite
|
||||
OMONNXOps)
|
||||
target_link_libraries(OMONNXRewrite
|
||||
onnx)
|
||||
|
||||
add_library(OMShapeInference ShapeInferencePass.cpp)
|
||||
target_include_directories(OMShapeInference
|
||||
|
|
|
@ -279,7 +279,12 @@ OpsWithResultTypeInference = {
|
|||
resultTypes.push_back(attr.getType());
|
||||
} else if (auto attr = sparse_valueAttr()) {
|
||||
resultTypes.push_back(attr.getType());
|
||||
}'''
|
||||
}''',
|
||||
"Cast":
|
||||
'''auto toAttr = to().getSExtValue();
|
||||
auto builder = mlir::OpBuilder(getContext());
|
||||
resultTypes.push_back(mlir::UnrankedTensorType::get(
|
||||
convertONNXTypeToMLIRType(builder, static_cast<onnx::TensorProto_DataType>(toAttr))));'''
|
||||
}
|
||||
|
||||
# Add an Op in this list if the Op needs result type deduction which is required
|
||||
|
|
Loading…
Reference in New Issue