Seq type (#199)
* base implementation * add example * change table gen * docs * small change for review Co-authored-by: Alexandre Eichenberger <alexe@us.ibm.com> Co-authored-by: Tian Jin <tjingrant@gmail.com>
This commit is contained in:
parent
8e8f894574
commit
b4228fd288
|
@ -637,7 +637,7 @@ ONNX ConcatFromSequence operation
|
|||
|
||||
| Operand | Description |
|
||||
| :-----: | ----------- |
|
||||
`input_sequence` | tensor of tensor of 8-bit unsigned integer values values or tensor of tensor of 16-bit unsigned integer values values or tensor of tensor of 32-bit unsigned integer values values or tensor of tensor of 64-bit unsigned integer values values or tensor of tensor of 8-bit signless integer values values or tensor of tensor of 16-bit signless integer values values or tensor of tensor of 32-bit signless integer values values or tensor of tensor of 64-bit signless integer values values or tensor of tensor of 16-bit float values values or tensor of tensor of 32-bit float values values or tensor of tensor of 64-bit float values values or tensor of tensor of stirng type values values or tensor of tensor of 1-bit signless integer values values or tensor of tensor of complex type with 32-bit float elements values values or tensor of tensor of complex type with 64-bit float elements values values or memref of any type values
|
||||
`input_sequence` | SeqType of tensor of 8-bit unsigned integer values values or SeqType of tensor of 16-bit unsigned integer values values or SeqType of tensor of 32-bit unsigned integer values values or SeqType of tensor of 64-bit unsigned integer values values or SeqType of tensor of 8-bit signless integer values values or SeqType of tensor of 16-bit signless integer values values or SeqType of tensor of 32-bit signless integer values values or SeqType of tensor of 64-bit signless integer values values or SeqType of tensor of 16-bit float values values or SeqType of tensor of 32-bit float values values or SeqType of tensor of 64-bit float values values or SeqType of tensor of stirng type values values or SeqType of tensor of 1-bit signless integer values values or SeqType of tensor of complex type with 32-bit float elements values values or SeqType of tensor of complex type with 64-bit float elements values values or memref of any type values
|
||||
|
||||
#### Results:
|
||||
|
||||
|
@ -4833,7 +4833,7 @@ ONNX SequenceAt operation
|
|||
|
||||
| Operand | Description |
|
||||
| :-----: | ----------- |
|
||||
`input_sequence` | tensor of tensor of 8-bit unsigned integer values values or tensor of tensor of 16-bit unsigned integer values values or tensor of tensor of 32-bit unsigned integer values values or tensor of tensor of 64-bit unsigned integer values values or tensor of tensor of 8-bit signless integer values values or tensor of tensor of 16-bit signless integer values values or tensor of tensor of 32-bit signless integer values values or tensor of tensor of 64-bit signless integer values values or tensor of tensor of 16-bit float values values or tensor of tensor of 32-bit float values values or tensor of tensor of 64-bit float values values or tensor of tensor of stirng type values values or tensor of tensor of 1-bit signless integer values values or tensor of tensor of complex type with 32-bit float elements values values or tensor of tensor of complex type with 64-bit float elements values values or memref of any type values
|
||||
`input_sequence` | SeqType of tensor of 8-bit unsigned integer values values or SeqType of tensor of 16-bit unsigned integer values values or SeqType of tensor of 32-bit unsigned integer values values or SeqType of tensor of 64-bit unsigned integer values values or SeqType of tensor of 8-bit signless integer values values or SeqType of tensor of 16-bit signless integer values values or SeqType of tensor of 32-bit signless integer values values or SeqType of tensor of 64-bit signless integer values values or SeqType of tensor of 16-bit float values values or SeqType of tensor of 32-bit float values values or SeqType of tensor of 64-bit float values values or SeqType of tensor of stirng type values values or SeqType of tensor of 1-bit signless integer values values or SeqType of tensor of complex type with 32-bit float elements values values or SeqType of tensor of complex type with 64-bit float elements values values or memref of any type values
|
||||
`position` | tensor of 32-bit signless integer values or tensor of 64-bit signless integer values or memref of any type values
|
||||
|
||||
#### Results:
|
||||
|
@ -4859,7 +4859,7 @@ ONNX SequenceConstruct operation
|
|||
|
||||
| Result | Description |
|
||||
| :----: | ----------- |
|
||||
`output_sequence` | tensor of tensor of 8-bit unsigned integer values values or tensor of tensor of 16-bit unsigned integer values values or tensor of tensor of 32-bit unsigned integer values values or tensor of tensor of 64-bit unsigned integer values values or tensor of tensor of 8-bit signless integer values values or tensor of tensor of 16-bit signless integer values values or tensor of tensor of 32-bit signless integer values values or tensor of tensor of 64-bit signless integer values values or tensor of tensor of 16-bit float values values or tensor of tensor of 32-bit float values values or tensor of tensor of 64-bit float values values or tensor of tensor of stirng type values values or tensor of tensor of 1-bit signless integer values values or tensor of tensor of complex type with 32-bit float elements values values or tensor of tensor of complex type with 64-bit float elements values values or memref of any type values
|
||||
`output_sequence` | SeqType of tensor of 8-bit unsigned integer values values or SeqType of tensor of 16-bit unsigned integer values values or SeqType of tensor of 32-bit unsigned integer values values or SeqType of tensor of 64-bit unsigned integer values values or SeqType of tensor of 8-bit signless integer values values or SeqType of tensor of 16-bit signless integer values values or SeqType of tensor of 32-bit signless integer values values or SeqType of tensor of 64-bit signless integer values values or SeqType of tensor of 16-bit float values values or SeqType of tensor of 32-bit float values values or SeqType of tensor of 64-bit float values values or SeqType of tensor of stirng type values values or SeqType of tensor of 1-bit signless integer values values or SeqType of tensor of complex type with 32-bit float elements values values or SeqType of tensor of complex type with 64-bit float elements values values or memref of any type values
|
||||
|
||||
### `onnx.SequenceEmpty` (ONNXSequenceEmptyOp)
|
||||
|
||||
|
@ -4877,7 +4877,7 @@ ONNX SequenceEmpty operation
|
|||
|
||||
| Result | Description |
|
||||
| :----: | ----------- |
|
||||
`output` | tensor of tensor of 8-bit unsigned integer values values or tensor of tensor of 16-bit unsigned integer values values or tensor of tensor of 32-bit unsigned integer values values or tensor of tensor of 64-bit unsigned integer values values or tensor of tensor of 8-bit signless integer values values or tensor of tensor of 16-bit signless integer values values or tensor of tensor of 32-bit signless integer values values or tensor of tensor of 64-bit signless integer values values or tensor of tensor of 16-bit float values values or tensor of tensor of 32-bit float values values or tensor of tensor of 64-bit float values values or tensor of tensor of stirng type values values or tensor of tensor of 1-bit signless integer values values or tensor of tensor of complex type with 32-bit float elements values values or tensor of tensor of complex type with 64-bit float elements values values or memref of any type values
|
||||
`output` | SeqType of tensor of 8-bit unsigned integer values values or SeqType of tensor of 16-bit unsigned integer values values or SeqType of tensor of 32-bit unsigned integer values values or SeqType of tensor of 64-bit unsigned integer values values or SeqType of tensor of 8-bit signless integer values values or SeqType of tensor of 16-bit signless integer values values or SeqType of tensor of 32-bit signless integer values values or SeqType of tensor of 64-bit signless integer values values or SeqType of tensor of 16-bit float values values or SeqType of tensor of 32-bit float values values or SeqType of tensor of 64-bit float values values or SeqType of tensor of stirng type values values or SeqType of tensor of 1-bit signless integer values values or SeqType of tensor of complex type with 32-bit float elements values values or SeqType of tensor of complex type with 64-bit float elements values values or memref of any type values
|
||||
|
||||
### `onnx.SequenceErase` (ONNXSequenceEraseOp)
|
||||
|
||||
|
@ -4892,14 +4892,14 @@ ONNX SequenceErase operation
|
|||
|
||||
| Operand | Description |
|
||||
| :-----: | ----------- |
|
||||
`input_sequence` | tensor of tensor of 8-bit unsigned integer values values or tensor of tensor of 16-bit unsigned integer values values or tensor of tensor of 32-bit unsigned integer values values or tensor of tensor of 64-bit unsigned integer values values or tensor of tensor of 8-bit signless integer values values or tensor of tensor of 16-bit signless integer values values or tensor of tensor of 32-bit signless integer values values or tensor of tensor of 64-bit signless integer values values or tensor of tensor of 16-bit float values values or tensor of tensor of 32-bit float values values or tensor of tensor of 64-bit float values values or tensor of tensor of stirng type values values or tensor of tensor of 1-bit signless integer values values or tensor of tensor of complex type with 32-bit float elements values values or tensor of tensor of complex type with 64-bit float elements values values or memref of any type values
|
||||
`input_sequence` | SeqType of tensor of 8-bit unsigned integer values values or SeqType of tensor of 16-bit unsigned integer values values or SeqType of tensor of 32-bit unsigned integer values values or SeqType of tensor of 64-bit unsigned integer values values or SeqType of tensor of 8-bit signless integer values values or SeqType of tensor of 16-bit signless integer values values or SeqType of tensor of 32-bit signless integer values values or SeqType of tensor of 64-bit signless integer values values or SeqType of tensor of 16-bit float values values or SeqType of tensor of 32-bit float values values or SeqType of tensor of 64-bit float values values or SeqType of tensor of stirng type values values or SeqType of tensor of 1-bit signless integer values values or SeqType of tensor of complex type with 32-bit float elements values values or SeqType of tensor of complex type with 64-bit float elements values values or memref of any type values
|
||||
`position` | tensor of 32-bit signless integer values or tensor of 64-bit signless integer values or memref of any type values or none type
|
||||
|
||||
#### Results:
|
||||
|
||||
| Result | Description |
|
||||
| :----: | ----------- |
|
||||
`output_sequence` | tensor of tensor of 8-bit unsigned integer values values or tensor of tensor of 16-bit unsigned integer values values or tensor of tensor of 32-bit unsigned integer values values or tensor of tensor of 64-bit unsigned integer values values or tensor of tensor of 8-bit signless integer values values or tensor of tensor of 16-bit signless integer values values or tensor of tensor of 32-bit signless integer values values or tensor of tensor of 64-bit signless integer values values or tensor of tensor of 16-bit float values values or tensor of tensor of 32-bit float values values or tensor of tensor of 64-bit float values values or tensor of tensor of stirng type values values or tensor of tensor of 1-bit signless integer values values or tensor of tensor of complex type with 32-bit float elements values values or tensor of tensor of complex type with 64-bit float elements values values or memref of any type values
|
||||
`output_sequence` | SeqType of tensor of 8-bit unsigned integer values values or SeqType of tensor of 16-bit unsigned integer values values or SeqType of tensor of 32-bit unsigned integer values values or SeqType of tensor of 64-bit unsigned integer values values or SeqType of tensor of 8-bit signless integer values values or SeqType of tensor of 16-bit signless integer values values or SeqType of tensor of 32-bit signless integer values values or SeqType of tensor of 64-bit signless integer values values or SeqType of tensor of 16-bit float values values or SeqType of tensor of 32-bit float values values or SeqType of tensor of 64-bit float values values or SeqType of tensor of stirng type values values or SeqType of tensor of 1-bit signless integer values values or SeqType of tensor of complex type with 32-bit float elements values values or SeqType of tensor of complex type with 64-bit float elements values values or memref of any type values
|
||||
|
||||
### `onnx.SequenceInsert` (ONNXSequenceInsertOp)
|
||||
|
||||
|
@ -4915,7 +4915,7 @@ ONNX SequenceInsert operation
|
|||
|
||||
| Operand | Description |
|
||||
| :-----: | ----------- |
|
||||
`input_sequence` | tensor of tensor of 8-bit unsigned integer values values or tensor of tensor of 16-bit unsigned integer values values or tensor of tensor of 32-bit unsigned integer values values or tensor of tensor of 64-bit unsigned integer values values or tensor of tensor of 8-bit signless integer values values or tensor of tensor of 16-bit signless integer values values or tensor of tensor of 32-bit signless integer values values or tensor of tensor of 64-bit signless integer values values or tensor of tensor of 16-bit float values values or tensor of tensor of 32-bit float values values or tensor of tensor of 64-bit float values values or tensor of tensor of stirng type values values or tensor of tensor of 1-bit signless integer values values or tensor of tensor of complex type with 32-bit float elements values values or tensor of tensor of complex type with 64-bit float elements values values or memref of any type values
|
||||
`input_sequence` | SeqType of tensor of 8-bit unsigned integer values values or SeqType of tensor of 16-bit unsigned integer values values or SeqType of tensor of 32-bit unsigned integer values values or SeqType of tensor of 64-bit unsigned integer values values or SeqType of tensor of 8-bit signless integer values values or SeqType of tensor of 16-bit signless integer values values or SeqType of tensor of 32-bit signless integer values values or SeqType of tensor of 64-bit signless integer values values or SeqType of tensor of 16-bit float values values or SeqType of tensor of 32-bit float values values or SeqType of tensor of 64-bit float values values or SeqType of tensor of stirng type values values or SeqType of tensor of 1-bit signless integer values values or SeqType of tensor of complex type with 32-bit float elements values values or SeqType of tensor of complex type with 64-bit float elements values values or memref of any type values
|
||||
`tensor` | tensor of 8-bit unsigned integer values or tensor of 16-bit unsigned integer values or tensor of 32-bit unsigned integer values or tensor of 64-bit unsigned integer values or tensor of 8-bit signless integer values or tensor of 16-bit signless integer values or tensor of 32-bit signless integer values or tensor of 64-bit signless integer values or tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values or tensor of stirng type values or tensor of 1-bit signless integer values or tensor of complex type with 32-bit float elements values or tensor of complex type with 64-bit float elements values or memref of any type values
|
||||
`position` | tensor of 32-bit signless integer values or tensor of 64-bit signless integer values or memref of any type values or none type
|
||||
|
||||
|
@ -4923,7 +4923,7 @@ ONNX SequenceInsert operation
|
|||
|
||||
| Result | Description |
|
||||
| :----: | ----------- |
|
||||
`output_sequence` | tensor of tensor of 8-bit unsigned integer values values or tensor of tensor of 16-bit unsigned integer values values or tensor of tensor of 32-bit unsigned integer values values or tensor of tensor of 64-bit unsigned integer values values or tensor of tensor of 8-bit signless integer values values or tensor of tensor of 16-bit signless integer values values or tensor of tensor of 32-bit signless integer values values or tensor of tensor of 64-bit signless integer values values or tensor of tensor of 16-bit float values values or tensor of tensor of 32-bit float values values or tensor of tensor of 64-bit float values values or tensor of tensor of stirng type values values or tensor of tensor of 1-bit signless integer values values or tensor of tensor of complex type with 32-bit float elements values values or tensor of tensor of complex type with 64-bit float elements values values or memref of any type values
|
||||
`output_sequence` | SeqType of tensor of 8-bit unsigned integer values values or SeqType of tensor of 16-bit unsigned integer values values or SeqType of tensor of 32-bit unsigned integer values values or SeqType of tensor of 64-bit unsigned integer values values or SeqType of tensor of 8-bit signless integer values values or SeqType of tensor of 16-bit signless integer values values or SeqType of tensor of 32-bit signless integer values values or SeqType of tensor of 64-bit signless integer values values or SeqType of tensor of 16-bit float values values or SeqType of tensor of 32-bit float values values or SeqType of tensor of 64-bit float values values or SeqType of tensor of stirng type values values or SeqType of tensor of 1-bit signless integer values values or SeqType of tensor of complex type with 32-bit float elements values values or SeqType of tensor of complex type with 64-bit float elements values values or memref of any type values
|
||||
|
||||
### `onnx.SequenceLength` (ONNXSequenceLengthOp)
|
||||
|
||||
|
@ -4935,7 +4935,7 @@ ONNX SequenceLength operation
|
|||
|
||||
| Operand | Description |
|
||||
| :-----: | ----------- |
|
||||
`input_sequence` | tensor of tensor of 8-bit unsigned integer values values or tensor of tensor of 16-bit unsigned integer values values or tensor of tensor of 32-bit unsigned integer values values or tensor of tensor of 64-bit unsigned integer values values or tensor of tensor of 8-bit signless integer values values or tensor of tensor of 16-bit signless integer values values or tensor of tensor of 32-bit signless integer values values or tensor of tensor of 64-bit signless integer values values or tensor of tensor of 16-bit float values values or tensor of tensor of 32-bit float values values or tensor of tensor of 64-bit float values values or tensor of tensor of stirng type values values or tensor of tensor of 1-bit signless integer values values or tensor of tensor of complex type with 32-bit float elements values values or tensor of tensor of complex type with 64-bit float elements values values or memref of any type values
|
||||
`input_sequence` | SeqType of tensor of 8-bit unsigned integer values values or SeqType of tensor of 16-bit unsigned integer values values or SeqType of tensor of 32-bit unsigned integer values values or SeqType of tensor of 64-bit unsigned integer values values or SeqType of tensor of 8-bit signless integer values values or SeqType of tensor of 16-bit signless integer values values or SeqType of tensor of 32-bit signless integer values values or SeqType of tensor of 64-bit signless integer values values or SeqType of tensor of 16-bit float values values or SeqType of tensor of 32-bit float values values or SeqType of tensor of 64-bit float values values or SeqType of tensor of stirng type values values or SeqType of tensor of 1-bit signless integer values values or SeqType of tensor of complex type with 32-bit float elements values values or SeqType of tensor of complex type with 64-bit float elements values values or memref of any type values
|
||||
|
||||
#### Results:
|
||||
|
||||
|
@ -5298,7 +5298,7 @@ ONNX SplitToSequence operation
|
|||
|
||||
| Result | Description |
|
||||
| :----: | ----------- |
|
||||
`output_sequence` | tensor of tensor of 8-bit unsigned integer values values or tensor of tensor of 16-bit unsigned integer values values or tensor of tensor of 32-bit unsigned integer values values or tensor of tensor of 64-bit unsigned integer values values or tensor of tensor of 8-bit signless integer values values or tensor of tensor of 16-bit signless integer values values or tensor of tensor of 32-bit signless integer values values or tensor of tensor of 64-bit signless integer values values or tensor of tensor of 16-bit float values values or tensor of tensor of 32-bit float values values or tensor of tensor of 64-bit float values values or tensor of tensor of stirng type values values or tensor of tensor of 1-bit signless integer values values or tensor of tensor of complex type with 32-bit float elements values values or tensor of tensor of complex type with 64-bit float elements values values or memref of any type values
|
||||
`output_sequence` | SeqType of tensor of 8-bit unsigned integer values values or SeqType of tensor of 16-bit unsigned integer values values or SeqType of tensor of 32-bit unsigned integer values values or SeqType of tensor of 64-bit unsigned integer values values or SeqType of tensor of 8-bit signless integer values values or SeqType of tensor of 16-bit signless integer values values or SeqType of tensor of 32-bit signless integer values values or SeqType of tensor of 64-bit signless integer values values or SeqType of tensor of 16-bit float values values or SeqType of tensor of 32-bit float values values or SeqType of tensor of 64-bit float values values or SeqType of tensor of stirng type values values or SeqType of tensor of 1-bit signless integer values values or SeqType of tensor of complex type with 32-bit float elements values values or SeqType of tensor of complex type with 64-bit float elements values values or memref of any type values
|
||||
|
||||
### `onnx.Sqrt` (ONNXSqrtOp)
|
||||
|
||||
|
@ -5964,5 +5964,5 @@ ONNX ZipMap operation
|
|||
|
||||
| Result | Description |
|
||||
| :----: | ----------- |
|
||||
`Z` | tensor of tuple with any combination of stirng type or 32-bit float values values or tensor of tuple with any combination of 64-bit signless integer or 32-bit float values values or memref of any type values
|
||||
`Z` | SeqType of tuple with any combination of stirng type or 32-bit float values values or SeqType of tuple with any combination of 64-bit signless integer or 32-bit float values values or memref of any type values
|
||||
|
||||
|
|
|
@ -484,18 +484,48 @@ ONNXOpsDialect::ONNXOpsDialect(mlir::MLIRContext *ctx)
|
|||
#include "src/Dialect/ONNX/ONNXOps.cpp.inc"
|
||||
>();
|
||||
addTypes<StringType>();
|
||||
addTypes<SeqType>();
|
||||
}
|
||||
|
||||
mlir::Type ONNXOpsDialect::parseType(mlir::DialectAsmParser &parser) const {
|
||||
if (parser.parseKeyword("String"))
|
||||
StringRef keyword;
|
||||
if (parser.parseKeyword(&keyword))
|
||||
return Type();
|
||||
|
||||
if (keyword == "String")
|
||||
return StringType::get(getContext());
|
||||
if (keyword == "Seq") {
|
||||
if (parser.parseLess())
|
||||
return Type();
|
||||
|
||||
SmallVector<mlir::Type, 1> elementTypes;
|
||||
do {
|
||||
llvm::SMLoc typeLoc = parser.getCurrentLocation();
|
||||
mlir::Type elementType;
|
||||
if (parser.parseType(elementType))
|
||||
return Type();
|
||||
|
||||
// TOFIX: type limitation for Seq? similar but different shape??
|
||||
elementTypes.push_back(elementType);
|
||||
} while (succeeded(parser.parseOptionalComma()));
|
||||
|
||||
if (parser.parseGreater())
|
||||
return Type();
|
||||
return SeqType::get(elementTypes);
|
||||
}
|
||||
}
|
||||
|
||||
void ONNXOpsDialect::printType(
|
||||
mlir::Type type, mlir::DialectAsmPrinter &printer) const {
|
||||
if (auto stringType = type.dyn_cast<StringType>()) {
|
||||
printer << "String";
|
||||
} else if (auto seqType = type.dyn_cast<SeqType>()) {
|
||||
printer << "Seq<";
|
||||
llvm::interleaveComma(seqType.getElementTypes(), printer);
|
||||
printer << '>';
|
||||
} else {
|
||||
llvm_unreachable("Unexpected onnxmlir type");
|
||||
}
|
||||
}
|
||||
|
||||
void ONNXEntryPointOp::build(mlir::OpBuilder &builder,
|
||||
|
@ -2625,9 +2655,56 @@ LogicalResult ONNXSliceOp::inferShapes() {
|
|||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ONNX type related code
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
namespace mlir {
|
||||
namespace onnxmlir {
|
||||
namespace detail {
|
||||
struct SeqTypeStorage : public mlir::TypeStorage {
|
||||
using KeyTy = llvm::ArrayRef<mlir::Type>;
|
||||
|
||||
SeqTypeStorage(llvm::ArrayRef<mlir::Type> elementTypes)
|
||||
: elementTypes(elementTypes) {}
|
||||
|
||||
bool operator==(const KeyTy &key) const { return key == elementTypes; }
|
||||
static llvm::hash_code hasKey(const KeyTy &key) {
|
||||
return llvm::hash_value(key);
|
||||
}
|
||||
|
||||
static KeyTy getKey(llvm::ArrayRef<mlir::Type> elementTypes) {
|
||||
return KeyTy(elementTypes);
|
||||
}
|
||||
|
||||
static SeqTypeStorage *construct(
|
||||
mlir::TypeStorageAllocator &allocator, const KeyTy &key) {
|
||||
llvm::ArrayRef<mlir::Type> elementTypes = allocator.copyInto(key);
|
||||
return new (allocator.allocate<SeqTypeStorage>())
|
||||
SeqTypeStorage(elementTypes);
|
||||
}
|
||||
llvm::ArrayRef<mlir::Type> elementTypes;
|
||||
};
|
||||
} // end namespace detail
|
||||
} // end namespace onnxmlir
|
||||
} // end namespace mlir
|
||||
|
||||
SeqType SeqType::get(llvm::ArrayRef<mlir::Type> elementTypes) {
|
||||
assert(!elementTypes.empty() && "expected non-empty seq");
|
||||
mlir::MLIRContext *ctx = elementTypes.front().getContext();
|
||||
return Base::get(ctx, ONNXTypes::SEQ, elementTypes);
|
||||
}
|
||||
|
||||
llvm::ArrayRef<mlir::Type> SeqType::getElementTypes() {
|
||||
return getImpl()->elementTypes;
|
||||
}
|
||||
|
||||
mlir::Type SeqType::getElementType() { return getElementTypes()[0]; }
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TableGen'd op method definitions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#define GET_OP_CLASSES
|
||||
|
||||
#include "src/Dialect/ONNX/ONNXOps.cpp.inc"
|
||||
|
|
|
@ -79,6 +79,27 @@ public:
|
|||
}
|
||||
};
|
||||
|
||||
namespace detail {
|
||||
struct SeqTypeStorage;
|
||||
} // namespace detail
|
||||
|
||||
class SeqType
|
||||
: public mlir::Type::TypeBase<SeqType, mlir::Type, detail::SeqTypeStorage> {
|
||||
public:
|
||||
using Base::Base;
|
||||
static bool kindof(unsigned kind) { return kind == ONNXTypes::SEQ; }
|
||||
|
||||
static unsigned getTypeKind() { return ONNXTypes::SEQ; }
|
||||
|
||||
static SeqType get(llvm::ArrayRef<mlir::Type> elementTypes);
|
||||
|
||||
llvm::ArrayRef<mlir::Type> getElementTypes();
|
||||
|
||||
mlir::Type getElementType();
|
||||
|
||||
size_t getNumElementTypes() { return getElementTypes().size(); }
|
||||
};
|
||||
|
||||
} // end namespace onnxmlir
|
||||
|
||||
} // end namespace mlir
|
||||
|
|
|
@ -19,6 +19,12 @@ include "mlir/IR/OpBase.td"
|
|||
|
||||
def StringType : Type<CPred<"$_self.isa<StringType>()">, "stirng type">;
|
||||
|
||||
def IsSeqTypePred : CPred<"$_self.isa<SeqType>()">;
|
||||
|
||||
class SeqOf<list<Type> allowedTypes> :
|
||||
ContainerType<AnyTypeOf<allowedTypes>, IsSeqTypePred,
|
||||
"$_self.cast<SeqType>().getElementType()", "SeqType">;
|
||||
|
||||
#ifdef SHAPE_INFERENCE_INTERFACE
|
||||
#else
|
||||
include "src/Interface/ShapeInferenceInterface.td"
|
||||
|
|
|
@ -610,7 +610,7 @@ def ONNXConcatFromSequenceOp:ONNX_Op<"ConcatFromSequence",
|
|||
"By default 'new_axis' is 0, the behavior is similar to numpy.concatenate."
|
||||
"When 'new_axis' is 1, the behavior is similar to numpy.stack."
|
||||
}];
|
||||
let arguments = (ins AnyTypeOf<[TensorOf<[TensorOf<[UI8]>]>, TensorOf<[TensorOf<[UI16]>]>, TensorOf<[TensorOf<[UI32]>]>, TensorOf<[TensorOf<[UI64]>]>, TensorOf<[TensorOf<[I8]>]>, TensorOf<[TensorOf<[I16]>]>, TensorOf<[TensorOf<[I32]>]>, TensorOf<[TensorOf<[I64]>]>, TensorOf<[TensorOf<[F16]>]>, TensorOf<[TensorOf<[F32]>]>, TensorOf<[TensorOf<[F64]>]>, TensorOf<[TensorOf<[StringType]>]>, TensorOf<[TensorOf<[I1]>]>, TensorOf<[TensorOf<[Complex<F32>]>]>, TensorOf<[TensorOf<[Complex<F64>]>]>, AnyMemRef]>:$input_sequence,
|
||||
let arguments = (ins AnyTypeOf<[SeqOf<[TensorOf<[UI8]>]>, SeqOf<[TensorOf<[UI16]>]>, SeqOf<[TensorOf<[UI32]>]>, SeqOf<[TensorOf<[UI64]>]>, SeqOf<[TensorOf<[I8]>]>, SeqOf<[TensorOf<[I16]>]>, SeqOf<[TensorOf<[I32]>]>, SeqOf<[TensorOf<[I64]>]>, SeqOf<[TensorOf<[F16]>]>, SeqOf<[TensorOf<[F32]>]>, SeqOf<[TensorOf<[F64]>]>, SeqOf<[TensorOf<[StringType]>]>, SeqOf<[TensorOf<[I1]>]>, SeqOf<[TensorOf<[Complex<F32>]>]>, SeqOf<[TensorOf<[Complex<F64>]>]>, AnyMemRef]>:$input_sequence,
|
||||
I64Attr:$axis,
|
||||
DefaultValuedAttr<I64Attr, "0">:$new_axis);
|
||||
let results = (outs AnyTypeOf<[TensorOf<[UI8]>, TensorOf<[UI16]>, TensorOf<[UI32]>, TensorOf<[UI64]>, TensorOf<[I8]>, TensorOf<[I16]>, TensorOf<[I32]>, TensorOf<[I64]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, TensorOf<[StringType]>, TensorOf<[I1]>, TensorOf<[Complex<F32>]>, TensorOf<[Complex<F64>]>, AnyMemRef]>:$concat_result);
|
||||
|
@ -4582,7 +4582,7 @@ def ONNXSequenceAtOp:ONNX_Op<"SequenceAt",
|
|||
"Accepted range for 'position' is in `[-n, n - 1]`, where `n` is the number of tensors in 'input_sequence'."
|
||||
"Negative value means counting positions from the back."
|
||||
}];
|
||||
let arguments = (ins AnyTypeOf<[TensorOf<[TensorOf<[UI8]>]>, TensorOf<[TensorOf<[UI16]>]>, TensorOf<[TensorOf<[UI32]>]>, TensorOf<[TensorOf<[UI64]>]>, TensorOf<[TensorOf<[I8]>]>, TensorOf<[TensorOf<[I16]>]>, TensorOf<[TensorOf<[I32]>]>, TensorOf<[TensorOf<[I64]>]>, TensorOf<[TensorOf<[F16]>]>, TensorOf<[TensorOf<[F32]>]>, TensorOf<[TensorOf<[F64]>]>, TensorOf<[TensorOf<[StringType]>]>, TensorOf<[TensorOf<[I1]>]>, TensorOf<[TensorOf<[Complex<F32>]>]>, TensorOf<[TensorOf<[Complex<F64>]>]>, AnyMemRef]>:$input_sequence,
|
||||
let arguments = (ins AnyTypeOf<[SeqOf<[TensorOf<[UI8]>]>, SeqOf<[TensorOf<[UI16]>]>, SeqOf<[TensorOf<[UI32]>]>, SeqOf<[TensorOf<[UI64]>]>, SeqOf<[TensorOf<[I8]>]>, SeqOf<[TensorOf<[I16]>]>, SeqOf<[TensorOf<[I32]>]>, SeqOf<[TensorOf<[I64]>]>, SeqOf<[TensorOf<[F16]>]>, SeqOf<[TensorOf<[F32]>]>, SeqOf<[TensorOf<[F64]>]>, SeqOf<[TensorOf<[StringType]>]>, SeqOf<[TensorOf<[I1]>]>, SeqOf<[TensorOf<[Complex<F32>]>]>, SeqOf<[TensorOf<[Complex<F64>]>]>, AnyMemRef]>:$input_sequence,
|
||||
AnyTypeOf<[TensorOf<[I32]>, TensorOf<[I64]>, AnyMemRef]>:$position);
|
||||
let results = (outs AnyTypeOf<[TensorOf<[UI8]>, TensorOf<[UI16]>, TensorOf<[UI32]>, TensorOf<[UI64]>, TensorOf<[I8]>, TensorOf<[I16]>, TensorOf<[I32]>, TensorOf<[I64]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, TensorOf<[StringType]>, TensorOf<[I1]>, TensorOf<[Complex<F32>]>, TensorOf<[Complex<F64>]>, AnyMemRef]>:$tensor);
|
||||
let extraClassDeclaration = [{
|
||||
|
@ -4606,7 +4606,7 @@ def ONNXSequenceConstructOp:ONNX_Op<"SequenceConstruct",
|
|||
"All tensors in 'inputs' must have the same data type."
|
||||
}];
|
||||
let arguments = (ins Variadic<AnyTypeOf<[TensorOf<[UI8]>, TensorOf<[UI16]>, TensorOf<[UI32]>, TensorOf<[UI64]>, TensorOf<[I8]>, TensorOf<[I16]>, TensorOf<[I32]>, TensorOf<[I64]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, TensorOf<[StringType]>, TensorOf<[I1]>, TensorOf<[Complex<F32>]>, TensorOf<[Complex<F64>]>, AnyMemRef]>>:$inputs);
|
||||
let results = (outs AnyTypeOf<[TensorOf<[TensorOf<[UI8]>]>, TensorOf<[TensorOf<[UI16]>]>, TensorOf<[TensorOf<[UI32]>]>, TensorOf<[TensorOf<[UI64]>]>, TensorOf<[TensorOf<[I8]>]>, TensorOf<[TensorOf<[I16]>]>, TensorOf<[TensorOf<[I32]>]>, TensorOf<[TensorOf<[I64]>]>, TensorOf<[TensorOf<[F16]>]>, TensorOf<[TensorOf<[F32]>]>, TensorOf<[TensorOf<[F64]>]>, TensorOf<[TensorOf<[StringType]>]>, TensorOf<[TensorOf<[I1]>]>, TensorOf<[TensorOf<[Complex<F32>]>]>, TensorOf<[TensorOf<[Complex<F64>]>]>, AnyMemRef]>:$output_sequence);
|
||||
let results = (outs AnyTypeOf<[SeqOf<[TensorOf<[UI8]>]>, SeqOf<[TensorOf<[UI16]>]>, SeqOf<[TensorOf<[UI32]>]>, SeqOf<[TensorOf<[UI64]>]>, SeqOf<[TensorOf<[I8]>]>, SeqOf<[TensorOf<[I16]>]>, SeqOf<[TensorOf<[I32]>]>, SeqOf<[TensorOf<[I64]>]>, SeqOf<[TensorOf<[F16]>]>, SeqOf<[TensorOf<[F32]>]>, SeqOf<[TensorOf<[F64]>]>, SeqOf<[TensorOf<[StringType]>]>, SeqOf<[TensorOf<[I1]>]>, SeqOf<[TensorOf<[Complex<F32>]>]>, SeqOf<[TensorOf<[Complex<F64>]>]>, AnyMemRef]>:$output_sequence);
|
||||
let extraClassDeclaration = [{
|
||||
static int getNumberOfOperands() {
|
||||
return -1;
|
||||
|
@ -4627,7 +4627,7 @@ def ONNXSequenceEmptyOp:ONNX_Op<"SequenceEmpty",
|
|||
"Construct an empty tensor sequence, with given data type."
|
||||
}];
|
||||
let arguments = (ins OptionalAttr<I64Attr>:$dtype);
|
||||
let results = (outs AnyTypeOf<[TensorOf<[TensorOf<[UI8]>]>, TensorOf<[TensorOf<[UI16]>]>, TensorOf<[TensorOf<[UI32]>]>, TensorOf<[TensorOf<[UI64]>]>, TensorOf<[TensorOf<[I8]>]>, TensorOf<[TensorOf<[I16]>]>, TensorOf<[TensorOf<[I32]>]>, TensorOf<[TensorOf<[I64]>]>, TensorOf<[TensorOf<[F16]>]>, TensorOf<[TensorOf<[F32]>]>, TensorOf<[TensorOf<[F64]>]>, TensorOf<[TensorOf<[StringType]>]>, TensorOf<[TensorOf<[I1]>]>, TensorOf<[TensorOf<[Complex<F32>]>]>, TensorOf<[TensorOf<[Complex<F64>]>]>, AnyMemRef]>:$output);
|
||||
let results = (outs AnyTypeOf<[SeqOf<[TensorOf<[UI8]>]>, SeqOf<[TensorOf<[UI16]>]>, SeqOf<[TensorOf<[UI32]>]>, SeqOf<[TensorOf<[UI64]>]>, SeqOf<[TensorOf<[I8]>]>, SeqOf<[TensorOf<[I16]>]>, SeqOf<[TensorOf<[I32]>]>, SeqOf<[TensorOf<[I64]>]>, SeqOf<[TensorOf<[F16]>]>, SeqOf<[TensorOf<[F32]>]>, SeqOf<[TensorOf<[F64]>]>, SeqOf<[TensorOf<[StringType]>]>, SeqOf<[TensorOf<[I1]>]>, SeqOf<[TensorOf<[Complex<F32>]>]>, SeqOf<[TensorOf<[Complex<F64>]>]>, AnyMemRef]>:$output);
|
||||
let extraClassDeclaration = [{
|
||||
static int getNumberOfOperands() {
|
||||
return 0;
|
||||
|
@ -4650,9 +4650,9 @@ def ONNXSequenceEraseOp:ONNX_Op<"SequenceErase",
|
|||
"Negative value means counting positions from the back."
|
||||
"'position' is optional, by default it erases the last tensor from 'input_sequence'."
|
||||
}];
|
||||
let arguments = (ins AnyTypeOf<[TensorOf<[TensorOf<[UI8]>]>, TensorOf<[TensorOf<[UI16]>]>, TensorOf<[TensorOf<[UI32]>]>, TensorOf<[TensorOf<[UI64]>]>, TensorOf<[TensorOf<[I8]>]>, TensorOf<[TensorOf<[I16]>]>, TensorOf<[TensorOf<[I32]>]>, TensorOf<[TensorOf<[I64]>]>, TensorOf<[TensorOf<[F16]>]>, TensorOf<[TensorOf<[F32]>]>, TensorOf<[TensorOf<[F64]>]>, TensorOf<[TensorOf<[StringType]>]>, TensorOf<[TensorOf<[I1]>]>, TensorOf<[TensorOf<[Complex<F32>]>]>, TensorOf<[TensorOf<[Complex<F64>]>]>, AnyMemRef]>:$input_sequence,
|
||||
let arguments = (ins AnyTypeOf<[SeqOf<[TensorOf<[UI8]>]>, SeqOf<[TensorOf<[UI16]>]>, SeqOf<[TensorOf<[UI32]>]>, SeqOf<[TensorOf<[UI64]>]>, SeqOf<[TensorOf<[I8]>]>, SeqOf<[TensorOf<[I16]>]>, SeqOf<[TensorOf<[I32]>]>, SeqOf<[TensorOf<[I64]>]>, SeqOf<[TensorOf<[F16]>]>, SeqOf<[TensorOf<[F32]>]>, SeqOf<[TensorOf<[F64]>]>, SeqOf<[TensorOf<[StringType]>]>, SeqOf<[TensorOf<[I1]>]>, SeqOf<[TensorOf<[Complex<F32>]>]>, SeqOf<[TensorOf<[Complex<F64>]>]>, AnyMemRef]>:$input_sequence,
|
||||
AnyTypeOf<[TensorOf<[I32]>, TensorOf<[I64]>, AnyMemRef, NoneType]>:$position);
|
||||
let results = (outs AnyTypeOf<[TensorOf<[TensorOf<[UI8]>]>, TensorOf<[TensorOf<[UI16]>]>, TensorOf<[TensorOf<[UI32]>]>, TensorOf<[TensorOf<[UI64]>]>, TensorOf<[TensorOf<[I8]>]>, TensorOf<[TensorOf<[I16]>]>, TensorOf<[TensorOf<[I32]>]>, TensorOf<[TensorOf<[I64]>]>, TensorOf<[TensorOf<[F16]>]>, TensorOf<[TensorOf<[F32]>]>, TensorOf<[TensorOf<[F64]>]>, TensorOf<[TensorOf<[StringType]>]>, TensorOf<[TensorOf<[I1]>]>, TensorOf<[TensorOf<[Complex<F32>]>]>, TensorOf<[TensorOf<[Complex<F64>]>]>, AnyMemRef]>:$output_sequence);
|
||||
let results = (outs AnyTypeOf<[SeqOf<[TensorOf<[UI8]>]>, SeqOf<[TensorOf<[UI16]>]>, SeqOf<[TensorOf<[UI32]>]>, SeqOf<[TensorOf<[UI64]>]>, SeqOf<[TensorOf<[I8]>]>, SeqOf<[TensorOf<[I16]>]>, SeqOf<[TensorOf<[I32]>]>, SeqOf<[TensorOf<[I64]>]>, SeqOf<[TensorOf<[F16]>]>, SeqOf<[TensorOf<[F32]>]>, SeqOf<[TensorOf<[F64]>]>, SeqOf<[TensorOf<[StringType]>]>, SeqOf<[TensorOf<[I1]>]>, SeqOf<[TensorOf<[Complex<F32>]>]>, SeqOf<[TensorOf<[Complex<F64>]>]>, AnyMemRef]>:$output_sequence);
|
||||
let extraClassDeclaration = [{
|
||||
static int getNumberOfOperands() {
|
||||
return 2;
|
||||
|
@ -4676,10 +4676,10 @@ def ONNXSequenceInsertOp:ONNX_Op<"SequenceInsert",
|
|||
"Negative value means counting positions from the back."
|
||||
"'position' is optional, by default it inserts 'tensor' to the back of 'input_sequence'."
|
||||
}];
|
||||
let arguments = (ins AnyTypeOf<[TensorOf<[TensorOf<[UI8]>]>, TensorOf<[TensorOf<[UI16]>]>, TensorOf<[TensorOf<[UI32]>]>, TensorOf<[TensorOf<[UI64]>]>, TensorOf<[TensorOf<[I8]>]>, TensorOf<[TensorOf<[I16]>]>, TensorOf<[TensorOf<[I32]>]>, TensorOf<[TensorOf<[I64]>]>, TensorOf<[TensorOf<[F16]>]>, TensorOf<[TensorOf<[F32]>]>, TensorOf<[TensorOf<[F64]>]>, TensorOf<[TensorOf<[StringType]>]>, TensorOf<[TensorOf<[I1]>]>, TensorOf<[TensorOf<[Complex<F32>]>]>, TensorOf<[TensorOf<[Complex<F64>]>]>, AnyMemRef]>:$input_sequence,
|
||||
let arguments = (ins AnyTypeOf<[SeqOf<[TensorOf<[UI8]>]>, SeqOf<[TensorOf<[UI16]>]>, SeqOf<[TensorOf<[UI32]>]>, SeqOf<[TensorOf<[UI64]>]>, SeqOf<[TensorOf<[I8]>]>, SeqOf<[TensorOf<[I16]>]>, SeqOf<[TensorOf<[I32]>]>, SeqOf<[TensorOf<[I64]>]>, SeqOf<[TensorOf<[F16]>]>, SeqOf<[TensorOf<[F32]>]>, SeqOf<[TensorOf<[F64]>]>, SeqOf<[TensorOf<[StringType]>]>, SeqOf<[TensorOf<[I1]>]>, SeqOf<[TensorOf<[Complex<F32>]>]>, SeqOf<[TensorOf<[Complex<F64>]>]>, AnyMemRef]>:$input_sequence,
|
||||
AnyTypeOf<[TensorOf<[UI8]>, TensorOf<[UI16]>, TensorOf<[UI32]>, TensorOf<[UI64]>, TensorOf<[I8]>, TensorOf<[I16]>, TensorOf<[I32]>, TensorOf<[I64]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, TensorOf<[StringType]>, TensorOf<[I1]>, TensorOf<[Complex<F32>]>, TensorOf<[Complex<F64>]>, AnyMemRef]>:$tensor,
|
||||
AnyTypeOf<[TensorOf<[I32]>, TensorOf<[I64]>, AnyMemRef, NoneType]>:$position);
|
||||
let results = (outs AnyTypeOf<[TensorOf<[TensorOf<[UI8]>]>, TensorOf<[TensorOf<[UI16]>]>, TensorOf<[TensorOf<[UI32]>]>, TensorOf<[TensorOf<[UI64]>]>, TensorOf<[TensorOf<[I8]>]>, TensorOf<[TensorOf<[I16]>]>, TensorOf<[TensorOf<[I32]>]>, TensorOf<[TensorOf<[I64]>]>, TensorOf<[TensorOf<[F16]>]>, TensorOf<[TensorOf<[F32]>]>, TensorOf<[TensorOf<[F64]>]>, TensorOf<[TensorOf<[StringType]>]>, TensorOf<[TensorOf<[I1]>]>, TensorOf<[TensorOf<[Complex<F32>]>]>, TensorOf<[TensorOf<[Complex<F64>]>]>, AnyMemRef]>:$output_sequence);
|
||||
let results = (outs AnyTypeOf<[SeqOf<[TensorOf<[UI8]>]>, SeqOf<[TensorOf<[UI16]>]>, SeqOf<[TensorOf<[UI32]>]>, SeqOf<[TensorOf<[UI64]>]>, SeqOf<[TensorOf<[I8]>]>, SeqOf<[TensorOf<[I16]>]>, SeqOf<[TensorOf<[I32]>]>, SeqOf<[TensorOf<[I64]>]>, SeqOf<[TensorOf<[F16]>]>, SeqOf<[TensorOf<[F32]>]>, SeqOf<[TensorOf<[F64]>]>, SeqOf<[TensorOf<[StringType]>]>, SeqOf<[TensorOf<[I1]>]>, SeqOf<[TensorOf<[Complex<F32>]>]>, SeqOf<[TensorOf<[Complex<F64>]>]>, AnyMemRef]>:$output_sequence);
|
||||
let extraClassDeclaration = [{
|
||||
static int getNumberOfOperands() {
|
||||
return 3;
|
||||
|
@ -4699,7 +4699,7 @@ def ONNXSequenceLengthOp:ONNX_Op<"SequenceLength",
|
|||
let description = [{
|
||||
"Produces a scalar(tensor of empty shape) containing the number of tensors in 'input_sequence'."
|
||||
}];
|
||||
let arguments = (ins AnyTypeOf<[TensorOf<[TensorOf<[UI8]>]>, TensorOf<[TensorOf<[UI16]>]>, TensorOf<[TensorOf<[UI32]>]>, TensorOf<[TensorOf<[UI64]>]>, TensorOf<[TensorOf<[I8]>]>, TensorOf<[TensorOf<[I16]>]>, TensorOf<[TensorOf<[I32]>]>, TensorOf<[TensorOf<[I64]>]>, TensorOf<[TensorOf<[F16]>]>, TensorOf<[TensorOf<[F32]>]>, TensorOf<[TensorOf<[F64]>]>, TensorOf<[TensorOf<[StringType]>]>, TensorOf<[TensorOf<[I1]>]>, TensorOf<[TensorOf<[Complex<F32>]>]>, TensorOf<[TensorOf<[Complex<F64>]>]>, AnyMemRef]>:$input_sequence);
|
||||
let arguments = (ins AnyTypeOf<[SeqOf<[TensorOf<[UI8]>]>, SeqOf<[TensorOf<[UI16]>]>, SeqOf<[TensorOf<[UI32]>]>, SeqOf<[TensorOf<[UI64]>]>, SeqOf<[TensorOf<[I8]>]>, SeqOf<[TensorOf<[I16]>]>, SeqOf<[TensorOf<[I32]>]>, SeqOf<[TensorOf<[I64]>]>, SeqOf<[TensorOf<[F16]>]>, SeqOf<[TensorOf<[F32]>]>, SeqOf<[TensorOf<[F64]>]>, SeqOf<[TensorOf<[StringType]>]>, SeqOf<[TensorOf<[I1]>]>, SeqOf<[TensorOf<[Complex<F32>]>]>, SeqOf<[TensorOf<[Complex<F64>]>]>, AnyMemRef]>:$input_sequence);
|
||||
let results = (outs AnyTypeOf<[TensorOf<[I64]>, AnyMemRef]>:$length);
|
||||
let extraClassDeclaration = [{
|
||||
static int getNumberOfOperands() {
|
||||
|
@ -5074,7 +5074,7 @@ def ONNXSplitToSequenceOp:ONNX_Op<"SplitToSequence",
|
|||
AnyTypeOf<[TensorOf<[I32]>, TensorOf<[I64]>, AnyMemRef, NoneType]>:$split,
|
||||
DefaultValuedAttr<I64Attr, "0">:$axis,
|
||||
DefaultValuedAttr<I64Attr, "1">:$keepdims);
|
||||
let results = (outs AnyTypeOf<[TensorOf<[TensorOf<[UI8]>]>, TensorOf<[TensorOf<[UI16]>]>, TensorOf<[TensorOf<[UI32]>]>, TensorOf<[TensorOf<[UI64]>]>, TensorOf<[TensorOf<[I8]>]>, TensorOf<[TensorOf<[I16]>]>, TensorOf<[TensorOf<[I32]>]>, TensorOf<[TensorOf<[I64]>]>, TensorOf<[TensorOf<[F16]>]>, TensorOf<[TensorOf<[F32]>]>, TensorOf<[TensorOf<[F64]>]>, TensorOf<[TensorOf<[StringType]>]>, TensorOf<[TensorOf<[I1]>]>, TensorOf<[TensorOf<[Complex<F32>]>]>, TensorOf<[TensorOf<[Complex<F64>]>]>, AnyMemRef]>:$output_sequence);
|
||||
let results = (outs AnyTypeOf<[SeqOf<[TensorOf<[UI8]>]>, SeqOf<[TensorOf<[UI16]>]>, SeqOf<[TensorOf<[UI32]>]>, SeqOf<[TensorOf<[UI64]>]>, SeqOf<[TensorOf<[I8]>]>, SeqOf<[TensorOf<[I16]>]>, SeqOf<[TensorOf<[I32]>]>, SeqOf<[TensorOf<[I64]>]>, SeqOf<[TensorOf<[F16]>]>, SeqOf<[TensorOf<[F32]>]>, SeqOf<[TensorOf<[F64]>]>, SeqOf<[TensorOf<[StringType]>]>, SeqOf<[TensorOf<[I1]>]>, SeqOf<[TensorOf<[Complex<F32>]>]>, SeqOf<[TensorOf<[Complex<F64>]>]>, AnyMemRef]>:$output_sequence);
|
||||
let extraClassDeclaration = [{
|
||||
static int getNumberOfOperands() {
|
||||
return 2;
|
||||
|
@ -6240,7 +6240,7 @@ def ONNXZipMapOp:ONNX_Op<"ZipMap",
|
|||
let arguments = (ins TensorOf<[F32]>:$X,
|
||||
OptionalAttr<I64ArrayAttr>:$classlabels_int64s,
|
||||
OptionalAttr<StrArrayAttr>:$classlabels_strings);
|
||||
let results = (outs AnyTypeOf<[TensorOf<[TupleOf<[StringType, F32]>]>, TensorOf<[TupleOf<[I64, F32]>]>, AnyMemRef]>:$Z);
|
||||
let results = (outs AnyTypeOf<[SeqOf<[TupleOf<[StringType, F32]>]>, SeqOf<[TupleOf<[I64, F32]>]>, AnyMemRef]>:$Z);
|
||||
let extraClassDeclaration = [{
|
||||
static int getNumberOfOperands() {
|
||||
return 1;
|
||||
|
|
|
@ -14,3 +14,14 @@ func @check_string(%arg0: tensor<10x20x!onnx.String>) -> tensor<10x20x!onnx.Stri
|
|||
// CHECK-NEXT: return %arg0 : tensor<10x20x!onnx.String>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @check_seq(%arg0: tensor<10x20xf32>, %arg1: tensor<5x20xf32>) -> tensor<*xf32> {
|
||||
func @check_seq(%arg0: tensor<10x20xf32>, %arg1: tensor<5x20xf32>) -> tensor<*xf32> {
|
||||
%cst = "onnx.Constant"() {value = dense<[0]> : tensor<1xi32>} : () -> tensor<1xi32>
|
||||
%0 = "onnx.SequenceConstruct"(%arg0, %arg1) : (tensor<10x20xf32>, tensor<5x20xf32>) -> !onnx.Seq<tensor<10x20xf32>, tensor<5x20xf32>>
|
||||
%1 = "onnx.SequenceAt"(%0, %cst) : (!onnx.Seq<tensor<10x20xf32>, tensor<5x20xf32>>, tensor<1xi32>) -> tensor<*xf32>
|
||||
return %1 : tensor<*xf32>
|
||||
// CHECK-NEXT: %0 = "onnx.Constant"() {value = dense<0> : tensor<1xi32>} : () -> tensor<1xi32>
|
||||
// CHECK-NEXT: %1 = "onnx.SequenceConstruct"(%arg0, %arg1) : (tensor<10x20xf32>, tensor<5x20xf32>) -> !onnx.Seq<tensor<10x20xf32>, tensor<5x20xf32>>
|
||||
// CHECK-NEXT: %2 = "onnx.SequenceAt"(%1, %0) : (!onnx.Seq<tensor<10x20xf32>, tensor<5x20xf32>>, tensor<1xi32>) -> tensor<*xf32>
|
||||
}
|
||||
|
||||
|
|
|
@ -728,7 +728,7 @@ def parse_type_str(allowedType):
|
|||
onnx_to_mlir_type_dict = { '(': '<[',
|
||||
')': ']>',
|
||||
'tensor' : 'TensorOf',
|
||||
'seq' : 'TensorOf',
|
||||
'seq' : 'SeqOf',
|
||||
'map' : 'TupleOf',
|
||||
'bool': 'I1',
|
||||
#'uint8' : 'AnyI8',
|
||||
|
|
Loading…
Reference in New Issue