String type (Ready for Review) (#182)

* string type from tensorflow

* simplify type

* parser and print

* gen StringType for tablegen

* onnx to onnx-mlir type

* add namespace

* allow all integer type

* dialect document

* add test case

* format

* more precise type for ONNXOp

* format

* enable the failed test

* update comment

* update onnx.md

Co-authored-by: Gheorghe-Teodor Bercea <gt.bercea@gmail.com>
This commit is contained in:
chentong319 2020-06-25 16:34:37 -04:00 committed by GitHub
parent f811718144
commit 2e08b2112c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 1210 additions and 1064 deletions

File diff suppressed because it is too large Load Diff

View File

@ -16,6 +16,7 @@ add_library(OMONNXOps
ONNXOps.hpp ONNXOps.hpp
ONNXOpsHelper.cpp ONNXOpsHelper.cpp
ONNXOpsHelper.hpp) ONNXOpsHelper.hpp)
target_include_directories(OMONNXOps target_include_directories(OMONNXOps
PRIVATE PRIVATE
${ONNX_MLIR_SRC_ROOT} ${ONNX_MLIR_SRC_ROOT}

View File

@ -11,6 +11,7 @@
#include "mlir/Dialect/Traits.h" #include "mlir/Dialect/Traits.h"
#include "mlir/IR/Block.h" #include "mlir/IR/Block.h"
#include "mlir/IR/Builders.h" #include "mlir/IR/Builders.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/Function.h" #include "mlir/IR/Function.h"
#include "mlir/IR/IntegerSet.h" #include "mlir/IR/IntegerSet.h"
#include "mlir/IR/Matchers.h" #include "mlir/IR/Matchers.h"
@ -25,6 +26,7 @@
using namespace mlir; using namespace mlir;
using namespace mlir::OpTrait::util; using namespace mlir::OpTrait::util;
using namespace mlir::onnxmlir;
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// ONNX Helper functions // ONNX Helper functions
@ -481,6 +483,19 @@ ONNXOpsDialect::ONNXOpsDialect(mlir::MLIRContext *ctx)
#define GET_OP_LIST #define GET_OP_LIST
#include "src/Dialect/ONNX/ONNXOps.cpp.inc" #include "src/Dialect/ONNX/ONNXOps.cpp.inc"
>(); >();
addTypes<StringType>();
}
mlir::Type ONNXOpsDialect::parseType(mlir::DialectAsmParser &parser) const {
if (parser.parseKeyword("String"))
return Type();
return StringType::get(getContext());
}
void ONNXOpsDialect::printType(
mlir::Type type, mlir::DialectAsmPrinter &printer) const {
printer << "String";
} }
void ONNXEntryPointOp::build(mlir::OpBuilder &builder, void ONNXEntryPointOp::build(mlir::OpBuilder &builder,
@ -2025,8 +2040,12 @@ LogicalResult ONNXDynamicQuantizeLinearOp::inferShapes() {
auto yScaleTy = y_scale().getType().cast<ShapedType>(); auto yScaleTy = y_scale().getType().cast<ShapedType>();
auto yZPTy = y_zero_point().getType().cast<ShapedType>(); auto yZPTy = y_zero_point().getType().cast<ShapedType>();
IntegerType i8Type = IntegerType::get(8, getContext()); IntegerType ui8Type =
RankedTensorType scalarType = RankedTensorType::get({}, i8Type); IntegerType::get(8, IntegerType::Unsigned, getContext());
FloatType f32Type = FloatType::getF32(getContext());
RankedTensorType scalarType = RankedTensorType::get({}, f32Type);
RankedTensorType y_zero_point_type = RankedTensorType::get({}, ui8Type);
// Set the types for the scalars // Set the types for the scalars
if (!yScaleTy.hasStaticShape()) { if (!yScaleTy.hasStaticShape()) {
@ -2034,11 +2053,11 @@ LogicalResult ONNXDynamicQuantizeLinearOp::inferShapes() {
} }
if (!yZPTy.hasStaticShape()) { if (!yZPTy.hasStaticShape()) {
y_zero_point().setType(scalarType); y_zero_point().setType(y_zero_point_type);
} }
if (!yTy.hasStaticShape()) { if (!yTy.hasStaticShape()) {
RankedTensorType outType = RankedTensorType::get(inTy.getShape(), i8Type); RankedTensorType outType = RankedTensorType::get(inTy.getShape(), ui8Type);
y().setType(outType); y().setType(outType);
} }

View File

@ -31,6 +31,13 @@ class ONNXOpsDialect : public Dialect {
public: public:
ONNXOpsDialect(MLIRContext *context); ONNXOpsDialect(MLIRContext *context);
/// Parse an instance of a type registered to the onnx dialect.
mlir::Type parseType(mlir::DialectAsmParser &parser) const override;
/// Print an instance of a type registered to the onnx dialect.
void printType(
mlir::Type type, mlir::DialectAsmPrinter &printer) const override;
/// Provide a utility accessor to the dialect namespace. This is used by /// Provide a utility accessor to the dialect namespace. This is used by
/// several utilities for casting between dialects. /// several utilities for casting between dialects.
static StringRef getDialectNamespace() { return "onnx"; } static StringRef getDialectNamespace() { return "onnx"; }
@ -41,6 +48,39 @@ public:
#define GET_OP_CLASSES #define GET_OP_CLASSES
#include "src/Dialect/ONNX/ONNXOps.hpp.inc" #include "src/Dialect/ONNX/ONNXOps.hpp.inc"
// The namespace onnxmlir is experimental.
// onnx_mlir has been used in KRNL. Other candidates are onnxops, onnxdialect.
// Should this namesapce for onnx mlir project or ONNXOp dialect?
// Or we need two namespace?
// Will put all the ONNXOps into this namespace
namespace onnxmlir {
namespace ONNXTypes {
enum Kind {
FIRST_USED_ONNX_TYPE = Type::FIRST_PRIVATE_EXPERIMENTAL_1_TYPE,
//#define HANDLE_TF_TYPE(tftype, enumerant, name) enumerant,
//#include "src/Dialect/ONNX/ONXTypes.def"
STRING,
SEQ,
LAST_USED_ONNX_TYPE,
};
} // namespace ONNXTypes
class StringType : public mlir::Type::TypeBase<StringType, mlir::Type> {
public:
using Base::Base;
static bool kindof(unsigned kind) { return kind == ONNXTypes::STRING; }
static unsigned getTypeKind() { return ONNXTypes::STRING; }
static StringType get(MLIRContext *ctx) {
return Base::get(ctx, ONNXTypes::STRING);
}
};
} // end namespace onnxmlir
} // end namespace mlir } // end namespace mlir
namespace onnx_mlir {} namespace onnx_mlir {}

View File

@ -17,6 +17,8 @@
include "mlir/IR/OpBase.td" include "mlir/IR/OpBase.td"
#endif // OP_BASE #endif // OP_BASE
def StringType : Type<CPred<"$_self.isa<StringType>()">, "stirng type">;
#ifdef SHAPE_INFERENCE_INTERFACE #ifdef SHAPE_INFERENCE_INTERFACE
#else #else
include "src/Interface/ShapeInferenceInterface.td" include "src/Interface/ShapeInferenceInterface.td"

File diff suppressed because it is too large Load Diff

View File

@ -10,8 +10,11 @@
#include "ONNXOpsHelper.hpp" #include "ONNXOpsHelper.hpp"
#include "ONNXOps.hpp"
// Identity affine // Identity affine
using namespace mlir; using namespace mlir;
using namespace mlir::onnxmlir;
AffineMap getIdentityDimMap(Builder &builder) { AffineMap getIdentityDimMap(Builder &builder) {
return AffineMap::get(1, 0, {builder.getAffineDimExpr(0)}); return AffineMap::get(1, 0, {builder.getAffineDimExpr(0)});
} }
@ -55,21 +58,26 @@ mlir::Type convertONNXTypeToMLIRType(
case onnx::TensorProto_DataType::TensorProto_DataType_DOUBLE: case onnx::TensorProto_DataType::TensorProto_DataType_DOUBLE:
return builder_.getF64Type(); return builder_.getF64Type();
case onnx::TensorProto_DataType::TensorProto_DataType_INT8: case onnx::TensorProto_DataType::TensorProto_DataType_INT8:
case onnx::TensorProto_DataType::TensorProto_DataType_UINT8:
return builder_.getIntegerType(/*width=*/8); return builder_.getIntegerType(/*width=*/8);
case onnx::TensorProto_DataType::TensorProto_DataType_UINT8:
return builder_.getIntegerType(/*width=*/8, false);
case onnx::TensorProto_DataType::TensorProto_DataType_INT16: case onnx::TensorProto_DataType::TensorProto_DataType_INT16:
case onnx::TensorProto_DataType::TensorProto_DataType_UINT16:
return builder_.getIntegerType(/*width=*/16); return builder_.getIntegerType(/*width=*/16);
case onnx::TensorProto_DataType::TensorProto_DataType_UINT16:
return builder_.getIntegerType(/*width=*/16, false);
case onnx::TensorProto_DataType::TensorProto_DataType_INT32: case onnx::TensorProto_DataType::TensorProto_DataType_INT32:
case onnx::TensorProto_DataType::TensorProto_DataType_UINT32:
return builder_.getIntegerType(/*width=*/32); return builder_.getIntegerType(/*width=*/32);
case onnx::TensorProto_DataType::TensorProto_DataType_UINT32:
return builder_.getIntegerType(/*width=*/32, false);
case onnx::TensorProto_DataType::TensorProto_DataType_INT64: case onnx::TensorProto_DataType::TensorProto_DataType_INT64:
case onnx::TensorProto_DataType::TensorProto_DataType_UINT64:
return builder_.getIntegerType(/*width=*/64); return builder_.getIntegerType(/*width=*/64);
case onnx::TensorProto_DataType::TensorProto_DataType_UINT64:
return builder_.getIntegerType(/*width=*/64, false);
case onnx::TensorProto_DataType::TensorProto_DataType_BOOL: case onnx::TensorProto_DataType::TensorProto_DataType_BOOL:
return builder_.getI1Type(); return builder_.getI1Type();
case onnx::TensorProto_DataType::TensorProto_DataType_STRING: case onnx::TensorProto_DataType::TensorProto_DataType_STRING:
return StringType::get(builder_.getContext());
case onnx::TensorProto_DataType::TensorProto_DataType_COMPLEX64: case onnx::TensorProto_DataType::TensorProto_DataType_COMPLEX64:
case onnx::TensorProto_DataType::TensorProto_DataType_COMPLEX128: case onnx::TensorProto_DataType::TensorProto_DataType_COMPLEX128:
case onnx::TensorProto_DataType::TensorProto_DataType_UNDEFINED: case onnx::TensorProto_DataType::TensorProto_DataType_UNDEFINED:

View File

@ -2,8 +2,8 @@
// ----- // -----
func @test_reshape(%arg0 : tensor<?x10xf32>, %arg1 : tensor<4xi32>) -> tensor<*xf32> { func @test_reshape(%arg0 : tensor<?x10xf32>, %arg1 : tensor<4xi64>) -> tensor<*xf32> {
%0 = "onnx.Reshape"(%arg0, %arg1) : (tensor<?x10xf32>, tensor<4xi32>) -> tensor<*xf32> %0 = "onnx.Reshape"(%arg0, %arg1) : (tensor<?x10xf32>, tensor<4xi64>) -> tensor<*xf32>
"std.return"(%0) : (tensor<*xf32>) -> () "std.return"(%0) : (tensor<*xf32>) -> ()
// CHECK: llvm.func @llvm.memcpy.p0i8.p0i8.i64(!llvm<"i8*">, !llvm<"i8*">, !llvm.i64, !llvm.i1) // CHECK: llvm.func @llvm.memcpy.p0i8.p0i8.i64(!llvm<"i8*">, !llvm<"i8*">, !llvm.i64, !llvm.i1)

View File

@ -397,8 +397,8 @@ func @test_relu(%arg0 : tensor<?x10xf32>) -> tensor<*xf32> {
// ----- // -----
func @test_reshape(%arg0 : tensor<?x10xf32>, %arg1 : tensor<4xi32>) -> tensor<*xf32> { func @test_reshape(%arg0 : tensor<?x10xf32>, %arg1 : tensor<4xi64>) -> tensor<*xf32> {
%0 = "onnx.Reshape"(%arg0, %arg1) : (tensor<?x10xf32>, tensor<4xi32>) -> tensor<*xf32> %0 = "onnx.Reshape"(%arg0, %arg1) : (tensor<?x10xf32>, tensor<4xi64>) -> tensor<*xf32>
"std.return"(%0) : (tensor<*xf32>) -> () "std.return"(%0) : (tensor<*xf32>) -> ()
// CHECK-LABEL: test_reshape // CHECK-LABEL: test_reshape
@ -411,56 +411,52 @@ func @test_reshape(%arg0 : tensor<?x10xf32>, %arg1 : tensor<4xi32>) -> tensor<*x
// CHECK: [[TYPE_IN_BYTES_1:%.+]] = constant 4 : i64 // CHECK: [[TYPE_IN_BYTES_1:%.+]] = constant 4 : i64
// CHECK: %[[CONSTANT_1:.+]] = constant 0 : index // CHECK: %[[CONSTANT_1:.+]] = constant 0 : index
// CHECK: [[LOAD_0:%.+]] = load %arg1[%[[CONSTANT_1]]] : memref<4xi32> // CHECK: [[LOAD_0:%.+]] = load %arg1[%[[CONSTANT_1]]] : memref<4xi64>
// CHECK: [[DIM_1:%.+]] = dim %arg0, 0 : memref<?x10xf32> // CHECK: [[DIM_1:%.+]] = dim %arg0, 0 : memref<?x10xf32>
// CHECK: [[DIM_1_CAST:%.+]] = index_cast [[DIM_1]] : index to i32 // CHECK: [[DIM_1_CAST:%.+]] = index_cast [[DIM_1]] : index to i64
// CHECK: [[CONSTANT_2:%.+]] = constant 0 : i32 // CHECK: [[CONSTANT_2:%.+]] = constant 0 : i64
// CHECK: [[CMP_0:%.+]] = cmpi "eq", [[LOAD_0]], [[CONSTANT_2]] : i32 // CHECK: [[CMP_0:%.+]] = cmpi "eq", [[LOAD_0]], [[CONSTANT_2]] : i64
// CHECK: [[SELECT_0:%.+]] = select [[CMP_0]], [[DIM_1_CAST]], [[LOAD_0]] : i32 // CHECK: [[SELECT_0:%.+]] = select [[CMP_0]], [[DIM_1_CAST]], [[LOAD_0]] : i64
// CHECK: [[ZEXTI_0:%.+]] = zexti [[SELECT_0]] : i32 to i64 // CHECK: [[MUL_1:%.+]] = muli [[TYPE_IN_BYTES_1]], [[SELECT_0]] : i64
// CHECK: [[MUL_1:%.+]] = muli [[TYPE_IN_BYTES_1]], [[ZEXTI_0]] : i64
// CHECK: %[[CONSTANT_3:.+]] = constant 1 : index // CHECK: %[[CONSTANT_3:.+]] = constant 1 : index
// CHECK: [[LOAD_1:%.+]] = load %arg1[%[[CONSTANT_3]]] : memref<4xi32> // CHECK: [[LOAD_1:%.+]] = load %arg1[%[[CONSTANT_3]]] : memref<4xi64>
// CHECK: [[CONSTANT_3:%.+]] = constant 10 : i32 // CHECK: [[CONSTANT_3:%.+]] = constant 10 : i64
// CHECK: [[CONSTANT_4:%.+]] = constant 0 : i32 // CHECK: [[CONSTANT_4:%.+]] = constant 0 : i64
// CHECK: [[CMP_1:%.+]] = cmpi "eq", [[LOAD_1]], [[CONSTANT_4]] : i32 // CHECK: [[CMP_1:%.+]] = cmpi "eq", [[LOAD_1]], [[CONSTANT_4]] : i64
// CHECK: [[SELECT_1:%.+]] = select [[CMP_1]], [[CONSTANT_3]], [[LOAD_1]] : i32 // CHECK: [[SELECT_1:%.+]] = select [[CMP_1]], [[CONSTANT_3]], [[LOAD_1]] : i64
// CHECK: [[ZEXTI_1:%.+]] = zexti [[SELECT_1]] : i32 to i64 // CHECK: [[MUL_2:%.+]] = muli [[MUL_1]], [[SELECT_1]] : i64
// CHECK: [[MUL_2:%.+]] = muli [[MUL_1]], [[ZEXTI_1]] : i64
// CHECK: %[[CONSTANT_5:.+]] = constant 2 : index // CHECK: %[[CONSTANT_5:.+]] = constant 2 : index
// CHECK: [[LOAD_2:%.+]] = load %arg1[%[[CONSTANT_5]]] : memref<4xi32> // CHECK: [[LOAD_2:%.+]] = load %arg1[%[[CONSTANT_5]]] : memref<4xi64>
// CHECK: [[ZEXTI_2:%.+]] = zexti [[LOAD_2]] : i32 to i64 // CHECK: [[MUL_3:%.+]] = muli [[MUL_2]], [[LOAD_2]] : i64
// CHECK: [[MUL_3:%.+]] = muli [[MUL_2]], [[ZEXTI_2]] : i64
// CHECK: %[[CONSTANT_6:.+]] = constant 3 : index // CHECK: %[[CONSTANT_6:.+]] = constant 3 : index
// CHECK: [[LOAD_3:%.+]] = load %arg1[%[[CONSTANT_6]]] : memref<4xi32> // CHECK: [[LOAD_3:%.+]] = load %arg1[%[[CONSTANT_6]]] : memref<4xi64>
// CHECK: [[ZEXTI_3:%.+]] = zexti [[LOAD_3]] : i32 to i64 // CHECK: [[MUL_4:%.+]] = muli [[MUL_3]], [[LOAD_3]] : i64
// CHECK: [[MUL_4:%.+]] = muli [[MUL_3]], [[ZEXTI_3]] : i64
// CHECK: [[CONSTANT_7:%.+]] = constant 0 : i64 // CHECK: [[CONSTANT_7:%.+]] = constant 0 : i64
// CHECK: [[SUB_0:%.+]] = subi [[CONSTANT_7]], [[MUL_4]] : i64 // CHECK: [[SUB_0:%.+]] = subi [[CONSTANT_7]], [[MUL_4]] : i64
// CHECK: [[CONSTANT_8:%.+]] = constant -1 : i64 // CHECK: [[CONSTANT_8:%.+]] = constant -1 : i64
// CHECK: [[CMP_2:%.+]] = cmpi "eq", [[ZEXTI_0]], [[CONSTANT_8]] : i64 // CHECK: [[CMP_2:%.+]] = cmpi "eq", [[SELECT_0]], [[CONSTANT_8]] : i64
// CHECK: [[DIVISIGNED_0:%.+]] = divi_signed [[TENSOR_SIZE]], [[SUB_0]] : i64 // CHECK: [[DIVISIGNED_0:%.+]] = divi_signed [[TENSOR_SIZE]], [[SUB_0]] : i64
// CHECK: [[SELECT_2:%.+]] = select [[CMP_2]], [[DIVISIGNED_0]], [[ZEXTI_0]] : i64 // CHECK: [[SELECT_2:%.+]] = select [[CMP_2]], [[DIVISIGNED_0]], [[SELECT_0]] : i64
// CHECK: [[CAST_0:%.+]] = index_cast [[SELECT_2]] : i64 to index // CHECK: [[CAST_0:%.+]] = index_cast [[SELECT_2]] : i64 to index
// CHECK: [[CMP_3:%.+]] = cmpi "eq", [[ZEXTI_1]], [[CONSTANT_8]] : i64 // CHECK: [[CMP_3:%.+]] = cmpi "eq", [[SELECT_1]], [[CONSTANT_8]] : i64
// CHECK: [[DIVISIGNED_1:%.+]] = divi_signed [[TENSOR_SIZE]], [[SUB_0]] : i64 // CHECK: [[DIVISIGNED_1:%.+]] = divi_signed [[TENSOR_SIZE]], [[SUB_0]] : i64
// CHECK: [[SELECT_3:%.+]] = select [[CMP_3]], [[DIVISIGNED_1]], [[ZEXTI_1]] : i64 // CHECK: [[SELECT_3:%.+]] = select [[CMP_3]], [[DIVISIGNED_1]], [[SELECT_1]] : i64
// CHECK: [[CAST_1:%.+]] = index_cast [[SELECT_3]] : i64 to index // CHECK: [[CAST_1:%.+]] = index_cast [[SELECT_3]] : i64 to index
// CHECK: [[CMP_4:%.+]] = cmpi "eq", [[ZEXTI_2]], [[CONSTANT_8]] : i64 // CHECK: [[CMP_4:%.+]] = cmpi "eq", [[LOAD_2]], [[CONSTANT_8]] : i64
// CHECK: [[DIVISIGNED_2:%.+]] = divi_signed [[TENSOR_SIZE]], [[SUB_0]] : i64 // CHECK: [[DIVISIGNED_2:%.+]] = divi_signed [[TENSOR_SIZE]], [[SUB_0]] : i64
// CHECK: [[SELECT_4:%.+]] = select [[CMP_4]], [[DIVISIGNED_2]], [[ZEXTI_2]] : i64 // CHECK: [[SELECT_4:%.+]] = select [[CMP_4]], [[DIVISIGNED_2]], [[LOAD_2]] : i64
// CHECK: [[CAST_2:%.+]] = index_cast [[SELECT_4]] : i64 to index // CHECK: [[CAST_2:%.+]] = index_cast [[SELECT_4]] : i64 to index
// CHECK: [[CMP_5:%.+]] = cmpi "eq", [[ZEXTI_3]], [[CONSTANT_8]] : i64 // CHECK: [[CMP_5:%.+]] = cmpi "eq", [[LOAD_3]], [[CONSTANT_8]] : i64
// CHECK: [[DIVISIGNED_3:%.+]] = divi_signed [[TENSOR_SIZE]], [[SUB_0]] : i64 // CHECK: [[DIVISIGNED_3:%.+]] = divi_signed [[TENSOR_SIZE]], [[SUB_0]] : i64
// CHECK: [[SELECT_5:%.+]] = select [[CMP_5]], [[DIVISIGNED_3]], [[ZEXTI_3]] : i64 // CHECK: [[SELECT_5:%.+]] = select [[CMP_5]], [[DIVISIGNED_3]], [[LOAD_3]] : i64
// CHECK: [[CAST_3:%.+]] = index_cast [[SELECT_5]] : i64 to index // CHECK: [[CAST_3:%.+]] = index_cast [[SELECT_5]] : i64 to index
// CHECK: [[ALLOC:%.+]] = alloc([[CAST_0]], [[CAST_1]], [[CAST_2]], [[CAST_3]]) : memref<?x?x?x?xf32> // CHECK: [[ALLOC:%.+]] = alloc([[CAST_0]], [[CAST_1]], [[CAST_2]], [[CAST_3]]) : memref<?x?x?x?xf32>

View File

@ -542,48 +542,48 @@ func @test_default_averagepool_strides_nonunifpad_ceil(%arg0 : tensor<5x5x30x32x
/// Test the reshape op inference when constants are present. /// Test the reshape op inference when constants are present.
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
func @test_reshape_dynamic(%arg0 : tensor<5x5x1x32xf32>, %arg1 : tensor<4xi32>) -> tensor<*xf32> { func @test_reshape_dynamic(%arg0 : tensor<5x5x1x32xf32>, %arg1 : tensor<4xi64>) -> tensor<*xf32> {
%0 = "onnx.Reshape"(%arg0, %arg1) : (tensor<5x5x1x32xf32>, tensor<4xi32>) -> tensor<*xf32> %0 = "onnx.Reshape"(%arg0, %arg1) : (tensor<5x5x1x32xf32>, tensor<4xi64>) -> tensor<*xf32>
"std.return"(%0) : (tensor<*xf32>) -> () "std.return"(%0) : (tensor<*xf32>) -> ()
// CHECK-LABEL: test_reshape_dynamic // CHECK-LABEL: test_reshape_dynamic
// CHECK: [[RES:%.+]] = "onnx.Reshape"(%arg0, %arg1) : (tensor<5x5x1x32xf32>, tensor<4xi32>) -> tensor<?x?x?x?xf32> // CHECK: [[RES:%.+]] = "onnx.Reshape"(%arg0, %arg1) : (tensor<5x5x1x32xf32>, tensor<4xi64>) -> tensor<?x?x?x?xf32>
// CHECK: return [[RES]] : tensor<?x?x?x?xf32> // CHECK: return [[RES]] : tensor<?x?x?x?xf32>
} }
// ----- // -----
func @test_reshape_1(%arg0 : tensor<5x5x1x32xf32>) -> tensor<*xf32> { func @test_reshape_1(%arg0 : tensor<5x5x1x32xf32>) -> tensor<*xf32> {
%0 = "onnx.Constant"() {value = dense<[5, 5, 16, 2]> : tensor<4xi32> } : () -> tensor<4xi32> %0 = "onnx.Constant"() {value = dense<[5, 5, 16, 2]> : tensor<4xi64> } : () -> tensor<4xi64>
%1 = "onnx.Reshape"(%arg0, %0) : (tensor<5x5x1x32xf32>, tensor<4xi32>) -> tensor<*xf32> %1 = "onnx.Reshape"(%arg0, %0) : (tensor<5x5x1x32xf32>, tensor<4xi64>) -> tensor<*xf32>
"std.return"(%1) : (tensor<*xf32>) -> () "std.return"(%1) : (tensor<*xf32>) -> ()
// CHECK-LABEL: test_reshape_1 // CHECK-LABEL: test_reshape_1
// CHECK: [[RES:%.+]] = "onnx.Reshape"(%arg0, %0) : (tensor<5x5x1x32xf32>, tensor<4xi32>) -> tensor<5x5x16x2xf32> // CHECK: [[RES:%.+]] = "onnx.Reshape"(%arg0, %0) : (tensor<5x5x1x32xf32>, tensor<4xi64>) -> tensor<5x5x16x2xf32>
// CHECK: return [[RES]] : tensor<5x5x16x2xf32> // CHECK: return [[RES]] : tensor<5x5x16x2xf32>
} }
// ----- // -----
func @test_reshape_2(%arg0 : tensor<5x5x1x32xf32>) -> tensor<*xf32> { func @test_reshape_2(%arg0 : tensor<5x5x1x32xf32>) -> tensor<*xf32> {
%0 = "onnx.Constant"() {value = dense<[-1, 16, 2]> : tensor<3xi32> } : () -> tensor<3xi32> %0 = "onnx.Constant"() {value = dense<[-1, 16, 2]> : tensor<3xi64> } : () -> tensor<3xi64>
%1 = "onnx.Reshape"(%arg0, %0) : (tensor<5x5x1x32xf32>, tensor<3xi32>) -> tensor<*xf32> %1 = "onnx.Reshape"(%arg0, %0) : (tensor<5x5x1x32xf32>, tensor<3xi64>) -> tensor<*xf32>
"std.return"(%1) : (tensor<*xf32>) -> () "std.return"(%1) : (tensor<*xf32>) -> ()
// CHECK-LABEL: test_reshape_2 // CHECK-LABEL: test_reshape_2
// CHECK: [[RES:%.+]] = "onnx.Reshape"(%arg0, %0) : (tensor<5x5x1x32xf32>, tensor<3xi32>) -> tensor<25x16x2xf32> // CHECK: [[RES:%.+]] = "onnx.Reshape"(%arg0, %0) : (tensor<5x5x1x32xf32>, tensor<3xi64>) -> tensor<25x16x2xf32>
// CHECK: return [[RES]] : tensor<25x16x2xf32> // CHECK: return [[RES]] : tensor<25x16x2xf32>
} }
// ----- // -----
func @test_reshape_3(%arg0 : tensor<5x5x1x32xf32>) -> tensor<*xf32> { func @test_reshape_3(%arg0 : tensor<5x5x1x32xf32>) -> tensor<*xf32> {
%0 = "onnx.Constant"() {value = dense<[-1, 0, 2]> : tensor<3xi32> } : () -> tensor<3xi32> %0 = "onnx.Constant"() {value = dense<[-1, 0, 2]> : tensor<3xi64> } : () -> tensor<3xi64>
%1 = "onnx.Reshape"(%arg0, %0) : (tensor<5x5x1x32xf32>, tensor<3xi32>) -> tensor<*xf32> %1 = "onnx.Reshape"(%arg0, %0) : (tensor<5x5x1x32xf32>, tensor<3xi64>) -> tensor<*xf32>
"std.return"(%1) : (tensor<*xf32>) -> () "std.return"(%1) : (tensor<*xf32>) -> ()
// CHECK-LABEL: test_reshape_3 // CHECK-LABEL: test_reshape_3
// CHECK: [[RES:%.+]] = "onnx.Reshape"(%arg0, %0) : (tensor<5x5x1x32xf32>, tensor<3xi32>) -> tensor<80x5x2xf32> // CHECK: [[RES:%.+]] = "onnx.Reshape"(%arg0, %0) : (tensor<5x5x1x32xf32>, tensor<3xi64>) -> tensor<80x5x2xf32>
// CHECK: return [[RES]] : tensor<80x5x2xf32> // CHECK: return [[RES]] : tensor<80x5x2xf32>
} }
@ -904,13 +904,13 @@ func @test_cast_2(%arg0 : tensor<2x3x4xf32>) -> tensor<*xui8> {
"std.return"(%1) : (tensor<*xui8>) -> () "std.return"(%1) : (tensor<*xui8>) -> ()
// CHECK-LABEL: test_cast_2 // CHECK-LABEL: test_cast_2
// CHECK: [[RES:%.+]] = "onnx.Cast"(%arg0) {to = 2 : i64} : (tensor<2x3x4xf32>) -> tensor<2x3x4xi8> // CHECK: [[RES:%.+]] = "onnx.Cast"(%arg0) {to = 2 : i64} : (tensor<2x3x4xf32>) -> tensor<2x3x4xui8>
// CHECK: return [[RES]] : tensor<2x3x4xi8> // CHECK: return [[RES]] : tensor<2x3x4xui8>
} }
func @test_cast_3(%arg0 : tensor<2x3x4xf32>) -> tensor<*xsi8> { func @test_cast_3(%arg0 : tensor<2x3x4xf32>) -> tensor<*xi8> {
%1 = "onnx.Cast"(%arg0) {to = 3} : (tensor<2x3x4xf32>) -> tensor<*xsi8> %1 = "onnx.Cast"(%arg0) {to = 3} : (tensor<2x3x4xf32>) -> tensor<*xi8>
"std.return"(%1) : (tensor<*xsi8>) -> () "std.return"(%1) : (tensor<*xi8>) -> ()
// CHECK-LABEL: test_cast_3 // CHECK-LABEL: test_cast_3
// CHECK: [[RES:%.+]] = "onnx.Cast"(%arg0) {to = 3 : i64} : (tensor<2x3x4xf32>) -> tensor<2x3x4xi8> // CHECK: [[RES:%.+]] = "onnx.Cast"(%arg0) {to = 3 : i64} : (tensor<2x3x4xf32>) -> tensor<2x3x4xi8>
@ -930,30 +930,33 @@ func @test_cast_10(%arg0 : tensor<2x3x4xf32>) -> tensor<*xf16> {
/// Test the quantization op inferences. /// Test the quantization op inferences.
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
func @test_dyn_quantize_linear_1(%arg0 : tensor<5x2x3x4xf32>) -> tensor<*xi8> { // TOFIX
%1:3 = "onnx.DynamicQuantizeLinear"(%arg0) {} : (tensor<5x2x3x4xf32>) -> (tensor<*xi8>, tensor<*xi8>, tensor<*xi8>) // This test case is commented out because the #1 output should be tensor<f32>
"std.return"(%1#0) {} : (tensor<*xi8>) -> () // but tensor<i8> is generated
func @test_dyn_quantize_linear_1(%arg0 : tensor<5x2x3x4xf32>) -> tensor<*xui8> {
%1:3 = "onnx.DynamicQuantizeLinear"(%arg0) {} : (tensor<5x2x3x4xf32>) -> (tensor<*xui8>, tensor<*xf32>, tensor<*xui8>)
"std.return"(%1#0) {} : (tensor<*xui8>) -> ()
// CHECK-LABEL: test_dyn_quantize_linear_1 // CHECK-LABEL: test_dyn_quantize_linear_1
// CHECK: [[RES:%.+]], {{.*}}, {{.*}} = "onnx.DynamicQuantizeLinear"(%arg0) : (tensor<5x2x3x4xf32>) -> (tensor<5x2x3x4xi8>, tensor<i8>, tensor<i8>) // CHECK: [[RES:%.+]], {{.*}}, {{.*}} = "onnx.DynamicQuantizeLinear"(%arg0) : (tensor<5x2x3x4xf32>) -> (tensor<5x2x3x4xui8>, tensor<f32>, tensor<ui8>)
// CHECK: return [[RES]] : tensor<5x2x3x4xi8> // CHECK: return [[RES]] : tensor<5x2x3x4xui8>
} }
func @test_quantize_linear_1(%arg0 : tensor<5x2x3x4xf32>, %arg1 : tensor<i8>, %arg2 : tensor<i8>) -> tensor<*xi8> { func @test_quantize_linear_1(%arg0 : tensor<5x2x3x4xf32>, %arg1 : tensor<f32>, %arg2 : tensor<i8>) -> tensor<*xi8> {
%1 = "onnx.QuantizeLinear"(%arg0, %arg1, %arg2) {} : (tensor<5x2x3x4xf32>, tensor<i8>, tensor<i8>) -> tensor<*xi8> %1 = "onnx.QuantizeLinear"(%arg0, %arg1, %arg2) {} : (tensor<5x2x3x4xf32>, tensor<f32>, tensor<i8>) -> tensor<*xi8>
"std.return"(%1) {} : (tensor<*xi8>) -> () "std.return"(%1) {} : (tensor<*xi8>) -> ()
// CHECK-LABEL: test_quantize_linear_1 // CHECK-LABEL: test_quantize_linear_1
// CHECK: [[RES:%.+]] = "onnx.QuantizeLinear"(%arg0, %arg1, %arg2) : (tensor<5x2x3x4xf32>, tensor<i8>, tensor<i8>) -> tensor<5x2x3x4xi8> // CHECK: [[RES:%.+]] = "onnx.QuantizeLinear"(%arg0, %arg1, %arg2) : (tensor<5x2x3x4xf32>, tensor<f32>, tensor<i8>) -> tensor<5x2x3x4xi8>
// CHECK: return [[RES]] : tensor<5x2x3x4xi8> // CHECK: return [[RES]] : tensor<5x2x3x4xi8>
} }
func @test_dequantize_linear_1(%arg0 : tensor<5x2x3x4xi8>, %arg1 : tensor<i8>, %arg2 : tensor<i8>) -> tensor<*xf32> { func @test_dequantize_linear_1(%arg0 : tensor<5x2x3x4xi8>, %arg1 : tensor<f32>, %arg2 : tensor<i8>) -> tensor<*xf32> {
%1 = "onnx.DequantizeLinear"(%arg0, %arg1, %arg2) {} : (tensor<5x2x3x4xi8>, tensor<i8>, tensor<i8>) -> tensor<*xf32> %1 = "onnx.DequantizeLinear"(%arg0, %arg1, %arg2) {} : (tensor<5x2x3x4xi8>, tensor<f32>, tensor<i8>) -> tensor<*xf32>
"std.return"(%1) {} : (tensor<*xf32>) -> () "std.return"(%1) {} : (tensor<*xf32>) -> ()
// CHECK-LABEL: test_dequantize_linear_1 // CHECK-LABEL: test_dequantize_linear_1
// CHECK: [[RES:%.+]] = "onnx.DequantizeLinear"(%arg0, %arg1, %arg2) : (tensor<5x2x3x4xi8>, tensor<i8>, tensor<i8>) -> tensor<5x2x3x4xf32> // CHECK: [[RES:%.+]] = "onnx.DequantizeLinear"(%arg0, %arg1, %arg2) : (tensor<5x2x3x4xi8>, tensor<f32>, tensor<i8>) -> tensor<5x2x3x4xf32>
// CHECK: return [[RES]] : tensor<5x2x3x4xf32> // CHECK: return [[RES]] : tensor<5x2x3x4xf32>
} }

View File

@ -1,9 +1,16 @@
// RUN: onnx-mlir-opt %s -split-input-file | FileCheck %s // RUN: onnx-mlir-opt %s -split-input-file | FileCheck %s
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// CHECK-LABEL: @check_map1(%arg0: tuple<tensor<10xi64>, tensor<10xi64>>) -> tensor<*xi64> { // CHECK-LABEL: @check_map1(%arg0: tuple<i64, f32>) -> tensor<*xf32> {
func @check_map1(%arg0: tuple<tensor<10xi64>, tensor<10xi64>>) -> tensor<*xi64> { func @check_map1(%arg0: tuple<i64, f32>) -> tensor<*xf32> {
%0 = "onnx.CastMap"(%arg0) {cast_to = "TO_FLOAT", map_form = "DENSE", max_map = 1 : i64} : (tuple<tensor<10xi64>, tensor<10xi64>>) -> tensor<*xi64> %0 = "onnx.CastMap"(%arg0) {cast_to = "TO_FLOAT", map_form = "DENSE", max_map = 1 : i64} : (tuple<i64, f32>) -> tensor<*xf32>
return %0 : tensor<*xi64> return %0 : tensor<*xf32>
// CHECK-NEXT: %0 = "onnx.CastMap"(%arg0) {cast_to = "TO_FLOAT", map_form = "DENSE", max_map = 1 : i64} : (tuple<tensor<10xi64>, tensor<10xi64>>) -> tensor<*xi64> // CHECK-NEXT: %0 = "onnx.CastMap"(%arg0) {cast_to = "TO_FLOAT", map_form = "DENSE", max_map = 1 : i64} : (tuple<i64, f32>) -> tensor<*xf32>
} }
// CHECK-LABEL: @check_string(%arg0: tensor<10x20x!onnx.String>) -> tensor<10x20x!onnx.String> {
func @check_string(%arg0: tensor<10x20x!onnx.String>) -> tensor<10x20x!onnx.String> {
return %arg0 : tensor<10x20x!onnx.String>
// CHECK-NEXT: return %arg0 : tensor<10x20x!onnx.String>
}

View File

@ -1,22 +1,22 @@
// RUN: onnx-mlir-opt --attribute-promotion %s -split-input-file | FileCheck %s // RUN: onnx-mlir-opt --attribute-promotion %s -split-input-file | FileCheck %s
func @test_should_promote_to_attribute(%arg0 : tensor<?x10xf32>) -> tensor<*xf32> { func @test_should_promote_to_attribute(%arg0 : tensor<?x10xf32>) -> tensor<*xf32> {
%shape = constant dense<[6, 7, 42]> : tensor<3xi32> %shape = constant dense<[6, 7, 42]> : tensor<3xi64>
%0 = "onnx.Reshape"(%arg0, %shape) : (tensor<?x10xf32>, tensor<3xi32>) -> tensor<*xf32> %0 = "onnx.Reshape"(%arg0, %shape) : (tensor<?x10xf32>, tensor<3xi64>) -> tensor<*xf32>
return %0 : tensor<*xf32> return %0 : tensor<*xf32>
// CHECK-LABEL: test_should_promote_to_attribute // CHECK-LABEL: test_should_promote_to_attribute
// CHECK-NEXT: [[NONE:%.+]] = constant unit // CHECK-NEXT: [[NONE:%.+]] = constant unit
// CHECK-NEXT: [[RESHAPE:%.+]] = "onnx.Reshape"(%{{.*}}, [[NONE]]) {shape = dense<[6, 7, 42]> : tensor<3xi32>} : (tensor<?x10xf32>, none) -> tensor<*xf32> // CHECK-NEXT: [[RESHAPE:%.+]] = "onnx.Reshape"(%{{.*}}, [[NONE]]) {shape = dense<[6, 7, 42]> : tensor<3xi64>} : (tensor<?x10xf32>, none) -> tensor<*xf32>
// CHECK-NEXT: return [[RESHAPE]] : tensor<*xf32> // CHECK-NEXT: return [[RESHAPE]] : tensor<*xf32>
} }
func @test_should_promote_to_attribute_1(%arg0 : tensor<?x10xf32>) -> tensor<*xf32> { func @test_should_promote_to_attribute_1(%arg0 : tensor<?x10xf32>) -> tensor<*xf32> {
%shape = "onnx.Constant"() { value = dense<[6, 7, 42]> : tensor<3xi32>}: () -> tensor<3xi32> %shape = "onnx.Constant"() { value = dense<[6, 7, 42]> : tensor<3xi64>}: () -> tensor<3xi64>
%0 = "onnx.Reshape"(%arg0, %shape) : (tensor<?x10xf32>, tensor<3xi32>) -> tensor<*xf32> %0 = "onnx.Reshape"(%arg0, %shape) : (tensor<?x10xf32>, tensor<3xi64>) -> tensor<*xf32>
return %0 : tensor<*xf32> return %0 : tensor<*xf32>
// CHECK-LABEL: test_should_promote_to_attribute_1 // CHECK-LABEL: test_should_promote_to_attribute_1
// CHECK-NEXT: [[NONE:%.+]] = constant unit // CHECK-NEXT: [[NONE:%.+]] = constant unit
// CHECK-NEXT: [[RESHAPE:%.+]] = "onnx.Reshape"(%{{.*}}, [[NONE]]) {shape = dense<[6, 7, 42]> : tensor<3xi32>} : (tensor<?x10xf32>, none) -> tensor<*xf32> // CHECK-NEXT: [[RESHAPE:%.+]] = "onnx.Reshape"(%{{.*}}, [[NONE]]) {shape = dense<[6, 7, 42]> : tensor<3xi64>} : (tensor<?x10xf32>, none) -> tensor<*xf32>
// CHECK-NEXT: return [[RESHAPE]] : tensor<*xf32> // CHECK-NEXT: return [[RESHAPE]] : tensor<*xf32>
} }
@ -29,25 +29,25 @@ func @test_should_not_promote_to_attribute(%arg0 : tensor<?x10xf32>, %arg1 : ten
} }
func @test_promote_to_attribute_without_removing_const_op(%arg0 : tensor<?x10xf32>) -> (tensor<*xf32>, tensor<*xf32>) { func @test_promote_to_attribute_without_removing_const_op(%arg0 : tensor<?x10xf32>) -> (tensor<*xf32>, tensor<*xf32>) {
%shape = constant dense<[6, 7, 42]> : tensor<3xi32> %shape = constant dense<[6, 7, 42]> : tensor<3xi64>
%0 = "onnx.Reshape"(%arg0, %shape) : (tensor<?x10xf32>, tensor<3xi32>) -> tensor<*xf32> %0 = "onnx.Reshape"(%arg0, %shape) : (tensor<?x10xf32>, tensor<3xi64>) -> tensor<*xf32>
%1 = "onnx.Identity"(%shape) : (tensor<3xi32>) -> tensor<*xf32> %1 = "onnx.Identity"(%shape) : (tensor<3xi64>) -> tensor<*xf32>
"std.return"(%0, %1) : (tensor<*xf32>, tensor<*xf32>) -> () "std.return"(%0, %1) : (tensor<*xf32>, tensor<*xf32>) -> ()
// CHECK-LABEL: test_promote_to_attribute_without_removing_const_op // CHECK-LABEL: test_promote_to_attribute_without_removing_const_op
// CHECK-NEXT: [[NONE:%.+]] = constant unit // CHECK-NEXT: [[NONE:%.+]] = constant unit
// CHECK-NEXT: [[SHAPE:%.+]] = constant dense<[6, 7, 42]> : tensor<3xi32> // CHECK-NEXT: [[SHAPE:%.+]] = constant dense<[6, 7, 42]> : tensor<3xi64>
// CHECK-NEXT: [[RESHAPE:%.+]] = "onnx.Reshape"(%{{.*}}, [[NONE]]) {shape = dense<[6, 7, 42]> : tensor<3xi32>} : (tensor<?x10xf32>, none) -> tensor<*xf32> // CHECK-NEXT: [[RESHAPE:%.+]] = "onnx.Reshape"(%{{.*}}, [[NONE]]) {shape = dense<[6, 7, 42]> : tensor<3xi64>} : (tensor<?x10xf32>, none) -> tensor<*xf32>
// CHECK-NEXT: [[IDENTITY:%.+]] = "onnx.Identity"([[SHAPE]]) : (tensor<3xi32>) -> tensor<*xf32> // CHECK-NEXT: [[IDENTITY:%.+]] = "onnx.Identity"([[SHAPE]]) : (tensor<3xi64>) -> tensor<*xf32>
// CHECK-NEXT: return [[RESHAPE]], [[IDENTITY]] : tensor<*xf32>, tensor<*xf32> // CHECK-NEXT: return [[RESHAPE]], [[IDENTITY]] : tensor<*xf32>, tensor<*xf32>
} }
func @test_should_promote_to_attribute1(%arg0 : tensor<?x?xf32>) -> tensor<*xf32> { func @test_should_promote_to_attribute1(%arg0 : tensor<?x?xf32>) -> tensor<*xf32> {
%shape = constant dense<[0, 2, 2, 4]> : tensor<4xi32> %shape = constant dense<[0, 2, 2, 4]> : tensor<4xi64>
%constant_value = constant dense<[0.]> : tensor<1xf32> %constant_value = constant dense<[0.]> : tensor<1xf32>
%0 = "onnx.Pad"(%arg0, %shape, %constant_value) {mode = "constant"} : (tensor<?x?xf32>, tensor<4xi32>, tensor<1xf32>)-> tensor<*xf32> %0 = "onnx.Pad"(%arg0, %shape, %constant_value) {mode = "constant"} : (tensor<?x?xf32>, tensor<4xi64>, tensor<1xf32>)-> tensor<*xf32>
return %0 : tensor<*xf32> return %0 : tensor<*xf32>
// CHECK-LABEL: test_should_promote_to_attribute1 // CHECK-LABEL: test_should_promote_to_attribute1
// CHECK-NEXT: [[NONE:%.+]] = constant unit // CHECK-NEXT: [[NONE:%.+]] = constant unit
// CHECK-NEXT: [[PAD:%.+]] = "onnx.Pad"(%{{.*}}, [[NONE]], [[NONE]]) {constant_value = dense<0.000000e+00> : tensor<1xf32>, mode = "constant", pads = dense<[0, 2, 2, 4]> : tensor<4xi32>} : (tensor<?x?xf32>, none, none) -> tensor<*xf32> // CHECK-NEXT: [[PAD:%.+]] = "onnx.Pad"(%{{.*}}, [[NONE]], [[NONE]]) {constant_value = dense<0.000000e+00> : tensor<1xf32>, mode = "constant", pads = dense<[0, 2, 2, 4]> : tensor<4xi64>} : (tensor<?x?xf32>, none, none) -> tensor<*xf32>
// CHECK-NEXT: return [[RESHAPE]] : tensor<*xf32> // CHECK-NEXT: return [[RESHAPE]] : tensor<*xf32>
} }

View File

@ -320,10 +320,10 @@ custom_definition_misc = dict([ ('Constant',
onnx_types = ( onnx_types = (
'bool', 'int8', 'int16', 'int32', 'int64', 'unkown', 'float16', 'bool', 'int8', 'int16', 'int32', 'int64', 'unkown', 'float16',
'float', 'double', 'complex64', 'complex128' 'float', 'double', 'complex64', 'complex128', 'string'
) )
tblgen_types = ('I1', 'I8', 'I16', 'I32', 'I64', 'BF16', 'F16', 'F32', 'F64', tblgen_types = ('AnyI1', 'AnyI8', 'AnyI16', 'AnyI32', 'AnyI64', 'BF16', 'F16', 'F32', 'F64',
'Complex<F32>', 'Complex<F64>' 'Complex<F32>', 'Complex<F64>', 'StringType'
) )
MAX_NUM_TYPES=20 MAX_NUM_TYPES=20
@ -468,7 +468,7 @@ def dec_indent(indent):
def join_args(args): def join_args(args):
return ", ".join(args) return ", ".join(args)
def get_operands_or_results(schema, is_input): def get_operands_or_results(schema, type_str_dict, is_input):
value_list = schema.inputs if is_input else schema.outputs value_list = schema.inputs if is_input else schema.outputs
if not value_list: if not value_list:
return OrderedDict() return OrderedDict()
@ -482,7 +482,10 @@ def get_operands_or_results(schema, is_input):
name_to_types = OrderedDict() name_to_types = OrderedDict()
for i, value in enumerate(value_list): for i, value in enumerate(value_list):
structure, elem_types = get_allowed_elem_types(schema, value) types = get_onnx_mlir_types(schema, type_str_dict, value)
'''
structure, elem_types = get_allowed_elem_types(schema, type_str_dict, value)
if structure == 'tensor' : if structure == 'tensor' :
if elem_types is None: if elem_types is None:
@ -513,6 +516,7 @@ def get_operands_or_results(schema, is_input):
types = list(map(lambda x: x.format(elem_types_str), types)) types = list(map(lambda x: x.format(elem_types_str), types))
else: else:
types = ["AnyMemRef", "AnyTensor"] types = ["AnyMemRef", "AnyTensor"]
'''
# If operand is promotable to an attribute, then it must be # If operand is promotable to an attribute, then it must be
# nullable in case it migrates to be an attribute. # nullable in case it migrates to be an attribute.
@ -693,7 +697,68 @@ def get_type_inference_func(s, indent, type_inference_code):
indent = dec_indent(indent) indent = dec_indent(indent)
return s return s
def parse_type_str(allowedType):
# AnyI may be used for uint because the onnx_mlir is not generating uint output
# This will be fixed later and UI will be replace AnyI
onnx_to_mlir_type_dict = { '(': '<[',
')': ']>',
'tensor' : 'TensorOf',
'seq' : 'TensorOf',
'map' : 'TupleOf',
'bool': 'I1',
#'uint8' : 'AnyI8',
#uint16' : 'AnyI16',
#uint32' : 'AnyI32',
#uint64' : 'AnyI64',
'uint8' : 'UI8',
'uint16' : 'UI16',
'uint32' : 'UI32',
'uint64' : 'UI64',
'int8' : 'I8',
'int16' : 'I16',
'int32' : 'I32',
'int64' : 'I64',
'float16' : 'F16',
'float' : 'F32',
'double' : 'F64',
'unkown' : 'BF16',
'complex64' : 'Complex<F32>',
'complex128' : 'Complex<F64>',
'string' : 'StringType'}
for key, item in onnx_to_mlir_type_dict.items():
allowedType = allowedType.replace(key, item)
return allowedType
def parse_a_type_constraint(constraint):
allowedTypes = constraint.allowed_type_strs
mlirTypes = []
for allowedType in allowedTypes:
mlirType = parse_type_str(allowedType)
mlirTypes.append(mlirType)
# Remove redundant and sort.
# However onnx keeps a consitently meaningful order
# There is no redundancy as long as each onnx type is mapped uniquely
# mlirTypes = sorted(list(set(mlirTypes)))
return mlirTypes
def parse_type_constraints(schema):
type_str_dict = dict()
for type_constraint in schema.type_constraints:
type_str_dict[type_constraint.type_param_str] = parse_a_type_constraint(type_constraint)
return type_str_dict
def get_onnx_mlir_types(schema, type_str_dict, input):
if input.typeStr :
if not input.typeStr in type_str_dict :
# some arguments use type description directly
# instead of constraint
return [parse_type_str(input.typeStr)]
else :
return type_str_dict[input.typeStr]
else :
print('No typeStr ', schema.name)
return []
def gen_op_def(schema): def gen_op_def(schema):
indent = inc_indent() indent = inc_indent()
@ -727,15 +792,20 @@ def gen_op_def(schema):
s += indent + '"{}"\n'.format(escaped_line) s += indent + '"{}"\n'.format(escaped_line)
s += indent + '}];\n' s += indent + '}];\n'
# handle the type constraint for input and output
# parse type constraint into onnx-mlir type string list
type_str_dict = parse_type_constraints(schema)
# Generate ins (consisting of operands and attributes). # Generate ins (consisting of operands and attributes).
ins = get_operands_or_results(schema, is_input=True) ins = get_operands_or_results(schema, type_str_dict, is_input=True)
ins.update(get_attrs(schema)) ins.update(get_attrs(schema))
ins_strs = ["{1}:${0}".format(*i) for i in ins.items()] ins_strs = ["{1}:${0}".format(*i) for i in ins.items()]
s += indent + 'let arguments = (ins {});\n'.format( s += indent + 'let arguments = (ins {});\n'.format(
(',\n' + inc_indent(indent)).join(ins_strs)) (',\n' + inc_indent(indent)).join(ins_strs))
# Generate outs (operation results). # Generate outs (operation results).
outs = get_operands_or_results(schema, is_input=False) outs = get_operands_or_results(schema, type_str_dict, is_input=False)
outs_strs = ["{1}:${0}".format(*i) for i in outs.items()] outs_strs = ["{1}:${0}".format(*i) for i in outs.items()]
s += indent + 'let results = (outs {});\n'.format( s += indent + 'let results = (outs {});\n'.format(
(',\n' + inc_indent(indent)).join(outs_strs)) (',\n' + inc_indent(indent)).join(outs_strs))
@ -756,7 +826,7 @@ def gen_op_def(schema):
# Value, Y, Attribute A", [{}]> # Value, Y, Attribute A", [{}]>
indent = inc_indent(indent) indent = inc_indent(indent)
s += indent + 'OpBuilder<"OpBuilder &builder, OperationState &state' s += indent + 'OpBuilder<"OpBuilder &builder, OperationState &state'
operands_dict = get_operands_or_results(schema, is_input=True) operands_dict = get_operands_or_results(schema, type_str_dict, is_input=True)
for name, ty in operands_dict.items(): for name, ty in operands_dict.items():
s += ', {} {}'.format(tblgen_operand_type_to_cpp_type(ty), s += ', {} {}'.format(tblgen_operand_type_to_cpp_type(ty),
name) name)