* base implementation

* add example

* change table gen

* docs

* small change for review

Co-authored-by: Alexandre Eichenberger <alexe@us.ibm.com>
Co-authored-by: Tian Jin <tjingrant@gmail.com>
This commit is contained in:
chentong319 2020-07-31 08:05:59 -04:00 committed by GitHub
parent 8e8f894574
commit b4228fd288
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 141 additions and 26 deletions

View File

@ -637,7 +637,7 @@ ONNX ConcatFromSequence operation
| Operand | Description | | Operand | Description |
| :-----: | ----------- | | :-----: | ----------- |
`input_sequence` | tensor of tensor of 8-bit unsigned integer values values or tensor of tensor of 16-bit unsigned integer values values or tensor of tensor of 32-bit unsigned integer values values or tensor of tensor of 64-bit unsigned integer values values or tensor of tensor of 8-bit signless integer values values or tensor of tensor of 16-bit signless integer values values or tensor of tensor of 32-bit signless integer values values or tensor of tensor of 64-bit signless integer values values or tensor of tensor of 16-bit float values values or tensor of tensor of 32-bit float values values or tensor of tensor of 64-bit float values values or tensor of tensor of stirng type values values or tensor of tensor of 1-bit signless integer values values or tensor of tensor of complex type with 32-bit float elements values values or tensor of tensor of complex type with 64-bit float elements values values or memref of any type values `input_sequence` | SeqType of tensor of 8-bit unsigned integer values values or SeqType of tensor of 16-bit unsigned integer values values or SeqType of tensor of 32-bit unsigned integer values values or SeqType of tensor of 64-bit unsigned integer values values or SeqType of tensor of 8-bit signless integer values values or SeqType of tensor of 16-bit signless integer values values or SeqType of tensor of 32-bit signless integer values values or SeqType of tensor of 64-bit signless integer values values or SeqType of tensor of 16-bit float values values or SeqType of tensor of 32-bit float values values or SeqType of tensor of 64-bit float values values or SeqType of tensor of stirng type values values or SeqType of tensor of 1-bit signless integer values values or SeqType of tensor of complex type with 32-bit float elements values values or SeqType of tensor of complex type with 64-bit float elements values values or memref of any type values
#### Results: #### Results:
@ -4833,7 +4833,7 @@ ONNX SequenceAt operation
| Operand | Description | | Operand | Description |
| :-----: | ----------- | | :-----: | ----------- |
`input_sequence` | tensor of tensor of 8-bit unsigned integer values values or tensor of tensor of 16-bit unsigned integer values values or tensor of tensor of 32-bit unsigned integer values values or tensor of tensor of 64-bit unsigned integer values values or tensor of tensor of 8-bit signless integer values values or tensor of tensor of 16-bit signless integer values values or tensor of tensor of 32-bit signless integer values values or tensor of tensor of 64-bit signless integer values values or tensor of tensor of 16-bit float values values or tensor of tensor of 32-bit float values values or tensor of tensor of 64-bit float values values or tensor of tensor of stirng type values values or tensor of tensor of 1-bit signless integer values values or tensor of tensor of complex type with 32-bit float elements values values or tensor of tensor of complex type with 64-bit float elements values values or memref of any type values `input_sequence` | SeqType of tensor of 8-bit unsigned integer values values or SeqType of tensor of 16-bit unsigned integer values values or SeqType of tensor of 32-bit unsigned integer values values or SeqType of tensor of 64-bit unsigned integer values values or SeqType of tensor of 8-bit signless integer values values or SeqType of tensor of 16-bit signless integer values values or SeqType of tensor of 32-bit signless integer values values or SeqType of tensor of 64-bit signless integer values values or SeqType of tensor of 16-bit float values values or SeqType of tensor of 32-bit float values values or SeqType of tensor of 64-bit float values values or SeqType of tensor of stirng type values values or SeqType of tensor of 1-bit signless integer values values or SeqType of tensor of complex type with 32-bit float elements values values or SeqType of tensor of complex type with 64-bit float elements values values or memref of any type values
`position` | tensor of 32-bit signless integer values or tensor of 64-bit signless integer values or memref of any type values `position` | tensor of 32-bit signless integer values or tensor of 64-bit signless integer values or memref of any type values
#### Results: #### Results:
@ -4859,7 +4859,7 @@ ONNX SequenceConstruct operation
| Result | Description | | Result | Description |
| :----: | ----------- | | :----: | ----------- |
`output_sequence` | tensor of tensor of 8-bit unsigned integer values values or tensor of tensor of 16-bit unsigned integer values values or tensor of tensor of 32-bit unsigned integer values values or tensor of tensor of 64-bit unsigned integer values values or tensor of tensor of 8-bit signless integer values values or tensor of tensor of 16-bit signless integer values values or tensor of tensor of 32-bit signless integer values values or tensor of tensor of 64-bit signless integer values values or tensor of tensor of 16-bit float values values or tensor of tensor of 32-bit float values values or tensor of tensor of 64-bit float values values or tensor of tensor of stirng type values values or tensor of tensor of 1-bit signless integer values values or tensor of tensor of complex type with 32-bit float elements values values or tensor of tensor of complex type with 64-bit float elements values values or memref of any type values `output_sequence` | SeqType of tensor of 8-bit unsigned integer values values or SeqType of tensor of 16-bit unsigned integer values values or SeqType of tensor of 32-bit unsigned integer values values or SeqType of tensor of 64-bit unsigned integer values values or SeqType of tensor of 8-bit signless integer values values or SeqType of tensor of 16-bit signless integer values values or SeqType of tensor of 32-bit signless integer values values or SeqType of tensor of 64-bit signless integer values values or SeqType of tensor of 16-bit float values values or SeqType of tensor of 32-bit float values values or SeqType of tensor of 64-bit float values values or SeqType of tensor of stirng type values values or SeqType of tensor of 1-bit signless integer values values or SeqType of tensor of complex type with 32-bit float elements values values or SeqType of tensor of complex type with 64-bit float elements values values or memref of any type values
### `onnx.SequenceEmpty` (ONNXSequenceEmptyOp) ### `onnx.SequenceEmpty` (ONNXSequenceEmptyOp)
@ -4877,7 +4877,7 @@ ONNX SequenceEmpty operation
| Result | Description | | Result | Description |
| :----: | ----------- | | :----: | ----------- |
`output` | tensor of tensor of 8-bit unsigned integer values values or tensor of tensor of 16-bit unsigned integer values values or tensor of tensor of 32-bit unsigned integer values values or tensor of tensor of 64-bit unsigned integer values values or tensor of tensor of 8-bit signless integer values values or tensor of tensor of 16-bit signless integer values values or tensor of tensor of 32-bit signless integer values values or tensor of tensor of 64-bit signless integer values values or tensor of tensor of 16-bit float values values or tensor of tensor of 32-bit float values values or tensor of tensor of 64-bit float values values or tensor of tensor of stirng type values values or tensor of tensor of 1-bit signless integer values values or tensor of tensor of complex type with 32-bit float elements values values or tensor of tensor of complex type with 64-bit float elements values values or memref of any type values `output` | SeqType of tensor of 8-bit unsigned integer values values or SeqType of tensor of 16-bit unsigned integer values values or SeqType of tensor of 32-bit unsigned integer values values or SeqType of tensor of 64-bit unsigned integer values values or SeqType of tensor of 8-bit signless integer values values or SeqType of tensor of 16-bit signless integer values values or SeqType of tensor of 32-bit signless integer values values or SeqType of tensor of 64-bit signless integer values values or SeqType of tensor of 16-bit float values values or SeqType of tensor of 32-bit float values values or SeqType of tensor of 64-bit float values values or SeqType of tensor of stirng type values values or SeqType of tensor of 1-bit signless integer values values or SeqType of tensor of complex type with 32-bit float elements values values or SeqType of tensor of complex type with 64-bit float elements values values or memref of any type values
### `onnx.SequenceErase` (ONNXSequenceEraseOp) ### `onnx.SequenceErase` (ONNXSequenceEraseOp)
@ -4892,14 +4892,14 @@ ONNX SequenceErase operation
| Operand | Description | | Operand | Description |
| :-----: | ----------- | | :-----: | ----------- |
`input_sequence` | tensor of tensor of 8-bit unsigned integer values values or tensor of tensor of 16-bit unsigned integer values values or tensor of tensor of 32-bit unsigned integer values values or tensor of tensor of 64-bit unsigned integer values values or tensor of tensor of 8-bit signless integer values values or tensor of tensor of 16-bit signless integer values values or tensor of tensor of 32-bit signless integer values values or tensor of tensor of 64-bit signless integer values values or tensor of tensor of 16-bit float values values or tensor of tensor of 32-bit float values values or tensor of tensor of 64-bit float values values or tensor of tensor of stirng type values values or tensor of tensor of 1-bit signless integer values values or tensor of tensor of complex type with 32-bit float elements values values or tensor of tensor of complex type with 64-bit float elements values values or memref of any type values `input_sequence` | SeqType of tensor of 8-bit unsigned integer values values or SeqType of tensor of 16-bit unsigned integer values values or SeqType of tensor of 32-bit unsigned integer values values or SeqType of tensor of 64-bit unsigned integer values values or SeqType of tensor of 8-bit signless integer values values or SeqType of tensor of 16-bit signless integer values values or SeqType of tensor of 32-bit signless integer values values or SeqType of tensor of 64-bit signless integer values values or SeqType of tensor of 16-bit float values values or SeqType of tensor of 32-bit float values values or SeqType of tensor of 64-bit float values values or SeqType of tensor of stirng type values values or SeqType of tensor of 1-bit signless integer values values or SeqType of tensor of complex type with 32-bit float elements values values or SeqType of tensor of complex type with 64-bit float elements values values or memref of any type values
`position` | tensor of 32-bit signless integer values or tensor of 64-bit signless integer values or memref of any type values or none type `position` | tensor of 32-bit signless integer values or tensor of 64-bit signless integer values or memref of any type values or none type
#### Results: #### Results:
| Result | Description | | Result | Description |
| :----: | ----------- | | :----: | ----------- |
`output_sequence` | tensor of tensor of 8-bit unsigned integer values values or tensor of tensor of 16-bit unsigned integer values values or tensor of tensor of 32-bit unsigned integer values values or tensor of tensor of 64-bit unsigned integer values values or tensor of tensor of 8-bit signless integer values values or tensor of tensor of 16-bit signless integer values values or tensor of tensor of 32-bit signless integer values values or tensor of tensor of 64-bit signless integer values values or tensor of tensor of 16-bit float values values or tensor of tensor of 32-bit float values values or tensor of tensor of 64-bit float values values or tensor of tensor of stirng type values values or tensor of tensor of 1-bit signless integer values values or tensor of tensor of complex type with 32-bit float elements values values or tensor of tensor of complex type with 64-bit float elements values values or memref of any type values `output_sequence` | SeqType of tensor of 8-bit unsigned integer values values or SeqType of tensor of 16-bit unsigned integer values values or SeqType of tensor of 32-bit unsigned integer values values or SeqType of tensor of 64-bit unsigned integer values values or SeqType of tensor of 8-bit signless integer values values or SeqType of tensor of 16-bit signless integer values values or SeqType of tensor of 32-bit signless integer values values or SeqType of tensor of 64-bit signless integer values values or SeqType of tensor of 16-bit float values values or SeqType of tensor of 32-bit float values values or SeqType of tensor of 64-bit float values values or SeqType of tensor of stirng type values values or SeqType of tensor of 1-bit signless integer values values or SeqType of tensor of complex type with 32-bit float elements values values or SeqType of tensor of complex type with 64-bit float elements values values or memref of any type values
### `onnx.SequenceInsert` (ONNXSequenceInsertOp) ### `onnx.SequenceInsert` (ONNXSequenceInsertOp)
@ -4915,7 +4915,7 @@ ONNX SequenceInsert operation
| Operand | Description | | Operand | Description |
| :-----: | ----------- | | :-----: | ----------- |
`input_sequence` | tensor of tensor of 8-bit unsigned integer values values or tensor of tensor of 16-bit unsigned integer values values or tensor of tensor of 32-bit unsigned integer values values or tensor of tensor of 64-bit unsigned integer values values or tensor of tensor of 8-bit signless integer values values or tensor of tensor of 16-bit signless integer values values or tensor of tensor of 32-bit signless integer values values or tensor of tensor of 64-bit signless integer values values or tensor of tensor of 16-bit float values values or tensor of tensor of 32-bit float values values or tensor of tensor of 64-bit float values values or tensor of tensor of stirng type values values or tensor of tensor of 1-bit signless integer values values or tensor of tensor of complex type with 32-bit float elements values values or tensor of tensor of complex type with 64-bit float elements values values or memref of any type values `input_sequence` | SeqType of tensor of 8-bit unsigned integer values values or SeqType of tensor of 16-bit unsigned integer values values or SeqType of tensor of 32-bit unsigned integer values values or SeqType of tensor of 64-bit unsigned integer values values or SeqType of tensor of 8-bit signless integer values values or SeqType of tensor of 16-bit signless integer values values or SeqType of tensor of 32-bit signless integer values values or SeqType of tensor of 64-bit signless integer values values or SeqType of tensor of 16-bit float values values or SeqType of tensor of 32-bit float values values or SeqType of tensor of 64-bit float values values or SeqType of tensor of stirng type values values or SeqType of tensor of 1-bit signless integer values values or SeqType of tensor of complex type with 32-bit float elements values values or SeqType of tensor of complex type with 64-bit float elements values values or memref of any type values
`tensor` | tensor of 8-bit unsigned integer values or tensor of 16-bit unsigned integer values or tensor of 32-bit unsigned integer values or tensor of 64-bit unsigned integer values or tensor of 8-bit signless integer values or tensor of 16-bit signless integer values or tensor of 32-bit signless integer values or tensor of 64-bit signless integer values or tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values or tensor of stirng type values or tensor of 1-bit signless integer values or tensor of complex type with 32-bit float elements values or tensor of complex type with 64-bit float elements values or memref of any type values `tensor` | tensor of 8-bit unsigned integer values or tensor of 16-bit unsigned integer values or tensor of 32-bit unsigned integer values or tensor of 64-bit unsigned integer values or tensor of 8-bit signless integer values or tensor of 16-bit signless integer values or tensor of 32-bit signless integer values or tensor of 64-bit signless integer values or tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values or tensor of stirng type values or tensor of 1-bit signless integer values or tensor of complex type with 32-bit float elements values or tensor of complex type with 64-bit float elements values or memref of any type values
`position` | tensor of 32-bit signless integer values or tensor of 64-bit signless integer values or memref of any type values or none type `position` | tensor of 32-bit signless integer values or tensor of 64-bit signless integer values or memref of any type values or none type
@ -4923,7 +4923,7 @@ ONNX SequenceInsert operation
| Result | Description | | Result | Description |
| :----: | ----------- | | :----: | ----------- |
`output_sequence` | tensor of tensor of 8-bit unsigned integer values values or tensor of tensor of 16-bit unsigned integer values values or tensor of tensor of 32-bit unsigned integer values values or tensor of tensor of 64-bit unsigned integer values values or tensor of tensor of 8-bit signless integer values values or tensor of tensor of 16-bit signless integer values values or tensor of tensor of 32-bit signless integer values values or tensor of tensor of 64-bit signless integer values values or tensor of tensor of 16-bit float values values or tensor of tensor of 32-bit float values values or tensor of tensor of 64-bit float values values or tensor of tensor of stirng type values values or tensor of tensor of 1-bit signless integer values values or tensor of tensor of complex type with 32-bit float elements values values or tensor of tensor of complex type with 64-bit float elements values values or memref of any type values `output_sequence` | SeqType of tensor of 8-bit unsigned integer values values or SeqType of tensor of 16-bit unsigned integer values values or SeqType of tensor of 32-bit unsigned integer values values or SeqType of tensor of 64-bit unsigned integer values values or SeqType of tensor of 8-bit signless integer values values or SeqType of tensor of 16-bit signless integer values values or SeqType of tensor of 32-bit signless integer values values or SeqType of tensor of 64-bit signless integer values values or SeqType of tensor of 16-bit float values values or SeqType of tensor of 32-bit float values values or SeqType of tensor of 64-bit float values values or SeqType of tensor of stirng type values values or SeqType of tensor of 1-bit signless integer values values or SeqType of tensor of complex type with 32-bit float elements values values or SeqType of tensor of complex type with 64-bit float elements values values or memref of any type values
### `onnx.SequenceLength` (ONNXSequenceLengthOp) ### `onnx.SequenceLength` (ONNXSequenceLengthOp)
@ -4935,7 +4935,7 @@ ONNX SequenceLength operation
| Operand | Description | | Operand | Description |
| :-----: | ----------- | | :-----: | ----------- |
`input_sequence` | tensor of tensor of 8-bit unsigned integer values values or tensor of tensor of 16-bit unsigned integer values values or tensor of tensor of 32-bit unsigned integer values values or tensor of tensor of 64-bit unsigned integer values values or tensor of tensor of 8-bit signless integer values values or tensor of tensor of 16-bit signless integer values values or tensor of tensor of 32-bit signless integer values values or tensor of tensor of 64-bit signless integer values values or tensor of tensor of 16-bit float values values or tensor of tensor of 32-bit float values values or tensor of tensor of 64-bit float values values or tensor of tensor of stirng type values values or tensor of tensor of 1-bit signless integer values values or tensor of tensor of complex type with 32-bit float elements values values or tensor of tensor of complex type with 64-bit float elements values values or memref of any type values `input_sequence` | SeqType of tensor of 8-bit unsigned integer values values or SeqType of tensor of 16-bit unsigned integer values values or SeqType of tensor of 32-bit unsigned integer values values or SeqType of tensor of 64-bit unsigned integer values values or SeqType of tensor of 8-bit signless integer values values or SeqType of tensor of 16-bit signless integer values values or SeqType of tensor of 32-bit signless integer values values or SeqType of tensor of 64-bit signless integer values values or SeqType of tensor of 16-bit float values values or SeqType of tensor of 32-bit float values values or SeqType of tensor of 64-bit float values values or SeqType of tensor of stirng type values values or SeqType of tensor of 1-bit signless integer values values or SeqType of tensor of complex type with 32-bit float elements values values or SeqType of tensor of complex type with 64-bit float elements values values or memref of any type values
#### Results: #### Results:
@ -5298,7 +5298,7 @@ ONNX SplitToSequence operation
| Result | Description | | Result | Description |
| :----: | ----------- | | :----: | ----------- |
`output_sequence` | tensor of tensor of 8-bit unsigned integer values values or tensor of tensor of 16-bit unsigned integer values values or tensor of tensor of 32-bit unsigned integer values values or tensor of tensor of 64-bit unsigned integer values values or tensor of tensor of 8-bit signless integer values values or tensor of tensor of 16-bit signless integer values values or tensor of tensor of 32-bit signless integer values values or tensor of tensor of 64-bit signless integer values values or tensor of tensor of 16-bit float values values or tensor of tensor of 32-bit float values values or tensor of tensor of 64-bit float values values or tensor of tensor of stirng type values values or tensor of tensor of 1-bit signless integer values values or tensor of tensor of complex type with 32-bit float elements values values or tensor of tensor of complex type with 64-bit float elements values values or memref of any type values `output_sequence` | SeqType of tensor of 8-bit unsigned integer values values or SeqType of tensor of 16-bit unsigned integer values values or SeqType of tensor of 32-bit unsigned integer values values or SeqType of tensor of 64-bit unsigned integer values values or SeqType of tensor of 8-bit signless integer values values or SeqType of tensor of 16-bit signless integer values values or SeqType of tensor of 32-bit signless integer values values or SeqType of tensor of 64-bit signless integer values values or SeqType of tensor of 16-bit float values values or SeqType of tensor of 32-bit float values values or SeqType of tensor of 64-bit float values values or SeqType of tensor of stirng type values values or SeqType of tensor of 1-bit signless integer values values or SeqType of tensor of complex type with 32-bit float elements values values or SeqType of tensor of complex type with 64-bit float elements values values or memref of any type values
### `onnx.Sqrt` (ONNXSqrtOp) ### `onnx.Sqrt` (ONNXSqrtOp)
@ -5964,5 +5964,5 @@ ONNX ZipMap operation
| Result | Description | | Result | Description |
| :----: | ----------- | | :----: | ----------- |
`Z` | tensor of tuple with any combination of stirng type or 32-bit float values values or tensor of tuple with any combination of 64-bit signless integer or 32-bit float values values or memref of any type values `Z` | SeqType of tuple with any combination of stirng type or 32-bit float values values or SeqType of tuple with any combination of 64-bit signless integer or 32-bit float values values or memref of any type values

