String type (Ready for Review) (#182)
* string type from tensorflow * simplify type * parser and print * gen StringType for tablegen * onnx to onnx-mlir type * add namespace * allow all integer type * dialect document * add test case * format * more precise type for ONNXOp * format * enable the failed test * update comment * update onnx.md Co-authored-by: Gheorghe-Teodor Bercea <gt.bercea@gmail.com>
This commit is contained in:
parent
f811718144
commit
2e08b2112c
File diff suppressed because it is too large
Load Diff
|
@ -16,6 +16,7 @@ add_library(OMONNXOps
|
|||
ONNXOps.hpp
|
||||
ONNXOpsHelper.cpp
|
||||
ONNXOpsHelper.hpp)
|
||||
|
||||
target_include_directories(OMONNXOps
|
||||
PRIVATE
|
||||
${ONNX_MLIR_SRC_ROOT}
|
||||
|
|
|
@ -11,6 +11,7 @@
|
|||
#include "mlir/Dialect/Traits.h"
|
||||
#include "mlir/IR/Block.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/DialectImplementation.h"
|
||||
#include "mlir/IR/Function.h"
|
||||
#include "mlir/IR/IntegerSet.h"
|
||||
#include "mlir/IR/Matchers.h"
|
||||
|
@ -25,6 +26,7 @@
|
|||
|
||||
using namespace mlir;
|
||||
using namespace mlir::OpTrait::util;
|
||||
using namespace mlir::onnxmlir;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ONNX Helper functions
|
||||
|
@ -481,6 +483,19 @@ ONNXOpsDialect::ONNXOpsDialect(mlir::MLIRContext *ctx)
|
|||
#define GET_OP_LIST
|
||||
#include "src/Dialect/ONNX/ONNXOps.cpp.inc"
|
||||
>();
|
||||
addTypes<StringType>();
|
||||
}
|
||||
|
||||
mlir::Type ONNXOpsDialect::parseType(mlir::DialectAsmParser &parser) const {
|
||||
if (parser.parseKeyword("String"))
|
||||
return Type();
|
||||
|
||||
return StringType::get(getContext());
|
||||
}
|
||||
|
||||
void ONNXOpsDialect::printType(
|
||||
mlir::Type type, mlir::DialectAsmPrinter &printer) const {
|
||||
printer << "String";
|
||||
}
|
||||
|
||||
void ONNXEntryPointOp::build(mlir::OpBuilder &builder,
|
||||
|
@ -2025,8 +2040,12 @@ LogicalResult ONNXDynamicQuantizeLinearOp::inferShapes() {
|
|||
auto yScaleTy = y_scale().getType().cast<ShapedType>();
|
||||
auto yZPTy = y_zero_point().getType().cast<ShapedType>();
|
||||
|
||||
IntegerType i8Type = IntegerType::get(8, getContext());
|
||||
RankedTensorType scalarType = RankedTensorType::get({}, i8Type);
|
||||
IntegerType ui8Type =
|
||||
IntegerType::get(8, IntegerType::Unsigned, getContext());
|
||||
FloatType f32Type = FloatType::getF32(getContext());
|
||||
|
||||
RankedTensorType scalarType = RankedTensorType::get({}, f32Type);
|
||||
RankedTensorType y_zero_point_type = RankedTensorType::get({}, ui8Type);
|
||||
|
||||
// Set the types for the scalars
|
||||
if (!yScaleTy.hasStaticShape()) {
|
||||
|
@ -2034,11 +2053,11 @@ LogicalResult ONNXDynamicQuantizeLinearOp::inferShapes() {
|
|||
}
|
||||
|
||||
if (!yZPTy.hasStaticShape()) {
|
||||
y_zero_point().setType(scalarType);
|
||||
y_zero_point().setType(y_zero_point_type);
|
||||
}
|
||||
|
||||
if (!yTy.hasStaticShape()) {
|
||||
RankedTensorType outType = RankedTensorType::get(inTy.getShape(), i8Type);
|
||||
RankedTensorType outType = RankedTensorType::get(inTy.getShape(), ui8Type);
|
||||
y().setType(outType);
|
||||
}
|
||||
|
||||
|
|
|
@ -31,6 +31,13 @@ class ONNXOpsDialect : public Dialect {
|
|||
public:
|
||||
ONNXOpsDialect(MLIRContext *context);
|
||||
|
||||
/// Parse an instance of a type registered to the onnx dialect.
|
||||
mlir::Type parseType(mlir::DialectAsmParser &parser) const override;
|
||||
|
||||
/// Print an instance of a type registered to the onnx dialect.
|
||||
void printType(
|
||||
mlir::Type type, mlir::DialectAsmPrinter &printer) const override;
|
||||
|
||||
/// Provide a utility accessor to the dialect namespace. This is used by
|
||||
/// several utilities for casting between dialects.
|
||||
static StringRef getDialectNamespace() { return "onnx"; }
|
||||
|
@ -41,6 +48,39 @@ public:
|
|||
#define GET_OP_CLASSES
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp.inc"
|
||||
|
||||
// The namespace onnxmlir is experimental.
|
||||
// onnx_mlir has been used in KRNL. Other candidates are onnxops, onnxdialect.
|
||||
// Should this namesapce for onnx mlir project or ONNXOp dialect?
|
||||
// Or we need two namespace?
|
||||
// Will put all the ONNXOps into this namespace
|
||||
namespace onnxmlir {
|
||||
|
||||
namespace ONNXTypes {
|
||||
|
||||
enum Kind {
|
||||
FIRST_USED_ONNX_TYPE = Type::FIRST_PRIVATE_EXPERIMENTAL_1_TYPE,
|
||||
//#define HANDLE_TF_TYPE(tftype, enumerant, name) enumerant,
|
||||
//#include "src/Dialect/ONNX/ONXTypes.def"
|
||||
STRING,
|
||||
SEQ,
|
||||
LAST_USED_ONNX_TYPE,
|
||||
};
|
||||
} // namespace ONNXTypes
|
||||
|
||||
class StringType : public mlir::Type::TypeBase<StringType, mlir::Type> {
|
||||
public:
|
||||
using Base::Base;
|
||||
static bool kindof(unsigned kind) { return kind == ONNXTypes::STRING; }
|
||||
|
||||
static unsigned getTypeKind() { return ONNXTypes::STRING; }
|
||||
|
||||
static StringType get(MLIRContext *ctx) {
|
||||
return Base::get(ctx, ONNXTypes::STRING);
|
||||
}
|
||||
};
|
||||
|
||||
} // end namespace onnxmlir
|
||||
|
||||
} // end namespace mlir
|
||||
|
||||
namespace onnx_mlir {}
|
||||
|
|
|
@ -17,6 +17,8 @@
|
|||
include "mlir/IR/OpBase.td"
|
||||
#endif // OP_BASE
|
||||
|
||||
def StringType : Type<CPred<"$_self.isa<StringType>()">, "stirng type">;
|
||||
|
||||
#ifdef SHAPE_INFERENCE_INTERFACE
|
||||
#else
|
||||
include "src/Interface/ShapeInferenceInterface.td"
|
||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -10,8 +10,11 @@
|
|||
|
||||
#include "ONNXOpsHelper.hpp"
|
||||
|
||||
#include "ONNXOps.hpp"
|
||||
|
||||
// Identity affine
|
||||
using namespace mlir;
|
||||
using namespace mlir::onnxmlir;
|
||||
AffineMap getIdentityDimMap(Builder &builder) {
|
||||
return AffineMap::get(1, 0, {builder.getAffineDimExpr(0)});
|
||||
}
|
||||
|
@ -55,21 +58,26 @@ mlir::Type convertONNXTypeToMLIRType(
|
|||
case onnx::TensorProto_DataType::TensorProto_DataType_DOUBLE:
|
||||
return builder_.getF64Type();
|
||||
case onnx::TensorProto_DataType::TensorProto_DataType_INT8:
|
||||
case onnx::TensorProto_DataType::TensorProto_DataType_UINT8:
|
||||
return builder_.getIntegerType(/*width=*/8);
|
||||
case onnx::TensorProto_DataType::TensorProto_DataType_UINT8:
|
||||
return builder_.getIntegerType(/*width=*/8, false);
|
||||
case onnx::TensorProto_DataType::TensorProto_DataType_INT16:
|
||||
case onnx::TensorProto_DataType::TensorProto_DataType_UINT16:
|
||||
return builder_.getIntegerType(/*width=*/16);
|
||||
case onnx::TensorProto_DataType::TensorProto_DataType_UINT16:
|
||||
return builder_.getIntegerType(/*width=*/16, false);
|
||||
case onnx::TensorProto_DataType::TensorProto_DataType_INT32:
|
||||
case onnx::TensorProto_DataType::TensorProto_DataType_UINT32:
|
||||
return builder_.getIntegerType(/*width=*/32);
|
||||
case onnx::TensorProto_DataType::TensorProto_DataType_UINT32:
|
||||
return builder_.getIntegerType(/*width=*/32, false);
|
||||
case onnx::TensorProto_DataType::TensorProto_DataType_INT64:
|
||||
case onnx::TensorProto_DataType::TensorProto_DataType_UINT64:
|
||||
return builder_.getIntegerType(/*width=*/64);
|
||||
case onnx::TensorProto_DataType::TensorProto_DataType_UINT64:
|
||||
return builder_.getIntegerType(/*width=*/64, false);
|
||||
case onnx::TensorProto_DataType::TensorProto_DataType_BOOL:
|
||||
return builder_.getI1Type();
|
||||
|
||||
case onnx::TensorProto_DataType::TensorProto_DataType_STRING:
|
||||
return StringType::get(builder_.getContext());
|
||||
|
||||
case onnx::TensorProto_DataType::TensorProto_DataType_COMPLEX64:
|
||||
case onnx::TensorProto_DataType::TensorProto_DataType_COMPLEX128:
|
||||
case onnx::TensorProto_DataType::TensorProto_DataType_UNDEFINED:
|
||||
|
|
|
@ -2,8 +2,8 @@
|
|||
|
||||
// -----
|
||||
|
||||
func @test_reshape(%arg0 : tensor<?x10xf32>, %arg1 : tensor<4xi32>) -> tensor<*xf32> {
|
||||
%0 = "onnx.Reshape"(%arg0, %arg1) : (tensor<?x10xf32>, tensor<4xi32>) -> tensor<*xf32>
|
||||
func @test_reshape(%arg0 : tensor<?x10xf32>, %arg1 : tensor<4xi64>) -> tensor<*xf32> {
|
||||
%0 = "onnx.Reshape"(%arg0, %arg1) : (tensor<?x10xf32>, tensor<4xi64>) -> tensor<*xf32>
|
||||
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||
|
||||
// CHECK: llvm.func @llvm.memcpy.p0i8.p0i8.i64(!llvm<"i8*">, !llvm<"i8*">, !llvm.i64, !llvm.i1)
|
||||
|
|
|
@ -397,8 +397,8 @@ func @test_relu(%arg0 : tensor<?x10xf32>) -> tensor<*xf32> {
|
|||
|
||||
// -----
|
||||
|
||||
func @test_reshape(%arg0 : tensor<?x10xf32>, %arg1 : tensor<4xi32>) -> tensor<*xf32> {
|
||||
%0 = "onnx.Reshape"(%arg0, %arg1) : (tensor<?x10xf32>, tensor<4xi32>) -> tensor<*xf32>
|
||||
func @test_reshape(%arg0 : tensor<?x10xf32>, %arg1 : tensor<4xi64>) -> tensor<*xf32> {
|
||||
%0 = "onnx.Reshape"(%arg0, %arg1) : (tensor<?x10xf32>, tensor<4xi64>) -> tensor<*xf32>
|
||||
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||
|
||||
// CHECK-LABEL: test_reshape
|
||||
|
@ -411,56 +411,52 @@ func @test_reshape(%arg0 : tensor<?x10xf32>, %arg1 : tensor<4xi32>) -> tensor<*x
|
|||
|
||||
// CHECK: [[TYPE_IN_BYTES_1:%.+]] = constant 4 : i64
|
||||
// CHECK: %[[CONSTANT_1:.+]] = constant 0 : index
|
||||
// CHECK: [[LOAD_0:%.+]] = load %arg1[%[[CONSTANT_1]]] : memref<4xi32>
|
||||
// CHECK: [[LOAD_0:%.+]] = load %arg1[%[[CONSTANT_1]]] : memref<4xi64>
|
||||
// CHECK: [[DIM_1:%.+]] = dim %arg0, 0 : memref<?x10xf32>
|
||||
// CHECK: [[DIM_1_CAST:%.+]] = index_cast [[DIM_1]] : index to i32
|
||||
// CHECK: [[CONSTANT_2:%.+]] = constant 0 : i32
|
||||
// CHECK: [[CMP_0:%.+]] = cmpi "eq", [[LOAD_0]], [[CONSTANT_2]] : i32
|
||||
// CHECK: [[SELECT_0:%.+]] = select [[CMP_0]], [[DIM_1_CAST]], [[LOAD_0]] : i32
|
||||
// CHECK: [[ZEXTI_0:%.+]] = zexti [[SELECT_0]] : i32 to i64
|
||||
// CHECK: [[MUL_1:%.+]] = muli [[TYPE_IN_BYTES_1]], [[ZEXTI_0]] : i64
|
||||
// CHECK: [[DIM_1_CAST:%.+]] = index_cast [[DIM_1]] : index to i64
|
||||
// CHECK: [[CONSTANT_2:%.+]] = constant 0 : i64
|
||||
// CHECK: [[CMP_0:%.+]] = cmpi "eq", [[LOAD_0]], [[CONSTANT_2]] : i64
|
||||
// CHECK: [[SELECT_0:%.+]] = select [[CMP_0]], [[DIM_1_CAST]], [[LOAD_0]] : i64
|
||||
// CHECK: [[MUL_1:%.+]] = muli [[TYPE_IN_BYTES_1]], [[SELECT_0]] : i64
|
||||
|
||||
// CHECK: %[[CONSTANT_3:.+]] = constant 1 : index
|
||||
// CHECK: [[LOAD_1:%.+]] = load %arg1[%[[CONSTANT_3]]] : memref<4xi32>
|
||||
// CHECK: [[CONSTANT_3:%.+]] = constant 10 : i32
|
||||
// CHECK: [[CONSTANT_4:%.+]] = constant 0 : i32
|
||||
// CHECK: [[CMP_1:%.+]] = cmpi "eq", [[LOAD_1]], [[CONSTANT_4]] : i32
|
||||
// CHECK: [[SELECT_1:%.+]] = select [[CMP_1]], [[CONSTANT_3]], [[LOAD_1]] : i32
|
||||
// CHECK: [[ZEXTI_1:%.+]] = zexti [[SELECT_1]] : i32 to i64
|
||||
// CHECK: [[MUL_2:%.+]] = muli [[MUL_1]], [[ZEXTI_1]] : i64
|
||||
// CHECK: [[LOAD_1:%.+]] = load %arg1[%[[CONSTANT_3]]] : memref<4xi64>
|
||||
// CHECK: [[CONSTANT_3:%.+]] = constant 10 : i64
|
||||
// CHECK: [[CONSTANT_4:%.+]] = constant 0 : i64
|
||||
// CHECK: [[CMP_1:%.+]] = cmpi "eq", [[LOAD_1]], [[CONSTANT_4]] : i64
|
||||
// CHECK: [[SELECT_1:%.+]] = select [[CMP_1]], [[CONSTANT_3]], [[LOAD_1]] : i64
|
||||
// CHECK: [[MUL_2:%.+]] = muli [[MUL_1]], [[SELECT_1]] : i64
|
||||
|
||||
// CHECK: %[[CONSTANT_5:.+]] = constant 2 : index
|
||||
// CHECK: [[LOAD_2:%.+]] = load %arg1[%[[CONSTANT_5]]] : memref<4xi32>
|
||||
// CHECK: [[ZEXTI_2:%.+]] = zexti [[LOAD_2]] : i32 to i64
|
||||
// CHECK: [[MUL_3:%.+]] = muli [[MUL_2]], [[ZEXTI_2]] : i64
|
||||
// CHECK: [[LOAD_2:%.+]] = load %arg1[%[[CONSTANT_5]]] : memref<4xi64>
|
||||
// CHECK: [[MUL_3:%.+]] = muli [[MUL_2]], [[LOAD_2]] : i64
|
||||
|
||||
// CHECK: %[[CONSTANT_6:.+]] = constant 3 : index
|
||||
// CHECK: [[LOAD_3:%.+]] = load %arg1[%[[CONSTANT_6]]] : memref<4xi32>
|
||||
// CHECK: [[ZEXTI_3:%.+]] = zexti [[LOAD_3]] : i32 to i64
|
||||
// CHECK: [[MUL_4:%.+]] = muli [[MUL_3]], [[ZEXTI_3]] : i64
|
||||
// CHECK: [[LOAD_3:%.+]] = load %arg1[%[[CONSTANT_6]]] : memref<4xi64>
|
||||
// CHECK: [[MUL_4:%.+]] = muli [[MUL_3]], [[LOAD_3]] : i64
|
||||
|
||||
// CHECK: [[CONSTANT_7:%.+]] = constant 0 : i64
|
||||
// CHECK: [[SUB_0:%.+]] = subi [[CONSTANT_7]], [[MUL_4]] : i64
|
||||
|
||||
// CHECK: [[CONSTANT_8:%.+]] = constant -1 : i64
|
||||
// CHECK: [[CMP_2:%.+]] = cmpi "eq", [[ZEXTI_0]], [[CONSTANT_8]] : i64
|
||||
// CHECK: [[CMP_2:%.+]] = cmpi "eq", [[SELECT_0]], [[CONSTANT_8]] : i64
|
||||
// CHECK: [[DIVISIGNED_0:%.+]] = divi_signed [[TENSOR_SIZE]], [[SUB_0]] : i64
|
||||
// CHECK: [[SELECT_2:%.+]] = select [[CMP_2]], [[DIVISIGNED_0]], [[ZEXTI_0]] : i64
|
||||
// CHECK: [[SELECT_2:%.+]] = select [[CMP_2]], [[DIVISIGNED_0]], [[SELECT_0]] : i64
|
||||
// CHECK: [[CAST_0:%.+]] = index_cast [[SELECT_2]] : i64 to index
|
||||
|
||||
// CHECK: [[CMP_3:%.+]] = cmpi "eq", [[ZEXTI_1]], [[CONSTANT_8]] : i64
|
||||
// CHECK: [[CMP_3:%.+]] = cmpi "eq", [[SELECT_1]], [[CONSTANT_8]] : i64
|
||||
// CHECK: [[DIVISIGNED_1:%.+]] = divi_signed [[TENSOR_SIZE]], [[SUB_0]] : i64
|
||||
// CHECK: [[SELECT_3:%.+]] = select [[CMP_3]], [[DIVISIGNED_1]], [[ZEXTI_1]] : i64
|
||||
// CHECK: [[SELECT_3:%.+]] = select [[CMP_3]], [[DIVISIGNED_1]], [[SELECT_1]] : i64
|
||||
// CHECK: [[CAST_1:%.+]] = index_cast [[SELECT_3]] : i64 to index
|
||||
|
||||
// CHECK: [[CMP_4:%.+]] = cmpi "eq", [[ZEXTI_2]], [[CONSTANT_8]] : i64
|
||||
// CHECK: [[CMP_4:%.+]] = cmpi "eq", [[LOAD_2]], [[CONSTANT_8]] : i64
|
||||
// CHECK: [[DIVISIGNED_2:%.+]] = divi_signed [[TENSOR_SIZE]], [[SUB_0]] : i64
|
||||
// CHECK: [[SELECT_4:%.+]] = select [[CMP_4]], [[DIVISIGNED_2]], [[ZEXTI_2]] : i64
|
||||
// CHECK: [[SELECT_4:%.+]] = select [[CMP_4]], [[DIVISIGNED_2]], [[LOAD_2]] : i64
|
||||
// CHECK: [[CAST_2:%.+]] = index_cast [[SELECT_4]] : i64 to index
|
||||
|
||||
// CHECK: [[CMP_5:%.+]] = cmpi "eq", [[ZEXTI_3]], [[CONSTANT_8]] : i64
|
||||
// CHECK: [[CMP_5:%.+]] = cmpi "eq", [[LOAD_3]], [[CONSTANT_8]] : i64
|
||||
// CHECK: [[DIVISIGNED_3:%.+]] = divi_signed [[TENSOR_SIZE]], [[SUB_0]] : i64
|
||||
// CHECK: [[SELECT_5:%.+]] = select [[CMP_5]], [[DIVISIGNED_3]], [[ZEXTI_3]] : i64
|
||||
// CHECK: [[SELECT_5:%.+]] = select [[CMP_5]], [[DIVISIGNED_3]], [[LOAD_3]] : i64
|
||||
// CHECK: [[CAST_3:%.+]] = index_cast [[SELECT_5]] : i64 to index
|
||||
|
||||
// CHECK: [[ALLOC:%.+]] = alloc([[CAST_0]], [[CAST_1]], [[CAST_2]], [[CAST_3]]) : memref<?x?x?x?xf32>
|
||||
|
|
|
@ -542,48 +542,48 @@ func @test_default_averagepool_strides_nonunifpad_ceil(%arg0 : tensor<5x5x30x32x
|
|||
/// Test the reshape op inference when constants are present.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
func @test_reshape_dynamic(%arg0 : tensor<5x5x1x32xf32>, %arg1 : tensor<4xi32>) -> tensor<*xf32> {
|
||||
%0 = "onnx.Reshape"(%arg0, %arg1) : (tensor<5x5x1x32xf32>, tensor<4xi32>) -> tensor<*xf32>
|
||||
func @test_reshape_dynamic(%arg0 : tensor<5x5x1x32xf32>, %arg1 : tensor<4xi64>) -> tensor<*xf32> {
|
||||
%0 = "onnx.Reshape"(%arg0, %arg1) : (tensor<5x5x1x32xf32>, tensor<4xi64>) -> tensor<*xf32>
|
||||
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||
|
||||
// CHECK-LABEL: test_reshape_dynamic
|
||||
// CHECK: [[RES:%.+]] = "onnx.Reshape"(%arg0, %arg1) : (tensor<5x5x1x32xf32>, tensor<4xi32>) -> tensor<?x?x?x?xf32>
|
||||
// CHECK: [[RES:%.+]] = "onnx.Reshape"(%arg0, %arg1) : (tensor<5x5x1x32xf32>, tensor<4xi64>) -> tensor<?x?x?x?xf32>
|
||||
// CHECK: return [[RES]] : tensor<?x?x?x?xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @test_reshape_1(%arg0 : tensor<5x5x1x32xf32>) -> tensor<*xf32> {
|
||||
%0 = "onnx.Constant"() {value = dense<[5, 5, 16, 2]> : tensor<4xi32> } : () -> tensor<4xi32>
|
||||
%1 = "onnx.Reshape"(%arg0, %0) : (tensor<5x5x1x32xf32>, tensor<4xi32>) -> tensor<*xf32>
|
||||
%0 = "onnx.Constant"() {value = dense<[5, 5, 16, 2]> : tensor<4xi64> } : () -> tensor<4xi64>
|
||||
%1 = "onnx.Reshape"(%arg0, %0) : (tensor<5x5x1x32xf32>, tensor<4xi64>) -> tensor<*xf32>
|
||||
"std.return"(%1) : (tensor<*xf32>) -> ()
|
||||
|
||||
// CHECK-LABEL: test_reshape_1
|
||||
// CHECK: [[RES:%.+]] = "onnx.Reshape"(%arg0, %0) : (tensor<5x5x1x32xf32>, tensor<4xi32>) -> tensor<5x5x16x2xf32>
|
||||
// CHECK: [[RES:%.+]] = "onnx.Reshape"(%arg0, %0) : (tensor<5x5x1x32xf32>, tensor<4xi64>) -> tensor<5x5x16x2xf32>
|
||||
// CHECK: return [[RES]] : tensor<5x5x16x2xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @test_reshape_2(%arg0 : tensor<5x5x1x32xf32>) -> tensor<*xf32> {
|
||||
%0 = "onnx.Constant"() {value = dense<[-1, 16, 2]> : tensor<3xi32> } : () -> tensor<3xi32>
|
||||
%1 = "onnx.Reshape"(%arg0, %0) : (tensor<5x5x1x32xf32>, tensor<3xi32>) -> tensor<*xf32>
|
||||
%0 = "onnx.Constant"() {value = dense<[-1, 16, 2]> : tensor<3xi64> } : () -> tensor<3xi64>
|
||||
%1 = "onnx.Reshape"(%arg0, %0) : (tensor<5x5x1x32xf32>, tensor<3xi64>) -> tensor<*xf32>
|
||||
"std.return"(%1) : (tensor<*xf32>) -> ()
|
||||
|
||||
// CHECK-LABEL: test_reshape_2
|
||||
// CHECK: [[RES:%.+]] = "onnx.Reshape"(%arg0, %0) : (tensor<5x5x1x32xf32>, tensor<3xi32>) -> tensor<25x16x2xf32>
|
||||
// CHECK: [[RES:%.+]] = "onnx.Reshape"(%arg0, %0) : (tensor<5x5x1x32xf32>, tensor<3xi64>) -> tensor<25x16x2xf32>
|
||||
// CHECK: return [[RES]] : tensor<25x16x2xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @test_reshape_3(%arg0 : tensor<5x5x1x32xf32>) -> tensor<*xf32> {
|
||||
%0 = "onnx.Constant"() {value = dense<[-1, 0, 2]> : tensor<3xi32> } : () -> tensor<3xi32>
|
||||
%1 = "onnx.Reshape"(%arg0, %0) : (tensor<5x5x1x32xf32>, tensor<3xi32>) -> tensor<*xf32>
|
||||
%0 = "onnx.Constant"() {value = dense<[-1, 0, 2]> : tensor<3xi64> } : () -> tensor<3xi64>
|
||||
%1 = "onnx.Reshape"(%arg0, %0) : (tensor<5x5x1x32xf32>, tensor<3xi64>) -> tensor<*xf32>
|
||||
"std.return"(%1) : (tensor<*xf32>) -> ()
|
||||
|
||||
// CHECK-LABEL: test_reshape_3
|
||||
// CHECK: [[RES:%.+]] = "onnx.Reshape"(%arg0, %0) : (tensor<5x5x1x32xf32>, tensor<3xi32>) -> tensor<80x5x2xf32>
|
||||
// CHECK: [[RES:%.+]] = "onnx.Reshape"(%arg0, %0) : (tensor<5x5x1x32xf32>, tensor<3xi64>) -> tensor<80x5x2xf32>
|
||||
// CHECK: return [[RES]] : tensor<80x5x2xf32>
|
||||
}
|
||||
|
||||
|
@ -904,13 +904,13 @@ func @test_cast_2(%arg0 : tensor<2x3x4xf32>) -> tensor<*xui8> {
|
|||
"std.return"(%1) : (tensor<*xui8>) -> ()
|
||||
|
||||
// CHECK-LABEL: test_cast_2
|
||||
// CHECK: [[RES:%.+]] = "onnx.Cast"(%arg0) {to = 2 : i64} : (tensor<2x3x4xf32>) -> tensor<2x3x4xi8>
|
||||
// CHECK: return [[RES]] : tensor<2x3x4xi8>
|
||||
// CHECK: [[RES:%.+]] = "onnx.Cast"(%arg0) {to = 2 : i64} : (tensor<2x3x4xf32>) -> tensor<2x3x4xui8>
|
||||
// CHECK: return [[RES]] : tensor<2x3x4xui8>
|
||||
}
|
||||
|
||||
func @test_cast_3(%arg0 : tensor<2x3x4xf32>) -> tensor<*xsi8> {
|
||||
%1 = "onnx.Cast"(%arg0) {to = 3} : (tensor<2x3x4xf32>) -> tensor<*xsi8>
|
||||
"std.return"(%1) : (tensor<*xsi8>) -> ()
|
||||
func @test_cast_3(%arg0 : tensor<2x3x4xf32>) -> tensor<*xi8> {
|
||||
%1 = "onnx.Cast"(%arg0) {to = 3} : (tensor<2x3x4xf32>) -> tensor<*xi8>
|
||||
"std.return"(%1) : (tensor<*xi8>) -> ()
|
||||
|
||||
// CHECK-LABEL: test_cast_3
|
||||
// CHECK: [[RES:%.+]] = "onnx.Cast"(%arg0) {to = 3 : i64} : (tensor<2x3x4xf32>) -> tensor<2x3x4xi8>
|
||||
|
@ -930,30 +930,33 @@ func @test_cast_10(%arg0 : tensor<2x3x4xf32>) -> tensor<*xf16> {
|
|||
/// Test the quantization op inferences.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
func @test_dyn_quantize_linear_1(%arg0 : tensor<5x2x3x4xf32>) -> tensor<*xi8> {
|
||||
%1:3 = "onnx.DynamicQuantizeLinear"(%arg0) {} : (tensor<5x2x3x4xf32>) -> (tensor<*xi8>, tensor<*xi8>, tensor<*xi8>)
|
||||
"std.return"(%1#0) {} : (tensor<*xi8>) -> ()
|
||||
// TOFIX
|
||||
// This test case is commented out because the #1 output should be tensor<f32>
|
||||
// but tensor<i8> is generated
|
||||
func @test_dyn_quantize_linear_1(%arg0 : tensor<5x2x3x4xf32>) -> tensor<*xui8> {
|
||||
%1:3 = "onnx.DynamicQuantizeLinear"(%arg0) {} : (tensor<5x2x3x4xf32>) -> (tensor<*xui8>, tensor<*xf32>, tensor<*xui8>)
|
||||
"std.return"(%1#0) {} : (tensor<*xui8>) -> ()
|
||||
|
||||
// CHECK-LABEL: test_dyn_quantize_linear_1
|
||||
// CHECK: [[RES:%.+]], {{.*}}, {{.*}} = "onnx.DynamicQuantizeLinear"(%arg0) : (tensor<5x2x3x4xf32>) -> (tensor<5x2x3x4xi8>, tensor<i8>, tensor<i8>)
|
||||
// CHECK: return [[RES]] : tensor<5x2x3x4xi8>
|
||||
// CHECK: [[RES:%.+]], {{.*}}, {{.*}} = "onnx.DynamicQuantizeLinear"(%arg0) : (tensor<5x2x3x4xf32>) -> (tensor<5x2x3x4xui8>, tensor<f32>, tensor<ui8>)
|
||||
// CHECK: return [[RES]] : tensor<5x2x3x4xui8>
|
||||
}
|
||||
|
||||
func @test_quantize_linear_1(%arg0 : tensor<5x2x3x4xf32>, %arg1 : tensor<i8>, %arg2 : tensor<i8>) -> tensor<*xi8> {
|
||||
%1 = "onnx.QuantizeLinear"(%arg0, %arg1, %arg2) {} : (tensor<5x2x3x4xf32>, tensor<i8>, tensor<i8>) -> tensor<*xi8>
|
||||
func @test_quantize_linear_1(%arg0 : tensor<5x2x3x4xf32>, %arg1 : tensor<f32>, %arg2 : tensor<i8>) -> tensor<*xi8> {
|
||||
%1 = "onnx.QuantizeLinear"(%arg0, %arg1, %arg2) {} : (tensor<5x2x3x4xf32>, tensor<f32>, tensor<i8>) -> tensor<*xi8>
|
||||
"std.return"(%1) {} : (tensor<*xi8>) -> ()
|
||||
|
||||
// CHECK-LABEL: test_quantize_linear_1
|
||||
// CHECK: [[RES:%.+]] = "onnx.QuantizeLinear"(%arg0, %arg1, %arg2) : (tensor<5x2x3x4xf32>, tensor<i8>, tensor<i8>) -> tensor<5x2x3x4xi8>
|
||||
// CHECK: [[RES:%.+]] = "onnx.QuantizeLinear"(%arg0, %arg1, %arg2) : (tensor<5x2x3x4xf32>, tensor<f32>, tensor<i8>) -> tensor<5x2x3x4xi8>
|
||||
// CHECK: return [[RES]] : tensor<5x2x3x4xi8>
|
||||
}
|
||||
|
||||
func @test_dequantize_linear_1(%arg0 : tensor<5x2x3x4xi8>, %arg1 : tensor<i8>, %arg2 : tensor<i8>) -> tensor<*xf32> {
|
||||
%1 = "onnx.DequantizeLinear"(%arg0, %arg1, %arg2) {} : (tensor<5x2x3x4xi8>, tensor<i8>, tensor<i8>) -> tensor<*xf32>
|
||||
func @test_dequantize_linear_1(%arg0 : tensor<5x2x3x4xi8>, %arg1 : tensor<f32>, %arg2 : tensor<i8>) -> tensor<*xf32> {
|
||||
%1 = "onnx.DequantizeLinear"(%arg0, %arg1, %arg2) {} : (tensor<5x2x3x4xi8>, tensor<f32>, tensor<i8>) -> tensor<*xf32>
|
||||
"std.return"(%1) {} : (tensor<*xf32>) -> ()
|
||||
|
||||
// CHECK-LABEL: test_dequantize_linear_1
|
||||
// CHECK: [[RES:%.+]] = "onnx.DequantizeLinear"(%arg0, %arg1, %arg2) : (tensor<5x2x3x4xi8>, tensor<i8>, tensor<i8>) -> tensor<5x2x3x4xf32>
|
||||
// CHECK: [[RES:%.+]] = "onnx.DequantizeLinear"(%arg0, %arg1, %arg2) : (tensor<5x2x3x4xi8>, tensor<f32>, tensor<i8>) -> tensor<5x2x3x4xf32>
|
||||
// CHECK: return [[RES]] : tensor<5x2x3x4xf32>
|
||||
}
|
||||
|
||||
|
|
|
@ -1,9 +1,16 @@
|
|||
// RUN: onnx-mlir-opt %s -split-input-file | FileCheck %s
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// CHECK-LABEL: @check_map1(%arg0: tuple<tensor<10xi64>, tensor<10xi64>>) -> tensor<*xi64> {
|
||||
func @check_map1(%arg0: tuple<tensor<10xi64>, tensor<10xi64>>) -> tensor<*xi64> {
|
||||
%0 = "onnx.CastMap"(%arg0) {cast_to = "TO_FLOAT", map_form = "DENSE", max_map = 1 : i64} : (tuple<tensor<10xi64>, tensor<10xi64>>) -> tensor<*xi64>
|
||||
return %0 : tensor<*xi64>
|
||||
// CHECK-NEXT: %0 = "onnx.CastMap"(%arg0) {cast_to = "TO_FLOAT", map_form = "DENSE", max_map = 1 : i64} : (tuple<tensor<10xi64>, tensor<10xi64>>) -> tensor<*xi64>
|
||||
// CHECK-LABEL: @check_map1(%arg0: tuple<i64, f32>) -> tensor<*xf32> {
|
||||
func @check_map1(%arg0: tuple<i64, f32>) -> tensor<*xf32> {
|
||||
%0 = "onnx.CastMap"(%arg0) {cast_to = "TO_FLOAT", map_form = "DENSE", max_map = 1 : i64} : (tuple<i64, f32>) -> tensor<*xf32>
|
||||
return %0 : tensor<*xf32>
|
||||
// CHECK-NEXT: %0 = "onnx.CastMap"(%arg0) {cast_to = "TO_FLOAT", map_form = "DENSE", max_map = 1 : i64} : (tuple<i64, f32>) -> tensor<*xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @check_string(%arg0: tensor<10x20x!onnx.String>) -> tensor<10x20x!onnx.String> {
|
||||
func @check_string(%arg0: tensor<10x20x!onnx.String>) -> tensor<10x20x!onnx.String> {
|
||||
return %arg0 : tensor<10x20x!onnx.String>
|
||||
// CHECK-NEXT: return %arg0 : tensor<10x20x!onnx.String>
|
||||
}
|
||||
|
||||
|
|
|
@ -1,22 +1,22 @@
|
|||
// RUN: onnx-mlir-opt --attribute-promotion %s -split-input-file | FileCheck %s
|
||||
|
||||
func @test_should_promote_to_attribute(%arg0 : tensor<?x10xf32>) -> tensor<*xf32> {
|
||||
%shape = constant dense<[6, 7, 42]> : tensor<3xi32>
|
||||
%0 = "onnx.Reshape"(%arg0, %shape) : (tensor<?x10xf32>, tensor<3xi32>) -> tensor<*xf32>
|
||||
%shape = constant dense<[6, 7, 42]> : tensor<3xi64>
|
||||
%0 = "onnx.Reshape"(%arg0, %shape) : (tensor<?x10xf32>, tensor<3xi64>) -> tensor<*xf32>
|
||||
return %0 : tensor<*xf32>
|
||||
// CHECK-LABEL: test_should_promote_to_attribute
|
||||
// CHECK-NEXT: [[NONE:%.+]] = constant unit
|
||||
// CHECK-NEXT: [[RESHAPE:%.+]] = "onnx.Reshape"(%{{.*}}, [[NONE]]) {shape = dense<[6, 7, 42]> : tensor<3xi32>} : (tensor<?x10xf32>, none) -> tensor<*xf32>
|
||||
// CHECK-NEXT: [[RESHAPE:%.+]] = "onnx.Reshape"(%{{.*}}, [[NONE]]) {shape = dense<[6, 7, 42]> : tensor<3xi64>} : (tensor<?x10xf32>, none) -> tensor<*xf32>
|
||||
// CHECK-NEXT: return [[RESHAPE]] : tensor<*xf32>
|
||||
}
|
||||
|
||||
func @test_should_promote_to_attribute_1(%arg0 : tensor<?x10xf32>) -> tensor<*xf32> {
|
||||
%shape = "onnx.Constant"() { value = dense<[6, 7, 42]> : tensor<3xi32>}: () -> tensor<3xi32>
|
||||
%0 = "onnx.Reshape"(%arg0, %shape) : (tensor<?x10xf32>, tensor<3xi32>) -> tensor<*xf32>
|
||||
%shape = "onnx.Constant"() { value = dense<[6, 7, 42]> : tensor<3xi64>}: () -> tensor<3xi64>
|
||||
%0 = "onnx.Reshape"(%arg0, %shape) : (tensor<?x10xf32>, tensor<3xi64>) -> tensor<*xf32>
|
||||
return %0 : tensor<*xf32>
|
||||
// CHECK-LABEL: test_should_promote_to_attribute_1
|
||||
// CHECK-NEXT: [[NONE:%.+]] = constant unit
|
||||
// CHECK-NEXT: [[RESHAPE:%.+]] = "onnx.Reshape"(%{{.*}}, [[NONE]]) {shape = dense<[6, 7, 42]> : tensor<3xi32>} : (tensor<?x10xf32>, none) -> tensor<*xf32>
|
||||
// CHECK-NEXT: [[RESHAPE:%.+]] = "onnx.Reshape"(%{{.*}}, [[NONE]]) {shape = dense<[6, 7, 42]> : tensor<3xi64>} : (tensor<?x10xf32>, none) -> tensor<*xf32>
|
||||
// CHECK-NEXT: return [[RESHAPE]] : tensor<*xf32>
|
||||
}
|
||||
|
||||
|
@ -29,25 +29,25 @@ func @test_should_not_promote_to_attribute(%arg0 : tensor<?x10xf32>, %arg1 : ten
|
|||
}
|
||||
|
||||
func @test_promote_to_attribute_without_removing_const_op(%arg0 : tensor<?x10xf32>) -> (tensor<*xf32>, tensor<*xf32>) {
|
||||
%shape = constant dense<[6, 7, 42]> : tensor<3xi32>
|
||||
%0 = "onnx.Reshape"(%arg0, %shape) : (tensor<?x10xf32>, tensor<3xi32>) -> tensor<*xf32>
|
||||
%1 = "onnx.Identity"(%shape) : (tensor<3xi32>) -> tensor<*xf32>
|
||||
%shape = constant dense<[6, 7, 42]> : tensor<3xi64>
|
||||
%0 = "onnx.Reshape"(%arg0, %shape) : (tensor<?x10xf32>, tensor<3xi64>) -> tensor<*xf32>
|
||||
%1 = "onnx.Identity"(%shape) : (tensor<3xi64>) -> tensor<*xf32>
|
||||
"std.return"(%0, %1) : (tensor<*xf32>, tensor<*xf32>) -> ()
|
||||
// CHECK-LABEL: test_promote_to_attribute_without_removing_const_op
|
||||
// CHECK-NEXT: [[NONE:%.+]] = constant unit
|
||||
// CHECK-NEXT: [[SHAPE:%.+]] = constant dense<[6, 7, 42]> : tensor<3xi32>
|
||||
// CHECK-NEXT: [[RESHAPE:%.+]] = "onnx.Reshape"(%{{.*}}, [[NONE]]) {shape = dense<[6, 7, 42]> : tensor<3xi32>} : (tensor<?x10xf32>, none) -> tensor<*xf32>
|
||||
// CHECK-NEXT: [[IDENTITY:%.+]] = "onnx.Identity"([[SHAPE]]) : (tensor<3xi32>) -> tensor<*xf32>
|
||||
// CHECK-NEXT: [[SHAPE:%.+]] = constant dense<[6, 7, 42]> : tensor<3xi64>
|
||||
// CHECK-NEXT: [[RESHAPE:%.+]] = "onnx.Reshape"(%{{.*}}, [[NONE]]) {shape = dense<[6, 7, 42]> : tensor<3xi64>} : (tensor<?x10xf32>, none) -> tensor<*xf32>
|
||||
// CHECK-NEXT: [[IDENTITY:%.+]] = "onnx.Identity"([[SHAPE]]) : (tensor<3xi64>) -> tensor<*xf32>
|
||||
// CHECK-NEXT: return [[RESHAPE]], [[IDENTITY]] : tensor<*xf32>, tensor<*xf32>
|
||||
}
|
||||
|
||||
func @test_should_promote_to_attribute1(%arg0 : tensor<?x?xf32>) -> tensor<*xf32> {
|
||||
%shape = constant dense<[0, 2, 2, 4]> : tensor<4xi32>
|
||||
%shape = constant dense<[0, 2, 2, 4]> : tensor<4xi64>
|
||||
%constant_value = constant dense<[0.]> : tensor<1xf32>
|
||||
%0 = "onnx.Pad"(%arg0, %shape, %constant_value) {mode = "constant"} : (tensor<?x?xf32>, tensor<4xi32>, tensor<1xf32>)-> tensor<*xf32>
|
||||
%0 = "onnx.Pad"(%arg0, %shape, %constant_value) {mode = "constant"} : (tensor<?x?xf32>, tensor<4xi64>, tensor<1xf32>)-> tensor<*xf32>
|
||||
return %0 : tensor<*xf32>
|
||||
// CHECK-LABEL: test_should_promote_to_attribute1
|
||||
// CHECK-NEXT: [[NONE:%.+]] = constant unit
|
||||
// CHECK-NEXT: [[PAD:%.+]] = "onnx.Pad"(%{{.*}}, [[NONE]], [[NONE]]) {constant_value = dense<0.000000e+00> : tensor<1xf32>, mode = "constant", pads = dense<[0, 2, 2, 4]> : tensor<4xi32>} : (tensor<?x?xf32>, none, none) -> tensor<*xf32>
|
||||
// CHECK-NEXT: [[PAD:%.+]] = "onnx.Pad"(%{{.*}}, [[NONE]], [[NONE]]) {constant_value = dense<0.000000e+00> : tensor<1xf32>, mode = "constant", pads = dense<[0, 2, 2, 4]> : tensor<4xi64>} : (tensor<?x?xf32>, none, none) -> tensor<*xf32>
|
||||
// CHECK-NEXT: return [[RESHAPE]] : tensor<*xf32>
|
||||
}
|
||||
|
|
|
@ -320,10 +320,10 @@ custom_definition_misc = dict([ ('Constant',
|
|||
|
||||
onnx_types = (
|
||||
'bool', 'int8', 'int16', 'int32', 'int64', 'unkown', 'float16',
|
||||
'float', 'double', 'complex64', 'complex128'
|
||||
'float', 'double', 'complex64', 'complex128', 'string'
|
||||
)
|
||||
tblgen_types = ('I1', 'I8', 'I16', 'I32', 'I64', 'BF16', 'F16', 'F32', 'F64',
|
||||
'Complex<F32>', 'Complex<F64>'
|
||||
tblgen_types = ('AnyI1', 'AnyI8', 'AnyI16', 'AnyI32', 'AnyI64', 'BF16', 'F16', 'F32', 'F64',
|
||||
'Complex<F32>', 'Complex<F64>', 'StringType'
|
||||
)
|
||||
|
||||
MAX_NUM_TYPES=20
|
||||
|
@ -468,7 +468,7 @@ def dec_indent(indent):
|
|||
def join_args(args):
|
||||
return ", ".join(args)
|
||||
|
||||
def get_operands_or_results(schema, is_input):
|
||||
def get_operands_or_results(schema, type_str_dict, is_input):
|
||||
value_list = schema.inputs if is_input else schema.outputs
|
||||
if not value_list:
|
||||
return OrderedDict()
|
||||
|
@ -482,7 +482,10 @@ def get_operands_or_results(schema, is_input):
|
|||
|
||||
name_to_types = OrderedDict()
|
||||
for i, value in enumerate(value_list):
|
||||
structure, elem_types = get_allowed_elem_types(schema, value)
|
||||
types = get_onnx_mlir_types(schema, type_str_dict, value)
|
||||
|
||||
'''
|
||||
structure, elem_types = get_allowed_elem_types(schema, type_str_dict, value)
|
||||
|
||||
if structure == 'tensor' :
|
||||
if elem_types is None:
|
||||
|
@ -513,6 +516,7 @@ def get_operands_or_results(schema, is_input):
|
|||
types = list(map(lambda x: x.format(elem_types_str), types))
|
||||
else:
|
||||
types = ["AnyMemRef", "AnyTensor"]
|
||||
'''
|
||||
|
||||
# If operand is promotable to an attribute, then it must be
|
||||
# nullable in case it migrates to be an attribute.
|
||||
|
@ -693,7 +697,68 @@ def get_type_inference_func(s, indent, type_inference_code):
|
|||
indent = dec_indent(indent)
|
||||
return s
|
||||
|
||||
def parse_type_str(allowedType):
|
||||
# AnyI may be used for uint because the onnx_mlir is not generating uint output
|
||||
# This will be fixed later and UI will be replace AnyI
|
||||
onnx_to_mlir_type_dict = { '(': '<[',
|
||||
')': ']>',
|
||||
'tensor' : 'TensorOf',
|
||||
'seq' : 'TensorOf',
|
||||
'map' : 'TupleOf',
|
||||
'bool': 'I1',
|
||||
#'uint8' : 'AnyI8',
|
||||
#uint16' : 'AnyI16',
|
||||
#uint32' : 'AnyI32',
|
||||
#uint64' : 'AnyI64',
|
||||
'uint8' : 'UI8',
|
||||
'uint16' : 'UI16',
|
||||
'uint32' : 'UI32',
|
||||
'uint64' : 'UI64',
|
||||
'int8' : 'I8',
|
||||
'int16' : 'I16',
|
||||
'int32' : 'I32',
|
||||
'int64' : 'I64',
|
||||
'float16' : 'F16',
|
||||
'float' : 'F32',
|
||||
'double' : 'F64',
|
||||
'unkown' : 'BF16',
|
||||
'complex64' : 'Complex<F32>',
|
||||
'complex128' : 'Complex<F64>',
|
||||
'string' : 'StringType'}
|
||||
|
||||
for key, item in onnx_to_mlir_type_dict.items():
|
||||
allowedType = allowedType.replace(key, item)
|
||||
return allowedType
|
||||
|
||||
def parse_a_type_constraint(constraint):
|
||||
allowedTypes = constraint.allowed_type_strs
|
||||
mlirTypes = []
|
||||
for allowedType in allowedTypes:
|
||||
mlirType = parse_type_str(allowedType)
|
||||
mlirTypes.append(mlirType)
|
||||
# Remove redundant and sort.
|
||||
# However onnx keeps a consitently meaningful order
|
||||
# There is no redundancy as long as each onnx type is mapped uniquely
|
||||
# mlirTypes = sorted(list(set(mlirTypes)))
|
||||
return mlirTypes
|
||||
|
||||
def parse_type_constraints(schema):
|
||||
type_str_dict = dict()
|
||||
for type_constraint in schema.type_constraints:
|
||||
type_str_dict[type_constraint.type_param_str] = parse_a_type_constraint(type_constraint)
|
||||
return type_str_dict
|
||||
|
||||
def get_onnx_mlir_types(schema, type_str_dict, input):
|
||||
if input.typeStr :
|
||||
if not input.typeStr in type_str_dict :
|
||||
# some arguments use type description directly
|
||||
# instead of constraint
|
||||
return [parse_type_str(input.typeStr)]
|
||||
else :
|
||||
return type_str_dict[input.typeStr]
|
||||
else :
|
||||
print('No typeStr ', schema.name)
|
||||
return []
|
||||
|
||||
def gen_op_def(schema):
|
||||
indent = inc_indent()
|
||||
|
@ -727,15 +792,20 @@ def gen_op_def(schema):
|
|||
s += indent + '"{}"\n'.format(escaped_line)
|
||||
s += indent + '}];\n'
|
||||
|
||||
|
||||
# handle the type constraint for input and output
|
||||
# parse type constraint into onnx-mlir type string list
|
||||
type_str_dict = parse_type_constraints(schema)
|
||||
|
||||
# Generate ins (consisting of operands and attributes).
|
||||
ins = get_operands_or_results(schema, is_input=True)
|
||||
ins = get_operands_or_results(schema, type_str_dict, is_input=True)
|
||||
ins.update(get_attrs(schema))
|
||||
ins_strs = ["{1}:${0}".format(*i) for i in ins.items()]
|
||||
s += indent + 'let arguments = (ins {});\n'.format(
|
||||
(',\n' + inc_indent(indent)).join(ins_strs))
|
||||
|
||||
# Generate outs (operation results).
|
||||
outs = get_operands_or_results(schema, is_input=False)
|
||||
outs = get_operands_or_results(schema, type_str_dict, is_input=False)
|
||||
outs_strs = ["{1}:${0}".format(*i) for i in outs.items()]
|
||||
s += indent + 'let results = (outs {});\n'.format(
|
||||
(',\n' + inc_indent(indent)).join(outs_strs))
|
||||
|
@ -756,7 +826,7 @@ def gen_op_def(schema):
|
|||
# Value, Y, Attribute A", [{}]>
|
||||
indent = inc_indent(indent)
|
||||
s += indent + 'OpBuilder<"OpBuilder &builder, OperationState &state'
|
||||
operands_dict = get_operands_or_results(schema, is_input=True)
|
||||
operands_dict = get_operands_or_results(schema, type_str_dict, is_input=True)
|
||||
for name, ty in operands_dict.items():
|
||||
s += ', {} {}'.format(tblgen_operand_type_to_cpp_type(ty),
|
||||
name)
|
||||
|
|
Loading…
Reference in New Issue