Add type inference for CastOp (#156)
* Add type inference for CastOp * Share type translation between op builder and onnx importer * clang-format * Format emitted code * Remove unnecessary dependencies
This commit is contained in:
parent
2a1fe9e1e7
commit
e2e1fbfd3b
|
@ -54,42 +54,6 @@ private:
|
||||||
|
|
||||||
mlir::Location UnknownLoc() { return mlir::UnknownLoc::get(&context_); }
|
mlir::Location UnknownLoc() { return mlir::UnknownLoc::get(&context_); }
|
||||||
|
|
||||||
// Convert type to MLIR type.
|
|
||||||
// A complete list of types can be found in:
|
|
||||||
// <onnx-mlir-build-folder>/third_party/onnx/onnx/onnx.pb.h
|
|
||||||
mlir::Type convertONNXTypeToMLIRType(onnx::TensorProto_DataType onnxType) {
|
|
||||||
switch (onnxType) {
|
|
||||||
case onnx::TensorProto_DataType::TensorProto_DataType_FLOAT16:
|
|
||||||
return builder_.getF16Type();
|
|
||||||
case onnx::TensorProto_DataType::TensorProto_DataType_FLOAT:
|
|
||||||
return builder_.getF32Type();
|
|
||||||
case onnx::TensorProto_DataType::TensorProto_DataType_DOUBLE:
|
|
||||||
return builder_.getF64Type();
|
|
||||||
case onnx::TensorProto_DataType::TensorProto_DataType_INT8:
|
|
||||||
case onnx::TensorProto_DataType::TensorProto_DataType_UINT8:
|
|
||||||
return builder_.getIntegerType(/*width=*/8);
|
|
||||||
case onnx::TensorProto_DataType::TensorProto_DataType_INT16:
|
|
||||||
case onnx::TensorProto_DataType::TensorProto_DataType_UINT16:
|
|
||||||
return builder_.getIntegerType(/*width=*/16);
|
|
||||||
case onnx::TensorProto_DataType::TensorProto_DataType_INT32:
|
|
||||||
case onnx::TensorProto_DataType::TensorProto_DataType_UINT32:
|
|
||||||
return builder_.getIntegerType(/*width=*/32);
|
|
||||||
case onnx::TensorProto_DataType::TensorProto_DataType_INT64:
|
|
||||||
case onnx::TensorProto_DataType::TensorProto_DataType_UINT64:
|
|
||||||
return builder_.getIntegerType(/*width=*/64);
|
|
||||||
case onnx::TensorProto_DataType::TensorProto_DataType_BOOL:
|
|
||||||
return builder_.getI1Type();
|
|
||||||
|
|
||||||
case onnx::TensorProto_DataType::TensorProto_DataType_STRING:
|
|
||||||
case onnx::TensorProto_DataType::TensorProto_DataType_COMPLEX64:
|
|
||||||
case onnx::TensorProto_DataType::TensorProto_DataType_COMPLEX128:
|
|
||||||
case onnx::TensorProto_DataType::TensorProto_DataType_UNDEFINED:
|
|
||||||
default:
|
|
||||||
assert(false && "Unsupported data type encountered.");
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/*!
|
/*!
|
||||||
* Import an onnx input tensor type by determining and recording its type
|
* Import an onnx input tensor type by determining and recording its type
|
||||||
* in a list of input tensor mlir types.
|
* in a list of input tensor mlir types.
|
||||||
|
@ -119,7 +83,8 @@ private:
|
||||||
|
|
||||||
auto elementOnnxType =
|
auto elementOnnxType =
|
||||||
(onnx::TensorProto_DataType)input.type().tensor_type().elem_type();
|
(onnx::TensorProto_DataType)input.type().tensor_type().elem_type();
|
||||||
mlir::Type elementType = convertONNXTypeToMLIRType(elementOnnxType);
|
mlir::Type elementType =
|
||||||
|
convertONNXTypeToMLIRType(builder_, elementOnnxType);
|
||||||
llvm::ArrayRef<int64_t> tensor_dims(dims.data(), dims.size());
|
llvm::ArrayRef<int64_t> tensor_dims(dims.data(), dims.size());
|
||||||
return mlir::RankedTensorType::get(tensor_dims, elementType);
|
return mlir::RankedTensorType::get(tensor_dims, elementType);
|
||||||
}
|
}
|
||||||
|
|
|
@ -39,6 +39,7 @@ endif()
|
||||||
# targets, or only use it for libraries that have no further dependencies
|
# targets, or only use it for libraries that have no further dependencies
|
||||||
# (except system libraries such as libc).
|
# (except system libraries such as libc).
|
||||||
target_link_libraries(onnx-mlir
|
target_link_libraries(onnx-mlir
|
||||||
|
onnx
|
||||||
OMBuilder
|
OMBuilder
|
||||||
OMKrnlOps
|
OMKrnlOps
|
||||||
OMONNXOps
|
OMONNXOps
|
||||||
|
|
|
@ -21,6 +21,8 @@ add_library(OMONNXToKrnl
|
||||||
Tensor/Constant.cpp
|
Tensor/Constant.cpp
|
||||||
Tensor/Concat.cpp
|
Tensor/Concat.cpp
|
||||||
ConvertONNXToKrnl.cpp)
|
ConvertONNXToKrnl.cpp)
|
||||||
|
target_link_libraries(OMONNXToKrnl
|
||||||
|
onnx)
|
||||||
target_include_directories(OMONNXToKrnl
|
target_include_directories(OMONNXToKrnl
|
||||||
PRIVATE
|
PRIVATE
|
||||||
${ONNX_MLIR_SRC_ROOT}
|
${ONNX_MLIR_SRC_ROOT}
|
||||||
|
|
|
@ -14,6 +14,8 @@ target_include_directories(OMONNXOps
|
||||||
${ONNX_MLIR_SRC_ROOT}
|
${ONNX_MLIR_SRC_ROOT}
|
||||||
${ONNX_MLIR_BIN_ROOT}
|
${ONNX_MLIR_BIN_ROOT}
|
||||||
${ONNX_MLIR_SRC_ROOT})
|
${ONNX_MLIR_SRC_ROOT})
|
||||||
|
target_link_libraries(OMONNXOps
|
||||||
|
onnx)
|
||||||
add_dependencies(OMONNXOps OMONNXOpsIncGen)
|
add_dependencies(OMONNXOps OMONNXOpsIncGen)
|
||||||
# Linking dependencies:
|
# Linking dependencies:
|
||||||
add_dependencies(OMONNXOps
|
add_dependencies(OMONNXOps
|
||||||
|
|
|
@ -21,7 +21,6 @@
|
||||||
#include "llvm/ADT/SmallBitVector.h"
|
#include "llvm/ADT/SmallBitVector.h"
|
||||||
|
|
||||||
#include "ONNXOps.hpp"
|
#include "ONNXOps.hpp"
|
||||||
#include "ONNXOpsHelper.hpp"
|
|
||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
using namespace mlir::OpTrait::util;
|
using namespace mlir::OpTrait::util;
|
||||||
|
|
|
@ -23,6 +23,8 @@
|
||||||
#include "src/Interface/ResultTypeInferenceOpInterface.hpp"
|
#include "src/Interface/ResultTypeInferenceOpInterface.hpp"
|
||||||
#include "src/Interface/ShapeInferenceInterface.hpp"
|
#include "src/Interface/ShapeInferenceInterface.hpp"
|
||||||
|
|
||||||
|
#include "ONNXOpsHelper.hpp"
|
||||||
|
|
||||||
namespace mlir {
|
namespace mlir {
|
||||||
|
|
||||||
class ONNXOpsDialect : public Dialect {
|
class ONNXOpsDialect : public Dialect {
|
||||||
|
|
|
@ -397,7 +397,7 @@ def ONNXBitShiftOp:ONNX_Op<"BitShift",
|
||||||
}
|
}
|
||||||
|
|
||||||
def ONNXCastOp:ONNX_Op<"Cast",
|
def ONNXCastOp:ONNX_Op<"Cast",
|
||||||
[NoSideEffect]> {
|
[NoSideEffect, OpInterface<"ResultTypeInferenceOpInterface">]> {
|
||||||
let summary = "ONNX Cast operation";
|
let summary = "ONNX Cast operation";
|
||||||
let description = [{
|
let description = [{
|
||||||
"The operator casts the elements of a given input tensor to a data type"
|
"The operator casts the elements of a given input tensor to a data type"
|
||||||
|
@ -433,6 +433,14 @@ def ONNXCastOp:ONNX_Op<"Cast",
|
||||||
static std::vector<int> getTypeMap() {
|
static std::vector<int> getTypeMap() {
|
||||||
return {-1};
|
return {-1};
|
||||||
}
|
}
|
||||||
|
std::vector<mlir::Type> resultTypeInference() {
|
||||||
|
std::vector<mlir::Type> resultTypes;
|
||||||
|
auto toAttr = to().getSExtValue();
|
||||||
|
auto builder = mlir::OpBuilder(getContext());
|
||||||
|
resultTypes.push_back(mlir::UnrankedTensorType::get(
|
||||||
|
convertONNXTypeToMLIRType(builder, static_cast<onnx::TensorProto_DataType>(toAttr))));
|
||||||
|
return resultTypes;
|
||||||
|
}
|
||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -40,3 +40,40 @@ AffineMap getConvDimMap(Builder &builder, bool ceilMode) {
|
||||||
|
|
||||||
return AffineMap::get(1, 4, {dimExp});
|
return AffineMap::get(1, 4, {dimExp});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Convert type to MLIR type.
|
||||||
|
// A complete list of types can be found in:
|
||||||
|
// <onnx-mlir-build-folder>/third_party/onnx/onnx/onnx.pb.h
|
||||||
|
mlir::Type convertONNXTypeToMLIRType(
|
||||||
|
mlir::OpBuilder &builder_, onnx::TensorProto_DataType onnxType) {
|
||||||
|
switch (onnxType) {
|
||||||
|
case onnx::TensorProto_DataType::TensorProto_DataType_FLOAT16:
|
||||||
|
return builder_.getF16Type();
|
||||||
|
case onnx::TensorProto_DataType::TensorProto_DataType_FLOAT:
|
||||||
|
return builder_.getF32Type();
|
||||||
|
case onnx::TensorProto_DataType::TensorProto_DataType_DOUBLE:
|
||||||
|
return builder_.getF64Type();
|
||||||
|
case onnx::TensorProto_DataType::TensorProto_DataType_INT8:
|
||||||
|
case onnx::TensorProto_DataType::TensorProto_DataType_UINT8:
|
||||||
|
return builder_.getIntegerType(/*width=*/8);
|
||||||
|
case onnx::TensorProto_DataType::TensorProto_DataType_INT16:
|
||||||
|
case onnx::TensorProto_DataType::TensorProto_DataType_UINT16:
|
||||||
|
return builder_.getIntegerType(/*width=*/16);
|
||||||
|
case onnx::TensorProto_DataType::TensorProto_DataType_INT32:
|
||||||
|
case onnx::TensorProto_DataType::TensorProto_DataType_UINT32:
|
||||||
|
return builder_.getIntegerType(/*width=*/32);
|
||||||
|
case onnx::TensorProto_DataType::TensorProto_DataType_INT64:
|
||||||
|
case onnx::TensorProto_DataType::TensorProto_DataType_UINT64:
|
||||||
|
return builder_.getIntegerType(/*width=*/64);
|
||||||
|
case onnx::TensorProto_DataType::TensorProto_DataType_BOOL:
|
||||||
|
return builder_.getI1Type();
|
||||||
|
|
||||||
|
case onnx::TensorProto_DataType::TensorProto_DataType_STRING:
|
||||||
|
case onnx::TensorProto_DataType::TensorProto_DataType_COMPLEX64:
|
||||||
|
case onnx::TensorProto_DataType::TensorProto_DataType_COMPLEX128:
|
||||||
|
case onnx::TensorProto_DataType::TensorProto_DataType_UNDEFINED:
|
||||||
|
default:
|
||||||
|
assert(false && "Unsupported data type encountered.");
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -11,6 +11,9 @@
|
||||||
#include "mlir/IR/AffineExpr.h"
|
#include "mlir/IR/AffineExpr.h"
|
||||||
#include "mlir/IR/AffineMap.h"
|
#include "mlir/IR/AffineMap.h"
|
||||||
#include "mlir/IR/Builders.h"
|
#include "mlir/IR/Builders.h"
|
||||||
|
#include "mlir/IR/StandardTypes.h"
|
||||||
|
|
||||||
|
#include "onnx/onnx_pb.h"
|
||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
|
|
||||||
|
@ -31,3 +34,6 @@ AffineMap getIdentityDimMap(Builder &builder);
|
||||||
// - s2: stride
|
// - s2: stride
|
||||||
// - s3: dilation
|
// - s3: dilation
|
||||||
AffineMap getConvDimMap(Builder &builder, bool ceilMode);
|
AffineMap getConvDimMap(Builder &builder, bool ceilMode);
|
||||||
|
|
||||||
|
mlir::Type convertONNXTypeToMLIRType(
|
||||||
|
mlir::OpBuilder &builder_, onnx::TensorProto_DataType onnxType);
|
||||||
|
|
|
@ -18,6 +18,8 @@ target_include_directories(OMKrnlToLLVM
|
||||||
${ONNX_MLIR_SRC_ROOT}
|
${ONNX_MLIR_SRC_ROOT}
|
||||||
${ONNX_MLIR_BIN_ROOT}
|
${ONNX_MLIR_BIN_ROOT}
|
||||||
${ONNX_MLIR_SRC_ROOT})
|
${ONNX_MLIR_SRC_ROOT})
|
||||||
|
target_link_libraries(OMKrnlToLLVM
|
||||||
|
onnx)
|
||||||
|
|
||||||
#Linking dependencies:
|
#Linking dependencies:
|
||||||
add_dependencies(OMKrnlToLLVM
|
add_dependencies(OMKrnlToLLVM
|
||||||
|
|
|
@ -7,6 +7,8 @@ target_include_directories(OMAttributePromotion
|
||||||
# Linking dependencies:
|
# Linking dependencies:
|
||||||
add_dependencies(OMAttributePromotion
|
add_dependencies(OMAttributePromotion
|
||||||
OMPromotableConstOperandsOpInterface)
|
OMPromotableConstOperandsOpInterface)
|
||||||
|
target_link_libraries(OMAttributePromotion
|
||||||
|
onnx)
|
||||||
|
|
||||||
add_library(OMElideConstants
|
add_library(OMElideConstants
|
||||||
ElideConstants.cpp)
|
ElideConstants.cpp)
|
||||||
|
@ -16,6 +18,8 @@ target_include_directories(OMElideConstants
|
||||||
|
|
||||||
add_dependencies(OMElideConstants
|
add_dependencies(OMElideConstants
|
||||||
OMONNXOps)
|
OMONNXOps)
|
||||||
|
target_link_libraries(OMElideConstants
|
||||||
|
onnx)
|
||||||
|
|
||||||
set(LLVM_TARGET_DEFINITIONS ONNXRewrite.td)
|
set(LLVM_TARGET_DEFINITIONS ONNXRewrite.td)
|
||||||
onnx_mlir_tablegen(ONNXRewrite.inc -gen-rewriters)
|
onnx_mlir_tablegen(ONNXRewrite.inc -gen-rewriters)
|
||||||
|
@ -43,6 +47,8 @@ add_dependencies(OMONNXRewrite
|
||||||
# Linking dependencies:
|
# Linking dependencies:
|
||||||
add_dependencies(OMONNXRewrite
|
add_dependencies(OMONNXRewrite
|
||||||
OMONNXOps)
|
OMONNXOps)
|
||||||
|
target_link_libraries(OMONNXRewrite
|
||||||
|
onnx)
|
||||||
|
|
||||||
add_library(OMShapeInference ShapeInferencePass.cpp)
|
add_library(OMShapeInference ShapeInferencePass.cpp)
|
||||||
target_include_directories(OMShapeInference
|
target_include_directories(OMShapeInference
|
||||||
|
|
|
@ -279,7 +279,12 @@ OpsWithResultTypeInference = {
|
||||||
resultTypes.push_back(attr.getType());
|
resultTypes.push_back(attr.getType());
|
||||||
} else if (auto attr = sparse_valueAttr()) {
|
} else if (auto attr = sparse_valueAttr()) {
|
||||||
resultTypes.push_back(attr.getType());
|
resultTypes.push_back(attr.getType());
|
||||||
}'''
|
}''',
|
||||||
|
"Cast":
|
||||||
|
'''auto toAttr = to().getSExtValue();
|
||||||
|
auto builder = mlir::OpBuilder(getContext());
|
||||||
|
resultTypes.push_back(mlir::UnrankedTensorType::get(
|
||||||
|
convertONNXTypeToMLIRType(builder, static_cast<onnx::TensorProto_DataType>(toAttr))));'''
|
||||||
}
|
}
|
||||||
|
|
||||||
# Add an Op in this list if the Op needs result type deduction which is required
|
# Add an Op in this list if the Op needs result type deduction which is required
|
||||||
|
|
Loading…
Reference in New Issue