View File

@ -484,18 +484,48 @@ ONNXOpsDialect::ONNXOpsDialect(mlir::MLIRContext *ctx)
#include "src/Dialect/ONNX/ONNXOps.cpp.inc" #include "src/Dialect/ONNX/ONNXOps.cpp.inc"
>(); >();
addTypes<StringType>(); addTypes<StringType>();
addTypes<SeqType>();
} }
mlir::Type ONNXOpsDialect::parseType(mlir::DialectAsmParser &parser) const { mlir::Type ONNXOpsDialect::parseType(mlir::DialectAsmParser &parser) const {
if (parser.parseKeyword("String")) StringRef keyword;
if (parser.parseKeyword(&keyword))
return Type(); return Type();
return StringType::get(getContext()); if (keyword == "String")
return StringType::get(getContext());
if (keyword == "Seq") {
if (parser.parseLess())
return Type();
SmallVector<mlir::Type, 1> elementTypes;
do {
llvm::SMLoc typeLoc = parser.getCurrentLocation();
mlir::Type elementType;
if (parser.parseType(elementType))
return Type();
// TOFIX: type limitation for Seq? similar but different shape??
elementTypes.push_back(elementType);
} while (succeeded(parser.parseOptionalComma()));
if (parser.parseGreater())
return Type();
return SeqType::get(elementTypes);
}
} }
void ONNXOpsDialect::printType( void ONNXOpsDialect::printType(
mlir::Type type, mlir::DialectAsmPrinter &printer) const { mlir::Type type, mlir::DialectAsmPrinter &printer) const {
printer << "String"; if (auto stringType = type.dyn_cast<StringType>()) {
printer << "String";
} else if (auto seqType = type.dyn_cast<SeqType>()) {
printer << "Seq<";
llvm::interleaveComma(seqType.getElementTypes(), printer);
printer << '>';
} else {
llvm_unreachable("Unexpected onnxmlir type");
}
} }
void ONNXEntryPointOp::build(mlir::OpBuilder &builder, void ONNXEntryPointOp::build(mlir::OpBuilder &builder,
@ -2625,9 +2655,56 @@ LogicalResult ONNXSliceOp::inferShapes() {
return success(); return success();
} }
//===----------------------------------------------------------------------===//
// ONNX type related code
//===----------------------------------------------------------------------===//
namespace mlir {
namespace onnxmlir {
namespace detail {
struct SeqTypeStorage : public mlir::TypeStorage {
using KeyTy = llvm::ArrayRef<mlir::Type>;
SeqTypeStorage(llvm::ArrayRef<mlir::Type> elementTypes)
: elementTypes(elementTypes) {}
bool operator==(const KeyTy &key) const { return key == elementTypes; }
static llvm::hash_code hasKey(const KeyTy &key) {
return llvm::hash_value(key);
}
static KeyTy getKey(llvm::ArrayRef<mlir::Type> elementTypes) {
return KeyTy(elementTypes);
}
static SeqTypeStorage *construct(
mlir::TypeStorageAllocator &allocator, const KeyTy &key) {
llvm::ArrayRef<mlir::Type> elementTypes = allocator.copyInto(key);
return new (allocator.allocate<SeqTypeStorage>())
SeqTypeStorage(elementTypes);
}
llvm::ArrayRef<mlir::Type> elementTypes;
};
} // end namespace detail
} // end namespace onnxmlir
} // end namespace mlir
SeqType SeqType::get(llvm::ArrayRef<mlir::Type> elementTypes) {
assert(!elementTypes.empty() && "expected non-empty seq");
mlir::MLIRContext *ctx = elementTypes.front().getContext();
return Base::get(ctx, ONNXTypes::SEQ, elementTypes);
}
llvm::ArrayRef<mlir::Type> SeqType::getElementTypes() {
return getImpl()->elementTypes;
}
mlir::Type SeqType::getElementType() { return getElementTypes()[0]; }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// TableGen'd op method definitions // TableGen'd op method definitions
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
#define GET_OP_CLASSES #define GET_OP_CLASSES
#include "src/Dialect/ONNX/ONNXOps.cpp.inc" #include "src/Dialect/ONNX/ONNXOps.cpp.inc"

View File

@ -79,6 +79,27 @@ public:
} }
}; };
namespace detail {
struct SeqTypeStorage;
} // namespace detail
class SeqType
: public mlir::Type::TypeBase<SeqType, mlir::Type, detail::SeqTypeStorage> {
public:
using Base::Base;
static bool kindof(unsigned kind) { return kind == ONNXTypes::SEQ; }
static unsigned getTypeKind() { return ONNXTypes::SEQ; }
static SeqType get(llvm::ArrayRef<mlir::Type> elementTypes);
llvm::ArrayRef<mlir::Type> getElementTypes();
mlir::Type getElementType();
size_t getNumElementTypes() { return getElementTypes().size(); }
};
} // end namespace onnxmlir } // end namespace onnxmlir
} // end namespace mlir } // end namespace mlir

View File

@ -19,6 +19,12 @@ include "mlir/IR/OpBase.td"
def StringType : Type<CPred<"$_self.isa<StringType>()">, "stirng type">; def StringType : Type<CPred<"$_self.isa<StringType>()">, "stirng type">;
def IsSeqTypePred : CPred<"$_self.isa<SeqType>()">;
class SeqOf<list<Type> allowedTypes> :
ContainerType<AnyTypeOf<allowedTypes>, IsSeqTypePred,
"$_self.cast<SeqType>().getElementType()", "SeqType">;
#ifdef SHAPE_INFERENCE_INTERFACE #ifdef SHAPE_INFERENCE_INTERFACE
#else #else
include "src/Interface/ShapeInferenceInterface.td" include "src/Interface/ShapeInferenceInterface.td"

View File

@ -610,7 +610,7 @@ def ONNXConcatFromSequenceOp:ONNX_Op<"ConcatFromSequence",
"By default 'new_axis' is 0, the behavior is similar to numpy.concatenate." "By default 'new_axis' is 0, the behavior is similar to numpy.concatenate."
"When 'new_axis' is 1, the behavior is similar to numpy.stack." "When 'new_axis' is 1, the behavior is similar to numpy.stack."
}]; }];
let arguments = (ins AnyTypeOf<[TensorOf<[TensorOf<[UI8]>]>, TensorOf<[TensorOf<[UI16]>]>, TensorOf<[TensorOf<[UI32]>]>, TensorOf<[TensorOf<[UI64]>]>, TensorOf<[TensorOf<[I8]>]>, TensorOf<[TensorOf<[I16]>]>, TensorOf<[TensorOf<[I32]>]>, TensorOf<[TensorOf<[I64]>]>, TensorOf<[TensorOf<[F16]>]>, TensorOf<[TensorOf<[F32]>]>, TensorOf<[TensorOf<[F64]>]>, TensorOf<[TensorOf<[StringType]>]>, TensorOf<[TensorOf<[I1]>]>, TensorOf<[TensorOf<[Complex<F32>]>]>, TensorOf<[TensorOf<[Complex<F64>]>]>, AnyMemRef]>:$input_sequence, let arguments = (ins AnyTypeOf<[SeqOf<[TensorOf<[UI8]>]>, SeqOf<[TensorOf<[UI16]>]>, SeqOf<[TensorOf<[UI32]>]>, SeqOf<[TensorOf<[UI64]>]>, SeqOf<[TensorOf<[I8]>]>, SeqOf<[TensorOf<[I16]>]>, SeqOf<[TensorOf<[I32]>]>, SeqOf<[TensorOf<[I64]>]>, SeqOf<[TensorOf<[F16]>]>, SeqOf<[TensorOf<[F32]>]>, SeqOf<[TensorOf<[F64]>]>, SeqOf<[TensorOf<[StringType]>]>, SeqOf<[TensorOf<[I1]>]>, SeqOf<[TensorOf<[Complex<F32>]>]>, SeqOf<[TensorOf<[Complex<F64>]>]>, AnyMemRef]>:$input_sequence,
I64Attr:$axis, I64Attr:$axis,
DefaultValuedAttr<I64Attr, "0">:$new_axis); DefaultValuedAttr<I64Attr, "0">:$new_axis);
let results = (outs AnyTypeOf<[TensorOf<[UI8]>, TensorOf<[UI16]>, TensorOf<[UI32]>, TensorOf<[UI64]>, TensorOf<[I8]>, TensorOf<[I16]>, TensorOf<[I32]>, TensorOf<[I64]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, TensorOf<[StringType]>, TensorOf<[I1]>, TensorOf<[Complex<F32>]>, TensorOf<[Complex<F64>]>, AnyMemRef]>:$concat_result); let results = (outs AnyTypeOf<[TensorOf<[UI8]>, TensorOf<[UI16]>, TensorOf<[UI32]>, TensorOf<[UI64]>, TensorOf<[I8]>, TensorOf<[I16]>, TensorOf<[I32]>, TensorOf<[I64]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, TensorOf<[StringType]>, TensorOf<[I1]>, TensorOf<[Complex<F32>]>, TensorOf<[Complex<F64>]>, AnyMemRef]>:$concat_result);
@ -4582,7 +4582,7 @@ def ONNXSequenceAtOp:ONNX_Op<"SequenceAt",
"Accepted range for 'position' is in `[-n, n - 1]`, where `n` is the number of tensors in 'input_sequence'." "Accepted range for 'position' is in `[-n, n - 1]`, where `n` is the number of tensors in 'input_sequence'."
"Negative value means counting positions from the back." "Negative value means counting positions from the back."
}]; }];
let arguments = (ins AnyTypeOf<[TensorOf<[TensorOf<[UI8]>]>, TensorOf<[TensorOf<[UI16]>]>, TensorOf<[TensorOf<[UI32]>]>, TensorOf<[TensorOf<[UI64]>]>, TensorOf<[TensorOf<[I8]>]>, TensorOf<[TensorOf<[I16]>]>, TensorOf<[TensorOf<[I32]>]>, TensorOf<[TensorOf<[I64]>]>, TensorOf<[TensorOf<[F16]>]>, TensorOf<[TensorOf<[F32]>]>, TensorOf<[TensorOf<[F64]>]>, TensorOf<[TensorOf<[StringType]>]>, TensorOf<[TensorOf<[I1]>]>, TensorOf<[TensorOf<[Complex<F32>]>]>, TensorOf<[TensorOf<[Complex<F64>]>]>, AnyMemRef]>:$input_sequence, let arguments = (ins AnyTypeOf<[SeqOf<[TensorOf<[UI8]>]>, SeqOf<[TensorOf<[UI16]>]>, SeqOf<[TensorOf<[UI32]>]>, SeqOf<[TensorOf<[UI64]>]>, SeqOf<[TensorOf<[I8]>]>, SeqOf<[TensorOf<[I16]>]>, SeqOf<[TensorOf<[I32]>]>, SeqOf<[TensorOf<[I64]>]>, SeqOf<[TensorOf<[F16]>]>, SeqOf<[TensorOf<[F32]>]>, SeqOf<[TensorOf<[F64]>]>, SeqOf<[TensorOf<[StringType]>]>, SeqOf<[TensorOf<[I1]>]>, SeqOf<[TensorOf<[Complex<F32>]>]>, SeqOf<[TensorOf<[Complex<F64>]>]>, AnyMemRef]>:$input_sequence,
AnyTypeOf<[TensorOf<[I32]>, TensorOf<[I64]>, AnyMemRef]>:$position); AnyTypeOf<[TensorOf<[I32]>, TensorOf<[I64]>, AnyMemRef]>:$position);
let results = (outs AnyTypeOf<[TensorOf<[UI8]>, TensorOf<[UI16]>, TensorOf<[UI32]>, TensorOf<[UI64]>, TensorOf<[I8]>, TensorOf<[I16]>, TensorOf<[I32]>, TensorOf<[I64]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, TensorOf<[StringType]>, TensorOf<[I1]>, TensorOf<[Complex<F32>]>, TensorOf<[Complex<F64>]>, AnyMemRef]>:$tensor); let results = (outs AnyTypeOf<[TensorOf<[UI8]>, TensorOf<[UI16]>, TensorOf<[UI32]>, TensorOf<[UI64]>, TensorOf<[I8]>, TensorOf<[I16]>, TensorOf<[I32]>, TensorOf<[I64]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, TensorOf<[StringType]>, TensorOf<[I1]>, TensorOf<[Complex<F32>]>, TensorOf<[Complex<F64>]>, AnyMemRef]>:$tensor);
let extraClassDeclaration = [{ let extraClassDeclaration = [{
@ -4606,7 +4606,7 @@ def ONNXSequenceConstructOp:ONNX_Op<"SequenceConstruct",
"All tensors in 'inputs' must have the same data type." "All tensors in 'inputs' must have the same data type."
}]; }];
let arguments = (ins Variadic<AnyTypeOf<[TensorOf<[UI8]>, TensorOf<[UI16]>, TensorOf<[UI32]>, TensorOf<[UI64]>, TensorOf<[I8]>, TensorOf<[I16]>, TensorOf<[I32]>, TensorOf<[I64]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, TensorOf<[StringType]>, TensorOf<[I1]>, TensorOf<[Complex<F32>]>, TensorOf<[Complex<F64>]>, AnyMemRef]>>:$inputs); let arguments = (ins Variadic<AnyTypeOf<[TensorOf<[UI8]>, TensorOf<[UI16]>, TensorOf<[UI32]>, TensorOf<[UI64]>, TensorOf<[I8]>, TensorOf<[I16]>, TensorOf<[I32]>, TensorOf<[I64]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, TensorOf<[StringType]>, TensorOf<[I1]>, TensorOf<[Complex<F32>]>, TensorOf<[Complex<F64>]>, AnyMemRef]>>:$inputs);
let results = (outs AnyTypeOf<[TensorOf<[TensorOf<[UI8]>]>, TensorOf<[TensorOf<[UI16]>]>, TensorOf<[TensorOf<[UI32]>]>, TensorOf<[TensorOf<[UI64]>]>, TensorOf<[TensorOf<[I8]>]>, TensorOf<[TensorOf<[I16]>]>, TensorOf<[TensorOf<[I32]>]>, TensorOf<[TensorOf<[I64]>]>, TensorOf<[TensorOf<[F16]>]>, TensorOf<[TensorOf<[F32]>]>, TensorOf<[TensorOf<[F64]>]>, TensorOf<[TensorOf<[StringType]>]>, TensorOf<[TensorOf<[I1]>]>, TensorOf<[TensorOf<[Complex<F32>]>]>, TensorOf<[TensorOf<[Complex<F64>]>]>, AnyMemRef]>:$output_sequence); let results = (outs AnyTypeOf<[SeqOf<[TensorOf<[UI8]>]>, SeqOf<[TensorOf<[UI16]>]>, SeqOf<[TensorOf<[UI32]>]>, SeqOf<[TensorOf<[UI64]>]>, SeqOf<[TensorOf<[I8]>]>, SeqOf<[TensorOf<[I16]>]>, SeqOf<[TensorOf<[I32]>]>, SeqOf<[TensorOf<[I64]>]>, SeqOf<[TensorOf<[F16]>]>, SeqOf<[TensorOf<[F32]>]>, SeqOf<[TensorOf<[F64]>]>, SeqOf<[TensorOf<[StringType]>]>, SeqOf<[TensorOf<[I1]>]>, SeqOf<[TensorOf<[Complex<F32>]>]>, SeqOf<[TensorOf<[Complex<F64>]>]>, AnyMemRef]>:$output_sequence);
let extraClassDeclaration = [{ let extraClassDeclaration = [{
static int getNumberOfOperands() { static int getNumberOfOperands() {
return -1; return -1;
@ -4627,7 +4627,7 @@ def ONNXSequenceEmptyOp:ONNX_Op<"SequenceEmpty",
"Construct an empty tensor sequence, with given data type." "Construct an empty tensor sequence, with given data type."
}]; }];
let arguments = (ins OptionalAttr<I64Attr>:$dtype); let arguments = (ins OptionalAttr<I64Attr>:$dtype);
let results = (outs AnyTypeOf<[TensorOf<[TensorOf<[UI8]>]>, TensorOf<[TensorOf<[UI16]>]>, TensorOf<[TensorOf<[UI32]>]>, TensorOf<[TensorOf<[UI64]>]>, TensorOf<[TensorOf<[I8]>]>, TensorOf<[TensorOf<[I16]>]>, TensorOf<[TensorOf<[I32]>]>, TensorOf<[TensorOf<[I64]>]>, TensorOf<[TensorOf<[F16]>]>, TensorOf<[TensorOf<[F32]>]>, TensorOf<[TensorOf<[F64]>]>, TensorOf<[TensorOf<[StringType]>]>, TensorOf<[TensorOf<[I1]>]>, TensorOf<[TensorOf<[Complex<F32>]>]>, TensorOf<[TensorOf<[Complex<F64>]>]>, AnyMemRef]>:$output); let results = (outs AnyTypeOf<[SeqOf<[TensorOf<[UI8]>]>, SeqOf<[TensorOf<[UI16]>]>, SeqOf<[TensorOf<[UI32]>]>, SeqOf<[TensorOf<[UI64]>]>, SeqOf<[TensorOf<[I8]>]>, SeqOf<[TensorOf<[I16]>]>, SeqOf<[TensorOf<[I32]>]>, SeqOf<[TensorOf<[I64]>]>, SeqOf<[TensorOf<[F16]>]>, SeqOf<[TensorOf<[F32]>]>, SeqOf<[TensorOf<[F64]>]>, SeqOf<[TensorOf<[StringType]>]>, SeqOf<[TensorOf<[I1]>]>, SeqOf<[TensorOf<[Complex<F32>]>]>, SeqOf<[TensorOf<[Complex<F64>]>]>, AnyMemRef]>:$output);
let extraClassDeclaration = [{ let extraClassDeclaration = [{
static int getNumberOfOperands() { static int getNumberOfOperands() {
return 0; return 0;
@ -4650,9 +4650,9 @@ def ONNXSequenceEraseOp:ONNX_Op<"SequenceErase",
"Negative value means counting positions from the back." "Negative value means counting positions from the back."
"'position' is optional, by default it erases the last tensor from 'input_sequence'." "'position' is optional, by default it erases the last tensor from 'input_sequence'."
}]; }];
let arguments = (ins AnyTypeOf<[TensorOf<[TensorOf<[UI8]>]>, TensorOf<[TensorOf<[UI16]>]>, TensorOf<[TensorOf<[UI32]>]>, TensorOf<[TensorOf<[UI64]>]>, TensorOf<[TensorOf<[I8]>]>, TensorOf<[TensorOf<[I16]>]>, TensorOf<[TensorOf<[I32]>]>, TensorOf<[TensorOf<[I64]>]>, TensorOf<[TensorOf<[F16]>]>, TensorOf<[TensorOf<[F32]>]>, TensorOf<[TensorOf<[F64]>]>, TensorOf<[TensorOf<[StringType]>]>, TensorOf<[TensorOf<[I1]>]>, TensorOf<[TensorOf<[Complex<F32>]>]>, TensorOf<[TensorOf<[Complex<F64>]>]>, AnyMemRef]>:$input_sequence, let arguments = (ins AnyTypeOf<[SeqOf<[TensorOf<[UI8]>]>, SeqOf<[TensorOf<[UI16]>]>, SeqOf<[TensorOf<[UI32]>]>, SeqOf<[TensorOf<[UI64]>]>, SeqOf<[TensorOf<[I8]>]>, SeqOf<[TensorOf<[I16]>]>, SeqOf<[TensorOf<[I32]>]>, SeqOf<[TensorOf<[I64]>]>, SeqOf<[TensorOf<[F16]>]>, SeqOf<[TensorOf<[F32]>]>, SeqOf<[TensorOf<[F64]>]>, SeqOf<[TensorOf<[StringType]>]>, SeqOf<[TensorOf<[I1]>]>, SeqOf<[TensorOf<[Complex<F32>]>]>, SeqOf<[TensorOf<[Complex<F64>]>]>, AnyMemRef]>:$input_sequence,
AnyTypeOf<[TensorOf<[I32]>, TensorOf<[I64]>, AnyMemRef, NoneType]>:$position); AnyTypeOf<[TensorOf<[I32]>, TensorOf<[I64]>, AnyMemRef, NoneType]>:$position);
let results = (outs AnyTypeOf<[TensorOf<[TensorOf<[UI8]>]>, TensorOf<[TensorOf<[UI16]>]>, TensorOf<[TensorOf<[UI32]>]>, TensorOf<[TensorOf<[UI64]>]>, TensorOf<[TensorOf<[I8]>]>, TensorOf<[TensorOf<[I16]>]>, TensorOf<[TensorOf<[I32]>]>, TensorOf<[TensorOf<[I64]>]>, TensorOf<[TensorOf<[F16]>]>, TensorOf<[TensorOf<[F32]>]>, TensorOf<[TensorOf<[F64]>]>, TensorOf<[TensorOf<[StringType]>]>, TensorOf<[TensorOf<[I1]>]>, TensorOf<[TensorOf<[Complex<F32>]>]>, TensorOf<[TensorOf<[Complex<F64>]>]>, AnyMemRef]>:$output_sequence); let results = (outs AnyTypeOf<[SeqOf<[TensorOf<[UI8]>]>, SeqOf<[TensorOf<[UI16]>]>, SeqOf<[TensorOf<[UI32]>]>, SeqOf<[TensorOf<[UI64]>]>, SeqOf<[TensorOf<[I8]>]>, SeqOf<[TensorOf<[I16]>]>, SeqOf<[TensorOf<[I32]>]>, SeqOf<[TensorOf<[I64]>]>, SeqOf<[TensorOf<[F16]>]>, SeqOf<[TensorOf<[F32]>]>, SeqOf<[TensorOf<[F64]>]>, SeqOf<[TensorOf<[StringType]>]>, SeqOf<[TensorOf<[I1]>]>, SeqOf<[TensorOf<[Complex<F32>]>]>, SeqOf<[TensorOf<[Complex<F64>]>]>, AnyMemRef]>:$output_sequence);
let extraClassDeclaration = [{ let extraClassDeclaration = [{
static int getNumberOfOperands() { static int getNumberOfOperands() {
return 2; return 2;
@ -4676,10 +4676,10 @@ def ONNXSequenceInsertOp:ONNX_Op<"SequenceInsert",
"Negative value means counting positions from the back." "Negative value means counting positions from the back."
"'position' is optional, by default it inserts 'tensor' to the back of 'input_sequence'." "'position' is optional, by default it inserts 'tensor' to the back of 'input_sequence'."
}]; }];
let arguments = (ins AnyTypeOf<[TensorOf<[TensorOf<[UI8]>]>, TensorOf<[TensorOf<[UI16]>]>, TensorOf<[TensorOf<[UI32]>]>, TensorOf<[TensorOf<[UI64]>]>, TensorOf<[TensorOf<[I8]>]>, TensorOf<[TensorOf<[I16]>]>, TensorOf<[TensorOf<[I32]>]>, TensorOf<[TensorOf<[I64]>]>, TensorOf<[TensorOf<[F16]>]>, TensorOf<[TensorOf<[F32]>]>, TensorOf<[TensorOf<[F64]>]>, TensorOf<[TensorOf<[StringType]>]>, TensorOf<[TensorOf<[I1]>]>, TensorOf<[TensorOf<[Complex<F32>]>]>, TensorOf<[TensorOf<[Complex<F64>]>]>, AnyMemRef]>:$input_sequence, let arguments = (ins AnyTypeOf<[SeqOf<[TensorOf<[UI8]>]>, SeqOf<[TensorOf<[UI16]>]>, SeqOf<[TensorOf<[UI32]>]>, SeqOf<[TensorOf<[UI64]>]>, SeqOf<[TensorOf<[I8]>]>, SeqOf<[TensorOf<[I16]>]>, SeqOf<[TensorOf<[I32]>]>, SeqOf<[TensorOf<[I64]>]>, SeqOf<[TensorOf<[F16]>]>, SeqOf<[TensorOf<[F32]>]>, SeqOf<[TensorOf<[F64]>]>, SeqOf<[TensorOf<[StringType]>]>, SeqOf<[TensorOf<[I1]>]>, SeqOf<[TensorOf<[Complex<F32>]>]>, SeqOf<[TensorOf<[Complex<F64>]>]>, AnyMemRef]>:$input_sequence,
AnyTypeOf<[TensorOf<[UI8]>, TensorOf<[UI16]>, TensorOf<[UI32]>, TensorOf<[UI64]>, TensorOf<[I8]>, TensorOf<[I16]>, TensorOf<[I32]>, TensorOf<[I64]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, TensorOf<[StringType]>, TensorOf<[I1]>, TensorOf<[Complex<F32>]>, TensorOf<[Complex<F64>]>, AnyMemRef]>:$tensor, AnyTypeOf<[TensorOf<[UI8]>, TensorOf<[UI16]>, TensorOf<[UI32]>, TensorOf<[UI64]>, TensorOf<[I8]>, TensorOf<[I16]>, TensorOf<[I32]>, TensorOf<[I64]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, TensorOf<[StringType]>, TensorOf<[I1]>, TensorOf<[Complex<F32>]>, TensorOf<[Complex<F64>]>, AnyMemRef]>:$tensor,
AnyTypeOf<[TensorOf<[I32]>, TensorOf<[I64]>, AnyMemRef, NoneType]>:$position); AnyTypeOf<[TensorOf<[I32]>, TensorOf<[I64]>, AnyMemRef, NoneType]>:$position);
let results = (outs AnyTypeOf<[TensorOf<[TensorOf<[UI8]>]>, TensorOf<[TensorOf<[UI16]>]>, TensorOf<[TensorOf<[UI32]>]>, TensorOf<[TensorOf<[UI64]>]>, TensorOf<[TensorOf<[I8]>]>, TensorOf<[TensorOf<[I16]>]>, TensorOf<[TensorOf<[I32]>]>, TensorOf<[TensorOf<[I64]>]>, TensorOf<[TensorOf<[F16]>]>, TensorOf<[TensorOf<[F32]>]>, TensorOf<[TensorOf<[F64]>]>, TensorOf<[TensorOf<[StringType]>]>, TensorOf<[TensorOf<[I1]>]>, TensorOf<[TensorOf<[Complex<F32>]>]>, TensorOf<[TensorOf<[Complex<F64>]>]>, AnyMemRef]>:$output_sequence); let results = (outs AnyTypeOf<[SeqOf<[TensorOf<[UI8]>]>, SeqOf<[TensorOf<[UI16]>]>, SeqOf<[TensorOf<[UI32]>]>, SeqOf<[TensorOf<[UI64]>]>, SeqOf<[TensorOf<[I8]>]>, SeqOf<[TensorOf<[I16]>]>, SeqOf<[TensorOf<[I32]>]>, SeqOf<[TensorOf<[I64]>]>, SeqOf<[TensorOf<[F16]>]>, SeqOf<[TensorOf<[F32]>]>, SeqOf<[TensorOf<[F64]>]>, SeqOf<[TensorOf<[StringType]>]>, SeqOf<[TensorOf<[I1]>]>, SeqOf<[TensorOf<[Complex<F32>]>]>, SeqOf<[TensorOf<[Complex<F64>]>]>, AnyMemRef]>:$output_sequence);
let extraClassDeclaration = [{ let extraClassDeclaration = [{
static int getNumberOfOperands() { static int getNumberOfOperands() {
return 3; return 3;
@ -4699,7 +4699,7 @@ def ONNXSequenceLengthOp:ONNX_Op<"SequenceLength",
let description = [{ let description = [{
"Produces a scalar(tensor of empty shape) containing the number of tensors in 'input_sequence'." "Produces a scalar(tensor of empty shape) containing the number of tensors in 'input_sequence'."
}]; }];
let arguments = (ins AnyTypeOf<[TensorOf<[TensorOf<[UI8]>]>, TensorOf<[TensorOf<[UI16]>]>, TensorOf<[TensorOf<[UI32]>]>, TensorOf<[TensorOf<[UI64]>]>, TensorOf<[TensorOf<[I8]>]>, TensorOf<[TensorOf<[I16]>]>, TensorOf<[TensorOf<[I32]>]>, TensorOf<[TensorOf<[I64]>]>, TensorOf<[TensorOf<[F16]>]>, TensorOf<[TensorOf<[F32]>]>, TensorOf<[TensorOf<[F64]>]>, TensorOf<[TensorOf<[StringType]>]>, TensorOf<[TensorOf<[I1]>]>, TensorOf<[TensorOf<[Complex<F32>]>]>, TensorOf<[TensorOf<[Complex<F64>]>]>, AnyMemRef]>:$input_sequence); let arguments = (ins AnyTypeOf<[SeqOf<[TensorOf<[UI8]>]>, SeqOf<[TensorOf<[UI16]>]>, SeqOf<[TensorOf<[UI32]>]>, SeqOf<[TensorOf<[UI64]>]>, SeqOf<[TensorOf<[I8]>]>, SeqOf<[TensorOf<[I16]>]>, SeqOf<[TensorOf<[I32]>]>, SeqOf<[TensorOf<[I64]>]>, SeqOf<[TensorOf<[F16]>]>, SeqOf<[TensorOf<[F32]>]>, SeqOf<[TensorOf<[F64]>]>, SeqOf<[TensorOf<[StringType]>]>, SeqOf<[TensorOf<[I1]>]>, SeqOf<[TensorOf<[Complex<F32>]>]>, SeqOf<[TensorOf<[Complex<F64>]>]>, AnyMemRef]>:$input_sequence);
let results = (outs AnyTypeOf<[TensorOf<[I64]>, AnyMemRef]>:$length); let results = (outs AnyTypeOf<[TensorOf<[I64]>, AnyMemRef]>:$length);
let extraClassDeclaration = [{ let extraClassDeclaration = [{
static int getNumberOfOperands() { static int getNumberOfOperands() {
@ -5074,7 +5074,7 @@ def ONNXSplitToSequenceOp:ONNX_Op<"SplitToSequence",
AnyTypeOf<[TensorOf<[I32]>, TensorOf<[I64]>, AnyMemRef, NoneType]>:$split, AnyTypeOf<[TensorOf<[I32]>, TensorOf<[I64]>, AnyMemRef, NoneType]>:$split,
DefaultValuedAttr<I64Attr, "0">:$axis, DefaultValuedAttr<I64Attr, "0">:$axis,
DefaultValuedAttr<I64Attr, "1">:$keepdims); DefaultValuedAttr<I64Attr, "1">:$keepdims);
let results = (outs AnyTypeOf<[TensorOf<[TensorOf<[UI8]>]>, TensorOf<[TensorOf<[UI16]>]>, TensorOf<[TensorOf<[UI32]>]>, TensorOf<[TensorOf<[UI64]>]>, TensorOf<[TensorOf<[I8]>]>, TensorOf<[TensorOf<[I16]>]>, TensorOf<[TensorOf<[I32]>]>, TensorOf<[TensorOf<[I64]>]>, TensorOf<[TensorOf<[F16]>]>, TensorOf<[TensorOf<[F32]>]>, TensorOf<[TensorOf<[F64]>]>, TensorOf<[TensorOf<[StringType]>]>, TensorOf<[TensorOf<[I1]>]>, TensorOf<[TensorOf<[Complex<F32>]>]>, TensorOf<[TensorOf<[Complex<F64>]>]>, AnyMemRef]>:$output_sequence); let results = (outs AnyTypeOf<[SeqOf<[TensorOf<[UI8]>]>, SeqOf<[TensorOf<[UI16]>]>, SeqOf<[TensorOf<[UI32]>]>, SeqOf<[TensorOf<[UI64]>]>, SeqOf<[TensorOf<[I8]>]>, SeqOf<[TensorOf<[I16]>]>, SeqOf<[TensorOf<[I32]>]>, SeqOf<[TensorOf<[I64]>]>, SeqOf<[TensorOf<[F16]>]>, SeqOf<[TensorOf<[F32]>]>, SeqOf<[TensorOf<[F64]>]>, SeqOf<[TensorOf<[StringType]>]>, SeqOf<[TensorOf<[I1]>]>, SeqOf<[TensorOf<[Complex<F32>]>]>, SeqOf<[TensorOf<[Complex<F64>]>]>, AnyMemRef]>:$output_sequence);
let extraClassDeclaration = [{ let extraClassDeclaration = [{
static int getNumberOfOperands() { static int getNumberOfOperands() {
return 2; return 2;
@ -6240,7 +6240,7 @@ def ONNXZipMapOp:ONNX_Op<"ZipMap",
let arguments = (ins TensorOf<[F32]>:$X, let arguments = (ins TensorOf<[F32]>:$X,
OptionalAttr<I64ArrayAttr>:$classlabels_int64s, OptionalAttr<I64ArrayAttr>:$classlabels_int64s,
OptionalAttr<StrArrayAttr>:$classlabels_strings); OptionalAttr<StrArrayAttr>:$classlabels_strings);
let results = (outs AnyTypeOf<[TensorOf<[TupleOf<[StringType, F32]>]>, TensorOf<[TupleOf<[I64, F32]>]>, AnyMemRef]>:$Z); let results = (outs AnyTypeOf<[SeqOf<[TupleOf<[StringType, F32]>]>, SeqOf<[TupleOf<[I64, F32]>]>, AnyMemRef]>:$Z);
let extraClassDeclaration = [{ let extraClassDeclaration = [{
static int getNumberOfOperands() { static int getNumberOfOperands() {
return 1; return 1;

View File

@ -14,3 +14,14 @@ func @check_string(%arg0: tensor<10x20x!onnx.String>) -> tensor<10x20x!onnx.Stri
// CHECK-NEXT: return %arg0 : tensor<10x20x!onnx.String> // CHECK-NEXT: return %arg0 : tensor<10x20x!onnx.String>
} }
// CHECK-LABEL: @check_seq(%arg0: tensor<10x20xf32>, %arg1: tensor<5x20xf32>) -> tensor<*xf32> {
func @check_seq(%arg0: tensor<10x20xf32>, %arg1: tensor<5x20xf32>) -> tensor<*xf32> {
%cst = "onnx.Constant"() {value = dense<[0]> : tensor<1xi32>} : () -> tensor<1xi32>
%0 = "onnx.SequenceConstruct"(%arg0, %arg1) : (tensor<10x20xf32>, tensor<5x20xf32>) -> !onnx.Seq<tensor<10x20xf32>, tensor<5x20xf32>>
%1 = "onnx.SequenceAt"(%0, %cst) : (!onnx.Seq<tensor<10x20xf32>, tensor<5x20xf32>>, tensor<1xi32>) -> tensor<*xf32>
return %1 : tensor<*xf32>
// CHECK-NEXT: %0 = "onnx.Constant"() {value = dense<0> : tensor<1xi32>} : () -> tensor<1xi32>
// CHECK-NEXT: %1 = "onnx.SequenceConstruct"(%arg0, %arg1) : (tensor<10x20xf32>, tensor<5x20xf32>) -> !onnx.Seq<tensor<10x20xf32>, tensor<5x20xf32>>
// CHECK-NEXT: %2 = "onnx.SequenceAt"(%1, %0) : (!onnx.Seq<tensor<10x20xf32>, tensor<5x20xf32>>, tensor<1xi32>) -> tensor<*xf32>
}

View File

@ -728,7 +728,7 @@ def parse_type_str(allowedType):
onnx_to_mlir_type_dict = { '(': '<[', onnx_to_mlir_type_dict = { '(': '<[',
')': ']>', ')': ']>',
'tensor' : 'TensorOf', 'tensor' : 'TensorOf',
'seq' : 'TensorOf', 'seq' : 'SeqOf',
'map' : 'TupleOf', 'map' : 'TupleOf',
'bool': 'I1', 'bool': 'I1',
#'uint8' : 'AnyI8', #'uint8' : 'AnyI8',