//===------------------ ONNXOps.cpp - ONNX Operations ---------------------===// // // Copyright 2019-2020 The IBM Research Authors. // // ============================================================================= // // This file provides definition of ONNX dialect operations. // //===----------------------------------------------------------------------===// #include "mlir/Dialect/Traits.h" #include "mlir/IR/Block.h" #include "mlir/IR/Builders.h" #include "mlir/IR/DialectImplementation.h" #include "mlir/IR/Function.h" #include "mlir/IR/IntegerSet.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/Module.h" #include "mlir/IR/OpImplementation.h" #include "mlir/IR/PatternMatch.h" #include "llvm/ADT/SetVector.h" #include "llvm/ADT/SmallBitVector.h" #include "llvm/Support/FormatVariadic.h" #include "ONNXOps.hpp" using namespace mlir; using namespace mlir::OpTrait::util; using namespace mlir::onnxmlir; //===----------------------------------------------------------------------===// // ONNX Helper functions //===----------------------------------------------------------------------===// static size_t ArrayAttrSize(ArrayAttr a) { return a.size(); } static size_t ArrayAttrSize(Optional a) { return a.getValue().size(); } static int64_t ArrayAttrIntVal(ArrayAttr a, int i) { return (a.getValue()[i]).cast().getInt(); } static int64_t ArrayAttrIntVal(Optional a, int i) { return (a.getValue().getValue()[i]).cast().getInt(); } // Returns the ConstantOp which defines an MLIR Value or null. static mlir::ONNXConstantOp getONNXConstantOp(Value value) { return dyn_cast_or_null(value.getDefiningOp()); } // This method substitutes any uses of dimensions and symbols (e.g. // dim#0 with dimReplacements[0]) in an affine map, simplifies the modified // affine map, and returns an integer constant. int64_t AffineMapIntConstant(Builder &builder, AffineMap map, ArrayRef dimReplacements, ArrayRef symReplacements, unsigned numResultDims, unsigned numResultSyms) { // Prepare affine expressions. SmallVector dimExprs, symExprs; for (int64_t dim : dimReplacements) { AffineExpr exp = builder.getAffineConstantExpr(dim); dimExprs.emplace_back(exp); } for (int64_t sym : symReplacements) { AffineExpr exp = builder.getAffineConstantExpr(sym); symExprs.emplace_back(exp); } // Replace all the affine map's arguments with real values and evaluate the // map. AffineMap replacedDimMap = map.replaceDimsAndSymbols( dimExprs, symExprs, numResultDims, numResultSyms); AffineMap simplifiedMap = simplifyAffineMap(replacedDimMap); return simplifiedMap.getSingleConstantResult(); } //===----------------------------------------------------------------------===// // Get reduction type //===----------------------------------------------------------------------===// RankedTensorType getReductionOutputType( RankedTensorType operandTy, Optional axesAttrs, APInt keepdims) { int64_t rank = operandTy.getRank(); SmallVector axes; if (axesAttrs != llvm::None) { for (auto axisAttr : axesAttrs.getValue()) { int64_t axis = axisAttr.cast().getInt(); axis = axis >= 0 ? axis : (rank + axis); assert(axis >= -rank && axis <= rank - 1); if (std::find(axes.begin(), axes.end(), axis) == axes.end()) axes.emplace_back(axis); } } else { for (decltype(rank) i = 0; i < rank; ++i) { axes.emplace_back(i); } } // Mark reduction axes. SmallVector isReductionAxis; for (decltype(rank) i = 0; i < rank; ++i) { if (std::find(axes.begin(), axes.end(), i) != axes.end()) isReductionAxis.emplace_back(true); else isReductionAxis.emplace_back(false); } // KeepDims bool isKeepdims = (keepdims == 1) ? true : false; SmallVector dims; for (decltype(rank) i = 0; i < rank; ++i) { if (isReductionAxis[i]) { if (isKeepdims) dims.emplace_back(1); // reduction dimension } else { dims.emplace_back(operandTy.getShape()[i]); } } return RankedTensorType::get(dims, operandTy.getElementType()); } //===----------------------------------------------------------------------===// // Support function that computes default values for dilations. //===----------------------------------------------------------------------===// template static LogicalResult processConvDilationParam( T *op, Optional kernelShape) { auto builder = mlir::Builder(op->getContext()); auto kernelRank = ArrayAttrSize(kernelShape); auto dilationsOpt = op->dilations(); if (dilationsOpt.hasValue()) { if (ArrayAttrSize(dilationsOpt) != kernelRank) { return op->emitError( "dialation rank is not the same as the spatial rank"); } // Test values to be greater than 0. for (int i = 0; i < kernelRank; ++i) { if (ArrayAttrIntVal(dilationsOpt, i) < 1) { return op->emitError("dialation value must be nonzero positive"); } } } else { // Default dilatation is needed, all dimensions init with 1. SmallVector defaultVals(kernelRank, 1); // Convert to ArrayRef, then build attribute, then store attribute. ArrayRef defaultRefs(defaultVals); op->dilationsAttr(builder.getI64ArrayAttr(defaultRefs)); } return success(); } //===----------------------------------------------------------------------===// // Support function that computes default values for strides. //===----------------------------------------------------------------------===// template static LogicalResult processConvStrideParam( T *op, Optional kernelShape) { auto builder = mlir::Builder(op->getContext()); auto kernelRank = ArrayAttrSize(kernelShape); auto stridesOpt = op->strides(); if (stridesOpt.hasValue()) { if (ArrayAttrSize(stridesOpt) != kernelRank) return op->emitError("strides rank is not the same as the spatial rank"); // Check values to be greater than 0. for (int i = 0; i < kernelRank; ++i) { if (ArrayAttrIntVal(stridesOpt, i) < 1) return op->emitError("strides value must be nonzero positive"); } } else { // Default stride is needed, all dimensions init with 1. SmallVector defaultVals(kernelRank, 1); // Convert to ArrayRef, then build attribute, then store attribute. ArrayRef defaultRefs(defaultVals); op->stridesAttr(builder.getI64ArrayAttr(defaultRefs)); } return success(); } //===----------------------------------------------------------------------===// // Support function that computes default values for pads. //===----------------------------------------------------------------------===// template static LogicalResult processConvPadParam(T *op, ArrayRef inputShape, Optional kernelShape, Optional stridesOpt, Optional dilationsOpt = llvm::None) { auto builder = mlir::Builder(op->getContext()); auto inputRank = inputShape.size(); auto kernelRank = ArrayAttrSize(kernelShape); auto kernelOffset = inputRank - kernelRank; // Try to find padding, getting auto_pad attribute first. auto autoPad = op->auto_pad(); // And then investigate the various different cases. Prefill pad values with // zeros, the most common case. SmallVector actualPads(2 * kernelRank, 0); bool updatedPad = false; if (autoPad == "NOTSET") { auto padsOpt = op->pads(); if (padsOpt.hasValue()) { // Only option where pads are not updated. Pads consists of two entries // for each spatial axis. if (ArrayAttrSize(padsOpt) != 2 * kernelRank) { return op->emitError("pads rank is not twice the spatial rank"); } // Check values, pads cannot be negative. for (int i = 0; i < 2 * kernelRank; ++i) { if (ArrayAttrIntVal(padsOpt, i) < 0) { return op->emitError("pads value must be nonnegative"); } } } else { // We have notset with no pads, they are assumed to be all zero. updatedPad = true; } } else if (autoPad == "SAME_UPPER" || autoPad == "SAME_LOWER") { // Reload dialtion and strides as they may have gotten default values. updatedPad = true; int64_t dilationVal = 1; for (int i = 0; i < kernelRank; ++i) { auto inputSize = inputShape[kernelOffset + i]; auto kernelSize = ArrayAttrIntVal(kernelShape, i); if (dilationsOpt.hasValue()) dilationVal = ArrayAttrIntVal(dilationsOpt, i); auto strideVal = ArrayAttrIntVal(stridesOpt, i); // Output size is input size divided by stride. When stride is 1, then // input and output are the same size, which is the usual case. When // stride is greater than 1, take the ceil to be sure to have each input // value used, as padding will be used to fill the gaps. int64_t outputSize = ceil((1.0 * inputSize) / (1.0 * strideVal)); // Forumla is from ONNX MaxPool, and can be explained as follows. Pads is // the difference between the needed values for the computations, minus // the input values. The needed values for the computation is the // effective side of the kernel plus the number of times we jump to the // next kernel. Number of time we jump is (outputSize - 1). That number is // multiplied with the size of the jump, namely strideVal. Now for the // effective kernel size. It is the kernelSize + the number of times we // have dilation holes time the dialtion. The number of dialtion holes is // (kernelSize -1). Thus the effective size is "kernelSize + // (kernelSize-1)*dialation". This simplifies to "(kernelSize // -1)*dialation + 1". auto sumOfPad = (outputSize - 1) * strideVal + ((kernelSize - 1) * dilationVal + 1) - inputSize; // Pad values are assumed equal on both size, at half the total value. actualPads[i] = actualPads[kernelRank + i] = sumOfPad / 2; // But if the total pad value is odd, we add 1 to begining or end // depending on autoPad value. if (sumOfPad % 2 != 0) { if (autoPad == "SAME_UPPER") { actualPads[kernelRank + i] += 1; } else { actualPads[i] += 1; } } } } else if (autoPad == "VALID") { // No pad, default value was set to zero, we are all set. updatedPad = true; } else { return op->emitError("auto_pad of unknown / unsupported value"); } // Set pads values in attributes, if it is needed. if (updatedPad) { ArrayRef defaultRefs(actualPads); op->padsAttr(builder.getI64ArrayAttr(defaultRefs)); } // In all cases now, the acutal pad values are found in the pads attribute. op->auto_padAttr(builder.getStringAttr("NOTSET")); return success(); } //===----------------------------------------------------------------------===// // Support function computing default values for dilations, strides, and pads. //===----------------------------------------------------------------------===// template static LogicalResult processConvTypeParams(T *op, Value inputOperand) { auto builder = mlir::Builder(op->getContext()); // 1) Get shape of input. auto inputShape = inputOperand.getType().cast().getShape(); auto inputRank = inputShape.size(); // 2) Get kernel_shape attribute. auto kernelShape = op->kernel_shape(); // Dilation. LogicalResult res = processConvDilationParam(op, kernelShape); if (failed(res)) return res; auto dilationsOpt = op->dilations(); // Strides. res = processConvStrideParam(op, kernelShape); if (failed(res)) return res; auto stridesOpt = op->strides(); // Pads. return processConvPadParam( op, inputShape, kernelShape, stridesOpt, dilationsOpt); } //===----------------------------------------------------------------------===// // Compute spatial dimensions given dilations, strides, pads, and ceil mode. //===----------------------------------------------------------------------===// static void insertConvSpatialDim(SmallVector *outputDims, Builder &builder, ArrayRef xShape, Optional kernelShape, Optional padsOpt, Optional stridesOpt, Optional dilationsOpt = llvm::None, bool ceilMode = false) { auto spatialRank = ArrayAttrSize(kernelShape); auto spatialOffset = xShape.size() - spatialRank; // Get an affine map to compute the output dimension. AffineMap dimMap = getConvDimMap(builder, ceilMode); for (int i = 0; i < spatialRank; ++i) { int64_t res = -1; if (xShape[spatialOffset + i] != -1) { auto inputSize = xShape[spatialOffset + i]; auto kernelSize = ArrayAttrIntVal(kernelShape, i); auto sumOfPads = ArrayAttrIntVal(padsOpt, i) + ArrayAttrIntVal(padsOpt, spatialRank + i); auto strideVal = ArrayAttrIntVal(stridesOpt, i); int64_t dilationVal = 1; if (dilationsOpt.hasValue()) dilationVal = ArrayAttrIntVal(dilationsOpt, i); res = AffineMapIntConstant(builder, dimMap, {inputSize}, {kernelSize, sumOfPads, strideVal, dilationVal}, 1, 4); } outputDims->emplace_back(res); } } //===----------------------------------------------------------------------===// // Support function that infers shape for RNN operations. //===----------------------------------------------------------------------===// template static LogicalResult RNNShapeInference(T *op) { Value X = op->X(); Value W = op->W(); Value R = op->R(); if (!X.getType().isa() || !W.getType().isa() || !R.getType().isa()) { return op->emitError("Input tensor not ranked"); } auto xTy = X.getType().cast(); auto elementType = xTy.getElementType(); // xShape :: [seq_length, batch_size, input_size] auto xShape = xTy.getShape(); // wShape :: [num_directions, 4*hidden_size, input_size] auto wShape = W.getType().cast().getShape(); // rShape :: [num_directions, 4*hidden_size, hidden_size] auto rShape = R.getType().cast().getShape(); if (xShape.size() != 3) { return op->emitError("The first input tensor must have rank 3"); } if (wShape.size() != 3) { return op->emitError("The second input tensor must have rank 3"); } if (rShape.size() != 3) { return op->emitError("The third input tensor must have rank 3"); } // Get sequence length, batch size and input size. auto sequenceLength = xShape[0]; auto batchSize = xShape[1]; auto inputSize = xShape[2]; // Get hidden size from hidden_size attribute. int64_t hiddenSize = -1; if (op->hidden_size().hasValue()) { hiddenSize = op->hidden_size().getValue().getSExtValue(); } else { // Infer hidden_size from wShape and rShape if possible. if (rShape[2] != -1) hiddenSize = rShape[2]; else if (rShape[1] != -1) hiddenSize = rShape[1] / 4; else if (wShape[1] != -1) hiddenSize = wShape[1] / 4; // Update hidden_size attribute. if (hiddenSize != -1) { auto builder = mlir::Builder(op->getContext()); op->hidden_sizeAttr(builder.getI64IntegerAttr(hiddenSize)); } } // Get direction. int numDirection; if ((op->direction() == "forward") || (op->direction() == "reverse")) numDirection = 1; else if (op->direction() == "bidirectional") numDirection = 2; else numDirection = -1; if (numDirection == -1) { return op->emitError( "direction attribute muse be one of the strings: forward, " "reverse, and bidirectional"); } // Set result types. unsigned numOfResults = op->getNumResults(); if (numOfResults > 0) { // Y :: [seq_length, num_directions, batch_size, hidden_size] Type yTy = op->getResults()[0].getType(); if (!yTy.isa()) { yTy = RankedTensorType::get( {sequenceLength, numDirection, batchSize, hiddenSize}, elementType); op->getResults()[0].setType(yTy); } } if (numOfResults > 1) { // Y_h :: [num_directions, batch_size, hidden_size] Type yhTy = op->getResults()[1].getType(); if (!yhTy.isa()) { yhTy = RankedTensorType::get( {numDirection, batchSize, hiddenSize}, elementType); op->getResults()[1].setType(yhTy); } } if (numOfResults > 2) { // Y_c :: [num_directions, batch_size, hidden_size] Type ycTy = op->getResults()[2].getType(); if (!ycTy.isa()) { ycTy = RankedTensorType::get( {numDirection, batchSize, hiddenSize}, elementType); op->getResults()[2].setType(ycTy); } } return success(); } static void insertConvTransposeSpatialDim(SmallVectorImpl &outputDims, ArrayRef xShape, Optional kernelShape, Optional padsOpt, Optional stridesOpt, Optional outputPadsOpt, Optional outputShapeOpt, Optional dilationsOpt = llvm::None, bool ceilMode = false) { auto xRank = xShape.size(); auto spatialRank = ArrayAttrSize(kernelShape); auto spatialOffset = xRank - spatialRank; int64_t dilationVal = 1; int64_t outputPadsVal = 0; // output_shape[i] = stride[i] * (input_size[i] - 1) + output_padding[i] + // ((kernel_shape[i] - 1) * dilations[i] + 1) - pads[start_i] - pads[end_i] for (int i = 0; i < spatialRank; ++i) { auto inputSize = xShape[spatialOffset + i]; auto sumOfPads = ArrayAttrIntVal(padsOpt, i) + ArrayAttrIntVal(padsOpt, spatialRank + i); auto kernelSize = ArrayAttrIntVal(kernelShape, i); if (dilationsOpt.hasValue()) dilationVal = ArrayAttrIntVal(dilationsOpt, i); auto strideVal = ArrayAttrIntVal(stridesOpt, i); if (outputPadsOpt.hasValue()) outputPadsVal = ArrayAttrIntVal(outputPadsOpt, i); // Number of useful values: input plus pad - effective size of kernel (see // processConvTypeParams comments to see how this value is derived). int64_t res = strideVal * (inputSize - 1) + outputPadsVal + ((kernelSize - 1) * dilationVal + 1) - sumOfPads; outputDims.emplace_back(res); } } //===----------------------------------------------------------------------===// // ONNXOpsDialect //===----------------------------------------------------------------------===// /// Dialect creation, the instance will be owned by the context. This is the /// point of registration of custom types and operations for the dialect. ONNXOpsDialect::ONNXOpsDialect(mlir::MLIRContext *ctx) : mlir::Dialect(getDialectNamespace(), ctx) { addOperations< #define GET_OP_LIST #include "src/Dialect/ONNX/ONNXOps.cpp.inc" >(); addTypes(); } mlir::Type ONNXOpsDialect::parseType(mlir::DialectAsmParser &parser) const { if (parser.parseKeyword("String")) return Type(); return StringType::get(getContext()); } void ONNXOpsDialect::printType( mlir::Type type, mlir::DialectAsmPrinter &printer) const { printer << "String"; } void ONNXEntryPointOp::build(mlir::OpBuilder &builder, mlir::OperationState &state, mlir::FuncOp function, int numInputs, int numOutputs) { state.addAttribute(ONNXEntryPointOp::getEntryPointFuncAttrName(), builder.getSymbolRefAttr(function)); state.addAttribute(ONNXEntryPointOp::getNumInputsAttrName(), builder.getI32IntegerAttr(numInputs)); state.addAttribute(ONNXEntryPointOp::getNumOutputsAttrName(), builder.getI32IntegerAttr(numOutputs)); } ONNXEntryPointOp ONNXEntryPointOp::create(mlir::Location location, mlir::FuncOp &func, int numInputs, int numOutputs) { mlir::OperationState state(location, "onnx.EntryPoint"); OpBuilder builder(location->getContext()); mlir::ONNXEntryPointOp::build(builder, state, func, numInputs, numOutputs); Operation *op = mlir::Operation::create(state); auto onnxEntryOp = llvm::cast(op); return onnxEntryOp; } //===----------------------------------------------------------------------===// // ONNX Operations //===----------------------------------------------------------------------===// // Exp /// Infer the output shape of the ONNXExpOp. This method is required by the /// shape inference interface. LogicalResult ONNXExpOp::inferShapes() { getResult().setType(getOperand().getType()); return success(); } //===----------------------------------------------------------------------===// // Atan //===----------------------------------------------------------------------===// /// Infer the output shape of the ONNXAtanOp. This method is required by the /// shape inference interface. LogicalResult ONNXAtanOp::inferShapes() { getResult().setType(getOperand().getType()); return success(); } //===----------------------------------------------------------------------===// // Tan //===----------------------------------------------------------------------===// /// Infer the output shape of the ONNXTanOp. This method is required by the /// shape inference interface. LogicalResult ONNXTanOp::inferShapes() { getResult().setType(getOperand().getType()); return success(); } //===----------------------------------------------------------------------===// // Tanh /// Infer the output shape of the ONNXTanhOp. This method is required by the /// shape inference interface. LogicalResult ONNXTanhOp::inferShapes() { getResult().setType(getOperand().getType()); return success(); } //===----------------------------------------------------------------------===// // Sin //===----------------------------------------------------------------------===// /// Infer the output shape of the ONNXSinOp. This method is required by the /// shape inference interface. LogicalResult ONNXSinOp::inferShapes() { getResult().setType(getOperand().getType()); return success(); } //===----------------------------------------------------------------------===// // Sinh //===----------------------------------------------------------------------===// /// Infer the output shape of the ONNXSinhOp. This method is required by the /// shape inference interface. LogicalResult ONNXSinhOp::inferShapes() { getResult().setType(getOperand().getType()); return success(); } //===----------------------------------------------------------------------===// // Cosh //===----------------------------------------------------------------------===// /// Infer the output shape of the ONNXCoshOp. This method is required by the /// shape inference interface. LogicalResult ONNXCoshOp::inferShapes() { getResult().setType(getOperand().getType()); return success(); } //===----------------------------------------------------------------------===// // Cos //===----------------------------------------------------------------------===// /// Infer the output shape of the ONNXCosOp. This method is required by the /// shape inference interface. LogicalResult ONNXCosOp::inferShapes() { getResult().setType(getOperand().getType()); return success(); } //===----------------------------------------------------------------------===// // Log //===----------------------------------------------------------------------===// /// Infer the output shape of the ONNXLogOp. This method is required by the /// shape inference interface. LogicalResult ONNXLogOp::inferShapes() { getResult().setType(getOperand().getType()); return success(); } //===----------------------------------------------------------------------===// // HardSigmoid //===----------------------------------------------------------------------===// /// Infer the output shape of the ONNXHardSigmoidOp. This method is required by /// the shape inference interface. LogicalResult ONNXHardSigmoidOp::inferShapes() { getResult().setType(getOperand().getType()); return success(); } //===----------------------------------------------------------------------===// // Sigmoid //===----------------------------------------------------------------------===// /// Infer the output shape of the ONNXSigmoidOp. This method is required by the /// shape inference interface. LogicalResult ONNXSigmoidOp::inferShapes() { getResult().setType(getOperand().getType()); return success(); } //===----------------------------------------------------------------------===// // Elu //===----------------------------------------------------------------------===// /// Infer the output shape of the ONNXEluOp. This method is required by the /// shape inference interface. LogicalResult ONNXEluOp::inferShapes() { getResult().setType(getOperand().getType()); return success(); } //===----------------------------------------------------------------------===// // Relu //===----------------------------------------------------------------------===// /// Infer the output shape of the ONNXReluOp. This method is required by the /// shape inference interface. LogicalResult ONNXReluOp::inferShapes() { getResult().setType(getOperand().getType()); return success(); } //===----------------------------------------------------------------------===// // LeakyRelu //===----------------------------------------------------------------------===// /// Infer the output shape of the ONNXLeakyReluOp. This method is required by /// the shape inference interface. LogicalResult ONNXLeakyReluOp::inferShapes() { getResult().setType(getOperand().getType()); return success(); } //===----------------------------------------------------------------------===// // Selu //===----------------------------------------------------------------------===// /// Infer the output shape of the ONNXSeluOp. This method is required by /// the shape inference interface. LogicalResult ONNXSeluOp::inferShapes() { getResult().setType(getOperand().getType()); return success(); } //===----------------------------------------------------------------------===// // Reciprocal //===----------------------------------------------------------------------===// /// Infer the output shape of the ONNXReciprocalOp. This method is required by /// the shape inference interface. LogicalResult ONNXReciprocalOp::inferShapes() { getResult().setType(getOperand().getType()); return success(); } //===----------------------------------------------------------------------===// // Softmax //===----------------------------------------------------------------------===// /// Infer the output shape of the ONNXSoftmaxOp. This method is required by /// the shape inference interface. LogicalResult ONNXSoftmaxOp::inferShapes() { getResult().setType(getOperand().getType()); return success(); } //===----------------------------------------------------------------------===// // Softplus //===----------------------------------------------------------------------===// /// Infer the output shape of the ONNXSoftplusOp. This method is required by /// the shape inference interface. LogicalResult ONNXSoftplusOp::inferShapes() { getResult().setType(getOperand().getType()); return success(); } //===----------------------------------------------------------------------===// // Softsign //===----------------------------------------------------------------------===// /// Infer the output shape of the ONNXSoftsignOp. This method is required by /// the shape inference interface. LogicalResult ONNXSoftsignOp::inferShapes() { getResult().setType(getOperand().getType()); return success(); } //===----------------------------------------------------------------------===// // Sqrt //===----------------------------------------------------------------------===// /// Infer the output shape of the ONNXSqrtOp. This method is required by /// the shape inference interface. LogicalResult ONNXSqrtOp::inferShapes() { getResult().setType(getOperand().getType()); return success(); } //===----------------------------------------------------------------------===// // Sign //===----------------------------------------------------------------------===// /// Infer the output shape of the ONNXSignOp. This method is required by /// the shape inference interface. LogicalResult ONNXSignOp::inferShapes() { getResult().setType(getOperand().getType()); return success(); } //===----------------------------------------------------------------------===// // Abs //===----------------------------------------------------------------------===// /// Infer the output shape of the ONNXAbsOp. This method is required by the /// shape inference interface. LogicalResult ONNXAbsOp::inferShapes() { getResult().setType(getOperand().getType()); return success(); } //===----------------------------------------------------------------------===// // Add //===----------------------------------------------------------------------===// /// Infer the output shape of the ONNXAddOp. This method is required by the /// shape inference interface. LogicalResult ONNXAddOp::inferShapes() { if (!getOperand(0).getType().isa() || !getOperand(1).getType().isa()) return emitError("Input tensor(s) not ranked"); auto lhsTy = getOperand(0).getType().cast(); auto rhsTy = getOperand(1).getType().cast(); getResult().setType(getBroadcastedType(lhsTy, rhsTy)); return success(); } //===----------------------------------------------------------------------===// // Mul //===----------------------------------------------------------------------===// /// Infer the output shape of the ONNXMulOp. This method is required by the /// shape inference interface. LogicalResult ONNXMulOp::inferShapes() { if (!getOperand(0).getType().isa() || !getOperand(1).getType().isa()) return emitError("Input tensor(s) not ranked"); auto lhsTy = getOperand(0).getType().cast(); auto rhsTy = getOperand(1).getType().cast(); getResult().setType(getBroadcastedType(lhsTy, rhsTy)); return success(); } //===----------------------------------------------------------------------===// // Div //===----------------------------------------------------------------------===// /// Infer the output shape of the ONNXDivOp. This method is required by the /// shape inference interface. LogicalResult ONNXDivOp::inferShapes() { if (!getOperand(0).getType().isa() || !getOperand(1).getType().isa()) return emitError("Input tensor(s) not ranked"); auto lhsTy = getOperand(0).getType().cast(); auto rhsTy = getOperand(1).getType().cast(); getResult().setType(getBroadcastedType(lhsTy, rhsTy)); return success(); } //===----------------------------------------------------------------------===// // Sub //===----------------------------------------------------------------------===// /// Infer the output shape of the ONNXSubOp. This method is required by the /// shape inference interface. LogicalResult ONNXSubOp::inferShapes() { if (!getOperand(0).getType().isa() || !getOperand(1).getType().isa()) return emitError("Input tensor(s) not ranked"); auto lhsTy = getOperand(0).getType().cast(); auto rhsTy = getOperand(1).getType().cast(); getResult().setType(getBroadcastedType(lhsTy, rhsTy)); return success(); } //===----------------------------------------------------------------------===// // And //===----------------------------------------------------------------------===// /// Infer the output shape of the ONNXAndOp. This method is required by the /// shape inference interface. LogicalResult ONNXAndOp::inferShapes() { if (!getOperand(0).getType().isa() || !getOperand(1).getType().isa()) return emitError("Input tensor(s) not ranked"); auto lhsTy = getOperand(0).getType().cast(); auto rhsTy = getOperand(1).getType().cast(); getResult().setType(getBroadcastedType(lhsTy, rhsTy)); return success(); } //===----------------------------------------------------------------------===// // Or //===----------------------------------------------------------------------===// /// Infer the output shape of the ONNXOrOp. This method is required by the /// shape inference interface. LogicalResult ONNXOrOp::inferShapes() { if (!getOperand(0).getType().isa() || !getOperand(1).getType().isa()) return emitError("Input tensor(s) not ranked"); auto lhsTy = getOperand(0).getType().cast(); auto rhsTy = getOperand(1).getType().cast(); getResult().setType(getBroadcastedType(lhsTy, rhsTy)); return success(); } //===----------------------------------------------------------------------===// // Xor //===----------------------------------------------------------------------===// /// Infer the output shape of the ONNXXorOp. This method is required by the /// shape inference interface. LogicalResult ONNXXorOp::inferShapes() { if (!getOperand(0).getType().isa() || !getOperand(1).getType().isa()) return emitError("Input tensor(s) not ranked"); auto lhsTy = getOperand(0).getType().cast(); auto rhsTy = getOperand(1).getType().cast(); getResult().setType(getBroadcastedType(lhsTy, rhsTy)); return success(); } //===----------------------------------------------------------------------===// // Sum //===----------------------------------------------------------------------===// /// Infer the output shape of the ONNXSumOp. This method is required by the /// shape inference interface. LogicalResult ONNXSumOp::inferShapes() { for (int i = 0; i < getNumOperands(); ++i) { if (!getOperand(i).getType().cast()) return emitError("Input tensor(s) not ranked"); } Type resultTy = getOperand(0).getType().cast(); for (int i = 1; i < getNumOperands(); ++i) { Type nextTy = getOperand(i).getType().cast(); resultTy = getBroadcastedType(resultTy, nextTy); } getResult().setType(resultTy); return success(); } //===----------------------------------------------------------------------===// // Max //===----------------------------------------------------------------------===// /// Infer the output shape of the ONNXMaxOp. This method is required by the /// shape inference interface. LogicalResult ONNXMaxOp::inferShapes() { for (int i = 0; i < getNumOperands(); ++i) { if (!getOperand(i).getType().cast()) return emitError("Input tensor(s) not ranked"); } Type resultTy = getOperand(0).getType().cast(); for (int i = 1; i < getNumOperands(); ++i) { Type nextTy = getOperand(i).getType().cast(); resultTy = getBroadcastedType(resultTy, nextTy); } getResult().setType(resultTy); return success(); } //===----------------------------------------------------------------------===// // Min //===----------------------------------------------------------------------===// /// Infer the output shape of the ONNXMinOp. This method is required by the /// shape inference interface. LogicalResult ONNXMinOp::inferShapes() { for (int i = 0; i < getNumOperands(); ++i) { if (!getOperand(i).getType().cast()) return emitError("Input tensor(s) not ranked"); } Type resultTy = getOperand(0).getType().cast(); for (int i = 1; i < getNumOperands(); ++i) { Type nextTy = getOperand(i).getType().cast(); resultTy = getBroadcastedType(resultTy, nextTy); } getResult().setType(resultTy); return success(); } //===----------------------------------------------------------------------===// // Neg //===----------------------------------------------------------------------===// /// Infer the output shape of the ONNXNegOp. This method is required by the /// shape inference interface. LogicalResult ONNXNegOp::inferShapes() { getResult().setType(getOperand().getType()); return success(); } //===----------------------------------------------------------------------===// // Identity //===----------------------------------------------------------------------===// /// Infer the output shape of the ONNXIdentityOp. This method is required by the /// shape inference interface. LogicalResult ONNXIdentityOp::inferShapes() { getResult().setType(getOperand().getType()); return success(); } //===----------------------------------------------------------------------===// // MatMul //===----------------------------------------------------------------------===// LogicalResult ONNXMatMulOp::inferShapes() { // Cannot infer shape if no shape exists. if (!A().getType().isa() || !B().getType().isa()) return emitError("Input tensor(s) not ranked"); auto lhsTy = A().getType().cast(); auto rhsTy = B().getType().cast(); SmallVector dims; auto lhsShape = lhsTy.getShape(); auto rhsShape = rhsTy.getShape(); if (lhsShape.size() < 1 && rhsShape.size() < 1) { // Multiplication by scalars is not allowed. return emitError("Multiplication by scalar arguments not allowed"); } else if (lhsShape.size() == 1 && rhsShape.size() == 1) { // Special case when both arrays are 1-dimensional and according to // numpy rules the types need to be extended to 1xN and Nx1. Helper sizes // need to be removed after the multiplication but cannot be removed if all // sizes are 1. if (lhsShape[0] != -1 && rhsShape[0] != -1 && lhsShape[0] != rhsShape[0]) return emitError("Attempt to multiply incompatible matrices"); dims.emplace_back(1); } else if (lhsShape.size() == 1 && rhsShape.size() >= 2) { // If the first argument is 1-D, it is promoted to a matrix by prepending a // 1 to its dimensions. After matrix multiplication the prepended 1 is // removed. // // N MATMUL (s1 x s2 x... x sK x N x P) // => // (s1 x s2 x... x sK x P) // Check legality of matrix multiplication. unsigned rhsRank = rhsShape.size(); if (lhsShape[0] != -1 && rhsShape[rhsRank - 2] != -1 && lhsShape[0] != rhsShape[rhsRank - 2]) return emitError("Attempt to multiply incompatible matrices"); for (decltype(rhsRank) i = 0; i < rhsRank - 2; ++i) dims.emplace_back(rhsShape[i]); dims.emplace_back(rhsShape[rhsRank - 1]); } else if (lhsShape.size() >= 2 && rhsShape.size() == 1) { // If the second argument is 1-D, it is promoted to a matrix by appending a // 1 to its dimensions. After matrix multiplication the appended 1 is // removed. // // (s1 x s2 x... x sK x M x N) MATMUL N // => // (s1 x s2 x... x sK x M) // Check legality of matrix multiplication. unsigned lhsRank = lhsShape.size(); if (lhsShape[lhsRank - 1] != -1 && rhsShape[0] != -1 && lhsShape[lhsRank - 1] != rhsShape[0]) return emitError("Attempt to multiply incompatible matrices"); for (decltype(lhsRank) i = 0; i < lhsRank - 2; ++i) dims.emplace_back(lhsShape[i]); dims.emplace_back(lhsShape[lhsRank - 2]); } else if (lhsShape.size() > 2 && rhsShape.size() == 2) { // (s1 x s2 x... x sK x M x N) MATMUL (N x P) // => // (s1 x s2 x... x sK x M x P) // Check legality of matrix multiplication. unsigned lhsRank = lhsShape.size(); if (lhsShape[lhsRank - 1] != -1 && rhsShape[0] != -1 && lhsShape[lhsRank - 1] != rhsShape[0]) return emitError("Attempt to multiply incompatible matrices"); for (decltype(lhsRank) i = 0; i < lhsRank - 1; ++i) dims.emplace_back(lhsShape[i]); dims.emplace_back(rhsShape[1]); } else if (lhsShape.size() == 2 && rhsShape.size() > 2) { // (M x N) MATMUL (s1 x s2 x... x sK x N x P) // => // (s1 x s2 x... x sK x M x P) // Check legality of matrix multiplication. unsigned rhsRank = rhsShape.size(); if (lhsShape[1] != -1 && rhsShape[rhsRank - 2] != -1 && lhsShape[1] != rhsShape[rhsRank - 2]) return emitError("Attempt to multiply incompatible matrices"); for (decltype(rhsRank) i = 0; i < rhsRank - 2; ++i) dims.emplace_back(rhsShape[i]); dims.emplace_back(lhsShape[0]); dims.emplace_back(rhsShape[rhsRank - 1]); } else if (lhsShape.size() > 2 && rhsShape.size() > 2) { // (s1 x s2 x... x sK x M x N) MATMUL (t1 x t2 x... x tK x N x P) // => // (u1 x u2 x... x uK x M x P) // Check legality of matrix multiplication. unsigned lhsRank = lhsShape.size(); unsigned rhsRank = rhsShape.size(); if (lhsShape[lhsRank - 1] != -1 && rhsShape[rhsRank - 2] != -1 && lhsShape[lhsRank - 1] != rhsShape[rhsRank - 2]) return emitError("Attempt to multiply incompatible matrices"); // Check and perform broadcasting for the shapes. SmallVector lhsBcastShape; for (decltype(lhsRank) i = 0; i < lhsRank - 2; ++i) lhsBcastShape.emplace_back(lhsShape[i]); SmallVector rhsBcastShape; for (decltype(rhsRank) i = 0; i < rhsRank - 2; ++i) rhsBcastShape.emplace_back(rhsShape[i]); if (!getBroadcastedShape(lhsBcastShape, rhsBcastShape, dims)) return emitError("Broadcasted dimensions are incompatible"); dims.emplace_back(lhsShape[lhsRank - 2]); dims.emplace_back(rhsShape[rhsRank - 1]); } else { // This case covers all remaining combinations of 1 and 2-D matrices. int64_t lhsDim = lhsShape[0]; int64_t rhsDim = rhsShape[0]; if (lhsShape.size() > 1) { lhsDim = lhsShape[1]; dims.emplace_back(lhsShape[0]); } // Check legality of matrix multiplication. if (lhsDim != -1 && rhsDim != -1 && lhsDim != rhsDim) return emitError("Attempt to multiply incompatible matrices"); if (rhsShape.size() > 1) dims.emplace_back(rhsShape[1]); } getResult().setType(RankedTensorType::get(dims, lhsTy.getElementType())); return success(); } // Gemm LogicalResult ONNXGemmOp::inferShapes() { bool hasBias = !C().getType().isa(); // Cannot infer shape if no shape exists. if (!A().getType().isa() || !B().getType().isa() || (hasBias && !C().getType().isa())) return emitError("Input tensor(s) not ranked"); auto lhsTy = A().getType().cast(); auto rhsTy = B().getType().cast(); int64_t M, N, K_A, K_B; M = (transA() == 0) ? lhsTy.getShape()[0] : lhsTy.getShape()[1]; K_A = (transA() == 0) ? lhsTy.getShape()[1] : lhsTy.getShape()[0]; N = (transB() == 0) ? rhsTy.getShape()[1] : rhsTy.getShape()[0]; K_B = (transB() == 0) ? rhsTy.getShape()[0] : rhsTy.getShape()[1]; if ((K_A != -1) && (K_B != -1) && (K_A != K_B)) return emitError("Tensor shapes mismatched"); if (hasBias) { // Check whether bias is unidirectional broadcasting or not. auto biasTy = C().getType().cast(); auto shape = biasTy.getShape(); int rank = shape.size(); if ((rank > 2) || (rank >= 1 && shape[rank - 1] != -1 && N != -1 && N != shape[rank - 1] && shape[rank - 1] != 1) || (rank == 2 && shape[rank - 2] != -1 && M != -1 && M != shape[rank - 2] && shape[rank - 2] != 1)) return emitError("Bias shape mismatched"); } SmallVector dims; dims.emplace_back(M); dims.emplace_back(N); getResult().setType(RankedTensorType::get(dims, lhsTy.getElementType())); return success(); } /// BatchNormalizationTestMode LogicalResult ONNXBatchNormalizationTestModeOp::inferShapes() { // Cannot infer shape if no shape exists. if (!X().getType().isa() || !scale().getType().isa() || !B().getType().isa() || !mean().getType().isa() || !var().getType().isa()) return emitError("Input tensor(s) not ranked"); auto inputTensorTy = X().getType().cast(); auto scaleTensorTy = scale().getType().cast(); auto biasTensorTy = B().getType().cast(); auto meanTensorTy = mean().getType().cast(); auto varianceTensorTy = var().getType().cast(); // Check whether the shapes of scale, bias, mean and variance are valid. // Operand's dimensions can be in the form of NxCxD1xD2x...xDn or N. // In case of N, C is assumed to be 1. // Shapes of scale, bias, mean and variance must be C. int64_t c = -1; if (inputTensorTy.getShape().size() == 1) { c = 1; } else if (inputTensorTy.getShape().size() > 2) { c = (inputTensorTy.getShape()[1] != -1) ? inputTensorTy.getShape()[1] : -1; } else { return emitError("Wrong rank for the input"); } if (c != -1) { auto s = scaleTensorTy.getShape(); auto b = biasTensorTy.getShape(); auto m = meanTensorTy.getShape(); auto v = varianceTensorTy.getShape(); if ((s.size() != 1) || (s[0] != -1 && s[0] != c)) return emitError("Wrong rank for the scale"); if ((b.size() != 1) || (b[0] != -1 && b[0] != c)) return emitError("Wrong rank for the bias"); if ((m.size() != 1) || (m[0] != -1 && m[0] != c)) return emitError("Wrong rank for the mean"); if ((v.size() != 1) || (v[0] != -1 && v[0] != c)) return emitError("Wrong rank for the variance"); } // The output tensor of the same shape as the input. getResult().setType(X().getType()); return success(); } // TODO: // Verify that matrix sizes are valid for multiplication and addition. // Take into account the dimensionality of the matrix. //===----------------------------------------------------------------------===// // Reshape //===----------------------------------------------------------------------===// LogicalResult ONNXReshapeOp::inferShapes() { // Cannot infer shape if no shape tensor is specified. if (!data().getType().isa()) return emitError("Input data tensor not ranked"); if (!shape().getType().isa()) return emitError("Shape tensor not ranked"); auto inputTensorTy = data().getType().cast(); auto shapeTensorTy = shape().getType().cast(); // Only rank 1 shape tensors are supported. if (shapeTensorTy.getShape().size() != 1) return emitError("Shape tensor must have rank one"); int64_t outputRank = shapeTensorTy.getShape()[0]; // Shape tensor must have constant shape. if (outputRank < 0) return emitError("Shape tensor must have constant shape"); // Compute total number of elements. int64_t totalInputSize = 1; for (auto inputDim : inputTensorTy.getShape()) totalInputSize *= inputDim; // Check if second argument of ReshapeOp is a constant. auto constantOp = getONNXConstantOp(shape()); SmallVector dims(outputRank, -1); if (constantOp) { DenseElementsAttr valueAttribute = constantOp.valueAttr().dyn_cast(); if (!valueAttribute) return emitError("DenseElementsAttr expected"); // Get dims from valueAttribute. auto valueIt = valueAttribute.getValues().begin(); for (int i = 0; i < outputRank; ++i) dims[i] = (*valueIt++).cast().getInt(); if (valueIt != valueAttribute.getValues().end()) return emitError("Constant value must have same rank as output"); int64_t numberOfDynamicInputs = 0; int64_t totalKnownDimsSize = 1; int64_t dynamicValueIndex = -1; for (int i = 0; i < outputRank; ++i) { // Set output dimension. if (dims[i] == 0) dims[i] = inputTensorTy.getShape()[i]; if (dims[i] < 0) { numberOfDynamicInputs++; dynamicValueIndex = i; } else { totalKnownDimsSize *= dims[i]; } } // If the number of dynamic inputs is 1 then deduce the missing value // based on the total input size. The total input size must be greater // than 0 i.e. all constant dimensions. // TODO: Support dynamic input dimensons. if (numberOfDynamicInputs == 1 && totalKnownDimsSize > 0 && totalInputSize > 0) dims[dynamicValueIndex] = totalInputSize / totalKnownDimsSize; } getResult().setType( RankedTensorType::get(dims, inputTensorTy.getElementType())); return success(); } // Transpose LogicalResult ONNXTransposeOp::inferShapes() { // Cannot infer shape if no shape exists. if (!data().getType().isa()) return emitError("Input tensor not ranked"); // Naive transposition which handles the default case of // reversing the shape of the tensor (similar to numpy.transpose). auto arrayTy = data().getType().cast(); SmallVector dims; auto permutation = ONNXTransposeOp::permAttr(); if (!permutation) { // Generate revese order for default transpose operation. SmallVector defaultVals; auto builder = mlir::Builder(getContext()); auto rank = arrayTy.getShape().size(); for (int i = rank - 1; i >= 0; --i) defaultVals.emplace_back(i); // Set default attribute. ArrayRef defaultRefs(defaultVals); permAttr(builder.getI64ArrayAttr(defaultRefs)); permutation = permAttr(); } // Perform transposition according to perm attribute. for (auto perm : permutation.getValue()) dims.emplace_back(arrayTy.getShape()[perm.cast().getInt()]); getResult().setType(RankedTensorType::get(dims, arrayTy.getElementType())); return success(); } //===----------------------------------------------------------------------===// // ReduceMax //===----------------------------------------------------------------------===// LogicalResult ONNXReduceMaxOp::inferShapes() { if (!getOperand().getType().isa()) return emitError("Input tensor not ranked"); auto operandTy = getOperand().getType().cast(); getResult().setType(getReductionOutputType(operandTy, axes(), keepdims())); return success(); } //===----------------------------------------------------------------------===// // ReduceMin //===----------------------------------------------------------------------===// LogicalResult ONNXReduceMinOp::inferShapes() { if (!getOperand().getType().isa()) return emitError("Input tensor not ranked"); auto operandTy = getOperand().getType().cast(); getResult().setType(getReductionOutputType(operandTy, axes(), keepdims())); return success(); } //===----------------------------------------------------------------------===// // ReduceProd //===----------------------------------------------------------------------===// LogicalResult ONNXReduceProdOp::inferShapes() { if (!getOperand().getType().isa()) return emitError("Input tensor not ranked"); auto operandTy = getOperand().getType().cast(); getResult().setType(getReductionOutputType(operandTy, axes(), keepdims())); return success(); } //===----------------------------------------------------------------------===// // ReduceSum //===----------------------------------------------------------------------===// LogicalResult ONNXReduceSumOp::inferShapes() { if (!getOperand().getType().isa()) return emitError("Input tensor not ranked"); auto operandTy = getOperand().getType().cast(); getResult().setType(getReductionOutputType(operandTy, axes(), keepdims())); return success(); } //===----------------------------------------------------------------------===// // Conv //===----------------------------------------------------------------------===// // For this operation, we define the attributes once in the original Conv // operation class. There is no need to redefine the attribute names for the // other classes based on Conv. // Conv attributes output: // - auto_pad set to NOTSET; // - dilations, strides: set to 1 if not defined by user; // - kernelShape: inferred from weight matrix if not defined by user; // - pads: set to proper value, 0 if not defined by user. LogicalResult ONNXConvOp::inferShapes() { // Generic shape for data input X, weight tensor W, and optional bias B // X: (N x C x D1 x D2 ... x Dn) // W: (M x C/group x k1 x k2 x ... x kn) // B: (M) Optional bool hasBias = !B().getType().isa(); // Cannot infer shape if no shape exists. if (!X().getType().isa() || !W().getType().isa() || (hasBias && !B().getType().isa())) return emitError("Input tensor not ranked"); auto xTy = X().getType().cast(); auto xShape = xTy.getShape(); auto weightTy = W().getType().cast(); auto weightShape = weightTy.getShape(); auto builder = mlir::Builder(this->getContext()); // Lowest supported convolution is a one dimensional convolution. if (xShape.size() < 3) return emitError("Data input shape must be at least (NxCxD1)"); // Check that shape of weight and data have same length. if (xShape.size() != weightShape.size()) return emitError("Weight size not compatible with data size"); // Group is a required attribute and should have default value of 1. int64_t group = ONNXConvOp::group().getSExtValue(); // Check if the attribute actually exists. If it does not then add it. if (!groupAttr()) groupAttr(builder.getI64IntegerAttr(group)); // Check that the X.shape[1] == (W.shape[1] * group) == C condition holds. if (xShape[1] != -1 && weightShape[1] != -1 && xShape[1] != (weightShape[1] * group)) return emitError("Channel dimension mismatch"); // Check the size of bias. if (hasBias) { auto bTx = B().getType().cast(); auto bShape = bTx.getShape(); if (bShape.size() != 1) return emitError("bias should be one dimensional"); if (bShape[0] != weightShape[0]) return emitError("bias should have same dimensions " "as weight's first dimension"); } // Note: the value of the group attribut only impacts the way the // computation is carried out and not the actual output size. // Number of spatial dimensions. auto spatialOffset = 2; int32_t spatialRank = xShape.size() - spatialOffset; // Use kernel_shape attribute if present otherwise use size from weight // argument. auto kernelShape = kernel_shape(); if (kernelShape.hasValue()) { if (ArrayAttrSize(kernelShape) != spatialRank) return emitError( "kernel_shape length incompatible with spatial dimensions"); // Have the right number of values, check them. for (int i = 0; i < spatialRank; ++i) if (ArrayAttrIntVal(kernelShape, i) < 1) return emitError("bad kernel_shape value"); } else { // Deduce shape from weight input. SmallVector defaultVals; for (int i = 0; i < spatialRank; ++i) defaultVals.emplace_back(weightShape[spatialOffset + i]); // Convert to ArrayRef, then build attribute, then store attribute. ArrayRef defaultRefs(defaultVals); auto builder = mlir::Builder(getContext()); kernel_shapeAttr(builder.getI64ArrayAttr(defaultRefs)); kernelShape = kernel_shape(); } // Process strides, dilations, and pads. processConvTypeParams<>(this, X()); auto dilationsOpt = dilations(); auto stridesOpt = strides(); auto padsOpt = pads(); // First two output dimensions consist of the number of batches and the // number of kernels being applied. SmallVector outputDims; // Insert batch size. outputDims.emplace_back(xShape[0]); // Insert number of filters being applied (number of output channels). outputDims.emplace_back(weightShape[0]); // Compute and insert spatial dims. insertConvSpatialDim(&outputDims, builder, xShape, kernelShape, padsOpt, stridesOpt, dilationsOpt); getResult().setType(RankedTensorType::get(outputDims, xTy.getElementType())); return success(); } //===----------------------------------------------------------------------===// // ConvTranspose //===----------------------------------------------------------------------===// // For this operation, we define the attributes once in the original Conv // operation class. There is no need to redefine the attribute names for the // other classes based on Conv. // Conv attributes output: // - auto_pad set to NOTSET; // - dilations, strides: set to 1 if not defined by user; // - kernelShape: inferred from weight matrix if not defined by user; // - pads: set to proper value, 0 if not defined by user. LogicalResult ONNXConvTransposeOp::inferShapes() { // Generic shape for data input X, weight tensor W, and optional bias B // X: (N x C x D1 x D2 ... x Dn) // W: (M x C/group x k1 x k2 x ... x kn) // B: (M) Optional bool hasBias = !B().getType().isa(); // Cannot infer shape if no shape exists. if (!X().getType().isa() || !W().getType().isa() || (hasBias && !B().getType().isa())) { return emitError("Input tensor not ranked"); } auto xTy = X().getType().cast(); auto xShape = xTy.getShape(); auto weightTy = W().getType().cast(); auto weightShape = weightTy.getShape(); auto builder = mlir::Builder(this->getContext()); // Lowest supported convolution is a one dimensional convolution. if (xShape.size() < 3) { return emitError("Data input shape must be at least (NxCxD1)"); } // Check that shape of weight and data have same length. if (xShape.size() != weightShape.size()) { return emitError("Weight size not compatible with data size"); } // Group is a required attribute and should have default value of 1. int64_t group = ONNXConvTransposeOp::group().getSExtValue(); // Check if the attribute actually exists. If it does not then add it. if (!groupAttr()) groupAttr(builder.getI64IntegerAttr(group)); // Check that the X.shape[1] == (W.shape[0] * group) == C condition holds. if (xShape[1] != -1 && weightShape[0] != -1 && xShape[1] != (weightShape[0] * group)) { return emitError("Channel dimension mismatch"); } // Check the size of bias. if (hasBias) { auto bTx = B().getType().cast(); auto bShape = bTx.getShape(); if (bShape.size() != 1) { return emitError("bias should be one dimensional"); } if (bShape[0] != weightShape[1]) { return emitError( "bias should have same dimensions as weight's second dimension"); } } // Note: the value of the group attribut only impacts the way the // computation is carried out and not the actual output size. // Number of spatial dimensions. auto spatialOffset = 2; int32_t spatialRank = xShape.size() - spatialOffset; // Use kernel_shape attribute if present otherwise use size from weight // argument. auto kernelShape = kernel_shape(); if (kernelShape.hasValue()) { if (ArrayAttrSize(kernelShape) != spatialRank) { return emitError( "kernel_shape length incompatible with spatial dimensions"); } // Have the right number of values, check them. for (int i = 0; i < spatialRank; ++i) if (ArrayAttrIntVal(kernelShape, i) < 1) { return emitError("bad kernel_shape value"); } } else { // Deduce shape from weight input. SmallVector defaultVals; for (int i = 0; i < spatialRank; ++i) defaultVals.emplace_back(weightShape[spatialOffset + i]); // Convert to ArrayRef, then build attribute, then store attribute. ArrayRef defaultRefs(defaultVals); auto builder = mlir::Builder(getContext()); kernel_shapeAttr(builder.getI64ArrayAttr(defaultRefs)); kernelShape = kernel_shape(); } // Process strides, dilations, and pads. processConvTypeParams<>(this, X()); auto dilationsOpt = dilations(); auto stridesOpt = strides(); auto padsOpt = pads(); auto outputPads = output_padding(); auto outputShape = output_shape(); // TODO: handle the spatial dimension computation if output shape is specified assert(!outputShape.hasValue() && "unhandled option in ConvTranspose"); // First two output dimensions consist of the number of batches and the // number of kernels being applied. SmallVector outputDims; // Insert batch size. outputDims.emplace_back(xShape[0]); // Insert number of filters being applied (number of output channels). outputDims.emplace_back(weightShape[1]); // Compute and insert spatial dims. insertConvTransposeSpatialDim(outputDims, xShape, kernelShape, padsOpt, stridesOpt, outputPads, outputShape, dilationsOpt); // Set the output shape if it's not already set if (!outputShape.hasValue()) { output_shapeAttr(builder.getI64ArrayAttr(outputDims)); } getResult().setType(RankedTensorType::get(outputDims, xTy.getElementType())); return success(); } //===----------------------------------------------------------------------===// // AveragePool //===----------------------------------------------------------------------===// // Infer shape attributes output: // - auto_pad set to NOTSET; // - strides: set to 1 if not defined by user; // - pads: set to proper value, 0 if not defined by user. LogicalResult ONNXAveragePoolOp::inferShapes() { // Cannot infer shape if no shape exists. if (!X().getType().isa()) return emitError("Input tensor not ranked"); auto builder = mlir::Builder(getContext()); // Get shape of input. auto xTy = X().getType().cast(); auto xShape = xTy.getShape(); // Kernel shape. auto kernelShape = kernel_shape(); if (!kernelShape) return emitError( "kernel_shape is a mandatory attribute for which there is no default"); // Ceil mode. auto ceilMode = ceil_mode().getSExtValue(); // Process strides and pads. LogicalResult res = processConvStrideParam(this, kernelShape); if (failed(res)) return res; auto stridesOpt = strides(); res = processConvPadParam( this, xShape, kernelShape, stridesOpt, llvm::None); if (failed(res)) return res; auto padsOpt = pads(); SmallVector outputDims; // Insert batch size. outputDims.emplace_back(xShape[0]); outputDims.emplace_back(xShape[1]); // Compute and insert spatial dims. insertConvSpatialDim(&outputDims, builder, xShape, kernelShape, padsOpt, stridesOpt, llvm::None, ceilMode); getResult().setType(RankedTensorType::get(outputDims, xTy.getElementType())); return success(); } //===----------------------------------------------------------------------===// // MaxPoolSingleOut //===----------------------------------------------------------------------===// // Infer shape attributes output: // - auto_pad set to NOTSET; // - dilations, strides: set to 1 if not defined by user; // - pads: set to proper value, 0 if not defined by user. LogicalResult ONNXMaxPoolSingleOutOp::inferShapes() { // Cannot infer shape if no shape exists. if (!X().getType().isa()) return emitError("Input tensor not ranked"); auto builder = mlir::Builder(getContext()); // Get shape of input. auto xTy = X().getType().cast(); auto xShape = xTy.getShape(); // Kernel shape. auto kernelShape = kernel_shape(); if (!kernelShape) return emitError( "kernel_shape is a mandatory attribute for which there is no default"); // Storage order. auto storageOrder = storage_order().getSExtValue(); if (storageOrder != 0) return emitError("column major storage order not supported at this time"); // Process strides, dilations, and pads. processConvTypeParams<>(this, X()); auto dilationsOpt = dilations(); auto stridesOpt = strides(); auto padsOpt = pads(); // Ceil mode. auto ceilMode = ceil_mode().getSExtValue(); SmallVector outputDims; // Insert batch size. outputDims.emplace_back(xShape[0]); outputDims.emplace_back(xShape[1]); // Compute and insert spatial dims. insertConvSpatialDim(&outputDims, builder, xShape, kernelShape, padsOpt, stridesOpt, dilationsOpt, ceilMode); getResult().setType(RankedTensorType::get(outputDims, xTy.getElementType())); return success(); } //===----------------------------------------------------------------------===// // Pad //===----------------------------------------------------------------------===// LogicalResult ONNXPadOp::inferShapes() { // Cannot infer shape if no shape exists. if (!data().getType().isa()) return emitError("Pad: unknown input shape"); // Cannot infer if the pads is not constant DenseElementsAttr padsAttributes = getAttr("pads").dyn_cast_or_null(); if (!padsAttributes) return emitError("Pad: unknown pads"); auto dataTy = data().getType().cast(); auto dataShape = dataTy.getShape(); auto dataRank = dataTy.getRank(); SmallVector outputShape(dataShape.begin(), dataShape.end()); // Get pads from valueAttribute. SmallVector pads(dataRank * 2, -1); auto valueIt = padsAttributes.getValues().begin(); for (int64_t i = 0; i < dataRank * 2; ++i) pads[i] = (*valueIt++).cast().getInt(); // Pads consists of two values for each axis of data. // The two values specify the number of elements padded before and after // respectively. for (int64_t i = 0; i < dataRank; ++i) { int64_t p1 = pads[i]; int64_t p2 = pads[i + dataRank]; // Have to non-negative constant if (p1 < 0 || p2 < 0) return emitError("padding value can not be negative"); if (outputShape[i] != -1) outputShape[i] += p1 + p2; } auto outputType = RankedTensorType::get(outputShape, dataTy.getElementType()); getResult().setType(outputType); return success(); } static Type padShapeInferenceHelper(Value data, ArrayAttr padsOpt) { // Cannot infer shape if no shape exists. if (!data.getType().isa()) return (Type)NULL; auto dataTy = data.getType().cast(); auto dataShape = dataTy.getShape(); auto dataRank = dataShape.size(); SmallVector outputShape(dataShape.begin(), dataShape.end()); if (padsOpt) { auto padsArray = padsOpt.getValue(); // Pads consists of two values for each axis of data. // The two values specify the number of elements padded before and after // respectively. for (int i = 0; i < dataRank; ++i) { int64_t p1 = (padsArray[i]).cast().getInt(); int64_t p2 = (padsArray[i + dataRank]).cast().getInt(); // Have to non-negative constant if (p1 < 0 || p2 < 0) return (Type)NULL; if (outputShape[i] != -1) outputShape[i] += p1 + p2; } return (RankedTensorType::get(outputShape, dataTy.getElementType())); } else { return (Type)NULL; } } //===----------------------------------------------------------------------===// // PadConstantPad //===----------------------------------------------------------------------===// LogicalResult ONNXPadConstantPadOp::inferShapes() { auto outputType = padShapeInferenceHelper(data(), pads()); if (!outputType) return emitError("missing output"); getResult().setType(outputType); return success(); } //===----------------------------------------------------------------------===// // PadConstantValuePad //===----------------------------------------------------------------------===// LogicalResult ONNXPadConstantValuePadOp::inferShapes() { auto outputType = padShapeInferenceHelper(data(), pads()); if (!outputType) return emitError("missing output"); getResult().setType(outputType); return success(); } void ONNXPadConstantValuePadOp::build(OpBuilder &builder, OperationState &state, Value data, ArrayAttr pads, FloatAttr constant_value, StringAttr mode) { Type outputType = padShapeInferenceHelper(data, pads); if (!outputType) { auto elementType = data.getType().cast().getElementType(); outputType = UnrankedTensorType::get(elementType); } build(builder, state, outputType, data, pads, constant_value, mode); } //===----------------------------------------------------------------------===// // Unsqueeze //===----------------------------------------------------------------------===// LogicalResult ONNXUnsqueezeOp::inferShapes() { if (!data().getType().isa()) return emitError("Input tensor not ranked"); auto operandTy = data().getType().cast(); int inRank = operandTy.getRank(); ArrayAttr axisAttrs = axesAttr(); SmallVector axes; int outRank = 0; if (axisAttrs) { outRank = inRank + axisAttrs.getValue().size(); for (auto axisAttr : axisAttrs.getValue()) { int axis = axisAttr.cast().getInt(); axis = axis >= 0 ? axis : (outRank + axis); // Valid range assert(axis >= -outRank && axis <= outRank - 1); if (std::find(axes.begin(), axes.end(), axis) == axes.end()) axes.emplace_back(axis); else return emitError("Duplicated axes"); } } else return emitError("Axes attribute is required"); SmallVector dims; for (int i = 0, j = 0; i < outRank || j < inRank; ++i) { if (std::find(axes.begin(), axes.end(), i) != axes.end()) { dims.emplace_back(1); } else { dims.emplace_back(operandTy.getShape()[j++]); } } getResult().setType(RankedTensorType::get(dims, operandTy.getElementType())); return success(); } //===----------------------------------------------------------------------===// // Cast //===----------------------------------------------------------------------===// LogicalResult ONNXCastOp::inferShapes() { ShapedType inputType = input().getType().dyn_cast(); if (!inputType) { return emitError("Non-shaped input type"); } auto getOutputType = [&inputType](Type elementType) -> Type { if (inputType.hasRank()) { return RankedTensorType::get(inputType.getShape(), elementType); } return UnrankedTensorType::get(elementType); }; int64_t targetType = toAttr().getInt(); OpBuilder builder(getContext()); if (auto elementType = convertONNXTypeToMLIRType( builder, static_cast(targetType))) { getResult().setType(getOutputType(elementType)); } else { return emitOpError("Unable to get the element type for to = " + std::to_string(targetType)); } return success(); } //===----------------------------------------------------------------------===// // Constant //===----------------------------------------------------------------------===// LogicalResult ONNXConstantOp::inferShapes() { if ((sparse_value().hasValue() && value().hasValue()) || (!sparse_value().hasValue() && !value().hasValue())) return emitError("Require exactly one of the two attributes, " "either value or sparse_value"); ElementsAttr valAttr; if (sparse_value().hasValue()) valAttr = sparse_valueAttr().cast(); else valAttr = valueAttr().cast(); getResult().setType(valAttr.getType()); return success(); } //===----------------------------------------------------------------------===// // Concat //===----------------------------------------------------------------------===// LogicalResult ONNXConcatOp::inferShapes() { int inputNum = getNumOperands(); for (int i = 0; i < inputNum; ++i) { if (!getOperand(i).getType().isa()) return emitError("Input tensor(s) not ranked"); } // Checking value of axis parameter. auto commonType = getOperand(0).getType().cast(); auto commonShape = commonType.getShape(); auto commonRank = commonShape.size(); auto axisIndex = axis().getSExtValue(); // Negative axis means values are counted from the opposite side. if (axisIndex < 0) { axisIndex = commonRank + axisIndex; auto builder = mlir::Builder(getContext()); axisAttr(builder.getI64IntegerAttr(axisIndex)); } if (axisIndex >= commonRank) return emitError("Concat axis value out of bound"); // Initial cummlative size is that of the first operand. int cummulativeAxisSize = commonShape[axisIndex]; // Compute the cummlative size with all of the other ones, and make sure // that the other sizes are all alike. for (int i = 1; i < inputNum; ++i) { auto currShape = getOperand(i).getType().cast().getShape(); if (currShape.size() != commonRank) return emitError("Concat input must all have the same rank"); for (int j = 0; j < commonRank; ++j) { if (j == axisIndex) { // Check that the value is positive. if (currShape[j] <= 0) return emitError("Concat axis being concatenated is " "expected to be known at compile time for now"); } else if (currShape[j] != commonShape[j]) { return emitError( "Concat input dimensions must be all identical, " "except for dimension on the axis of the concatenation"); } } cummulativeAxisSize += currShape[axisIndex]; } // Set output size and type SmallVector outputDims; for (int j = 0; j < commonRank; ++j) outputDims.emplace_back( j == axisIndex ? cummulativeAxisSize : commonShape[j]); getResult().setType( RankedTensorType::get(outputDims, commonType.getElementType())); return success(); } //===----------------------------------------------------------------------===// // RNN //===----------------------------------------------------------------------===// LogicalResult ONNXRNNOp::inferShapes() { return RNNShapeInference<>(this); } //===----------------------------------------------------------------------===// // LSTM //===----------------------------------------------------------------------===// LogicalResult ONNXLSTMOp::inferShapes() { return RNNShapeInference<>(this); } //===----------------------------------------------------------------------===// // GRU //===----------------------------------------------------------------------===// LogicalResult ONNXGRUOp::inferShapes() { return RNNShapeInference<>(this); } //===----------------------------------------------------------------------===// // Split //===----------------------------------------------------------------------===// LogicalResult ONNXSplitOp::inferShapes() { if (!getOperand().getType().cast()) return emitError("Input tensor not ranked"); int numOfResults = getNumResults(); auto inputType = getOperand().getType().cast(); auto inputShape = inputType.getShape(); int64_t inputRank = inputShape.size(); // Checking value of axis parameter. auto axisIndex = axis().getSExtValue(); if (axisIndex < -inputRank || axisIndex >= inputRank) return emitError("Split axis value out of bound"); // Negative axis means values are counted from the opposite side. if (axisIndex < 0) { axisIndex = inputRank + axisIndex; auto builder = mlir::Builder(getContext()); axisAttr(builder.getI64IntegerAttr(axisIndex)); } // Checking value of split parameter. auto splitAttribute = split(); SmallVector splitLengths; if (splitAttribute.hasValue()) { if (ArrayAttrSize(splitAttribute) != numOfResults) return emitError("Split size not equal to the number of results"); for (int i = 0; i < numOfResults; ++i) splitLengths.emplace_back(ArrayAttrIntVal(splitAttribute, i)); } else { if (inputShape[axisIndex] <= 0) return emitError("The dimension at the split axis is " "expected to be known at compile time"); if (inputShape[axisIndex] % numOfResults != 0) return emitError("The dimension at the split axis is " "expected to be divisible by the number of results"); // If split parameter is not specified, the dimension is split to // equal-sized parts. for (int i = 0; i < numOfResults; ++i) splitLengths.emplace_back(inputShape[axisIndex] / numOfResults); // Build attribute and store attribute. auto builder = mlir::Builder(getContext()); splitAttr(builder.getI64ArrayAttr(llvm::makeArrayRef(splitLengths))); } // Build result types. for (int i = 0; i < numOfResults; ++i) { SmallVector resultShape; for (int j = 0; j < inputRank; ++j) { if (j == axisIndex) { resultShape.emplace_back(splitLengths[i]); } else { resultShape.emplace_back(inputShape[j]); } } getResults()[i].setType( RankedTensorType::get(resultShape, inputType.getElementType())); } return success(); } //===----------------------------------------------------------------------===// // Flatten //===----------------------------------------------------------------------===// LogicalResult ONNXFlattenOp::inferShapes() { assert(axis() == 1 && "ONNXFlattenOp can only handle axis=1 for now"); auto inTy = input().getType().dyn_cast(); if (!inTy) { return emitOpError("Input is a non-shaped type"); } auto outTy = output().getType().dyn_cast(); if (!outTy) { return emitOpError("Output is a non-shaped type"); } // TODO(tjingrant): Seems like we can also fairly easily support the case // where the batch dimension is dynamic if (!outTy.hasStaticShape()) { auto inShape = inTy.getShape(); assert(inShape.size() >= 1 && "ONNXFlattenOp inShape.size() should be > 0"); uint64_t outDim = 1; for (auto it = inShape.begin() + 1; it < inShape.end(); it++) { outDim *= *it; } SmallVector dims; // https://pytorch.org/docs/master/generated/torch.nn.Flatten.html dims.emplace_back(inShape[0]); dims.emplace_back(outDim); getResult().setType(RankedTensorType::get(dims, outTy.getElementType())); } return success(); } //===----------------------------------------------------------------------===// // DynamicQuantizeLinear //===----------------------------------------------------------------------===// LogicalResult ONNXDynamicQuantizeLinearOp::inferShapes() { auto inTy = x().getType().dyn_cast(); if (!inTy || !inTy.hasStaticShape()) { return emitOpError("Input is not a statically-shaped type"); } auto yTy = y().getType().cast(); auto yScaleTy = y_scale().getType().cast(); auto yZPTy = y_zero_point().getType().cast(); IntegerType ui8Type = IntegerType::get(8, IntegerType::Unsigned, getContext()); FloatType f32Type = FloatType::getF32(getContext()); RankedTensorType scalarType = RankedTensorType::get({}, f32Type); RankedTensorType y_zero_point_type = RankedTensorType::get({}, ui8Type); // Set the types for the scalars if (!yScaleTy.hasStaticShape()) { y_scale().setType(scalarType); } if (!yZPTy.hasStaticShape()) { y_zero_point().setType(y_zero_point_type); } if (!yTy.hasStaticShape()) { RankedTensorType outType = RankedTensorType::get(inTy.getShape(), ui8Type); y().setType(outType); } return success(); } //===----------------------------------------------------------------------===// // QuantizeLinear //===----------------------------------------------------------------------===// LogicalResult ONNXQuantizeLinearOp::inferShapes() { auto inTy = x().getType().dyn_cast(); if (!inTy || !inTy.hasStaticShape()) { return emitOpError("Input is not a statically-shaped type"); } auto yTy = y().getType().cast(); if (!yTy.hasStaticShape()) { // TODO: Unfortunately, we can't tell if this should be signed or unsigned // here... IntegerType i8Type = IntegerType::get(8, getContext()); RankedTensorType outType = RankedTensorType::get(inTy.getShape(), i8Type); y().setType(outType); } return success(); } //===----------------------------------------------------------------------===// // DequantizeLinear //===----------------------------------------------------------------------===// LogicalResult ONNXDequantizeLinearOp::inferShapes() { auto inTy = x().getType().dyn_cast(); if (!inTy || !inTy.hasStaticShape()) { return emitOpError("Input is not a statically-shaped type"); } auto yTy = y().getType().cast(); if (!yTy.hasStaticShape()) { FloatType f32 = FloatType::getF32(getContext()); RankedTensorType outType = RankedTensorType::get(inTy.getShape(), f32); y().setType(outType); } return success(); } //===----------------------------------------------------------------------===// // ConvInteger - copied almost exactly from Conv (X -> x, W -> w, no bias) //===----------------------------------------------------------------------===// LogicalResult ONNXConvIntegerOp::inferShapes() { // Generic shape for data input X, weight tensor W // X: (N x C x D1 x D2 ... x Dn) // W: (M x C/group x k1 x k2 x ... x kn) // Cannot infer shape if no shape exists. if (!x().getType().isa() || !w().getType().isa()) { return emitOpError("Input tensor not ranked"); } auto xTy = x().getType().cast(); if (!xTy.getElementType().isInteger(8)) { return emitOpError("Invalid input type"); } auto xShape = xTy.getShape(); auto weightTy = w().getType().cast(); if (!weightTy.getElementType().isInteger(8)) { return emitOpError("Invalid input type"); } auto weightShape = weightTy.getShape(); auto builder = mlir::Builder(this->getContext()); // Lowest supported convolution is a one dimensional convolution. if (xShape.size() < 3) { return emitOpError("Data input shape must be at least (NxCxD1)"); } // Check that shape of weight and data have same length. if (xShape.size() != weightShape.size()) { return emitError("Weight size not compatible with data size"); } // Group is a required attribute and should have default value of 1. int64_t group = ONNXConvIntegerOp::group().getSExtValue(); // Check if the attribute actually exists. If it does not then add it. if (!groupAttr()) groupAttr(builder.getI64IntegerAttr(group)); // Check that the X.shape[1] == (W.shape[1] * group) == C condition holds. if (xShape[1] != -1 && weightShape[1] != -1 && xShape[1] != (weightShape[1] * group)) { return emitOpError("Channel dimension mismatch"); } // Note: the value of the group attribut only impacts the way the // computation is carried out and not the actual output size. // Number of spatial dimensions. auto spatialOffset = 2; int32_t spatialRank = xShape.size() - spatialOffset; // Use kernel_shape attribute if present otherwise use size from weight // argument. auto kernelShape = kernel_shape(); if (kernelShape.hasValue()) { if (ArrayAttrSize(kernelShape) != spatialRank) { return emitOpError( "kernel_shape length incompatible with spatial dimensions"); } // Have the right number of values, check them. for (int i = 0; i < spatialRank; ++i) if (ArrayAttrIntVal(kernelShape, i) < 1) { return emitError("bad kernel_shape value"); } } else { // Deduce shape from weight input. SmallVector defaultVals; for (int i = 0; i < spatialRank; ++i) defaultVals.emplace_back(weightShape[spatialOffset + i]); // Convert to ArrayRef, then build attribute, then store attribute. ArrayRef defaultRefs(defaultVals); auto builder = mlir::Builder(getContext()); kernel_shapeAttr(builder.getI64ArrayAttr(defaultRefs)); kernelShape = kernel_shape(); } // Process strides, dilations, and pads. processConvTypeParams<>(this, x()); auto dilationsOpt = dilations(); auto stridesOpt = strides(); auto padsOpt = pads(); // First two output dimensions consist of the number of batches and the // number of kernels being applied. SmallVector outputDims; // Insert batch size. outputDims.emplace_back(xShape[0]); // Insert number of filters being applied (number of output channels). outputDims.emplace_back(weightShape[0]); // Compute and insert spatial dims. insertConvSpatialDim(&outputDims, builder, xShape, kernelShape, padsOpt, stridesOpt, dilationsOpt); // ONNX spec specifies the output type as an int32 Type outputType = IntegerType::get(32, getContext()); getResult().setType(RankedTensorType::get(outputDims, outputType)); return success(); } //===----------------------------------------------------------------------===// // TableGen'd op method definitions //===----------------------------------------------------------------------===// #define GET_OP_CLASSES #include "src/Dialect/ONNX/ONNXOps.cpp.inc"