Introduce helper class to generate KRNL code and apply it to Convolution (#93)
* helper to gen krnl code, applied to conv * suggested changes, name, removed set insertion point * format * suggested changes * added comments and made a small name change
This commit is contained in:
		
							parent
							
								
									9c398c0121
								
							
						
					
					
						commit
						fcb5f35993
					
				| 
						 | 
					@ -1,2 +1,3 @@
 | 
				
			||||||
BasedOnStyle: LLVM
 | 
					BasedOnStyle: LLVM
 | 
				
			||||||
AlwaysBreakTemplateDeclarations: Yes
 | 
					AlwaysBreakTemplateDeclarations: Yes
 | 
				
			||||||
 | 
					AlignAfterOpenBracket: DontAlign
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -12,8 +12,7 @@ struct ONNXConvNoBiasOpLowering : public ConversionPattern {
 | 
				
			||||||
  ONNXConvNoBiasOpLowering(MLIRContext *ctx)
 | 
					  ONNXConvNoBiasOpLowering(MLIRContext *ctx)
 | 
				
			||||||
      : ConversionPattern(mlir::ONNXConvNoBiasOp::getOperationName(), 1, ctx) {}
 | 
					      : ConversionPattern(mlir::ONNXConvNoBiasOp::getOperationName(), 1, ctx) {}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  PatternMatchResult
 | 
					  PatternMatchResult matchAndRewrite(Operation *op, ArrayRef<Value> operands,
 | 
				
			||||||
  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
 | 
					 | 
				
			||||||
      ConversionPatternRewriter &rewriter) const final {
 | 
					      ConversionPatternRewriter &rewriter) const final {
 | 
				
			||||||
    auto loc = op->getLoc();
 | 
					    auto loc = op->getLoc();
 | 
				
			||||||
    // Insert an allocation and deallocation for the result of this operation.
 | 
					    // Insert an allocation and deallocation for the result of this operation.
 | 
				
			||||||
| 
						 | 
					@ -25,12 +24,14 @@ struct ONNXConvNoBiasOpLowering : public ConversionPattern {
 | 
				
			||||||
    if (hasAllConstantDimensions(memRefType))
 | 
					    if (hasAllConstantDimensions(memRefType))
 | 
				
			||||||
      alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc);
 | 
					      alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc);
 | 
				
			||||||
    else
 | 
					    else
 | 
				
			||||||
      alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc,
 | 
					      alloc = insertAllocAndDealloc(
 | 
				
			||||||
                                    {operands[0]});
 | 
					          memRefType, loc, rewriter, insertDealloc, {operands[0]});
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    auto resultShape = memRefType.getShape();
 | 
					    auto resultShape = memRefType.getShape();
 | 
				
			||||||
    auto inputShape = operands[0].getType().cast<MemRefType>().getShape();
 | 
					    auto &inputOperand = operands[0];
 | 
				
			||||||
    auto kernelShape = operands[1].getType().cast<MemRefType>().getShape();
 | 
					    auto inputShape = inputOperand.getType().cast<MemRefType>().getShape();
 | 
				
			||||||
 | 
					    auto &kernelOperand = operands[1];
 | 
				
			||||||
 | 
					    auto kernelShape = kernelOperand.getType().cast<MemRefType>().getShape();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    // R = ConvNoBias(D, K)
 | 
					    // R = ConvNoBias(D, K)
 | 
				
			||||||
    //
 | 
					    //
 | 
				
			||||||
| 
						 | 
					@ -91,123 +92,82 @@ struct ONNXConvNoBiasOpLowering : public ConversionPattern {
 | 
				
			||||||
        loc, FloatAttr::get(memRefType.getElementType(), 0));
 | 
					        loc, FloatAttr::get(memRefType.getElementType(), 0));
 | 
				
			||||||
    Value subchannels;
 | 
					    Value subchannels;
 | 
				
			||||||
    if (kernelShape[1] < 0) {
 | 
					    if (kernelShape[1] < 0) {
 | 
				
			||||||
      subchannels =
 | 
					      subchannels = rewriter.create<DimOp>(loc, kernelOperand, 1).getResult();
 | 
				
			||||||
          rewriter.create<DimOp>(loc, operands[1], 1).getResult();
 | 
					 | 
				
			||||||
    } else {
 | 
					    } else {
 | 
				
			||||||
      subchannels = rewriter.create<ConstantIndexOp>(
 | 
					      subchannels = rewriter.create<ConstantIndexOp>(loc, kernelShape[1]);
 | 
				
			||||||
          loc, kernelShape[1]);
 | 
					 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    // 1. Define outer loops and emit empty optimization block:
 | 
					    // 1. Define outer loops and emit empty optimization block:
 | 
				
			||||||
    int64_t nOuterLoops = (group > 1) ? 3 : 2;
 | 
					    int64_t nOuterLoops = (group > 1) ? 3 : 2;
 | 
				
			||||||
    std::vector<Value> outerLoops;
 | 
					    BuildKrnlLoop outerLoops(rewriter, loc, nOuterLoops);
 | 
				
			||||||
    std::vector<Value> optimizedOuterLoops;
 | 
					    outerLoops.createDefineAndOptimizeOp();
 | 
				
			||||||
    Block *optimizationBlock = defineLoops(rewriter, loc, outerLoops,
 | 
					 | 
				
			||||||
        optimizedOuterLoops, nOuterLoops);
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    // Prepare iteration arguments over outer loop nest.
 | 
					 | 
				
			||||||
    KrnlIterateOperandPack pack(
 | 
					 | 
				
			||||||
        rewriter, outerLoops, optimizedOuterLoops);
 | 
					 | 
				
			||||||
    //   for n = 0 .. N:
 | 
					    //   for n = 0 .. N:
 | 
				
			||||||
    pack.pushConstantBound(0);
 | 
					    int nIndex = outerLoops.pushBounds(0, inputOperand, 0);
 | 
				
			||||||
    if (inputShape[0] < 0)
 | 
					 | 
				
			||||||
      pack.pushOperandBound(
 | 
					 | 
				
			||||||
          rewriter.create<DimOp>(loc, operands[0], 0).getResult());
 | 
					 | 
				
			||||||
    else
 | 
					 | 
				
			||||||
      pack.pushConstantBound(inputShape[0]);
 | 
					 | 
				
			||||||
    //   for g = 0 .. N:
 | 
					    //   for g = 0 .. N:
 | 
				
			||||||
    if (group > 1) {
 | 
					    int gIndex = -1;
 | 
				
			||||||
      pack.pushConstantBound(0);
 | 
					    if (group > 1)
 | 
				
			||||||
      pack.pushConstantBound(group);
 | 
					      gIndex = outerLoops.pushBounds(0, group);
 | 
				
			||||||
    }
 | 
					 | 
				
			||||||
    //   for m = 0 .. kernelsPerGroup:
 | 
					    //   for m = 0 .. kernelsPerGroup:
 | 
				
			||||||
    pack.pushConstantBound(0);
 | 
					    int mIndex = outerLoops.pushBounds(0, kernelsPerGroup);
 | 
				
			||||||
    pack.pushConstantBound(kernelsPerGroup);
 | 
					    // Outer loop iteration
 | 
				
			||||||
    // Outer loop iteration.
 | 
					    outerLoops.createIterateOp();
 | 
				
			||||||
    auto iterateOp = rewriter.create<KrnlIterateOp>(loc, pack);
 | 
					    rewriter.setInsertionPointToStart(outerLoops.getIterateBlock());
 | 
				
			||||||
    Block &outerIterationBlock = iterateOp.bodyRegion().front();
 | 
					 | 
				
			||||||
    // Emit optimizations for outer loops:
 | 
					 | 
				
			||||||
    rewriter.setInsertionPointToEnd(optimizationBlock);
 | 
					 | 
				
			||||||
    rewriter.create<KrnlReturnLoopsOp>(loc, outerLoops);
 | 
					 | 
				
			||||||
    rewriter.setInsertionPointToStart(&outerIterationBlock);
 | 
					 | 
				
			||||||
    {
 | 
					    {
 | 
				
			||||||
      // 2. Emit the body of the outer loop nest.
 | 
					      // 2. Emit the body of the outer loop nest.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
      // 2.1 Compute kernel order number: kernel = g * kernelsPerGroup + m;
 | 
					      // 2.1 Compute kernel order number: kernel = g * kernelsPerGroup + m;
 | 
				
			||||||
      // If group is not set then the value of the kernel ID is
 | 
					      // If group is not set then the value of the kernel ID is
 | 
				
			||||||
      // identical to that of the loop over kernels.
 | 
					      // identical to that of the loop over kernels.
 | 
				
			||||||
      Value kernel = outerIterationBlock.getArguments()[1];
 | 
					      Value kernel = outerLoops.getInductionVar(mIndex);
 | 
				
			||||||
      if (group > 1) {
 | 
					      if (group > 1) {
 | 
				
			||||||
        // Middle loop is over groups and third loop is over the
 | 
					        // Middle loop is over groups and third loop is over the
 | 
				
			||||||
        // kernel identifiers in the current group.
 | 
					        // kernel identifiers in the current group.
 | 
				
			||||||
        auto kernelsOffset = rewriter.create<MulIOp>(loc,
 | 
					        auto kernelsOffset = rewriter.create<MulIOp>(
 | 
				
			||||||
            outerIterationBlock.getArguments()[1],
 | 
					            loc, outerLoops.getInductionVar(gIndex), kernelsPerGroupValue);
 | 
				
			||||||
            kernelsPerGroupValue);
 | 
					        kernel = rewriter.create<AddIOp>(
 | 
				
			||||||
        kernel = rewriter.create<AddIOp>(loc, kernelsOffset,
 | 
					            loc, kernelsOffset, outerLoops.getInductionVar(mIndex));
 | 
				
			||||||
            outerIterationBlock.getArguments()[2]);
 | 
					 | 
				
			||||||
      }
 | 
					      }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
      // 2.2 Define spatial loops
 | 
					      // 2.2 Define spatial loops
 | 
				
			||||||
      int64_t nSpatialLoops = resultShape.size() - 2;
 | 
					      int64_t nSpatialLoops = resultShape.size() - 2;
 | 
				
			||||||
      std::vector<Value> spatialLoops;
 | 
					      BuildKrnlLoop spatialLoops(rewriter, loc, nSpatialLoops);
 | 
				
			||||||
      std::vector<Value> optimizedSpatialLoops;
 | 
					      spatialLoops.createDefineAndOptimizeOp();
 | 
				
			||||||
      Block *optSpatialLoopBlock = defineLoops(rewriter, loc, spatialLoops,
 | 
					 | 
				
			||||||
        optimizedSpatialLoops, nSpatialLoops);
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
      // 2.3 Prepare iteration arguments for spatial loop nest.
 | 
					 | 
				
			||||||
      KrnlIterateOperandPack spatialPack(
 | 
					 | 
				
			||||||
        rewriter, spatialLoops, optimizedSpatialLoops);
 | 
					 | 
				
			||||||
      for (int i = 2; i < resultShape.size(); ++i)
 | 
					      for (int i = 2; i < resultShape.size(); ++i)
 | 
				
			||||||
        addDimensionToPack(rewriter, loc, spatialPack, alloc, i);
 | 
					        spatialLoops.pushBounds(0, alloc, i);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
      // 2.4 Emit loop nest over output spatial dimensions.
 | 
					      // 2.4 Emit loop nest over output spatial dimensions.
 | 
				
			||||||
      //   for rX = 0 .. RX
 | 
					      //   for rX = 0 .. RX
 | 
				
			||||||
      auto spatialIterateOp =
 | 
					      spatialLoops.createIterateOp();
 | 
				
			||||||
          rewriter.create<KrnlIterateOp>(loc, spatialPack);
 | 
					      rewriter.setInsertionPointToStart(spatialLoops.getIterateBlock());
 | 
				
			||||||
      Block &spatialIterationBlock = spatialIterateOp.bodyRegion().front();
 | 
					
 | 
				
			||||||
      // 2.5 Emit optimizations for outer loops:
 | 
					 | 
				
			||||||
      rewriter.setInsertionPointToEnd(optSpatialLoopBlock);
 | 
					 | 
				
			||||||
      rewriter.create<KrnlReturnLoopsOp>(loc, spatialLoops);
 | 
					 | 
				
			||||||
      rewriter.setInsertionPointToStart(&spatialIterationBlock);
 | 
					 | 
				
			||||||
      {
 | 
					      {
 | 
				
			||||||
        // 3. Emit the body of the spatial loop nest.
 | 
					        // 3. Emit the body of the spatial loop nest.
 | 
				
			||||||
        // 3.1 Emit: R[n][kernel][r1][r2] = 0;
 | 
					        // 3.1 Emit: R[n][kernel][r1][r2] = 0;
 | 
				
			||||||
        SmallVector<Value, 4> resultIndices;
 | 
					        SmallVector<Value, 4> resultIndices;
 | 
				
			||||||
        // n
 | 
					        // n
 | 
				
			||||||
        resultIndices.emplace_back(outerIterationBlock.getArguments()[0]);
 | 
					        resultIndices.emplace_back(outerLoops.getInductionVar(nIndex));
 | 
				
			||||||
        // kernel
 | 
					        // kernel
 | 
				
			||||||
        resultIndices.emplace_back(kernel);
 | 
					        resultIndices.emplace_back(kernel);
 | 
				
			||||||
        // rX
 | 
					        // rX
 | 
				
			||||||
        for (auto arg : spatialIterationBlock.getArguments())
 | 
					        for (auto arg : spatialLoops.getIterateBlock()->getArguments())
 | 
				
			||||||
          resultIndices.emplace_back(arg);
 | 
					          resultIndices.emplace_back(arg);
 | 
				
			||||||
        // Store initializer value into output location.
 | 
					        // Store initializer value into output location.
 | 
				
			||||||
        rewriter.create<StoreOp>(loc, zero, alloc, resultIndices);
 | 
					        rewriter.create<StoreOp>(loc, zero, alloc, resultIndices);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        // 3.2 Define inner loops.
 | 
					        // 3.2 Define inner loops.
 | 
				
			||||||
        int64_t nInnerLoops = 1 + (kernelShape.size() - 2);
 | 
					        int64_t nInnerLoops = 1 + (kernelShape.size() - 2);
 | 
				
			||||||
        std::vector<Value> innerLoops;
 | 
					        BuildKrnlLoop innerLoops(rewriter, loc, nInnerLoops);
 | 
				
			||||||
        std::vector<Value> optimizedInnerLoops;
 | 
					        innerLoops.createDefineAndOptimizeOp();
 | 
				
			||||||
        Block *optInnerLoopBlock = defineLoops(rewriter, loc, innerLoops,
 | 
					 | 
				
			||||||
            optimizedInnerLoops, nInnerLoops);
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        // 3.3 Prepare iteration arguments for inner loop nest.
 | 
					 | 
				
			||||||
        KrnlIterateOperandPack innerPack(
 | 
					 | 
				
			||||||
            rewriter, innerLoops, optimizedInnerLoops);
 | 
					 | 
				
			||||||
        //   for c = 0 .. C/group
 | 
					        //   for c = 0 .. C/group
 | 
				
			||||||
        innerPack.pushConstantBound(0);
 | 
					        int cIndex = innerLoops.pushBounds(0, kernelShape[1]);
 | 
				
			||||||
        innerPack.pushConstantBound(kernelShape[1]);
 | 
					 | 
				
			||||||
        //   for Kx = 0 .. KX
 | 
					        //   for Kx = 0 .. KX
 | 
				
			||||||
        for (int i = 2; i < kernelShape.size(); ++i)
 | 
					        for (int i = 2; i < kernelShape.size(); ++i)
 | 
				
			||||||
          addDimensionToPack(rewriter, loc, innerPack, operands[1], i);
 | 
					          innerLoops.pushBounds(0, kernelOperand, i);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        // 3.4 Emit inner loop nest.
 | 
					        // 3.4 Emit inner loop nest.
 | 
				
			||||||
        auto innerIterateOp =
 | 
					        innerLoops.createIterateOp();
 | 
				
			||||||
            rewriter.create<KrnlIterateOp>(loc, innerPack);
 | 
					        rewriter.setInsertionPointToStart(innerLoops.getIterateBlock());
 | 
				
			||||||
        Block &innerIterationBlock = innerIterateOp.bodyRegion().front();
 | 
					
 | 
				
			||||||
        // 3.5 Emit optimizations for outer loops:
 | 
					 | 
				
			||||||
        rewriter.setInsertionPointToEnd(optInnerLoopBlock);
 | 
					 | 
				
			||||||
        rewriter.create<KrnlReturnLoopsOp>(loc, innerLoops);
 | 
					 | 
				
			||||||
        rewriter.setInsertionPointToStart(&innerIterationBlock);
 | 
					 | 
				
			||||||
        {
 | 
					        {
 | 
				
			||||||
          // 4. Emit inner loop body
 | 
					          // 4. Emit inner loop body
 | 
				
			||||||
          // R[n][kernel][r1][r2] =
 | 
					          // R[n][kernel][r1][r2] =
 | 
				
			||||||
| 
						 | 
					@ -217,13 +177,13 @@ struct ONNXConvNoBiasOpLowering : public ConversionPattern {
 | 
				
			||||||
          // 4.1 Prepare indices for accesing the data tensor.
 | 
					          // 4.1 Prepare indices for accesing the data tensor.
 | 
				
			||||||
          SmallVector<Value, 4> dataIndices;
 | 
					          SmallVector<Value, 4> dataIndices;
 | 
				
			||||||
          // n
 | 
					          // n
 | 
				
			||||||
          dataIndices.emplace_back(outerIterationBlock.getArguments()[0]);
 | 
					          dataIndices.emplace_back(outerLoops.getInductionVar(nIndex));
 | 
				
			||||||
          // g * (C / group) + c
 | 
					          // g * (C / group) + c
 | 
				
			||||||
          Value channelDepth = innerIterationBlock.getArguments()[0];
 | 
					          Value channelDepth = innerLoops.getInductionVar(cIndex);
 | 
				
			||||||
          if (group > 1)
 | 
					          if (group > 1)
 | 
				
			||||||
            channelDepth = rewriter.create<AddIOp>(loc, channelDepth,
 | 
					            channelDepth = rewriter.create<AddIOp>(loc, channelDepth,
 | 
				
			||||||
                rewriter.create<MulIOp>(loc, subchannels,
 | 
					                rewriter.create<MulIOp>(
 | 
				
			||||||
                    outerIterationBlock.getArguments()[1]));
 | 
					                    loc, subchannels, outerLoops.getInductionVar(gIndex)));
 | 
				
			||||||
          dataIndices.emplace_back(channelDepth);
 | 
					          dataIndices.emplace_back(channelDepth);
 | 
				
			||||||
          // sX * rX + kX
 | 
					          // sX * rX + kX
 | 
				
			||||||
          auto stridesAttribute = convOp.stridesAttr();
 | 
					          auto stridesAttribute = convOp.stridesAttr();
 | 
				
			||||||
| 
						 | 
					@ -233,15 +193,14 @@ struct ONNXConvNoBiasOpLowering : public ConversionPattern {
 | 
				
			||||||
            for (auto stride : stridesAttribute.getValue())
 | 
					            for (auto stride : stridesAttribute.getValue())
 | 
				
			||||||
              strides.emplace_back(stride.cast<IntegerAttr>().getInt());
 | 
					              strides.emplace_back(stride.cast<IntegerAttr>().getInt());
 | 
				
			||||||
          for (int i = 0; i < kernelShape.size() - 2; ++i) {
 | 
					          for (int i = 0; i < kernelShape.size() - 2; ++i) {
 | 
				
			||||||
            Value spatialIndex = spatialIterationBlock.getArguments()[i];
 | 
					            Value spatialIndex = spatialLoops.getInductionVar(i);
 | 
				
			||||||
            // If strides are present then emit the correct access index.
 | 
					            // If strides are present then emit the correct access index.
 | 
				
			||||||
            if (stridesAttribute && strides[i] > 1)
 | 
					            if (stridesAttribute && strides[i] > 1)
 | 
				
			||||||
              spatialIndex = rewriter.create<MulIOp>(loc,
 | 
					              spatialIndex = rewriter.create<MulIOp>(loc,
 | 
				
			||||||
                  rewriter.create<ConstantIndexOp>(loc, strides[i]),
 | 
					                  rewriter.create<ConstantIndexOp>(loc, strides[i]),
 | 
				
			||||||
                  spatialIterationBlock.getArguments()[i]);
 | 
					                  spatialLoops.getInductionVar(i));
 | 
				
			||||||
            dataIndices.emplace_back(
 | 
					            dataIndices.emplace_back(rewriter.create<AddIOp>(
 | 
				
			||||||
                rewriter.create<AddIOp>(loc, spatialIndex,
 | 
					                loc, spatialIndex, innerLoops.getInductionVar(i + 1)));
 | 
				
			||||||
                    innerIterationBlock.getArguments()[i+1]));
 | 
					 | 
				
			||||||
          }
 | 
					          }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
          // 4.2 Prepare indices for accessing the kernel tensor.
 | 
					          // 4.2 Prepare indices for accessing the kernel tensor.
 | 
				
			||||||
| 
						 | 
					@ -249,17 +208,16 @@ struct ONNXConvNoBiasOpLowering : public ConversionPattern {
 | 
				
			||||||
          // kernel
 | 
					          // kernel
 | 
				
			||||||
          kernelIndices.emplace_back(kernel);
 | 
					          kernelIndices.emplace_back(kernel);
 | 
				
			||||||
          // c
 | 
					          // c
 | 
				
			||||||
          kernelIndices.emplace_back(innerIterationBlock.getArguments()[0]);
 | 
					          kernelIndices.emplace_back(innerLoops.getInductionVar(cIndex));
 | 
				
			||||||
          // kX
 | 
					          // kX
 | 
				
			||||||
          for (int i = 0; i < kernelShape.size() - 2; ++i)
 | 
					          for (int i = 0; i < kernelShape.size() - 2; ++i)
 | 
				
			||||||
            kernelIndices.emplace_back(
 | 
					            kernelIndices.emplace_back(innerLoops.getInductionVar(i + 1));
 | 
				
			||||||
                innerIterationBlock.getArguments()[i+1]);
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
          // 4.3 Compute convolution.
 | 
					          // 4.3 Compute convolution.
 | 
				
			||||||
          auto loadData =
 | 
					          auto loadData =
 | 
				
			||||||
              rewriter.create<LoadOp>(loc, operands[0], dataIndices);
 | 
					              rewriter.create<LoadOp>(loc, inputOperand, dataIndices);
 | 
				
			||||||
          auto loadKernel =
 | 
					          auto loadKernel =
 | 
				
			||||||
              rewriter.create<LoadOp>(loc, operands[1], kernelIndices);
 | 
					              rewriter.create<LoadOp>(loc, kernelOperand, kernelIndices);
 | 
				
			||||||
          auto loadPartialSum =
 | 
					          auto loadPartialSum =
 | 
				
			||||||
              rewriter.create<LoadOp>(loc, alloc, resultIndices);
 | 
					              rewriter.create<LoadOp>(loc, alloc, resultIndices);
 | 
				
			||||||
          Value result = rewriter.create<AddFOp>(loc, loadPartialSum,
 | 
					          Value result = rewriter.create<AddFOp>(loc, loadPartialSum,
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -1,4 +1,5 @@
 | 
				
			||||||
#include "mlir/Dialect/AffineOps/AffineOps.h"
 | 
					#include "mlir/Dialect/AffineOps/AffineOps.h"
 | 
				
			||||||
 | 
					#include "mlir/Dialect/StandardOps/Ops.h"
 | 
				
			||||||
#include "mlir/IR/AffineExpr.h"
 | 
					#include "mlir/IR/AffineExpr.h"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
#include "src/dialect/krnl/krnl_ops.hpp"
 | 
					#include "src/dialect/krnl/krnl_ops.hpp"
 | 
				
			||||||
| 
						 | 
					@ -9,9 +10,8 @@ namespace onnf {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
using namespace mlir;
 | 
					using namespace mlir;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
ParseResult
 | 
					ParseResult KrnlDialectOperandParser::ParseOptionalOperand(
 | 
				
			||||||
KrnlDialectOperandParser::ParseOptionalOperand(const Type &operandType,
 | 
					    const Type &operandType, Value &operand) {
 | 
				
			||||||
                                               Value &operand) {
 | 
					 | 
				
			||||||
  // If operand queue is empty, parse more operands and cache them.
 | 
					  // If operand queue is empty, parse more operands and cache them.
 | 
				
			||||||
  if (_operandRefQueue.empty()) {
 | 
					  if (_operandRefQueue.empty()) {
 | 
				
			||||||
    // Parse operand types:
 | 
					    // Parse operand types:
 | 
				
			||||||
| 
						 | 
					@ -19,7 +19,7 @@ KrnlDialectOperandParser::ParseOptionalOperand(const Type &operandType,
 | 
				
			||||||
    _parser.parseOperandList(operand_refs);
 | 
					    _parser.parseOperandList(operand_refs);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    // Record operands:
 | 
					    // Record operands:
 | 
				
			||||||
    for (auto& operand_ref : operand_refs)
 | 
					    for (auto &operand_ref : operand_refs)
 | 
				
			||||||
      _operandRefQueue.emplace(operand_ref);
 | 
					      _operandRefQueue.emplace(operand_ref);
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -48,8 +48,8 @@ ParseResult KrnlDialectOperandParser::ParseOptionalOperand(
 | 
				
			||||||
  return success();
 | 
					  return success();
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
ParseResult KrnlDialectOperandParser::ParseOperand(const Type &operandType,
 | 
					ParseResult KrnlDialectOperandParser::ParseOperand(
 | 
				
			||||||
                                                   Value &operand) {
 | 
					    const Type &operandType, Value &operand) {
 | 
				
			||||||
  if (ParseOptionalOperand(operandType, operand))
 | 
					  if (ParseOptionalOperand(operandType, operand))
 | 
				
			||||||
    return _parser.emitError(
 | 
					    return _parser.emitError(
 | 
				
			||||||
        _parser.getCurrentLocation(), "Expecting an operand.");
 | 
					        _parser.getCurrentLocation(), "Expecting an operand.");
 | 
				
			||||||
| 
						 | 
					@ -65,8 +65,8 @@ ParseResult KrnlDialectOperandParser::ParseOperand(
 | 
				
			||||||
  return success();
 | 
					  return success();
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
void printDimAndSymbolList(Operation::operand_iterator& begin, unsigned numDims,
 | 
					void printDimAndSymbolList(Operation::operand_iterator &begin, unsigned numDims,
 | 
				
			||||||
    unsigned numSymbols, OpAsmPrinter& p) {
 | 
					    unsigned numSymbols, OpAsmPrinter &p) {
 | 
				
			||||||
  p << '(';
 | 
					  p << '(';
 | 
				
			||||||
  p.printOperands(begin, begin + numDims);
 | 
					  p.printOperands(begin, begin + numDims);
 | 
				
			||||||
  p << ')';
 | 
					  p << ')';
 | 
				
			||||||
| 
						 | 
					@ -81,8 +81,8 @@ void printDimAndSymbolList(Operation::operand_iterator& begin, unsigned numDims,
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
void printBound(AffineMapAttr boundMap,
 | 
					void printBound(AffineMapAttr boundMap,
 | 
				
			||||||
    Operation::operand_iterator& boundOperandsBeg, const char* prefix,
 | 
					    Operation::operand_iterator &boundOperandsBeg, const char *prefix,
 | 
				
			||||||
    OpAsmPrinter& p) {
 | 
					    OpAsmPrinter &p) {
 | 
				
			||||||
  AffineMap map = boundMap.getValue();
 | 
					  AffineMap map = boundMap.getValue();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  // Check if this bound should be printed using custom assembly form.
 | 
					  // Check if this bound should be printed using custom assembly form.
 | 
				
			||||||
| 
						 | 
					@ -123,6 +123,7 @@ void printBound(AffineMapAttr boundMap,
 | 
				
			||||||
} // namespace onnf
 | 
					} // namespace onnf
 | 
				
			||||||
 | 
					
 | 
				
			||||||
namespace mlir {
 | 
					namespace mlir {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
void KrnlIterateOperandPack::pushConstantBound(int64_t bound) {
 | 
					void KrnlIterateOperandPack::pushConstantBound(int64_t bound) {
 | 
				
			||||||
  if (boundMaps.size() % 2 == 0)
 | 
					  if (boundMaps.size() % 2 == 0)
 | 
				
			||||||
    _operands.emplace_back(inputLoops[boundMaps.size() / 2]);
 | 
					    _operands.emplace_back(inputLoops[boundMaps.size() / 2]);
 | 
				
			||||||
| 
						 | 
					@ -137,4 +138,125 @@ void KrnlIterateOperandPack::pushOperandBound(mlir::Value operand) {
 | 
				
			||||||
  boundMaps.emplace_back(AffineMapAttr::get(map));
 | 
					  boundMaps.emplace_back(AffineMapAttr::get(map));
 | 
				
			||||||
  _operands.emplace_back(operand);
 | 
					  _operands.emplace_back(operand);
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					BuildKrnlLoop::BuildKrnlLoop(
 | 
				
			||||||
 | 
					    ConversionPatternRewriter &rewriter, Location loc, int loopNum)
 | 
				
			||||||
 | 
					    : rewriter(rewriter), loc(loc), originalLoopNum(loopNum), pack(NULL),
 | 
				
			||||||
 | 
					      pushCount(0), createdDefineOp(false), createdOptimizeOp(false),
 | 
				
			||||||
 | 
					      createdIterateOp(false) {
 | 
				
			||||||
 | 
					  if (originalLoopNum <= 0)
 | 
				
			||||||
 | 
					    emitError(loc, "expected positive number of original loops");
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					BuildKrnlLoop::BuildKrnlLoop(
 | 
				
			||||||
 | 
					    ConversionPatternRewriter &rewriter, Location loc, Value memRefOperand)
 | 
				
			||||||
 | 
					    : BuildKrnlLoop(rewriter, loc,
 | 
				
			||||||
 | 
					          memRefOperand.getType().cast<MemRefType>().getShape().size()) {}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					BuildKrnlLoop::~BuildKrnlLoop() {
 | 
				
			||||||
 | 
					  if (!createdDefineOp)
 | 
				
			||||||
 | 
					    emitError(loc, "expected to create define op");
 | 
				
			||||||
 | 
					  if (!createdIterateOp)
 | 
				
			||||||
 | 
					    emitError(loc, "expected to create iteration op");
 | 
				
			||||||
 | 
					  if (pack)
 | 
				
			||||||
 | 
					    free(pack);
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					void BuildKrnlLoop::createDefineAndOptimizeOp(bool withEmptyOptimization) {
 | 
				
			||||||
 | 
					  // insert define loop op
 | 
				
			||||||
 | 
					  auto loopsOp = rewriter.create<KrnlDefineLoopsOp>(loc, originalLoopNum);
 | 
				
			||||||
 | 
					  originalLoops.reserve(originalLoopNum);
 | 
				
			||||||
 | 
					  for (auto result : loopsOp.getResults())
 | 
				
			||||||
 | 
					    originalLoops.push_back(result);
 | 
				
			||||||
 | 
					  // inserte optimize loop op.
 | 
				
			||||||
 | 
					  auto optimizedLoopsOp =
 | 
				
			||||||
 | 
					      rewriter.create<KrnlOptimizeLoopsOp>(loc, originalLoopNum);
 | 
				
			||||||
 | 
					  optLoops.reserve(originalLoopNum);
 | 
				
			||||||
 | 
					  // Emit empty optimizations
 | 
				
			||||||
 | 
					  if (withEmptyOptimization) {
 | 
				
			||||||
 | 
					    for (auto result : optimizedLoopsOp.getResults())
 | 
				
			||||||
 | 
					      optLoops.push_back(result);
 | 
				
			||||||
 | 
					    optBlock = &optimizedLoopsOp.region().front();
 | 
				
			||||||
 | 
					    auto ip = rewriter.saveInsertionPoint();
 | 
				
			||||||
 | 
					    rewriter.setInsertionPointToEnd(optBlock);
 | 
				
			||||||
 | 
					    rewriter.create<KrnlReturnLoopsOp>(loc, originalLoops);
 | 
				
			||||||
 | 
					    rewriter.restoreInsertionPoint(ip);
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					  // prepare data structure to push bounds
 | 
				
			||||||
 | 
					  pack = new KrnlIterateOperandPack(rewriter, originalLoops, optLoops);
 | 
				
			||||||
 | 
					  createdOptimizeOp = true;
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// push bounds (lower and upper) and return index for loop info
 | 
				
			||||||
 | 
					int BuildKrnlLoop::pushBounds(int64_t lowerBound, int64_t upperBound) {
 | 
				
			||||||
 | 
					  pack->pushConstantBound(lowerBound);
 | 
				
			||||||
 | 
					  pack->pushConstantBound(upperBound);
 | 
				
			||||||
 | 
					  return pushCount++;
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					int BuildKrnlLoop::pushBounds(int64_t lowerBound, Value upperBound) {
 | 
				
			||||||
 | 
					  pack->pushConstantBound(lowerBound);
 | 
				
			||||||
 | 
					  pack->pushOperandBound(upperBound);
 | 
				
			||||||
 | 
					  return pushCount++;
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					int BuildKrnlLoop::pushBounds(int64_t lowerBound, Value upperBoundMemRefOperand,
 | 
				
			||||||
 | 
					    int upperBoundMemRefIndex, bool upperBoundMustBeConstant) {
 | 
				
			||||||
 | 
					  pack->pushConstantBound(lowerBound);
 | 
				
			||||||
 | 
					  // process upperBound as a dimension of mem ref, possibly non-constant
 | 
				
			||||||
 | 
					  auto shape = upperBoundMemRefOperand.getType().cast<MemRefType>().getShape();
 | 
				
			||||||
 | 
					  if (shape[upperBoundMemRefIndex] < 0) {
 | 
				
			||||||
 | 
					    if (upperBoundMustBeConstant)
 | 
				
			||||||
 | 
					      emitError(loc, "bound expected to be constant");
 | 
				
			||||||
 | 
					    pack->pushOperandBound(
 | 
				
			||||||
 | 
					        rewriter
 | 
				
			||||||
 | 
					            .create<DimOp>(loc, upperBoundMemRefOperand, upperBoundMemRefIndex)
 | 
				
			||||||
 | 
					            .getResult());
 | 
				
			||||||
 | 
					  } else
 | 
				
			||||||
 | 
					    pack->pushConstantBound(shape[upperBoundMemRefIndex]);
 | 
				
			||||||
 | 
					  return pushCount++;
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					int BuildKrnlLoop::pushBounds(Value lowerBound, Value upperBound) {
 | 
				
			||||||
 | 
					  pack->pushOperandBound(lowerBound);
 | 
				
			||||||
 | 
					  pack->pushOperandBound(upperBound);
 | 
				
			||||||
 | 
					  return pushCount++;
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// create iter
 | 
				
			||||||
 | 
					void BuildKrnlLoop::createIterateOp() {
 | 
				
			||||||
 | 
					  if (!createdDefineOp)
 | 
				
			||||||
 | 
					    emitError(loc, "must create define op before iterate op");
 | 
				
			||||||
 | 
					  // Tight now, optimize (possibly empty) is mandatory. This may change
 | 
				
			||||||
 | 
					  if (!createdOptimizeOp)
 | 
				
			||||||
 | 
					    emitError(loc, "must create optimize op before iterate op");
 | 
				
			||||||
 | 
					  // have to have defined all bounds
 | 
				
			||||||
 | 
					  if (pushCount != originalLoopNum) {
 | 
				
			||||||
 | 
					    printf(" push count %d, original loop %d\n", pushCount, originalLoopNum);
 | 
				
			||||||
 | 
					    emitError(loc, "must push bounds for all original loops");
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					  // create iterate op
 | 
				
			||||||
 | 
					  auto iterateOp = rewriter.create<KrnlIterateOp>(loc, *pack);
 | 
				
			||||||
 | 
					  iterBlock = &iterateOp.bodyRegion().front();
 | 
				
			||||||
 | 
					  createdIterateOp = true;
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					void BuildKrnlLoop::createDefineOptimizeAndIterateOp(
 | 
				
			||||||
 | 
					    Value memRefOperand, bool withEmptyOptimization) {
 | 
				
			||||||
 | 
					  int loopNum = memRefOperand.getType().cast<MemRefType>().getShape().size();
 | 
				
			||||||
 | 
					  if (originalLoopNum != loopNum)
 | 
				
			||||||
 | 
					    emitError(loc, "mismatch in loop numbers from constructor and define");
 | 
				
			||||||
 | 
					  createDefineAndOptimizeOp(withEmptyOptimization);
 | 
				
			||||||
 | 
					  for (int i = 0; i < originalLoopNum; ++i)
 | 
				
			||||||
 | 
					    pushBounds(0, memRefOperand, i);
 | 
				
			||||||
 | 
					  createIterateOp();
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// get induction variable to be use within iter
 | 
				
			||||||
 | 
					BlockArgument &BuildKrnlLoop::getInductionVar(int originalLoopIndex) {
 | 
				
			||||||
 | 
					  if (originalLoopIndex < 0 || originalLoopIndex >= originalLoopNum)
 | 
				
			||||||
 | 
					    emitError(loc, "original loop index is out of bound");
 | 
				
			||||||
 | 
					  return iterBlock->getArguments()[originalLoopIndex];
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
} // namespace mlir
 | 
					} // namespace mlir
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -8,39 +8,38 @@
 | 
				
			||||||
#include "mlir/IR/OpDefinition.h"
 | 
					#include "mlir/IR/OpDefinition.h"
 | 
				
			||||||
#include "mlir/IR/OpImplementation.h"
 | 
					#include "mlir/IR/OpImplementation.h"
 | 
				
			||||||
#include "mlir/IR/StandardTypes.h"
 | 
					#include "mlir/IR/StandardTypes.h"
 | 
				
			||||||
 | 
					#include "mlir/Transforms/DialectConversion.h"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
namespace onnf {
 | 
					namespace onnf {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class KrnlDialectOperandParser {
 | 
					class KrnlDialectOperandParser {
 | 
				
			||||||
 public:
 | 
					public:
 | 
				
			||||||
  explicit KrnlDialectOperandParser(mlir::OpAsmParser& parser)
 | 
					  explicit KrnlDialectOperandParser(mlir::OpAsmParser &parser)
 | 
				
			||||||
      : _parser(parser), _builder(parser.getBuilder()){};
 | 
					      : _parser(parser), _builder(parser.getBuilder()){};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  // Parse an optional operand.
 | 
					  // Parse an optional operand.
 | 
				
			||||||
  mlir::ParseResult ParseOptionalOperand(const mlir::Type &operandType,
 | 
					  mlir::ParseResult ParseOptionalOperand(
 | 
				
			||||||
                                         mlir::Value &operand);
 | 
					      const mlir::Type &operandType, mlir::Value &operand);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  // Parse an optional operand and push it to an operand list.
 | 
					  // Parse an optional operand and push it to an operand list.
 | 
				
			||||||
  mlir::ParseResult
 | 
					  mlir::ParseResult ParseOptionalOperand(const mlir::Type &operandType,
 | 
				
			||||||
  ParseOptionalOperand(const mlir::Type &operandType,
 | 
					 | 
				
			||||||
      llvm::SmallVectorImpl<mlir::Value> &operandList);
 | 
					      llvm::SmallVectorImpl<mlir::Value> &operandList);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  // Parse a required operand.
 | 
					  // Parse a required operand.
 | 
				
			||||||
  mlir::ParseResult ParseOperand(const mlir::Type &operandType,
 | 
					  mlir::ParseResult ParseOperand(
 | 
				
			||||||
                                 mlir::Value &operand);
 | 
					      const mlir::Type &operandType, mlir::Value &operand);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  // Parse a required operand and push it to an operand list.
 | 
					  // Parse a required operand and push it to an operand list.
 | 
				
			||||||
  mlir::ParseResult
 | 
					  mlir::ParseResult ParseOperand(const mlir::Type &operandType,
 | 
				
			||||||
  ParseOperand(const mlir::Type &operandType,
 | 
					 | 
				
			||||||
      llvm::SmallVectorImpl<mlir::Value> &operandList);
 | 
					      llvm::SmallVectorImpl<mlir::Value> &operandList);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  // Do we have more operands to parse?
 | 
					  // Do we have more operands to parse?
 | 
				
			||||||
  bool hasOperandLeft() { return !_operandRefQueue.empty(); }
 | 
					  bool hasOperandLeft() { return !_operandRefQueue.empty(); }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 private:
 | 
					private:
 | 
				
			||||||
  mlir::OpAsmParser& _parser;
 | 
					  mlir::OpAsmParser &_parser;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  mlir::Builder& _builder;
 | 
					  mlir::Builder &_builder;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  // A queue storing the parsed SSA id references.
 | 
					  // A queue storing the parsed SSA id references.
 | 
				
			||||||
  std::queue<mlir::OpAsmParser::OperandType> _operandRefQueue;
 | 
					  std::queue<mlir::OpAsmParser::OperandType> _operandRefQueue;
 | 
				
			||||||
| 
						 | 
					@ -50,16 +49,16 @@ class KrnlDialectOperandParser {
 | 
				
			||||||
// https://github.com/tensorflow/mlir/blob/6a150d70c7e06fb37cddd7188fa48cde9a90fe59/lib/Dialect/StandardOps/Ops.cpp#L197
 | 
					// https://github.com/tensorflow/mlir/blob/6a150d70c7e06fb37cddd7188fa48cde9a90fe59/lib/Dialect/StandardOps/Ops.cpp#L197
 | 
				
			||||||
// Main difference is that it advances the iterator `begin` as it consumes
 | 
					// Main difference is that it advances the iterator `begin` as it consumes
 | 
				
			||||||
// dimension and symbol operands.
 | 
					// dimension and symbol operands.
 | 
				
			||||||
void printDimAndSymbolList(mlir::Operation::operand_iterator& begin,
 | 
					void printDimAndSymbolList(mlir::Operation::operand_iterator &begin,
 | 
				
			||||||
    unsigned numDims, unsigned numSymbols, mlir::OpAsmPrinter& p);
 | 
					    unsigned numDims, unsigned numSymbols, mlir::OpAsmPrinter &p);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// Adapted from:
 | 
					// Adapted from:
 | 
				
			||||||
// https://github.com/tensorflow/mlir/blob/5cb42c914fed14cebbbe5c170b4e2784d2628304/lib/Dialect/AffineOps/AffineOps.cpp#L1272
 | 
					// https://github.com/tensorflow/mlir/blob/5cb42c914fed14cebbbe5c170b4e2784d2628304/lib/Dialect/AffineOps/AffineOps.cpp#L1272
 | 
				
			||||||
// Main difference is that it advances the iterator `boundOperandsBeg` as it
 | 
					// Main difference is that it advances the iterator `boundOperandsBeg` as it
 | 
				
			||||||
// prints bound.
 | 
					// prints bound.
 | 
				
			||||||
void printBound(mlir::AffineMapAttr boundMap,
 | 
					void printBound(mlir::AffineMapAttr boundMap,
 | 
				
			||||||
    mlir::Operation::operand_iterator& boundOperandsBeg, const char* prefix,
 | 
					    mlir::Operation::operand_iterator &boundOperandsBeg, const char *prefix,
 | 
				
			||||||
    mlir::OpAsmPrinter& p);
 | 
					    mlir::OpAsmPrinter &p);
 | 
				
			||||||
} // namespace onnf
 | 
					} // namespace onnf
 | 
				
			||||||
 | 
					
 | 
				
			||||||
namespace mlir {
 | 
					namespace mlir {
 | 
				
			||||||
| 
						 | 
					@ -88,7 +87,7 @@ struct KrnlIterateOperandPack {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  size_t getNumInputLoops() const { return inputLoops.size(); }
 | 
					  size_t getNumInputLoops() const { return inputLoops.size(); }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 private:
 | 
					private:
 | 
				
			||||||
  int _boundIdx = 0;
 | 
					  int _boundIdx = 0;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  llvm::SmallVector<mlir::Value, 8> _operands;
 | 
					  llvm::SmallVector<mlir::Value, 8> _operands;
 | 
				
			||||||
| 
						 | 
					@ -97,7 +96,99 @@ struct KrnlIterateOperandPack {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  llvm::ArrayRef<mlir::Value> inputLoops, optimizedLoops;
 | 
					  llvm::ArrayRef<mlir::Value> inputLoops, optimizedLoops;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  mlir::Builder& builder;
 | 
					  mlir::Builder &builder;
 | 
				
			||||||
 | 
					};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Helper function to write kernel loops. This class will let us build a single
 | 
				
			||||||
 | 
					// define/optimize/iterate operation combo. We can then insert optimizations in
 | 
				
			||||||
 | 
					// the body of the optimization operation, and operations in the body of the
 | 
				
			||||||
 | 
					// iterate operation.
 | 
				
			||||||
 | 
					//
 | 
				
			||||||
 | 
					// The sequence is as follow:
 | 
				
			||||||
 | 
					//
 | 
				
			||||||
 | 
					//   1) Create a object giving the rewriter, location, and number of loop in the
 | 
				
			||||||
 | 
					//   original (non optimized) loop.
 | 
				
			||||||
 | 
					//
 | 
				
			||||||
 | 
					//   2) Create define & optimize ops (currently paired). Optimizations can then
 | 
				
			||||||
 | 
					//   be added to the inner block of the optimize operation. Make sure to set the
 | 
				
			||||||
 | 
					//   insertion point to that block for optimizations to go in the right place.
 | 
				
			||||||
 | 
					//
 | 
				
			||||||
 | 
					//   3) Push the bounds for each of the original loops. Bounds are pushed in
 | 
				
			||||||
 | 
					//   pairs (lower & upper bounds). THere are a few methods to do it depending on
 | 
				
			||||||
 | 
					//   the type of the bounds. When pushing bounds, the method returns a number
 | 
				
			||||||
 | 
					//   that represent the index associated with that iteration (induction variable
 | 
				
			||||||
 | 
					//   and bounds). That index can be used later to extract the induction variable
 | 
				
			||||||
 | 
					//   for reference in computation and/or index calculations of mem refs.
 | 
				
			||||||
 | 
					//
 | 
				
			||||||
 | 
					//   4) Once all the bounds are pushed, create the iterate operation. Once this
 | 
				
			||||||
 | 
					//   is done, we can add operations within the iterate blocks by setting the
 | 
				
			||||||
 | 
					//   insertion point to it. Value of the induction variables can be retrieved
 | 
				
			||||||
 | 
					//   using the proper index (determined when pushin the bounds).
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class BuildKrnlLoop {
 | 
				
			||||||
 | 
					public:
 | 
				
			||||||
 | 
					  // Create a build kernel loop for the given location and loop number.
 | 
				
			||||||
 | 
					  BuildKrnlLoop(ConversionPatternRewriter &rewriter, Location loc, int loopNum);
 | 
				
			||||||
 | 
					  // Do the same, but where the loop number corresponds to the dimensionality of
 | 
				
			||||||
 | 
					  // the mem ref operand.
 | 
				
			||||||
 | 
					  BuildKrnlLoop(
 | 
				
			||||||
 | 
					      ConversionPatternRewriter &rewriter, Location loc, Value memRefOperand);
 | 
				
			||||||
 | 
					  ~BuildKrnlLoop();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // Create define and optimize loop with loopNum original loops. If
 | 
				
			||||||
 | 
					  // withEmptyOptimization, the optimization is simply the identity function (no
 | 
				
			||||||
 | 
					  // optimizations).
 | 
				
			||||||
 | 
					  void createDefineAndOptimizeOp(bool withEmptyOptimization = true);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // Push bounds (lower and upper) for each of the loops, in order. It returns
 | 
				
			||||||
 | 
					  // the index associated with the loop iteration. This index is in the range
 | 
				
			||||||
 | 
					  // from zero to original loop number -1, and is monotonally increasing from
 | 
				
			||||||
 | 
					  // call to call. This index is later used in the getInductionVar call.
 | 
				
			||||||
 | 
					  int pushBounds(int64_t lowerBound, int64_t upperBound);
 | 
				
			||||||
 | 
					  int pushBounds(int64_t lowerBound, Value upperBound);
 | 
				
			||||||
 | 
					  int pushBounds(Value lowerBound, Value upperBound);
 | 
				
			||||||
 | 
					  // same, where the lower bound is an integer, and the uppoer bound is given by
 | 
				
			||||||
 | 
					  // the size of the mem ref operand along the upperBoundMemRefIndex dimension.
 | 
				
			||||||
 | 
					  int pushBounds(int64_t lowerBound, Value upperBoundMemRefOperand,
 | 
				
			||||||
 | 
					      int upperBoundMemRefIndex, bool upperBoundMustBeConstant = false);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // Create an iterate op.
 | 
				
			||||||
 | 
					  void createIterateOp();
 | 
				
			||||||
 | 
					  // Create an define, optimize and iterate op, with the same loop nummber as
 | 
				
			||||||
 | 
					  // the rank of the memRefOperand. The lower bound of each loops is zero, and
 | 
				
			||||||
 | 
					  // the upper bound of each loops is the dimension given by the mem refs
 | 
				
			||||||
 | 
					  void createDefineOptimizeAndIterateOp(
 | 
				
			||||||
 | 
					      Value memRefOperand, bool withEmptyOptimization = true);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // Get the (original loop) induction variable associated with the given index.
 | 
				
			||||||
 | 
					  // Use the index returned when pushing the bounds.
 | 
				
			||||||
 | 
					  BlockArgument &getInductionVar(int originalLoopIndex);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // Get blocks. This allow us to set the insertion point to the inner block of
 | 
				
			||||||
 | 
					  // the optimize and the iterate Operation
 | 
				
			||||||
 | 
					  Block *getOptimizationBlock() { return optBlock; }
 | 
				
			||||||
 | 
					  Block *getIterateBlock() { return iterBlock; }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // get original or optimized loops
 | 
				
			||||||
 | 
					  std::vector<Value> &getOriginalLoops() { return originalLoops; }
 | 
				
			||||||
 | 
					  std::vector<Value> &getOptimizedLoops() { return optLoops; }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					private:
 | 
				
			||||||
 | 
					  // inputs
 | 
				
			||||||
 | 
					  ConversionPatternRewriter &rewriter;
 | 
				
			||||||
 | 
					  Location loc;
 | 
				
			||||||
 | 
					  int originalLoopNum;
 | 
				
			||||||
 | 
					  // track loops and bounds
 | 
				
			||||||
 | 
					  std::vector<Value> originalLoops;
 | 
				
			||||||
 | 
					  std::vector<Value> optLoops;
 | 
				
			||||||
 | 
					  KrnlIterateOperandPack *pack;
 | 
				
			||||||
 | 
					  int pushCount;
 | 
				
			||||||
 | 
					  bool createdDefineOp;
 | 
				
			||||||
 | 
					  bool createdOptimizeOp;
 | 
				
			||||||
 | 
					  bool createdIterateOp;
 | 
				
			||||||
 | 
					  // insertion points (opt block, iterate)
 | 
				
			||||||
 | 
					  Block *optBlock;
 | 
				
			||||||
 | 
					  Block *iterBlock;
 | 
				
			||||||
};
 | 
					};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
} // namespace mlir
 | 
					} // namespace mlir
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue