2019-11-02 05:09:48 +08:00
|
|
|
//===- onnx_ops.cpp - MLIR ONNX Operations --------------------------------===//
|
|
|
|
//
|
|
|
|
// Copyright 2019 The IBM Research Authors.
|
|
|
|
//
|
|
|
|
// =============================================================================
|
|
|
|
//
|
|
|
|
// This file defines ONNX operations in the MLIR operation set.
|
|
|
|
//
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
#include "llvm/ADT/SetVector.h"
|
|
|
|
#include "llvm/ADT/SmallBitVector.h"
|
|
|
|
#include "mlir/IR/Block.h"
|
|
|
|
#include "mlir/IR/Builders.h"
|
|
|
|
#include "mlir/IR/Function.h"
|
|
|
|
#include "mlir/IR/IntegerSet.h"
|
|
|
|
#include "mlir/IR/Matchers.h"
|
|
|
|
#include "mlir/IR/OpImplementation.h"
|
|
|
|
#include "mlir/IR/PatternMatch.h"
|
|
|
|
|
|
|
|
#include "onnx_ops.hpp"
|
|
|
|
|
|
|
|
using namespace mlir;
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// ONNXOpsDialect
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
/// Dialect creation, the instance will be owned by the context. This is the
|
|
|
|
/// point of registration of custom types and operations for the dialect.
|
|
|
|
ONNXOpsDialect::ONNXOpsDialect(mlir::MLIRContext* ctx)
|
|
|
|
: mlir::Dialect(getDialectNamespace(), ctx) {
|
|
|
|
addOperations<
|
|
|
|
#define GET_OP_LIST
|
|
|
|
#include "src/compiler/onnx.cpp.inc"
|
|
|
|
>();
|
|
|
|
}
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// ONNX Operations
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2019-11-13 02:37:46 +08:00
|
|
|
// Add
|
2019-11-02 05:09:48 +08:00
|
|
|
|
2019-11-08 00:42:40 +08:00
|
|
|
void ONNXAddOp::inferShapes() {
|
|
|
|
getResult()->setType(getOperand(0)->getType());
|
|
|
|
}
|
|
|
|
|
2019-11-13 02:37:46 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
// MatMul
|
|
|
|
|
|
|
|
void ONNXMatMulOp::inferShapes() {
|
|
|
|
auto lhsTy = getOperand(0)->getType().cast<RankedTensorType>();
|
|
|
|
auto rhsTy = getOperand(1)->getType().cast<RankedTensorType>();
|
|
|
|
SmallVector<int64_t, 2> dims(lhsTy.getShape()[0]);
|
|
|
|
dims.emplace_back(rhsTy.getShape()[1]);
|
|
|
|
getResult()->setType(RankedTensorType::get(dims, lhsTy.getElementType()));
|
|
|
|
}
|
|
|
|
|
|
|
|
// TODO:
|
|
|
|
// Verify that matrix sizes are valid.
|
|
|
|
// Take into account the dimensionality of the matrix.
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
// Gemm
|
|
|
|
|
|
|
|
void ONNXGemmOp::inferShapes() {
|
|
|
|
auto lhsTy = getOperand(0)->getType().cast<RankedTensorType>();
|
|
|
|
auto rhsTy = getOperand(1)->getType().cast<RankedTensorType>();
|
|
|
|
SmallVector<int64_t, 2> dims(lhsTy.getShape()[0]);
|
|
|
|
dims.emplace_back(rhsTy.getShape()[1]);
|
|
|
|
getResult()->setType(RankedTensorType::get(dims, lhsTy.getElementType()));
|
|
|
|
}
|
|
|
|
|
|
|
|
// FullGemm
|
|
|
|
|
|
|
|
void ONNXFullGemmOp::inferShapes() {
|
|
|
|
auto lhsTy = getOperand(0)->getType().cast<RankedTensorType>();
|
|
|
|
auto rhsTy = getOperand(1)->getType().cast<RankedTensorType>();
|
|
|
|
SmallVector<int64_t, 2> dims(lhsTy.getShape()[0]);
|
|
|
|
dims.emplace_back(rhsTy.getShape()[1]);
|
|
|
|
getResult()->setType(RankedTensorType::get(dims, lhsTy.getElementType()));
|
|
|
|
}
|
|
|
|
|
|
|
|
// TODO:
|
|
|
|
// Verify that matrix sizes are valid for multiplication and addition.
|
|
|
|
// Take into account the dimensionality of the matrix.
|
|
|
|
|
2019-11-02 05:09:48 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// TableGen'd op method definitions
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
#define GET_OP_CLASSES
|
|
|
|
#include "src/compiler/onnx.cpp.inc"
|