[MLIR] Add support for Max, Min, Sum, Elu, Selu, LeakyRelu, HardSigmoid (#395)
* Lower ONNXSumOp * Add inferShapes() and test cases * Load the first operand to the result * Update SharingWork.md * Update SharingWork.md * Update SharingWork.md * Add support for Max, Min * Pass operation instead of location to mapToLowerScalarOp * Add support for Elu, Selu, LeakyRelu, HardSigmoid * Add test cases * Update SharingWork.md * Rewrite the part of lowering variadic ops and use it for binary ops * Use two diffenrent templates for Unary and Variadic Ops * Revise the code
This commit is contained in:
parent
fb1b43f842
commit
5ed79083d5
|
@ -6,27 +6,34 @@ ONNX operations for which some work is needed.
|
||||||
* M for multi-broadcast, U for unidir-broadcast
|
* M for multi-broadcast, U for unidir-broadcast
|
||||||
|
|
||||||
|
|
||||||
ONNX Oper | Person working on it | ONNX 2 KRNL | Basic functionality | Extended functionality (e.g. broadcast)
|
| ONNX Oper | Person working on it | ONNX 2 KRNL | Basic functionality | Extended functionality (e.g. broadcast) |
|
||||||
----------|----------------------|--------------|---------------------|----------------------------------------
|
| ---------- | --------------------- | -------------- | --------------------- | ---------------------------------------- |
|
||||||
Add | ? | v | v | noM
|
| Add | Tung (updated) | v | v | noM |
|
||||||
And | ? | v | v | noM
|
| And | Tung | v | v | noM |
|
||||||
Cosh | ? | v | v | noM
|
| Cosh | Tung | v | v | noM |
|
||||||
Div | ? | v | v |
|
| Div | Tung | v | v | |
|
||||||
Exp | ? | v | v |
|
| Elu | Tung | v | v | |
|
||||||
FullGemm | | | | noU
|
| Exp | Tung | v | v | |
|
||||||
Gemm | | | | noU
|
| FullGemm | | | | noU |
|
||||||
MatMul | | | | noM
|
| Gemm | | | | noU |
|
||||||
Mul | ? | v | v | noM
|
| HardSigmoid | Tung | v | v | |
|
||||||
Or | ? | v | v | noM
|
| LeakyRelu | Tung | v | v | |
|
||||||
Relu | ? | v | v |
|
| MatMul | | | | noM |
|
||||||
Sigmoid | ? | v | v |
|
| Max | Tung | v | v | noM |
|
||||||
Sinh | ? | v | v |
|
| Min | Tung | v | v | noM |
|
||||||
Sub | ? | v | v | noM
|
| Mul | Tung | v | v | noM |
|
||||||
Tanh | ? | v | v |
|
| Or | Tung | v | v | noM |
|
||||||
Xor | ? | v | v | noM
|
| Relu | Tung | v | v | |
|
||||||
|
| Selu | Tung | v | v | |
|
||||||
|
| Sigmoid | Tung | v | v | |
|
||||||
|
| Sinh | Tung | v | v | |
|
||||||
|
| Sub | Tung | v | v | noM |
|
||||||
|
| Sum | Tung | v | v | noM |
|
||||||
|
| Tanh | Tung | v | v | |
|
||||||
|
| Xor | Tung | v | v | noM |
|
||||||
|
|
||||||
|
|
||||||
ONNX operations for which the work is completed (full functionality) and tested
|
ONNX operations for which the work is completed (full functionality) and tested
|
||||||
|
|
||||||
ONNX Oper | Person working on it | Initial work | Basic functionality | Extended functionality (e.g. broadcast)
|
| ONNX Oper | Person working on it | Initial work | Basic functionality | Extended functionality (e.g. broadcast) |
|
||||||
----------|----------------------|--------------|---------------------|----------------------------------------
|
| ---------- | ---------------------- | -------------- | --------------------- | ---------------------------------------- |
|
||||||
|
|
|
@ -265,7 +265,8 @@ def collect_types(schema, input) :
|
||||||
def gen_schema(schema) :
|
def gen_schema(schema) :
|
||||||
ShapeInferenceList=['Exp', 'Tanh', 'Sinh', 'Cosh', 'Sigmoid', 'Relu',
|
ShapeInferenceList=['Exp', 'Tanh', 'Sinh', 'Cosh', 'Sigmoid', 'Relu',
|
||||||
'Add', 'Mul', 'Div', 'Sub', 'And', 'Or', 'Xor',
|
'Add', 'Mul', 'Div', 'Sub', 'And', 'Or', 'Xor',
|
||||||
'MatMul', 'Gemm']
|
'Sum', 'Max', 'Min', 'MatMul', 'Gemm', 'LeakyRelu',
|
||||||
|
'Elu', 'Selu', 'HardSigmoid']
|
||||||
CanonicalList=['Add', 'Identity']
|
CanonicalList=['Add', 'Identity']
|
||||||
line_indent = ' '
|
line_indent = ' '
|
||||||
|
|
||||||
|
|
|
@ -70,6 +70,14 @@ void ONNXCoshOp::inferShapes() {
|
||||||
getResult()->setType(getOperand()->getType());
|
getResult()->setType(getOperand()->getType());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// HardSigmoid
|
||||||
|
/// Infer the output shape of the ONNXHardSigmoidOp. This method is required by
|
||||||
|
/// the shape inference interface.
|
||||||
|
void ONNXHardSigmoidOp::inferShapes() {
|
||||||
|
getResult()->setType(getOperand()->getType());
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Sigmoid
|
// Sigmoid
|
||||||
/// Infer the output shape of the ONNXSigmoidOp. This method is required by the
|
/// Infer the output shape of the ONNXSigmoidOp. This method is required by the
|
||||||
|
@ -78,6 +86,14 @@ void ONNXSigmoidOp::inferShapes() {
|
||||||
getResult()->setType(getOperand()->getType());
|
getResult()->setType(getOperand()->getType());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// Elu
|
||||||
|
/// Infer the output shape of the ONNXEluOp. This method is required by the
|
||||||
|
/// shape inference interface.
|
||||||
|
void ONNXEluOp::inferShapes() {
|
||||||
|
getResult()->setType(getOperand()->getType());
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Relu
|
// Relu
|
||||||
/// Infer the output shape of the ONNXReluOp. This method is required by the
|
/// Infer the output shape of the ONNXReluOp. This method is required by the
|
||||||
|
@ -86,6 +102,22 @@ void ONNXReluOp::inferShapes() {
|
||||||
getResult()->setType(getOperand()->getType());
|
getResult()->setType(getOperand()->getType());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// LeakyRelu
|
||||||
|
/// Infer the output shape of the ONNXLeakyReluOp. This method is required by
|
||||||
|
/// the shape inference interface.
|
||||||
|
void ONNXLeakyReluOp::inferShapes() {
|
||||||
|
getResult()->setType(getOperand()->getType());
|
||||||
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// Selu
|
||||||
|
/// Infer the output shape of the ONNXSeluOp. This method is required by
|
||||||
|
/// the shape inference interface.
|
||||||
|
void ONNXSeluOp::inferShapes() {
|
||||||
|
getResult()->setType(getOperand()->getType());
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Add
|
// Add
|
||||||
/// Infer the output shape of the ONNXAddOp. This method is required by the
|
/// Infer the output shape of the ONNXAddOp. This method is required by the
|
||||||
|
@ -144,6 +176,32 @@ void ONNXXorOp::inferShapes() {
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// Sum
|
||||||
|
/// Infer the output shape of the ONNXSumOp. This method is required by the
|
||||||
|
/// shape inference interface.
|
||||||
|
void ONNXSumOp::inferShapes() {
|
||||||
|
getResult()->setType(getOperand(0)->getType());
|
||||||
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// Max
|
||||||
|
/// Infer the output shape of the ONNXMaxOp. This method is required by the
|
||||||
|
/// shape inference interface.
|
||||||
|
void ONNXMaxOp::inferShapes() {
|
||||||
|
getResult()->setType(getOperand(0)->getType());
|
||||||
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// Min
|
||||||
|
/// Infer the output shape of the ONNXMinOp. This method is required by the
|
||||||
|
/// shape inference interface.
|
||||||
|
void ONNXMinOp::inferShapes() {
|
||||||
|
getResult()->setType(getOperand(0)->getType());
|
||||||
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
// MatMul
|
// MatMul
|
||||||
|
|
||||||
void ONNXMatMulOp::inferShapes() {
|
void ONNXMatMulOp::inferShapes() {
|
||||||
|
|
|
@ -531,7 +531,7 @@ def ONNXDynamicQuantizeLinearOp:ONNX_Op<"DynamicQuantizeLinear",
|
||||||
}
|
}
|
||||||
|
|
||||||
def ONNXEluOp:ONNX_Op<"Elu",
|
def ONNXEluOp:ONNX_Op<"Elu",
|
||||||
[NoSideEffect]> {
|
[NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
|
||||||
let summary = "ONNX Elu operation";
|
let summary = "ONNX Elu operation";
|
||||||
let description = [{
|
let description = [{
|
||||||
"Elu takes one input data (Tensor<T>) and produces one output data"
|
"Elu takes one input data (Tensor<T>) and produces one output data"
|
||||||
|
@ -991,7 +991,7 @@ def ONNXGreaterOp:ONNX_Op<"Greater",
|
||||||
}
|
}
|
||||||
|
|
||||||
def ONNXHardSigmoidOp:ONNX_Op<"HardSigmoid",
|
def ONNXHardSigmoidOp:ONNX_Op<"HardSigmoid",
|
||||||
[NoSideEffect]> {
|
[NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
|
||||||
let summary = "ONNX HardSigmoid operation";
|
let summary = "ONNX HardSigmoid operation";
|
||||||
let description = [{
|
let description = [{
|
||||||
"HardSigmoid takes one input data (Tensor<T>) and produces one output data"
|
"HardSigmoid takes one input data (Tensor<T>) and produces one output data"
|
||||||
|
@ -1191,7 +1191,7 @@ def ONNXLSTMOp:ONNX_Op<"LSTM",
|
||||||
}
|
}
|
||||||
|
|
||||||
def ONNXLeakyReluOp:ONNX_Op<"LeakyRelu",
|
def ONNXLeakyReluOp:ONNX_Op<"LeakyRelu",
|
||||||
[NoSideEffect]> {
|
[NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
|
||||||
let summary = "ONNX LeakyRelu operation";
|
let summary = "ONNX LeakyRelu operation";
|
||||||
let description = [{
|
let description = [{
|
||||||
"LeakyRelu takes input data (Tensor<T>) and an argument alpha, and produces one"
|
"LeakyRelu takes input data (Tensor<T>) and an argument alpha, and produces one"
|
||||||
|
@ -1436,7 +1436,7 @@ def ONNXMatMulIntegerOp:ONNX_Op<"MatMulInteger",
|
||||||
}
|
}
|
||||||
|
|
||||||
def ONNXMaxOp:ONNX_Op<"Max",
|
def ONNXMaxOp:ONNX_Op<"Max",
|
||||||
[NoSideEffect]> {
|
[NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
|
||||||
let summary = "ONNX Max operation";
|
let summary = "ONNX Max operation";
|
||||||
let description = [{
|
let description = [{
|
||||||
"Element-wise max of each of the input tensors (with Numpy-style broadcasting support)."
|
"Element-wise max of each of the input tensors (with Numpy-style broadcasting support)."
|
||||||
|
@ -1548,7 +1548,7 @@ def ONNXMeanVarianceNormalizationOp:ONNX_Op<"MeanVarianceNormalization",
|
||||||
}
|
}
|
||||||
|
|
||||||
def ONNXMinOp:ONNX_Op<"Min",
|
def ONNXMinOp:ONNX_Op<"Min",
|
||||||
[NoSideEffect]> {
|
[NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
|
||||||
let summary = "ONNX Min operation";
|
let summary = "ONNX Min operation";
|
||||||
let description = [{
|
let description = [{
|
||||||
"Element-wise min of each of the input tensors (with Numpy-style broadcasting support)."
|
"Element-wise min of each of the input tensors (with Numpy-style broadcasting support)."
|
||||||
|
@ -2625,7 +2625,7 @@ def ONNXScatterNDOp:ONNX_Op<"ScatterND",
|
||||||
}
|
}
|
||||||
|
|
||||||
def ONNXSeluOp:ONNX_Op<"Selu",
|
def ONNXSeluOp:ONNX_Op<"Selu",
|
||||||
[NoSideEffect]> {
|
[NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
|
||||||
let summary = "ONNX Selu operation";
|
let summary = "ONNX Selu operation";
|
||||||
let description = [{
|
let description = [{
|
||||||
"Selu takes one input data (Tensor<T>) and produces one output data"
|
"Selu takes one input data (Tensor<T>) and produces one output data"
|
||||||
|
@ -2972,7 +2972,7 @@ def ONNXSubOp:ONNX_Op<"Sub",
|
||||||
}
|
}
|
||||||
|
|
||||||
def ONNXSumOp:ONNX_Op<"Sum",
|
def ONNXSumOp:ONNX_Op<"Sum",
|
||||||
[NoSideEffect]> {
|
[NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
|
||||||
let summary = "ONNX Sum operation";
|
let summary = "ONNX Sum operation";
|
||||||
let description = [{
|
let description = [{
|
||||||
"Element-wise sum of each of the input tensors (with Numpy-style broadcasting support)."
|
"Element-wise sum of each of the input tensors (with Numpy-style broadcasting support)."
|
||||||
|
|
|
@ -148,6 +148,12 @@ struct ScalarOp<ONNXExpOp> {
|
||||||
using IOp = ExpOp; // not use
|
using IOp = ExpOp; // not use
|
||||||
};
|
};
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct ScalarOp<ONNXSumOp> {
|
||||||
|
using FOp = AddFOp;
|
||||||
|
using IOp = AddIOp;
|
||||||
|
};
|
||||||
|
|
||||||
template <typename ElementwiseNaryOp>
|
template <typename ElementwiseNaryOp>
|
||||||
using ScalarFOp = typename ScalarOp<ElementwiseNaryOp>::FOp;
|
using ScalarFOp = typename ScalarOp<ElementwiseNaryOp>::FOp;
|
||||||
template <typename ElementwiseNaryOp>
|
template <typename ElementwiseNaryOp>
|
||||||
|
@ -157,11 +163,11 @@ using ScalarIOp = typename ScalarOp<ElementwiseNaryOp>::IOp;
|
||||||
// Scalar unary ops for lowering to Krnl dialect.
|
// Scalar unary ops for lowering to Krnl dialect.
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
template <typename UnaryOp>
|
template <typename UnaryOp>
|
||||||
Value* mapToLowerScalarOp(Location loc, ArrayRef<Type> result_types,
|
Value* mapToLowerScalarOp(Operation* op, ArrayRef<Type> result_types,
|
||||||
ArrayRef<Value*> operands, ConversionPatternRewriter& rewriter) {
|
ArrayRef<Value*> operands, ConversionPatternRewriter& rewriter) {
|
||||||
/* Lower UnaryOp to Ops in the Standard dialect.
|
/* Lower UnaryOp to Ops in the Standard dialect.
|
||||||
*/
|
*/
|
||||||
|
auto loc = op->getLoc();
|
||||||
Type element_type = operands.front()->getType();
|
Type element_type = operands.front()->getType();
|
||||||
if (element_type.isa<IntegerType>()) {
|
if (element_type.isa<IntegerType>()) {
|
||||||
return rewriter.create<ScalarIOp<UnaryOp>>(
|
return rewriter.create<ScalarIOp<UnaryOp>>(
|
||||||
|
@ -179,11 +185,14 @@ Value* mapToLowerScalarOp(Location loc, ArrayRef<Type> result_types,
|
||||||
// Scalar unary ops for lowering ONNXTanhOp
|
// Scalar unary ops for lowering ONNXTanhOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
template <>
|
template <>
|
||||||
Value* mapToLowerScalarOp<ONNXTanhOp>(Location loc, ArrayRef<Type> result_types,
|
Value* mapToLowerScalarOp<ONNXTanhOp>(Operation* op,
|
||||||
ArrayRef<Value*> operands, ConversionPatternRewriter& rewriter) {
|
ArrayRef<Type> result_types, ArrayRef<Value*> operands,
|
||||||
|
ConversionPatternRewriter& rewriter) {
|
||||||
// ONNXTanhOp(%X) = DivFOp(SubFOp(ExpOp(%X), ExpOp(NegFOp(%X))),
|
// ONNXTanhOp(%X) = DivFOp(SubFOp(ExpOp(%X), ExpOp(NegFOp(%X))),
|
||||||
// AddFOp(ExpOp(%X), ExpOp(NegFOp(%X))))
|
// AddFOp(ExpOp(%X), ExpOp(NegFOp(%X))))
|
||||||
|
auto loc = op->getLoc();
|
||||||
Value* operand = operands[0];
|
Value* operand = operands[0];
|
||||||
|
|
||||||
auto zero = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(0.0f));
|
auto zero = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(0.0f));
|
||||||
auto neg = rewriter.create<SubFOp>(loc, zero, operand);
|
auto neg = rewriter.create<SubFOp>(loc, zero, operand);
|
||||||
auto exp = rewriter.create<ExpOp>(loc, operand);
|
auto exp = rewriter.create<ExpOp>(loc, operand);
|
||||||
|
@ -191,6 +200,7 @@ Value* mapToLowerScalarOp<ONNXTanhOp>(Location loc, ArrayRef<Type> result_types,
|
||||||
auto result =
|
auto result =
|
||||||
rewriter.create<DivFOp>(loc, rewriter.create<SubFOp>(loc, exp, negExp),
|
rewriter.create<DivFOp>(loc, rewriter.create<SubFOp>(loc, exp, negExp),
|
||||||
rewriter.create<AddFOp>(loc, exp, negExp));
|
rewriter.create<AddFOp>(loc, exp, negExp));
|
||||||
|
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -198,11 +208,14 @@ Value* mapToLowerScalarOp<ONNXTanhOp>(Location loc, ArrayRef<Type> result_types,
|
||||||
// Scalar unary ops for lowering ONNXSinhOp
|
// Scalar unary ops for lowering ONNXSinhOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
template <>
|
template <>
|
||||||
Value* mapToLowerScalarOp<ONNXSinhOp>(Location loc, ArrayRef<Type> result_types,
|
Value* mapToLowerScalarOp<ONNXSinhOp>(Operation* op,
|
||||||
ArrayRef<Value*> operands, ConversionPatternRewriter& rewriter) {
|
ArrayRef<Type> result_types, ArrayRef<Value*> operands,
|
||||||
|
ConversionPatternRewriter& rewriter) {
|
||||||
// ONNXSinhOp(%X) = DivFOp(SubFOp(ExpOp(%X), ExpOp(NegFOp(%X))),
|
// ONNXSinhOp(%X) = DivFOp(SubFOp(ExpOp(%X), ExpOp(NegFOp(%X))),
|
||||||
// ConstantOp 2)
|
// ConstantOp 2)
|
||||||
|
auto loc = op->getLoc();
|
||||||
Value* operand = operands[0];
|
Value* operand = operands[0];
|
||||||
|
|
||||||
auto zero = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(0.0f));
|
auto zero = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(0.0f));
|
||||||
auto two = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(2.0f));
|
auto two = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(2.0f));
|
||||||
auto neg = rewriter.create<SubFOp>(loc, zero, operand);
|
auto neg = rewriter.create<SubFOp>(loc, zero, operand);
|
||||||
|
@ -210,6 +223,7 @@ Value* mapToLowerScalarOp<ONNXSinhOp>(Location loc, ArrayRef<Type> result_types,
|
||||||
auto negExp = rewriter.create<ExpOp>(loc, neg);
|
auto negExp = rewriter.create<ExpOp>(loc, neg);
|
||||||
auto result = rewriter.create<DivFOp>(
|
auto result = rewriter.create<DivFOp>(
|
||||||
loc, rewriter.create<SubFOp>(loc, exp, negExp), two);
|
loc, rewriter.create<SubFOp>(loc, exp, negExp), two);
|
||||||
|
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -217,11 +231,14 @@ Value* mapToLowerScalarOp<ONNXSinhOp>(Location loc, ArrayRef<Type> result_types,
|
||||||
// Scalar unary ops for lowering ONNXCoshOp
|
// Scalar unary ops for lowering ONNXCoshOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
template <>
|
template <>
|
||||||
Value* mapToLowerScalarOp<ONNXCoshOp>(Location loc, ArrayRef<Type> result_types,
|
Value* mapToLowerScalarOp<ONNXCoshOp>(Operation* op,
|
||||||
ArrayRef<Value*> operands, ConversionPatternRewriter& rewriter) {
|
ArrayRef<Type> result_types, ArrayRef<Value*> operands,
|
||||||
|
ConversionPatternRewriter& rewriter) {
|
||||||
// ONNXCoshOp(%X) = DivFOp(AddFOp(ExpOp(%X), ExpOp(NegFOp(%X))),
|
// ONNXCoshOp(%X) = DivFOp(AddFOp(ExpOp(%X), ExpOp(NegFOp(%X))),
|
||||||
// ConstantOp 2)
|
// ConstantOp 2)
|
||||||
|
auto loc = op->getLoc();
|
||||||
Value* operand = operands[0];
|
Value* operand = operands[0];
|
||||||
|
|
||||||
auto zero = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(0.0f));
|
auto zero = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(0.0f));
|
||||||
auto two = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(2.0f));
|
auto two = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(2.0f));
|
||||||
auto neg = rewriter.create<SubFOp>(loc, zero, operand);
|
auto neg = rewriter.create<SubFOp>(loc, zero, operand);
|
||||||
|
@ -229,6 +246,7 @@ Value* mapToLowerScalarOp<ONNXCoshOp>(Location loc, ArrayRef<Type> result_types,
|
||||||
auto negExp = rewriter.create<ExpOp>(loc, neg);
|
auto negExp = rewriter.create<ExpOp>(loc, neg);
|
||||||
auto result = rewriter.create<DivFOp>(
|
auto result = rewriter.create<DivFOp>(
|
||||||
loc, rewriter.create<AddFOp>(loc, exp, negExp), two);
|
loc, rewriter.create<AddFOp>(loc, exp, negExp), two);
|
||||||
|
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -236,18 +254,84 @@ Value* mapToLowerScalarOp<ONNXCoshOp>(Location loc, ArrayRef<Type> result_types,
|
||||||
// Scalar unary ops for lowering ONNXSigmoidOp
|
// Scalar unary ops for lowering ONNXSigmoidOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
template <>
|
template <>
|
||||||
Value* mapToLowerScalarOp<ONNXSigmoidOp>(Location loc,
|
Value* mapToLowerScalarOp<ONNXSigmoidOp>(Operation* op,
|
||||||
ArrayRef<Type> result_types, ArrayRef<Value*> operands,
|
ArrayRef<Type> result_types, ArrayRef<Value*> operands,
|
||||||
ConversionPatternRewriter& rewriter) {
|
ConversionPatternRewriter& rewriter) {
|
||||||
// ONNXSigmoidOp(%X) = DivFOp(ConstantOp 1,
|
// ONNXSigmoidOp(%X) = DivFOp(ConstantOp 1,
|
||||||
// AddFOp(ConstantOp 1, ExpOp(NegFOp(%X))))
|
// AddFOp(ConstantOp 1, ExpOp(NegFOp(%X))))
|
||||||
|
auto loc = op->getLoc();
|
||||||
Value* operand = operands[0];
|
Value* operand = operands[0];
|
||||||
|
|
||||||
auto zero = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(0.0f));
|
auto zero = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(0.0f));
|
||||||
auto one = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(1.0f));
|
auto one = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(1.0f));
|
||||||
auto neg = rewriter.create<SubFOp>(loc, zero, operand);
|
auto neg = rewriter.create<SubFOp>(loc, zero, operand);
|
||||||
auto negExp = rewriter.create<ExpOp>(loc, neg);
|
auto negExp = rewriter.create<ExpOp>(loc, neg);
|
||||||
auto result = rewriter.create<DivFOp>(
|
auto result = rewriter.create<DivFOp>(
|
||||||
loc, one, rewriter.create<AddFOp>(loc, one, negExp));
|
loc, one, rewriter.create<AddFOp>(loc, one, negExp));
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// Scalar unary ops for lowering ONNXHardSigmoidOp
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
template <>
|
||||||
|
Value* mapToLowerScalarOp<ONNXHardSigmoidOp>(Operation* op,
|
||||||
|
ArrayRef<Type> result_types, ArrayRef<Value*> operands,
|
||||||
|
ConversionPatternRewriter& rewriter) {
|
||||||
|
// %Y = AddFOp(MulFOp(alpha, %X), beta)
|
||||||
|
// %Z = SelectOp(CmpFOp(OGT, %Y, Constant 0),
|
||||||
|
// %Y,
|
||||||
|
// Constant 0)
|
||||||
|
// ONNXHardSigmoidOp(%X) = SelectOp(CmpFOp(OLT, %Z, Constant 1),
|
||||||
|
// %Z,
|
||||||
|
// Constant 1)
|
||||||
|
auto loc = op->getLoc();
|
||||||
|
Value* operand = operands[0];
|
||||||
|
auto alphaAttr = op->getAttrOfType<FloatAttr>("HardSigmoid.alpha");
|
||||||
|
auto betaAttr = op->getAttrOfType<FloatAttr>("HardSigmoid.beta");
|
||||||
|
|
||||||
|
auto zero = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(0.0f));
|
||||||
|
auto one = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(1.0f));
|
||||||
|
auto alpha = rewriter.create<ConstantOp>(loc, alphaAttr);
|
||||||
|
auto beta = rewriter.create<ConstantOp>(loc, betaAttr);
|
||||||
|
|
||||||
|
auto add = rewriter.create<AddFOp>(
|
||||||
|
loc, rewriter.create<MulFOp>(loc, alpha, operand), beta);
|
||||||
|
auto maxPredicate =
|
||||||
|
rewriter.create<CmpFOp>(loc, CmpFPredicate::OGT, add, zero);
|
||||||
|
auto max = rewriter.create<SelectOp>(loc, maxPredicate, add, zero);
|
||||||
|
auto minPredicate =
|
||||||
|
rewriter.create<CmpFOp>(loc, CmpFPredicate::OLT, max, one);
|
||||||
|
auto result = rewriter.create<SelectOp>(loc, minPredicate, max, one);
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// Scalar unary ops for lowering ONNXEluOp
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
template <>
|
||||||
|
Value* mapToLowerScalarOp<ONNXEluOp>(Operation* op, ArrayRef<Type> result_types,
|
||||||
|
ArrayRef<Value*> operands, ConversionPatternRewriter& rewriter) {
|
||||||
|
// ONNXEluOp(%X) = SelectOp(CmpFOp(OLT, %X, ConstantOp 0),
|
||||||
|
// MulFOp(alpha, SubFOp(ExpOp(%X), 1)),
|
||||||
|
// %X)
|
||||||
|
auto loc = op->getLoc();
|
||||||
|
Value* operand = operands[0];
|
||||||
|
|
||||||
|
auto alphaAttr = op->getAttrOfType<FloatAttr>("Elu.alpha");
|
||||||
|
auto zero = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(0.0f));
|
||||||
|
auto one = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(1.0f));
|
||||||
|
auto alpha = rewriter.create<ConstantOp>(loc, alphaAttr);
|
||||||
|
auto exp = rewriter.create<ExpOp>(loc, operand);
|
||||||
|
auto lessThanZero =
|
||||||
|
rewriter.create<CmpFOp>(loc, CmpFPredicate::OLT, operand, zero);
|
||||||
|
auto result = rewriter.create<SelectOp>(loc, lessThanZero,
|
||||||
|
rewriter.create<MulFOp>(
|
||||||
|
loc, alpha, rewriter.create<SubFOp>(loc, exp, one)),
|
||||||
|
operand);
|
||||||
|
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -255,30 +339,122 @@ Value* mapToLowerScalarOp<ONNXSigmoidOp>(Location loc,
|
||||||
// Scalar unary ops for lowering ONNXReluOp
|
// Scalar unary ops for lowering ONNXReluOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
template <>
|
template <>
|
||||||
Value* mapToLowerScalarOp<ONNXReluOp>(Location loc, ArrayRef<Type> result_types,
|
Value* mapToLowerScalarOp<ONNXReluOp>(Operation* op,
|
||||||
ArrayRef<Value*> operands, ConversionPatternRewriter& rewriter) {
|
ArrayRef<Type> result_types, ArrayRef<Value*> operands,
|
||||||
|
ConversionPatternRewriter& rewriter) {
|
||||||
// ONNXReluOp(%X) = SelectOp(CmpFOp(OLT, %X, ConstantOp 0),
|
// ONNXReluOp(%X) = SelectOp(CmpFOp(OLT, %X, ConstantOp 0),
|
||||||
// ConstantOp 0,
|
// ConstantOp 0,
|
||||||
// %X)
|
// %X)
|
||||||
|
auto loc = op->getLoc();
|
||||||
Value* operand = operands[0];
|
Value* operand = operands[0];
|
||||||
|
|
||||||
auto zero = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(0.0f));
|
auto zero = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(0.0f));
|
||||||
auto lessThanZero =
|
auto lessThanZero =
|
||||||
rewriter.create<CmpFOp>(loc, CmpFPredicate::OLT, operand, zero);
|
rewriter.create<CmpFOp>(loc, CmpFPredicate::OLT, operand, zero);
|
||||||
auto result = rewriter.create<SelectOp>(loc, lessThanZero, zero, operand);
|
auto result = rewriter.create<SelectOp>(loc, lessThanZero, zero, operand);
|
||||||
|
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Element-wise n-ary ops lowering to Krnl dialect.
|
// Scalar unary ops for lowering ONNXLeakyReluOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
template <typename ElementwiseNaryOp, unsigned numArgs>
|
template <>
|
||||||
struct ONNXElementwiseNaryOpLowering : public ConversionPattern {
|
Value* mapToLowerScalarOp<ONNXLeakyReluOp>(Operation* op,
|
||||||
ONNXElementwiseNaryOpLowering(MLIRContext* ctx)
|
ArrayRef<Type> result_types, ArrayRef<Value*> operands,
|
||||||
: ConversionPattern(ElementwiseNaryOp::getOperationName(), 1, ctx) {}
|
ConversionPatternRewriter& rewriter) {
|
||||||
|
// ONNXLeakyReluOp(%X) = SelectOp(CmpFOp(OLT, %X, ConstantOp 0),
|
||||||
|
// MulFOp(alpha, %X),
|
||||||
|
// %X)
|
||||||
|
auto loc = op->getLoc();
|
||||||
|
Value* operand = operands[0];
|
||||||
|
|
||||||
|
auto alphaAttr = op->getAttrOfType<FloatAttr>("LeakyRelu.alpha");
|
||||||
|
auto zero = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(0.0f));
|
||||||
|
auto alpha = rewriter.create<ConstantOp>(loc, alphaAttr);
|
||||||
|
auto lessThanZero =
|
||||||
|
rewriter.create<CmpFOp>(loc, CmpFPredicate::OLT, operand, zero);
|
||||||
|
auto result = rewriter.create<SelectOp>(
|
||||||
|
loc, lessThanZero, rewriter.create<MulFOp>(loc, alpha, operand), operand);
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// Scalar unary ops for lowering ONNXSeluOp
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
template <>
|
||||||
|
Value* mapToLowerScalarOp<ONNXSeluOp>(Operation* op,
|
||||||
|
ArrayRef<Type> result_types, ArrayRef<Value*> operands,
|
||||||
|
ConversionPatternRewriter& rewriter) {
|
||||||
|
// ONNXSeluOp(%X) = SelectOp(CmpFOp(OGT, %X, ConstantOp 0),
|
||||||
|
// MulFOp(gamma, %X),
|
||||||
|
// MulFOp(gamma,
|
||||||
|
// SubFOp(MulFOp(alpha, ExpOp(%X)),
|
||||||
|
// alpha)))
|
||||||
|
auto loc = op->getLoc();
|
||||||
|
Value* operand = operands[0];
|
||||||
|
auto alphaAttr = op->getAttrOfType<FloatAttr>("Selu.alpha");
|
||||||
|
auto gammaAttr = op->getAttrOfType<FloatAttr>("Selu.gamma");
|
||||||
|
|
||||||
|
auto zero = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(0.0f));
|
||||||
|
auto alpha = rewriter.create<ConstantOp>(loc, alphaAttr);
|
||||||
|
auto gamma = rewriter.create<ConstantOp>(loc, gammaAttr);
|
||||||
|
auto exp = rewriter.create<ExpOp>(loc, operand);
|
||||||
|
auto greaterThanZero =
|
||||||
|
rewriter.create<CmpFOp>(loc, CmpFPredicate::OGT, operand, zero);
|
||||||
|
auto select = rewriter.create<SelectOp>(loc, greaterThanZero, operand,
|
||||||
|
rewriter.create<SubFOp>(
|
||||||
|
loc, rewriter.create<MulFOp>(loc, alpha, exp), alpha));
|
||||||
|
auto result = rewriter.create<MulFOp>(loc, gamma, select);
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// Scalar unary ops for lowering ONNXMaxOp
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
template <>
|
||||||
|
Value* mapToLowerScalarOp<ONNXMaxOp>(Operation* op, ArrayRef<Type> result_types,
|
||||||
|
ArrayRef<Value*> operands, ConversionPatternRewriter& rewriter) {
|
||||||
|
// ONNXMaxOp(%X, %Y) = SelectOp(CmpFOp(OGT, %X, %Y),
|
||||||
|
// %X,
|
||||||
|
// %Y)
|
||||||
|
auto loc = op->getLoc();
|
||||||
|
Value* lhs = operands[0];
|
||||||
|
Value* rhs = operands[1];
|
||||||
|
auto max = rewriter.create<CmpFOp>(loc, CmpFPredicate::OGT, lhs, rhs);
|
||||||
|
auto result = rewriter.create<SelectOp>(loc, max, lhs, rhs);
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// Scalar unary ops for lowering ONNXMinOp
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
template <>
|
||||||
|
Value* mapToLowerScalarOp<ONNXMinOp>(Operation* op, ArrayRef<Type> result_types,
|
||||||
|
ArrayRef<Value*> operands, ConversionPatternRewriter& rewriter) {
|
||||||
|
// ONNXMinOp(%X, %Y) = SelectOp(CmpFOp(OLT, %X, %Y),
|
||||||
|
// %X,
|
||||||
|
// %Y)
|
||||||
|
auto loc = op->getLoc();
|
||||||
|
Value* lhs = operands[0];
|
||||||
|
Value* rhs = operands[1];
|
||||||
|
auto min = rewriter.create<CmpFOp>(loc, CmpFPredicate::OLT, lhs, rhs);
|
||||||
|
auto result = rewriter.create<SelectOp>(loc, min, lhs, rhs);
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Element-wise unary ops lowering to Krnl dialect.
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
template <typename ElementwiseUnaryOp>
|
||||||
|
struct ONNXElementwiseUnaryOpLowering : public ConversionPattern {
|
||||||
|
ONNXElementwiseUnaryOpLowering(MLIRContext* ctx)
|
||||||
|
: ConversionPattern(ElementwiseUnaryOp::getOperationName(), 1, ctx) {}
|
||||||
PatternMatchResult matchAndRewrite(Operation* op, ArrayRef<Value*> operands,
|
PatternMatchResult matchAndRewrite(Operation* op, ArrayRef<Value*> operands,
|
||||||
ConversionPatternRewriter& rewriter) const final {
|
ConversionPatternRewriter& rewriter) const final {
|
||||||
// TODO: Check that the types are valid.
|
// TODO: Check that the types are valid.
|
||||||
// An element-wise binary operation must have all operands and the result of
|
// An element-wise unary operation must have all operands and the result of
|
||||||
// the same type. This should have been verified by the verifier.
|
// the same type. This should have been verified by the verifier.
|
||||||
auto tensorType = (*op->result_type_begin()).cast<TensorType>();
|
auto tensorType = (*op->result_type_begin()).cast<TensorType>();
|
||||||
auto loc = op->getLoc();
|
auto loc = op->getLoc();
|
||||||
|
@ -287,7 +463,7 @@ struct ONNXElementwiseNaryOpLowering : public ConversionPattern {
|
||||||
auto memRefType = convertTensorToMemRef(tensorType);
|
auto memRefType = convertTensorToMemRef(tensorType);
|
||||||
|
|
||||||
// If the output has a dynamic dimension, pass the operands required for
|
// If the output has a dynamic dimension, pass the operands required for
|
||||||
// each dynamic dimension to the AllocOp. The first operand of the binary
|
// each dynamic dimension to the AllocOp. The first operand of the
|
||||||
// operation is used. The operands of the op need to match in terms of
|
// operation is used. The operands of the op need to match in terms of
|
||||||
// dimensions with the result at this pre-optimization phase.
|
// dimensions with the result at this pre-optimization phase.
|
||||||
// TODO: verify that dimensions match.
|
// TODO: verify that dimensions match.
|
||||||
|
@ -359,15 +535,9 @@ struct ONNXElementwiseNaryOpLowering : public ConversionPattern {
|
||||||
for (auto arg : iterationBlock.getArguments())
|
for (auto arg : iterationBlock.getArguments())
|
||||||
loopIVs.push_back(arg);
|
loopIVs.push_back(arg);
|
||||||
|
|
||||||
SmallVector<Value*, numArgs> loadedVals;
|
auto loadedVal = rewriter.create<LoadOp>(loc, operands[0], loopIVs);
|
||||||
for (unsigned i = 0; i < numArgs; i++) {
|
auto loweredOpResult = mapToLowerScalarOp<ElementwiseUnaryOp>(
|
||||||
auto loadedVal = rewriter.create<LoadOp>(loc, operands[i], loopIVs);
|
op, memRefType.getElementType(), {loadedVal}, rewriter);
|
||||||
loadedVals.push_back(loadedVal);
|
|
||||||
}
|
|
||||||
|
|
||||||
auto loweredOpResult = mapToLowerScalarOp<ElementwiseNaryOp>(
|
|
||||||
loc, memRefType.getElementType(), loadedVals, rewriter);
|
|
||||||
|
|
||||||
// Store result in the resulting array.
|
// Store result in the resulting array.
|
||||||
rewriter.create<StoreOp>(loc, loweredOpResult, alloc, loopIVs);
|
rewriter.create<StoreOp>(loc, loweredOpResult, alloc, loopIVs);
|
||||||
|
|
||||||
|
@ -377,12 +547,113 @@ struct ONNXElementwiseNaryOpLowering : public ConversionPattern {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
template <typename ElementwiseNaryOp>
|
// Element-wise variadic ops lowering to Krnl dialect.
|
||||||
using ONNXElementwiseUnaryOpLowering =
|
//===----------------------------------------------------------------------===//
|
||||||
ONNXElementwiseNaryOpLowering<ElementwiseNaryOp, 1>;
|
template <typename ElementwiseVariadicOp>
|
||||||
template <typename ElementwiseNaryOp>
|
struct ONNXElementwiseVariadicOpLowering : public ConversionPattern {
|
||||||
using ONNXElementwiseBinaryOpLowering =
|
ONNXElementwiseVariadicOpLowering(MLIRContext* ctx)
|
||||||
ONNXElementwiseNaryOpLowering<ElementwiseNaryOp, 2>;
|
: ConversionPattern(ElementwiseVariadicOp::getOperationName(), 1, ctx) {}
|
||||||
|
PatternMatchResult matchAndRewrite(Operation* op, ArrayRef<Value*> operands,
|
||||||
|
ConversionPatternRewriter& rewriter) const final {
|
||||||
|
// TODO: Check that the types are valid.
|
||||||
|
// An element-wise variadic operation must have all operands and the result
|
||||||
|
// of the same type. This should have been verified by the verifier.
|
||||||
|
auto tensorType = (*op->result_type_begin()).cast<TensorType>();
|
||||||
|
auto loc = op->getLoc();
|
||||||
|
auto numArgs = op->getNumOperands();
|
||||||
|
|
||||||
|
// Insert an allocation and deallocation for the result of this operation.
|
||||||
|
auto memRefType = convertTensorToMemRef(tensorType);
|
||||||
|
|
||||||
|
// If the output has a dynamic dimension, pass the operands required for
|
||||||
|
// each dynamic dimension to the AllocOp. The first operand of the
|
||||||
|
// operation is used. The operands of the op need to match in terms of
|
||||||
|
// dimensions with the result at this pre-optimization phase.
|
||||||
|
// TODO: verify that dimensions match.
|
||||||
|
// TODO: can the dimension of the result differ after optimizations?
|
||||||
|
Value* alloc;
|
||||||
|
bool insertDealloc = checkInsertDealloc(op);
|
||||||
|
|
||||||
|
if (hasAllConstantDimensions(memRefType))
|
||||||
|
alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc);
|
||||||
|
else
|
||||||
|
alloc = insertAllocAndDealloc(
|
||||||
|
memRefType, loc, rewriter, insertDealloc, operands[0]);
|
||||||
|
|
||||||
|
// Number of loops
|
||||||
|
auto memRefShape = memRefType.getShape();
|
||||||
|
int64_t rank = memRefShape.size();
|
||||||
|
|
||||||
|
// Define loops.
|
||||||
|
auto loopsOp = rewriter.create<KrnlDefineLoopsOp>(loc, rank);
|
||||||
|
std::vector<Value*> originalLoops;
|
||||||
|
originalLoops.reserve(rank);
|
||||||
|
for (auto result : loopsOp.getResults()) {
|
||||||
|
originalLoops.push_back(result);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Define loop optimization.
|
||||||
|
auto optimizedLoopsOp = rewriter.create<KrnlOptimizeLoopsOp>(loc, rank);
|
||||||
|
std::vector<Value*> optimizedLoops;
|
||||||
|
optimizedLoops.reserve(rank);
|
||||||
|
for (auto result : optimizedLoopsOp.getResults()) {
|
||||||
|
optimizedLoops.push_back(result);
|
||||||
|
}
|
||||||
|
Block& optimizationBlock = optimizedLoopsOp.region().front();
|
||||||
|
|
||||||
|
KrnlIterateOperandPack pack(rewriter, originalLoops, optimizedLoops);
|
||||||
|
// Iterate over the loop nest.
|
||||||
|
// TODO (Tian): move this logic inside KrnlIterateOp. Pass MemRefShape
|
||||||
|
// to KrnlIterateOp instead.
|
||||||
|
for (int i = 0; i < rank; ++i) {
|
||||||
|
if (memRefShape[i] < 0) {
|
||||||
|
pack.pushConstantBound(0);
|
||||||
|
pack.pushOperandBound(
|
||||||
|
rewriter.create<DimOp>(loc, operands[0], i).getResult());
|
||||||
|
} else {
|
||||||
|
pack.pushConstantBound(0);
|
||||||
|
pack.pushConstantBound(memRefShape[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
auto iterateOp = rewriter.create<KrnlIterateOp>(loc, pack);
|
||||||
|
Block& iterationBlock = iterateOp.bodyRegion().front();
|
||||||
|
|
||||||
|
// Now perform the insertions into the body of the
|
||||||
|
// just generated instructions:
|
||||||
|
|
||||||
|
// 1. Insert any optimizations in the KrnlOptimizeLoopsOp body.
|
||||||
|
rewriter.setInsertionPointToEnd(&optimizationBlock);
|
||||||
|
// Return from KrnlOptimizeLoopsOp body.
|
||||||
|
// When no optimizations are present we just return the loops
|
||||||
|
// unchaged.
|
||||||
|
rewriter.create<KrnlReturnLoopsOp>(loc, originalLoops);
|
||||||
|
rewriter.setInsertionPoint(optimizedLoopsOp);
|
||||||
|
|
||||||
|
// 2. Insert instructions inside the KernelIterateOp body.
|
||||||
|
rewriter.setInsertionPointToStart(&iterationBlock);
|
||||||
|
|
||||||
|
// Handle the operation:
|
||||||
|
SmallVector<Value*, 4> loopIVs;
|
||||||
|
for (auto arg : iterationBlock.getArguments())
|
||||||
|
loopIVs.push_back(arg);
|
||||||
|
|
||||||
|
// Fold over operands for each of their scalar values
|
||||||
|
Value *accumulated, *next;
|
||||||
|
accumulated = rewriter.create<LoadOp>(loc, operands[0], loopIVs);
|
||||||
|
for (unsigned i = 1; i < numArgs; i++) {
|
||||||
|
next = rewriter.create<LoadOp>(loc, operands[i], loopIVs);
|
||||||
|
accumulated = mapToLowerScalarOp<ElementwiseVariadicOp>(
|
||||||
|
op, memRefType.getElementType(), {accumulated, next}, rewriter);
|
||||||
|
}
|
||||||
|
// Store result in the resulting array.
|
||||||
|
rewriter.create<StoreOp>(loc, accumulated, alloc, loopIVs);
|
||||||
|
|
||||||
|
rewriter.replaceOp(op, alloc);
|
||||||
|
|
||||||
|
return matchSuccess();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Conversion from Tensor type to the Standard dialect MemRef type.
|
// Conversion from Tensor type to the Standard dialect MemRef type.
|
||||||
|
@ -469,14 +740,21 @@ void FrontendToKrnlLoweringPass::runOnModule() {
|
||||||
ONNXElementwiseUnaryOpLowering<mlir::ONNXSinhOp>,
|
ONNXElementwiseUnaryOpLowering<mlir::ONNXSinhOp>,
|
||||||
ONNXElementwiseUnaryOpLowering<mlir::ONNXCoshOp>,
|
ONNXElementwiseUnaryOpLowering<mlir::ONNXCoshOp>,
|
||||||
ONNXElementwiseUnaryOpLowering<mlir::ONNXSigmoidOp>,
|
ONNXElementwiseUnaryOpLowering<mlir::ONNXSigmoidOp>,
|
||||||
|
ONNXElementwiseUnaryOpLowering<mlir::ONNXHardSigmoidOp>,
|
||||||
|
ONNXElementwiseUnaryOpLowering<mlir::ONNXEluOp>,
|
||||||
ONNXElementwiseUnaryOpLowering<mlir::ONNXReluOp>,
|
ONNXElementwiseUnaryOpLowering<mlir::ONNXReluOp>,
|
||||||
ONNXElementwiseBinaryOpLowering<mlir::ONNXAddOp>,
|
ONNXElementwiseUnaryOpLowering<mlir::ONNXLeakyReluOp>,
|
||||||
ONNXElementwiseBinaryOpLowering<mlir::ONNXMulOp>,
|
ONNXElementwiseUnaryOpLowering<mlir::ONNXSeluOp>,
|
||||||
ONNXElementwiseBinaryOpLowering<mlir::ONNXDivOp>,
|
ONNXElementwiseVariadicOpLowering<mlir::ONNXAddOp>,
|
||||||
ONNXElementwiseBinaryOpLowering<mlir::ONNXSubOp>,
|
ONNXElementwiseVariadicOpLowering<mlir::ONNXMulOp>,
|
||||||
ONNXElementwiseBinaryOpLowering<mlir::ONNXAndOp>,
|
ONNXElementwiseVariadicOpLowering<mlir::ONNXDivOp>,
|
||||||
ONNXElementwiseBinaryOpLowering<mlir::ONNXOrOp>,
|
ONNXElementwiseVariadicOpLowering<mlir::ONNXSubOp>,
|
||||||
ONNXElementwiseBinaryOpLowering<mlir::ONNXXorOp>>(&getContext());
|
ONNXElementwiseVariadicOpLowering<mlir::ONNXAndOp>,
|
||||||
|
ONNXElementwiseVariadicOpLowering<mlir::ONNXOrOp>,
|
||||||
|
ONNXElementwiseVariadicOpLowering<mlir::ONNXXorOp>,
|
||||||
|
ONNXElementwiseVariadicOpLowering<mlir::ONNXSumOp>,
|
||||||
|
ONNXElementwiseVariadicOpLowering<mlir::ONNXMaxOp>,
|
||||||
|
ONNXElementwiseVariadicOpLowering<mlir::ONNXMinOp>>(&getContext());
|
||||||
|
|
||||||
// With the target and rewrite patterns defined, we can now attempt the
|
// With the target and rewrite patterns defined, we can now attempt the
|
||||||
// conversion. The conversion will signal failure if any of our `illegal`
|
// conversion. The conversion will signal failure if any of our `illegal`
|
||||||
|
|
|
@ -93,7 +93,11 @@ class ShapeInferencePass : public mlir::FunctionPass<ShapeInferencePass> {
|
||||||
op->getName().getStringRef() != "onnx.Sinh" &&
|
op->getName().getStringRef() != "onnx.Sinh" &&
|
||||||
op->getName().getStringRef() != "onnx.Cosh" &&
|
op->getName().getStringRef() != "onnx.Cosh" &&
|
||||||
op->getName().getStringRef() != "onnx.Sigmoid" &&
|
op->getName().getStringRef() != "onnx.Sigmoid" &&
|
||||||
|
op->getName().getStringRef() != "onnx.HardSigmoid" &&
|
||||||
|
op->getName().getStringRef() != "onnx.Elu" &&
|
||||||
op->getName().getStringRef() != "onnx.Relu" &&
|
op->getName().getStringRef() != "onnx.Relu" &&
|
||||||
|
op->getName().getStringRef() != "onnx.LeakyRelu" &&
|
||||||
|
op->getName().getStringRef() != "onnx.Selu" &&
|
||||||
op->getName().getStringRef() != "onnx.Mul" &&
|
op->getName().getStringRef() != "onnx.Mul" &&
|
||||||
op->getName().getStringRef() != "onnx.Add" &&
|
op->getName().getStringRef() != "onnx.Add" &&
|
||||||
op->getName().getStringRef() != "onnx.Div" &&
|
op->getName().getStringRef() != "onnx.Div" &&
|
||||||
|
@ -101,6 +105,9 @@ class ShapeInferencePass : public mlir::FunctionPass<ShapeInferencePass> {
|
||||||
op->getName().getStringRef() != "onnx.And" &&
|
op->getName().getStringRef() != "onnx.And" &&
|
||||||
op->getName().getStringRef() != "onnx.Or" &&
|
op->getName().getStringRef() != "onnx.Or" &&
|
||||||
op->getName().getStringRef() != "onnx.Xor" &&
|
op->getName().getStringRef() != "onnx.Xor" &&
|
||||||
|
op->getName().getStringRef() != "onnx.Sum" &&
|
||||||
|
op->getName().getStringRef() != "onnx.Max" &&
|
||||||
|
op->getName().getStringRef() != "onnx.Min" &&
|
||||||
op->getName().getStringRef() != "onnx.MatMul" &&
|
op->getName().getStringRef() != "onnx.MatMul" &&
|
||||||
op->getName().getStringRef() != "onnx.Gemm" &&
|
op->getName().getStringRef() != "onnx.Gemm" &&
|
||||||
op->getName().getStringRef() != "onnx.FullGemm")
|
op->getName().getStringRef() != "onnx.FullGemm")
|
||||||
|
|
|
@ -278,3 +278,169 @@ func @test_relu(%arg0 : tensor<?x10xf32>) -> tensor<*xf32> {
|
||||||
// CHECK: store [[RELU_RES]], [[RES]][%arg1, %arg2] : memref<?x10xf32>
|
// CHECK: store [[RELU_RES]], [[RES]][%arg1, %arg2] : memref<?x10xf32>
|
||||||
// CHECK: return [[RES]] : memref<?x10xf32>
|
// CHECK: return [[RES]] : memref<?x10xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func @test_sum(%arg0 : tensor<?x10xf32>, %arg1 : tensor<?x10xf32>) -> tensor<*xf32> {
|
||||||
|
%0 = "onnx.Sum"(%arg0, %arg1) : (tensor<?x10xf32>, tensor<?x10xf32>) -> tensor<*xf32>
|
||||||
|
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||||
|
|
||||||
|
// CHECK-LABEL: test_sum
|
||||||
|
// CHECK: [[DIM_0:%.+]] = dim %arg0, 0 : memref<?x10xf32>
|
||||||
|
// CHECK: [[RES:%.+]] = alloc([[DIM_0]]) : memref<?x10xf32>
|
||||||
|
// CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2
|
||||||
|
// CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops {
|
||||||
|
// CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1
|
||||||
|
// CHECK: } : () -> (!krnl.loop, !krnl.loop)
|
||||||
|
// CHECK: [[DIM_2:%.+]] = dim %arg0, 0 : memref<?x10xf32>
|
||||||
|
// CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg2 = 0 to [[DIM_2]], [[DEF_LOOPS]]#1 -> %arg3 = 0 to 10) {
|
||||||
|
// CHECK: [[LOAD1:%.+]] = load %arg0[%arg2, %arg3] : memref<?x10xf32>
|
||||||
|
// CHECK: [[LOAD2:%.+]] = load %arg1[%arg2, %arg3] : memref<?x10xf32>
|
||||||
|
// CHECK: [[ADD:%.+]] = addf [[LOAD1]], [[LOAD2]] : f32
|
||||||
|
// CHECK: store [[ADD]], [[RES]][%arg2, %arg3] : memref<?x10xf32>
|
||||||
|
// CHECK: return [[RES]] : memref<?x10xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
func @test_max(%arg0 : tensor<?x10xf32>, %arg1 : tensor<?x10xf32>) -> tensor<*xf32> {
|
||||||
|
%0 = "onnx.Max"(%arg0, %arg1) : (tensor<?x10xf32>, tensor<?x10xf32>) -> tensor<*xf32>
|
||||||
|
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||||
|
|
||||||
|
// CHECK-LABEL: test_max
|
||||||
|
// CHECK: [[DIM_0:%.+]] = dim %arg0, 0 : memref<?x10xf32>
|
||||||
|
// CHECK: [[RES:%.+]] = alloc([[DIM_0]]) : memref<?x10xf32>
|
||||||
|
// CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2
|
||||||
|
// CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops {
|
||||||
|
// CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1
|
||||||
|
// CHECK: } : () -> (!krnl.loop, !krnl.loop)
|
||||||
|
// CHECK: [[DIM_2:%.+]] = dim %arg0, 0 : memref<?x10xf32>
|
||||||
|
// CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg2 = 0 to [[DIM_2]], [[DEF_LOOPS]]#1 -> %arg3 = 0 to 10) {
|
||||||
|
// CHECK: [[LOAD1:%.+]] = load %arg0[%arg2, %arg3] : memref<?x10xf32>
|
||||||
|
// CHECK: [[LOAD2:%.+]] = load %arg1[%arg2, %arg3] : memref<?x10xf32>
|
||||||
|
// CHECK: [[MAX:%.+]] = cmpf "ogt", [[LOAD1]], [[LOAD2]] : f32
|
||||||
|
// CHECK: [[RELU_RES:%.+]] = select [[MAX]], [[LOAD1]], [[LOAD2]] : f32
|
||||||
|
// CHECK: store [[RELU_RES]], [[RES]][%arg2, %arg3] : memref<?x10xf32>
|
||||||
|
// CHECK: return [[RES]] : memref<?x10xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
func @test_min(%arg0 : tensor<?x10xf32>, %arg1 : tensor<?x10xf32>) -> tensor<*xf32> {
|
||||||
|
%0 = "onnx.Min"(%arg0, %arg1) : (tensor<?x10xf32>, tensor<?x10xf32>) -> tensor<*xf32>
|
||||||
|
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||||
|
|
||||||
|
// CHECK-LABEL: test_min
|
||||||
|
// CHECK: [[DIM_0:%.+]] = dim %arg0, 0 : memref<?x10xf32>
|
||||||
|
// CHECK: [[RES:%.+]] = alloc([[DIM_0]]) : memref<?x10xf32>
|
||||||
|
// CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2
|
||||||
|
// CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops {
|
||||||
|
// CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1
|
||||||
|
// CHECK: } : () -> (!krnl.loop, !krnl.loop)
|
||||||
|
// CHECK: [[DIM_2:%.+]] = dim %arg0, 0 : memref<?x10xf32>
|
||||||
|
// CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg2 = 0 to [[DIM_2]], [[DEF_LOOPS]]#1 -> %arg3 = 0 to 10) {
|
||||||
|
// CHECK: [[LOAD1:%.+]] = load %arg0[%arg2, %arg3] : memref<?x10xf32>
|
||||||
|
// CHECK: [[LOAD2:%.+]] = load %arg1[%arg2, %arg3] : memref<?x10xf32>
|
||||||
|
// CHECK: [[MIN:%.+]] = cmpf "olt", [[LOAD1]], [[LOAD2]] : f32
|
||||||
|
// CHECK: [[RELU_RES:%.+]] = select [[MIN]], [[LOAD1]], [[LOAD2]] : f32
|
||||||
|
// CHECK: store [[RELU_RES]], [[RES]][%arg2, %arg3] : memref<?x10xf32>
|
||||||
|
// CHECK: return [[RES]] : memref<?x10xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
func @test_elu(%arg0 : tensor<?x10xf32>) -> tensor<*xf32> {
|
||||||
|
%0 = "onnx.Elu"(%arg0) {Elu.alpha=2.0:f32} : (tensor<?x10xf32>) -> tensor<*xf32>
|
||||||
|
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||||
|
|
||||||
|
// CHECK-LABEL: test_elu
|
||||||
|
// CHECK: [[DIM_0:%.+]] = dim %arg0, 0 : memref<?x10xf32>
|
||||||
|
// CHECK: [[RES:%.+]] = alloc([[DIM_0]]) : memref<?x10xf32>
|
||||||
|
// CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2
|
||||||
|
// CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops {
|
||||||
|
// CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1
|
||||||
|
// CHECK: } : () -> (!krnl.loop, !krnl.loop)
|
||||||
|
// CHECK: [[DIM_2:%.+]] = dim %arg0, 0 : memref<?x10xf32>
|
||||||
|
// CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg1 = 0 to [[DIM_2]], [[DEF_LOOPS]]#1 -> %arg2 = 0 to 10) {
|
||||||
|
// CHECK: [[LOAD:%.+]] = load %arg0[%arg1, %arg2] : memref<?x10xf32>
|
||||||
|
// CHECK: [[ZERO:%.+]] = constant {{0.+}} : f32
|
||||||
|
// CHECK: [[ONE:%.+]] = constant {{1.+}} : f32
|
||||||
|
// CHECK: [[ALPHA:%.+]] = constant {{2.+}} : f32
|
||||||
|
// CHECK: [[EXP:%.+]] = exp [[LOAD]] : f32
|
||||||
|
// CHECK: [[CMP:%.+]] = cmpf "olt", [[LOAD]], [[ZERO]] : f32
|
||||||
|
// CHECK: [[SUB:%.+]] = subf [[EXP]], [[ONE]] : f32
|
||||||
|
// CHECK: [[MUL:%.+]] = mulf [[ALPHA]], [[SUB]] : f32
|
||||||
|
// CHECK: [[SELECT:%.+]] = select [[CMP]], [[MUL]], [[LOAD]] : f32
|
||||||
|
// CHECK: store [[SELECT]], [[RES]][%arg1, %arg2] : memref<?x10xf32>
|
||||||
|
// CHECK: return [[RES]] : memref<?x10xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
func @test_leakyrelu(%arg0 : tensor<?x10xf32>) -> tensor<*xf32> {
|
||||||
|
%0 = "onnx.LeakyRelu"(%arg0) {LeakyRelu.alpha=1.0:f32} : (tensor<?x10xf32>) -> tensor<*xf32>
|
||||||
|
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||||
|
|
||||||
|
// CHECK-LABEL: test_leakyrelu
|
||||||
|
// CHECK: [[DIM_0:%.+]] = dim %arg0, 0 : memref<?x10xf32>
|
||||||
|
// CHECK: [[RES:%.+]] = alloc([[DIM_0]]) : memref<?x10xf32>
|
||||||
|
// CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2
|
||||||
|
// CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops {
|
||||||
|
// CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1
|
||||||
|
// CHECK: } : () -> (!krnl.loop, !krnl.loop)
|
||||||
|
// CHECK: [[DIM_2:%.+]] = dim %arg0, 0 : memref<?x10xf32>
|
||||||
|
// CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg1 = 0 to [[DIM_2]], [[DEF_LOOPS]]#1 -> %arg2 = 0 to 10) {
|
||||||
|
// CHECK: [[LOAD:%.+]] = load %arg0[%arg1, %arg2] : memref<?x10xf32>
|
||||||
|
// CHECK: [[ZERO:%.+]] = constant {{0.+}} : f32
|
||||||
|
// CHECK: [[ALPHA:%.+]] = constant {{1.+}} : f32
|
||||||
|
// CHECK: [[CMP:%.+]] = cmpf "olt", [[LOAD]], [[ZERO]] : f32
|
||||||
|
// CHECK: [[MUL:%.+]] = mulf [[ALPHA]], [[LOAD]] : f32
|
||||||
|
// CHECK: [[SELECT:%.+]] = select [[CMP]], [[MUL]], [[LOAD]] : f32
|
||||||
|
// CHECK: store [[SELECT]], [[RES]][%arg1, %arg2] : memref<?x10xf32>
|
||||||
|
// CHECK: return [[RES]] : memref<?x10xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
func @test_selu(%arg0 : tensor<?x10xf32>) -> tensor<*xf32> {
|
||||||
|
%0 = "onnx.Selu"(%arg0) {Selu.alpha=1.0:f32, Selu.gamma=2.0:f32} : (tensor<?x10xf32>) -> tensor<*xf32>
|
||||||
|
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||||
|
|
||||||
|
// CHECK-LABEL: test_selu
|
||||||
|
// CHECK: [[DIM_0:%.+]] = dim %arg0, 0 : memref<?x10xf32>
|
||||||
|
// CHECK: [[RES:%.+]] = alloc([[DIM_0]]) : memref<?x10xf32>
|
||||||
|
// CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2
|
||||||
|
// CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops {
|
||||||
|
// CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1
|
||||||
|
// CHECK: } : () -> (!krnl.loop, !krnl.loop)
|
||||||
|
// CHECK: [[DIM_2:%.+]] = dim %arg0, 0 : memref<?x10xf32>
|
||||||
|
// CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg1 = 0 to [[DIM_2]], [[DEF_LOOPS]]#1 -> %arg2 = 0 to 10) {
|
||||||
|
// CHECK: [[LOAD:%.+]] = load %arg0[%arg1, %arg2] : memref<?x10xf32>
|
||||||
|
// CHECK: [[ZERO:%.+]] = constant {{0.+}} : f32
|
||||||
|
// CHECK: [[ALPHA:%.+]] = constant {{1.+}} : f32
|
||||||
|
// CHECK: [[GAMMA:%.+]] = constant {{2.+}} : f32
|
||||||
|
// CHECK: [[EXP:%.+]] = exp [[LOAD]] : f32
|
||||||
|
// CHECK: [[CMP:%.+]] = cmpf "ogt", [[LOAD]], [[ZERO]] : f32
|
||||||
|
// CHECK: [[MUL:%.+]] = mulf [[ALPHA]], [[EXP]] : f32
|
||||||
|
// CHECK: [[SUB:%.+]] = subf [[MUL]], [[ALPHA]] : f32
|
||||||
|
// CHECK: [[SELECT:%.+]] = select [[CMP]], [[LOAD]], [[SUB]] : f32
|
||||||
|
// CHECK: [[SELU_RES:%.+]] = mulf [[GAMMA]], [[SELECT]] : f32
|
||||||
|
// CHECK: store [[SELU_RES]], [[RES]][%arg1, %arg2] : memref<?x10xf32>
|
||||||
|
// CHECK: return [[RES]] : memref<?x10xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
func @test_hardsigmoid(%arg0 : tensor<?x10xf32>) -> tensor<*xf32> {
|
||||||
|
%0 = "onnx.HardSigmoid"(%arg0) {HardSigmoid.alpha=1.0:f32, HardSigmoid.beta=2.0:f32} : (tensor<?x10xf32>) -> tensor<*xf32>
|
||||||
|
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||||
|
|
||||||
|
// CHECK-LABEL: test_hardsigmoid
|
||||||
|
// CHECK: [[DIM_0:%.+]] = dim %arg0, 0 : memref<?x10xf32>
|
||||||
|
// CHECK: [[RES:%.+]] = alloc([[DIM_0]]) : memref<?x10xf32>
|
||||||
|
// CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2
|
||||||
|
// CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops {
|
||||||
|
// CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1
|
||||||
|
// CHECK: } : () -> (!krnl.loop, !krnl.loop)
|
||||||
|
// CHECK: [[DIM_2:%.+]] = dim %arg0, 0 : memref<?x10xf32>
|
||||||
|
// CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg1 = 0 to [[DIM_2]], [[DEF_LOOPS]]#1 -> %arg2 = 0 to 10) {
|
||||||
|
// CHECK: [[LOAD:%.+]] = load %arg0[%arg1, %arg2] : memref<?x10xf32>
|
||||||
|
// CHECK: [[ZERO:%.+]] = constant {{0.+}} : f32
|
||||||
|
// CHECK: [[ONE:%.+]] = constant {{1.+}} : f32
|
||||||
|
// CHECK: [[ALPHA:%.+]] = constant {{1.+}} : f32
|
||||||
|
// CHECK: [[BETA:%.+]] = constant {{2.+}} : f32
|
||||||
|
// CHECK: [[MUL:%.+]] = mulf [[ALPHA]], [[LOAD]] : f32
|
||||||
|
// CHECK: [[ADD:%.+]] = addf [[MUL]], [[BETA]] : f32
|
||||||
|
// CHECK: [[CMP1:%.+]] = cmpf "ogt", [[ADD]], [[ZERO]] : f32
|
||||||
|
// CHECK: [[SELECT1:%.+]] = select [[CMP1]], [[ADD]], [[ZERO]] : f32
|
||||||
|
// CHECK: [[CMP2:%.+]] = cmpf "olt", [[SELECT1]], [[ONE]] : f32
|
||||||
|
// CHECK: [[SELECT2:%.+]] = select [[CMP2]], [[SELECT1]], [[ONE]] : f32
|
||||||
|
// CHECK: store [[SELECT2]], [[RES]][%arg1, %arg2] : memref<?x10xf32>
|
||||||
|
// CHECK: return [[RES]] : memref<?x10xf32>
|
||||||
|
}
|
||||||
|
|
|
@ -571,3 +571,342 @@ func @test_relu_relu(%arg0 : tensor<?x10xf32>) -> tensor<*xf32> {
|
||||||
|
|
||||||
// CHECK: return [[RET_RES]] : memref<?x10xf32>
|
// CHECK: return [[RET_RES]] : memref<?x10xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func @test_sum_sum(%arg0 : tensor<?x10xf32>, %arg1 : tensor<?x10xf32>) -> tensor<*xf32> {
|
||||||
|
%0 = "onnx.Sum"(%arg0, %arg1) : (tensor<?x10xf32>, tensor<?x10xf32>) -> tensor<*xf32>
|
||||||
|
%1 = "onnx.Sum"(%0, %arg1) : (tensor<*xf32>, tensor<?x10xf32>) -> tensor<*xf32>
|
||||||
|
"std.return"(%1) : (tensor<*xf32>) -> ()
|
||||||
|
|
||||||
|
// CHECK-LABEL: test_sum_sum
|
||||||
|
/// First Sum
|
||||||
|
// CHECK: [[DIM_0:%.+]] = dim %arg0, 0 : memref<?x10xf32>
|
||||||
|
// CHECK: [[RES:%.+]] = alloc([[DIM_0]]) : memref<?x10xf32>
|
||||||
|
// CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2
|
||||||
|
// CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops {
|
||||||
|
// CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1
|
||||||
|
// CHECK: } : () -> (!krnl.loop, !krnl.loop)
|
||||||
|
// CHECK: [[DIM_2:%.+]] = dim %arg0, 0 : memref<?x10xf32>
|
||||||
|
// CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg2 = 0 to [[DIM_2]], [[DEF_LOOPS]]#1 -> %arg3 = 0 to 10) {
|
||||||
|
// CHECK: [[LOAD1:%.+]] = load %arg0[%arg2, %arg3] : memref<?x10xf32>
|
||||||
|
// CHECK: [[LOAD2:%.+]] = load %arg1[%arg2, %arg3] : memref<?x10xf32>
|
||||||
|
// CHECK: [[ADD:%.+]] = addf [[LOAD1]], [[LOAD2]] : f32
|
||||||
|
// CHECK: store [[ADD]], [[RES]][%arg2, %arg3] : memref<?x10xf32>
|
||||||
|
|
||||||
|
/// Second Sum
|
||||||
|
// CHECK: [[DIM_0:%.+]] = dim [[RES]], 0 : memref<?x10xf32>
|
||||||
|
// CHECK: [[RET_RES:%.+]] = alloc([[DIM_0]]) : memref<?x10xf32>
|
||||||
|
// CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2
|
||||||
|
// CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops {
|
||||||
|
// CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1
|
||||||
|
// CHECK: } : () -> (!krnl.loop, !krnl.loop)
|
||||||
|
// CHECK: [[DIM_2:%.+]] = dim [[RES]], 0 : memref<?x10xf32>
|
||||||
|
// CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg2 = 0 to [[DIM_2]], [[DEF_LOOPS]]#1 -> %arg3 = 0 to 10) {
|
||||||
|
// CHECK: [[LOAD1:%.+]] = load [[RES]][%arg2, %arg3] : memref<?x10xf32>
|
||||||
|
// CHECK: [[LOAD2:%.+]] = load %arg1[%arg2, %arg3] : memref<?x10xf32>
|
||||||
|
// CHECK: [[ADD:%.+]] = addf [[LOAD1]], [[LOAD2]] : f32
|
||||||
|
// CHECK: store [[ADD]], [[RET_RES]][%arg2, %arg3] : memref<?x10xf32>
|
||||||
|
|
||||||
|
/// Dealloc of first result.
|
||||||
|
// CHECK: dealloc [[RES]] : memref<?x10xf32>
|
||||||
|
// CHECK-NOT: dealloc [[RET_RES]] : memref<?x10xf32>
|
||||||
|
|
||||||
|
// CHECK: return [[RET_RES]] : memref<?x10xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
func @test_max_max(%arg0 : tensor<?x10xf32>, %arg1 : tensor<?x10xf32>) -> tensor<*xf32> {
|
||||||
|
%0 = "onnx.Max"(%arg0, %arg1) : (tensor<?x10xf32>, tensor<?x10xf32>) -> tensor<*xf32>
|
||||||
|
%1 = "onnx.Max"(%0, %arg1) : (tensor<*xf32>, tensor<?x10xf32>) -> tensor<*xf32>
|
||||||
|
"std.return"(%1) : (tensor<*xf32>) -> ()
|
||||||
|
|
||||||
|
// CHECK-LABEL: test_max_max
|
||||||
|
/// First Max
|
||||||
|
// CHECK: [[DIM_0:%.+]] = dim %arg0, 0 : memref<?x10xf32>
|
||||||
|
// CHECK: [[RES:%.+]] = alloc([[DIM_0]]) : memref<?x10xf32>
|
||||||
|
// CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2
|
||||||
|
// CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops {
|
||||||
|
// CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1
|
||||||
|
// CHECK: } : () -> (!krnl.loop, !krnl.loop)
|
||||||
|
// CHECK: [[DIM_2:%.+]] = dim %arg0, 0 : memref<?x10xf32>
|
||||||
|
// CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg2 = 0 to [[DIM_2]], [[DEF_LOOPS]]#1 -> %arg3 = 0 to 10) {
|
||||||
|
// CHECK: [[LOAD1:%.+]] = load %arg0[%arg2, %arg3] : memref<?x10xf32>
|
||||||
|
// CHECK: [[LOAD2:%.+]] = load %arg1[%arg2, %arg3] : memref<?x10xf32>
|
||||||
|
// CHECK: [[MAX:%.+]] = cmpf "ogt", [[LOAD1]], [[LOAD2]] : f32
|
||||||
|
// CHECK: [[RELU_RES:%.+]] = select [[MAX]], [[LOAD1]], [[LOAD2]] : f32
|
||||||
|
// CHECK: store [[RELU_RES]], [[RES]][%arg2, %arg3] : memref<?x10xf32>
|
||||||
|
|
||||||
|
/// Second Max
|
||||||
|
// CHECK: [[DIM_0:%.+]] = dim [[RES]], 0 : memref<?x10xf32>
|
||||||
|
// CHECK: [[RET_RES:%.+]] = alloc([[DIM_0]]) : memref<?x10xf32>
|
||||||
|
// CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2
|
||||||
|
// CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops {
|
||||||
|
// CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1
|
||||||
|
// CHECK: } : () -> (!krnl.loop, !krnl.loop)
|
||||||
|
// CHECK: [[DIM_2:%.+]] = dim [[RES]], 0 : memref<?x10xf32>
|
||||||
|
// CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg2 = 0 to [[DIM_2]], [[DEF_LOOPS]]#1 -> %arg3 = 0 to 10) {
|
||||||
|
// CHECK: [[LOAD1:%.+]] = load [[RES]][%arg2, %arg3] : memref<?x10xf32>
|
||||||
|
// CHECK: [[LOAD2:%.+]] = load %arg1[%arg2, %arg3] : memref<?x10xf32>
|
||||||
|
// CHECK: [[MAX:%.+]] = cmpf "ogt", [[LOAD1]], [[LOAD2]] : f32
|
||||||
|
// CHECK: [[RELU_RES:%.+]] = select [[MAX]], [[LOAD1]], [[LOAD2]] : f32
|
||||||
|
// CHECK: store [[RELU_RES]], [[RET_RES]][%arg2, %arg3] : memref<?x10xf32>
|
||||||
|
|
||||||
|
/// Dealloc of first result.
|
||||||
|
// CHECK: dealloc [[RES]] : memref<?x10xf32>
|
||||||
|
// CHECK-NOT: dealloc [[RET_RES]] : memref<?x10xf32>
|
||||||
|
|
||||||
|
// CHECK: return [[RET_RES]] : memref<?x10xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
func @test_min_min(%arg0 : tensor<?x10xf32>, %arg1 : tensor<?x10xf32>) -> tensor<*xf32> {
|
||||||
|
%0 = "onnx.Min"(%arg0, %arg1) : (tensor<?x10xf32>, tensor<?x10xf32>) -> tensor<*xf32>
|
||||||
|
%1 = "onnx.Min"(%0, %arg1) : (tensor<*xf32>, tensor<?x10xf32>) -> tensor<*xf32>
|
||||||
|
"std.return"(%1) : (tensor<*xf32>) -> ()
|
||||||
|
|
||||||
|
// CHECK-LABEL: test_min_min
|
||||||
|
/// First Min
|
||||||
|
// CHECK: [[DIM_0:%.+]] = dim %arg0, 0 : memref<?x10xf32>
|
||||||
|
// CHECK: [[RES:%.+]] = alloc([[DIM_0]]) : memref<?x10xf32>
|
||||||
|
// CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2
|
||||||
|
// CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops {
|
||||||
|
// CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1
|
||||||
|
// CHECK: } : () -> (!krnl.loop, !krnl.loop)
|
||||||
|
// CHECK: [[DIM_2:%.+]] = dim %arg0, 0 : memref<?x10xf32>
|
||||||
|
// CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg2 = 0 to [[DIM_2]], [[DEF_LOOPS]]#1 -> %arg3 = 0 to 10) {
|
||||||
|
// CHECK: [[LOAD1:%.+]] = load %arg0[%arg2, %arg3] : memref<?x10xf32>
|
||||||
|
// CHECK: [[LOAD2:%.+]] = load %arg1[%arg2, %arg3] : memref<?x10xf32>
|
||||||
|
// CHECK: [[MIN:%.+]] = cmpf "olt", [[LOAD1]], [[LOAD2]] : f32
|
||||||
|
// CHECK: [[RELU_RES:%.+]] = select [[MIN]], [[LOAD1]], [[LOAD2]] : f32
|
||||||
|
// CHECK: store [[RELU_RES]], [[RES]][%arg2, %arg3] : memref<?x10xf32>
|
||||||
|
|
||||||
|
/// Second Min
|
||||||
|
// CHECK: [[DIM_0:%.+]] = dim [[RES]], 0 : memref<?x10xf32>
|
||||||
|
// CHECK: [[RET_RES:%.+]] = alloc([[DIM_0]]) : memref<?x10xf32>
|
||||||
|
// CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2
|
||||||
|
// CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops {
|
||||||
|
// CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1
|
||||||
|
// CHECK: } : () -> (!krnl.loop, !krnl.loop)
|
||||||
|
// CHECK: [[DIM_2:%.+]] = dim [[RES]], 0 : memref<?x10xf32>
|
||||||
|
// CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg2 = 0 to [[DIM_2]], [[DEF_LOOPS]]#1 -> %arg3 = 0 to 10) {
|
||||||
|
// CHECK: [[LOAD1:%.+]] = load [[RES]][%arg2, %arg3] : memref<?x10xf32>
|
||||||
|
// CHECK: [[LOAD2:%.+]] = load %arg1[%arg2, %arg3] : memref<?x10xf32>
|
||||||
|
// CHECK: [[MIN:%.+]] = cmpf "olt", [[LOAD1]], [[LOAD2]] : f32
|
||||||
|
// CHECK: [[RELU_RES:%.+]] = select [[MIN]], [[LOAD1]], [[LOAD2]] : f32
|
||||||
|
// CHECK: store [[RELU_RES]], [[RET_RES]][%arg2, %arg3] : memref<?x10xf32>
|
||||||
|
|
||||||
|
/// Dealloc of first result.
|
||||||
|
// CHECK: dealloc [[RES]] : memref<?x10xf32>
|
||||||
|
// CHECK-NOT: dealloc [[RET_RES]] : memref<?x10xf32>
|
||||||
|
|
||||||
|
// CHECK: return [[RET_RES]] : memref<?x10xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
func @test_elu_elu(%arg0 : tensor<?x10xf32>) -> tensor<*xf32> {
|
||||||
|
%0 = "onnx.Elu"(%arg0) {Elu.alpha=2.0:f32} : (tensor<?x10xf32>) -> tensor<*xf32>
|
||||||
|
%1 = "onnx.Elu"(%0) {Elu.alpha=2.0:f32} : (tensor<*xf32>) -> tensor<*xf32>
|
||||||
|
"std.return"(%1) : (tensor<*xf32>) -> ()
|
||||||
|
|
||||||
|
// CHECK-LABEL: test_elu_elu
|
||||||
|
/// First Elu
|
||||||
|
// CHECK: [[DIM_0:%.+]] = dim %arg0, 0 : memref<?x10xf32>
|
||||||
|
// CHECK: [[RES:%.+]] = alloc([[DIM_0]]) : memref<?x10xf32>
|
||||||
|
// CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2
|
||||||
|
// CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops {
|
||||||
|
// CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1
|
||||||
|
// CHECK: } : () -> (!krnl.loop, !krnl.loop)
|
||||||
|
// CHECK: [[DIM_2:%.+]] = dim %arg0, 0 : memref<?x10xf32>
|
||||||
|
// CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg1 = 0 to [[DIM_2]], [[DEF_LOOPS]]#1 -> %arg2 = 0 to 10) {
|
||||||
|
// CHECK: [[LOAD:%.+]] = load %arg0[%arg1, %arg2] : memref<?x10xf32>
|
||||||
|
// CHECK: [[ZERO:%.+]] = constant {{0.+}} : f32
|
||||||
|
// CHECK: [[ONE:%.+]] = constant {{1.+}} : f32
|
||||||
|
// CHECK: [[ALPHA:%.+]] = constant {{2.+}} : f32
|
||||||
|
// CHECK: [[EXP:%.+]] = exp [[LOAD]] : f32
|
||||||
|
// CHECK: [[CMP:%.+]] = cmpf "olt", [[LOAD]], [[ZERO]] : f32
|
||||||
|
// CHECK: [[SUB:%.+]] = subf [[EXP]], [[ONE]] : f32
|
||||||
|
// CHECK: [[MUL:%.+]] = mulf [[ALPHA]], [[SUB]] : f32
|
||||||
|
// CHECK: [[SELECT:%.+]] = select [[CMP]], [[MUL]], [[LOAD]] : f32
|
||||||
|
// CHECK: store [[SELECT]], [[RES]][%arg1, %arg2] : memref<?x10xf32>
|
||||||
|
|
||||||
|
/// Second Elu
|
||||||
|
// CHECK: [[DIM_0:%.+]] = dim [[RES]], 0 : memref<?x10xf32>
|
||||||
|
// CHECK: [[RET_RES:%.+]] = alloc([[DIM_0]]) : memref<?x10xf32>
|
||||||
|
// CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2
|
||||||
|
// CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops {
|
||||||
|
// CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1
|
||||||
|
// CHECK: } : () -> (!krnl.loop, !krnl.loop)
|
||||||
|
// CHECK: [[DIM_2:%.+]] = dim [[RES]], 0 : memref<?x10xf32>
|
||||||
|
// CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg1 = 0 to [[DIM_2]], [[DEF_LOOPS]]#1 -> %arg2 = 0 to 10) {
|
||||||
|
// CHECK: [[LOAD:%.+]] = load [[RES]][%arg1, %arg2] : memref<?x10xf32>
|
||||||
|
// CHECK: [[ZERO:%.+]] = constant {{0.+}} : f32
|
||||||
|
// CHECK: [[ONE:%.+]] = constant {{1.+}} : f32
|
||||||
|
// CHECK: [[ALPHA:%.+]] = constant {{2.+}} : f32
|
||||||
|
// CHECK: [[EXP:%.+]] = exp [[LOAD]] : f32
|
||||||
|
// CHECK: [[CMP:%.+]] = cmpf "olt", [[LOAD]], [[ZERO]] : f32
|
||||||
|
// CHECK: [[SUB:%.+]] = subf [[EXP]], [[ONE]] : f32
|
||||||
|
// CHECK: [[MUL:%.+]] = mulf [[ALPHA]], [[SUB]] : f32
|
||||||
|
// CHECK: [[SELECT:%.+]] = select [[CMP]], [[MUL]], [[LOAD]] : f32
|
||||||
|
// CHECK: store [[SELECT]], [[RET_RES]][%arg1, %arg2] : memref<?x10xf32>
|
||||||
|
|
||||||
|
/// Dealloc of first result.
|
||||||
|
// CHECK: dealloc [[RES]] : memref<?x10xf32>
|
||||||
|
// CHECK-NOT: dealloc [[RET_RES]] : memref<?x10xf32>
|
||||||
|
|
||||||
|
// CHECK: return [[RET_RES]] : memref<?x10xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
func @test_leakyrelu_leakyrelu(%arg0 : tensor<?x10xf32>) -> tensor<*xf32> {
|
||||||
|
%0 = "onnx.LeakyRelu"(%arg0) {LeakyRelu.alpha=1.0:f32} : (tensor<?x10xf32>) -> tensor<*xf32>
|
||||||
|
%1 = "onnx.LeakyRelu"(%0) {LeakyRelu.alpha=1.0:f32} : (tensor<*xf32>) -> tensor<*xf32>
|
||||||
|
"std.return"(%1) : (tensor<*xf32>) -> ()
|
||||||
|
|
||||||
|
// CHECK-LABEL: test_leakyrelu_leakyrelu
|
||||||
|
/// First LeakyRelu
|
||||||
|
// CHECK: [[DIM_0:%.+]] = dim %arg0, 0 : memref<?x10xf32>
|
||||||
|
// CHECK: [[RES:%.+]] = alloc([[DIM_0]]) : memref<?x10xf32>
|
||||||
|
// CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2
|
||||||
|
// CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops {
|
||||||
|
// CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1
|
||||||
|
// CHECK: } : () -> (!krnl.loop, !krnl.loop)
|
||||||
|
// CHECK: [[DIM_2:%.+]] = dim %arg0, 0 : memref<?x10xf32>
|
||||||
|
// CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg1 = 0 to [[DIM_2]], [[DEF_LOOPS]]#1 -> %arg2 = 0 to 10) {
|
||||||
|
// CHECK: [[LOAD:%.+]] = load %arg0[%arg1, %arg2] : memref<?x10xf32>
|
||||||
|
// CHECK: [[ZERO:%.+]] = constant {{0.+}} : f32
|
||||||
|
// CHECK: [[ALPHA:%.+]] = constant {{1.+}} : f32
|
||||||
|
// CHECK: [[CMP:%.+]] = cmpf "olt", [[LOAD]], [[ZERO]] : f32
|
||||||
|
// CHECK: [[MUL:%.+]] = mulf [[ALPHA]], [[LOAD]] : f32
|
||||||
|
// CHECK: [[SELECT:%.+]] = select [[CMP]], [[MUL]], [[LOAD]] : f32
|
||||||
|
// CHECK: store [[SELECT]], [[RES]][%arg1, %arg2] : memref<?x10xf32>
|
||||||
|
|
||||||
|
/// Second LeakyRelu
|
||||||
|
// CHECK: [[DIM_0:%.+]] = dim [[RES]], 0 : memref<?x10xf32>
|
||||||
|
// CHECK: [[RET_RES:%.+]] = alloc([[DIM_0]]) : memref<?x10xf32>
|
||||||
|
// CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2
|
||||||
|
// CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops {
|
||||||
|
// CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1
|
||||||
|
// CHECK: } : () -> (!krnl.loop, !krnl.loop)
|
||||||
|
// CHECK: [[DIM_2:%.+]] = dim [[RES]], 0 : memref<?x10xf32>
|
||||||
|
// CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg1 = 0 to [[DIM_2]], [[DEF_LOOPS]]#1 -> %arg2 = 0 to 10) {
|
||||||
|
// CHECK: [[LOAD:%.+]] = load [[RES]][%arg1, %arg2] : memref<?x10xf32>
|
||||||
|
// CHECK: [[ZERO:%.+]] = constant {{0.+}} : f32
|
||||||
|
// CHECK: [[ALPHA:%.+]] = constant {{1.+}} : f32
|
||||||
|
// CHECK: [[CMP:%.+]] = cmpf "olt", [[LOAD]], [[ZERO]] : f32
|
||||||
|
// CHECK: [[MUL:%.+]] = mulf [[ALPHA]], [[LOAD]] : f32
|
||||||
|
// CHECK: [[SELECT:%.+]] = select [[CMP]], [[MUL]], [[LOAD]] : f32
|
||||||
|
// CHECK: store [[SELECT]], [[RET_RES]][%arg1, %arg2] : memref<?x10xf32>
|
||||||
|
|
||||||
|
/// Dealloc of first result.
|
||||||
|
// CHECK: dealloc [[RES]] : memref<?x10xf32>
|
||||||
|
// CHECK-NOT: dealloc [[RET_RES]] : memref<?x10xf32>
|
||||||
|
|
||||||
|
// CHECK: return [[RET_RES]] : memref<?x10xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
func @test_selu_selu(%arg0 : tensor<?x10xf32>) -> tensor<*xf32> {
|
||||||
|
%0 = "onnx.Selu"(%arg0) {Selu.alpha=1.0:f32, Selu.gamma=2.0:f32} : (tensor<?x10xf32>) -> tensor<*xf32>
|
||||||
|
%1 = "onnx.Selu"(%0) {Selu.alpha=1.0:f32, Selu.gamma=2.0:f32} : (tensor<*xf32>) -> tensor<*xf32>
|
||||||
|
"std.return"(%1) : (tensor<*xf32>) -> ()
|
||||||
|
|
||||||
|
// CHECK-LABEL: test_selu_selu
|
||||||
|
/// First Selu
|
||||||
|
// CHECK: [[DIM_0:%.+]] = dim %arg0, 0 : memref<?x10xf32>
|
||||||
|
// CHECK: [[RES:%.+]] = alloc([[DIM_0]]) : memref<?x10xf32>
|
||||||
|
// CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2
|
||||||
|
// CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops {
|
||||||
|
// CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1
|
||||||
|
// CHECK: } : () -> (!krnl.loop, !krnl.loop)
|
||||||
|
// CHECK: [[DIM_2:%.+]] = dim %arg0, 0 : memref<?x10xf32>
|
||||||
|
// CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg1 = 0 to [[DIM_2]], [[DEF_LOOPS]]#1 -> %arg2 = 0 to 10) {
|
||||||
|
// CHECK: [[LOAD:%.+]] = load %arg0[%arg1, %arg2] : memref<?x10xf32>
|
||||||
|
// CHECK: [[ZERO:%.+]] = constant {{0.+}} : f32
|
||||||
|
// CHECK: [[ALPHA:%.+]] = constant {{1.+}} : f32
|
||||||
|
// CHECK: [[GAMMA:%.+]] = constant {{2.+}} : f32
|
||||||
|
// CHECK: [[EXP:%.+]] = exp [[LOAD]] : f32
|
||||||
|
// CHECK: [[CMP:%.+]] = cmpf "ogt", [[LOAD]], [[ZERO]] : f32
|
||||||
|
// CHECK: [[MUL:%.+]] = mulf [[ALPHA]], [[EXP]] : f32
|
||||||
|
// CHECK: [[SUB:%.+]] = subf [[MUL]], [[ALPHA]] : f32
|
||||||
|
// CHECK: [[SELECT:%.+]] = select [[CMP]], [[LOAD]], [[SUB]] : f32
|
||||||
|
// CHECK: [[SELU_RES:%.+]] = mulf [[GAMMA]], [[SELECT]] : f32
|
||||||
|
// CHECK: store [[SELU_RES]], [[RES]][%arg1, %arg2] : memref<?x10xf32>
|
||||||
|
|
||||||
|
/// Second Selu
|
||||||
|
// CHECK: [[DIM_0:%.+]] = dim [[RES]], 0 : memref<?x10xf32>
|
||||||
|
// CHECK: [[RET_RES:%.+]] = alloc([[DIM_0]]) : memref<?x10xf32>
|
||||||
|
// CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2
|
||||||
|
// CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops {
|
||||||
|
// CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1
|
||||||
|
// CHECK: } : () -> (!krnl.loop, !krnl.loop)
|
||||||
|
// CHECK: [[DIM_2:%.+]] = dim [[RES]], 0 : memref<?x10xf32>
|
||||||
|
// CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg1 = 0 to [[DIM_2]], [[DEF_LOOPS]]#1 -> %arg2 = 0 to 10) {
|
||||||
|
// CHECK: [[LOAD:%.+]] = load [[RES]][%arg1, %arg2] : memref<?x10xf32>
|
||||||
|
// CHECK: [[ZERO:%.+]] = constant {{0.+}} : f32
|
||||||
|
// CHECK: [[ALPHA:%.+]] = constant {{1.+}} : f32
|
||||||
|
// CHECK: [[GAMMA:%.+]] = constant {{2.+}} : f32
|
||||||
|
// CHECK: [[EXP:%.+]] = exp [[LOAD]] : f32
|
||||||
|
// CHECK: [[CMP:%.+]] = cmpf "ogt", [[LOAD]], [[ZERO]] : f32
|
||||||
|
// CHECK: [[MUL:%.+]] = mulf [[ALPHA]], [[EXP]] : f32
|
||||||
|
// CHECK: [[SUB:%.+]] = subf [[MUL]], [[ALPHA]] : f32
|
||||||
|
// CHECK: [[SELECT:%.+]] = select [[CMP]], [[LOAD]], [[SUB]] : f32
|
||||||
|
// CHECK: [[SELU_RES:%.+]] = mulf [[GAMMA]], [[SELECT]] : f32
|
||||||
|
// CHECK: store [[SELU_RES]], [[RET_RES]][%arg1, %arg2] : memref<?x10xf32>
|
||||||
|
|
||||||
|
/// Dealloc of first result.
|
||||||
|
// CHECK: dealloc [[RES]] : memref<?x10xf32>
|
||||||
|
// CHECK-NOT: dealloc [[RET_RES]] : memref<?x10xf32>
|
||||||
|
|
||||||
|
// CHECK: return [[RET_RES]] : memref<?x10xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
func @test_hardsigmoid_hardsigmoid(%arg0 : tensor<?x10xf32>) -> tensor<*xf32> {
|
||||||
|
%0 = "onnx.HardSigmoid"(%arg0) {HardSigmoid.alpha=1.0:f32, HardSigmoid.beta=2.0:f32} : (tensor<?x10xf32>) -> tensor<*xf32>
|
||||||
|
%1 = "onnx.HardSigmoid"(%0) {HardSigmoid.alpha=1.0:f32, HardSigmoid.beta=2.0:f32} : (tensor<*xf32>) -> tensor<*xf32>
|
||||||
|
"std.return"(%1) : (tensor<*xf32>) -> ()
|
||||||
|
|
||||||
|
// CHECK-LABEL: test_hardsigmoid_hardsigmoid
|
||||||
|
/// First HardSigmoid
|
||||||
|
// CHECK: [[DIM_0:%.+]] = dim %arg0, 0 : memref<?x10xf32>
|
||||||
|
// CHECK: [[RES:%.+]] = alloc([[DIM_0]]) : memref<?x10xf32>
|
||||||
|
// CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2
|
||||||
|
// CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops {
|
||||||
|
// CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1
|
||||||
|
// CHECK: } : () -> (!krnl.loop, !krnl.loop)
|
||||||
|
// CHECK: [[DIM_2:%.+]] = dim %arg0, 0 : memref<?x10xf32>
|
||||||
|
// CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg1 = 0 to [[DIM_2]], [[DEF_LOOPS]]#1 -> %arg2 = 0 to 10) {
|
||||||
|
// CHECK: [[LOAD:%.+]] = load %arg0[%arg1, %arg2] : memref<?x10xf32>
|
||||||
|
// CHECK: [[ZERO:%.+]] = constant {{0.+}} : f32
|
||||||
|
// CHECK: [[ONE:%.+]] = constant {{1.+}} : f32
|
||||||
|
// CHECK: [[ALPHA:%.+]] = constant {{1.+}} : f32
|
||||||
|
// CHECK: [[BETA:%.+]] = constant {{2.+}} : f32
|
||||||
|
// CHECK: [[MUL:%.+]] = mulf [[ALPHA]], [[LOAD]] : f32
|
||||||
|
// CHECK: [[ADD:%.+]] = addf [[MUL]], [[BETA]] : f32
|
||||||
|
// CHECK: [[CMP1:%.+]] = cmpf "ogt", [[ADD]], [[ZERO]] : f32
|
||||||
|
// CHECK: [[SELECT1:%.+]] = select [[CMP1]], [[ADD]], [[ZERO]] : f32
|
||||||
|
// CHECK: [[CMP2:%.+]] = cmpf "olt", [[SELECT1]], [[ONE]] : f32
|
||||||
|
// CHECK: [[SELECT2:%.+]] = select [[CMP2]], [[SELECT1]], [[ONE]] : f32
|
||||||
|
// CHECK: store [[SELECT2]], [[RES]][%arg1, %arg2] : memref<?x10xf32>
|
||||||
|
|
||||||
|
/// Second HardSigmoid
|
||||||
|
// CHECK: [[DIM_0:%.+]] = dim [[RES]], 0 : memref<?x10xf32>
|
||||||
|
// CHECK: [[RET_RES:%.+]] = alloc([[DIM_0]]) : memref<?x10xf32>
|
||||||
|
// CHECK: [[DEF_LOOPS:%.+]]:2 = krnl.define_loops 2
|
||||||
|
// CHECK: [[OPT_LOOPS:%.+]]:2 = krnl.optimize_loops {
|
||||||
|
// CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1
|
||||||
|
// CHECK: } : () -> (!krnl.loop, !krnl.loop)
|
||||||
|
// CHECK: [[DIM_2:%.+]] = dim [[RES]], 0 : memref<?x10xf32>
|
||||||
|
// CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg1 = 0 to [[DIM_2]], [[DEF_LOOPS]]#1 -> %arg2 = 0 to 10) {
|
||||||
|
// CHECK: [[LOAD:%.+]] = load [[RES]][%arg1, %arg2] : memref<?x10xf32>
|
||||||
|
// CHECK: [[ZERO:%.+]] = constant {{0.+}} : f32
|
||||||
|
// CHECK: [[ONE:%.+]] = constant {{1.+}} : f32
|
||||||
|
// CHECK: [[ALPHA:%.+]] = constant {{1.+}} : f32
|
||||||
|
// CHECK: [[BETA:%.+]] = constant {{2.+}} : f32
|
||||||
|
// CHECK: [[MUL:%.+]] = mulf [[ALPHA]], [[LOAD]] : f32
|
||||||
|
// CHECK: [[ADD:%.+]] = addf [[MUL]], [[BETA]] : f32
|
||||||
|
// CHECK: [[CMP1:%.+]] = cmpf "ogt", [[ADD]], [[ZERO]] : f32
|
||||||
|
// CHECK: [[SELECT1:%.+]] = select [[CMP1]], [[ADD]], [[ZERO]] : f32
|
||||||
|
// CHECK: [[CMP2:%.+]] = cmpf "olt", [[SELECT1]], [[ONE]] : f32
|
||||||
|
// CHECK: [[SELECT2:%.+]] = select [[CMP2]], [[SELECT1]], [[ONE]] : f32
|
||||||
|
// CHECK: store [[SELECT2]], [[RET_RES]][%arg1, %arg2] : memref<?x10xf32>
|
||||||
|
|
||||||
|
/// Dealloc of first result.
|
||||||
|
// CHECK: dealloc [[RES]] : memref<?x10xf32>
|
||||||
|
// CHECK-NOT: dealloc [[RET_RES]] : memref<?x10xf32>
|
||||||
|
|
||||||
|
// CHECK: return [[RET_RES]] : memref<?x10xf32>
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue