Unify Conv implementation (#54)
* fixed readme for new git repo * conv with bias as an optional input
This commit is contained in:
		
							parent
							
								
									1777c07b1e
								
							
						
					
					
						commit
						653fa69102
					
				|  | @ -636,35 +636,6 @@ ONNX ConvInteger operation | ||||||
| 
 | 
 | ||||||
| 1. `y`: memref of any type values or tensor of any type values | 1. `y`: memref of any type values or tensor of any type values | ||||||
| 
 | 
 | ||||||
| ### onnx.ConvNoBias (ONNXConvNoBiasOp) |  | ||||||
| ONNX Conv operation with no Bias operand. |  | ||||||
| 
 |  | ||||||
| #### Description: |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| "The convolution operator consumes an input tensor and a filter, and" |  | ||||||
| "computes the output." |  | ||||||
| 
 |  | ||||||
| #### Operands: |  | ||||||
| 
 |  | ||||||
| 1. `X`: memref of any type values or tensor of any type values |  | ||||||
| 1. `W`: memref of any type values or tensor of any type values |  | ||||||
| 
 |  | ||||||
| #### Attributes: |  | ||||||
| 
 |  | ||||||
| | Attribute | MLIR Type | Description | |  | ||||||
| | :-------: | :-------: | ----------- | |  | ||||||
| | `auto_pad` | `StringAttr` | string attribute attribute | |  | ||||||
| | `dilations` | `ArrayAttr` | 64-bit integer array attribute attribute | |  | ||||||
| | `group` | `IntegerAttr` | 64-bit integer attribute attribute | |  | ||||||
| | `kernel_shape` | `ArrayAttr` | 64-bit integer array attribute attribute | |  | ||||||
| | `pads` | `ArrayAttr` | 64-bit integer array attribute attribute | |  | ||||||
| | `strides` | `ArrayAttr` | 64-bit integer array attribute attribute | |  | ||||||
| 
 |  | ||||||
| #### Results: |  | ||||||
| 
 |  | ||||||
| 1. `o_Y`: memref of any type values or tensor of any type values |  | ||||||
| 
 |  | ||||||
| ### onnx.Conv (ONNXConvOp) | ### onnx.Conv (ONNXConvOp) | ||||||
| ONNX Conv operation | ONNX Conv operation | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -32,7 +32,6 @@ special_attr_defaults = dict([ | ||||||
| 
 | 
 | ||||||
| # Special operation importing handlers. | # Special operation importing handlers. | ||||||
| special_op_handler = dict([ | special_op_handler = dict([ | ||||||
|     ("Conv", "ImportNodeConv"), |  | ||||||
|     ("MaxPool", "ImportNodeMaxPool"), |     ("MaxPool", "ImportNodeMaxPool"), | ||||||
|     ("BatchNormalization", "ImportNodeBatchNormalization"), |     ("BatchNormalization", "ImportNodeBatchNormalization"), | ||||||
|     ("Pad", "ImportNodePad"), |     ("Pad", "ImportNodePad"), | ||||||
|  | @ -47,11 +46,11 @@ OpsWithShapeInference = [ | ||||||
|     'LeakyRelu', 'Elu', 'Selu', 'HardSigmoid', 'Reshape', 'Reciprocal', |     'LeakyRelu', 'Elu', 'Selu', 'HardSigmoid', 'Reshape', 'Reciprocal', | ||||||
|     'Identity', 'Cos', 'Log', 'Transpose', 'Softmax', 'ReduceMax', 'ReduceMin', |     'Identity', 'Cos', 'Log', 'Transpose', 'Softmax', 'ReduceMax', 'ReduceMin', | ||||||
|     'ReduceProd', 'ReduceSum', 'Softplus', 'Softsign', 'Sqrt', 'Unsqueeze', |     'ReduceProd', 'ReduceSum', 'Softplus', 'Softsign', 'Sqrt', 'Unsqueeze', | ||||||
|     'Sign', 'Constant', 'AveragePool', 'Abs' |     'Sign', 'Constant', 'AveragePool', 'Abs', 'Conv' | ||||||
| ] | ] | ||||||
| 
 | 
 | ||||||
| # Operations supporting canonicalization. | # Operations supporting canonicalization. | ||||||
| OpsWithCanonicalizer = ['Add', 'Identity', 'Gemm'] | OpsWithCanonicalizer = ['Add', 'Identity', 'Gemm', 'Conv'] | ||||||
| 
 | 
 | ||||||
| # Operations who have operands that, if produced by constant operations, should | # Operations who have operands that, if produced by constant operations, should | ||||||
| # be promoted to become an attribute (via attribute promotion). | # be promoted to become an attribute (via attribute promotion). | ||||||
|  |  | ||||||
|  | @ -303,30 +303,6 @@ private: | ||||||
|     buildOutputAndOperation<mlir::ONNXReshapeOp>(node, inputs, nIn, nOut); |     buildOutputAndOperation<mlir::ONNXReshapeOp>(node, inputs, nIn, nOut); | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|   /*!
 |  | ||||||
|    * Special handle for Conv operations. |  | ||||||
|    * c++ does not allow template specialization inside a class scope |  | ||||||
|    * a specialized function is used |  | ||||||
|    */ |  | ||||||
|   void ImportNodeConv(onnx::NodeProto node, int nIn, int nOut) { |  | ||||||
|     // Conv has attribute dilations, kernel_shape, pads, the default value of
 |  | ||||||
|     // which  is determined by the shape of first argument. However, since the
 |  | ||||||
|     // shape is unknown now, these attributes can be not generated auto
 |  | ||||||
|     // dilations_attr = get_attr_ints(node, "dilations",
 |  | ||||||
|     //    std::vector<int>(inputs[0]->getType().cast<RankedTensorType>.getDims()-2,
 |  | ||||||
|     //    1));
 |  | ||||||
|     // attributes.push_back(dilations_attr)
 |  | ||||||
|     // similar situation for pads, strides in AveragePool
 |  | ||||||
|     // axes of ReduceSum,  pads, strides, dilations and kernel_shape of MaxPool
 |  | ||||||
|     // TODO: fix this after type inference
 |  | ||||||
|     int nOps = node.input().size(); |  | ||||||
| 
 |  | ||||||
|     if (nOps == 2) |  | ||||||
|       buildOperation<mlir::ONNXConvNoBiasOp>(node, nOps, nOut); |  | ||||||
|     else |  | ||||||
|       buildOperation<mlir::ONNXConvOp>(node, nOps, nOut); |  | ||||||
|   } |  | ||||||
| 
 |  | ||||||
|   /*!
 |   /*!
 | ||||||
|    * Special handle for MaxPool operations. |    * Special handle for MaxPool operations. | ||||||
|    */ |    */ | ||||||
|  |  | ||||||
|  | @ -50,7 +50,7 @@ if (opName == "Constant") | ||||||
| if (opName == "ConstantOfShape") | if (opName == "ConstantOfShape") | ||||||
|   return buildOperation<mlir::ONNXConstantOfShapeOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); |   return buildOperation<mlir::ONNXConstantOfShapeOp>(node, /* expected_num_operands = */ 1, /* expected_num_results = */ 1); | ||||||
| if (opName == "Conv") | if (opName == "Conv") | ||||||
|   return ImportNodeConv(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1); |   return buildOperation<mlir::ONNXConvOp>(node, /* expected_num_operands = */ 3, /* expected_num_results = */ 1); | ||||||
| if (opName == "ConvInteger") | if (opName == "ConvInteger") | ||||||
|   return buildOperation<mlir::ONNXConvIntegerOp>(node, /* expected_num_operands = */ 4, /* expected_num_results = */ 1); |   return buildOperation<mlir::ONNXConvIntegerOp>(node, /* expected_num_operands = */ 4, /* expected_num_results = */ 1); | ||||||
| if (opName == "ConvTranspose") | if (opName == "ConvTranspose") | ||||||
|  |  | ||||||
|  | @ -12,18 +12,19 @@ | ||||||
| 
 | 
 | ||||||
| using namespace mlir; | using namespace mlir; | ||||||
| 
 | 
 | ||||||
| struct ONNXConvNoBiasOpLowering : public ConversionPattern { | struct ONNXConvOpLowering : public ConversionPattern { | ||||||
|   ONNXConvNoBiasOpLowering(MLIRContext *ctx) |   ONNXConvOpLowering(MLIRContext *ctx) | ||||||
|       : ConversionPattern(mlir::ONNXConvNoBiasOp::getOperationName(), 1, ctx) {} |       : ConversionPattern(mlir::ONNXConvOp::getOperationName(), 1, ctx) {} | ||||||
| 
 | 
 | ||||||
|   PatternMatchResult matchAndRewrite(Operation *op, ArrayRef<Value> operands, |   PatternMatchResult matchAndRewrite(Operation *op, ArrayRef<Value> operands, | ||||||
|       ConversionPatternRewriter &rewriter) const final { |       ConversionPatternRewriter &rewriter) const final { | ||||||
|     auto loc = op->getLoc(); |     auto loc = op->getLoc(); | ||||||
|  |     ONNXConvOpOperandAdaptor operandAdaptor(operands); | ||||||
|     // Insert an allocation and deallocation for the result of this operation.
 |     // Insert an allocation and deallocation for the result of this operation.
 | ||||||
|     auto memRefType = convertToMemRefType(*op->result_type_begin()); |     auto memRefType = convertToMemRefType(*op->result_type_begin()); | ||||||
|     Value alloc; |     Value alloc; | ||||||
|     bool insertDealloc = checkInsertDealloc(op); |     bool insertDealloc = checkInsertDealloc(op); | ||||||
|     ONNXConvNoBiasOp convOp = llvm::dyn_cast<ONNXConvNoBiasOp>(op); |     ONNXConvOp convOp = llvm::dyn_cast<ONNXConvOp>(op); | ||||||
| 
 | 
 | ||||||
|     if (hasAllConstantDimensions(memRefType)) |     if (hasAllConstantDimensions(memRefType)) | ||||||
|       alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc); |       alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc); | ||||||
|  | @ -32,12 +33,14 @@ struct ONNXConvNoBiasOpLowering : public ConversionPattern { | ||||||
|           memRefType, loc, rewriter, insertDealloc, {operands[0]}); |           memRefType, loc, rewriter, insertDealloc, {operands[0]}); | ||||||
| 
 | 
 | ||||||
|     auto resultShape = memRefType.getShape(); |     auto resultShape = memRefType.getShape(); | ||||||
|     auto &inputOperand = operands[0]; |     auto inputOperand = operandAdaptor.X(); | ||||||
|     auto inputShape = inputOperand.getType().cast<MemRefType>().getShape(); |     auto inputShape = inputOperand.getType().cast<MemRefType>().getShape(); | ||||||
|     auto &kernelOperand = operands[1]; |     auto kernelOperand = operandAdaptor.W(); | ||||||
|     auto kernelShape = kernelOperand.getType().cast<MemRefType>().getShape(); |     auto kernelShape = kernelOperand.getType().cast<MemRefType>().getShape(); | ||||||
|  |     auto biasOperand = operandAdaptor.B(); | ||||||
|  |     bool hasBias = !biasOperand.getType().isa<NoneType>(); | ||||||
| 
 | 
 | ||||||
|     // R = ConvNoBias(D, K)
 |     // R = Conv(D, K)
 | ||||||
|     //
 |     //
 | ||||||
|     // The input/output shapes will look like this:
 |     // The input/output shapes will look like this:
 | ||||||
|     //
 |     //
 | ||||||
|  | @ -169,8 +172,23 @@ struct ONNXConvNoBiasOpLowering : public ConversionPattern { | ||||||
| 
 | 
 | ||||||
|         // 3.4 Emit inner loop nest.
 |         // 3.4 Emit inner loop nest.
 | ||||||
|         innerLoops.createIterateOp(); |         innerLoops.createIterateOp(); | ||||||
|         rewriter.setInsertionPointToStart(innerLoops.getIterateBlock()); |  | ||||||
| 
 | 
 | ||||||
|  |         // Emit the bias, if needed.
 | ||||||
|  |         if (hasBias) { | ||||||
|  |           auto loadResult = | ||||||
|  |               rewriter.create<LoadOp>(loc, alloc, resultIndices); | ||||||
|  |           SmallVector<Value, 4> biasIndices; | ||||||
|  |           biasIndices.emplace_back(kernel); | ||||||
|  |           auto loadBias = | ||||||
|  |               rewriter.create<LoadOp>(loc, biasOperand, kernel); | ||||||
|  |           auto resultWithBias = rewriter.create<MulFOp>( | ||||||
|  |             loc, loadResult, loadBias); | ||||||
|  |           // Store initializer value into output location.
 | ||||||
|  |           rewriter.create<StoreOp>(loc, resultWithBias, alloc, resultIndices); | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|  |         //
 | ||||||
|  |         rewriter.setInsertionPointToStart(innerLoops.getIterateBlock()); | ||||||
|         { |         { | ||||||
|           // 4. Emit inner loop body
 |           // 4. Emit inner loop body
 | ||||||
|           // R[n][kernel][r1][r2] =
 |           // R[n][kernel][r1][r2] =
 | ||||||
|  | @ -238,5 +256,5 @@ struct ONNXConvNoBiasOpLowering : public ConversionPattern { | ||||||
| 
 | 
 | ||||||
| void populateLoweringONNXConvOpPattern( | void populateLoweringONNXConvOpPattern( | ||||||
|     OwningRewritePatternList &patterns, MLIRContext *ctx) { |     OwningRewritePatternList &patterns, MLIRContext *ctx) { | ||||||
|   patterns.insert<ONNXConvNoBiasOpLowering>(ctx); |   patterns.insert<ONNXConvOpLowering>(ctx); | ||||||
| } | } | ||||||
|  |  | ||||||
|  | @ -1022,14 +1022,18 @@ void ONNXReduceSumOp::inferShapes() { | ||||||
| //   -  kernelShape: inferred from weight matrix if not defined by user;
 | //   -  kernelShape: inferred from weight matrix if not defined by user;
 | ||||||
| //   -  pads: set to proper value, 0 if not defined by user.
 | //   -  pads: set to proper value, 0 if not defined by user.
 | ||||||
| 
 | 
 | ||||||
| void ONNXConvNoBiasOp::inferShapes() { | void ONNXConvOp::inferShapes() { | ||||||
|   // Generic shape for data input X and weight tensor W:
 |   // Generic shape for data input X, weight tensor W, and optional bias B
 | ||||||
|   // X: (N x C x D1 x D2 ... x Dn)
 |   // X: (N x C x D1 x D2 ... x Dn)
 | ||||||
|   // W: (M x C/group x k1 x k2 x ... x kn)
 |   // W: (M x C/group x k1 x k2 x ... x kn)
 | ||||||
|  |   // B: (M) Optional
 | ||||||
|  | 
 | ||||||
|  |   bool hasBias = !B().getType().isa<NoneType>(); | ||||||
| 
 | 
 | ||||||
|   // Cannot infer shape if no shape exists.
 |   // Cannot infer shape if no shape exists.
 | ||||||
|   if (!X().getType().isa<RankedTensorType>() || |   if (!X().getType().isa<RankedTensorType>() || | ||||||
|       !W().getType().isa<RankedTensorType>()) |       !W().getType().isa<RankedTensorType>() || | ||||||
|  |       (hasBias && !B().getType().isa<RankedTensorType>())) | ||||||
|     return; |     return; | ||||||
| 
 | 
 | ||||||
|   auto xTy = X().getType().cast<RankedTensorType>(); |   auto xTy = X().getType().cast<RankedTensorType>(); | ||||||
|  | @ -1047,7 +1051,7 @@ void ONNXConvNoBiasOp::inferShapes() { | ||||||
|     emitError("Weight size not compatible with data size"); |     emitError("Weight size not compatible with data size"); | ||||||
| 
 | 
 | ||||||
|   // Group is a required attribute and should have default value of 1.
 |   // Group is a required attribute and should have default value of 1.
 | ||||||
|   int64_t group = ONNXConvNoBiasOp::group().getSExtValue(); |   int64_t group = ONNXConvOp::group().getSExtValue(); | ||||||
| 
 | 
 | ||||||
|   // Check if the attribute actually exists. If it does not then add it.
 |   // Check if the attribute actually exists. If it does not then add it.
 | ||||||
|   if (!groupAttr()) |   if (!groupAttr()) | ||||||
|  | @ -1058,6 +1062,16 @@ void ONNXConvNoBiasOp::inferShapes() { | ||||||
|       xShape[1] != (weightShape[1] * group)) |       xShape[1] != (weightShape[1] * group)) | ||||||
|     emitError("Channel dimension mismatch"); |     emitError("Channel dimension mismatch"); | ||||||
| 
 | 
 | ||||||
|  |   // Check the size of bias.
 | ||||||
|  |   if (hasBias) { | ||||||
|  |     auto bTx = B().getType().cast<RankedTensorType>(); | ||||||
|  |     auto bShape = bTx.getShape(); | ||||||
|  |     if (bShape.size() != 1) | ||||||
|  |       emitError("bias should be one dimensional"); | ||||||
|  |     if (bShape[0] != weightShape[0]) | ||||||
|  |       emitError("bias should have same dimensions as weight's first dimension"); | ||||||
|  |   } | ||||||
|  |    | ||||||
|   // Note: the value of the group attribut only impacts the way the
 |   // Note: the value of the group attribut only impacts the way the
 | ||||||
|   // computation is carried out and not the actual output size.
 |   // computation is carried out and not the actual output size.
 | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -95,25 +95,6 @@ def ONNXEntryPointOp: ONNX_Op<"EntryPoint"> { | ||||||
| // or outputs. This decision affects only ONNX operations with optional | // or outputs. This decision affects only ONNX operations with optional | ||||||
| // arguments not ONNX operations with variadic operands. | // arguments not ONNX operations with variadic operands. | ||||||
| 
 | 
 | ||||||
| def ONNXConvNoBiasOp:ONNX_Op<"ConvNoBias", |  | ||||||
|     [NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> { |  | ||||||
|   let hasCanonicalizer = 1; |  | ||||||
|   let summary = "ONNX Conv operation with no Bias operand."; |  | ||||||
|   let description = [{ |  | ||||||
|     "The convolution operator consumes an input tensor and a filter, and" |  | ||||||
|     "computes the output." |  | ||||||
|   }]; |  | ||||||
|   let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$X, |  | ||||||
|            AnyTypeOf<[AnyMemRef, AnyTensor]>:$W, |  | ||||||
|            DefaultValuedAttr<StrAttr, "NOTSET">:$auto_pad, |  | ||||||
|            OptionalAttr<I64ArrayAttr>:$dilations, |  | ||||||
|            DefaultValuedAttr<I64Attr, "1">:$group, |  | ||||||
|            OptionalAttr<I64ArrayAttr>:$kernel_shape, |  | ||||||
|            OptionalAttr<I64ArrayAttr>:$pads, |  | ||||||
|            OptionalAttr<I64ArrayAttr>:$strides); |  | ||||||
|   let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$o_Y); |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| def ONNXMaxPoolSingleOutOp: ONNX_Op<"MaxPoolSingleOut", | def ONNXMaxPoolSingleOutOp: ONNX_Op<"MaxPoolSingleOut", | ||||||
|     [NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> { |     [NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> { | ||||||
|   let hasCanonicalizer = 1; |   let hasCanonicalizer = 1; | ||||||
|  |  | ||||||
|  | @ -363,7 +363,8 @@ def ONNXConstantOfShapeOp:ONNX_Op<"ConstantOfShape", | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| def ONNXConvOp:ONNX_Op<"Conv", | def ONNXConvOp:ONNX_Op<"Conv", | ||||||
|   [NoSideEffect]> { |   [NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> { | ||||||
|  |   let hasCanonicalizer = 1; | ||||||
|   let summary = "ONNX Conv operation"; |   let summary = "ONNX Conv operation"; | ||||||
|   let description = [{ |   let description = [{ | ||||||
|   "The convolution operator consumes an input tensor and a filter, and" |   "The convolution operator consumes an input tensor and a filter, and" | ||||||
|  |  | ||||||
|  | @ -72,7 +72,7 @@ ArrayAttr insertZerosForNonPaddedDims( | ||||||
| 
 | 
 | ||||||
| //===----------------------------------------------------------------------===//
 | //===----------------------------------------------------------------------===//
 | ||||||
| // Rewrite:
 | // Rewrite:
 | ||||||
| // %0 = onnx.ConvNoBiasOp(%D : tensor<DShape>, %K)
 | // %0 = onnx.Conv(%D : tensor<DShape>, %K)
 | ||||||
| //     {pads = [b0, b1, ... bK, e0, e1, ..., eK]} ->
 | //     {pads = [b0, b1, ... bK, e0, e1, ..., eK]} ->
 | ||||||
| //         tensor<OutShape>
 | //         tensor<OutShape>
 | ||||||
| //
 | //
 | ||||||
|  | @ -80,14 +80,14 @@ ArrayAttr insertZerosForNonPaddedDims( | ||||||
| // %0 = onnx.PadConstantValuePasOp(%D)
 | // %0 = onnx.PadConstantValuePasOp(%D)
 | ||||||
| //     {pads = [0, 0, b0, b1, ... bK, 0, 0, e0, e1, ..., eK]} ->
 | //     {pads = [0, 0, b0, b1, ... bK, 0, 0, e0, e1, ..., eK]} ->
 | ||||||
| //     tensor<DPaddedShape>
 | //     tensor<DPaddedShape>
 | ||||||
| // %1 = onnx.ConvNoBias(%0 : tensor<DPaddedShape>, %K) {pads = [0, ..., 0]} ->
 | // %1 = onnx.Conv(%0 : tensor<DPaddedShape>, %K) {pads = [0, ..., 0]} ->
 | ||||||
| //     tensor<OutShape>
 | //     tensor<OutShape>
 | ||||||
| //===----------------------------------------------------------------------===//
 | //===----------------------------------------------------------------------===//
 | ||||||
| struct SplitConvOpPattern : public RewritePattern { | struct SplitConvOpPattern : public RewritePattern { | ||||||
|   SplitConvOpPattern(MLIRContext *context) |   SplitConvOpPattern(MLIRContext *context) | ||||||
|       : RewritePattern(ONNXConvNoBiasOp::getOperationName(), |       : RewritePattern(ONNXConvOp::getOperationName(), | ||||||
|                        {ONNXPadConstantValuePadOp::getOperationName(), |                        {ONNXPadConstantValuePadOp::getOperationName(), | ||||||
|                         ONNXConvNoBiasOp::getOperationName()}, |                         ONNXConvOp::getOperationName()}, | ||||||
|                        1, context) {} |                        1, context) {} | ||||||
| 
 | 
 | ||||||
|   PatternMatchResult matchAndRewrite(Operation *op, |   PatternMatchResult matchAndRewrite(Operation *op, | ||||||
|  | @ -95,7 +95,7 @@ struct SplitConvOpPattern : public RewritePattern { | ||||||
|     auto loc = op->getLoc(); |     auto loc = op->getLoc(); | ||||||
| 
 | 
 | ||||||
|     // If convolution does not use padding then no rewrite is required.
 |     // If convolution does not use padding then no rewrite is required.
 | ||||||
|     ONNXConvNoBiasOp convOp = llvm::dyn_cast<ONNXConvNoBiasOp>(op); |     ONNXConvOp convOp = llvm::dyn_cast<ONNXConvOp>(op); | ||||||
|     auto padsAttribute = convOp.padsAttr(); |     auto padsAttribute = convOp.padsAttr(); | ||||||
|     if (!padsAttribute) |     if (!padsAttribute) | ||||||
|       return matchFailure(); |       return matchFailure(); | ||||||
|  | @ -155,8 +155,9 @@ struct SplitConvOpPattern : public RewritePattern { | ||||||
| 
 | 
 | ||||||
|     SmallVector<int64_t, 4> newConvPads(2 * inputDims, 0); |     SmallVector<int64_t, 4> newConvPads(2 * inputDims, 0); | ||||||
|     auto tensorType = (*op->result_type_begin()).cast<TensorType>(); |     auto tensorType = (*op->result_type_begin()).cast<TensorType>(); | ||||||
|     ONNXConvNoBiasOp newConvOp = rewriter.create<ONNXConvNoBiasOp>( |     ONNXConvOp newConvOp = rewriter.create<ONNXConvOp>( | ||||||
|             loc, tensorType, paddingOp.getResult(), convOp.getOperands()[1], |             loc, tensorType, paddingOp.getResult(), convOp.getOperands()[1], | ||||||
|  |             convOp.getOperands()[2], | ||||||
|             convOp.auto_padAttr(), convOp.dilationsAttr(), |             convOp.auto_padAttr(), convOp.dilationsAttr(), | ||||||
|             convOp.groupAttr(), convOp.kernel_shapeAttr(), |             convOp.groupAttr(), convOp.kernel_shapeAttr(), | ||||||
|             rewriter.getI64ArrayAttr(newConvPads), |             rewriter.getI64ArrayAttr(newConvPads), | ||||||
|  | @ -173,8 +174,8 @@ void ONNXMaxPoolSingleOutOp::getCanonicalizationPatterns( | ||||||
|     OwningRewritePatternList &results, MLIRContext *context) { |     OwningRewritePatternList &results, MLIRContext *context) { | ||||||
|   results.insert<MaxPoolSingleOutOpPaddingPattern>(context); |   results.insert<MaxPoolSingleOutOpPaddingPattern>(context); | ||||||
| } | } | ||||||
| /// on the ONNXConvNoBiasOp.
 | /// on the ONNXConvOp.
 | ||||||
| void ONNXConvNoBiasOp::getCanonicalizationPatterns( | void ONNXConvOp::getCanonicalizationPatterns( | ||||||
|     OwningRewritePatternList &results, MLIRContext *context) { |     OwningRewritePatternList &results, MLIRContext *context) { | ||||||
|   results.insert<SplitConvOpPattern>(context); |   results.insert<SplitConvOpPattern>(context); | ||||||
| } | } | ||||||
|  |  | ||||||
|  | @ -113,7 +113,7 @@ public: | ||||||
|         op->getName().getStringRef() != "onnx.ReduceSum" && |         op->getName().getStringRef() != "onnx.ReduceSum" && | ||||||
|         op->getName().getStringRef() != "onnx.Softmax" && |         op->getName().getStringRef() != "onnx.Softmax" && | ||||||
|         op->getName().getStringRef() != "onnx.Sqrt" && |         op->getName().getStringRef() != "onnx.Sqrt" && | ||||||
|         op->getName().getStringRef() != "onnx.ConvNoBias" && |         op->getName().getStringRef() != "onnx.Conv" && | ||||||
|         op->getName().getStringRef() != "onnx.PadConstantPad" && |         op->getName().getStringRef() != "onnx.PadConstantPad" && | ||||||
|         op->getName().getStringRef() != "onnx.PadConstantValuePad" && |         op->getName().getStringRef() != "onnx.PadConstantValuePad" && | ||||||
|         op->getName().getStringRef() != "onnx.BatchNormalizationTestMode" && |         op->getName().getStringRef() != "onnx.BatchNormalizationTestMode" && | ||||||
|  |  | ||||||
|  | @ -48,10 +48,12 @@ func @test_constant_pad(%arg0 : tensor<?x?xf32>) -> tensor<*xf32> { | ||||||
| 
 | 
 | ||||||
| // CHECK-LABEL: @test_conv_split(%{{.*}}: tensor<1x9x32x64xf32>, %{{.*}}: tensor<5x9x6x7xf32>) -> tensor<*xf32> { | // CHECK-LABEL: @test_conv_split(%{{.*}}: tensor<1x9x32x64xf32>, %{{.*}}: tensor<5x9x6x7xf32>) -> tensor<*xf32> { | ||||||
| func @test_conv_split(%arg0 : tensor<1x9x32x64xf32>, %arg1 : tensor<5x9x6x7xf32>) -> tensor<*xf32> { | func @test_conv_split(%arg0 : tensor<1x9x32x64xf32>, %arg1 : tensor<5x9x6x7xf32>) -> tensor<*xf32> { | ||||||
|   %0 = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "NOTSET", group = 1 : i64, pads = [2, 3, 4, 5]} : (tensor<1x9x32x64xf32>, tensor<5x9x6x7xf32>) -> tensor<*xf32> |   %cst = constant unit | ||||||
|  |   %0 = "onnx.Conv"(%arg0, %arg1, %cst) {auto_pad = "NOTSET", group = 1 : i64, pads = [2, 3, 4, 5]} : (tensor<1x9x32x64xf32>, tensor<5x9x6x7xf32>, none) -> tensor<*xf32> | ||||||
|   "std.return"(%0) : (tensor<*xf32>) -> () |   "std.return"(%0) : (tensor<*xf32>) -> () | ||||||
|  |   // CHECK-NEXT: %cst = constant unit | ||||||
|   // CHECK-NEXT: %0 = "onnx.PadConstantValuePad"(%arg0) {constant_value = 0.000000e+00 : f32, mode = "constant", pads = [0, 0, 2, 3, 0, 0, 4, 5]} : (tensor<1x9x32x64xf32>) -> tensor<1x9x38x72xf32> |   // CHECK-NEXT: %0 = "onnx.PadConstantValuePad"(%arg0) {constant_value = 0.000000e+00 : f32, mode = "constant", pads = [0, 0, 2, 3, 0, 0, 4, 5]} : (tensor<1x9x32x64xf32>) -> tensor<1x9x38x72xf32> | ||||||
|   // CHECK-NEXT: %1 = "onnx.ConvNoBias"(%0, %arg1) {auto_pad = "NOTSET", group = 1 : i64, pads = [0, 0, 0, 0]} : (tensor<1x9x38x72xf32>, tensor<5x9x6x7xf32>) -> tensor<*xf32> |   // CHECK-NEXT: %1 = "onnx.Conv"(%0, %arg1, %cst) {auto_pad = "NOTSET", group = 1 : i64, pads = [0, 0, 0, 0]} : (tensor<1x9x38x72xf32>, tensor<5x9x6x7xf32>, none) -> tensor<*xf32> | ||||||
|   // CHECK-NEXT: return %1 : tensor<*xf32> |   // CHECK-NEXT: return %1 : tensor<*xf32> | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -1149,7 +1149,8 @@ func @test_matmul7(%arg0 : tensor<5xf32>, %arg1 : tensor<5xf32>) -> tensor<*xf32 | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func @test_conv_no_bias_no_pad(%arg0 : tensor<1x2x32x64xf32>, %arg1 : tensor<5x2x6x7xf32>) -> tensor<*xf32> { | func @test_conv_no_bias_no_pad(%arg0 : tensor<1x2x32x64xf32>, %arg1 : tensor<5x2x6x7xf32>) -> tensor<*xf32> { | ||||||
|   %0 = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "NOTSET", group = 1 : i64} : (tensor<1x2x32x64xf32>, tensor<5x2x6x7xf32>) -> tensor<*xf32> |   %cst = constant unit | ||||||
|  |   %0 = "onnx.Conv"(%arg0, %arg1, %cst) {auto_pad = "NOTSET", group = 1 : i64} : (tensor<1x2x32x64xf32>, tensor<5x2x6x7xf32>, none) -> tensor<*xf32> | ||||||
|   "std.return"(%0) : (tensor<*xf32>) -> () |   "std.return"(%0) : (tensor<*xf32>) -> () | ||||||
| 
 | 
 | ||||||
|   // CHECK-LABEL: test_conv_no_bias_no_pad |   // CHECK-LABEL: test_conv_no_bias_no_pad | ||||||
|  | @ -1191,8 +1192,55 @@ func @test_conv_no_bias_no_pad(%arg0 : tensor<1x2x32x64xf32>, %arg1 : tensor<5x2 | ||||||
|   // CHECK: return [[RES]] : memref<1x5x27x58xf32> |   // CHECK: return [[RES]] : memref<1x5x27x58xf32> | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | func @test_conv_bias_no_pad(%arg0 : tensor<1x2x32x64xf32>, %arg1 : tensor<5x2x6x7xf32>, %arg2 : tensor<5xf32>) -> tensor<*xf32> { | ||||||
|  |   %0 = "onnx.Conv"(%arg0, %arg1, %arg2) {auto_pad = "NOTSET", group = 1 : i64} : (tensor<1x2x32x64xf32>, tensor<5x2x6x7xf32>, tensor<5xf32>) -> tensor<*xf32> | ||||||
|  |   "std.return"(%0) : (tensor<*xf32>) -> () | ||||||
|  | 
 | ||||||
|  |   // CHECK-LABEL: test_conv_bias_no_pad | ||||||
|  |   // CHECK: [[RES:%.+]] = alloc() : memref<1x5x27x58xf32> | ||||||
|  |   // CHECK: [[CONST0:%.+]] = constant 5 : index | ||||||
|  |   // CHECK: [[CONST1:%.+]] = constant 0.000000e+00 : f32 | ||||||
|  |   // CHECK: [[CONST2:%.+]] = constant 2 : index | ||||||
|  |   // CHECK: [[OUTER_LOOPS:%.+]]:2 = krnl.define_loops 2 | ||||||
|  |   // CHECK: [[OPT_OUTER_LOOPS:%.+]]:2 = krnl.optimize_loops  { | ||||||
|  |   // CHECK: krnl.return_loops [[OUTER_LOOPS]]#0, [[OUTER_LOOPS]]#1 | ||||||
|  |   // CHECK: } : () -> (!krnl.loop, !krnl.loop) | ||||||
|  | 
 | ||||||
|  |   // CHECK: krnl.iterate([[OPT_OUTER_LOOPS]]#0, [[OPT_OUTER_LOOPS]]#1) with ([[OUTER_LOOPS]]#0 -> %arg3 = 0 to 1, [[OUTER_LOOPS]]#1 -> %arg4 = 0 to 5) { | ||||||
|  |   // CHECK: [[SPATIAL_LOOPS:%.+]]:2 = krnl.define_loops 2 | ||||||
|  |   // CHECK: [[OPT_SPATIAL_LOOPS:%.+]]:2 = krnl.optimize_loops  { | ||||||
|  |   // CHECK: krnl.return_loops [[SPATIAL_LOOPS]]#0, [[SPATIAL_LOOPS]]#1 | ||||||
|  |   // CHECK: } : () -> (!krnl.loop, !krnl.loop) | ||||||
|  | 
 | ||||||
|  |   // CHECK: krnl.iterate([[OPT_SPATIAL_LOOPS]]#0, [[OPT_SPATIAL_LOOPS]]#1) with ([[SPATIAL_LOOPS]]#0 -> %arg5 = 0 to 27, [[SPATIAL_LOOPS]]#1 -> %arg6 = 0 to 58) { | ||||||
|  |   // CHECK: store [[CONST1]], [[RES]][%arg3, %arg4, %arg5, %arg6] : memref<1x5x27x58xf32> | ||||||
|  |   // CHECK: [[INNER_LOOPS:%.+]]:3 = krnl.define_loops 3 | ||||||
|  |   // CHECK: [[OPT_INNER_LOOPS:%.+]]:3 = krnl.optimize_loops  { | ||||||
|  |   // CHECK: krnl.return_loops [[INNER_LOOPS]]#0, [[INNER_LOOPS]]#1, [[INNER_LOOPS]]#2 | ||||||
|  |   // CHECK: } : () -> (!krnl.loop, !krnl.loop, !krnl.loop) | ||||||
|  | 
 | ||||||
|  |   // CHECK: krnl.iterate([[OPT_INNER_LOOPS]]#0, [[OPT_INNER_LOOPS]]#1, [[OPT_INNER_LOOPS]]#2) with ([[INNER_LOOPS]]#0 -> %arg7 = 0 to 2, [[INNER_LOOPS]]#1 -> %arg8 = 0 to 6, [[INNER_LOOPS]]#2 -> %arg9 = 0 to 7) { | ||||||
|  |   // CHECK: [[R1PLUSK1:%.+]] = addi %arg5, %arg8 : index | ||||||
|  |   // CHECK: [[R2PLUSK2:%.+]] = addi %arg6, %arg9 : index | ||||||
|  |   // CHECK: [[DATA:%.+]] = load %arg0[%arg3, %arg7, [[R1PLUSK1]], [[R2PLUSK2]]] : memref<1x2x32x64xf32> | ||||||
|  |   // CHECK: [[KERNEL:%.+]] = load %arg1[%arg4, %arg7, %arg8, %arg9] : memref<5x2x6x7xf32> | ||||||
|  |   // CHECK: [[ACC_RES:%.+]] = load %0[%arg3, %arg4, %arg5, %arg6] : memref<1x5x27x58xf32> | ||||||
|  |   // CHECK: [[MUL:%.+]] = mulf [[DATA]], [[KERNEL]] : f32 | ||||||
|  |   // CHECK: [[ADD:%.+]] = addf [[ACC_RES]], [[MUL]] : f32 | ||||||
|  |   // CHECK: store [[ADD]], [[RES]][%arg3, %arg4, %arg5, %arg6] : memref<1x5x27x58xf32> | ||||||
|  |   // CHECK: } | ||||||
|  |   // CHECK: [[BIAS1:%.+]] = load [[RES]][%arg3, %arg4, %arg5, %arg6] : memref<1x5x27x58xf32> | ||||||
|  |   // CHECK: [[BIAS2:%.+]] = load %arg2[%arg4] : memref<5xf32> | ||||||
|  |   // CHECK: [[BIAS3:%.+]] = mulf [[BIAS1]], [[BIAS2]] : f32 | ||||||
|  |   // CHECK: store [[BIAS3]], [[RES]][%arg3, %arg4, %arg5, %arg6] : memref<1x5x27x58xf32> | ||||||
|  |   // CHECK: } | ||||||
|  |   // CHECK: } | ||||||
|  |   // CHECK: return [[RES]] : memref<1x5x27x58xf32> | ||||||
|  | } | ||||||
|  | 
 | ||||||
| func @test_conv_no_bias_no_pad_w_group(%arg0 : tensor<1x9x32x64xf32>, %arg1 : tensor<5x3x6x7xf32>) -> tensor<*xf32> { | func @test_conv_no_bias_no_pad_w_group(%arg0 : tensor<1x9x32x64xf32>, %arg1 : tensor<5x3x6x7xf32>) -> tensor<*xf32> { | ||||||
|   %0 = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "NOTSET", group = 3 : i64} : (tensor<1x9x32x64xf32>, tensor<5x3x6x7xf32>) -> tensor<*xf32> |   %cst = constant unit | ||||||
|  |   %0 = "onnx.Conv"(%arg0, %arg1, %cst) {auto_pad = "NOTSET", group = 3 : i64} : (tensor<1x9x32x64xf32>, tensor<5x3x6x7xf32>, none) -> tensor<*xf32> | ||||||
|   "std.return"(%0) : (tensor<*xf32>) -> () |   "std.return"(%0) : (tensor<*xf32>) -> () | ||||||
| 
 | 
 | ||||||
|   // CHECK-LABEL: test_conv_no_bias_no_pad_w_group |   // CHECK-LABEL: test_conv_no_bias_no_pad_w_group | ||||||
|  | @ -1239,7 +1287,8 @@ func @test_conv_no_bias_no_pad_w_group(%arg0 : tensor<1x9x32x64xf32>, %arg1 : te | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func @test_conv_no_bias_no_pad_w_strides(%arg0 : tensor<1x9x32x64xf32>, %arg1 : tensor<5x9x6x7xf32>) -> tensor<*xf32> { | func @test_conv_no_bias_no_pad_w_strides(%arg0 : tensor<1x9x32x64xf32>, %arg1 : tensor<5x9x6x7xf32>) -> tensor<*xf32> { | ||||||
|   %0 = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "NOTSET", group = 1 : i64, strides = [2, 2]} : (tensor<1x9x32x64xf32>, tensor<5x9x6x7xf32>) -> tensor<*xf32> |   %cst = constant unit | ||||||
|  |   %0 = "onnx.Conv"(%arg0, %arg1, %cst) {auto_pad = "NOTSET", group = 1 : i64, strides = [2, 2]} : (tensor<1x9x32x64xf32>, tensor<5x9x6x7xf32>, none) -> tensor<*xf32> | ||||||
|   "std.return"(%0) : (tensor<*xf32>) -> () |   "std.return"(%0) : (tensor<*xf32>) -> () | ||||||
| 
 | 
 | ||||||
|   // CHECK-LABEL: test_conv_no_bias_no_pad_w_strides |   // CHECK-LABEL: test_conv_no_bias_no_pad_w_strides | ||||||
|  |  | ||||||
|  | @ -140,39 +140,42 @@ func @test_matmul_10(%arg0 : tensor<?x42x32xf32>, %arg1 : tensor<32xf32>) -> ten | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| //===----------------------------------------------------------------------===// | //===----------------------------------------------------------------------===// | ||||||
| /// Test shape inference for ConvNoBias operation and all its attributes. | /// Test shape inference for Conv (first with no bias) operation and all its attributes. | ||||||
| //===----------------------------------------------------------------------===// | //===----------------------------------------------------------------------===// | ||||||
| 
 | 
 | ||||||
| /// Default and required attributes for 1-D convolution. | /// Default and required attributes for 1-D convolution. | ||||||
| 
 | 
 | ||||||
| func @test_conv_no_bias_0(%arg0 : tensor<1x2x32xf32>, %arg1 : tensor<5x2x6xf32>) -> tensor<*xf32> { | func @test_conv_no_bias_0(%arg0 : tensor<1x2x32xf32>, %arg1 : tensor<5x2x6xf32>) -> tensor<*xf32> { | ||||||
|   %0 = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "NOTSET", group = 1 : i64} : (tensor<1x2x32xf32>, tensor<5x2x6xf32>) -> tensor<*xf32> |   %cst = constant unit | ||||||
|  |   %0 = "onnx.Conv"(%arg0, %arg1, %cst) {auto_pad = "NOTSET", group = 1 : i64} : (tensor<1x2x32xf32>, tensor<5x2x6xf32>, none) -> tensor<*xf32> | ||||||
|   "std.return"(%0) : (tensor<*xf32>) -> () |   "std.return"(%0) : (tensor<*xf32>) -> () | ||||||
| 
 | 
 | ||||||
|   // CHECK-LABEL: test_conv_no_bias_0 |   // CHECK-LABEL: test_conv_no_bias_0 | ||||||
|   // CHECK: [[RES_ATTR:%.+]] = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "NOTSET", dilations = [1], group = 1 : i64, kernel_shape = [6], pads = [0, 0], strides = [1]} : (tensor<1x2x32xf32>, tensor<5x2x6xf32>) -> tensor<1x5x27xf32> |   // CHECK: [[RES_ATTR:%.+]] = "onnx.Conv"(%arg0, %arg1, %cst) {auto_pad = "NOTSET", dilations = [1], group = 1 : i64, kernel_shape = [6], pads = [0, 0], strides = [1]} : (tensor<1x2x32xf32>, tensor<5x2x6xf32>, none) -> tensor<1x5x27xf32> | ||||||
|   // CHECK: return [[RES_ATTR]] : tensor<1x5x27xf32> |   // CHECK: return [[RES_ATTR]] : tensor<1x5x27xf32> | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| /// Default and required attributes. | /// Default and required attributes. | ||||||
| 
 | 
 | ||||||
| func @test_conv_no_bias_1(%arg0 : tensor<1x2x32x64xf32>, %arg1 : tensor<5x2x6x7xf32>) -> tensor<*xf32> { | func @test_conv_no_bias_1(%arg0 : tensor<1x2x32x64xf32>, %arg1 : tensor<5x2x6x7xf32>) -> tensor<*xf32> { | ||||||
|   %0 = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "NOTSET", group = 1 : i64} : (tensor<1x2x32x64xf32>, tensor<5x2x6x7xf32>) -> tensor<*xf32> |   %cst = constant unit | ||||||
|  |   %0 = "onnx.Conv"(%arg0, %arg1, %cst) {auto_pad = "NOTSET", group = 1 : i64} : (tensor<1x2x32x64xf32>, tensor<5x2x6x7xf32>, none) -> tensor<*xf32> | ||||||
|   "std.return"(%0) : (tensor<*xf32>) -> () |   "std.return"(%0) : (tensor<*xf32>) -> () | ||||||
| 
 | 
 | ||||||
|   // CHECK-LABEL: test_conv_no_bias_1 |   // CHECK-LABEL: test_conv_no_bias_1 | ||||||
|   // CHECK: [[RES_ATTR:%.+]] = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "NOTSET", dilations = [1, 1], group = 1 : i64, kernel_shape = [6, 7], pads = [0, 0, 0, 0], strides = [1, 1]} : (tensor<1x2x32x64xf32>, tensor<5x2x6x7xf32>) -> tensor<1x5x27x58xf32> |   // CHECK: [[RES_ATTR:%.+]] = "onnx.Conv"(%arg0, %arg1, %cst) {auto_pad = "NOTSET", dilations = [1, 1], group = 1 : i64, kernel_shape = [6, 7], pads = [0, 0, 0, 0], strides = [1, 1]} : (tensor<1x2x32x64xf32>, tensor<5x2x6x7xf32>, none) -> tensor<1x5x27x58xf32> | ||||||
|   // CHECK: return [[RES_ATTR]] : tensor<1x5x27x58xf32> |   // CHECK: return [[RES_ATTR]] : tensor<1x5x27x58xf32> | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| /// kernel_shape attribute. | /// kernel_shape attribute. | ||||||
| 
 | 
 | ||||||
| func @test_conv_no_bias_2(%arg0 : tensor<1x2x32x64xf32>, %arg1 : tensor<5x2x6x7xf32>) -> tensor<*xf32> { | func @test_conv_no_bias_2(%arg0 : tensor<1x2x32x64xf32>, %arg1 : tensor<5x2x6x7xf32>) -> tensor<*xf32> { | ||||||
|   %0 = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "NOTSET", group = 1 : i64, kernel_shape = [8, 9]} : (tensor<1x2x32x64xf32>, tensor<5x2x6x7xf32>) -> tensor<*xf32> |   %cst = constant unit | ||||||
|  |   %0 = "onnx.Conv"(%arg0, %arg1, %cst) {auto_pad = "NOTSET", group = 1 : i64, kernel_shape = [8, 9]} : (tensor<1x2x32x64xf32>, tensor<5x2x6x7xf32>, none) -> tensor<*xf32> | ||||||
|   "std.return"(%0) : (tensor<*xf32>) -> () |   "std.return"(%0) : (tensor<*xf32>) -> () | ||||||
| 
 | 
 | ||||||
|   // CHECK-LABEL: test_conv_no_bias_2 |   // CHECK-LABEL: test_conv_no_bias_2 | ||||||
|   // CHECK: [[RES_ATTR:%.+]] = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "NOTSET", dilations = [1, 1], group = 1 : i64, kernel_shape = [8, 9], pads = [0, 0, 0, 0], strides = [1, 1]} : (tensor<1x2x32x64xf32>, tensor<5x2x6x7xf32>) -> tensor<1x5x25x56xf32> |   // CHECK: [[RES_ATTR:%.+]] = "onnx.Conv"(%arg0, %arg1, %cst) {auto_pad = "NOTSET", dilations = [1, 1], group = 1 : i64, kernel_shape = [8, 9], pads = [0, 0, 0, 0], strides = [1, 1]} : (tensor<1x2x32x64xf32>, tensor<5x2x6x7xf32>, none) -> tensor<1x5x25x56xf32> | ||||||
|   // CHECK: return [[RES_ATTR]] : tensor<1x5x25x56xf32> |   // CHECK: return [[RES_ATTR]] : tensor<1x5x25x56xf32> | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | @ -180,53 +183,58 @@ func @test_conv_no_bias_2(%arg0 : tensor<1x2x32x64xf32>, %arg1 : tensor<5x2x6x7x | ||||||
| /// Use pads to make output size equal to input size by adding K - 1 to the result. | /// Use pads to make output size equal to input size by adding K - 1 to the result. | ||||||
| 
 | 
 | ||||||
| func @test_conv_no_bias_3(%arg0 : tensor<1x2x32x64xf32>, %arg1 : tensor<5x2x6x10xf32>) -> tensor<*xf32> { | func @test_conv_no_bias_3(%arg0 : tensor<1x2x32x64xf32>, %arg1 : tensor<5x2x6x10xf32>) -> tensor<*xf32> { | ||||||
|   %0 = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "NOTSET", group = 1 : i64, pads = [2, 4, 3, 5]} : (tensor<1x2x32x64xf32>, tensor<5x2x6x10xf32>) -> tensor<*xf32> |   %cst = constant unit | ||||||
|  |   %0 = "onnx.Conv"(%arg0, %arg1, %cst) {auto_pad = "NOTSET", group = 1 : i64, pads = [2, 4, 3, 5]} : (tensor<1x2x32x64xf32>, tensor<5x2x6x10xf32>, none) -> tensor<*xf32> | ||||||
|   "std.return"(%0) : (tensor<*xf32>) -> () |   "std.return"(%0) : (tensor<*xf32>) -> () | ||||||
| 
 | 
 | ||||||
|   // CHECK-LABEL: test_conv_no_bias_3 |   // CHECK-LABEL: test_conv_no_bias_3 | ||||||
|   // CHECK: [[RES_ATTR:%.+]] = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "NOTSET", dilations = [1, 1], group = 1 : i64, kernel_shape = [6, 10], pads = [2, 4, 3, 5], strides = [1, 1]} : (tensor<1x2x32x64xf32>, tensor<5x2x6x10xf32>) -> tensor<1x5x32x64xf32> |   // CHECK: [[RES_ATTR:%.+]] = "onnx.Conv"(%arg0, %arg1, %cst) {auto_pad = "NOTSET", dilations = [1, 1], group = 1 : i64, kernel_shape = [6, 10], pads = [2, 4, 3, 5], strides = [1, 1]} : (tensor<1x2x32x64xf32>, tensor<5x2x6x10xf32>, none) -> tensor<1x5x32x64xf32> | ||||||
|   // CHECK: return [[RES_ATTR]] : tensor<1x5x32x64xf32> |   // CHECK: return [[RES_ATTR]] : tensor<1x5x32x64xf32> | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| /// auto_pad set to SAME_UPPER and SAME_LOWER. | /// auto_pad set to SAME_UPPER and SAME_LOWER. | ||||||
| 
 | 
 | ||||||
| func @test_conv_no_bias_4(%arg0 : tensor<1x2x32x64xf32>, %arg1 : tensor<5x2x6x10xf32>) -> tensor<*xf32> { | func @test_conv_no_bias_4(%arg0 : tensor<1x2x32x64xf32>, %arg1 : tensor<5x2x6x10xf32>) -> tensor<*xf32> { | ||||||
|   %0 = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "SAME_UPPER", group = 1 : i64} : (tensor<1x2x32x64xf32>, tensor<5x2x6x10xf32>) -> tensor<*xf32> |   %cst = constant unit | ||||||
|  |   %0 = "onnx.Conv"(%arg0, %arg1, %cst) {auto_pad = "SAME_UPPER", group = 1 : i64} : (tensor<1x2x32x64xf32>, tensor<5x2x6x10xf32>, none) -> tensor<*xf32> | ||||||
|   "std.return"(%0) : (tensor<*xf32>) -> () |   "std.return"(%0) : (tensor<*xf32>) -> () | ||||||
| 
 | 
 | ||||||
|   // CHECK-LABEL: test_conv_no_bias_4 |   // CHECK-LABEL: test_conv_no_bias_4 | ||||||
|   // CHECK: [[RES_ATTR:%.+]] = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "NOTSET", dilations = [1, 1], group = 1 : i64, kernel_shape = [6, 10], pads = [2, 4, 3, 5], strides = [1, 1]} : (tensor<1x2x32x64xf32>, tensor<5x2x6x10xf32>) -> tensor<1x5x32x64xf32> |   // CHECK: [[RES_ATTR:%.+]] = "onnx.Conv"(%arg0, %arg1, %cst) {auto_pad = "NOTSET", dilations = [1, 1], group = 1 : i64, kernel_shape = [6, 10], pads = [2, 4, 3, 5], strides = [1, 1]} : (tensor<1x2x32x64xf32>, tensor<5x2x6x10xf32>, none) -> tensor<1x5x32x64xf32> | ||||||
|   // CHECK: return [[RES_ATTR]] : tensor<1x5x32x64xf32> |   // CHECK: return [[RES_ATTR]] : tensor<1x5x32x64xf32> | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func @test_conv_no_bias_5(%arg0 : tensor<1x2x32x64xf32>, %arg1 : tensor<5x2x6x10xf32>) -> tensor<*xf32> { | func @test_conv_no_bias_5(%arg0 : tensor<1x2x32x64xf32>, %arg1 : tensor<5x2x6x10xf32>) -> tensor<*xf32> { | ||||||
|   %0 = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "SAME_LOWER", group = 1 : i64} : (tensor<1x2x32x64xf32>, tensor<5x2x6x10xf32>) -> tensor<*xf32> |   %cst = constant unit | ||||||
|  |   %0 = "onnx.Conv"(%arg0, %arg1, %cst) {auto_pad = "SAME_LOWER", group = 1 : i64} : (tensor<1x2x32x64xf32>, tensor<5x2x6x10xf32>, none) -> tensor<*xf32> | ||||||
|   "std.return"(%0) : (tensor<*xf32>) -> () |   "std.return"(%0) : (tensor<*xf32>) -> () | ||||||
| 
 | 
 | ||||||
|   // CHECK-LABEL: test_conv_no_bias_5 |   // CHECK-LABEL: test_conv_no_bias_5 | ||||||
|   // CHECK: [[RES_ATTR:%.+]] = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "NOTSET", dilations = [1, 1], group = 1 : i64, kernel_shape = [6, 10], pads = [3, 5, 2, 4], strides = [1, 1]} : (tensor<1x2x32x64xf32>, tensor<5x2x6x10xf32>) -> tensor<1x5x32x64xf32> |   // CHECK: [[RES_ATTR:%.+]] = "onnx.Conv"(%arg0, %arg1, %cst) {auto_pad = "NOTSET", dilations = [1, 1], group = 1 : i64, kernel_shape = [6, 10], pads = [3, 5, 2, 4], strides = [1, 1]} : (tensor<1x2x32x64xf32>, tensor<5x2x6x10xf32>, none) -> tensor<1x5x32x64xf32> | ||||||
|   // CHECK: return [[RES_ATTR]] : tensor<1x5x32x64xf32> |   // CHECK: return [[RES_ATTR]] : tensor<1x5x32x64xf32> | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| /// auto_pad set to VALID. | /// auto_pad set to VALID. | ||||||
| 
 | 
 | ||||||
| func @test_conv_no_bias_6(%arg0 : tensor<1x2x32x64xf32>, %arg1 : tensor<5x2x6x10xf32>) -> tensor<*xf32> { | func @test_conv_no_bias_6(%arg0 : tensor<1x2x32x64xf32>, %arg1 : tensor<5x2x6x10xf32>) -> tensor<*xf32> { | ||||||
|   %0 = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "VALID", group = 1 : i64} : (tensor<1x2x32x64xf32>, tensor<5x2x6x10xf32>) -> tensor<*xf32> |   %cst = constant unit | ||||||
|  |   %0 = "onnx.Conv"(%arg0, %arg1, %cst) {auto_pad = "VALID", group = 1 : i64} : (tensor<1x2x32x64xf32>, tensor<5x2x6x10xf32>, none) -> tensor<*xf32> | ||||||
|   "std.return"(%0) : (tensor<*xf32>) -> () |   "std.return"(%0) : (tensor<*xf32>) -> () | ||||||
| 
 | 
 | ||||||
|   // CHECK-LABEL: test_conv_no_bias_6 |   // CHECK-LABEL: test_conv_no_bias_6 | ||||||
|   // CHECK: [[RES_ATTR:%.+]] = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "NOTSET", dilations = [1, 1], group = 1 : i64, kernel_shape = [6, 10], pads = [0, 0, 0, 0], strides = [1, 1]} : (tensor<1x2x32x64xf32>, tensor<5x2x6x10xf32>) -> tensor<1x5x27x55xf32> |   // CHECK: [[RES_ATTR:%.+]] = "onnx.Conv"(%arg0, %arg1, %cst) {auto_pad = "NOTSET", dilations = [1, 1], group = 1 : i64, kernel_shape = [6, 10], pads = [0, 0, 0, 0], strides = [1, 1]} : (tensor<1x2x32x64xf32>, tensor<5x2x6x10xf32>, none) -> tensor<1x5x27x55xf32> | ||||||
|   // CHECK: return [[RES_ATTR]] : tensor<1x5x27x55xf32> |   // CHECK: return [[RES_ATTR]] : tensor<1x5x27x55xf32> | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| /// With strides attribute. | /// With strides attribute. | ||||||
| 
 | 
 | ||||||
| func @test_conv_no_bias_7(%arg0 : tensor<1x2x32x64xf32>, %arg1 : tensor<5x2x6x7xf32>) -> tensor<*xf32> { | func @test_conv_no_bias_7(%arg0 : tensor<1x2x32x64xf32>, %arg1 : tensor<5x2x6x7xf32>) -> tensor<*xf32> { | ||||||
|   %0 = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "NOTSET", group = 1 : i64, strides = [2, 3]} : (tensor<1x2x32x64xf32>, tensor<5x2x6x7xf32>) -> tensor<*xf32> |   %cst = constant unit | ||||||
|  |   %0 = "onnx.Conv"(%arg0, %arg1, %cst) {auto_pad = "NOTSET", group = 1 : i64, strides = [2, 3]} : (tensor<1x2x32x64xf32>, tensor<5x2x6x7xf32>, none) -> tensor<*xf32> | ||||||
|   "std.return"(%0) : (tensor<*xf32>) -> () |   "std.return"(%0) : (tensor<*xf32>) -> () | ||||||
| 
 | 
 | ||||||
|   // CHECK-LABEL: test_conv_no_bias_7 |   // CHECK-LABEL: test_conv_no_bias_7 | ||||||
|   // CHECK: [[RES_ATTR:%.+]] = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "NOTSET", dilations = [1, 1], group = 1 : i64, kernel_shape = [6, 7], pads = [0, 0, 0, 0], strides = [2, 3]} : (tensor<1x2x32x64xf32>, tensor<5x2x6x7xf32>) -> tensor<1x5x14x20xf32> |   // CHECK: [[RES_ATTR:%.+]] = "onnx.Conv"(%arg0, %arg1, %cst) {auto_pad = "NOTSET", dilations = [1, 1], group = 1 : i64, kernel_shape = [6, 7], pads = [0, 0, 0, 0], strides = [2, 3]} : (tensor<1x2x32x64xf32>, tensor<5x2x6x7xf32>, none) -> tensor<1x5x14x20xf32> | ||||||
|   // CHECK: return [[RES_ATTR]] : tensor<1x5x14x20xf32> |   // CHECK: return [[RES_ATTR]] : tensor<1x5x14x20xf32> | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | @ -234,45 +242,61 @@ func @test_conv_no_bias_7(%arg0 : tensor<1x2x32x64xf32>, %arg1 : tensor<5x2x6x7x | ||||||
| /// The auto_pad will pas as if stride is equal to 1. | /// The auto_pad will pas as if stride is equal to 1. | ||||||
| 
 | 
 | ||||||
| func @test_conv_no_bias_8(%arg0 : tensor<1x2x32x64xf32>, %arg1 : tensor<5x2x6x7xf32>) -> tensor<*xf32> { | func @test_conv_no_bias_8(%arg0 : tensor<1x2x32x64xf32>, %arg1 : tensor<5x2x6x7xf32>) -> tensor<*xf32> { | ||||||
|   %0 = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "SAME_UPPER", group = 1 : i64, strides = [2, 3]} : (tensor<1x2x32x64xf32>, tensor<5x2x6x7xf32>) -> tensor<*xf32> |   %cst = constant unit | ||||||
|  |   %0 = "onnx.Conv"(%arg0, %arg1, %cst) {auto_pad = "SAME_UPPER", group = 1 : i64, strides = [2, 3]} : (tensor<1x2x32x64xf32>, tensor<5x2x6x7xf32>, none) -> tensor<*xf32> | ||||||
|   "std.return"(%0) : (tensor<*xf32>) -> () |   "std.return"(%0) : (tensor<*xf32>) -> () | ||||||
| 
 | 
 | ||||||
|   // CHECK-LABEL: test_conv_no_bias_8 |   // CHECK-LABEL: test_conv_no_bias_8 | ||||||
|   // CHECK: [[RES_ATTR:%.+]] = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "NOTSET", dilations = [1, 1], group = 1 : i64, kernel_shape = [6, 7], pads = [2, 3, 2, 3], strides = [2, 3]} : (tensor<1x2x32x64xf32>, tensor<5x2x6x7xf32>) -> tensor<1x5x16x22xf32> |   // CHECK: [[RES_ATTR:%.+]] = "onnx.Conv"(%arg0, %arg1, %cst) {auto_pad = "NOTSET", dilations = [1, 1], group = 1 : i64, kernel_shape = [6, 7], pads = [2, 3, 2, 3], strides = [2, 3]} : (tensor<1x2x32x64xf32>, tensor<5x2x6x7xf32>, none) -> tensor<1x5x16x22xf32> | ||||||
|   // CHECK: return [[RES_ATTR]] : tensor<1x5x16x22xf32> |   // CHECK: return [[RES_ATTR]] : tensor<1x5x16x22xf32> | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| /// dilations attribute. | /// dilations attribute. | ||||||
| 
 | 
 | ||||||
| func @test_conv_no_bias_9(%arg0 : tensor<1x2x32x64xf32>, %arg1 : tensor<5x2x6x7xf32>) -> tensor<*xf32> { | func @test_conv_no_bias_9(%arg0 : tensor<1x2x32x64xf32>, %arg1 : tensor<5x2x6x7xf32>) -> tensor<*xf32> { | ||||||
|   %0 = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "NOTSET", group = 1 : i64, dilations = [2, 3]} : (tensor<1x2x32x64xf32>, tensor<5x2x6x7xf32>) -> tensor<*xf32> |   %cst = constant unit | ||||||
|  |   %0 = "onnx.Conv"(%arg0, %arg1, %cst) {auto_pad = "NOTSET", group = 1 : i64, dilations = [2, 3]} : (tensor<1x2x32x64xf32>, tensor<5x2x6x7xf32>, none) -> tensor<*xf32> | ||||||
|   "std.return"(%0) : (tensor<*xf32>) -> () |   "std.return"(%0) : (tensor<*xf32>) -> () | ||||||
| 
 | 
 | ||||||
|   // CHECK-LABEL: test_conv_no_bias_9 |   // CHECK-LABEL: test_conv_no_bias_9 | ||||||
|   // CHECK: [[RES_ATTR:%.+]] = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "NOTSET", dilations = [2, 3], group = 1 : i64, kernel_shape = [6, 7], pads = [0, 0, 0, 0], strides = [1, 1]} : (tensor<1x2x32x64xf32>, tensor<5x2x6x7xf32>) -> tensor<1x5x22x46xf32> |   // CHECK: [[RES_ATTR:%.+]] = "onnx.Conv"(%arg0, %arg1, %cst) {auto_pad = "NOTSET", dilations = [2, 3], group = 1 : i64, kernel_shape = [6, 7], pads = [0, 0, 0, 0], strides = [1, 1]} : (tensor<1x2x32x64xf32>, tensor<5x2x6x7xf32>, none) -> tensor<1x5x22x46xf32> | ||||||
|   // CHECK: return [[RES_ATTR]] : tensor<1x5x22x46xf32> |   // CHECK: return [[RES_ATTR]] : tensor<1x5x22x46xf32> | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| /// dilations attribute with stride. | /// dilations attribute with stride. | ||||||
| 
 | 
 | ||||||
| func @test_conv_no_bias_10(%arg0 : tensor<1x2x32x64xf32>, %arg1 : tensor<5x2x6x7xf32>) -> tensor<*xf32> { | func @test_conv_no_bias_10(%arg0 : tensor<1x2x32x64xf32>, %arg1 : tensor<5x2x6x7xf32>) -> tensor<*xf32> { | ||||||
|   %0 = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "NOTSET", group = 1 : i64, dilations = [2, 3], strides = [2, 2]} : (tensor<1x2x32x64xf32>, tensor<5x2x6x7xf32>) -> tensor<*xf32> |   %cst = constant unit | ||||||
|  |   %0 = "onnx.Conv"(%arg0, %arg1, %cst) {auto_pad = "NOTSET", group = 1 : i64, dilations = [2, 3], strides = [2, 2]} : (tensor<1x2x32x64xf32>, tensor<5x2x6x7xf32>, none) -> tensor<*xf32> | ||||||
|   "std.return"(%0) : (tensor<*xf32>) -> () |   "std.return"(%0) : (tensor<*xf32>) -> () | ||||||
| 
 | 
 | ||||||
|   // CHECK-LABEL: test_conv_no_bias_10 |   // CHECK-LABEL: test_conv_no_bias_10 | ||||||
|   // CHECK: [[RES_ATTR:%.+]] = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "NOTSET", dilations = [2, 3], group = 1 : i64, kernel_shape = [6, 7], pads = [0, 0, 0, 0], strides = [2, 2]} : (tensor<1x2x32x64xf32>, tensor<5x2x6x7xf32>) -> tensor<1x5x11x23xf32> |   // CHECK: [[RES_ATTR:%.+]] = "onnx.Conv"(%arg0, %arg1, %cst) {auto_pad = "NOTSET", dilations = [2, 3], group = 1 : i64, kernel_shape = [6, 7], pads = [0, 0, 0, 0], strides = [2, 2]} : (tensor<1x2x32x64xf32>, tensor<5x2x6x7xf32>, none) -> tensor<1x5x11x23xf32> | ||||||
|   // CHECK: return [[RES_ATTR]] : tensor<1x5x11x23xf32> |   // CHECK: return [[RES_ATTR]] : tensor<1x5x11x23xf32> | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| /// dilations attribute with auto_pad set to SAME_UPPER. | /// dilations attribute with auto_pad set to SAME_UPPER. | ||||||
| 
 | 
 | ||||||
| func @test_conv_no_bias_11(%arg0 : tensor<1x2x32x64xf32>, %arg1 : tensor<5x2x6x7xf32>) -> tensor<*xf32> { | func @test_conv_no_bias_11(%arg0 : tensor<1x2x32x64xf32>, %arg1 : tensor<5x2x6x7xf32>) -> tensor<*xf32> { | ||||||
|   %0 = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "SAME_UPPER", group = 1 : i64, dilations = [2, 3]} : (tensor<1x2x32x64xf32>, tensor<5x2x6x7xf32>) -> tensor<*xf32> |   %cst = constant unit | ||||||
|  |   %0 = "onnx.Conv"(%arg0, %arg1, %cst) {auto_pad = "SAME_UPPER", group = 1 : i64, dilations = [2, 3]} : (tensor<1x2x32x64xf32>, tensor<5x2x6x7xf32>, none) -> tensor<*xf32> | ||||||
|   "std.return"(%0) : (tensor<*xf32>) -> () |   "std.return"(%0) : (tensor<*xf32>) -> () | ||||||
| } | 
 | ||||||
|   // CHECK-LABEL: test_conv_no_bias_11 |   // CHECK-LABEL: test_conv_no_bias_11 | ||||||
|   // CHECK: [[RES_ATTR:%.+]] = "onnx.ConvNoBias"(%arg0, %arg1) {auto_pad = "NOTSET", dilations = [2, 3], group = 1 : i64, kernel_shape = [6, 7], pads = [5, 9, 5, 9], strides = [1, 1]} : (tensor<1x2x32x64xf32>, tensor<5x2x6x7xf32>) -> tensor<1x5x32x64xf32> |   // CHECK: [[RES_ATTR:%.+]] = "onnx.Conv"(%arg0, %arg1, %cst) {auto_pad = "NOTSET", dilations = [2, 3], group = 1 : i64, kernel_shape = [6, 7], pads = [5, 9, 5, 9], strides = [1, 1]} : (tensor<1x2x32x64xf32>, tensor<5x2x6x7xf32>, none) -> tensor<1x5x32x64xf32> | ||||||
|   // CHECK: return [[RES_ATTR]] : tensor<1x5x32x64xf32> |   // CHECK: return [[RES_ATTR]] : tensor<1x5x32x64xf32> | ||||||
|  | } | ||||||
|  |   | ||||||
|  | // Test convolution with bias input. | ||||||
|  | 
 | ||||||
|  | func @test_conv_12(%arg0 : tensor<1x2x32xf32>, %arg1 : tensor<5x2x6xf32>, %arg2 : tensor<5xf32>) -> tensor<*xf32> { | ||||||
|  |   %0 = "onnx.Conv"(%arg0, %arg1, %arg2) {auto_pad = "NOTSET", group = 1 : i64} : (tensor<1x2x32xf32>, tensor<5x2x6xf32>, tensor<5xf32>) -> tensor<*xf32> | ||||||
|  |   "std.return"(%0) : (tensor<*xf32>) -> () | ||||||
|  | 
 | ||||||
|  |   // CHECK-LABEL: test_conv_12 | ||||||
|  |   // CHECK: [[RES_ATTR:%.+]] = "onnx.Conv"(%arg0, %arg1, %arg2) {auto_pad = "NOTSET", dilations = [1], group = 1 : i64, kernel_shape = [6], pads = [0, 0], strides = [1]} : (tensor<1x2x32xf32>, tensor<5x2x6xf32>, tensor<5xf32>) -> tensor<1x5x27xf32> | ||||||
|  |   // CHECK: return [[RES_ATTR]] : tensor<1x5x27xf32> | ||||||
|  | } | ||||||
| 
 | 
 | ||||||
| //===----------------------------------------------------------------------===// | //===----------------------------------------------------------------------===// | ||||||
| /// Test shape inference for PadConstantValuePad. | /// Test shape inference for PadConstantValuePad. | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue