diff --git a/src/Builder/FrontendDialectTransformer.cpp b/src/Builder/FrontendDialectTransformer.cpp index 78a12a0..ecae7b0 100644 --- a/src/Builder/FrontendDialectTransformer.cpp +++ b/src/Builder/FrontendDialectTransformer.cpp @@ -157,6 +157,7 @@ private: result.addTypes(mlir::UnrankedTensorType::get(builder_.getF32Type())); } result.addOperands(inputs); + result.addAttributes(ImportNodeAttributes(node)); auto op = builder_.createOperation(result); for (int i = 0; i < node.output().size(); i++) { auto r = op->getResult(i); diff --git a/src/Dialect/ONNX/ONNXOps.cpp b/src/Dialect/ONNX/ONNXOps.cpp index d2ea7c6..367328b 100644 --- a/src/Dialect/ONNX/ONNXOps.cpp +++ b/src/Dialect/ONNX/ONNXOps.cpp @@ -19,6 +19,7 @@ #include "mlir/IR/PatternMatch.h" #include "llvm/ADT/SetVector.h" #include "llvm/ADT/SmallBitVector.h" +#include "llvm/Support/FormatVariadic.h" #include "ONNXOps.hpp" @@ -436,6 +437,37 @@ static LogicalResult RNNShapeInference(T *op) { return success(); } +static void insertConvTransposeSpatialDim(SmallVectorImpl &outputDims, + ArrayRef xShape, Optional kernelShape, + Optional padsOpt, Optional stridesOpt, + Optional outputPadsOpt, Optional outputShapeOpt, + Optional dilationsOpt = llvm::None, bool ceilMode = false) { + auto xRank = xShape.size(); + auto spatialRank = ArrayAttrSize(kernelShape); + auto spatialOffset = xRank - spatialRank; + + int64_t dilationVal = 1; + int64_t outputPadsVal = 0; + // output_shape[i] = stride[i] * (input_size[i] - 1) + output_padding[i] + + // ((kernel_shape[i] - 1) * dilations[i] + 1) - pads[start_i] - pads[end_i] + for (int i = 0; i < spatialRank; ++i) { + auto inputSize = xShape[spatialOffset + i]; + auto sumOfPads = + ArrayAttrIntVal(padsOpt, i) + ArrayAttrIntVal(padsOpt, spatialRank + i); + auto kernelSize = ArrayAttrIntVal(kernelShape, i); + if (dilationsOpt.hasValue()) + dilationVal = ArrayAttrIntVal(dilationsOpt, i); + auto strideVal = ArrayAttrIntVal(stridesOpt, i); + if (outputPadsOpt.hasValue()) + outputPadsVal = ArrayAttrIntVal(outputPadsOpt, i); + // Number of useful values: input plus pad - effective size of kernel (see + // processConvTypeParams comments to see how this value is derived). + int64_t res = strideVal * (inputSize - 1) + outputPadsVal + + ((kernelSize - 1) * dilationVal + 1) - sumOfPads; + outputDims.emplace_back(res); + } +} + //===----------------------------------------------------------------------===// // ONNXOpsDialect //===----------------------------------------------------------------------===// @@ -482,6 +514,24 @@ LogicalResult ONNXExpOp::inferShapes() { return success(); } +//===----------------------------------------------------------------------===// +// Atan +/// Infer the output shape of the ONNXAtanOp. This method is required by the +/// shape inference interface. +LogicalResult ONNXAtanOp::inferShapes() { + getResult().setType(getOperand().getType()); + return success(); +} + +//===----------------------------------------------------------------------===// +// Tan +/// Infer the output shape of the ONNXTanOp. This method is required by the +/// shape inference interface. +LogicalResult ONNXTanOp::inferShapes() { + getResult().setType(getOperand().getType()); + return success(); +} + //===----------------------------------------------------------------------===// // Tanh /// Infer the output shape of the ONNXTanhOp. This method is required by the @@ -491,6 +541,15 @@ LogicalResult ONNXTanhOp::inferShapes() { return success(); } +//===----------------------------------------------------------------------===// +// Sin +/// Infer the output shape of the ONNXSinOp. This method is required by the +/// shape inference interface. +LogicalResult ONNXSinOp::inferShapes() { + getResult().setType(getOperand().getType()); + return success(); +} + //===----------------------------------------------------------------------===// // Sinh /// Infer the output shape of the ONNXSinhOp. This method is required by the @@ -1316,6 +1375,138 @@ LogicalResult ONNXConvOp::inferShapes() { //===----------------------------------------------------------------------===// +// ConvTranspose + +// For this operation, we define the attributes once in the original Conv +// operation class. There is no need to redefine the attribute names for the +// other classes based on Conv. +// Conv attributes output: +// - auto_pad set to NOTSET; +// - dilations, strides: set to 1 if not defined by user; +// - kernelShape: inferred from weight matrix if not defined by user; +// - pads: set to proper value, 0 if not defined by user. + +LogicalResult ONNXConvTransposeOp::inferShapes() { + // Generic shape for data input X, weight tensor W, and optional bias B + // X: (N x C x D1 x D2 ... x Dn) + // W: (M x C/group x k1 x k2 x ... x kn) + // B: (M) Optional + + bool hasBias = !B().getType().isa(); + + // Cannot infer shape if no shape exists. + if (!X().getType().isa() || + !W().getType().isa() || + (hasBias && !B().getType().isa())) { + return emitError("Input tensor not ranked"); + } + + auto xTy = X().getType().cast(); + auto xShape = xTy.getShape(); + auto weightTy = W().getType().cast(); + auto weightShape = weightTy.getShape(); + auto builder = mlir::Builder(this->getContext()); + + // Lowest supported convolution is a one dimensional convolution. + if (xShape.size() < 3) { + return emitError("Data input shape must be at least (NxCxD1)"); + } + + // Check that shape of weight and data have same length. + if (xShape.size() != weightShape.size()) { + return emitError("Weight size not compatible with data size"); + } + + // Group is a required attribute and should have default value of 1. + int64_t group = ONNXConvTransposeOp::group().getSExtValue(); + + // Check if the attribute actually exists. If it does not then add it. + if (!groupAttr()) + groupAttr(builder.getI64IntegerAttr(group)); + + // Check that the X.shape[1] == (W.shape[0] * group) == C condition holds. + if (xShape[1] != -1 && weightShape[0] != -1 && + xShape[1] != (weightShape[0] * group)) { + return emitError("Channel dimension mismatch"); + } + + // Check the size of bias. + if (hasBias) { + auto bTx = B().getType().cast(); + auto bShape = bTx.getShape(); + if (bShape.size() != 1) { + return emitError("bias should be one dimensional"); + } + if (bShape[0] != weightShape[1]) { + return emitError( + "bias should have same dimensions as weight's second dimension"); + } + } + + // Note: the value of the group attribut only impacts the way the + // computation is carried out and not the actual output size. + + // Number of spatial dimensions. + auto spatialOffset = 2; + int32_t spatialRank = xShape.size() - spatialOffset; + + // Use kernel_shape attribute if present otherwise use size from weight + // argument. + auto kernelShape = kernel_shape(); + if (kernelShape.hasValue()) { + if (ArrayAttrSize(kernelShape) != spatialRank) { + return emitError( + "kernel_shape length incompatible with spatial dimensions"); + } + // Have the right number of values, check them. + for (int i = 0; i < spatialRank; ++i) + if (ArrayAttrIntVal(kernelShape, i) < 1) { + return emitError("bad kernel_shape value"); + } + } else { + // Deduce shape from weight input. + SmallVector defaultVals; + for (int i = 0; i < spatialRank; ++i) + defaultVals.emplace_back(weightShape[spatialOffset + i]); + // Convert to ArrayRef, then build attribute, then store attribute. + ArrayRef defaultRefs(defaultVals); + auto builder = mlir::Builder(getContext()); + kernel_shapeAttr(builder.getI64ArrayAttr(defaultRefs)); + kernelShape = kernel_shape(); + } + + // Process strides, dilations, and pads. + processConvTypeParams<>(this, X()); + auto dilationsOpt = dilations(); + auto stridesOpt = strides(); + auto padsOpt = pads(); + auto outputPads = output_padding(); + auto outputShape = output_shape(); + // TODO: handle the spatial dimension computation if output shape is specified + assert(!outputShape.hasValue() && "unhandled option in ConvTranspose"); + + // First two output dimensions consist of the number of batches and the + // number of kernels being applied. + SmallVector outputDims; + // Insert batch size. + outputDims.emplace_back(xShape[0]); + // Insert number of filters being applied (number of output channels). + outputDims.emplace_back(weightShape[1]); + // Compute and insert spatial dims. + insertConvTransposeSpatialDim(outputDims, xShape, kernelShape, padsOpt, + stridesOpt, outputPads, outputShape, dilationsOpt); + + // Set the output shape if it's not already set + if (!outputShape.hasValue()) { + output_shapeAttr(builder.getI64ArrayAttr(outputDims)); + } + + getResult().setType(RankedTensorType::get(outputDims, xTy.getElementType())); + return success(); +} + +//===----------------------------------------------------------------------===// + // AveragePool // Infer shape attributes output: // - auto_pad set to NOTSET; @@ -1561,6 +1752,34 @@ LogicalResult ONNXUnsqueezeOp::inferShapes() { return success(); } +//===----------------------------------------------------------------------===// +// Cast + +LogicalResult ONNXCastOp::inferShapes() { + ShapedType inputType = input().getType().dyn_cast(); + if (!inputType) { + return emitError("Non-shaped input type"); + } + + auto getOutputType = [&inputType](Type elementType) -> Type { + if (inputType.hasRank()) { + return RankedTensorType::get(inputType.getShape(), elementType); + } + return UnrankedTensorType::get(elementType); + }; + + int64_t targetType = toAttr().getInt(); + OpBuilder builder(getContext()); + if (auto elementType = convertONNXTypeToMLIRType( + builder, static_cast(targetType))) { + getResult().setType(getOutputType(elementType)); + } else { + return emitOpError("Unable to get the element type for to = " + + std::to_string(targetType)); + } + return success(); +} + //===----------------------------------------------------------------------===// // Constant @@ -1583,7 +1802,7 @@ LogicalResult ONNXConstantOp::inferShapes() { LogicalResult ONNXConcatOp::inferShapes() { int inputNum = getNumOperands(); for (int i = 0; i < inputNum; ++i) { - if (!getOperand(i).getType().cast()) + if (!getOperand(i).getType().isa()) return emitError("Input tensor(s) not ranked"); } // Checking value of axis parameter. @@ -1713,6 +1932,219 @@ LogicalResult ONNXSplitOp::inferShapes() { return success(); } +//===----------------------------------------------------------------------===// +// Flatten + +LogicalResult ONNXFlattenOp::inferShapes() { + assert(axis() == 1 && "ONNXFlattenOp can only handle axis=1 for now"); + auto inTy = input().getType().dyn_cast(); + if (!inTy) { + return emitOpError("Input is a non-shaped type"); + } + auto outTy = output().getType().dyn_cast(); + if (!outTy) { + return emitOpError("Output is a non-shaped type"); + } + + // TODO(tjingrant): Seems like we can also fairly easily support the case + // where the batch dimension is dynamic + if (!outTy.hasStaticShape()) { + auto inShape = inTy.getShape(); + assert(inShape.size() >= 1 && "ONNXFlattenOp inShape.size() should be > 0"); + uint64_t outDim = 1; + for (auto it = inShape.begin() + 1; it < inShape.end(); it++) { + outDim *= *it; + } + + SmallVector dims; + // https://pytorch.org/docs/master/generated/torch.nn.Flatten.html + dims.emplace_back(inShape[0]); + dims.emplace_back(outDim); + getResult().setType(RankedTensorType::get(dims, outTy.getElementType())); + } + + return success(); +} + +//===----------------------------------------------------------------------===// +// DynamicQuantizeLinear + +LogicalResult ONNXDynamicQuantizeLinearOp::inferShapes() { + auto inTy = x().getType().dyn_cast(); + if (!inTy || !inTy.hasStaticShape()) { + return emitOpError("Input is not a statically-shaped type"); + } + + auto yTy = y().getType().cast(); + auto yScaleTy = y_scale().getType().cast(); + auto yZPTy = y_zero_point().getType().cast(); + + IntegerType i8Type = IntegerType::get(8, getContext()); + RankedTensorType scalarType = RankedTensorType::get({}, i8Type); + + // Set the types for the scalars + if (!yScaleTy.hasStaticShape()) { + y_scale().setType(scalarType); + } + + if (!yZPTy.hasStaticShape()) { + y_zero_point().setType(scalarType); + } + + if (!yTy.hasStaticShape()) { + RankedTensorType outType = RankedTensorType::get(inTy.getShape(), i8Type); + y().setType(outType); + } + + return success(); +} + +//===----------------------------------------------------------------------===// +// QuantizeLinear + +LogicalResult ONNXQuantizeLinearOp::inferShapes() { + auto inTy = x().getType().dyn_cast(); + if (!inTy || !inTy.hasStaticShape()) { + return emitOpError("Input is not a statically-shaped type"); + } + + auto yTy = y().getType().cast(); + + if (!yTy.hasStaticShape()) { + // TODO: Unfortunately, we can't tell if this should be signed or unsigned + // here... + IntegerType i8Type = IntegerType::get(8, getContext()); + RankedTensorType outType = RankedTensorType::get(inTy.getShape(), i8Type); + y().setType(outType); + } + + return success(); +} + +//===----------------------------------------------------------------------===// +// DequantizeLinear + +LogicalResult ONNXDequantizeLinearOp::inferShapes() { + auto inTy = x().getType().dyn_cast(); + if (!inTy || !inTy.hasStaticShape()) { + return emitOpError("Input is not a statically-shaped type"); + } + + auto yTy = y().getType().cast(); + + if (!yTy.hasStaticShape()) { + FloatType f32 = FloatType::getF32(getContext()); + RankedTensorType outType = RankedTensorType::get(inTy.getShape(), f32); + y().setType(outType); + } + + return success(); +} + +//===----------------------------------------------------------------------===// +// ConvInteger - copied almost exactly from Conv (X -> x, W -> w, no bias) + +LogicalResult ONNXConvIntegerOp::inferShapes() { + // Generic shape for data input X, weight tensor W + // X: (N x C x D1 x D2 ... x Dn) + // W: (M x C/group x k1 x k2 x ... x kn) + + // Cannot infer shape if no shape exists. + if (!x().getType().isa() || + !w().getType().isa()) { + return emitOpError("Input tensor not ranked"); + } + + auto xTy = x().getType().cast(); + if (!xTy.getElementType().isInteger(8)) { + return emitOpError("Invalid input type"); + } + auto xShape = xTy.getShape(); + auto weightTy = w().getType().cast(); + if (!weightTy.getElementType().isInteger(8)) { + return emitOpError("Invalid input type"); + } + auto weightShape = weightTy.getShape(); + auto builder = mlir::Builder(this->getContext()); + + // Lowest supported convolution is a one dimensional convolution. + if (xShape.size() < 3) { + return emitOpError("Data input shape must be at least (NxCxD1)"); + } + + // Check that shape of weight and data have same length. + if (xShape.size() != weightShape.size()) { + return emitError("Weight size not compatible with data size"); + } + + // Group is a required attribute and should have default value of 1. + int64_t group = ONNXConvIntegerOp::group().getSExtValue(); + + // Check if the attribute actually exists. If it does not then add it. + if (!groupAttr()) + groupAttr(builder.getI64IntegerAttr(group)); + + // Check that the X.shape[1] == (W.shape[1] * group) == C condition holds. + if (xShape[1] != -1 && weightShape[1] != -1 && + xShape[1] != (weightShape[1] * group)) { + return emitOpError("Channel dimension mismatch"); + } + + // Note: the value of the group attribut only impacts the way the + // computation is carried out and not the actual output size. + + // Number of spatial dimensions. + auto spatialOffset = 2; + int32_t spatialRank = xShape.size() - spatialOffset; + + // Use kernel_shape attribute if present otherwise use size from weight + // argument. + auto kernelShape = kernel_shape(); + if (kernelShape.hasValue()) { + if (ArrayAttrSize(kernelShape) != spatialRank) { + return emitOpError( + "kernel_shape length incompatible with spatial dimensions"); + } + // Have the right number of values, check them. + for (int i = 0; i < spatialRank; ++i) + if (ArrayAttrIntVal(kernelShape, i) < 1) { + return emitError("bad kernel_shape value"); + } + } else { + // Deduce shape from weight input. + SmallVector defaultVals; + for (int i = 0; i < spatialRank; ++i) + defaultVals.emplace_back(weightShape[spatialOffset + i]); + // Convert to ArrayRef, then build attribute, then store attribute. + ArrayRef defaultRefs(defaultVals); + auto builder = mlir::Builder(getContext()); + kernel_shapeAttr(builder.getI64ArrayAttr(defaultRefs)); + kernelShape = kernel_shape(); + } + + // Process strides, dilations, and pads. + processConvTypeParams<>(this, x()); + auto dilationsOpt = dilations(); + auto stridesOpt = strides(); + auto padsOpt = pads(); + + // First two output dimensions consist of the number of batches and the + // number of kernels being applied. + SmallVector outputDims; + // Insert batch size. + outputDims.emplace_back(xShape[0]); + // Insert number of filters being applied (number of output channels). + outputDims.emplace_back(weightShape[0]); + // Compute and insert spatial dims. + insertConvSpatialDim(&outputDims, builder, xShape, kernelShape, padsOpt, + stridesOpt, dilationsOpt); + + // ONNX spec specifies the output type as an int32 + Type outputType = IntegerType::get(32, getContext()); + getResult().setType(RankedTensorType::get(outputDims, outputType)); + return success(); +} + //===----------------------------------------------------------------------===// // TableGen'd op method definitions //===----------------------------------------------------------------------===// diff --git a/src/Dialect/ONNX/ONNXOps.td.inc b/src/Dialect/ONNX/ONNXOps.td.inc index 5ac237b..b082fd8 100644 --- a/src/Dialect/ONNX/ONNXOps.td.inc +++ b/src/Dialect/ONNX/ONNXOps.td.inc @@ -278,7 +278,7 @@ def ONNXAsinhOp:ONNX_Op<"Asinh", } def ONNXAtanOp:ONNX_Op<"Atan", - [NoSideEffect]> { + [NoSideEffect, DeclareOpInterfaceMethods]> { let summary = "ONNX Atan operation"; let description = [{ "Calculates the arctangent (inverse of tangent) of the given input tensor, element-wise." @@ -449,7 +449,7 @@ def ONNXBitShiftOp:ONNX_Op<"BitShift", } def ONNXCastOp:ONNX_Op<"Cast", - [NoSideEffect, OpInterface<"ResultTypeInferenceOpInterface">]> { + [NoSideEffect, DeclareOpInterfaceMethods, OpInterface<"ResultTypeInferenceOpInterface">]> { let summary = "ONNX Cast operation"; let description = [{ "The operator casts the elements of a given input tensor to a data type" @@ -715,7 +715,7 @@ def ONNXConvOp:ONNX_Op<"Conv", } def ONNXConvIntegerOp:ONNX_Op<"ConvInteger", - [NoSideEffect]> { + [NoSideEffect, DeclareOpInterfaceMethods]> { let summary = "ONNX ConvInteger operation"; let description = [{ "The integer convolution operator consumes an input tensor, its zero-point, a filter, and its zero-point," @@ -746,7 +746,7 @@ def ONNXConvIntegerOp:ONNX_Op<"ConvInteger", } def ONNXConvTransposeOp:ONNX_Op<"ConvTranspose", - [NoSideEffect]> { + [NoSideEffect, DeclareOpInterfaceMethods]> { let summary = "ONNX ConvTranspose operation"; let description = [{ "The convolution transpose operator consumes an input tensor and a filter," @@ -924,7 +924,7 @@ def ONNXDepthToSpaceOp:ONNX_Op<"DepthToSpace", } def ONNXDequantizeLinearOp:ONNX_Op<"DequantizeLinear", - [NoSideEffect]> { + [NoSideEffect, DeclareOpInterfaceMethods]> { let summary = "ONNX DequantizeLinear operation"; let description = [{ "The linear dequantization operator. It consumes a quantized tensor, a scale, a zero point to compute the full precision tensor." @@ -1053,7 +1053,7 @@ def ONNXDropoutOp:ONNX_Op<"Dropout", } def ONNXDynamicQuantizeLinearOp:ONNX_Op<"DynamicQuantizeLinear", - [NoSideEffect]> { + [NoSideEffect, DeclareOpInterfaceMethods]> { let summary = "ONNX DynamicQuantizeLinear operation"; let description = [{ "A Function to fuse calculation for Scale, Zero Point and FP32->8Bit convertion of FP32 Input data." @@ -1285,7 +1285,7 @@ def ONNXEyeLikeOp:ONNX_Op<"EyeLike", } def ONNXFlattenOp:ONNX_Op<"Flatten", - [NoSideEffect]> { + [NoSideEffect, DeclareOpInterfaceMethods]> { let summary = "ONNX Flatten operation"; let description = [{ "Flattens the input tensor into a 2D matrix. If input tensor has shape" @@ -3327,7 +3327,7 @@ def ONNXQLinearMatMulOp:ONNX_Op<"QLinearMatMul", } def ONNXQuantizeLinearOp:ONNX_Op<"QuantizeLinear", - [NoSideEffect]> { + [NoSideEffect, DeclareOpInterfaceMethods]> { let summary = "ONNX QuantizeLinear operation"; let description = [{ "The linear per-tensor/layer quantization operator. It consumes a high precision tensor, a scale, a zero point to compute the low precision / quantized tensor." @@ -4787,7 +4787,7 @@ def ONNXSignOp:ONNX_Op<"Sign", } def ONNXSinOp:ONNX_Op<"Sin", - [NoSideEffect]> { + [NoSideEffect, DeclareOpInterfaceMethods]> { let summary = "ONNX Sin operation"; let description = [{ "Calculates the sine of the given input tensor, element-wise." @@ -5223,7 +5223,7 @@ def ONNXSumOp:ONNX_Op<"Sum", } def ONNXTanOp:ONNX_Op<"Tan", - [NoSideEffect]> { + [NoSideEffect, DeclareOpInterfaceMethods]> { let summary = "ONNX Tan operation"; let description = [{ "Calculates the tangent of the given input tensor, element-wise." diff --git a/src/Dialect/ONNX/ONNXOpsHelper.cpp b/src/Dialect/ONNX/ONNXOpsHelper.cpp index f7f8b9c..5961495 100644 --- a/src/Dialect/ONNX/ONNXOpsHelper.cpp +++ b/src/Dialect/ONNX/ONNXOpsHelper.cpp @@ -44,6 +44,7 @@ AffineMap getConvDimMap(Builder &builder, bool ceilMode) { // Convert type to MLIR type. // A complete list of types can be found in: // /third_party/onnx/onnx/onnx.pb.h +// TODO: Update Int*/Uint* to emit signed/unsigned MLIR types mlir::Type convertONNXTypeToMLIRType( mlir::OpBuilder &builder_, onnx::TensorProto_DataType onnxType) { switch (onnxType) { diff --git a/src/Transform/ONNX/ShapeInferencePass.cpp b/src/Transform/ONNX/ShapeInferencePass.cpp index b156afd..2e67d17 100644 --- a/src/Transform/ONNX/ShapeInferencePass.cpp +++ b/src/Transform/ONNX/ShapeInferencePass.cpp @@ -37,8 +37,7 @@ public: if (returnsDynamicShape(op)) { if (auto shape_op = dyn_cast(op)) { if (failed(shape_op.inferShapes())) { - op->emitError("unable to infer shape of operation without shape " - "inference method"); + op->emitError("shape inference failed"); return signalPassFailure(); } } else { @@ -79,7 +78,10 @@ public: // shaped outputs. All those operation need to implement the inferShape() // method. if (op->getName().getStringRef() != "onnx.Exp" && + op->getName().getStringRef() != "onnx.Atan" && + op->getName().getStringRef() != "onnx.Tan" && op->getName().getStringRef() != "onnx.Tanh" && + op->getName().getStringRef() != "onnx.Sin" && op->getName().getStringRef() != "onnx.Sinh" && op->getName().getStringRef() != "onnx.Cosh" && op->getName().getStringRef() != "onnx.Cos" && @@ -130,7 +132,14 @@ public: op->getName().getStringRef() != "onnx.RNN" && op->getName().getStringRef() != "onnx.LSTM" && op->getName().getStringRef() != "onnx.GRU" && - op->getName().getStringRef() != "onnx.Unsqueeze") + op->getName().getStringRef() != "onnx.Unsqueeze" && + op->getName().getStringRef() != "onnx.Cast" && + op->getName().getStringRef() != "onnx.ConvTranspose" && + op->getName().getStringRef() != "onnx.Flatten" && + op->getName().getStringRef() != "onnx.DynamicQuantizeLinear" && + op->getName().getStringRef() != "onnx.QuantizeLinear" && + op->getName().getStringRef() != "onnx.DequantizeLinear" && + op->getName().getStringRef() != "onnx.ConvInteger") return false; return llvm::any_of(op->getResultTypes(), [](Type result_type) { return !result_type.isa() && diff --git a/test/mlir/onnx/onnx_shape_inference.mlir b/test/mlir/onnx/onnx_shape_inference.mlir index ef2ba33..67a695a 100644 --- a/test/mlir/onnx/onnx_shape_inference.mlir +++ b/test/mlir/onnx/onnx_shape_inference.mlir @@ -589,6 +589,19 @@ func @test_reshape_3(%arg0 : tensor<5x5x1x32xf32>) -> tensor<*xf32> { // ----- +//===----------------------------------------------------------------------===// +/// Test the flatten op inference. +//===----------------------------------------------------------------------===// + +func @test_flatten_1(%arg0 : tensor<5x2x3x4xf32>) -> tensor<*xf32> { + %1 = "onnx.Flatten"(%arg0) {axis = 1 : i64} : (tensor<5x2x3x4xf32>) -> tensor<*xf32> + "std.return"(%1) : (tensor<*xf32>) -> () + + // CHECK-LABEL: test_flatten_1 + // CHECK: [[RES:%.+]] = "onnx.Flatten"(%arg0) {axis = 1 : i64} : (tensor<5x2x3x4xf32>) -> tensor<5x24xf32> + // CHECK: return [[RES]] : tensor<5x24xf32> +} + //===----------------------------------------------------------------------===// /// Test the reshape op inference when concat are present. //===----------------------------------------------------------------------===// @@ -872,3 +885,210 @@ func @test_split_3(%arg0 : tensor<16x32x64xf32>) -> tensor<*xf32> { // CHECK: [[RES:%.+]]:2 = "onnx.Split"(%arg0) {axis = 1 : i64, split = [2, 30]} : (tensor<16x32x64xf32>) -> (tensor<16x2x64xf32>, tensor<16x30x64xf32>) // CHECK: return [[RES]]#0 : tensor<16x2x64xf32> } + +//===----------------------------------------------------------------------===// +/// Test the cast op inference. +//===----------------------------------------------------------------------===// + +func @test_cast_1(%arg0 : tensor<2x3x4xf32>) -> tensor<*xf32> { + %1 = "onnx.Cast"(%arg0) {to = 1} : (tensor<2x3x4xf32>) -> tensor<*xf32> + "std.return"(%1) : (tensor<*xf32>) -> () + + // CHECK-LABEL: test_cast_1 + // CHECK: [[RES:%.+]] = "onnx.Cast"(%arg0) {to = 1 : i64} : (tensor<2x3x4xf32>) -> tensor<2x3x4xf32> + // CHECK: return [[RES]] : tensor<2x3x4xf32> +} + +func @test_cast_2(%arg0 : tensor<2x3x4xf32>) -> tensor<*xui8> { + %1 = "onnx.Cast"(%arg0) {to = 2} : (tensor<2x3x4xf32>) -> tensor<*xui8> + "std.return"(%1) : (tensor<*xui8>) -> () + + // CHECK-LABEL: test_cast_2 + // CHECK: [[RES:%.+]] = "onnx.Cast"(%arg0) {to = 2 : i64} : (tensor<2x3x4xf32>) -> tensor<2x3x4xi8> + // CHECK: return [[RES]] : tensor<2x3x4xi8> +} + +func @test_cast_3(%arg0 : tensor<2x3x4xf32>) -> tensor<*xsi8> { + %1 = "onnx.Cast"(%arg0) {to = 3} : (tensor<2x3x4xf32>) -> tensor<*xsi8> + "std.return"(%1) : (tensor<*xsi8>) -> () + + // CHECK-LABEL: test_cast_3 + // CHECK: [[RES:%.+]] = "onnx.Cast"(%arg0) {to = 3 : i64} : (tensor<2x3x4xf32>) -> tensor<2x3x4xi8> + // CHECK: return [[RES]] : tensor<2x3x4xi8> +} + +func @test_cast_10(%arg0 : tensor<2x3x4xf32>) -> tensor<*xf16> { + %1 = "onnx.Cast"(%arg0) {to = 10} : (tensor<2x3x4xf32>) -> tensor<*xf16> + "std.return"(%1) : (tensor<*xf16>) -> () + + // CHECK-LABEL: test_cast_10 + // CHECK: [[RES:%.+]] = "onnx.Cast"(%arg0) {to = 10 : i64} : (tensor<2x3x4xf32>) -> tensor<2x3x4xf16> + // CHECK: return [[RES]] : tensor<2x3x4xf16> +} + +//===----------------------------------------------------------------------===// +/// Test the quantization op inferences. +//===----------------------------------------------------------------------===// + +func @test_dyn_quantize_linear_1(%arg0 : tensor<5x2x3x4xf32>) -> tensor<*xi8> { + %1:3 = "onnx.DynamicQuantizeLinear"(%arg0) {} : (tensor<5x2x3x4xf32>) -> (tensor<*xi8>, tensor<*xi8>, tensor<*xi8>) + "std.return"(%1#0) {} : (tensor<*xi8>) -> () + + // CHECK-LABEL: test_dyn_quantize_linear_1 + // CHECK: [[RES:%.+]], {{.*}}, {{.*}} = "onnx.DynamicQuantizeLinear"(%arg0) : (tensor<5x2x3x4xf32>) -> (tensor<5x2x3x4xi8>, tensor, tensor) + // CHECK: return [[RES]] : tensor<5x2x3x4xi8> +} + +func @test_quantize_linear_1(%arg0 : tensor<5x2x3x4xf32>, %arg1 : tensor, %arg2 : tensor) -> tensor<*xi8> { + %1 = "onnx.QuantizeLinear"(%arg0, %arg1, %arg2) {} : (tensor<5x2x3x4xf32>, tensor, tensor) -> tensor<*xi8> + "std.return"(%1) {} : (tensor<*xi8>) -> () + + // CHECK-LABEL: test_quantize_linear_1 + // CHECK: [[RES:%.+]] = "onnx.QuantizeLinear"(%arg0, %arg1, %arg2) : (tensor<5x2x3x4xf32>, tensor, tensor) -> tensor<5x2x3x4xi8> + // CHECK: return [[RES]] : tensor<5x2x3x4xi8> +} + +func @test_dequantize_linear_1(%arg0 : tensor<5x2x3x4xi8>, %arg1 : tensor, %arg2 : tensor) -> tensor<*xf32> { + %1 = "onnx.DequantizeLinear"(%arg0, %arg1, %arg2) {} : (tensor<5x2x3x4xi8>, tensor, tensor) -> tensor<*xf32> + "std.return"(%1) {} : (tensor<*xf32>) -> () + + // CHECK-LABEL: test_dequantize_linear_1 + // CHECK: [[RES:%.+]] = "onnx.DequantizeLinear"(%arg0, %arg1, %arg2) : (tensor<5x2x3x4xi8>, tensor, tensor) -> tensor<5x2x3x4xf32> + // CHECK: return [[RES]] : tensor<5x2x3x4xf32> +} + +//===----------------------------------------------------------------------===// +/// Test shape inference for ConvInteger operation and all its attributes. +//===----------------------------------------------------------------------===// + +/// Default and required attributes for 1-D convolution. + +func @test_convinteger_0(%arg0 : tensor<1x2x32xi8>, %arg1 : tensor<5x2x6xi8>, %arg2 : tensor, %arg3 : tensor) -> tensor<*xi32> { + %0 = "onnx.ConvInteger"(%arg0, %arg1, %arg2, %arg3) {auto_pad = "NOTSET", group = 1 : i64} : (tensor<1x2x32xi8>, tensor<5x2x6xi8>, tensor, tensor) -> tensor<*xi32> + "std.return"(%0) : (tensor<*xi32>) -> () + + // CHECK-LABEL: test_convinteger_0 + // CHECK: [[RES_ATTR:%.+]] = "onnx.ConvInteger"(%arg0, %arg1, %arg2, %arg3) {auto_pad = "NOTSET", dilations = [1], group = 1 : i64, kernel_shape = [6], pads = [0, 0], strides = [1]} : (tensor<1x2x32xi8>, tensor<5x2x6xi8>, tensor, tensor) -> tensor<1x5x27xi32> + // CHECK: return [[RES_ATTR]] : tensor<1x5x27xi32> +} + +/// Default and required attributes. + +func @test_convinteger_1(%arg0 : tensor<1x2x32x64xi8>, %arg1 : tensor<5x2x6x7xi8>, %arg2 : tensor, %arg3 : tensor) -> tensor<*xi32> { + %0 = "onnx.ConvInteger"(%arg0, %arg1, %arg2, %arg3) {auto_pad = "NOTSET", group = 1 : i64} : (tensor<1x2x32x64xi8>, tensor<5x2x6x7xi8>, tensor, tensor) -> tensor<*xi32> + "std.return"(%0) : (tensor<*xi32>) -> () + + // CHECK-LABEL: test_convinteger_1 + // CHECK: [[RES_ATTR:%.+]] = "onnx.ConvInteger"(%arg0, %arg1, %arg2, %arg3) {auto_pad = "NOTSET", dilations = [1, 1], group = 1 : i64, kernel_shape = [6, 7], pads = [0, 0, 0, 0], strides = [1, 1]} : (tensor<1x2x32x64xi8>, tensor<5x2x6x7xi8>, tensor, tensor) -> tensor<1x5x27x58xi32> + // CHECK: return [[RES_ATTR]] : tensor<1x5x27x58xi32> +} + +/// kernel_shape attribute. + +func @test_convinteger_2(%arg0 : tensor<1x2x32x64xi8>, %arg1 : tensor<5x2x6x7xi8>, %arg2 : tensor, %arg3 : tensor) -> tensor<*xi32> { + %0 = "onnx.ConvInteger"(%arg0, %arg1, %arg2, %arg3) {auto_pad = "NOTSET", group = 1 : i64, kernel_shape = [8, 9]} : (tensor<1x2x32x64xi8>, tensor<5x2x6x7xi8>, tensor, tensor) -> tensor<*xi32> + "std.return"(%0) : (tensor<*xi32>) -> () + + // CHECK-LABEL: test_convinteger_2 + // CHECK: [[RES_ATTR:%.+]] = "onnx.ConvInteger"(%arg0, %arg1, %arg2, %arg3) {auto_pad = "NOTSET", dilations = [1, 1], group = 1 : i64, kernel_shape = [8, 9], pads = [0, 0, 0, 0], strides = [1, 1]} : (tensor<1x2x32x64xi8>, tensor<5x2x6x7xi8>, tensor, tensor) -> tensor<1x5x25x56xi32> + // CHECK: return [[RES_ATTR]] : tensor<1x5x25x56xi32> +} + +/// pads attribute. +/// Use pads to make output size equal to input size by adding K - 1 to the result. + +func @test_convinteger_3(%arg0 : tensor<1x2x32x64xi8>, %arg1 : tensor<5x2x6x10xi8>, %arg2 : tensor, %arg3 : tensor) -> tensor<*xi32> { + %0 = "onnx.ConvInteger"(%arg0, %arg1, %arg2, %arg3) {auto_pad = "NOTSET", group = 1 : i64, pads = [2, 4, 3, 5]} : (tensor<1x2x32x64xi8>, tensor<5x2x6x10xi8>, tensor, tensor) -> tensor<*xi32> + "std.return"(%0) : (tensor<*xi32>) -> () + + // CHECK-LABEL: test_convinteger_3 + // CHECK: [[RES_ATTR:%.+]] = "onnx.ConvInteger"(%arg0, %arg1, %arg2, %arg3) {auto_pad = "NOTSET", dilations = [1, 1], group = 1 : i64, kernel_shape = [6, 10], pads = [2, 4, 3, 5], strides = [1, 1]} : (tensor<1x2x32x64xi8>, tensor<5x2x6x10xi8>, tensor, tensor) -> tensor<1x5x32x64xi32> + // CHECK: return [[RES_ATTR]] : tensor<1x5x32x64xi32> +} + +/// auto_pad set to SAME_UPPER and SAME_LOWER. + +func @test_convinteger_4(%arg0 : tensor<1x2x32x64xi8>, %arg1 : tensor<5x2x6x10xi8>, %arg2 : tensor, %arg3 : tensor) -> tensor<*xi32> { + %0 = "onnx.ConvInteger"(%arg0, %arg1, %arg2, %arg3) {auto_pad = "SAME_UPPER", group = 1 : i64} : (tensor<1x2x32x64xi8>, tensor<5x2x6x10xi8>, tensor, tensor) -> tensor<*xi32> + "std.return"(%0) : (tensor<*xi32>) -> () + + // CHECK-LABEL: test_convinteger_4 + // CHECK: [[RES_ATTR:%.+]] = "onnx.ConvInteger"(%arg0, %arg1, %arg2, %arg3) {auto_pad = "NOTSET", dilations = [1, 1], group = 1 : i64, kernel_shape = [6, 10], pads = [2, 4, 3, 5], strides = [1, 1]} : (tensor<1x2x32x64xi8>, tensor<5x2x6x10xi8>, tensor, tensor) -> tensor<1x5x32x64xi32> + // CHECK: return [[RES_ATTR]] : tensor<1x5x32x64xi32> +} + +func @test_convinteger_5(%arg0 : tensor<1x2x32x64xi8>, %arg1 : tensor<5x2x6x10xi8>, %arg2 : tensor, %arg3 : tensor) -> tensor<*xi32> { + %0 = "onnx.ConvInteger"(%arg0, %arg1, %arg2, %arg3) {auto_pad = "SAME_LOWER", group = 1 : i64} : (tensor<1x2x32x64xi8>, tensor<5x2x6x10xi8>, tensor, tensor) -> tensor<*xi32> + "std.return"(%0) : (tensor<*xi32>) -> () + + // CHECK-LABEL: test_convinteger_5 + // CHECK: [[RES_ATTR:%.+]] = "onnx.ConvInteger"(%arg0, %arg1, %arg2, %arg3) {auto_pad = "NOTSET", dilations = [1, 1], group = 1 : i64, kernel_shape = [6, 10], pads = [3, 5, 2, 4], strides = [1, 1]} : (tensor<1x2x32x64xi8>, tensor<5x2x6x10xi8>, tensor, tensor) -> tensor<1x5x32x64xi32> + // CHECK: return [[RES_ATTR]] : tensor<1x5x32x64xi32> +} + +/// auto_pad set to VALID. + +func @test_convinteger_6(%arg0 : tensor<1x2x32x64xi8>, %arg1 : tensor<5x2x6x10xi8>, %arg2 : tensor, %arg3 : tensor) -> tensor<*xi32> { + %0 = "onnx.ConvInteger"(%arg0, %arg1, %arg2, %arg3) {auto_pad = "VALID", group = 1 : i64} : (tensor<1x2x32x64xi8>, tensor<5x2x6x10xi8>, tensor, tensor) -> tensor<*xi32> + "std.return"(%0) : (tensor<*xi32>) -> () + + // CHECK-LABEL: test_convinteger_6 + // CHECK: [[RES_ATTR:%.+]] = "onnx.ConvInteger"(%arg0, %arg1, %arg2, %arg3) {auto_pad = "NOTSET", dilations = [1, 1], group = 1 : i64, kernel_shape = [6, 10], pads = [0, 0, 0, 0], strides = [1, 1]} : (tensor<1x2x32x64xi8>, tensor<5x2x6x10xi8>, tensor, tensor) -> tensor<1x5x27x55xi32> + // CHECK: return [[RES_ATTR]] : tensor<1x5x27x55xi32> +} + +/// With strides attribute. + +func @test_convinteger_7(%arg0 : tensor<1x2x32x64xi8>, %arg1 : tensor<5x2x6x7xi8>, %arg2 : tensor, %arg3 : tensor) -> tensor<*xi32> { + %0 = "onnx.ConvInteger"(%arg0, %arg1, %arg2, %arg3) {auto_pad = "NOTSET", group = 1 : i64, strides = [2, 3]} : (tensor<1x2x32x64xi8>, tensor<5x2x6x7xi8>, tensor, tensor) -> tensor<*xi32> + "std.return"(%0) : (tensor<*xi32>) -> () + + // CHECK-LABEL: test_convinteger_7 + // CHECK: [[RES_ATTR:%.+]] = "onnx.ConvInteger"(%arg0, %arg1, %arg2, %arg3) {auto_pad = "NOTSET", dilations = [1, 1], group = 1 : i64, kernel_shape = [6, 7], pads = [0, 0, 0, 0], strides = [2, 3]} : (tensor<1x2x32x64xi8>, tensor<5x2x6x7xi8>, tensor, tensor) -> tensor<1x5x14x20xi32> + // CHECK: return [[RES_ATTR]] : tensor<1x5x14x20xi32> +} + +/// auto_pad set to SAME_UPPER with strides attribute. +/// The auto_pad will pas as if stride is equal to 1. + +func @test_convinteger_8(%arg0 : tensor<1x2x32x64xi8>, %arg1 : tensor<5x2x6x7xi8>, %arg2 : tensor, %arg3 : tensor) -> tensor<*xi32> { + %0 = "onnx.ConvInteger"(%arg0, %arg1, %arg2, %arg3) {auto_pad = "SAME_UPPER", group = 1 : i64, strides = [2, 3]} : (tensor<1x2x32x64xi8>, tensor<5x2x6x7xi8>, tensor, tensor) -> tensor<*xi32> + "std.return"(%0) : (tensor<*xi32>) -> () + + // CHECK-LABEL: test_convinteger_8 + // CHECK: [[RES_ATTR:%.+]] = "onnx.ConvInteger"(%arg0, %arg1, %arg2, %arg3) {auto_pad = "NOTSET", dilations = [1, 1], group = 1 : i64, kernel_shape = [6, 7], pads = [2, 3, 2, 3], strides = [2, 3]} : (tensor<1x2x32x64xi8>, tensor<5x2x6x7xi8>, tensor, tensor) -> tensor<1x5x16x22xi32> + // CHECK: return [[RES_ATTR]] : tensor<1x5x16x22xi32> +} + +/// dilations attribute. + +func @test_convinteger_9(%arg0 : tensor<1x2x32x64xi8>, %arg1 : tensor<5x2x6x7xi8>, %arg2 : tensor, %arg3 : tensor) -> tensor<*xi32> { + %0 = "onnx.ConvInteger"(%arg0, %arg1, %arg2, %arg3) {auto_pad = "NOTSET", group = 1 : i64, dilations = [2, 3]} : (tensor<1x2x32x64xi8>, tensor<5x2x6x7xi8>, tensor, tensor) -> tensor<*xi32> + "std.return"(%0) : (tensor<*xi32>) -> () + + // CHECK-LABEL: test_convinteger_9 + // CHECK: [[RES_ATTR:%.+]] = "onnx.ConvInteger"(%arg0, %arg1, %arg2, %arg3) {auto_pad = "NOTSET", dilations = [2, 3], group = 1 : i64, kernel_shape = [6, 7], pads = [0, 0, 0, 0], strides = [1, 1]} : (tensor<1x2x32x64xi8>, tensor<5x2x6x7xi8>, tensor, tensor) -> tensor<1x5x22x46xi32> + // CHECK: return [[RES_ATTR]] : tensor<1x5x22x46xi32> +} + +/// dilations attribute with stride. + +func @test_convinteger_10(%arg0 : tensor<1x2x32x64xi8>, %arg1 : tensor<5x2x6x7xi8>, %arg2 : tensor, %arg3 : tensor) -> tensor<*xi32> { + %0 = "onnx.ConvInteger"(%arg0, %arg1, %arg2, %arg3) {auto_pad = "NOTSET", group = 1 : i64, dilations = [2, 3], strides = [2, 2]} : (tensor<1x2x32x64xi8>, tensor<5x2x6x7xi8>, tensor, tensor) -> tensor<*xi32> + "std.return"(%0) : (tensor<*xi32>) -> () + + // CHECK-LABEL: test_convinteger_10 + // CHECK: [[RES_ATTR:%.+]] = "onnx.ConvInteger"(%arg0, %arg1, %arg2, %arg3) {auto_pad = "NOTSET", dilations = [2, 3], group = 1 : i64, kernel_shape = [6, 7], pads = [0, 0, 0, 0], strides = [2, 2]} : (tensor<1x2x32x64xi8>, tensor<5x2x6x7xi8>, tensor, tensor) -> tensor<1x5x11x23xi32> + // CHECK: return [[RES_ATTR]] : tensor<1x5x11x23xi32> +} + +/// dilations attribute with auto_pad set to SAME_UPPER. + +func @test_convinteger_11(%arg0 : tensor<1x2x32x64xi8>, %arg1 : tensor<5x2x6x7xi8>, %arg2 : tensor, %arg3 : tensor) -> tensor<*xi32> { + %0 = "onnx.ConvInteger"(%arg0, %arg1, %arg2, %arg3) {auto_pad = "SAME_UPPER", group = 1 : i64, dilations = [2, 3]} : (tensor<1x2x32x64xi8>, tensor<5x2x6x7xi8>, tensor, tensor) -> tensor<*xi32> + "std.return"(%0) : (tensor<*xi32>) -> () + + // CHECK-LABEL: test_convinteger_11 + // CHECK: [[RES_ATTR:%.+]] = "onnx.ConvInteger"(%arg0, %arg1, %arg2, %arg3) {auto_pad = "NOTSET", dilations = [2, 3], group = 1 : i64, kernel_shape = [6, 7], pads = [5, 9, 5, 9], strides = [1, 1]} : (tensor<1x2x32x64xi8>, tensor<5x2x6x7xi8>, tensor, tensor) -> tensor<1x5x32x64xi32> + // CHECK: return [[RES_ATTR]] : tensor<1x5x32x64xi32> +} diff --git a/utils/CMakeLists.txt b/utils/CMakeLists.txt index 5a3a85f..9a74721 100644 --- a/utils/CMakeLists.txt +++ b/utils/CMakeLists.txt @@ -52,6 +52,6 @@ add_custom_target(OMONNXCheckVersion COMMAND python ${CMAKE_CURRENT_SOURCE_DIR}/gen_onnx_mlir.py --check-operation-version) add_custom_target(OMMLONNXCheckVersion - COMMAND python ${CMAKE_CURRENT_SOURCE_DIR}/gen_onnx_mlir.py + COMMAND python ${CMAKE_CURRENT_SOURCE_DIR}/gen_onnx_mlir.py --check-operation-version --domain="ONNX_ML") - + diff --git a/utils/gen_onnx_mlir.py b/utils/gen_onnx_mlir.py index 6c3b810..5e30d25 100644 --- a/utils/gen_onnx_mlir.py +++ b/utils/gen_onnx_mlir.py @@ -36,7 +36,7 @@ parser.add_argument("--check-operation-version", " newer version of operation compared with version stored in version_dicts", action="store_true", default=False) -parser.add_argument("--domain", +parser.add_argument("--domain", help="specify domain, ONNX or ONNX_ML", default = "ONNX") @@ -249,13 +249,14 @@ special_op_handler = dict([ # Operations supporting shape inference. OpsWithShapeInference = [ - 'Exp', 'Tanh', 'Sinh', 'Cosh', 'Sigmoid', 'Relu', 'Add', 'Mul', 'Div', - 'Sub', 'And', 'Or', 'Xor', 'Sum', 'Max', 'Min', 'MatMul', 'Gemm', - 'LeakyRelu', 'Elu', 'Selu', 'HardSigmoid', 'Reshape', 'Reciprocal', + 'Exp', 'Atan', 'Tan', 'Tanh', 'Sin', 'Sinh', 'Cosh', 'Sigmoid', 'Relu', + 'Add', 'Mul', 'Div', 'Sub', 'And', 'Or', 'Xor', 'Sum', 'Max', 'Min', 'MatMul', + 'Gemm', 'LeakyRelu', 'Elu', 'Selu', 'HardSigmoid', 'Reshape', 'Reciprocal', 'Identity', 'Cos', 'Log', 'Transpose', 'Softmax', 'ReduceMax', 'ReduceMin', 'ReduceProd', 'ReduceSum', 'Softplus', 'Softsign', 'Sqrt', 'Unsqueeze', 'Sign', 'Constant', 'AveragePool', 'Abs', 'Conv', 'Concat', 'Neg', 'RNN', - 'LSTM', 'GRU', 'Split', 'Pad' + 'LSTM', 'GRU', 'Split', 'Pad', 'Cast', 'ConvTranspose', 'Flatten', + 'DynamicQuantizeLinear', 'QuantizeLinear', 'DequantizeLinear', 'ConvInteger', ] # Operations supporting canonicalization. @@ -286,7 +287,7 @@ OpsWithResultTypeInference = { resultTypes.push_back(mlir::UnrankedTensorType::get( convertONNXTypeToMLIRType(builder, static_cast(toAttr))));''' } - + # Add an Op in this list if the Op needs result type deduction which is required # when writing declarative rewriting rules. Deduced type is always # an UnrankedTensorType whose element type is the same as the first operand's @@ -306,7 +307,7 @@ custom_builder_broadcast_ops_list = ['Add', 'And', 'Div', 'Equal', 'Greater', custom_builder_ops_list = custom_builder_unranked_ops_list + custom_builder_broadcast_ops_list #a dictionary to add any special definition for an operation -custom_definition_misc = dict([ ('Constant', +custom_definition_misc = dict([ ('Constant', ''' let builders = [ OpBuilder<"OpBuilder &builder, OperationState &state, Attribute sparse_value, Attribute value", [{ if (value) { @@ -325,7 +326,7 @@ onnx_types = ( 'bool', 'int8', 'int16', 'int32', 'int64', 'unkown', 'float16', 'float', 'double', 'complex64', 'complex128' ) -tblgen_types = ('I1', 'I8', 'I16', 'I32', 'I64', 'BF16', 'F16', 'F32', 'F64', +tblgen_types = ('I1', 'I8', 'I16', 'I32', 'I64', 'BF16', 'F16', 'F32', 'F64', 'Complex', 'Complex' ) @@ -424,7 +425,7 @@ def get_tblgen_type_index(type_str): #the possible data structures are tensor, map and seq(tensor()) #TOFIX: currently, only tensor structure is supported -def get_data_structure_element(allowed_type_str): +def get_data_structure_element(allowed_type_str): if allowed_type_str.startswith('tensor') : element = allowed_type_str.replace('tensor(', '', 1).replace(')', '', 1) return ('tensor', element) @@ -454,9 +455,9 @@ def get_allowed_elem_types(schema, input): return None if not t in allowed_type_list : allowed_tyoe_list = allowed_type_list.append(t) - + return allowed_type_list - + return None @@ -610,9 +611,9 @@ def get_output_type_mapping(schema): #unknown output type mapping.append(str(-1)) - + return mapping - + def get_numberof_inout(s, indent, schema): expected_num_operands = get_numberof_list(schema.inputs) indent = inc_indent(indent) @@ -672,8 +673,8 @@ def get_type_inference_func(s, indent, type_inference_code): indent = dec_indent(indent) return s - - + + def gen_op_def(schema): indent = inc_indent() @@ -810,7 +811,7 @@ def gen_op_def(schema): # generate input/output number s = get_numberof_inout(s, indent, schema) - # generate ProtableConst + # generate ProtableConst if schema.name in OpsWithPromotableConstOperands: s = get_promotable_const_operands_func( s, indent, OpsWithPromotableConstOperands[schema.name]) @@ -914,7 +915,7 @@ def build_operator_schemas(): schema.since_version, schema.name)) elif schema.since_version > version_dict[schema.name]: print("Check-operation-version: Operation {} has a newer version {}"+ - "(old version {})".format( schema.name, + "(old version {})".format( schema.name, schema.since_version, version_dict[schema.name])) else: # Generate operation according to the version in version_dict.