String type (Ready for Review) (#182)
* string type from tensorflow * simplify type * parser and print * gen StringType for tablegen * onnx to onnx-mlir type * add namespace * allow all integer type * dialect document * add test case * format * more precise type for ONNXOp * format * enable the failed test * update comment * update onnx.md Co-authored-by: Gheorghe-Teodor Bercea <gt.bercea@gmail.com>
This commit is contained in:
		
							parent
							
								
									f811718144
								
							
						
					
					
						commit
						2e08b2112c
					
				
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							| 
						 | 
					@ -16,6 +16,7 @@ add_library(OMONNXOps
 | 
				
			||||||
        ONNXOps.hpp
 | 
					        ONNXOps.hpp
 | 
				
			||||||
        ONNXOpsHelper.cpp
 | 
					        ONNXOpsHelper.cpp
 | 
				
			||||||
        ONNXOpsHelper.hpp)
 | 
					        ONNXOpsHelper.hpp)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
target_include_directories(OMONNXOps
 | 
					target_include_directories(OMONNXOps
 | 
				
			||||||
        PRIVATE
 | 
					        PRIVATE
 | 
				
			||||||
        ${ONNX_MLIR_SRC_ROOT}
 | 
					        ${ONNX_MLIR_SRC_ROOT}
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -11,6 +11,7 @@
 | 
				
			||||||
#include "mlir/Dialect/Traits.h"
 | 
					#include "mlir/Dialect/Traits.h"
 | 
				
			||||||
#include "mlir/IR/Block.h"
 | 
					#include "mlir/IR/Block.h"
 | 
				
			||||||
#include "mlir/IR/Builders.h"
 | 
					#include "mlir/IR/Builders.h"
 | 
				
			||||||
 | 
					#include "mlir/IR/DialectImplementation.h"
 | 
				
			||||||
#include "mlir/IR/Function.h"
 | 
					#include "mlir/IR/Function.h"
 | 
				
			||||||
#include "mlir/IR/IntegerSet.h"
 | 
					#include "mlir/IR/IntegerSet.h"
 | 
				
			||||||
#include "mlir/IR/Matchers.h"
 | 
					#include "mlir/IR/Matchers.h"
 | 
				
			||||||
| 
						 | 
					@ -25,6 +26,7 @@
 | 
				
			||||||
 | 
					
 | 
				
			||||||
using namespace mlir;
 | 
					using namespace mlir;
 | 
				
			||||||
using namespace mlir::OpTrait::util;
 | 
					using namespace mlir::OpTrait::util;
 | 
				
			||||||
 | 
					using namespace mlir::onnxmlir;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
//===----------------------------------------------------------------------===//
 | 
					//===----------------------------------------------------------------------===//
 | 
				
			||||||
// ONNX Helper functions
 | 
					// ONNX Helper functions
 | 
				
			||||||
| 
						 | 
					@ -481,6 +483,19 @@ ONNXOpsDialect::ONNXOpsDialect(mlir::MLIRContext *ctx)
 | 
				
			||||||
#define GET_OP_LIST
 | 
					#define GET_OP_LIST
 | 
				
			||||||
#include "src/Dialect/ONNX/ONNXOps.cpp.inc"
 | 
					#include "src/Dialect/ONNX/ONNXOps.cpp.inc"
 | 
				
			||||||
      >();
 | 
					      >();
 | 
				
			||||||
 | 
					  addTypes<StringType>();
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					mlir::Type ONNXOpsDialect::parseType(mlir::DialectAsmParser &parser) const {
 | 
				
			||||||
 | 
					  if (parser.parseKeyword("String"))
 | 
				
			||||||
 | 
					    return Type();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  return StringType::get(getContext());
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					void ONNXOpsDialect::printType(
 | 
				
			||||||
 | 
					    mlir::Type type, mlir::DialectAsmPrinter &printer) const {
 | 
				
			||||||
 | 
					  printer << "String";
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
void ONNXEntryPointOp::build(mlir::OpBuilder &builder,
 | 
					void ONNXEntryPointOp::build(mlir::OpBuilder &builder,
 | 
				
			||||||
| 
						 | 
					@ -2025,8 +2040,12 @@ LogicalResult ONNXDynamicQuantizeLinearOp::inferShapes() {
 | 
				
			||||||
  auto yScaleTy = y_scale().getType().cast<ShapedType>();
 | 
					  auto yScaleTy = y_scale().getType().cast<ShapedType>();
 | 
				
			||||||
  auto yZPTy = y_zero_point().getType().cast<ShapedType>();
 | 
					  auto yZPTy = y_zero_point().getType().cast<ShapedType>();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  IntegerType i8Type = IntegerType::get(8, getContext());
 | 
					  IntegerType ui8Type =
 | 
				
			||||||
  RankedTensorType scalarType = RankedTensorType::get({}, i8Type);
 | 
					      IntegerType::get(8, IntegerType::Unsigned, getContext());
 | 
				
			||||||
 | 
					  FloatType f32Type = FloatType::getF32(getContext());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  RankedTensorType scalarType = RankedTensorType::get({}, f32Type);
 | 
				
			||||||
 | 
					  RankedTensorType y_zero_point_type = RankedTensorType::get({}, ui8Type);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  // Set the types for the scalars
 | 
					  // Set the types for the scalars
 | 
				
			||||||
  if (!yScaleTy.hasStaticShape()) {
 | 
					  if (!yScaleTy.hasStaticShape()) {
 | 
				
			||||||
| 
						 | 
					@ -2034,11 +2053,11 @@ LogicalResult ONNXDynamicQuantizeLinearOp::inferShapes() {
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  if (!yZPTy.hasStaticShape()) {
 | 
					  if (!yZPTy.hasStaticShape()) {
 | 
				
			||||||
    y_zero_point().setType(scalarType);
 | 
					    y_zero_point().setType(y_zero_point_type);
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  if (!yTy.hasStaticShape()) {
 | 
					  if (!yTy.hasStaticShape()) {
 | 
				
			||||||
    RankedTensorType outType = RankedTensorType::get(inTy.getShape(), i8Type);
 | 
					    RankedTensorType outType = RankedTensorType::get(inTy.getShape(), ui8Type);
 | 
				
			||||||
    y().setType(outType);
 | 
					    y().setType(outType);
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -31,6 +31,13 @@ class ONNXOpsDialect : public Dialect {
 | 
				
			||||||
public:
 | 
					public:
 | 
				
			||||||
  ONNXOpsDialect(MLIRContext *context);
 | 
					  ONNXOpsDialect(MLIRContext *context);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  /// Parse an instance of a type registered to the onnx dialect.
 | 
				
			||||||
 | 
					  mlir::Type parseType(mlir::DialectAsmParser &parser) const override;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  /// Print an instance of a type registered to the onnx dialect.
 | 
				
			||||||
 | 
					  void printType(
 | 
				
			||||||
 | 
					      mlir::Type type, mlir::DialectAsmPrinter &printer) const override;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  /// Provide a utility accessor to the dialect namespace. This is used by
 | 
					  /// Provide a utility accessor to the dialect namespace. This is used by
 | 
				
			||||||
  /// several utilities for casting between dialects.
 | 
					  /// several utilities for casting between dialects.
 | 
				
			||||||
  static StringRef getDialectNamespace() { return "onnx"; }
 | 
					  static StringRef getDialectNamespace() { return "onnx"; }
 | 
				
			||||||
| 
						 | 
					@ -41,6 +48,39 @@ public:
 | 
				
			||||||
#define GET_OP_CLASSES
 | 
					#define GET_OP_CLASSES
 | 
				
			||||||
#include "src/Dialect/ONNX/ONNXOps.hpp.inc"
 | 
					#include "src/Dialect/ONNX/ONNXOps.hpp.inc"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// The namespace onnxmlir is experimental.
 | 
				
			||||||
 | 
					// onnx_mlir has been used in KRNL. Other candidates are onnxops, onnxdialect.
 | 
				
			||||||
 | 
					// Should this namesapce for onnx mlir project or ONNXOp dialect?
 | 
				
			||||||
 | 
					// Or we need two namespace?
 | 
				
			||||||
 | 
					// Will put all the ONNXOps into this namespace
 | 
				
			||||||
 | 
					namespace onnxmlir {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					namespace ONNXTypes {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					enum Kind {
 | 
				
			||||||
 | 
					  FIRST_USED_ONNX_TYPE = Type::FIRST_PRIVATE_EXPERIMENTAL_1_TYPE,
 | 
				
			||||||
 | 
					  //#define HANDLE_TF_TYPE(tftype, enumerant, name) enumerant,
 | 
				
			||||||
 | 
					  //#include "src/Dialect/ONNX/ONXTypes.def"
 | 
				
			||||||
 | 
					  STRING,
 | 
				
			||||||
 | 
					  SEQ,
 | 
				
			||||||
 | 
					  LAST_USED_ONNX_TYPE,
 | 
				
			||||||
 | 
					};
 | 
				
			||||||
 | 
					} // namespace ONNXTypes
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class StringType : public mlir::Type::TypeBase<StringType, mlir::Type> {
 | 
				
			||||||
 | 
					public:
 | 
				
			||||||
 | 
					  using Base::Base;
 | 
				
			||||||
 | 
					  static bool kindof(unsigned kind) { return kind == ONNXTypes::STRING; }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  static unsigned getTypeKind() { return ONNXTypes::STRING; }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  static StringType get(MLIRContext *ctx) {
 | 
				
			||||||
 | 
					    return Base::get(ctx, ONNXTypes::STRING);
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					} // end namespace onnxmlir
 | 
				
			||||||
 | 
					
 | 
				
			||||||
} // end namespace mlir
 | 
					} // end namespace mlir
 | 
				
			||||||
 | 
					
 | 
				
			||||||
namespace onnx_mlir {}
 | 
					namespace onnx_mlir {}
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -17,6 +17,8 @@
 | 
				
			||||||
include "mlir/IR/OpBase.td"
 | 
					include "mlir/IR/OpBase.td"
 | 
				
			||||||
#endif // OP_BASE
 | 
					#endif // OP_BASE
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def StringType : Type<CPred<"$_self.isa<StringType>()">, "stirng type">;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
#ifdef SHAPE_INFERENCE_INTERFACE
 | 
					#ifdef SHAPE_INFERENCE_INTERFACE
 | 
				
			||||||
#else
 | 
					#else
 | 
				
			||||||
include "src/Interface/ShapeInferenceInterface.td"
 | 
					include "src/Interface/ShapeInferenceInterface.td"
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							| 
						 | 
					@ -10,8 +10,11 @@
 | 
				
			||||||
 | 
					
 | 
				
			||||||
#include "ONNXOpsHelper.hpp"
 | 
					#include "ONNXOpsHelper.hpp"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#include "ONNXOps.hpp"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// Identity affine
 | 
					// Identity affine
 | 
				
			||||||
using namespace mlir;
 | 
					using namespace mlir;
 | 
				
			||||||
 | 
					using namespace mlir::onnxmlir;
 | 
				
			||||||
AffineMap getIdentityDimMap(Builder &builder) {
 | 
					AffineMap getIdentityDimMap(Builder &builder) {
 | 
				
			||||||
  return AffineMap::get(1, 0, {builder.getAffineDimExpr(0)});
 | 
					  return AffineMap::get(1, 0, {builder.getAffineDimExpr(0)});
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
| 
						 | 
					@ -55,21 +58,26 @@ mlir::Type convertONNXTypeToMLIRType(
 | 
				
			||||||
  case onnx::TensorProto_DataType::TensorProto_DataType_DOUBLE:
 | 
					  case onnx::TensorProto_DataType::TensorProto_DataType_DOUBLE:
 | 
				
			||||||
    return builder_.getF64Type();
 | 
					    return builder_.getF64Type();
 | 
				
			||||||
  case onnx::TensorProto_DataType::TensorProto_DataType_INT8:
 | 
					  case onnx::TensorProto_DataType::TensorProto_DataType_INT8:
 | 
				
			||||||
  case onnx::TensorProto_DataType::TensorProto_DataType_UINT8:
 | 
					 | 
				
			||||||
    return builder_.getIntegerType(/*width=*/8);
 | 
					    return builder_.getIntegerType(/*width=*/8);
 | 
				
			||||||
 | 
					  case onnx::TensorProto_DataType::TensorProto_DataType_UINT8:
 | 
				
			||||||
 | 
					    return builder_.getIntegerType(/*width=*/8, false);
 | 
				
			||||||
  case onnx::TensorProto_DataType::TensorProto_DataType_INT16:
 | 
					  case onnx::TensorProto_DataType::TensorProto_DataType_INT16:
 | 
				
			||||||
  case onnx::TensorProto_DataType::TensorProto_DataType_UINT16:
 | 
					 | 
				
			||||||
    return builder_.getIntegerType(/*width=*/16);
 | 
					    return builder_.getIntegerType(/*width=*/16);
 | 
				
			||||||
 | 
					  case onnx::TensorProto_DataType::TensorProto_DataType_UINT16:
 | 
				
			||||||
 | 
					    return builder_.getIntegerType(/*width=*/16, false);
 | 
				
			||||||
  case onnx::TensorProto_DataType::TensorProto_DataType_INT32:
 | 
					  case onnx::TensorProto_DataType::TensorProto_DataType_INT32:
 | 
				
			||||||
  case onnx::TensorProto_DataType::TensorProto_DataType_UINT32:
 | 
					 | 
				
			||||||
    return builder_.getIntegerType(/*width=*/32);
 | 
					    return builder_.getIntegerType(/*width=*/32);
 | 
				
			||||||
 | 
					  case onnx::TensorProto_DataType::TensorProto_DataType_UINT32:
 | 
				
			||||||
 | 
					    return builder_.getIntegerType(/*width=*/32, false);
 | 
				
			||||||
  case onnx::TensorProto_DataType::TensorProto_DataType_INT64:
 | 
					  case onnx::TensorProto_DataType::TensorProto_DataType_INT64:
 | 
				
			||||||
  case onnx::TensorProto_DataType::TensorProto_DataType_UINT64:
 | 
					 | 
				
			||||||
    return builder_.getIntegerType(/*width=*/64);
 | 
					    return builder_.getIntegerType(/*width=*/64);
 | 
				
			||||||
 | 
					  case onnx::TensorProto_DataType::TensorProto_DataType_UINT64:
 | 
				
			||||||
 | 
					    return builder_.getIntegerType(/*width=*/64, false);
 | 
				
			||||||
  case onnx::TensorProto_DataType::TensorProto_DataType_BOOL:
 | 
					  case onnx::TensorProto_DataType::TensorProto_DataType_BOOL:
 | 
				
			||||||
    return builder_.getI1Type();
 | 
					    return builder_.getI1Type();
 | 
				
			||||||
 | 
					 | 
				
			||||||
  case onnx::TensorProto_DataType::TensorProto_DataType_STRING:
 | 
					  case onnx::TensorProto_DataType::TensorProto_DataType_STRING:
 | 
				
			||||||
 | 
					    return StringType::get(builder_.getContext());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  case onnx::TensorProto_DataType::TensorProto_DataType_COMPLEX64:
 | 
					  case onnx::TensorProto_DataType::TensorProto_DataType_COMPLEX64:
 | 
				
			||||||
  case onnx::TensorProto_DataType::TensorProto_DataType_COMPLEX128:
 | 
					  case onnx::TensorProto_DataType::TensorProto_DataType_COMPLEX128:
 | 
				
			||||||
  case onnx::TensorProto_DataType::TensorProto_DataType_UNDEFINED:
 | 
					  case onnx::TensorProto_DataType::TensorProto_DataType_UNDEFINED:
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -2,8 +2,8 @@
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// -----
 | 
					// -----
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func @test_reshape(%arg0 : tensor<?x10xf32>, %arg1 : tensor<4xi32>) -> tensor<*xf32> {
 | 
					func @test_reshape(%arg0 : tensor<?x10xf32>, %arg1 : tensor<4xi64>) -> tensor<*xf32> {
 | 
				
			||||||
  %0 = "onnx.Reshape"(%arg0, %arg1) : (tensor<?x10xf32>, tensor<4xi32>) -> tensor<*xf32>
 | 
					  %0 = "onnx.Reshape"(%arg0, %arg1) : (tensor<?x10xf32>, tensor<4xi64>) -> tensor<*xf32>
 | 
				
			||||||
  "std.return"(%0) : (tensor<*xf32>) -> ()
 | 
					  "std.return"(%0) : (tensor<*xf32>) -> ()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  // CHECK: llvm.func @llvm.memcpy.p0i8.p0i8.i64(!llvm<"i8*">, !llvm<"i8*">, !llvm.i64, !llvm.i1)
 | 
					  // CHECK: llvm.func @llvm.memcpy.p0i8.p0i8.i64(!llvm<"i8*">, !llvm<"i8*">, !llvm.i64, !llvm.i1)
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -397,8 +397,8 @@ func @test_relu(%arg0 : tensor<?x10xf32>) -> tensor<*xf32> {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// -----
 | 
					// -----
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func @test_reshape(%arg0 : tensor<?x10xf32>, %arg1 : tensor<4xi32>) -> tensor<*xf32> {
 | 
					func @test_reshape(%arg0 : tensor<?x10xf32>, %arg1 : tensor<4xi64>) -> tensor<*xf32> {
 | 
				
			||||||
  %0 = "onnx.Reshape"(%arg0, %arg1) : (tensor<?x10xf32>, tensor<4xi32>) -> tensor<*xf32>
 | 
					  %0 = "onnx.Reshape"(%arg0, %arg1) : (tensor<?x10xf32>, tensor<4xi64>) -> tensor<*xf32>
 | 
				
			||||||
  "std.return"(%0) : (tensor<*xf32>) -> ()
 | 
					  "std.return"(%0) : (tensor<*xf32>) -> ()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  // CHECK-LABEL: test_reshape
 | 
					  // CHECK-LABEL: test_reshape
 | 
				
			||||||
| 
						 | 
					@ -411,56 +411,52 @@ func @test_reshape(%arg0 : tensor<?x10xf32>, %arg1 : tensor<4xi32>) -> tensor<*x
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  // CHECK: [[TYPE_IN_BYTES_1:%.+]] = constant 4 : i64
 | 
					  // CHECK: [[TYPE_IN_BYTES_1:%.+]] = constant 4 : i64
 | 
				
			||||||
  // CHECK: %[[CONSTANT_1:.+]] = constant 0 : index
 | 
					  // CHECK: %[[CONSTANT_1:.+]] = constant 0 : index
 | 
				
			||||||
  // CHECK: [[LOAD_0:%.+]] = load %arg1[%[[CONSTANT_1]]] : memref<4xi32>
 | 
					  // CHECK: [[LOAD_0:%.+]] = load %arg1[%[[CONSTANT_1]]] : memref<4xi64>
 | 
				
			||||||
  // CHECK: [[DIM_1:%.+]] = dim %arg0, 0 : memref<?x10xf32>
 | 
					  // CHECK: [[DIM_1:%.+]] = dim %arg0, 0 : memref<?x10xf32>
 | 
				
			||||||
  // CHECK: [[DIM_1_CAST:%.+]] = index_cast [[DIM_1]] : index to i32
 | 
					  // CHECK: [[DIM_1_CAST:%.+]] = index_cast [[DIM_1]] : index to i64
 | 
				
			||||||
  // CHECK: [[CONSTANT_2:%.+]] = constant 0 : i32
 | 
					  // CHECK: [[CONSTANT_2:%.+]] = constant 0 : i64
 | 
				
			||||||
  // CHECK: [[CMP_0:%.+]] = cmpi "eq", [[LOAD_0]], [[CONSTANT_2]] : i32
 | 
					  // CHECK: [[CMP_0:%.+]] = cmpi "eq", [[LOAD_0]], [[CONSTANT_2]] : i64
 | 
				
			||||||
  // CHECK: [[SELECT_0:%.+]] = select [[CMP_0]], [[DIM_1_CAST]], [[LOAD_0]] : i32
 | 
					  // CHECK: [[SELECT_0:%.+]] = select [[CMP_0]], [[DIM_1_CAST]], [[LOAD_0]] : i64
 | 
				
			||||||
  // CHECK: [[ZEXTI_0:%.+]] = zexti [[SELECT_0]] : i32 to i64
 | 
					  // CHECK: [[MUL_1:%.+]] = muli [[TYPE_IN_BYTES_1]], [[SELECT_0]] : i64
 | 
				
			||||||
  // CHECK: [[MUL_1:%.+]] = muli [[TYPE_IN_BYTES_1]], [[ZEXTI_0]] : i64
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
  // CHECK: %[[CONSTANT_3:.+]] = constant 1 : index
 | 
					  // CHECK: %[[CONSTANT_3:.+]] = constant 1 : index
 | 
				
			||||||
  // CHECK: [[LOAD_1:%.+]] = load %arg1[%[[CONSTANT_3]]] : memref<4xi32>
 | 
					  // CHECK: [[LOAD_1:%.+]] = load %arg1[%[[CONSTANT_3]]] : memref<4xi64>
 | 
				
			||||||
  // CHECK: [[CONSTANT_3:%.+]] = constant 10 : i32
 | 
					  // CHECK: [[CONSTANT_3:%.+]] = constant 10 : i64
 | 
				
			||||||
  // CHECK: [[CONSTANT_4:%.+]] = constant 0 : i32
 | 
					  // CHECK: [[CONSTANT_4:%.+]] = constant 0 : i64
 | 
				
			||||||
  // CHECK: [[CMP_1:%.+]] = cmpi "eq", [[LOAD_1]], [[CONSTANT_4]] : i32
 | 
					  // CHECK: [[CMP_1:%.+]] = cmpi "eq", [[LOAD_1]], [[CONSTANT_4]] : i64
 | 
				
			||||||
  // CHECK: [[SELECT_1:%.+]] = select [[CMP_1]], [[CONSTANT_3]], [[LOAD_1]] : i32
 | 
					  // CHECK: [[SELECT_1:%.+]] = select [[CMP_1]], [[CONSTANT_3]], [[LOAD_1]] : i64
 | 
				
			||||||
  // CHECK: [[ZEXTI_1:%.+]] = zexti [[SELECT_1]] : i32 to i64
 | 
					  // CHECK: [[MUL_2:%.+]] = muli [[MUL_1]], [[SELECT_1]] : i64
 | 
				
			||||||
  // CHECK: [[MUL_2:%.+]] = muli [[MUL_1]], [[ZEXTI_1]] : i64
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
  // CHECK: %[[CONSTANT_5:.+]] = constant 2 : index
 | 
					  // CHECK: %[[CONSTANT_5:.+]] = constant 2 : index
 | 
				
			||||||
  // CHECK: [[LOAD_2:%.+]] = load %arg1[%[[CONSTANT_5]]] : memref<4xi32>
 | 
					  // CHECK: [[LOAD_2:%.+]] = load %arg1[%[[CONSTANT_5]]] : memref<4xi64>
 | 
				
			||||||
  // CHECK: [[ZEXTI_2:%.+]] = zexti [[LOAD_2]] : i32 to i64
 | 
					  // CHECK: [[MUL_3:%.+]] = muli [[MUL_2]], [[LOAD_2]] : i64
 | 
				
			||||||
  // CHECK: [[MUL_3:%.+]] = muli [[MUL_2]], [[ZEXTI_2]] : i64
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
  // CHECK: %[[CONSTANT_6:.+]] = constant 3 : index
 | 
					  // CHECK: %[[CONSTANT_6:.+]] = constant 3 : index
 | 
				
			||||||
  // CHECK: [[LOAD_3:%.+]] = load %arg1[%[[CONSTANT_6]]] : memref<4xi32>
 | 
					  // CHECK: [[LOAD_3:%.+]] = load %arg1[%[[CONSTANT_6]]] : memref<4xi64>
 | 
				
			||||||
  // CHECK: [[ZEXTI_3:%.+]] = zexti [[LOAD_3]] : i32 to i64
 | 
					  // CHECK: [[MUL_4:%.+]] = muli [[MUL_3]], [[LOAD_3]] : i64
 | 
				
			||||||
  // CHECK: [[MUL_4:%.+]] = muli [[MUL_3]], [[ZEXTI_3]] : i64
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
  // CHECK: [[CONSTANT_7:%.+]] = constant 0 : i64
 | 
					  // CHECK: [[CONSTANT_7:%.+]] = constant 0 : i64
 | 
				
			||||||
  // CHECK: [[SUB_0:%.+]] = subi [[CONSTANT_7]], [[MUL_4]] : i64
 | 
					  // CHECK: [[SUB_0:%.+]] = subi [[CONSTANT_7]], [[MUL_4]] : i64
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  // CHECK: [[CONSTANT_8:%.+]] = constant -1 : i64
 | 
					  // CHECK: [[CONSTANT_8:%.+]] = constant -1 : i64
 | 
				
			||||||
  // CHECK: [[CMP_2:%.+]] = cmpi "eq", [[ZEXTI_0]], [[CONSTANT_8]] : i64
 | 
					  // CHECK: [[CMP_2:%.+]] = cmpi "eq", [[SELECT_0]], [[CONSTANT_8]] : i64
 | 
				
			||||||
  // CHECK: [[DIVISIGNED_0:%.+]] = divi_signed [[TENSOR_SIZE]], [[SUB_0]] : i64
 | 
					  // CHECK: [[DIVISIGNED_0:%.+]] = divi_signed [[TENSOR_SIZE]], [[SUB_0]] : i64
 | 
				
			||||||
  // CHECK: [[SELECT_2:%.+]] = select [[CMP_2]], [[DIVISIGNED_0]], [[ZEXTI_0]] : i64
 | 
					  // CHECK: [[SELECT_2:%.+]] = select [[CMP_2]], [[DIVISIGNED_0]], [[SELECT_0]] : i64
 | 
				
			||||||
  // CHECK: [[CAST_0:%.+]] = index_cast [[SELECT_2]] : i64 to index
 | 
					  // CHECK: [[CAST_0:%.+]] = index_cast [[SELECT_2]] : i64 to index
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  // CHECK: [[CMP_3:%.+]] = cmpi "eq", [[ZEXTI_1]], [[CONSTANT_8]] : i64
 | 
					  // CHECK: [[CMP_3:%.+]] = cmpi "eq", [[SELECT_1]], [[CONSTANT_8]] : i64
 | 
				
			||||||
  // CHECK: [[DIVISIGNED_1:%.+]] = divi_signed [[TENSOR_SIZE]], [[SUB_0]] : i64
 | 
					  // CHECK: [[DIVISIGNED_1:%.+]] = divi_signed [[TENSOR_SIZE]], [[SUB_0]] : i64
 | 
				
			||||||
  // CHECK: [[SELECT_3:%.+]] = select [[CMP_3]], [[DIVISIGNED_1]], [[ZEXTI_1]] : i64
 | 
					  // CHECK: [[SELECT_3:%.+]] = select [[CMP_3]], [[DIVISIGNED_1]], [[SELECT_1]] : i64
 | 
				
			||||||
  // CHECK: [[CAST_1:%.+]] = index_cast [[SELECT_3]] : i64 to index
 | 
					  // CHECK: [[CAST_1:%.+]] = index_cast [[SELECT_3]] : i64 to index
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  // CHECK: [[CMP_4:%.+]] = cmpi "eq", [[ZEXTI_2]], [[CONSTANT_8]] : i64
 | 
					  // CHECK: [[CMP_4:%.+]] = cmpi "eq", [[LOAD_2]], [[CONSTANT_8]] : i64
 | 
				
			||||||
  // CHECK: [[DIVISIGNED_2:%.+]] = divi_signed [[TENSOR_SIZE]], [[SUB_0]] : i64
 | 
					  // CHECK: [[DIVISIGNED_2:%.+]] = divi_signed [[TENSOR_SIZE]], [[SUB_0]] : i64
 | 
				
			||||||
  // CHECK: [[SELECT_4:%.+]] = select [[CMP_4]], [[DIVISIGNED_2]], [[ZEXTI_2]] : i64
 | 
					  // CHECK: [[SELECT_4:%.+]] = select [[CMP_4]], [[DIVISIGNED_2]], [[LOAD_2]] : i64
 | 
				
			||||||
  // CHECK: [[CAST_2:%.+]] = index_cast [[SELECT_4]] : i64 to index
 | 
					  // CHECK: [[CAST_2:%.+]] = index_cast [[SELECT_4]] : i64 to index
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  // CHECK: [[CMP_5:%.+]] = cmpi "eq", [[ZEXTI_3]], [[CONSTANT_8]] : i64
 | 
					  // CHECK: [[CMP_5:%.+]] = cmpi "eq", [[LOAD_3]], [[CONSTANT_8]] : i64
 | 
				
			||||||
  // CHECK: [[DIVISIGNED_3:%.+]] = divi_signed [[TENSOR_SIZE]], [[SUB_0]] : i64
 | 
					  // CHECK: [[DIVISIGNED_3:%.+]] = divi_signed [[TENSOR_SIZE]], [[SUB_0]] : i64
 | 
				
			||||||
  // CHECK: [[SELECT_5:%.+]] = select [[CMP_5]], [[DIVISIGNED_3]], [[ZEXTI_3]] : i64
 | 
					  // CHECK: [[SELECT_5:%.+]] = select [[CMP_5]], [[DIVISIGNED_3]], [[LOAD_3]] : i64
 | 
				
			||||||
  // CHECK: [[CAST_3:%.+]] = index_cast [[SELECT_5]] : i64 to index
 | 
					  // CHECK: [[CAST_3:%.+]] = index_cast [[SELECT_5]] : i64 to index
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  // CHECK: [[ALLOC:%.+]] = alloc([[CAST_0]], [[CAST_1]], [[CAST_2]], [[CAST_3]]) : memref<?x?x?x?xf32>
 | 
					  // CHECK: [[ALLOC:%.+]] = alloc([[CAST_0]], [[CAST_1]], [[CAST_2]], [[CAST_3]]) : memref<?x?x?x?xf32>
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -542,48 +542,48 @@ func @test_default_averagepool_strides_nonunifpad_ceil(%arg0 : tensor<5x5x30x32x
 | 
				
			||||||
/// Test the reshape op inference when constants are present.
 | 
					/// Test the reshape op inference when constants are present.
 | 
				
			||||||
//===----------------------------------------------------------------------===//
 | 
					//===----------------------------------------------------------------------===//
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func @test_reshape_dynamic(%arg0 : tensor<5x5x1x32xf32>, %arg1 : tensor<4xi32>) -> tensor<*xf32> {
 | 
					func @test_reshape_dynamic(%arg0 : tensor<5x5x1x32xf32>, %arg1 : tensor<4xi64>) -> tensor<*xf32> {
 | 
				
			||||||
  %0 = "onnx.Reshape"(%arg0, %arg1) : (tensor<5x5x1x32xf32>, tensor<4xi32>) -> tensor<*xf32>
 | 
					  %0 = "onnx.Reshape"(%arg0, %arg1) : (tensor<5x5x1x32xf32>, tensor<4xi64>) -> tensor<*xf32>
 | 
				
			||||||
  "std.return"(%0) : (tensor<*xf32>) -> ()
 | 
					  "std.return"(%0) : (tensor<*xf32>) -> ()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  // CHECK-LABEL: test_reshape_dynamic
 | 
					  // CHECK-LABEL: test_reshape_dynamic
 | 
				
			||||||
  // CHECK: [[RES:%.+]] = "onnx.Reshape"(%arg0, %arg1) : (tensor<5x5x1x32xf32>, tensor<4xi32>) -> tensor<?x?x?x?xf32>
 | 
					  // CHECK: [[RES:%.+]] = "onnx.Reshape"(%arg0, %arg1) : (tensor<5x5x1x32xf32>, tensor<4xi64>) -> tensor<?x?x?x?xf32>
 | 
				
			||||||
  // CHECK: return [[RES]] : tensor<?x?x?x?xf32>
 | 
					  // CHECK: return [[RES]] : tensor<?x?x?x?xf32>
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// -----
 | 
					// -----
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func @test_reshape_1(%arg0 : tensor<5x5x1x32xf32>) -> tensor<*xf32> {
 | 
					func @test_reshape_1(%arg0 : tensor<5x5x1x32xf32>) -> tensor<*xf32> {
 | 
				
			||||||
  %0 = "onnx.Constant"() {value = dense<[5, 5, 16, 2]> : tensor<4xi32> } : () -> tensor<4xi32>
 | 
					  %0 = "onnx.Constant"() {value = dense<[5, 5, 16, 2]> : tensor<4xi64> } : () -> tensor<4xi64>
 | 
				
			||||||
  %1 = "onnx.Reshape"(%arg0, %0) : (tensor<5x5x1x32xf32>, tensor<4xi32>) -> tensor<*xf32>
 | 
					  %1 = "onnx.Reshape"(%arg0, %0) : (tensor<5x5x1x32xf32>, tensor<4xi64>) -> tensor<*xf32>
 | 
				
			||||||
  "std.return"(%1) : (tensor<*xf32>) -> ()
 | 
					  "std.return"(%1) : (tensor<*xf32>) -> ()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  // CHECK-LABEL: test_reshape_1
 | 
					  // CHECK-LABEL: test_reshape_1
 | 
				
			||||||
  // CHECK: [[RES:%.+]] = "onnx.Reshape"(%arg0, %0) : (tensor<5x5x1x32xf32>, tensor<4xi32>) -> tensor<5x5x16x2xf32>
 | 
					  // CHECK: [[RES:%.+]] = "onnx.Reshape"(%arg0, %0) : (tensor<5x5x1x32xf32>, tensor<4xi64>) -> tensor<5x5x16x2xf32>
 | 
				
			||||||
  // CHECK: return [[RES]] : tensor<5x5x16x2xf32>
 | 
					  // CHECK: return [[RES]] : tensor<5x5x16x2xf32>
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// -----
 | 
					// -----
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func @test_reshape_2(%arg0 : tensor<5x5x1x32xf32>) -> tensor<*xf32> {
 | 
					func @test_reshape_2(%arg0 : tensor<5x5x1x32xf32>) -> tensor<*xf32> {
 | 
				
			||||||
  %0 = "onnx.Constant"() {value = dense<[-1, 16, 2]> : tensor<3xi32> } : () -> tensor<3xi32>
 | 
					  %0 = "onnx.Constant"() {value = dense<[-1, 16, 2]> : tensor<3xi64> } : () -> tensor<3xi64>
 | 
				
			||||||
  %1 = "onnx.Reshape"(%arg0, %0) : (tensor<5x5x1x32xf32>, tensor<3xi32>) -> tensor<*xf32>
 | 
					  %1 = "onnx.Reshape"(%arg0, %0) : (tensor<5x5x1x32xf32>, tensor<3xi64>) -> tensor<*xf32>
 | 
				
			||||||
  "std.return"(%1) : (tensor<*xf32>) -> ()
 | 
					  "std.return"(%1) : (tensor<*xf32>) -> ()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  // CHECK-LABEL: test_reshape_2
 | 
					  // CHECK-LABEL: test_reshape_2
 | 
				
			||||||
  // CHECK: [[RES:%.+]] = "onnx.Reshape"(%arg0, %0) : (tensor<5x5x1x32xf32>, tensor<3xi32>) -> tensor<25x16x2xf32>
 | 
					  // CHECK: [[RES:%.+]] = "onnx.Reshape"(%arg0, %0) : (tensor<5x5x1x32xf32>, tensor<3xi64>) -> tensor<25x16x2xf32>
 | 
				
			||||||
  // CHECK: return [[RES]] : tensor<25x16x2xf32>
 | 
					  // CHECK: return [[RES]] : tensor<25x16x2xf32>
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// -----
 | 
					// -----
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func @test_reshape_3(%arg0 : tensor<5x5x1x32xf32>) -> tensor<*xf32> {
 | 
					func @test_reshape_3(%arg0 : tensor<5x5x1x32xf32>) -> tensor<*xf32> {
 | 
				
			||||||
  %0 = "onnx.Constant"() {value = dense<[-1, 0, 2]> : tensor<3xi32> } : () -> tensor<3xi32>
 | 
					  %0 = "onnx.Constant"() {value = dense<[-1, 0, 2]> : tensor<3xi64> } : () -> tensor<3xi64>
 | 
				
			||||||
  %1 = "onnx.Reshape"(%arg0, %0) : (tensor<5x5x1x32xf32>, tensor<3xi32>) -> tensor<*xf32>
 | 
					  %1 = "onnx.Reshape"(%arg0, %0) : (tensor<5x5x1x32xf32>, tensor<3xi64>) -> tensor<*xf32>
 | 
				
			||||||
  "std.return"(%1) : (tensor<*xf32>) -> ()
 | 
					  "std.return"(%1) : (tensor<*xf32>) -> ()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  // CHECK-LABEL: test_reshape_3
 | 
					  // CHECK-LABEL: test_reshape_3
 | 
				
			||||||
  // CHECK: [[RES:%.+]] = "onnx.Reshape"(%arg0, %0) : (tensor<5x5x1x32xf32>, tensor<3xi32>) -> tensor<80x5x2xf32>
 | 
					  // CHECK: [[RES:%.+]] = "onnx.Reshape"(%arg0, %0) : (tensor<5x5x1x32xf32>, tensor<3xi64>) -> tensor<80x5x2xf32>
 | 
				
			||||||
  // CHECK: return [[RES]] : tensor<80x5x2xf32>
 | 
					  // CHECK: return [[RES]] : tensor<80x5x2xf32>
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -904,13 +904,13 @@ func @test_cast_2(%arg0 : tensor<2x3x4xf32>) -> tensor<*xui8> {
 | 
				
			||||||
  "std.return"(%1) : (tensor<*xui8>) -> ()
 | 
					  "std.return"(%1) : (tensor<*xui8>) -> ()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  // CHECK-LABEL: test_cast_2
 | 
					  // CHECK-LABEL: test_cast_2
 | 
				
			||||||
  // CHECK: [[RES:%.+]] = "onnx.Cast"(%arg0) {to = 2 : i64} : (tensor<2x3x4xf32>) -> tensor<2x3x4xi8>
 | 
					  // CHECK: [[RES:%.+]] = "onnx.Cast"(%arg0) {to = 2 : i64} : (tensor<2x3x4xf32>) -> tensor<2x3x4xui8>
 | 
				
			||||||
  // CHECK: return [[RES]] : tensor<2x3x4xi8>
 | 
					  // CHECK: return [[RES]] : tensor<2x3x4xui8>
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func @test_cast_3(%arg0 : tensor<2x3x4xf32>) -> tensor<*xsi8> {
 | 
					func @test_cast_3(%arg0 : tensor<2x3x4xf32>) -> tensor<*xi8> {
 | 
				
			||||||
  %1 = "onnx.Cast"(%arg0) {to = 3} : (tensor<2x3x4xf32>) -> tensor<*xsi8>
 | 
					  %1 = "onnx.Cast"(%arg0) {to = 3} : (tensor<2x3x4xf32>) -> tensor<*xi8>
 | 
				
			||||||
  "std.return"(%1) : (tensor<*xsi8>) -> ()
 | 
					  "std.return"(%1) : (tensor<*xi8>) -> ()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  // CHECK-LABEL: test_cast_3
 | 
					  // CHECK-LABEL: test_cast_3
 | 
				
			||||||
  // CHECK: [[RES:%.+]] = "onnx.Cast"(%arg0) {to = 3 : i64} : (tensor<2x3x4xf32>) -> tensor<2x3x4xi8>
 | 
					  // CHECK: [[RES:%.+]] = "onnx.Cast"(%arg0) {to = 3 : i64} : (tensor<2x3x4xf32>) -> tensor<2x3x4xi8>
 | 
				
			||||||
| 
						 | 
					@ -930,30 +930,33 @@ func @test_cast_10(%arg0 : tensor<2x3x4xf32>) -> tensor<*xf16> {
 | 
				
			||||||
/// Test the quantization op inferences.
 | 
					/// Test the quantization op inferences.
 | 
				
			||||||
//===----------------------------------------------------------------------===//
 | 
					//===----------------------------------------------------------------------===//
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func @test_dyn_quantize_linear_1(%arg0 : tensor<5x2x3x4xf32>) -> tensor<*xi8> {
 | 
					// TOFIX
 | 
				
			||||||
  %1:3 = "onnx.DynamicQuantizeLinear"(%arg0) {} : (tensor<5x2x3x4xf32>) -> (tensor<*xi8>, tensor<*xi8>, tensor<*xi8>)
 | 
					// This test case is commented out because the #1 output should be tensor<f32>
 | 
				
			||||||
  "std.return"(%1#0) {} : (tensor<*xi8>) -> ()
 | 
					// but tensor<i8> is generated
 | 
				
			||||||
 | 
					func @test_dyn_quantize_linear_1(%arg0 : tensor<5x2x3x4xf32>) -> tensor<*xui8> {
 | 
				
			||||||
 | 
					 %1:3 = "onnx.DynamicQuantizeLinear"(%arg0) {} : (tensor<5x2x3x4xf32>) -> (tensor<*xui8>, tensor<*xf32>, tensor<*xui8>)
 | 
				
			||||||
 | 
					 "std.return"(%1#0) {} : (tensor<*xui8>) -> ()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 // CHECK-LABEL: test_dyn_quantize_linear_1
 | 
					 // CHECK-LABEL: test_dyn_quantize_linear_1
 | 
				
			||||||
  // CHECK: [[RES:%.+]], {{.*}}, {{.*}} = "onnx.DynamicQuantizeLinear"(%arg0) : (tensor<5x2x3x4xf32>) -> (tensor<5x2x3x4xi8>, tensor<i8>, tensor<i8>)
 | 
					 // CHECK: [[RES:%.+]], {{.*}}, {{.*}} = "onnx.DynamicQuantizeLinear"(%arg0) : (tensor<5x2x3x4xf32>) -> (tensor<5x2x3x4xui8>, tensor<f32>, tensor<ui8>)
 | 
				
			||||||
  // CHECK: return [[RES]] : tensor<5x2x3x4xi8>
 | 
					 // CHECK: return [[RES]] : tensor<5x2x3x4xui8>
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func @test_quantize_linear_1(%arg0 : tensor<5x2x3x4xf32>, %arg1 : tensor<i8>, %arg2 : tensor<i8>) -> tensor<*xi8> {
 | 
					func @test_quantize_linear_1(%arg0 : tensor<5x2x3x4xf32>, %arg1 : tensor<f32>, %arg2 : tensor<i8>) -> tensor<*xi8> {
 | 
				
			||||||
  %1 = "onnx.QuantizeLinear"(%arg0, %arg1, %arg2) {} : (tensor<5x2x3x4xf32>, tensor<i8>, tensor<i8>) -> tensor<*xi8>
 | 
					  %1 = "onnx.QuantizeLinear"(%arg0, %arg1, %arg2) {} : (tensor<5x2x3x4xf32>, tensor<f32>, tensor<i8>) -> tensor<*xi8>
 | 
				
			||||||
  "std.return"(%1) {} : (tensor<*xi8>) -> ()
 | 
					  "std.return"(%1) {} : (tensor<*xi8>) -> ()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  // CHECK-LABEL: test_quantize_linear_1
 | 
					  // CHECK-LABEL: test_quantize_linear_1
 | 
				
			||||||
  // CHECK: [[RES:%.+]] = "onnx.QuantizeLinear"(%arg0, %arg1, %arg2) : (tensor<5x2x3x4xf32>, tensor<i8>, tensor<i8>) -> tensor<5x2x3x4xi8>
 | 
					  // CHECK: [[RES:%.+]] = "onnx.QuantizeLinear"(%arg0, %arg1, %arg2) : (tensor<5x2x3x4xf32>, tensor<f32>, tensor<i8>) -> tensor<5x2x3x4xi8>
 | 
				
			||||||
  // CHECK: return [[RES]] : tensor<5x2x3x4xi8>
 | 
					  // CHECK: return [[RES]] : tensor<5x2x3x4xi8>
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func @test_dequantize_linear_1(%arg0 : tensor<5x2x3x4xi8>, %arg1 : tensor<i8>, %arg2 : tensor<i8>) -> tensor<*xf32> {
 | 
					func @test_dequantize_linear_1(%arg0 : tensor<5x2x3x4xi8>, %arg1 : tensor<f32>, %arg2 : tensor<i8>) -> tensor<*xf32> {
 | 
				
			||||||
  %1 = "onnx.DequantizeLinear"(%arg0, %arg1, %arg2) {} : (tensor<5x2x3x4xi8>, tensor<i8>, tensor<i8>) -> tensor<*xf32>
 | 
					  %1 = "onnx.DequantizeLinear"(%arg0, %arg1, %arg2) {} : (tensor<5x2x3x4xi8>, tensor<f32>, tensor<i8>) -> tensor<*xf32>
 | 
				
			||||||
  "std.return"(%1) {} : (tensor<*xf32>) -> ()
 | 
					  "std.return"(%1) {} : (tensor<*xf32>) -> ()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  // CHECK-LABEL: test_dequantize_linear_1
 | 
					  // CHECK-LABEL: test_dequantize_linear_1
 | 
				
			||||||
  // CHECK: [[RES:%.+]] = "onnx.DequantizeLinear"(%arg0, %arg1, %arg2) : (tensor<5x2x3x4xi8>, tensor<i8>, tensor<i8>) -> tensor<5x2x3x4xf32>
 | 
					  // CHECK: [[RES:%.+]] = "onnx.DequantizeLinear"(%arg0, %arg1, %arg2) : (tensor<5x2x3x4xi8>, tensor<f32>, tensor<i8>) -> tensor<5x2x3x4xf32>
 | 
				
			||||||
  // CHECK: return [[RES]] : tensor<5x2x3x4xf32>
 | 
					  // CHECK: return [[RES]] : tensor<5x2x3x4xf32>
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -1,9 +1,16 @@
 | 
				
			||||||
// RUN: onnx-mlir-opt %s -split-input-file | FileCheck %s
 | 
					// RUN: onnx-mlir-opt %s -split-input-file | FileCheck %s
 | 
				
			||||||
 | 
					
 | 
				
			||||||
//===----------------------------------------------------------------------===//
 | 
					//===----------------------------------------------------------------------===//
 | 
				
			||||||
// CHECK-LABEL: @check_map1(%arg0: tuple<tensor<10xi64>, tensor<10xi64>>) -> tensor<*xi64> {
 | 
					// CHECK-LABEL: @check_map1(%arg0: tuple<i64, f32>) -> tensor<*xf32> {
 | 
				
			||||||
func @check_map1(%arg0: tuple<tensor<10xi64>, tensor<10xi64>>) -> tensor<*xi64> {
 | 
					func @check_map1(%arg0: tuple<i64, f32>) -> tensor<*xf32> {
 | 
				
			||||||
  %0 = "onnx.CastMap"(%arg0) {cast_to = "TO_FLOAT", map_form = "DENSE", max_map = 1 : i64} : (tuple<tensor<10xi64>, tensor<10xi64>>) -> tensor<*xi64>
 | 
					  %0 = "onnx.CastMap"(%arg0) {cast_to = "TO_FLOAT", map_form = "DENSE", max_map = 1 : i64} : (tuple<i64, f32>) -> tensor<*xf32>
 | 
				
			||||||
  return %0 : tensor<*xi64>
 | 
					  return %0 : tensor<*xf32>
 | 
				
			||||||
  // CHECK-NEXT: %0 = "onnx.CastMap"(%arg0) {cast_to = "TO_FLOAT", map_form = "DENSE", max_map = 1 : i64} : (tuple<tensor<10xi64>, tensor<10xi64>>) -> tensor<*xi64>
 | 
					  // CHECK-NEXT: %0 = "onnx.CastMap"(%arg0) {cast_to = "TO_FLOAT", map_form = "DENSE", max_map = 1 : i64} : (tuple<i64, f32>) -> tensor<*xf32>
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// CHECK-LABEL: @check_string(%arg0: tensor<10x20x!onnx.String>) -> tensor<10x20x!onnx.String> {
 | 
				
			||||||
 | 
					func @check_string(%arg0: tensor<10x20x!onnx.String>) -> tensor<10x20x!onnx.String> {
 | 
				
			||||||
 | 
					  return %arg0 : tensor<10x20x!onnx.String>
 | 
				
			||||||
 | 
					  // CHECK-NEXT: return %arg0 : tensor<10x20x!onnx.String>
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -1,22 +1,22 @@
 | 
				
			||||||
// RUN: onnx-mlir-opt --attribute-promotion %s -split-input-file | FileCheck %s
 | 
					// RUN: onnx-mlir-opt --attribute-promotion %s -split-input-file | FileCheck %s
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func @test_should_promote_to_attribute(%arg0 : tensor<?x10xf32>) -> tensor<*xf32> {
 | 
					func @test_should_promote_to_attribute(%arg0 : tensor<?x10xf32>) -> tensor<*xf32> {
 | 
				
			||||||
  %shape = constant dense<[6, 7, 42]> : tensor<3xi32>
 | 
					  %shape = constant dense<[6, 7, 42]> : tensor<3xi64>
 | 
				
			||||||
  %0 = "onnx.Reshape"(%arg0, %shape) : (tensor<?x10xf32>, tensor<3xi32>) -> tensor<*xf32>
 | 
					  %0 = "onnx.Reshape"(%arg0, %shape) : (tensor<?x10xf32>, tensor<3xi64>) -> tensor<*xf32>
 | 
				
			||||||
  return %0 : tensor<*xf32>
 | 
					  return %0 : tensor<*xf32>
 | 
				
			||||||
  // CHECK-LABEL: test_should_promote_to_attribute
 | 
					  // CHECK-LABEL: test_should_promote_to_attribute
 | 
				
			||||||
  // CHECK-NEXT: [[NONE:%.+]] = constant unit
 | 
					  // CHECK-NEXT: [[NONE:%.+]] = constant unit
 | 
				
			||||||
  // CHECK-NEXT: [[RESHAPE:%.+]] = "onnx.Reshape"(%{{.*}}, [[NONE]]) {shape = dense<[6, 7, 42]> : tensor<3xi32>} : (tensor<?x10xf32>, none) -> tensor<*xf32>
 | 
					  // CHECK-NEXT: [[RESHAPE:%.+]] = "onnx.Reshape"(%{{.*}}, [[NONE]]) {shape = dense<[6, 7, 42]> : tensor<3xi64>} : (tensor<?x10xf32>, none) -> tensor<*xf32>
 | 
				
			||||||
  // CHECK-NEXT: return [[RESHAPE]] : tensor<*xf32>
 | 
					  // CHECK-NEXT: return [[RESHAPE]] : tensor<*xf32>
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func @test_should_promote_to_attribute_1(%arg0 : tensor<?x10xf32>) -> tensor<*xf32> {
 | 
					func @test_should_promote_to_attribute_1(%arg0 : tensor<?x10xf32>) -> tensor<*xf32> {
 | 
				
			||||||
  %shape = "onnx.Constant"() { value =  dense<[6, 7, 42]> : tensor<3xi32>}: () -> tensor<3xi32>
 | 
					  %shape = "onnx.Constant"() { value =  dense<[6, 7, 42]> : tensor<3xi64>}: () -> tensor<3xi64>
 | 
				
			||||||
  %0 = "onnx.Reshape"(%arg0, %shape) : (tensor<?x10xf32>, tensor<3xi32>) -> tensor<*xf32>
 | 
					  %0 = "onnx.Reshape"(%arg0, %shape) : (tensor<?x10xf32>, tensor<3xi64>) -> tensor<*xf32>
 | 
				
			||||||
  return %0 : tensor<*xf32>
 | 
					  return %0 : tensor<*xf32>
 | 
				
			||||||
  // CHECK-LABEL: test_should_promote_to_attribute_1
 | 
					  // CHECK-LABEL: test_should_promote_to_attribute_1
 | 
				
			||||||
  // CHECK-NEXT: [[NONE:%.+]] = constant unit
 | 
					  // CHECK-NEXT: [[NONE:%.+]] = constant unit
 | 
				
			||||||
  // CHECK-NEXT: [[RESHAPE:%.+]] = "onnx.Reshape"(%{{.*}}, [[NONE]]) {shape = dense<[6, 7, 42]> : tensor<3xi32>} : (tensor<?x10xf32>, none) -> tensor<*xf32>
 | 
					  // CHECK-NEXT: [[RESHAPE:%.+]] = "onnx.Reshape"(%{{.*}}, [[NONE]]) {shape = dense<[6, 7, 42]> : tensor<3xi64>} : (tensor<?x10xf32>, none) -> tensor<*xf32>
 | 
				
			||||||
  // CHECK-NEXT: return [[RESHAPE]] : tensor<*xf32>
 | 
					  // CHECK-NEXT: return [[RESHAPE]] : tensor<*xf32>
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -29,25 +29,25 @@ func @test_should_not_promote_to_attribute(%arg0 : tensor<?x10xf32>, %arg1 : ten
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func @test_promote_to_attribute_without_removing_const_op(%arg0 : tensor<?x10xf32>) -> (tensor<*xf32>, tensor<*xf32>) {
 | 
					func @test_promote_to_attribute_without_removing_const_op(%arg0 : tensor<?x10xf32>) -> (tensor<*xf32>, tensor<*xf32>) {
 | 
				
			||||||
  %shape = constant dense<[6, 7, 42]> : tensor<3xi32>
 | 
					  %shape = constant dense<[6, 7, 42]> : tensor<3xi64>
 | 
				
			||||||
  %0 = "onnx.Reshape"(%arg0, %shape) : (tensor<?x10xf32>, tensor<3xi32>) -> tensor<*xf32>
 | 
					  %0 = "onnx.Reshape"(%arg0, %shape) : (tensor<?x10xf32>, tensor<3xi64>) -> tensor<*xf32>
 | 
				
			||||||
  %1 = "onnx.Identity"(%shape) : (tensor<3xi32>) -> tensor<*xf32>
 | 
					  %1 = "onnx.Identity"(%shape) : (tensor<3xi64>) -> tensor<*xf32>
 | 
				
			||||||
  "std.return"(%0, %1) : (tensor<*xf32>, tensor<*xf32>) -> ()
 | 
					  "std.return"(%0, %1) : (tensor<*xf32>, tensor<*xf32>) -> ()
 | 
				
			||||||
  // CHECK-LABEL: test_promote_to_attribute_without_removing_const_op
 | 
					  // CHECK-LABEL: test_promote_to_attribute_without_removing_const_op
 | 
				
			||||||
  // CHECK-NEXT: [[NONE:%.+]] = constant unit
 | 
					  // CHECK-NEXT: [[NONE:%.+]] = constant unit
 | 
				
			||||||
  // CHECK-NEXT: [[SHAPE:%.+]] = constant dense<[6, 7, 42]> : tensor<3xi32>
 | 
					  // CHECK-NEXT: [[SHAPE:%.+]] = constant dense<[6, 7, 42]> : tensor<3xi64>
 | 
				
			||||||
  // CHECK-NEXT: [[RESHAPE:%.+]] = "onnx.Reshape"(%{{.*}}, [[NONE]]) {shape = dense<[6, 7, 42]> : tensor<3xi32>} : (tensor<?x10xf32>, none) -> tensor<*xf32>
 | 
					  // CHECK-NEXT: [[RESHAPE:%.+]] = "onnx.Reshape"(%{{.*}}, [[NONE]]) {shape = dense<[6, 7, 42]> : tensor<3xi64>} : (tensor<?x10xf32>, none) -> tensor<*xf32>
 | 
				
			||||||
  // CHECK-NEXT: [[IDENTITY:%.+]] = "onnx.Identity"([[SHAPE]]) : (tensor<3xi32>) -> tensor<*xf32>
 | 
					  // CHECK-NEXT: [[IDENTITY:%.+]] = "onnx.Identity"([[SHAPE]]) : (tensor<3xi64>) -> tensor<*xf32>
 | 
				
			||||||
  // CHECK-NEXT: return [[RESHAPE]], [[IDENTITY]] : tensor<*xf32>, tensor<*xf32>
 | 
					  // CHECK-NEXT: return [[RESHAPE]], [[IDENTITY]] : tensor<*xf32>, tensor<*xf32>
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func @test_should_promote_to_attribute1(%arg0 : tensor<?x?xf32>) -> tensor<*xf32> {
 | 
					func @test_should_promote_to_attribute1(%arg0 : tensor<?x?xf32>) -> tensor<*xf32> {
 | 
				
			||||||
  %shape = constant dense<[0, 2,  2, 4]> : tensor<4xi32>
 | 
					  %shape = constant dense<[0, 2,  2, 4]> : tensor<4xi64>
 | 
				
			||||||
  %constant_value = constant dense<[0.]> : tensor<1xf32>
 | 
					  %constant_value = constant dense<[0.]> : tensor<1xf32>
 | 
				
			||||||
  %0 = "onnx.Pad"(%arg0, %shape, %constant_value) {mode = "constant"} : (tensor<?x?xf32>, tensor<4xi32>, tensor<1xf32>)-> tensor<*xf32>
 | 
					  %0 = "onnx.Pad"(%arg0, %shape, %constant_value) {mode = "constant"} : (tensor<?x?xf32>, tensor<4xi64>, tensor<1xf32>)-> tensor<*xf32>
 | 
				
			||||||
  return %0 : tensor<*xf32>
 | 
					  return %0 : tensor<*xf32>
 | 
				
			||||||
  // CHECK-LABEL: test_should_promote_to_attribute1
 | 
					  // CHECK-LABEL: test_should_promote_to_attribute1
 | 
				
			||||||
  // CHECK-NEXT: [[NONE:%.+]] = constant unit
 | 
					  // CHECK-NEXT: [[NONE:%.+]] = constant unit
 | 
				
			||||||
  // CHECK-NEXT: [[PAD:%.+]] = "onnx.Pad"(%{{.*}}, [[NONE]], [[NONE]]) {constant_value = dense<0.000000e+00> : tensor<1xf32>, mode = "constant", pads = dense<[0, 2, 2, 4]> : tensor<4xi32>} : (tensor<?x?xf32>, none, none) -> tensor<*xf32>
 | 
					  // CHECK-NEXT: [[PAD:%.+]] = "onnx.Pad"(%{{.*}}, [[NONE]], [[NONE]]) {constant_value = dense<0.000000e+00> : tensor<1xf32>, mode = "constant", pads = dense<[0, 2, 2, 4]> : tensor<4xi64>} : (tensor<?x?xf32>, none, none) -> tensor<*xf32>
 | 
				
			||||||
  // CHECK-NEXT: return [[RESHAPE]] : tensor<*xf32>
 | 
					  // CHECK-NEXT: return [[RESHAPE]] : tensor<*xf32>
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -320,10 +320,10 @@ custom_definition_misc = dict([ ('Constant',
 | 
				
			||||||
 | 
					
 | 
				
			||||||
onnx_types = (
 | 
					onnx_types = (
 | 
				
			||||||
    'bool', 'int8', 'int16', 'int32', 'int64', 'unkown', 'float16',
 | 
					    'bool', 'int8', 'int16', 'int32', 'int64', 'unkown', 'float16',
 | 
				
			||||||
    'float', 'double', 'complex64', 'complex128'
 | 
					    'float', 'double', 'complex64', 'complex128', 'string'
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
tblgen_types = ('I1', 'I8', 'I16', 'I32', 'I64', 'BF16', 'F16', 'F32', 'F64',
 | 
					tblgen_types = ('AnyI1', 'AnyI8', 'AnyI16', 'AnyI32', 'AnyI64', 'BF16', 'F16', 'F32', 'F64',
 | 
				
			||||||
    'Complex<F32>', 'Complex<F64>'
 | 
					    'Complex<F32>', 'Complex<F64>', 'StringType'
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
MAX_NUM_TYPES=20
 | 
					MAX_NUM_TYPES=20
 | 
				
			||||||
| 
						 | 
					@ -468,7 +468,7 @@ def dec_indent(indent):
 | 
				
			||||||
def join_args(args):
 | 
					def join_args(args):
 | 
				
			||||||
    return ", ".join(args)
 | 
					    return ", ".join(args)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def get_operands_or_results(schema, is_input):
 | 
					def get_operands_or_results(schema, type_str_dict,  is_input):
 | 
				
			||||||
    value_list = schema.inputs if is_input else schema.outputs
 | 
					    value_list = schema.inputs if is_input else schema.outputs
 | 
				
			||||||
    if not value_list:
 | 
					    if not value_list:
 | 
				
			||||||
        return OrderedDict()
 | 
					        return OrderedDict()
 | 
				
			||||||
| 
						 | 
					@ -482,7 +482,10 @@ def get_operands_or_results(schema, is_input):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    name_to_types = OrderedDict()
 | 
					    name_to_types = OrderedDict()
 | 
				
			||||||
    for i, value in enumerate(value_list):
 | 
					    for i, value in enumerate(value_list):
 | 
				
			||||||
        structure, elem_types = get_allowed_elem_types(schema, value)
 | 
					        types = get_onnx_mlir_types(schema, type_str_dict,  value)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        '''
 | 
				
			||||||
 | 
					        structure, elem_types = get_allowed_elem_types(schema, type_str_dict,  value)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if structure == 'tensor' :
 | 
					        if structure == 'tensor' :
 | 
				
			||||||
            if elem_types is None:
 | 
					            if elem_types is None:
 | 
				
			||||||
| 
						 | 
					@ -513,6 +516,7 @@ def get_operands_or_results(schema, is_input):
 | 
				
			||||||
                types = list(map(lambda x: x.format(elem_types_str), types))
 | 
					                types = list(map(lambda x: x.format(elem_types_str), types))
 | 
				
			||||||
        else:
 | 
					        else:
 | 
				
			||||||
            types = ["AnyMemRef", "AnyTensor"]
 | 
					            types = ["AnyMemRef", "AnyTensor"]
 | 
				
			||||||
 | 
					        '''
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # If operand is promotable to an attribute, then it must be
 | 
					        # If operand is promotable to an attribute, then it must be
 | 
				
			||||||
        # nullable in case it migrates to be an attribute.
 | 
					        # nullable in case it migrates to be an attribute.
 | 
				
			||||||
| 
						 | 
					@ -693,7 +697,68 @@ def get_type_inference_func(s, indent, type_inference_code):
 | 
				
			||||||
    indent = dec_indent(indent)
 | 
					    indent = dec_indent(indent)
 | 
				
			||||||
    return s
 | 
					    return s
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def parse_type_str(allowedType):
 | 
				
			||||||
 | 
					    # AnyI may be used for uint because the onnx_mlir is not generating uint output
 | 
				
			||||||
 | 
					    # This will be fixed later and UI will be replace AnyI
 | 
				
			||||||
 | 
					    onnx_to_mlir_type_dict = { '(': '<[',
 | 
				
			||||||
 | 
					        ')': ']>',
 | 
				
			||||||
 | 
					        'tensor' : 'TensorOf',
 | 
				
			||||||
 | 
					        'seq' : 'TensorOf',
 | 
				
			||||||
 | 
					        'map' : 'TupleOf',
 | 
				
			||||||
 | 
					        'bool': 'I1',
 | 
				
			||||||
 | 
					        #'uint8' : 'AnyI8',
 | 
				
			||||||
 | 
					        #uint16' : 'AnyI16',
 | 
				
			||||||
 | 
					        #uint32' : 'AnyI32',
 | 
				
			||||||
 | 
					        #uint64' : 'AnyI64',
 | 
				
			||||||
 | 
					        'uint8' : 'UI8',
 | 
				
			||||||
 | 
					        'uint16' : 'UI16',
 | 
				
			||||||
 | 
					        'uint32' : 'UI32',
 | 
				
			||||||
 | 
					        'uint64' : 'UI64',
 | 
				
			||||||
 | 
					        'int8' : 'I8',
 | 
				
			||||||
 | 
					        'int16' : 'I16',
 | 
				
			||||||
 | 
					        'int32' : 'I32',
 | 
				
			||||||
 | 
					        'int64' : 'I64',
 | 
				
			||||||
 | 
					        'float16' : 'F16',
 | 
				
			||||||
 | 
					        'float' : 'F32',
 | 
				
			||||||
 | 
					        'double' : 'F64',
 | 
				
			||||||
 | 
					        'unkown' : 'BF16',  
 | 
				
			||||||
 | 
					        'complex64' : 'Complex<F32>',
 | 
				
			||||||
 | 
					        'complex128' : 'Complex<F64>',
 | 
				
			||||||
 | 
					        'string' : 'StringType'}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    for key, item in onnx_to_mlir_type_dict.items():
 | 
				
			||||||
 | 
					        allowedType = allowedType.replace(key, item)
 | 
				
			||||||
 | 
					    return allowedType
 | 
				
			||||||
 | 
					          
 | 
				
			||||||
 | 
					def parse_a_type_constraint(constraint):
 | 
				
			||||||
 | 
					    allowedTypes = constraint.allowed_type_strs
 | 
				
			||||||
 | 
					    mlirTypes = []
 | 
				
			||||||
 | 
					    for allowedType in allowedTypes:
 | 
				
			||||||
 | 
					        mlirType = parse_type_str(allowedType)
 | 
				
			||||||
 | 
					        mlirTypes.append(mlirType)
 | 
				
			||||||
 | 
					    # Remove redundant and sort. 
 | 
				
			||||||
 | 
					    # However onnx keeps a consitently meaningful order
 | 
				
			||||||
 | 
					    # There is no redundancy as long as each onnx type is mapped uniquely
 | 
				
			||||||
 | 
					    # mlirTypes = sorted(list(set(mlirTypes)))
 | 
				
			||||||
 | 
					    return mlirTypes
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def parse_type_constraints(schema):
 | 
				
			||||||
 | 
					    type_str_dict = dict()
 | 
				
			||||||
 | 
					    for type_constraint in schema.type_constraints:
 | 
				
			||||||
 | 
					        type_str_dict[type_constraint.type_param_str]  = parse_a_type_constraint(type_constraint)
 | 
				
			||||||
 | 
					    return type_str_dict
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def get_onnx_mlir_types(schema, type_str_dict, input):
 | 
				
			||||||
 | 
					    if input.typeStr :
 | 
				
			||||||
 | 
					         if not input.typeStr in type_str_dict :
 | 
				
			||||||
 | 
					             # some arguments use type description directly
 | 
				
			||||||
 | 
					             # instead of constraint
 | 
				
			||||||
 | 
					             return [parse_type_str(input.typeStr)]
 | 
				
			||||||
 | 
					         else :
 | 
				
			||||||
 | 
					             return type_str_dict[input.typeStr]
 | 
				
			||||||
 | 
					    else :
 | 
				
			||||||
 | 
					        print('No typeStr ', schema.name)
 | 
				
			||||||
 | 
					        return []
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def gen_op_def(schema):
 | 
					def gen_op_def(schema):
 | 
				
			||||||
    indent = inc_indent()
 | 
					    indent = inc_indent()
 | 
				
			||||||
| 
						 | 
					@ -727,15 +792,20 @@ def gen_op_def(schema):
 | 
				
			||||||
            s += indent + '"{}"\n'.format(escaped_line)
 | 
					            s += indent + '"{}"\n'.format(escaped_line)
 | 
				
			||||||
    s += indent + '}];\n'
 | 
					    s += indent + '}];\n'
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # handle the type constraint for input and output
 | 
				
			||||||
 | 
					    # parse type constraint into onnx-mlir type string list
 | 
				
			||||||
 | 
					    type_str_dict =  parse_type_constraints(schema)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Generate ins (consisting of operands and attributes).
 | 
					    # Generate ins (consisting of operands and attributes).
 | 
				
			||||||
    ins = get_operands_or_results(schema, is_input=True)
 | 
					    ins = get_operands_or_results(schema, type_str_dict, is_input=True)
 | 
				
			||||||
    ins.update(get_attrs(schema))
 | 
					    ins.update(get_attrs(schema))
 | 
				
			||||||
    ins_strs = ["{1}:${0}".format(*i) for i in ins.items()]
 | 
					    ins_strs = ["{1}:${0}".format(*i) for i in ins.items()]
 | 
				
			||||||
    s += indent + 'let arguments = (ins {});\n'.format(
 | 
					    s += indent + 'let arguments = (ins {});\n'.format(
 | 
				
			||||||
        (',\n' + inc_indent(indent)).join(ins_strs))
 | 
					        (',\n' + inc_indent(indent)).join(ins_strs))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Generate outs (operation results).
 | 
					    # Generate outs (operation results).
 | 
				
			||||||
    outs = get_operands_or_results(schema, is_input=False)
 | 
					    outs = get_operands_or_results(schema, type_str_dict, is_input=False)
 | 
				
			||||||
    outs_strs = ["{1}:${0}".format(*i) for i in outs.items()]
 | 
					    outs_strs = ["{1}:${0}".format(*i) for i in outs.items()]
 | 
				
			||||||
    s += indent + 'let results = (outs {});\n'.format(
 | 
					    s += indent + 'let results = (outs {});\n'.format(
 | 
				
			||||||
        (',\n' + inc_indent(indent)).join(outs_strs))
 | 
					        (',\n' + inc_indent(indent)).join(outs_strs))
 | 
				
			||||||
| 
						 | 
					@ -756,7 +826,7 @@ def gen_op_def(schema):
 | 
				
			||||||
            #   Value, Y, Attribute A", [{}]>
 | 
					            #   Value, Y, Attribute A", [{}]>
 | 
				
			||||||
            indent = inc_indent(indent)
 | 
					            indent = inc_indent(indent)
 | 
				
			||||||
            s += indent + 'OpBuilder<"OpBuilder &builder, OperationState &state'
 | 
					            s += indent + 'OpBuilder<"OpBuilder &builder, OperationState &state'
 | 
				
			||||||
            operands_dict = get_operands_or_results(schema, is_input=True)
 | 
					            operands_dict = get_operands_or_results(schema, type_str_dict,  is_input=True)
 | 
				
			||||||
            for name, ty in operands_dict.items():
 | 
					            for name, ty in operands_dict.items():
 | 
				
			||||||
                s += ', {} {}'.format(tblgen_operand_type_to_cpp_type(ty),
 | 
					                s += ', {} {}'.format(tblgen_operand_type_to_cpp_type(ty),
 | 
				
			||||||
                                      name)
 | 
					                                      name)
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue