2019-11-02 05:09:48 +08:00
|
|
|
//===- onnx_ops.cpp - MLIR ONNX Operations --------------------------------===//
|
|
|
|
//
|
|
|
|
// Copyright 2019 The IBM Research Authors.
|
|
|
|
//
|
|
|
|
// =============================================================================
|
|
|
|
//
|
|
|
|
// This file defines ONNX operations in the MLIR operation set.
|
|
|
|
//
|
|
|
|
//===----------------------------------------------------------------------===//
|
2019-12-20 00:28:06 +08:00
|
|
|
#include "mlir/Dialect/Traits.h"
|
2019-11-02 05:09:48 +08:00
|
|
|
#include "mlir/IR/Block.h"
|
|
|
|
#include "mlir/IR/Builders.h"
|
|
|
|
#include "mlir/IR/Function.h"
|
|
|
|
#include "mlir/IR/IntegerSet.h"
|
|
|
|
#include "mlir/IR/Matchers.h"
|
2020-01-21 03:46:54 +08:00
|
|
|
#include "mlir/IR/Module.h"
|
2019-11-02 05:09:48 +08:00
|
|
|
#include "mlir/IR/OpImplementation.h"
|
|
|
|
#include "mlir/IR/PatternMatch.h"
|
2019-12-17 07:45:39 +08:00
|
|
|
#include "llvm/ADT/SetVector.h"
|
|
|
|
#include "llvm/ADT/SmallBitVector.h"
|
2019-11-02 05:09:48 +08:00
|
|
|
|
|
|
|
#include "onnx_ops.hpp"
|
|
|
|
|
|
|
|
using namespace mlir;
|
2019-12-20 00:28:06 +08:00
|
|
|
using namespace mlir::OpTrait::util;
|
2019-11-02 05:09:48 +08:00
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// ONNXOpsDialect
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
/// Dialect creation, the instance will be owned by the context. This is the
|
|
|
|
/// point of registration of custom types and operations for the dialect.
|
2019-12-17 07:45:39 +08:00
|
|
|
ONNXOpsDialect::ONNXOpsDialect(mlir::MLIRContext *ctx)
|
2019-11-02 05:09:48 +08:00
|
|
|
: mlir::Dialect(getDialectNamespace(), ctx) {
|
|
|
|
addOperations<
|
|
|
|
#define GET_OP_LIST
|
2019-12-23 13:13:52 +08:00
|
|
|
#include "src/onnx.cpp.inc"
|
2019-11-02 05:09:48 +08:00
|
|
|
>();
|
|
|
|
}
|
|
|
|
|
2019-12-22 13:25:02 +08:00
|
|
|
void ONNXEntryPointOp::build(mlir::Builder *builder,
|
|
|
|
mlir::OperationState &state, mlir::FuncOp function,
|
|
|
|
int numInputs, int numOutputs) {
|
|
|
|
state.addAttribute(ONNXEntryPointOp::getEntryPointFuncAttrName(),
|
|
|
|
builder->getSymbolRefAttr(function));
|
|
|
|
state.addAttribute(ONNXEntryPointOp::getNumInputsAttrName(),
|
|
|
|
builder->getI32IntegerAttr(numInputs));
|
|
|
|
state.addAttribute(ONNXEntryPointOp::getNumOutputsAttrName(),
|
|
|
|
builder->getI32IntegerAttr(numOutputs));
|
|
|
|
}
|
|
|
|
|
|
|
|
ONNXEntryPointOp ONNXEntryPointOp::create(mlir::Location location,
|
|
|
|
mlir::FuncOp &func, int numInputs,
|
|
|
|
int numOutputs) {
|
|
|
|
mlir::OperationState state(location, "onnx.EntryPoint");
|
|
|
|
Builder builder(location->getContext());
|
|
|
|
mlir::ONNXEntryPointOp::build(&builder, state, func, numInputs, numOutputs);
|
|
|
|
Operation *op = mlir::Operation::create(state);
|
|
|
|
auto onnxEntryOp = llvm::cast<mlir::ONNXEntryPointOp>(op);
|
|
|
|
return onnxEntryOp;
|
|
|
|
}
|
|
|
|
|
2019-11-02 05:09:48 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// ONNX Operations
|
2019-12-06 09:08:09 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// Exp
|
|
|
|
/// Infer the output shape of the ONNXExpOp. This method is required by the
|
|
|
|
/// shape inference interface.
|
2020-01-14 01:21:29 +08:00
|
|
|
void ONNXExpOp::inferShapes() { getResult().setType(getOperand().getType()); }
|
2019-12-06 09:08:09 +08:00
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// Tanh
|
|
|
|
/// Infer the output shape of the ONNXTanhOp. This method is required by the
|
|
|
|
/// shape inference interface.
|
|
|
|
void ONNXTanhOp::inferShapes() {
|
2020-01-14 01:21:29 +08:00
|
|
|
getResult().setType(getOperand().getType());
|
2019-12-06 09:08:09 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// Sinh
|
|
|
|
/// Infer the output shape of the ONNXSinhOp. This method is required by the
|
|
|
|
/// shape inference interface.
|
|
|
|
void ONNXSinhOp::inferShapes() {
|
2020-01-14 01:21:29 +08:00
|
|
|
getResult().setType(getOperand().getType());
|
2019-12-06 09:08:09 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// Cosh
|
|
|
|
/// Infer the output shape of the ONNXCoshOp. This method is required by the
|
|
|
|
/// shape inference interface.
|
|
|
|
void ONNXCoshOp::inferShapes() {
|
2020-01-14 01:21:29 +08:00
|
|
|
getResult().setType(getOperand().getType());
|
2019-12-06 09:08:09 +08:00
|
|
|
}
|
|
|
|
|
2020-01-08 11:11:21 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// Cos
|
|
|
|
/// Infer the output shape of the ONNXCosOp. This method is required by the
|
|
|
|
/// shape inference interface.
|
2020-01-14 01:21:29 +08:00
|
|
|
void ONNXCosOp::inferShapes() { getResult().setType(getOperand().getType()); }
|
2020-01-08 11:11:21 +08:00
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// Log
|
|
|
|
/// Infer the output shape of the ONNXLogOp. This method is required by the
|
|
|
|
/// shape inference interface.
|
2020-01-14 01:21:29 +08:00
|
|
|
void ONNXLogOp::inferShapes() { getResult().setType(getOperand().getType()); }
|
2020-01-08 11:11:21 +08:00
|
|
|
|
[MLIR] Add support for Max, Min, Sum, Elu, Selu, LeakyRelu, HardSigmoid (#395)
* Lower ONNXSumOp
* Add inferShapes() and test cases
* Load the first operand to the result
* Update SharingWork.md
* Update SharingWork.md
* Update SharingWork.md
* Add support for Max, Min
* Pass operation instead of location to mapToLowerScalarOp
* Add support for Elu, Selu, LeakyRelu, HardSigmoid
* Add test cases
* Update SharingWork.md
* Rewrite the part of lowering variadic ops and use it for binary ops
* Use two diffenrent templates for Unary and Variadic Ops
* Revise the code
2019-12-12 10:49:50 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// HardSigmoid
|
|
|
|
/// Infer the output shape of the ONNXHardSigmoidOp. This method is required by
|
|
|
|
/// the shape inference interface.
|
|
|
|
void ONNXHardSigmoidOp::inferShapes() {
|
2020-01-14 01:21:29 +08:00
|
|
|
getResult().setType(getOperand().getType());
|
[MLIR] Add support for Max, Min, Sum, Elu, Selu, LeakyRelu, HardSigmoid (#395)
* Lower ONNXSumOp
* Add inferShapes() and test cases
* Load the first operand to the result
* Update SharingWork.md
* Update SharingWork.md
* Update SharingWork.md
* Add support for Max, Min
* Pass operation instead of location to mapToLowerScalarOp
* Add support for Elu, Selu, LeakyRelu, HardSigmoid
* Add test cases
* Update SharingWork.md
* Rewrite the part of lowering variadic ops and use it for binary ops
* Use two diffenrent templates for Unary and Variadic Ops
* Revise the code
2019-12-12 10:49:50 +08:00
|
|
|
}
|
|
|
|
|
2019-12-06 09:08:09 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// Sigmoid
|
|
|
|
/// Infer the output shape of the ONNXSigmoidOp. This method is required by the
|
|
|
|
/// shape inference interface.
|
|
|
|
void ONNXSigmoidOp::inferShapes() {
|
2020-01-14 01:21:29 +08:00
|
|
|
getResult().setType(getOperand().getType());
|
2019-12-06 09:08:09 +08:00
|
|
|
}
|
|
|
|
|
[MLIR] Add support for Max, Min, Sum, Elu, Selu, LeakyRelu, HardSigmoid (#395)
* Lower ONNXSumOp
* Add inferShapes() and test cases
* Load the first operand to the result
* Update SharingWork.md
* Update SharingWork.md
* Update SharingWork.md
* Add support for Max, Min
* Pass operation instead of location to mapToLowerScalarOp
* Add support for Elu, Selu, LeakyRelu, HardSigmoid
* Add test cases
* Update SharingWork.md
* Rewrite the part of lowering variadic ops and use it for binary ops
* Use two diffenrent templates for Unary and Variadic Ops
* Revise the code
2019-12-12 10:49:50 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// Elu
|
|
|
|
/// Infer the output shape of the ONNXEluOp. This method is required by the
|
|
|
|
/// shape inference interface.
|
2020-01-14 01:21:29 +08:00
|
|
|
void ONNXEluOp::inferShapes() { getResult().setType(getOperand().getType()); }
|
[MLIR] Add support for Max, Min, Sum, Elu, Selu, LeakyRelu, HardSigmoid (#395)
* Lower ONNXSumOp
* Add inferShapes() and test cases
* Load the first operand to the result
* Update SharingWork.md
* Update SharingWork.md
* Update SharingWork.md
* Add support for Max, Min
* Pass operation instead of location to mapToLowerScalarOp
* Add support for Elu, Selu, LeakyRelu, HardSigmoid
* Add test cases
* Update SharingWork.md
* Rewrite the part of lowering variadic ops and use it for binary ops
* Use two diffenrent templates for Unary and Variadic Ops
* Revise the code
2019-12-12 10:49:50 +08:00
|
|
|
|
2019-12-06 13:31:17 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// Relu
|
|
|
|
/// Infer the output shape of the ONNXReluOp. This method is required by the
|
|
|
|
/// shape inference interface.
|
|
|
|
void ONNXReluOp::inferShapes() {
|
2020-01-14 01:21:29 +08:00
|
|
|
getResult().setType(getOperand().getType());
|
2019-12-06 13:31:17 +08:00
|
|
|
}
|
|
|
|
|
[MLIR] Add support for Max, Min, Sum, Elu, Selu, LeakyRelu, HardSigmoid (#395)
* Lower ONNXSumOp
* Add inferShapes() and test cases
* Load the first operand to the result
* Update SharingWork.md
* Update SharingWork.md
* Update SharingWork.md
* Add support for Max, Min
* Pass operation instead of location to mapToLowerScalarOp
* Add support for Elu, Selu, LeakyRelu, HardSigmoid
* Add test cases
* Update SharingWork.md
* Rewrite the part of lowering variadic ops and use it for binary ops
* Use two diffenrent templates for Unary and Variadic Ops
* Revise the code
2019-12-12 10:49:50 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// LeakyRelu
|
|
|
|
/// Infer the output shape of the ONNXLeakyReluOp. This method is required by
|
|
|
|
/// the shape inference interface.
|
|
|
|
void ONNXLeakyReluOp::inferShapes() {
|
2020-01-14 01:21:29 +08:00
|
|
|
getResult().setType(getOperand().getType());
|
[MLIR] Add support for Max, Min, Sum, Elu, Selu, LeakyRelu, HardSigmoid (#395)
* Lower ONNXSumOp
* Add inferShapes() and test cases
* Load the first operand to the result
* Update SharingWork.md
* Update SharingWork.md
* Update SharingWork.md
* Add support for Max, Min
* Pass operation instead of location to mapToLowerScalarOp
* Add support for Elu, Selu, LeakyRelu, HardSigmoid
* Add test cases
* Update SharingWork.md
* Rewrite the part of lowering variadic ops and use it for binary ops
* Use two diffenrent templates for Unary and Variadic Ops
* Revise the code
2019-12-12 10:49:50 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// Selu
|
|
|
|
/// Infer the output shape of the ONNXSeluOp. This method is required by
|
|
|
|
/// the shape inference interface.
|
|
|
|
void ONNXSeluOp::inferShapes() {
|
2020-01-14 01:21:29 +08:00
|
|
|
getResult().setType(getOperand().getType());
|
2019-12-16 14:23:33 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// Reciprocal
|
|
|
|
/// Infer the output shape of the ONNXReciprocalOp. This method is required by
|
|
|
|
/// the shape inference interface.
|
|
|
|
void ONNXReciprocalOp::inferShapes() {
|
2020-01-14 01:21:29 +08:00
|
|
|
getResult().setType(getOperand().getType());
|
[MLIR] Add support for Max, Min, Sum, Elu, Selu, LeakyRelu, HardSigmoid (#395)
* Lower ONNXSumOp
* Add inferShapes() and test cases
* Load the first operand to the result
* Update SharingWork.md
* Update SharingWork.md
* Update SharingWork.md
* Add support for Max, Min
* Pass operation instead of location to mapToLowerScalarOp
* Add support for Elu, Selu, LeakyRelu, HardSigmoid
* Add test cases
* Update SharingWork.md
* Rewrite the part of lowering variadic ops and use it for binary ops
* Use two diffenrent templates for Unary and Variadic Ops
* Revise the code
2019-12-12 10:49:50 +08:00
|
|
|
}
|
|
|
|
|
2020-01-21 10:57:32 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// Softmax
|
|
|
|
/// Infer the output shape of the ONNXSoftmaxOp. This method is required by
|
|
|
|
/// the shape inference interface.
|
|
|
|
void ONNXSoftmaxOp::inferShapes() {
|
|
|
|
getResult().setType(getOperand().getType());
|
|
|
|
}
|
|
|
|
|
2019-11-02 05:09:48 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
2019-11-13 02:37:46 +08:00
|
|
|
// Add
|
2019-11-19 10:08:21 +08:00
|
|
|
/// Infer the output shape of the ONNXAddOp. This method is required by the
|
|
|
|
/// shape inference interface.
|
2019-11-08 00:42:40 +08:00
|
|
|
void ONNXAddOp::inferShapes() {
|
2020-01-14 01:21:29 +08:00
|
|
|
if (!getOperand(0).getType().isa<RankedTensorType>() ||
|
|
|
|
!getOperand(1).getType().isa<RankedTensorType>())
|
2019-12-20 00:28:06 +08:00
|
|
|
return;
|
2020-01-14 01:21:29 +08:00
|
|
|
auto lhsTy = getOperand(0).getType().cast<RankedTensorType>();
|
|
|
|
auto rhsTy = getOperand(1).getType().cast<RankedTensorType>();
|
|
|
|
getResult().setType(getBroadcastedType(lhsTy, rhsTy));
|
2019-11-08 00:42:40 +08:00
|
|
|
}
|
|
|
|
|
[MLIR] Lower ONNX element-wise binary ops: Mul, Div, Sub, And, Or, Xor (#388)
* Lower ONNX element-wise binary ops: Mul, Div, Sub, And, Or, Xor
* Edit gen_doc.py to avoid changes about AnyTypeOf<[AnyMemRef, AnyTensor]>
* Miss a space
* Add tests
* Shorten ONNXElementWiseBinaryOpLowering into ONNXEWBinaryOpLowering
* Move lowering patterns into runOnModule()
* Redundant space
2019-12-04 00:17:21 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// Mul
|
|
|
|
/// Infer the output shape of the ONNXMulOp. This method is required by the
|
|
|
|
/// shape inference interface.
|
|
|
|
void ONNXMulOp::inferShapes() {
|
2020-01-14 01:21:29 +08:00
|
|
|
if (!getOperand(0).getType().isa<RankedTensorType>() ||
|
|
|
|
!getOperand(1).getType().isa<RankedTensorType>())
|
2019-12-20 00:28:06 +08:00
|
|
|
return;
|
2020-01-14 01:21:29 +08:00
|
|
|
auto lhsTy = getOperand(0).getType().cast<RankedTensorType>();
|
|
|
|
auto rhsTy = getOperand(1).getType().cast<RankedTensorType>();
|
|
|
|
getResult().setType(getBroadcastedType(lhsTy, rhsTy));
|
[MLIR] Lower ONNX element-wise binary ops: Mul, Div, Sub, And, Or, Xor (#388)
* Lower ONNX element-wise binary ops: Mul, Div, Sub, And, Or, Xor
* Edit gen_doc.py to avoid changes about AnyTypeOf<[AnyMemRef, AnyTensor]>
* Miss a space
* Add tests
* Shorten ONNXElementWiseBinaryOpLowering into ONNXEWBinaryOpLowering
* Move lowering patterns into runOnModule()
* Redundant space
2019-12-04 00:17:21 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// Div
|
|
|
|
/// Infer the output shape of the ONNXDivOp. This method is required by the
|
|
|
|
/// shape inference interface.
|
|
|
|
void ONNXDivOp::inferShapes() {
|
2020-01-14 01:21:29 +08:00
|
|
|
if (!getOperand(0).getType().isa<RankedTensorType>() ||
|
|
|
|
!getOperand(1).getType().isa<RankedTensorType>())
|
2019-12-20 00:28:06 +08:00
|
|
|
return;
|
2020-01-14 01:21:29 +08:00
|
|
|
auto lhsTy = getOperand(0).getType().cast<RankedTensorType>();
|
|
|
|
auto rhsTy = getOperand(1).getType().cast<RankedTensorType>();
|
|
|
|
getResult().setType(getBroadcastedType(lhsTy, rhsTy));
|
[MLIR] Lower ONNX element-wise binary ops: Mul, Div, Sub, And, Or, Xor (#388)
* Lower ONNX element-wise binary ops: Mul, Div, Sub, And, Or, Xor
* Edit gen_doc.py to avoid changes about AnyTypeOf<[AnyMemRef, AnyTensor]>
* Miss a space
* Add tests
* Shorten ONNXElementWiseBinaryOpLowering into ONNXEWBinaryOpLowering
* Move lowering patterns into runOnModule()
* Redundant space
2019-12-04 00:17:21 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// Sub
|
|
|
|
/// Infer the output shape of the ONNXSubOp. This method is required by the
|
|
|
|
/// shape inference interface.
|
|
|
|
void ONNXSubOp::inferShapes() {
|
2020-01-14 01:21:29 +08:00
|
|
|
if (!getOperand(0).getType().isa<RankedTensorType>() ||
|
|
|
|
!getOperand(1).getType().isa<RankedTensorType>())
|
2019-12-20 00:28:06 +08:00
|
|
|
return;
|
2020-01-14 01:21:29 +08:00
|
|
|
auto lhsTy = getOperand(0).getType().cast<RankedTensorType>();
|
|
|
|
auto rhsTy = getOperand(1).getType().cast<RankedTensorType>();
|
|
|
|
getResult().setType(getBroadcastedType(lhsTy, rhsTy));
|
[MLIR] Lower ONNX element-wise binary ops: Mul, Div, Sub, And, Or, Xor (#388)
* Lower ONNX element-wise binary ops: Mul, Div, Sub, And, Or, Xor
* Edit gen_doc.py to avoid changes about AnyTypeOf<[AnyMemRef, AnyTensor]>
* Miss a space
* Add tests
* Shorten ONNXElementWiseBinaryOpLowering into ONNXEWBinaryOpLowering
* Move lowering patterns into runOnModule()
* Redundant space
2019-12-04 00:17:21 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// And
|
|
|
|
/// Infer the output shape of the ONNXAndOp. This method is required by the
|
|
|
|
/// shape inference interface.
|
|
|
|
void ONNXAndOp::inferShapes() {
|
2020-01-14 01:21:29 +08:00
|
|
|
if (!getOperand(0).getType().isa<RankedTensorType>() ||
|
|
|
|
!getOperand(1).getType().isa<RankedTensorType>())
|
2019-12-20 00:28:06 +08:00
|
|
|
return;
|
2020-01-14 01:21:29 +08:00
|
|
|
auto lhsTy = getOperand(0).getType().cast<RankedTensorType>();
|
|
|
|
auto rhsTy = getOperand(1).getType().cast<RankedTensorType>();
|
|
|
|
getResult().setType(getBroadcastedType(lhsTy, rhsTy));
|
[MLIR] Lower ONNX element-wise binary ops: Mul, Div, Sub, And, Or, Xor (#388)
* Lower ONNX element-wise binary ops: Mul, Div, Sub, And, Or, Xor
* Edit gen_doc.py to avoid changes about AnyTypeOf<[AnyMemRef, AnyTensor]>
* Miss a space
* Add tests
* Shorten ONNXElementWiseBinaryOpLowering into ONNXEWBinaryOpLowering
* Move lowering patterns into runOnModule()
* Redundant space
2019-12-04 00:17:21 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// Or
|
|
|
|
/// Infer the output shape of the ONNXOrOp. This method is required by the
|
|
|
|
/// shape inference interface.
|
2019-12-20 00:28:06 +08:00
|
|
|
void ONNXOrOp::inferShapes() {
|
2020-01-14 01:21:29 +08:00
|
|
|
if (!getOperand(0).getType().isa<RankedTensorType>() ||
|
|
|
|
!getOperand(1).getType().isa<RankedTensorType>())
|
2019-12-20 00:28:06 +08:00
|
|
|
return;
|
2020-01-14 01:21:29 +08:00
|
|
|
auto lhsTy = getOperand(0).getType().cast<RankedTensorType>();
|
|
|
|
auto rhsTy = getOperand(1).getType().cast<RankedTensorType>();
|
|
|
|
getResult().setType(getBroadcastedType(lhsTy, rhsTy));
|
2019-12-20 00:28:06 +08:00
|
|
|
}
|
[MLIR] Lower ONNX element-wise binary ops: Mul, Div, Sub, And, Or, Xor (#388)
* Lower ONNX element-wise binary ops: Mul, Div, Sub, And, Or, Xor
* Edit gen_doc.py to avoid changes about AnyTypeOf<[AnyMemRef, AnyTensor]>
* Miss a space
* Add tests
* Shorten ONNXElementWiseBinaryOpLowering into ONNXEWBinaryOpLowering
* Move lowering patterns into runOnModule()
* Redundant space
2019-12-04 00:17:21 +08:00
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// Xor
|
|
|
|
/// Infer the output shape of the ONNXXorOp. This method is required by the
|
|
|
|
/// shape inference interface.
|
|
|
|
void ONNXXorOp::inferShapes() {
|
2020-01-14 01:21:29 +08:00
|
|
|
if (!getOperand(0).getType().isa<RankedTensorType>() ||
|
|
|
|
!getOperand(1).getType().isa<RankedTensorType>())
|
2019-12-20 00:28:06 +08:00
|
|
|
return;
|
2020-01-14 01:21:29 +08:00
|
|
|
auto lhsTy = getOperand(0).getType().cast<RankedTensorType>();
|
|
|
|
auto rhsTy = getOperand(1).getType().cast<RankedTensorType>();
|
|
|
|
getResult().setType(getBroadcastedType(lhsTy, rhsTy));
|
[MLIR] Lower ONNX element-wise binary ops: Mul, Div, Sub, And, Or, Xor (#388)
* Lower ONNX element-wise binary ops: Mul, Div, Sub, And, Or, Xor
* Edit gen_doc.py to avoid changes about AnyTypeOf<[AnyMemRef, AnyTensor]>
* Miss a space
* Add tests
* Shorten ONNXElementWiseBinaryOpLowering into ONNXEWBinaryOpLowering
* Move lowering patterns into runOnModule()
* Redundant space
2019-12-04 00:17:21 +08:00
|
|
|
}
|
|
|
|
|
2019-11-13 02:37:46 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
[MLIR] Add support for Max, Min, Sum, Elu, Selu, LeakyRelu, HardSigmoid (#395)
* Lower ONNXSumOp
* Add inferShapes() and test cases
* Load the first operand to the result
* Update SharingWork.md
* Update SharingWork.md
* Update SharingWork.md
* Add support for Max, Min
* Pass operation instead of location to mapToLowerScalarOp
* Add support for Elu, Selu, LeakyRelu, HardSigmoid
* Add test cases
* Update SharingWork.md
* Rewrite the part of lowering variadic ops and use it for binary ops
* Use two diffenrent templates for Unary and Variadic Ops
* Revise the code
2019-12-12 10:49:50 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// Sum
|
|
|
|
/// Infer the output shape of the ONNXSumOp. This method is required by the
|
|
|
|
/// shape inference interface.
|
|
|
|
void ONNXSumOp::inferShapes() {
|
2019-12-20 00:28:06 +08:00
|
|
|
for (int i = 0; i < getNumOperands(); ++i) {
|
2020-01-14 01:21:29 +08:00
|
|
|
if (!getOperand(i).getType().cast<RankedTensorType>())
|
2019-12-20 00:28:06 +08:00
|
|
|
return;
|
|
|
|
}
|
2020-01-14 01:21:29 +08:00
|
|
|
Type resultTy = getOperand(0).getType().cast<RankedTensorType>();
|
2019-12-20 00:28:06 +08:00
|
|
|
for (int i = 1; i < getNumOperands(); ++i) {
|
2020-01-14 01:21:29 +08:00
|
|
|
Type nextTy = getOperand(i).getType().cast<RankedTensorType>();
|
2019-12-20 00:28:06 +08:00
|
|
|
resultTy = getBroadcastedType(resultTy, nextTy);
|
|
|
|
}
|
2020-01-14 01:21:29 +08:00
|
|
|
getResult().setType(resultTy);
|
[MLIR] Add support for Max, Min, Sum, Elu, Selu, LeakyRelu, HardSigmoid (#395)
* Lower ONNXSumOp
* Add inferShapes() and test cases
* Load the first operand to the result
* Update SharingWork.md
* Update SharingWork.md
* Update SharingWork.md
* Add support for Max, Min
* Pass operation instead of location to mapToLowerScalarOp
* Add support for Elu, Selu, LeakyRelu, HardSigmoid
* Add test cases
* Update SharingWork.md
* Rewrite the part of lowering variadic ops and use it for binary ops
* Use two diffenrent templates for Unary and Variadic Ops
* Revise the code
2019-12-12 10:49:50 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// Max
|
|
|
|
/// Infer the output shape of the ONNXMaxOp. This method is required by the
|
|
|
|
/// shape inference interface.
|
|
|
|
void ONNXMaxOp::inferShapes() {
|
2019-12-20 00:28:06 +08:00
|
|
|
for (int i = 0; i < getNumOperands(); ++i) {
|
2020-01-14 01:21:29 +08:00
|
|
|
if (!getOperand(i).getType().cast<RankedTensorType>())
|
2019-12-20 00:28:06 +08:00
|
|
|
return;
|
|
|
|
}
|
2020-01-14 01:21:29 +08:00
|
|
|
Type resultTy = getOperand(0).getType().cast<RankedTensorType>();
|
2019-12-20 00:28:06 +08:00
|
|
|
for (int i = 1; i < getNumOperands(); ++i) {
|
2020-01-14 01:21:29 +08:00
|
|
|
Type nextTy = getOperand(i).getType().cast<RankedTensorType>();
|
2019-12-20 00:28:06 +08:00
|
|
|
resultTy = getBroadcastedType(resultTy, nextTy);
|
|
|
|
}
|
2020-01-14 01:21:29 +08:00
|
|
|
getResult().setType(resultTy);
|
[MLIR] Add support for Max, Min, Sum, Elu, Selu, LeakyRelu, HardSigmoid (#395)
* Lower ONNXSumOp
* Add inferShapes() and test cases
* Load the first operand to the result
* Update SharingWork.md
* Update SharingWork.md
* Update SharingWork.md
* Add support for Max, Min
* Pass operation instead of location to mapToLowerScalarOp
* Add support for Elu, Selu, LeakyRelu, HardSigmoid
* Add test cases
* Update SharingWork.md
* Rewrite the part of lowering variadic ops and use it for binary ops
* Use two diffenrent templates for Unary and Variadic Ops
* Revise the code
2019-12-12 10:49:50 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// Min
|
|
|
|
/// Infer the output shape of the ONNXMinOp. This method is required by the
|
|
|
|
/// shape inference interface.
|
|
|
|
void ONNXMinOp::inferShapes() {
|
2019-12-20 00:28:06 +08:00
|
|
|
for (int i = 0; i < getNumOperands(); ++i) {
|
2020-01-14 01:21:29 +08:00
|
|
|
if (!getOperand(i).getType().cast<RankedTensorType>())
|
2019-12-20 00:28:06 +08:00
|
|
|
return;
|
|
|
|
}
|
2020-01-14 01:21:29 +08:00
|
|
|
Type resultTy = getOperand(0).getType().cast<RankedTensorType>();
|
2019-12-20 00:28:06 +08:00
|
|
|
for (int i = 1; i < getNumOperands(); ++i) {
|
2020-01-14 01:21:29 +08:00
|
|
|
Type nextTy = getOperand(i).getType().cast<RankedTensorType>();
|
2019-12-20 00:28:06 +08:00
|
|
|
resultTy = getBroadcastedType(resultTy, nextTy);
|
|
|
|
}
|
2020-01-14 01:21:29 +08:00
|
|
|
getResult().setType(resultTy);
|
[MLIR] Add support for Max, Min, Sum, Elu, Selu, LeakyRelu, HardSigmoid (#395)
* Lower ONNXSumOp
* Add inferShapes() and test cases
* Load the first operand to the result
* Update SharingWork.md
* Update SharingWork.md
* Update SharingWork.md
* Add support for Max, Min
* Pass operation instead of location to mapToLowerScalarOp
* Add support for Elu, Selu, LeakyRelu, HardSigmoid
* Add test cases
* Update SharingWork.md
* Rewrite the part of lowering variadic ops and use it for binary ops
* Use two diffenrent templates for Unary and Variadic Ops
* Revise the code
2019-12-12 10:49:50 +08:00
|
|
|
}
|
|
|
|
|
2019-12-17 07:45:39 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// Identity
|
|
|
|
/// Infer the output shape of the ONNXIdentityOp. This method is required by the
|
|
|
|
/// shape inference interface.
|
|
|
|
void ONNXIdentityOp::inferShapes() {
|
2020-01-14 01:21:29 +08:00
|
|
|
getResult().setType(getOperand().getType());
|
2019-12-17 07:45:39 +08:00
|
|
|
}
|
|
|
|
|
[MLIR] Add support for Max, Min, Sum, Elu, Selu, LeakyRelu, HardSigmoid (#395)
* Lower ONNXSumOp
* Add inferShapes() and test cases
* Load the first operand to the result
* Update SharingWork.md
* Update SharingWork.md
* Update SharingWork.md
* Add support for Max, Min
* Pass operation instead of location to mapToLowerScalarOp
* Add support for Elu, Selu, LeakyRelu, HardSigmoid
* Add test cases
* Update SharingWork.md
* Rewrite the part of lowering variadic ops and use it for binary ops
* Use two diffenrent templates for Unary and Variadic Ops
* Revise the code
2019-12-12 10:49:50 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2019-11-13 02:37:46 +08:00
|
|
|
// MatMul
|
|
|
|
|
|
|
|
void ONNXMatMulOp::inferShapes() {
|
2019-11-16 02:10:41 +08:00
|
|
|
// Cannot infer shape if no shape exists.
|
2020-01-14 01:21:29 +08:00
|
|
|
if (!getOperand(0).getType().isa<RankedTensorType>() ||
|
|
|
|
!getOperand(1).getType().isa<RankedTensorType>())
|
2019-11-16 02:10:41 +08:00
|
|
|
return;
|
2020-01-10 04:30:57 +08:00
|
|
|
|
|
|
|
auto lhsTy = getOperand(0)->getType().cast<RankedTensorType>();
|
|
|
|
auto rhsTy = getOperand(1)->getType().cast<RankedTensorType>();
|
|
|
|
|
2019-11-16 02:10:41 +08:00
|
|
|
SmallVector<int64_t, 2> dims;
|
2020-01-10 04:30:57 +08:00
|
|
|
auto lhsShape = lhsTy.getShape();
|
|
|
|
auto rhsShape = rhsTy.getShape();
|
|
|
|
if (lhsShape.size() == 1 && rhsShape.size() == 1) {
|
|
|
|
// Special case when both arrays are 1-dimensional and according to
|
|
|
|
// numpy rules the types need to be extended to 1xN and Nx1. Helper sizes
|
|
|
|
// need to be removed after the multiplication but cannot be removed if all
|
2020-01-11 00:22:41 +08:00
|
|
|
// sizes are 1.
|
2020-01-10 04:30:57 +08:00
|
|
|
if (lhsShape[0] != -1 && rhsShape[0] != -1 &&
|
|
|
|
lhsShape[0] != rhsShape[0])
|
|
|
|
emitError("Attempt to multiply incompatible matrices.");
|
|
|
|
dims.emplace_back(1);
|
2020-01-11 01:30:12 +08:00
|
|
|
} else if (lhsShape.size() > 2 && rhsShape.size() == 2) {
|
|
|
|
// (s1 x s2 x... x sK x M x N) MATMUL (N x P)
|
|
|
|
// =>
|
|
|
|
// (s1 x s2 x... x sK x M x P)
|
|
|
|
|
|
|
|
// Check legality of matrix multiplication.
|
|
|
|
unsigned leftDims = lhsShape.size();
|
|
|
|
if (lhsShape[leftDims - 1] != -1 && rhsShape[0] != -1 &&
|
|
|
|
lhsShape[leftDims - 1] != rhsShape[0])
|
|
|
|
emitError("Attempt to multiply incompatible matrices.");
|
|
|
|
|
|
|
|
for (int i = 0; i < leftDims - 1; ++i)
|
|
|
|
dims.emplace_back(lhsShape[i]);
|
|
|
|
dims.emplace_back(rhsShape[1]);
|
|
|
|
} else if (lhsShape.size() == 2 && rhsShape.size() > 2) {
|
|
|
|
// (M x N) MATMUL (s1 x s2 x... x sK x N x P)
|
|
|
|
// =>
|
|
|
|
// (s1 x s2 x... x sK x M x P)
|
|
|
|
|
|
|
|
// Check legality of matrix multiplication.
|
|
|
|
unsigned rightDims = rhsShape.size();
|
|
|
|
if (lhsShape[1] != -1 && rhsShape[rightDims - 2] != -1 &&
|
|
|
|
lhsShape[1] != rhsShape[rightDims - 2])
|
|
|
|
emitError("Attempt to multiply incompatible matrices.");
|
|
|
|
|
|
|
|
for (int i = 0; i < rightDims - 2; ++i)
|
|
|
|
dims.emplace_back(rhsShape[i]);
|
|
|
|
dims.emplace_back(lhsShape[0]);
|
|
|
|
dims.emplace_back(rhsShape[rightDims - 1]);
|
|
|
|
} else if (lhsShape.size() > 2 && rhsShape.size() > 2) {
|
|
|
|
// (s1 x s2 x... x sK x M x N) MATMUL (t1 x t2 x... x tK x N x P)
|
|
|
|
// =>
|
|
|
|
// (u1 x u2 x... x uK x M x P)
|
|
|
|
|
|
|
|
// Check legality of matrix multiplication.
|
|
|
|
unsigned leftDims = lhsShape.size();
|
|
|
|
unsigned rightDims = rhsShape.size();
|
|
|
|
if (lhsShape[leftDims - 1] != -1 && rhsShape[rightDims - 2] != -1 &&
|
|
|
|
lhsShape[leftDims - 1] != rhsShape[rightDims - 2])
|
|
|
|
emitError("Attempt to multiply incompatible matrices.");
|
|
|
|
|
|
|
|
// Check and perform broadcasting for the shapes.
|
|
|
|
SmallVector<int64_t, 2> lhsBcastShape;
|
|
|
|
for (int i = 0; i < leftDims - 2; ++i)
|
|
|
|
lhsBcastShape.emplace_back(lhsShape[i]);
|
|
|
|
SmallVector<int64_t, 2> rhsBcastShape;
|
|
|
|
for (int i = 0; i < rightDims - 2; ++i)
|
|
|
|
rhsBcastShape.emplace_back(rhsShape[i]);
|
|
|
|
if (!getBroadcastedShape(lhsBcastShape, rhsBcastShape, dims))
|
|
|
|
emitError("Broadcasted dimensions are incompatible.");
|
|
|
|
|
|
|
|
dims.emplace_back(lhsShape[leftDims - 2]);
|
|
|
|
dims.emplace_back(rhsShape[rightDims - 1]);
|
2020-01-10 04:30:57 +08:00
|
|
|
} else {
|
2020-01-11 01:30:12 +08:00
|
|
|
// This case covers all remaining combinations of 1 and 2-D matrices.
|
2020-01-11 04:26:29 +08:00
|
|
|
int64_t lhsDim = lhsShape[0];
|
|
|
|
int64_t rhsDim = rhsShape[0];
|
|
|
|
if (lhsShape.size() > 1) {
|
|
|
|
lhsDim = lhsShape[1];
|
2020-01-11 04:16:45 +08:00
|
|
|
dims.emplace_back(lhsShape[0]);
|
2020-01-11 04:26:29 +08:00
|
|
|
}
|
2020-01-11 04:16:45 +08:00
|
|
|
|
2020-01-11 04:26:29 +08:00
|
|
|
// Check legality of matrix multiplication.
|
|
|
|
if (lhsDim != -1 && rhsDim != -1 && lhsDim != rhsDim)
|
|
|
|
emitError("Attempt to multiply incompatible matrices.");
|
|
|
|
|
|
|
|
if (rhsShape.size() > 1)
|
2020-01-11 04:16:45 +08:00
|
|
|
dims.emplace_back(rhsShape[1]);
|
2020-01-10 04:30:57 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
getResult()->setType(RankedTensorType::get(dims, lhsTy.getElementType()));
|
2019-11-13 02:37:46 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
// Gemm
|
|
|
|
|
|
|
|
void ONNXGemmOp::inferShapes() {
|
2019-11-16 02:10:41 +08:00
|
|
|
// Cannot infer shape if no shape exists.
|
2020-01-14 01:21:29 +08:00
|
|
|
if (!getOperand(0).getType().isa<RankedTensorType>() ||
|
|
|
|
!getOperand(1).getType().isa<RankedTensorType>())
|
2019-11-16 02:10:41 +08:00
|
|
|
return;
|
2020-01-14 01:21:29 +08:00
|
|
|
auto lhsTy = getOperand(0).getType().cast<RankedTensorType>();
|
|
|
|
auto rhsTy = getOperand(1).getType().cast<RankedTensorType>();
|
2019-11-16 02:10:41 +08:00
|
|
|
SmallVector<int64_t, 2> dims;
|
|
|
|
dims.emplace_back(lhsTy.getShape()[0]);
|
2019-11-13 02:37:46 +08:00
|
|
|
dims.emplace_back(rhsTy.getShape()[1]);
|
2020-01-14 01:21:29 +08:00
|
|
|
getResult().setType(RankedTensorType::get(dims, lhsTy.getElementType()));
|
2019-11-13 02:37:46 +08:00
|
|
|
}
|
|
|
|
|
2020-01-16 03:27:21 +08:00
|
|
|
// GemmNoBias
|
2019-11-13 02:37:46 +08:00
|
|
|
|
2020-01-16 03:11:32 +08:00
|
|
|
void ONNXGemmNoBiasOp::inferShapes() {
|
2019-11-16 02:10:41 +08:00
|
|
|
// Cannot infer shape if no shape exists.
|
2020-01-14 01:21:29 +08:00
|
|
|
if (!getOperand(0).getType().isa<RankedTensorType>() ||
|
|
|
|
!getOperand(1).getType().isa<RankedTensorType>())
|
2019-11-16 02:10:41 +08:00
|
|
|
return;
|
2020-01-14 01:21:29 +08:00
|
|
|
auto lhsTy = getOperand(0).getType().cast<RankedTensorType>();
|
|
|
|
auto rhsTy = getOperand(1).getType().cast<RankedTensorType>();
|
2019-11-16 02:10:41 +08:00
|
|
|
SmallVector<int64_t, 2> dims;
|
|
|
|
dims.emplace_back(lhsTy.getShape()[0]);
|
2019-11-13 02:37:46 +08:00
|
|
|
dims.emplace_back(rhsTy.getShape()[1]);
|
2020-01-14 01:21:29 +08:00
|
|
|
getResult().setType(RankedTensorType::get(dims, lhsTy.getElementType()));
|
2019-11-13 02:37:46 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
// TODO:
|
|
|
|
// Verify that matrix sizes are valid for multiplication and addition.
|
|
|
|
// Take into account the dimensionality of the matrix.
|
|
|
|
|
2019-12-14 04:28:56 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
// Reshape
|
|
|
|
|
|
|
|
void ONNXReshapeOp::inferShapes() {
|
|
|
|
// Cannot infer shape if no shape tensor is specified.
|
2020-01-14 01:21:29 +08:00
|
|
|
if (!getOperand(1).getType().isa<RankedTensorType>())
|
2019-12-14 04:28:56 +08:00
|
|
|
emitError("Shape tensor not ranked.");
|
|
|
|
|
2020-01-14 01:21:29 +08:00
|
|
|
auto inputTensorTy = getOperand(0).getType().cast<RankedTensorType>();
|
|
|
|
auto shapeTensorTy = getOperand(1).getType().cast<RankedTensorType>();
|
2019-12-14 04:28:56 +08:00
|
|
|
|
|
|
|
// Only rank 1 shape tensors are supported.
|
|
|
|
if (shapeTensorTy.getShape().size() != 1)
|
|
|
|
emitError("Shape tensor must have rank one.");
|
|
|
|
|
|
|
|
int64_t outputRank = shapeTensorTy.getShape()[0];
|
|
|
|
|
|
|
|
// Shape tensor must have constant shape.
|
|
|
|
if (outputRank < 0)
|
|
|
|
emitError("Shape tensor must have constant shape.");
|
|
|
|
|
|
|
|
SmallVector<int64_t, 2> dims;
|
|
|
|
for (int i = 0; i < outputRank; ++i)
|
|
|
|
dims.emplace_back(-1);
|
|
|
|
|
2020-01-14 01:21:29 +08:00
|
|
|
getResult().setType(
|
2019-12-14 04:28:56 +08:00
|
|
|
RankedTensorType::get(dims, inputTensorTy.getElementType()));
|
|
|
|
}
|
|
|
|
|
2020-01-08 03:48:01 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
// Transpose
|
|
|
|
|
|
|
|
void ONNXTransposeOp::inferShapes() {
|
|
|
|
// Cannot infer shape if no shape exists.
|
2020-01-14 01:21:29 +08:00
|
|
|
if (!getOperand().getType().isa<RankedTensorType>())
|
2020-01-08 03:48:01 +08:00
|
|
|
emitError("Shape tensor not ranked.");
|
|
|
|
|
|
|
|
// Naive transposition which handles the default case of
|
|
|
|
// reversing the shape of the tensor (similar to numpy.transpose).
|
2020-01-14 01:21:29 +08:00
|
|
|
auto arrayTy = getOperand().getType().cast<RankedTensorType>();
|
2020-01-14 07:08:19 +08:00
|
|
|
SmallVector<int64_t, 2> dims;
|
|
|
|
|
|
|
|
if (auto permutation = getAttrOfType<ArrayAttr>(
|
|
|
|
ONNXTransposeOp::getPermAttrName())) {
|
|
|
|
// Perform transposition according to perm attribute.
|
2020-01-21 03:46:54 +08:00
|
|
|
for (auto perm : permutation.getValue())
|
|
|
|
dims.emplace_back(arrayTy.getShape()[perm.cast<IntegerAttr>().getInt()]);
|
2020-01-14 07:08:19 +08:00
|
|
|
} else {
|
|
|
|
// Default
|
2020-01-21 03:46:54 +08:00
|
|
|
for (auto dim : llvm::reverse(arrayTy.getShape()))
|
|
|
|
dims.emplace_back(dim);
|
2020-01-14 07:08:19 +08:00
|
|
|
}
|
|
|
|
|
2020-01-14 01:21:29 +08:00
|
|
|
getResult().setType(RankedTensorType::get(dims, arrayTy.getElementType()));
|
2020-01-08 03:48:01 +08:00
|
|
|
}
|
|
|
|
|
2020-01-21 03:46:54 +08:00
|
|
|
LogicalResult verify(ONNXTransposeOp op) {
|
|
|
|
auto module = op.getParentOfType<ModuleOp>();
|
|
|
|
if (!module)
|
|
|
|
op.emitError("Expected to belong to a module.");
|
|
|
|
|
|
|
|
if (auto permutation = op.getAttrOfType<ArrayAttr>(
|
|
|
|
ONNXTransposeOp::getPermAttrName())) {
|
|
|
|
for (auto perm : permutation.getValue())
|
|
|
|
if (perm.cast<IntegerAttr>().getInt() < 0)
|
|
|
|
op.emitError("Cannot tranpose, permuation contains negative index.");
|
|
|
|
}
|
|
|
|
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
2019-11-02 05:09:48 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// TableGen'd op method definitions
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
#define GET_OP_CLASSES
|
2019-12-23 13:13:52 +08:00
|
|
|
#include "src/onnx.cpp.inc"
|