Seq type (#199)
* base implementation * add example * change table gen * docs * small change for review Co-authored-by: Alexandre Eichenberger <alexe@us.ibm.com> Co-authored-by: Tian Jin <tjingrant@gmail.com>
This commit is contained in:
		
							parent
							
								
									8e8f894574
								
							
						
					
					
						commit
						b4228fd288
					
				|  | @ -637,7 +637,7 @@ ONNX ConcatFromSequence operation | |||
| 
 | ||||
| | Operand | Description | | ||||
| | :-----: | ----------- | | ||||
| `input_sequence` | tensor of tensor of 8-bit unsigned integer values values or tensor of tensor of 16-bit unsigned integer values values or tensor of tensor of 32-bit unsigned integer values values or tensor of tensor of 64-bit unsigned integer values values or tensor of tensor of 8-bit signless integer values values or tensor of tensor of 16-bit signless integer values values or tensor of tensor of 32-bit signless integer values values or tensor of tensor of 64-bit signless integer values values or tensor of tensor of 16-bit float values values or tensor of tensor of 32-bit float values values or tensor of tensor of 64-bit float values values or tensor of tensor of stirng type values values or tensor of tensor of 1-bit signless integer values values or tensor of tensor of complex type with 32-bit float elements values values or tensor of tensor of complex type with 64-bit float elements values values or memref of any type values | ||||
| `input_sequence` | SeqType of tensor of 8-bit unsigned integer values values or SeqType of tensor of 16-bit unsigned integer values values or SeqType of tensor of 32-bit unsigned integer values values or SeqType of tensor of 64-bit unsigned integer values values or SeqType of tensor of 8-bit signless integer values values or SeqType of tensor of 16-bit signless integer values values or SeqType of tensor of 32-bit signless integer values values or SeqType of tensor of 64-bit signless integer values values or SeqType of tensor of 16-bit float values values or SeqType of tensor of 32-bit float values values or SeqType of tensor of 64-bit float values values or SeqType of tensor of stirng type values values or SeqType of tensor of 1-bit signless integer values values or SeqType of tensor of complex type with 32-bit float elements values values or SeqType of tensor of complex type with 64-bit float elements values values or memref of any type values | ||||
| 
 | ||||
| #### Results: | ||||
| 
 | ||||
|  | @ -4833,7 +4833,7 @@ ONNX SequenceAt operation | |||
| 
 | ||||
| | Operand | Description | | ||||
| | :-----: | ----------- | | ||||
| `input_sequence` | tensor of tensor of 8-bit unsigned integer values values or tensor of tensor of 16-bit unsigned integer values values or tensor of tensor of 32-bit unsigned integer values values or tensor of tensor of 64-bit unsigned integer values values or tensor of tensor of 8-bit signless integer values values or tensor of tensor of 16-bit signless integer values values or tensor of tensor of 32-bit signless integer values values or tensor of tensor of 64-bit signless integer values values or tensor of tensor of 16-bit float values values or tensor of tensor of 32-bit float values values or tensor of tensor of 64-bit float values values or tensor of tensor of stirng type values values or tensor of tensor of 1-bit signless integer values values or tensor of tensor of complex type with 32-bit float elements values values or tensor of tensor of complex type with 64-bit float elements values values or memref of any type values | ||||
| `input_sequence` | SeqType of tensor of 8-bit unsigned integer values values or SeqType of tensor of 16-bit unsigned integer values values or SeqType of tensor of 32-bit unsigned integer values values or SeqType of tensor of 64-bit unsigned integer values values or SeqType of tensor of 8-bit signless integer values values or SeqType of tensor of 16-bit signless integer values values or SeqType of tensor of 32-bit signless integer values values or SeqType of tensor of 64-bit signless integer values values or SeqType of tensor of 16-bit float values values or SeqType of tensor of 32-bit float values values or SeqType of tensor of 64-bit float values values or SeqType of tensor of stirng type values values or SeqType of tensor of 1-bit signless integer values values or SeqType of tensor of complex type with 32-bit float elements values values or SeqType of tensor of complex type with 64-bit float elements values values or memref of any type values | ||||
| `position` | tensor of 32-bit signless integer values or tensor of 64-bit signless integer values or memref of any type values | ||||
| 
 | ||||
| #### Results: | ||||
|  | @ -4859,7 +4859,7 @@ ONNX SequenceConstruct operation | |||
| 
 | ||||
| | Result | Description | | ||||
| | :----: | ----------- | | ||||
| `output_sequence` | tensor of tensor of 8-bit unsigned integer values values or tensor of tensor of 16-bit unsigned integer values values or tensor of tensor of 32-bit unsigned integer values values or tensor of tensor of 64-bit unsigned integer values values or tensor of tensor of 8-bit signless integer values values or tensor of tensor of 16-bit signless integer values values or tensor of tensor of 32-bit signless integer values values or tensor of tensor of 64-bit signless integer values values or tensor of tensor of 16-bit float values values or tensor of tensor of 32-bit float values values or tensor of tensor of 64-bit float values values or tensor of tensor of stirng type values values or tensor of tensor of 1-bit signless integer values values or tensor of tensor of complex type with 32-bit float elements values values or tensor of tensor of complex type with 64-bit float elements values values or memref of any type values | ||||
| `output_sequence` | SeqType of tensor of 8-bit unsigned integer values values or SeqType of tensor of 16-bit unsigned integer values values or SeqType of tensor of 32-bit unsigned integer values values or SeqType of tensor of 64-bit unsigned integer values values or SeqType of tensor of 8-bit signless integer values values or SeqType of tensor of 16-bit signless integer values values or SeqType of tensor of 32-bit signless integer values values or SeqType of tensor of 64-bit signless integer values values or SeqType of tensor of 16-bit float values values or SeqType of tensor of 32-bit float values values or SeqType of tensor of 64-bit float values values or SeqType of tensor of stirng type values values or SeqType of tensor of 1-bit signless integer values values or SeqType of tensor of complex type with 32-bit float elements values values or SeqType of tensor of complex type with 64-bit float elements values values or memref of any type values | ||||
| 
 | ||||
| ### `onnx.SequenceEmpty` (ONNXSequenceEmptyOp) | ||||
| 
 | ||||
|  | @ -4877,7 +4877,7 @@ ONNX SequenceEmpty operation | |||
| 
 | ||||
| | Result | Description | | ||||
| | :----: | ----------- | | ||||
| `output` | tensor of tensor of 8-bit unsigned integer values values or tensor of tensor of 16-bit unsigned integer values values or tensor of tensor of 32-bit unsigned integer values values or tensor of tensor of 64-bit unsigned integer values values or tensor of tensor of 8-bit signless integer values values or tensor of tensor of 16-bit signless integer values values or tensor of tensor of 32-bit signless integer values values or tensor of tensor of 64-bit signless integer values values or tensor of tensor of 16-bit float values values or tensor of tensor of 32-bit float values values or tensor of tensor of 64-bit float values values or tensor of tensor of stirng type values values or tensor of tensor of 1-bit signless integer values values or tensor of tensor of complex type with 32-bit float elements values values or tensor of tensor of complex type with 64-bit float elements values values or memref of any type values | ||||
| `output` | SeqType of tensor of 8-bit unsigned integer values values or SeqType of tensor of 16-bit unsigned integer values values or SeqType of tensor of 32-bit unsigned integer values values or SeqType of tensor of 64-bit unsigned integer values values or SeqType of tensor of 8-bit signless integer values values or SeqType of tensor of 16-bit signless integer values values or SeqType of tensor of 32-bit signless integer values values or SeqType of tensor of 64-bit signless integer values values or SeqType of tensor of 16-bit float values values or SeqType of tensor of 32-bit float values values or SeqType of tensor of 64-bit float values values or SeqType of tensor of stirng type values values or SeqType of tensor of 1-bit signless integer values values or SeqType of tensor of complex type with 32-bit float elements values values or SeqType of tensor of complex type with 64-bit float elements values values or memref of any type values | ||||
| 
 | ||||
| ### `onnx.SequenceErase` (ONNXSequenceEraseOp) | ||||
| 
 | ||||
|  | @ -4892,14 +4892,14 @@ ONNX SequenceErase operation | |||
| 
 | ||||
| | Operand | Description | | ||||
| | :-----: | ----------- | | ||||
| `input_sequence` | tensor of tensor of 8-bit unsigned integer values values or tensor of tensor of 16-bit unsigned integer values values or tensor of tensor of 32-bit unsigned integer values values or tensor of tensor of 64-bit unsigned integer values values or tensor of tensor of 8-bit signless integer values values or tensor of tensor of 16-bit signless integer values values or tensor of tensor of 32-bit signless integer values values or tensor of tensor of 64-bit signless integer values values or tensor of tensor of 16-bit float values values or tensor of tensor of 32-bit float values values or tensor of tensor of 64-bit float values values or tensor of tensor of stirng type values values or tensor of tensor of 1-bit signless integer values values or tensor of tensor of complex type with 32-bit float elements values values or tensor of tensor of complex type with 64-bit float elements values values or memref of any type values | ||||
| `input_sequence` | SeqType of tensor of 8-bit unsigned integer values values or SeqType of tensor of 16-bit unsigned integer values values or SeqType of tensor of 32-bit unsigned integer values values or SeqType of tensor of 64-bit unsigned integer values values or SeqType of tensor of 8-bit signless integer values values or SeqType of tensor of 16-bit signless integer values values or SeqType of tensor of 32-bit signless integer values values or SeqType of tensor of 64-bit signless integer values values or SeqType of tensor of 16-bit float values values or SeqType of tensor of 32-bit float values values or SeqType of tensor of 64-bit float values values or SeqType of tensor of stirng type values values or SeqType of tensor of 1-bit signless integer values values or SeqType of tensor of complex type with 32-bit float elements values values or SeqType of tensor of complex type with 64-bit float elements values values or memref of any type values | ||||
| `position` | tensor of 32-bit signless integer values or tensor of 64-bit signless integer values or memref of any type values or none type | ||||
| 
 | ||||
| #### Results: | ||||
| 
 | ||||
| | Result | Description | | ||||
| | :----: | ----------- | | ||||
| `output_sequence` | tensor of tensor of 8-bit unsigned integer values values or tensor of tensor of 16-bit unsigned integer values values or tensor of tensor of 32-bit unsigned integer values values or tensor of tensor of 64-bit unsigned integer values values or tensor of tensor of 8-bit signless integer values values or tensor of tensor of 16-bit signless integer values values or tensor of tensor of 32-bit signless integer values values or tensor of tensor of 64-bit signless integer values values or tensor of tensor of 16-bit float values values or tensor of tensor of 32-bit float values values or tensor of tensor of 64-bit float values values or tensor of tensor of stirng type values values or tensor of tensor of 1-bit signless integer values values or tensor of tensor of complex type with 32-bit float elements values values or tensor of tensor of complex type with 64-bit float elements values values or memref of any type values | ||||
| `output_sequence` | SeqType of tensor of 8-bit unsigned integer values values or SeqType of tensor of 16-bit unsigned integer values values or SeqType of tensor of 32-bit unsigned integer values values or SeqType of tensor of 64-bit unsigned integer values values or SeqType of tensor of 8-bit signless integer values values or SeqType of tensor of 16-bit signless integer values values or SeqType of tensor of 32-bit signless integer values values or SeqType of tensor of 64-bit signless integer values values or SeqType of tensor of 16-bit float values values or SeqType of tensor of 32-bit float values values or SeqType of tensor of 64-bit float values values or SeqType of tensor of stirng type values values or SeqType of tensor of 1-bit signless integer values values or SeqType of tensor of complex type with 32-bit float elements values values or SeqType of tensor of complex type with 64-bit float elements values values or memref of any type values | ||||
| 
 | ||||
| ### `onnx.SequenceInsert` (ONNXSequenceInsertOp) | ||||
| 
 | ||||
|  | @ -4915,7 +4915,7 @@ ONNX SequenceInsert operation | |||
| 
 | ||||
| | Operand | Description | | ||||
| | :-----: | ----------- | | ||||
| `input_sequence` | tensor of tensor of 8-bit unsigned integer values values or tensor of tensor of 16-bit unsigned integer values values or tensor of tensor of 32-bit unsigned integer values values or tensor of tensor of 64-bit unsigned integer values values or tensor of tensor of 8-bit signless integer values values or tensor of tensor of 16-bit signless integer values values or tensor of tensor of 32-bit signless integer values values or tensor of tensor of 64-bit signless integer values values or tensor of tensor of 16-bit float values values or tensor of tensor of 32-bit float values values or tensor of tensor of 64-bit float values values or tensor of tensor of stirng type values values or tensor of tensor of 1-bit signless integer values values or tensor of tensor of complex type with 32-bit float elements values values or tensor of tensor of complex type with 64-bit float elements values values or memref of any type values | ||||
| `input_sequence` | SeqType of tensor of 8-bit unsigned integer values values or SeqType of tensor of 16-bit unsigned integer values values or SeqType of tensor of 32-bit unsigned integer values values or SeqType of tensor of 64-bit unsigned integer values values or SeqType of tensor of 8-bit signless integer values values or SeqType of tensor of 16-bit signless integer values values or SeqType of tensor of 32-bit signless integer values values or SeqType of tensor of 64-bit signless integer values values or SeqType of tensor of 16-bit float values values or SeqType of tensor of 32-bit float values values or SeqType of tensor of 64-bit float values values or SeqType of tensor of stirng type values values or SeqType of tensor of 1-bit signless integer values values or SeqType of tensor of complex type with 32-bit float elements values values or SeqType of tensor of complex type with 64-bit float elements values values or memref of any type values | ||||
| `tensor` | tensor of 8-bit unsigned integer values or tensor of 16-bit unsigned integer values or tensor of 32-bit unsigned integer values or tensor of 64-bit unsigned integer values or tensor of 8-bit signless integer values or tensor of 16-bit signless integer values or tensor of 32-bit signless integer values or tensor of 64-bit signless integer values or tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values or tensor of stirng type values or tensor of 1-bit signless integer values or tensor of complex type with 32-bit float elements values or tensor of complex type with 64-bit float elements values or memref of any type values | ||||
| `position` | tensor of 32-bit signless integer values or tensor of 64-bit signless integer values or memref of any type values or none type | ||||
| 
 | ||||
|  | @ -4923,7 +4923,7 @@ ONNX SequenceInsert operation | |||
| 
 | ||||
| | Result | Description | | ||||
| | :----: | ----------- | | ||||
| `output_sequence` | tensor of tensor of 8-bit unsigned integer values values or tensor of tensor of 16-bit unsigned integer values values or tensor of tensor of 32-bit unsigned integer values values or tensor of tensor of 64-bit unsigned integer values values or tensor of tensor of 8-bit signless integer values values or tensor of tensor of 16-bit signless integer values values or tensor of tensor of 32-bit signless integer values values or tensor of tensor of 64-bit signless integer values values or tensor of tensor of 16-bit float values values or tensor of tensor of 32-bit float values values or tensor of tensor of 64-bit float values values or tensor of tensor of stirng type values values or tensor of tensor of 1-bit signless integer values values or tensor of tensor of complex type with 32-bit float elements values values or tensor of tensor of complex type with 64-bit float elements values values or memref of any type values | ||||
| `output_sequence` | SeqType of tensor of 8-bit unsigned integer values values or SeqType of tensor of 16-bit unsigned integer values values or SeqType of tensor of 32-bit unsigned integer values values or SeqType of tensor of 64-bit unsigned integer values values or SeqType of tensor of 8-bit signless integer values values or SeqType of tensor of 16-bit signless integer values values or SeqType of tensor of 32-bit signless integer values values or SeqType of tensor of 64-bit signless integer values values or SeqType of tensor of 16-bit float values values or SeqType of tensor of 32-bit float values values or SeqType of tensor of 64-bit float values values or SeqType of tensor of stirng type values values or SeqType of tensor of 1-bit signless integer values values or SeqType of tensor of complex type with 32-bit float elements values values or SeqType of tensor of complex type with 64-bit float elements values values or memref of any type values | ||||
| 
 | ||||
| ### `onnx.SequenceLength` (ONNXSequenceLengthOp) | ||||
| 
 | ||||
|  | @ -4935,7 +4935,7 @@ ONNX SequenceLength operation | |||
| 
 | ||||
| | Operand | Description | | ||||
| | :-----: | ----------- | | ||||
| `input_sequence` | tensor of tensor of 8-bit unsigned integer values values or tensor of tensor of 16-bit unsigned integer values values or tensor of tensor of 32-bit unsigned integer values values or tensor of tensor of 64-bit unsigned integer values values or tensor of tensor of 8-bit signless integer values values or tensor of tensor of 16-bit signless integer values values or tensor of tensor of 32-bit signless integer values values or tensor of tensor of 64-bit signless integer values values or tensor of tensor of 16-bit float values values or tensor of tensor of 32-bit float values values or tensor of tensor of 64-bit float values values or tensor of tensor of stirng type values values or tensor of tensor of 1-bit signless integer values values or tensor of tensor of complex type with 32-bit float elements values values or tensor of tensor of complex type with 64-bit float elements values values or memref of any type values | ||||
| `input_sequence` | SeqType of tensor of 8-bit unsigned integer values values or SeqType of tensor of 16-bit unsigned integer values values or SeqType of tensor of 32-bit unsigned integer values values or SeqType of tensor of 64-bit unsigned integer values values or SeqType of tensor of 8-bit signless integer values values or SeqType of tensor of 16-bit signless integer values values or SeqType of tensor of 32-bit signless integer values values or SeqType of tensor of 64-bit signless integer values values or SeqType of tensor of 16-bit float values values or SeqType of tensor of 32-bit float values values or SeqType of tensor of 64-bit float values values or SeqType of tensor of stirng type values values or SeqType of tensor of 1-bit signless integer values values or SeqType of tensor of complex type with 32-bit float elements values values or SeqType of tensor of complex type with 64-bit float elements values values or memref of any type values | ||||
| 
 | ||||
| #### Results: | ||||
| 
 | ||||
|  | @ -5298,7 +5298,7 @@ ONNX SplitToSequence operation | |||
| 
 | ||||
| | Result | Description | | ||||
| | :----: | ----------- | | ||||
| `output_sequence` | tensor of tensor of 8-bit unsigned integer values values or tensor of tensor of 16-bit unsigned integer values values or tensor of tensor of 32-bit unsigned integer values values or tensor of tensor of 64-bit unsigned integer values values or tensor of tensor of 8-bit signless integer values values or tensor of tensor of 16-bit signless integer values values or tensor of tensor of 32-bit signless integer values values or tensor of tensor of 64-bit signless integer values values or tensor of tensor of 16-bit float values values or tensor of tensor of 32-bit float values values or tensor of tensor of 64-bit float values values or tensor of tensor of stirng type values values or tensor of tensor of 1-bit signless integer values values or tensor of tensor of complex type with 32-bit float elements values values or tensor of tensor of complex type with 64-bit float elements values values or memref of any type values | ||||
| `output_sequence` | SeqType of tensor of 8-bit unsigned integer values values or SeqType of tensor of 16-bit unsigned integer values values or SeqType of tensor of 32-bit unsigned integer values values or SeqType of tensor of 64-bit unsigned integer values values or SeqType of tensor of 8-bit signless integer values values or SeqType of tensor of 16-bit signless integer values values or SeqType of tensor of 32-bit signless integer values values or SeqType of tensor of 64-bit signless integer values values or SeqType of tensor of 16-bit float values values or SeqType of tensor of 32-bit float values values or SeqType of tensor of 64-bit float values values or SeqType of tensor of stirng type values values or SeqType of tensor of 1-bit signless integer values values or SeqType of tensor of complex type with 32-bit float elements values values or SeqType of tensor of complex type with 64-bit float elements values values or memref of any type values | ||||
| 
 | ||||
| ### `onnx.Sqrt` (ONNXSqrtOp) | ||||
| 
 | ||||
|  | @ -5964,5 +5964,5 @@ ONNX ZipMap operation | |||
| 
 | ||||
| | Result | Description | | ||||
| | :----: | ----------- | | ||||
| `Z` | tensor of tuple with any combination of stirng type or 32-bit float values values or tensor of tuple with any combination of 64-bit signless integer or 32-bit float values values or memref of any type values | ||||
| `Z` | SeqType of tuple with any combination of stirng type or 32-bit float values values or SeqType of tuple with any combination of 64-bit signless integer or 32-bit float values values or memref of any type values | ||||
| 
 | ||||
|  |  | |||
|  | @ -484,18 +484,48 @@ ONNXOpsDialect::ONNXOpsDialect(mlir::MLIRContext *ctx) | |||
| #include "src/Dialect/ONNX/ONNXOps.cpp.inc" | ||||
|       >(); | ||||
|   addTypes<StringType>(); | ||||
|   addTypes<SeqType>(); | ||||
| } | ||||
| 
 | ||||
| mlir::Type ONNXOpsDialect::parseType(mlir::DialectAsmParser &parser) const { | ||||
|   if (parser.parseKeyword("String")) | ||||
|   StringRef keyword; | ||||
|   if (parser.parseKeyword(&keyword)) | ||||
|     return Type(); | ||||
| 
 | ||||
|   return StringType::get(getContext()); | ||||
|   if (keyword == "String") | ||||
|     return StringType::get(getContext()); | ||||
|   if (keyword == "Seq") { | ||||
|     if (parser.parseLess()) | ||||
|       return Type(); | ||||
| 
 | ||||
|     SmallVector<mlir::Type, 1> elementTypes; | ||||
|     do { | ||||
|       llvm::SMLoc typeLoc = parser.getCurrentLocation(); | ||||
|       mlir::Type elementType; | ||||
|       if (parser.parseType(elementType)) | ||||
|         return Type(); | ||||
| 
 | ||||
|       // TOFIX: type limitation for Seq? similar but different shape??
 | ||||
|       elementTypes.push_back(elementType); | ||||
|     } while (succeeded(parser.parseOptionalComma())); | ||||
| 
 | ||||
|     if (parser.parseGreater()) | ||||
|       return Type(); | ||||
|     return SeqType::get(elementTypes); | ||||
|   } | ||||
| } | ||||
| 
 | ||||
| void ONNXOpsDialect::printType( | ||||
|     mlir::Type type, mlir::DialectAsmPrinter &printer) const { | ||||
|   printer << "String"; | ||||
|   if (auto stringType = type.dyn_cast<StringType>()) { | ||||
|     printer << "String"; | ||||
|   } else if (auto seqType = type.dyn_cast<SeqType>()) { | ||||
|     printer << "Seq<"; | ||||
|     llvm::interleaveComma(seqType.getElementTypes(), printer); | ||||
|     printer << '>'; | ||||
|   } else { | ||||
|     llvm_unreachable("Unexpected onnxmlir type"); | ||||
|   } | ||||
| } | ||||
| 
 | ||||
| void ONNXEntryPointOp::build(mlir::OpBuilder &builder, | ||||
|  | @ -2625,9 +2655,56 @@ LogicalResult ONNXSliceOp::inferShapes() { | |||
|   return success(); | ||||
| } | ||||
| 
 | ||||
| //===----------------------------------------------------------------------===//
 | ||||
| // ONNX type related code
 | ||||
| //===----------------------------------------------------------------------===//
 | ||||
| 
 | ||||
| namespace mlir { | ||||
| namespace onnxmlir { | ||||
| namespace detail { | ||||
| struct SeqTypeStorage : public mlir::TypeStorage { | ||||
|   using KeyTy = llvm::ArrayRef<mlir::Type>; | ||||
| 
 | ||||
|   SeqTypeStorage(llvm::ArrayRef<mlir::Type> elementTypes) | ||||
|       : elementTypes(elementTypes) {} | ||||
| 
 | ||||
|   bool operator==(const KeyTy &key) const { return key == elementTypes; } | ||||
|   static llvm::hash_code hasKey(const KeyTy &key) { | ||||
|     return llvm::hash_value(key); | ||||
|   } | ||||
| 
 | ||||
|   static KeyTy getKey(llvm::ArrayRef<mlir::Type> elementTypes) { | ||||
|     return KeyTy(elementTypes); | ||||
|   } | ||||
| 
 | ||||
|   static SeqTypeStorage *construct( | ||||
|       mlir::TypeStorageAllocator &allocator, const KeyTy &key) { | ||||
|     llvm::ArrayRef<mlir::Type> elementTypes = allocator.copyInto(key); | ||||
|     return new (allocator.allocate<SeqTypeStorage>()) | ||||
|         SeqTypeStorage(elementTypes); | ||||
|   } | ||||
|   llvm::ArrayRef<mlir::Type> elementTypes; | ||||
| }; | ||||
| } // end namespace detail
 | ||||
| } // end namespace onnxmlir
 | ||||
| } // end namespace mlir
 | ||||
| 
 | ||||
| SeqType SeqType::get(llvm::ArrayRef<mlir::Type> elementTypes) { | ||||
|   assert(!elementTypes.empty() && "expected non-empty seq"); | ||||
|   mlir::MLIRContext *ctx = elementTypes.front().getContext(); | ||||
|   return Base::get(ctx, ONNXTypes::SEQ, elementTypes); | ||||
| } | ||||
| 
 | ||||
| llvm::ArrayRef<mlir::Type> SeqType::getElementTypes() { | ||||
|   return getImpl()->elementTypes; | ||||
| } | ||||
| 
 | ||||
| mlir::Type SeqType::getElementType() { return getElementTypes()[0]; } | ||||
| 
 | ||||
| //===----------------------------------------------------------------------===//
 | ||||
| // TableGen'd op method definitions
 | ||||
| //===----------------------------------------------------------------------===//
 | ||||
| 
 | ||||
| #define GET_OP_CLASSES | ||||
| 
 | ||||
| #include "src/Dialect/ONNX/ONNXOps.cpp.inc" | ||||
|  |  | |||
|  | @ -79,6 +79,27 @@ public: | |||
|   } | ||||
| }; | ||||
| 
 | ||||
| namespace detail { | ||||
| struct SeqTypeStorage; | ||||
| } // namespace detail
 | ||||
| 
 | ||||
| class SeqType | ||||
|     : public mlir::Type::TypeBase<SeqType, mlir::Type, detail::SeqTypeStorage> { | ||||
| public: | ||||
|   using Base::Base; | ||||
|   static bool kindof(unsigned kind) { return kind == ONNXTypes::SEQ; } | ||||
| 
 | ||||
|   static unsigned getTypeKind() { return ONNXTypes::SEQ; } | ||||
| 
 | ||||
|   static SeqType get(llvm::ArrayRef<mlir::Type> elementTypes); | ||||
| 
 | ||||
|   llvm::ArrayRef<mlir::Type> getElementTypes(); | ||||
| 
 | ||||
|   mlir::Type getElementType(); | ||||
| 
 | ||||
|   size_t getNumElementTypes() { return getElementTypes().size(); } | ||||
| }; | ||||
| 
 | ||||
| } // end namespace onnxmlir
 | ||||
| 
 | ||||
| } // end namespace mlir
 | ||||
|  |  | |||
|  | @ -19,6 +19,12 @@ include "mlir/IR/OpBase.td" | |||
| 
 | ||||
| def StringType : Type<CPred<"$_self.isa<StringType>()">, "stirng type">; | ||||
| 
 | ||||
| def IsSeqTypePred : CPred<"$_self.isa<SeqType>()">; | ||||
| 
 | ||||
| class SeqOf<list<Type> allowedTypes> :  | ||||
|   ContainerType<AnyTypeOf<allowedTypes>, IsSeqTypePred, | ||||
|                 "$_self.cast<SeqType>().getElementType()",  "SeqType">; | ||||
| 
 | ||||
| #ifdef SHAPE_INFERENCE_INTERFACE | ||||
| #else | ||||
| include "src/Interface/ShapeInferenceInterface.td" | ||||
|  |  | |||
|  | @ -610,7 +610,7 @@ def ONNXConcatFromSequenceOp:ONNX_Op<"ConcatFromSequence", | |||
|   "By default 'new_axis' is 0, the behavior is similar to numpy.concatenate." | ||||
|   "When 'new_axis' is 1, the behavior is similar to numpy.stack." | ||||
|   }]; | ||||
|   let arguments = (ins AnyTypeOf<[TensorOf<[TensorOf<[UI8]>]>, TensorOf<[TensorOf<[UI16]>]>, TensorOf<[TensorOf<[UI32]>]>, TensorOf<[TensorOf<[UI64]>]>, TensorOf<[TensorOf<[I8]>]>, TensorOf<[TensorOf<[I16]>]>, TensorOf<[TensorOf<[I32]>]>, TensorOf<[TensorOf<[I64]>]>, TensorOf<[TensorOf<[F16]>]>, TensorOf<[TensorOf<[F32]>]>, TensorOf<[TensorOf<[F64]>]>, TensorOf<[TensorOf<[StringType]>]>, TensorOf<[TensorOf<[I1]>]>, TensorOf<[TensorOf<[Complex<F32>]>]>, TensorOf<[TensorOf<[Complex<F64>]>]>, AnyMemRef]>:$input_sequence, | ||||
|   let arguments = (ins AnyTypeOf<[SeqOf<[TensorOf<[UI8]>]>, SeqOf<[TensorOf<[UI16]>]>, SeqOf<[TensorOf<[UI32]>]>, SeqOf<[TensorOf<[UI64]>]>, SeqOf<[TensorOf<[I8]>]>, SeqOf<[TensorOf<[I16]>]>, SeqOf<[TensorOf<[I32]>]>, SeqOf<[TensorOf<[I64]>]>, SeqOf<[TensorOf<[F16]>]>, SeqOf<[TensorOf<[F32]>]>, SeqOf<[TensorOf<[F64]>]>, SeqOf<[TensorOf<[StringType]>]>, SeqOf<[TensorOf<[I1]>]>, SeqOf<[TensorOf<[Complex<F32>]>]>, SeqOf<[TensorOf<[Complex<F64>]>]>, AnyMemRef]>:$input_sequence, | ||||
|     I64Attr:$axis, | ||||
|     DefaultValuedAttr<I64Attr, "0">:$new_axis); | ||||
|   let results = (outs AnyTypeOf<[TensorOf<[UI8]>, TensorOf<[UI16]>, TensorOf<[UI32]>, TensorOf<[UI64]>, TensorOf<[I8]>, TensorOf<[I16]>, TensorOf<[I32]>, TensorOf<[I64]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, TensorOf<[StringType]>, TensorOf<[I1]>, TensorOf<[Complex<F32>]>, TensorOf<[Complex<F64>]>, AnyMemRef]>:$concat_result); | ||||
|  | @ -4582,7 +4582,7 @@ def ONNXSequenceAtOp:ONNX_Op<"SequenceAt", | |||
|   "Accepted range for 'position' is in `[-n, n - 1]`, where `n` is the number of tensors in 'input_sequence'." | ||||
|   "Negative value means counting positions from the back." | ||||
|   }]; | ||||
|   let arguments = (ins AnyTypeOf<[TensorOf<[TensorOf<[UI8]>]>, TensorOf<[TensorOf<[UI16]>]>, TensorOf<[TensorOf<[UI32]>]>, TensorOf<[TensorOf<[UI64]>]>, TensorOf<[TensorOf<[I8]>]>, TensorOf<[TensorOf<[I16]>]>, TensorOf<[TensorOf<[I32]>]>, TensorOf<[TensorOf<[I64]>]>, TensorOf<[TensorOf<[F16]>]>, TensorOf<[TensorOf<[F32]>]>, TensorOf<[TensorOf<[F64]>]>, TensorOf<[TensorOf<[StringType]>]>, TensorOf<[TensorOf<[I1]>]>, TensorOf<[TensorOf<[Complex<F32>]>]>, TensorOf<[TensorOf<[Complex<F64>]>]>, AnyMemRef]>:$input_sequence, | ||||
|   let arguments = (ins AnyTypeOf<[SeqOf<[TensorOf<[UI8]>]>, SeqOf<[TensorOf<[UI16]>]>, SeqOf<[TensorOf<[UI32]>]>, SeqOf<[TensorOf<[UI64]>]>, SeqOf<[TensorOf<[I8]>]>, SeqOf<[TensorOf<[I16]>]>, SeqOf<[TensorOf<[I32]>]>, SeqOf<[TensorOf<[I64]>]>, SeqOf<[TensorOf<[F16]>]>, SeqOf<[TensorOf<[F32]>]>, SeqOf<[TensorOf<[F64]>]>, SeqOf<[TensorOf<[StringType]>]>, SeqOf<[TensorOf<[I1]>]>, SeqOf<[TensorOf<[Complex<F32>]>]>, SeqOf<[TensorOf<[Complex<F64>]>]>, AnyMemRef]>:$input_sequence, | ||||
|     AnyTypeOf<[TensorOf<[I32]>, TensorOf<[I64]>, AnyMemRef]>:$position); | ||||
|   let results = (outs AnyTypeOf<[TensorOf<[UI8]>, TensorOf<[UI16]>, TensorOf<[UI32]>, TensorOf<[UI64]>, TensorOf<[I8]>, TensorOf<[I16]>, TensorOf<[I32]>, TensorOf<[I64]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, TensorOf<[StringType]>, TensorOf<[I1]>, TensorOf<[Complex<F32>]>, TensorOf<[Complex<F64>]>, AnyMemRef]>:$tensor); | ||||
|   let extraClassDeclaration = [{ | ||||
|  | @ -4606,7 +4606,7 @@ def ONNXSequenceConstructOp:ONNX_Op<"SequenceConstruct", | |||
|   "All tensors in 'inputs' must have the same data type." | ||||
|   }]; | ||||
|   let arguments = (ins Variadic<AnyTypeOf<[TensorOf<[UI8]>, TensorOf<[UI16]>, TensorOf<[UI32]>, TensorOf<[UI64]>, TensorOf<[I8]>, TensorOf<[I16]>, TensorOf<[I32]>, TensorOf<[I64]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, TensorOf<[StringType]>, TensorOf<[I1]>, TensorOf<[Complex<F32>]>, TensorOf<[Complex<F64>]>, AnyMemRef]>>:$inputs); | ||||
|   let results = (outs AnyTypeOf<[TensorOf<[TensorOf<[UI8]>]>, TensorOf<[TensorOf<[UI16]>]>, TensorOf<[TensorOf<[UI32]>]>, TensorOf<[TensorOf<[UI64]>]>, TensorOf<[TensorOf<[I8]>]>, TensorOf<[TensorOf<[I16]>]>, TensorOf<[TensorOf<[I32]>]>, TensorOf<[TensorOf<[I64]>]>, TensorOf<[TensorOf<[F16]>]>, TensorOf<[TensorOf<[F32]>]>, TensorOf<[TensorOf<[F64]>]>, TensorOf<[TensorOf<[StringType]>]>, TensorOf<[TensorOf<[I1]>]>, TensorOf<[TensorOf<[Complex<F32>]>]>, TensorOf<[TensorOf<[Complex<F64>]>]>, AnyMemRef]>:$output_sequence); | ||||
|   let results = (outs AnyTypeOf<[SeqOf<[TensorOf<[UI8]>]>, SeqOf<[TensorOf<[UI16]>]>, SeqOf<[TensorOf<[UI32]>]>, SeqOf<[TensorOf<[UI64]>]>, SeqOf<[TensorOf<[I8]>]>, SeqOf<[TensorOf<[I16]>]>, SeqOf<[TensorOf<[I32]>]>, SeqOf<[TensorOf<[I64]>]>, SeqOf<[TensorOf<[F16]>]>, SeqOf<[TensorOf<[F32]>]>, SeqOf<[TensorOf<[F64]>]>, SeqOf<[TensorOf<[StringType]>]>, SeqOf<[TensorOf<[I1]>]>, SeqOf<[TensorOf<[Complex<F32>]>]>, SeqOf<[TensorOf<[Complex<F64>]>]>, AnyMemRef]>:$output_sequence); | ||||
|   let extraClassDeclaration = [{ | ||||
|     static int getNumberOfOperands() { | ||||
|       return -1; | ||||
|  | @ -4627,7 +4627,7 @@ def ONNXSequenceEmptyOp:ONNX_Op<"SequenceEmpty", | |||
|   "Construct an empty tensor sequence, with given data type." | ||||
|   }]; | ||||
|   let arguments = (ins OptionalAttr<I64Attr>:$dtype); | ||||
|   let results = (outs AnyTypeOf<[TensorOf<[TensorOf<[UI8]>]>, TensorOf<[TensorOf<[UI16]>]>, TensorOf<[TensorOf<[UI32]>]>, TensorOf<[TensorOf<[UI64]>]>, TensorOf<[TensorOf<[I8]>]>, TensorOf<[TensorOf<[I16]>]>, TensorOf<[TensorOf<[I32]>]>, TensorOf<[TensorOf<[I64]>]>, TensorOf<[TensorOf<[F16]>]>, TensorOf<[TensorOf<[F32]>]>, TensorOf<[TensorOf<[F64]>]>, TensorOf<[TensorOf<[StringType]>]>, TensorOf<[TensorOf<[I1]>]>, TensorOf<[TensorOf<[Complex<F32>]>]>, TensorOf<[TensorOf<[Complex<F64>]>]>, AnyMemRef]>:$output); | ||||
|   let results = (outs AnyTypeOf<[SeqOf<[TensorOf<[UI8]>]>, SeqOf<[TensorOf<[UI16]>]>, SeqOf<[TensorOf<[UI32]>]>, SeqOf<[TensorOf<[UI64]>]>, SeqOf<[TensorOf<[I8]>]>, SeqOf<[TensorOf<[I16]>]>, SeqOf<[TensorOf<[I32]>]>, SeqOf<[TensorOf<[I64]>]>, SeqOf<[TensorOf<[F16]>]>, SeqOf<[TensorOf<[F32]>]>, SeqOf<[TensorOf<[F64]>]>, SeqOf<[TensorOf<[StringType]>]>, SeqOf<[TensorOf<[I1]>]>, SeqOf<[TensorOf<[Complex<F32>]>]>, SeqOf<[TensorOf<[Complex<F64>]>]>, AnyMemRef]>:$output); | ||||
|   let extraClassDeclaration = [{ | ||||
|     static int getNumberOfOperands() { | ||||
|       return 0; | ||||
|  | @ -4650,9 +4650,9 @@ def ONNXSequenceEraseOp:ONNX_Op<"SequenceErase", | |||
|   "Negative value means counting positions from the back." | ||||
|   "'position' is optional, by default it erases the last tensor from 'input_sequence'." | ||||
|   }]; | ||||
|   let arguments = (ins AnyTypeOf<[TensorOf<[TensorOf<[UI8]>]>, TensorOf<[TensorOf<[UI16]>]>, TensorOf<[TensorOf<[UI32]>]>, TensorOf<[TensorOf<[UI64]>]>, TensorOf<[TensorOf<[I8]>]>, TensorOf<[TensorOf<[I16]>]>, TensorOf<[TensorOf<[I32]>]>, TensorOf<[TensorOf<[I64]>]>, TensorOf<[TensorOf<[F16]>]>, TensorOf<[TensorOf<[F32]>]>, TensorOf<[TensorOf<[F64]>]>, TensorOf<[TensorOf<[StringType]>]>, TensorOf<[TensorOf<[I1]>]>, TensorOf<[TensorOf<[Complex<F32>]>]>, TensorOf<[TensorOf<[Complex<F64>]>]>, AnyMemRef]>:$input_sequence, | ||||
|   let arguments = (ins AnyTypeOf<[SeqOf<[TensorOf<[UI8]>]>, SeqOf<[TensorOf<[UI16]>]>, SeqOf<[TensorOf<[UI32]>]>, SeqOf<[TensorOf<[UI64]>]>, SeqOf<[TensorOf<[I8]>]>, SeqOf<[TensorOf<[I16]>]>, SeqOf<[TensorOf<[I32]>]>, SeqOf<[TensorOf<[I64]>]>, SeqOf<[TensorOf<[F16]>]>, SeqOf<[TensorOf<[F32]>]>, SeqOf<[TensorOf<[F64]>]>, SeqOf<[TensorOf<[StringType]>]>, SeqOf<[TensorOf<[I1]>]>, SeqOf<[TensorOf<[Complex<F32>]>]>, SeqOf<[TensorOf<[Complex<F64>]>]>, AnyMemRef]>:$input_sequence, | ||||
|     AnyTypeOf<[TensorOf<[I32]>, TensorOf<[I64]>, AnyMemRef, NoneType]>:$position); | ||||
|   let results = (outs AnyTypeOf<[TensorOf<[TensorOf<[UI8]>]>, TensorOf<[TensorOf<[UI16]>]>, TensorOf<[TensorOf<[UI32]>]>, TensorOf<[TensorOf<[UI64]>]>, TensorOf<[TensorOf<[I8]>]>, TensorOf<[TensorOf<[I16]>]>, TensorOf<[TensorOf<[I32]>]>, TensorOf<[TensorOf<[I64]>]>, TensorOf<[TensorOf<[F16]>]>, TensorOf<[TensorOf<[F32]>]>, TensorOf<[TensorOf<[F64]>]>, TensorOf<[TensorOf<[StringType]>]>, TensorOf<[TensorOf<[I1]>]>, TensorOf<[TensorOf<[Complex<F32>]>]>, TensorOf<[TensorOf<[Complex<F64>]>]>, AnyMemRef]>:$output_sequence); | ||||
|   let results = (outs AnyTypeOf<[SeqOf<[TensorOf<[UI8]>]>, SeqOf<[TensorOf<[UI16]>]>, SeqOf<[TensorOf<[UI32]>]>, SeqOf<[TensorOf<[UI64]>]>, SeqOf<[TensorOf<[I8]>]>, SeqOf<[TensorOf<[I16]>]>, SeqOf<[TensorOf<[I32]>]>, SeqOf<[TensorOf<[I64]>]>, SeqOf<[TensorOf<[F16]>]>, SeqOf<[TensorOf<[F32]>]>, SeqOf<[TensorOf<[F64]>]>, SeqOf<[TensorOf<[StringType]>]>, SeqOf<[TensorOf<[I1]>]>, SeqOf<[TensorOf<[Complex<F32>]>]>, SeqOf<[TensorOf<[Complex<F64>]>]>, AnyMemRef]>:$output_sequence); | ||||
|   let extraClassDeclaration = [{ | ||||
|     static int getNumberOfOperands() { | ||||
|       return 2; | ||||
|  | @ -4676,10 +4676,10 @@ def ONNXSequenceInsertOp:ONNX_Op<"SequenceInsert", | |||
|   "Negative value means counting positions from the back." | ||||
|   "'position' is optional, by default it inserts 'tensor' to the back of 'input_sequence'." | ||||
|   }]; | ||||
|   let arguments = (ins AnyTypeOf<[TensorOf<[TensorOf<[UI8]>]>, TensorOf<[TensorOf<[UI16]>]>, TensorOf<[TensorOf<[UI32]>]>, TensorOf<[TensorOf<[UI64]>]>, TensorOf<[TensorOf<[I8]>]>, TensorOf<[TensorOf<[I16]>]>, TensorOf<[TensorOf<[I32]>]>, TensorOf<[TensorOf<[I64]>]>, TensorOf<[TensorOf<[F16]>]>, TensorOf<[TensorOf<[F32]>]>, TensorOf<[TensorOf<[F64]>]>, TensorOf<[TensorOf<[StringType]>]>, TensorOf<[TensorOf<[I1]>]>, TensorOf<[TensorOf<[Complex<F32>]>]>, TensorOf<[TensorOf<[Complex<F64>]>]>, AnyMemRef]>:$input_sequence, | ||||
|   let arguments = (ins AnyTypeOf<[SeqOf<[TensorOf<[UI8]>]>, SeqOf<[TensorOf<[UI16]>]>, SeqOf<[TensorOf<[UI32]>]>, SeqOf<[TensorOf<[UI64]>]>, SeqOf<[TensorOf<[I8]>]>, SeqOf<[TensorOf<[I16]>]>, SeqOf<[TensorOf<[I32]>]>, SeqOf<[TensorOf<[I64]>]>, SeqOf<[TensorOf<[F16]>]>, SeqOf<[TensorOf<[F32]>]>, SeqOf<[TensorOf<[F64]>]>, SeqOf<[TensorOf<[StringType]>]>, SeqOf<[TensorOf<[I1]>]>, SeqOf<[TensorOf<[Complex<F32>]>]>, SeqOf<[TensorOf<[Complex<F64>]>]>, AnyMemRef]>:$input_sequence, | ||||
|     AnyTypeOf<[TensorOf<[UI8]>, TensorOf<[UI16]>, TensorOf<[UI32]>, TensorOf<[UI64]>, TensorOf<[I8]>, TensorOf<[I16]>, TensorOf<[I32]>, TensorOf<[I64]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, TensorOf<[StringType]>, TensorOf<[I1]>, TensorOf<[Complex<F32>]>, TensorOf<[Complex<F64>]>, AnyMemRef]>:$tensor, | ||||
|     AnyTypeOf<[TensorOf<[I32]>, TensorOf<[I64]>, AnyMemRef, NoneType]>:$position); | ||||
|   let results = (outs AnyTypeOf<[TensorOf<[TensorOf<[UI8]>]>, TensorOf<[TensorOf<[UI16]>]>, TensorOf<[TensorOf<[UI32]>]>, TensorOf<[TensorOf<[UI64]>]>, TensorOf<[TensorOf<[I8]>]>, TensorOf<[TensorOf<[I16]>]>, TensorOf<[TensorOf<[I32]>]>, TensorOf<[TensorOf<[I64]>]>, TensorOf<[TensorOf<[F16]>]>, TensorOf<[TensorOf<[F32]>]>, TensorOf<[TensorOf<[F64]>]>, TensorOf<[TensorOf<[StringType]>]>, TensorOf<[TensorOf<[I1]>]>, TensorOf<[TensorOf<[Complex<F32>]>]>, TensorOf<[TensorOf<[Complex<F64>]>]>, AnyMemRef]>:$output_sequence); | ||||
|   let results = (outs AnyTypeOf<[SeqOf<[TensorOf<[UI8]>]>, SeqOf<[TensorOf<[UI16]>]>, SeqOf<[TensorOf<[UI32]>]>, SeqOf<[TensorOf<[UI64]>]>, SeqOf<[TensorOf<[I8]>]>, SeqOf<[TensorOf<[I16]>]>, SeqOf<[TensorOf<[I32]>]>, SeqOf<[TensorOf<[I64]>]>, SeqOf<[TensorOf<[F16]>]>, SeqOf<[TensorOf<[F32]>]>, SeqOf<[TensorOf<[F64]>]>, SeqOf<[TensorOf<[StringType]>]>, SeqOf<[TensorOf<[I1]>]>, SeqOf<[TensorOf<[Complex<F32>]>]>, SeqOf<[TensorOf<[Complex<F64>]>]>, AnyMemRef]>:$output_sequence); | ||||
|   let extraClassDeclaration = [{ | ||||
|     static int getNumberOfOperands() { | ||||
|       return 3; | ||||
|  | @ -4699,7 +4699,7 @@ def ONNXSequenceLengthOp:ONNX_Op<"SequenceLength", | |||
|   let description = [{ | ||||
|   "Produces a scalar(tensor of empty shape) containing the number of tensors in 'input_sequence'." | ||||
|   }]; | ||||
|   let arguments = (ins AnyTypeOf<[TensorOf<[TensorOf<[UI8]>]>, TensorOf<[TensorOf<[UI16]>]>, TensorOf<[TensorOf<[UI32]>]>, TensorOf<[TensorOf<[UI64]>]>, TensorOf<[TensorOf<[I8]>]>, TensorOf<[TensorOf<[I16]>]>, TensorOf<[TensorOf<[I32]>]>, TensorOf<[TensorOf<[I64]>]>, TensorOf<[TensorOf<[F16]>]>, TensorOf<[TensorOf<[F32]>]>, TensorOf<[TensorOf<[F64]>]>, TensorOf<[TensorOf<[StringType]>]>, TensorOf<[TensorOf<[I1]>]>, TensorOf<[TensorOf<[Complex<F32>]>]>, TensorOf<[TensorOf<[Complex<F64>]>]>, AnyMemRef]>:$input_sequence); | ||||
|   let arguments = (ins AnyTypeOf<[SeqOf<[TensorOf<[UI8]>]>, SeqOf<[TensorOf<[UI16]>]>, SeqOf<[TensorOf<[UI32]>]>, SeqOf<[TensorOf<[UI64]>]>, SeqOf<[TensorOf<[I8]>]>, SeqOf<[TensorOf<[I16]>]>, SeqOf<[TensorOf<[I32]>]>, SeqOf<[TensorOf<[I64]>]>, SeqOf<[TensorOf<[F16]>]>, SeqOf<[TensorOf<[F32]>]>, SeqOf<[TensorOf<[F64]>]>, SeqOf<[TensorOf<[StringType]>]>, SeqOf<[TensorOf<[I1]>]>, SeqOf<[TensorOf<[Complex<F32>]>]>, SeqOf<[TensorOf<[Complex<F64>]>]>, AnyMemRef]>:$input_sequence); | ||||
|   let results = (outs AnyTypeOf<[TensorOf<[I64]>, AnyMemRef]>:$length); | ||||
|   let extraClassDeclaration = [{ | ||||
|     static int getNumberOfOperands() { | ||||
|  | @ -5074,7 +5074,7 @@ def ONNXSplitToSequenceOp:ONNX_Op<"SplitToSequence", | |||
|     AnyTypeOf<[TensorOf<[I32]>, TensorOf<[I64]>, AnyMemRef, NoneType]>:$split, | ||||
|     DefaultValuedAttr<I64Attr, "0">:$axis, | ||||
|     DefaultValuedAttr<I64Attr, "1">:$keepdims); | ||||
|   let results = (outs AnyTypeOf<[TensorOf<[TensorOf<[UI8]>]>, TensorOf<[TensorOf<[UI16]>]>, TensorOf<[TensorOf<[UI32]>]>, TensorOf<[TensorOf<[UI64]>]>, TensorOf<[TensorOf<[I8]>]>, TensorOf<[TensorOf<[I16]>]>, TensorOf<[TensorOf<[I32]>]>, TensorOf<[TensorOf<[I64]>]>, TensorOf<[TensorOf<[F16]>]>, TensorOf<[TensorOf<[F32]>]>, TensorOf<[TensorOf<[F64]>]>, TensorOf<[TensorOf<[StringType]>]>, TensorOf<[TensorOf<[I1]>]>, TensorOf<[TensorOf<[Complex<F32>]>]>, TensorOf<[TensorOf<[Complex<F64>]>]>, AnyMemRef]>:$output_sequence); | ||||
|   let results = (outs AnyTypeOf<[SeqOf<[TensorOf<[UI8]>]>, SeqOf<[TensorOf<[UI16]>]>, SeqOf<[TensorOf<[UI32]>]>, SeqOf<[TensorOf<[UI64]>]>, SeqOf<[TensorOf<[I8]>]>, SeqOf<[TensorOf<[I16]>]>, SeqOf<[TensorOf<[I32]>]>, SeqOf<[TensorOf<[I64]>]>, SeqOf<[TensorOf<[F16]>]>, SeqOf<[TensorOf<[F32]>]>, SeqOf<[TensorOf<[F64]>]>, SeqOf<[TensorOf<[StringType]>]>, SeqOf<[TensorOf<[I1]>]>, SeqOf<[TensorOf<[Complex<F32>]>]>, SeqOf<[TensorOf<[Complex<F64>]>]>, AnyMemRef]>:$output_sequence); | ||||
|   let extraClassDeclaration = [{ | ||||
|     static int getNumberOfOperands() { | ||||
|       return 2; | ||||
|  | @ -6240,7 +6240,7 @@ def ONNXZipMapOp:ONNX_Op<"ZipMap", | |||
|   let arguments = (ins TensorOf<[F32]>:$X, | ||||
|     OptionalAttr<I64ArrayAttr>:$classlabels_int64s, | ||||
|     OptionalAttr<StrArrayAttr>:$classlabels_strings); | ||||
|   let results = (outs AnyTypeOf<[TensorOf<[TupleOf<[StringType, F32]>]>, TensorOf<[TupleOf<[I64, F32]>]>, AnyMemRef]>:$Z); | ||||
|   let results = (outs AnyTypeOf<[SeqOf<[TupleOf<[StringType, F32]>]>, SeqOf<[TupleOf<[I64, F32]>]>, AnyMemRef]>:$Z); | ||||
|   let extraClassDeclaration = [{ | ||||
|     static int getNumberOfOperands() { | ||||
|       return 1; | ||||
|  |  | |||
|  | @ -14,3 +14,14 @@ func @check_string(%arg0: tensor<10x20x!onnx.String>) -> tensor<10x20x!onnx.Stri | |||
|   // CHECK-NEXT: return %arg0 : tensor<10x20x!onnx.String> | ||||
| } | ||||
| 
 | ||||
| // CHECK-LABEL: @check_seq(%arg0: tensor<10x20xf32>, %arg1: tensor<5x20xf32>) -> tensor<*xf32> { | ||||
| func @check_seq(%arg0: tensor<10x20xf32>, %arg1: tensor<5x20xf32>) -> tensor<*xf32> { | ||||
|   %cst = "onnx.Constant"() {value = dense<[0]> : tensor<1xi32>} : () -> tensor<1xi32> | ||||
|   %0 = "onnx.SequenceConstruct"(%arg0, %arg1) : (tensor<10x20xf32>, tensor<5x20xf32>) -> !onnx.Seq<tensor<10x20xf32>, tensor<5x20xf32>> | ||||
|   %1 = "onnx.SequenceAt"(%0, %cst) : (!onnx.Seq<tensor<10x20xf32>, tensor<5x20xf32>>, tensor<1xi32>) -> tensor<*xf32> | ||||
|   return %1 : tensor<*xf32> | ||||
|   // CHECK-NEXT: %0 = "onnx.Constant"() {value = dense<0> : tensor<1xi32>} : () -> tensor<1xi32> | ||||
|   // CHECK-NEXT: %1 = "onnx.SequenceConstruct"(%arg0, %arg1) : (tensor<10x20xf32>, tensor<5x20xf32>) -> !onnx.Seq<tensor<10x20xf32>, tensor<5x20xf32>> | ||||
|   // CHECK-NEXT: %2 = "onnx.SequenceAt"(%1, %0) : (!onnx.Seq<tensor<10x20xf32>, tensor<5x20xf32>>, tensor<1xi32>) -> tensor<*xf32> | ||||
| } | ||||
| 
 | ||||
|  |  | |||
|  | @ -728,7 +728,7 @@ def parse_type_str(allowedType): | |||
|     onnx_to_mlir_type_dict = { '(': '<[', | ||||
|         ')': ']>', | ||||
|         'tensor' : 'TensorOf', | ||||
|         'seq' : 'TensorOf', | ||||
|         'seq' : 'SeqOf', | ||||
|         'map' : 'TupleOf', | ||||
|         'bool': 'I1', | ||||
|         #'uint8' : 'AnyI8', | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue