Merge branch 'master' into matmul-shape
This commit is contained in:
commit
a87f01747a
|
@ -611,7 +611,7 @@ private:
|
||||||
* a specialized function is used
|
* a specialized function is used
|
||||||
*/
|
*/
|
||||||
void ImportNodeConv(
|
void ImportNodeConv(
|
||||||
onnx::NodeProto node, int nIn, int nOut,
|
onnx::NodeProto node, int nOut,
|
||||||
std::initializer_list<std::tuple<std::string, std::string, std::string>>
|
std::initializer_list<std::tuple<std::string, std::string, std::string>>
|
||||||
attrs) {
|
attrs) {
|
||||||
// Conv has attribute dilations, kernel_shape, pads, the default value of
|
// Conv has attribute dilations, kernel_shape, pads, the default value of
|
||||||
|
@ -624,12 +624,12 @@ private:
|
||||||
// similar situation for pads, strides in AveragePool
|
// similar situation for pads, strides in AveragePool
|
||||||
// axes of ReduceSum, pads, strides, dilations and kernel_shape of MaxPool
|
// axes of ReduceSum, pads, strides, dilations and kernel_shape of MaxPool
|
||||||
// TODO: fix this after type inference
|
// TODO: fix this after type inference
|
||||||
|
int nOps = node.input().size();
|
||||||
|
|
||||||
if (node.input().size() == 1) {
|
if (nOps == 2)
|
||||||
ImportNodeOneOut<mlir::ONNXConv1Op>(node, nIn, nOut, attrs);
|
ImportNodeOneOut<mlir::ONNXConvNoBiasOp>(node, nOps, nOut, attrs);
|
||||||
} else {
|
else
|
||||||
ImportNodeOneOut<mlir::ONNXConv3Op>(node, nIn, nOut, attrs);
|
ImportNodeOneOut<mlir::ONNXConvOp>(node, nOps, nOut, attrs);
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/*!
|
/*!
|
||||||
|
|
|
@ -87,7 +87,7 @@
|
||||||
{"value","", ""}
|
{"value","", ""}
|
||||||
});
|
});
|
||||||
}else if (OpName == "Conv") {
|
}else if (OpName == "Conv") {
|
||||||
ImportNodeConv(node, 3, 1, {
|
ImportNodeConv(node, 1, {
|
||||||
{"auto_pad","str","NOTSET"}
|
{"auto_pad","str","NOTSET"}
|
||||||
,{"dilations","", ""}
|
,{"dilations","", ""}
|
||||||
,{"group","int", "1"}
|
,{"group","int", "1"}
|
||||||
|
|
|
@ -90,39 +90,28 @@ def ONNXEntryPointOp: ONNX_Op<"EntryPoint"> {
|
||||||
// or outputs. This decision affects only ONNX operations with optional
|
// or outputs. This decision affects only ONNX operations with optional
|
||||||
// arguments not ONNX operations with variadic operands.
|
// arguments not ONNX operations with variadic operands.
|
||||||
|
|
||||||
def ONNXFullGemmOp: ONNX_Op<"FullGemm",
|
def ONNXGemmNoBiasOp: ONNX_Op<"GemmNoBias",
|
||||||
[NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
|
[NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
|
||||||
let summary = "ONNX general matrix multiply operation";
|
let summary = "ONNX general matrix multiply operation without bias.";
|
||||||
let description = [{
|
let description = [{
|
||||||
|
|
||||||
The "onnx.gemm" generic matrix multiplication with bias.
|
The "onnx.Gemm" generic matrix multiplication without bias.
|
||||||
|
|
||||||
}];
|
}];
|
||||||
|
|
||||||
let arguments = (ins AnyTensor:$lhs_in, AnyTensor:$rhs_in, AnyTensor:$bias_in);
|
let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$lhs_in, AnyTypeOf<[AnyMemRef, AnyTensor]>:$rhs_in);
|
||||||
let results = (outs AnyTensor);
|
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>);
|
||||||
}
|
}
|
||||||
|
|
||||||
def ONNXConv1Op:ONNX_Op<"Conv1",
|
def ONNXConvNoBiasOp:ONNX_Op<"ConvNoBias",
|
||||||
[NoSideEffect]> {
|
[NoSideEffect]> {
|
||||||
let summary = "ONNX Conv operation";
|
let summary = "ONNX Conv operation with no Bias operand.";
|
||||||
let description = [{
|
let description = [{
|
||||||
"The convolution operator consumes an input tensor and a filter, and"
|
"The convolution operator consumes an input tensor and a filter, and"
|
||||||
"computes the output."
|
"computes the output."
|
||||||
}];
|
}];
|
||||||
let arguments = (ins AnyTensor:$X);
|
let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$X, AnyTypeOf<[AnyMemRef, AnyTensor]>:$W);
|
||||||
let results = (outs AnyTensor);
|
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>);
|
||||||
}
|
|
||||||
|
|
||||||
def ONNXConv3Op:ONNX_Op<"Conv3",
|
|
||||||
[NoSideEffect]> {
|
|
||||||
let summary = "ONNX Conv operation";
|
|
||||||
let description = [{
|
|
||||||
"The convolution operator consumes an input tensor and a filter, and"
|
|
||||||
"computes the output."
|
|
||||||
}];
|
|
||||||
let arguments = (ins AnyTensor:$X, AnyTensor:$W, AnyTensor:$B);
|
|
||||||
let results = (outs AnyTensor);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
def ONNXMaxPoolSingleOutOp: ONNX_Op<"MaxPoolSingleOut",
|
def ONNXMaxPoolSingleOutOp: ONNX_Op<"MaxPoolSingleOut",
|
||||||
|
|
|
@ -424,9 +424,9 @@ void ONNXGemmOp::inferShapes() {
|
||||||
getResult().setType(RankedTensorType::get(dims, lhsTy.getElementType()));
|
getResult().setType(RankedTensorType::get(dims, lhsTy.getElementType()));
|
||||||
}
|
}
|
||||||
|
|
||||||
// FullGemm
|
// GemmNoBias
|
||||||
|
|
||||||
void ONNXFullGemmOp::inferShapes() {
|
void ONNXGemmNoBiasOp::inferShapes() {
|
||||||
// Cannot infer shape if no shape exists.
|
// Cannot infer shape if no shape exists.
|
||||||
if (!getOperand(0).getType().isa<RankedTensorType>() ||
|
if (!getOperand(0).getType().isa<RankedTensorType>() ||
|
||||||
!getOperand(1).getType().isa<RankedTensorType>())
|
!getOperand(1).getType().isa<RankedTensorType>())
|
||||||
|
|
|
@ -315,7 +315,7 @@ def ONNXConstantOfShapeOp:ONNX_Op<"ConstantOfShape",
|
||||||
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>);
|
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>);
|
||||||
}
|
}
|
||||||
|
|
||||||
def ONNXConvOp:ONNX_Op<"Conv",
|
def ONNXConvOp:ONNX_Op<"Conv",
|
||||||
[NoSideEffect]> {
|
[NoSideEffect]> {
|
||||||
let summary = "ONNX Conv operation";
|
let summary = "ONNX Conv operation";
|
||||||
let description = [{
|
let description = [{
|
||||||
|
|
|
@ -30,9 +30,9 @@ def HasOneUse : Constraint<CPred<"$0.hasOneUse()">>;
|
||||||
// Pattern-Match and Rewrite
|
// Pattern-Match and Rewrite
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
// onnx.add(onnx.matmul(%X, %Y), %Z) = onnx.FullGemm(%X, %Y, %Z)
|
// onnx.add(onnx.matmul(%X, %Y), %Z) = onnx.Gemm(%X, %Y, %Z)
|
||||||
def MulAddToGemmOptPattern : Pat<(ONNXAddOp (ONNXMatMulOp:$res $m1, $m2), $m3),
|
def MulAddToGemmOptPattern : Pat<(ONNXAddOp (ONNXMatMulOp:$res $m1, $m2), $m3),
|
||||||
(ONNXFullGemmOp $m1, $m2, $m3),
|
(ONNXGemmOp $m1, $m2, $m3),
|
||||||
[(HasOneUse $res)]>;
|
[(HasOneUse $res)]>;
|
||||||
|
|
||||||
// ONNX_Op (onnx.Identity (%X)) = ONNX_Op (%X)
|
// ONNX_Op (onnx.Identity (%X)) = ONNX_Op (%X)
|
||||||
|
|
|
@ -114,7 +114,7 @@ public:
|
||||||
op->getName().getStringRef() != "onnx.Identity" &&
|
op->getName().getStringRef() != "onnx.Identity" &&
|
||||||
op->getName().getStringRef() != "onnx.MatMul" &&
|
op->getName().getStringRef() != "onnx.MatMul" &&
|
||||||
op->getName().getStringRef() != "onnx.Gemm" &&
|
op->getName().getStringRef() != "onnx.Gemm" &&
|
||||||
op->getName().getStringRef() != "onnx.FullGemm" &&
|
op->getName().getStringRef() != "onnx.GemmNoBias" &&
|
||||||
op->getName().getStringRef() != "onnx.Reshape" &&
|
op->getName().getStringRef() != "onnx.Reshape" &&
|
||||||
op->getName().getStringRef() != "onnx.Transpose")
|
op->getName().getStringRef() != "onnx.Transpose")
|
||||||
return false;
|
return false;
|
||||||
|
|
|
@ -2,7 +2,7 @@
|
||||||
|
|
||||||
func @test_matmul_add_simplification(%a0: tensor<10x10xf32>, %a1: tensor<10x10xf32>, %a2: tensor<10x10xf32>) -> tensor<10x10xf32> {
|
func @test_matmul_add_simplification(%a0: tensor<10x10xf32>, %a1: tensor<10x10xf32>, %a2: tensor<10x10xf32>) -> tensor<10x10xf32> {
|
||||||
// CHECK-LABEL: test_matmul_add_simplification
|
// CHECK-LABEL: test_matmul_add_simplification
|
||||||
// CHECK: %{{[0-9]+}} = "onnx.FullGemm"(%{{.*}}, %{{.*}}, %{{.*}}) : (tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<10x10xf32>
|
// CHECK: %{{[0-9]+}} = "onnx.Gemm"(%{{.*}}, %{{.*}}, %{{.*}}) : (tensor<10x10xf32>, tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<10x10xf32>
|
||||||
%0 = "onnx.MatMul"(%a0, %a1) : (tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<10x10xf32>
|
%0 = "onnx.MatMul"(%a0, %a1) : (tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<10x10xf32>
|
||||||
%1 = "onnx.Add"(%0, %a2) : (tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<10x10xf32>
|
%1 = "onnx.Add"(%0, %a2) : (tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<10x10xf32>
|
||||||
"std.return"(%1) : (tensor<10x10xf32>) -> ()
|
"std.return"(%1) : (tensor<10x10xf32>) -> ()
|
||||||
|
|
Loading…
Reference in New Issue