//===- onnx_ops.cpp - MLIR ONNX Operations --------------------------------===// // // Copyright 2019 The IBM Research Authors. // // ============================================================================= // // This file defines ONNX operations in the MLIR operation set. // //===----------------------------------------------------------------------===// #include "mlir/Dialect/Traits.h" #include "mlir/IR/Block.h" #include "mlir/IR/Builders.h" #include "mlir/IR/Function.h" #include "mlir/IR/IntegerSet.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/Module.h" #include "mlir/IR/OpImplementation.h" #include "mlir/IR/PatternMatch.h" #include "llvm/ADT/SetVector.h" #include "llvm/ADT/SmallBitVector.h" #include "onnx_ops.hpp" using namespace mlir; using namespace mlir::OpTrait::util; //===----------------------------------------------------------------------===// // ONNXOpsDialect //===----------------------------------------------------------------------===// /// Dialect creation, the instance will be owned by the context. This is the /// point of registration of custom types and operations for the dialect. ONNXOpsDialect::ONNXOpsDialect(mlir::MLIRContext *ctx) : mlir::Dialect(getDialectNamespace(), ctx) { addOperations< #define GET_OP_LIST #include "src/onnx.cpp.inc" >(); } void ONNXEntryPointOp::build(mlir::Builder *builder, mlir::OperationState &state, mlir::FuncOp function, int numInputs, int numOutputs) { state.addAttribute(ONNXEntryPointOp::getEntryPointFuncAttrName(), builder->getSymbolRefAttr(function)); state.addAttribute(ONNXEntryPointOp::getNumInputsAttrName(), builder->getI32IntegerAttr(numInputs)); state.addAttribute(ONNXEntryPointOp::getNumOutputsAttrName(), builder->getI32IntegerAttr(numOutputs)); } ONNXEntryPointOp ONNXEntryPointOp::create(mlir::Location location, mlir::FuncOp &func, int numInputs, int numOutputs) { mlir::OperationState state(location, "onnx.EntryPoint"); Builder builder(location->getContext()); mlir::ONNXEntryPointOp::build(&builder, state, func, numInputs, numOutputs); Operation *op = mlir::Operation::create(state); auto onnxEntryOp = llvm::cast(op); return onnxEntryOp; } //===----------------------------------------------------------------------===// // ONNX Operations //===----------------------------------------------------------------------===// // Exp /// Infer the output shape of the ONNXExpOp. This method is required by the /// shape inference interface. void ONNXExpOp::inferShapes() { getResult().setType(getOperand().getType()); } //===----------------------------------------------------------------------===// // Tanh /// Infer the output shape of the ONNXTanhOp. This method is required by the /// shape inference interface. void ONNXTanhOp::inferShapes() { getResult().setType(getOperand().getType()); } //===----------------------------------------------------------------------===// // Sinh /// Infer the output shape of the ONNXSinhOp. This method is required by the /// shape inference interface. void ONNXSinhOp::inferShapes() { getResult().setType(getOperand().getType()); } //===----------------------------------------------------------------------===// // Cosh /// Infer the output shape of the ONNXCoshOp. This method is required by the /// shape inference interface. void ONNXCoshOp::inferShapes() { getResult().setType(getOperand().getType()); } //===----------------------------------------------------------------------===// // Cos /// Infer the output shape of the ONNXCosOp. This method is required by the /// shape inference interface. void ONNXCosOp::inferShapes() { getResult().setType(getOperand().getType()); } //===----------------------------------------------------------------------===// // Log /// Infer the output shape of the ONNXLogOp. This method is required by the /// shape inference interface. void ONNXLogOp::inferShapes() { getResult().setType(getOperand().getType()); } //===----------------------------------------------------------------------===// // HardSigmoid /// Infer the output shape of the ONNXHardSigmoidOp. This method is required by /// the shape inference interface. void ONNXHardSigmoidOp::inferShapes() { getResult().setType(getOperand().getType()); } //===----------------------------------------------------------------------===// // Sigmoid /// Infer the output shape of the ONNXSigmoidOp. This method is required by the /// shape inference interface. void ONNXSigmoidOp::inferShapes() { getResult().setType(getOperand().getType()); } //===----------------------------------------------------------------------===// // Elu /// Infer the output shape of the ONNXEluOp. This method is required by the /// shape inference interface. void ONNXEluOp::inferShapes() { getResult().setType(getOperand().getType()); } //===----------------------------------------------------------------------===// // Relu /// Infer the output shape of the ONNXReluOp. This method is required by the /// shape inference interface. void ONNXReluOp::inferShapes() { getResult().setType(getOperand().getType()); } //===----------------------------------------------------------------------===// // LeakyRelu /// Infer the output shape of the ONNXLeakyReluOp. This method is required by /// the shape inference interface. void ONNXLeakyReluOp::inferShapes() { getResult().setType(getOperand().getType()); } //===----------------------------------------------------------------------===// // Selu /// Infer the output shape of the ONNXSeluOp. This method is required by /// the shape inference interface. void ONNXSeluOp::inferShapes() { getResult().setType(getOperand().getType()); } //===----------------------------------------------------------------------===// // Reciprocal /// Infer the output shape of the ONNXReciprocalOp. This method is required by /// the shape inference interface. void ONNXReciprocalOp::inferShapes() { getResult().setType(getOperand().getType()); } //===----------------------------------------------------------------------===// // Softmax /// Infer the output shape of the ONNXSoftmaxOp. This method is required by /// the shape inference interface. void ONNXSoftmaxOp::inferShapes() { getResult().setType(getOperand().getType()); } //===----------------------------------------------------------------------===// // Add /// Infer the output shape of the ONNXAddOp. This method is required by the /// shape inference interface. void ONNXAddOp::inferShapes() { if (!getOperand(0).getType().isa() || !getOperand(1).getType().isa()) return; auto lhsTy = getOperand(0).getType().cast(); auto rhsTy = getOperand(1).getType().cast(); getResult().setType(getBroadcastedType(lhsTy, rhsTy)); } //===----------------------------------------------------------------------===// // Mul /// Infer the output shape of the ONNXMulOp. This method is required by the /// shape inference interface. void ONNXMulOp::inferShapes() { if (!getOperand(0).getType().isa() || !getOperand(1).getType().isa()) return; auto lhsTy = getOperand(0).getType().cast(); auto rhsTy = getOperand(1).getType().cast(); getResult().setType(getBroadcastedType(lhsTy, rhsTy)); } //===----------------------------------------------------------------------===// // Div /// Infer the output shape of the ONNXDivOp. This method is required by the /// shape inference interface. void ONNXDivOp::inferShapes() { if (!getOperand(0).getType().isa() || !getOperand(1).getType().isa()) return; auto lhsTy = getOperand(0).getType().cast(); auto rhsTy = getOperand(1).getType().cast(); getResult().setType(getBroadcastedType(lhsTy, rhsTy)); } //===----------------------------------------------------------------------===// // Sub /// Infer the output shape of the ONNXSubOp. This method is required by the /// shape inference interface. void ONNXSubOp::inferShapes() { if (!getOperand(0).getType().isa() || !getOperand(1).getType().isa()) return; auto lhsTy = getOperand(0).getType().cast(); auto rhsTy = getOperand(1).getType().cast(); getResult().setType(getBroadcastedType(lhsTy, rhsTy)); } //===----------------------------------------------------------------------===// // And /// Infer the output shape of the ONNXAndOp. This method is required by the /// shape inference interface. void ONNXAndOp::inferShapes() { if (!getOperand(0).getType().isa() || !getOperand(1).getType().isa()) return; auto lhsTy = getOperand(0).getType().cast(); auto rhsTy = getOperand(1).getType().cast(); getResult().setType(getBroadcastedType(lhsTy, rhsTy)); } //===----------------------------------------------------------------------===// // Or /// Infer the output shape of the ONNXOrOp. This method is required by the /// shape inference interface. void ONNXOrOp::inferShapes() { if (!getOperand(0).getType().isa() || !getOperand(1).getType().isa()) return; auto lhsTy = getOperand(0).getType().cast(); auto rhsTy = getOperand(1).getType().cast(); getResult().setType(getBroadcastedType(lhsTy, rhsTy)); } //===----------------------------------------------------------------------===// // Xor /// Infer the output shape of the ONNXXorOp. This method is required by the /// shape inference interface. void ONNXXorOp::inferShapes() { if (!getOperand(0).getType().isa() || !getOperand(1).getType().isa()) return; auto lhsTy = getOperand(0).getType().cast(); auto rhsTy = getOperand(1).getType().cast(); getResult().setType(getBroadcastedType(lhsTy, rhsTy)); } //===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===// // Sum /// Infer the output shape of the ONNXSumOp. This method is required by the /// shape inference interface. void ONNXSumOp::inferShapes() { for (int i = 0; i < getNumOperands(); ++i) { if (!getOperand(i).getType().cast()) return; } Type resultTy = getOperand(0).getType().cast(); for (int i = 1; i < getNumOperands(); ++i) { Type nextTy = getOperand(i).getType().cast(); resultTy = getBroadcastedType(resultTy, nextTy); } getResult().setType(resultTy); } //===----------------------------------------------------------------------===// // Max /// Infer the output shape of the ONNXMaxOp. This method is required by the /// shape inference interface. void ONNXMaxOp::inferShapes() { for (int i = 0; i < getNumOperands(); ++i) { if (!getOperand(i).getType().cast()) return; } Type resultTy = getOperand(0).getType().cast(); for (int i = 1; i < getNumOperands(); ++i) { Type nextTy = getOperand(i).getType().cast(); resultTy = getBroadcastedType(resultTy, nextTy); } getResult().setType(resultTy); } //===----------------------------------------------------------------------===// // Min /// Infer the output shape of the ONNXMinOp. This method is required by the /// shape inference interface. void ONNXMinOp::inferShapes() { for (int i = 0; i < getNumOperands(); ++i) { if (!getOperand(i).getType().cast()) return; } Type resultTy = getOperand(0).getType().cast(); for (int i = 1; i < getNumOperands(); ++i) { Type nextTy = getOperand(i).getType().cast(); resultTy = getBroadcastedType(resultTy, nextTy); } getResult().setType(resultTy); } //===----------------------------------------------------------------------===// // Identity /// Infer the output shape of the ONNXIdentityOp. This method is required by the /// shape inference interface. void ONNXIdentityOp::inferShapes() { getResult().setType(getOperand().getType()); } //===----------------------------------------------------------------------===// // MatMul void ONNXMatMulOp::inferShapes() { // Cannot infer shape if no shape exists. if (!getOperand(0).getType().isa() || !getOperand(1).getType().isa()) return; auto lhsTy = getOperand(0).getType().cast(); auto rhsTy = getOperand(1).getType().cast(); SmallVector dims; dims.emplace_back(lhsTy.getShape()[0]); dims.emplace_back(rhsTy.getShape()[1]); getResult().setType(RankedTensorType::get(dims, lhsTy.getElementType())); } // TODO: // Verify that matrix sizes are valid. // Take into account the dimensionality of the matrix. //===----------------------------------------------------------------------===// // Gemm void ONNXGemmOp::inferShapes() { // Cannot infer shape if no shape exists. if (!getOperand(0).getType().isa() || !getOperand(1).getType().isa()) return; auto lhsTy = getOperand(0).getType().cast(); auto rhsTy = getOperand(1).getType().cast(); SmallVector dims; dims.emplace_back(lhsTy.getShape()[0]); dims.emplace_back(rhsTy.getShape()[1]); getResult().setType(RankedTensorType::get(dims, lhsTy.getElementType())); } // GemmNoBias void ONNXGemmNoBiasOp::inferShapes() { // Cannot infer shape if no shape exists. if (!getOperand(0).getType().isa() || !getOperand(1).getType().isa()) return; auto lhsTy = getOperand(0).getType().cast(); auto rhsTy = getOperand(1).getType().cast(); SmallVector dims; dims.emplace_back(lhsTy.getShape()[0]); dims.emplace_back(rhsTy.getShape()[1]); getResult().setType(RankedTensorType::get(dims, lhsTy.getElementType())); } // TODO: // Verify that matrix sizes are valid for multiplication and addition. // Take into account the dimensionality of the matrix. //===----------------------------------------------------------------------===// // Reshape void ONNXReshapeOp::inferShapes() { // Cannot infer shape if no shape tensor is specified. if (!getOperand(1).getType().isa()) emitError("Shape tensor not ranked."); auto inputTensorTy = getOperand(0).getType().cast(); auto shapeTensorTy = getOperand(1).getType().cast(); // Only rank 1 shape tensors are supported. if (shapeTensorTy.getShape().size() != 1) emitError("Shape tensor must have rank one."); int64_t outputRank = shapeTensorTy.getShape()[0]; // Shape tensor must have constant shape. if (outputRank < 0) emitError("Shape tensor must have constant shape."); SmallVector dims; for (int i = 0; i < outputRank; ++i) dims.emplace_back(-1); getResult().setType( RankedTensorType::get(dims, inputTensorTy.getElementType())); } //===----------------------------------------------------------------------===// // Transpose void ONNXTransposeOp::inferShapes() { // Cannot infer shape if no shape exists. if (!getOperand().getType().isa()) emitError("Shape tensor not ranked."); // Naive transposition which handles the default case of // reversing the shape of the tensor (similar to numpy.transpose). auto arrayTy = getOperand().getType().cast(); SmallVector dims; if (auto permutation = getAttrOfType( ONNXTransposeOp::getPermAttrName())) { // Perform transposition according to perm attribute. for (auto perm : permutation.getValue()) dims.emplace_back(arrayTy.getShape()[perm.cast().getInt()]); } else { // Default for (auto dim : llvm::reverse(arrayTy.getShape())) dims.emplace_back(dim); } getResult().setType(RankedTensorType::get(dims, arrayTy.getElementType())); } LogicalResult verify(ONNXTransposeOp op) { auto module = op.getParentOfType(); if (!module) op.emitError("Expected to belong to a module."); if (auto permutation = op.getAttrOfType( ONNXTransposeOp::getPermAttrName())) { for (auto perm : permutation.getValue()) if (perm.cast().getInt() < 0) op.emitError("Cannot tranpose, permuation contains negative index."); } return success(); } //===----------------------------------------------------------------------===// // Conv // For this operation, we define the attributes once in the original Conv // operation class. There is no need to redefine the attribute names for the // other classes based on Conv. void ONNXConvNoBiasOp::inferShapes() { // Generic shape for data input X and weight tensor W: // X: (N x C x D1 x D2 ... x Dn) // W: (M x C/group x k1 x k2 x ... x kn) // Cannot infer shape if no shape exists. if (!getOperand(0).getType().isa() || !getOperand(1).getType().isa()) return; auto dataTy = getOperand(0)->getType().cast(); auto weightTy = getOperand(1)->getType().cast(); auto dataShape = dataTy.getShape(); auto weightShape = weightTy.getShape(); if (dataShape.size() != weightShape.size()) emitError("ConvNoBias: weight size not compatible with data size."); // Group is a required attribute and should have default value of 1. int64_t group = getAttrOfType( ONNXConvOp::getGroupAttrName()).getInt(); if (!group) emitError("ConvNoBias: group attribute missing."); // Check that the X.shape[1] == (W.shape[1] * group) == C condition holds. if (dataShape[1] != (weightShape[1] * group)) emitError("ConvNoBias: channel dimension mismatch."); // Required attributes. auto auto_pad = getAttrOfType( ONNXConvOp::getAutoPadAttrName()); auto pads = getAttrOfType( ONNXConvOp::getPadsAttrName()); SmallVector dims; // Insert batch size. dims.emplace_back(dataShape[0]); // Insert number of filters being applied (number of output channels). dims.emplace_back(weightShape[0]); // // Compute the spatial dimensions. // SmallVector spatialDims; // // Number of spatial dimensions. // int32_t nDims = dataTy.size() - 2; // // Initialize dimenions based on the input and weight spatial dimensions. // for (int i = 2; i < dataTy.size(); ++i) // spatialDims.emplace_back(dataTy[i] - weightTy[i]); // // Add padding information. // if () { // for (int i = 0; i < nDims; ++i) { // // Padding for beginning of axis. // int32_t p = (pads.getValue()[i]).cast().getInt(); // spatialDims[i] += p; // // Padding for end of axis. // p = (pads.getValue()[i + nDims]).cast().getInt(); // spatialDims[i] += p; // } // } else if () { // // Attribute pads has not been provided. // } getResult().setType(RankedTensorType::get(dims, dataTy.getElementType())); } LogicalResult verify(ONNXConvNoBiasOp op) { auto module = op.getParentOfType(); if (!module) op.emitError("expected to belong to a module"); auto autoPadAttr = op.getAttrOfType( ONNXConvOp::getAutoPadAttrName()); if (!autoPadAttr) op.emitError("ONNXConvNoBiasOp: auto_pad attribute not specified."); auto groupAttr = op.getAttrOfType(ONNXConvOp::getGroupAttrName()); if (!groupAttr) op.emitError("ONNXConvNoBiasOp: group attribute not specified."); return success(); } //===----------------------------------------------------------------------===// // TableGen'd op method definitions //===----------------------------------------------------------------------===// #define GET_OP_CLASSES #include "src/onnx.cpp.inc"