[NFC] Categorize ONNX ops lowering (#80)
* Create two categories: elementwise and tensor * typos * Create directories for categories * Edit comments * Extract a function that creates a KrnlIterateOp * Add comments * Extract some common parts * Revise softmax * Add reduction.inc * Move lower-frontend to lib/conversion * Move directory to directory * Change file/directory names * Comment format * Add matmul.inc
This commit is contained in:
		
							parent
							
								
									3c505ae31d
								
							
						
					
					
						commit
						b9f2f25b56
					
				| 
						 | 
					@ -57,7 +57,7 @@ target_include_directories(onnf_shape_inference
 | 
				
			||||||
target_link_libraries(onnf_shape_inference ${MLIRLibs})
 | 
					target_link_libraries(onnf_shape_inference ${MLIRLibs})
 | 
				
			||||||
add_dependencies(onnf_shape_inference gen_krnl_ops)
 | 
					add_dependencies(onnf_shape_inference gen_krnl_ops)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
add_library(onnf_lower_frontend pass/lower_frontend_to_krnl.cpp)
 | 
					add_library(onnf_lower_frontend conversion/onnx_to_krnl/convert_onnx_to_krnl.cpp)
 | 
				
			||||||
target_include_directories(onnf_lower_frontend
 | 
					target_include_directories(onnf_lower_frontend
 | 
				
			||||||
        PRIVATE ${ONNF_SRC_ROOT} ${ONNF_BIN_ROOT}
 | 
					        PRIVATE ${ONNF_SRC_ROOT} ${ONNF_BIN_ROOT}
 | 
				
			||||||
        ${ONNF_SRC_ROOT})
 | 
					        ${ONNF_SRC_ROOT})
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -0,0 +1,529 @@
 | 
				
			||||||
 | 
					//====- convert_onnx_to_krnl.cpp - ONNX dialects to Krnl lowering ---------===//
 | 
				
			||||||
 | 
					//
 | 
				
			||||||
 | 
					// Copyright 2019 The IBM Research Authors.
 | 
				
			||||||
 | 
					//
 | 
				
			||||||
 | 
					// =============================================================================
 | 
				
			||||||
 | 
					//
 | 
				
			||||||
 | 
					// This file implements the lowering of frontend operations to a combination of
 | 
				
			||||||
 | 
					// Krnl IR and standard operations.
 | 
				
			||||||
 | 
					//
 | 
				
			||||||
 | 
					//===----------------------------------------------------------------------===//
 | 
				
			||||||
 | 
					#include <map>
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#include "mlir/Dialect/AffineOps/AffineOps.h"
 | 
				
			||||||
 | 
					#include "mlir/Dialect/StandardOps/Ops.h"
 | 
				
			||||||
 | 
					#include "mlir/Pass/Pass.h"
 | 
				
			||||||
 | 
					#include "mlir/Transforms/DialectConversion.h"
 | 
				
			||||||
 | 
					#include "llvm/ADT/ArrayRef.h"
 | 
				
			||||||
 | 
					#include "llvm/ADT/Sequence.h"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#include "src/dialect/krnl/krnl_helper.hpp"
 | 
				
			||||||
 | 
					#include "src/dialect/krnl/krnl_ops.hpp"
 | 
				
			||||||
 | 
					#include "src/dialect/onnx/onnx_ops.hpp"
 | 
				
			||||||
 | 
					#include "src/pass/passes.hpp"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					using namespace mlir;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					//===----------------------------------------------------------------------===//
 | 
				
			||||||
 | 
					// FrontendToAffine RewritePatterns
 | 
				
			||||||
 | 
					//===----------------------------------------------------------------------===//
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					/// Check is all dimensions are known at compile time.
 | 
				
			||||||
 | 
					static bool hasAllConstantDimensions(MemRefType type) {
 | 
				
			||||||
 | 
					  auto memRefShape = type.getShape();
 | 
				
			||||||
 | 
					  for (int i = 0; i < memRefShape.size(); ++i)
 | 
				
			||||||
 | 
					    if (memRefShape[i] < 0)
 | 
				
			||||||
 | 
					      return false;
 | 
				
			||||||
 | 
					  return true;
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					/// Convert the given TensorType into the corresponding MemRefType.
 | 
				
			||||||
 | 
					static MemRefType convertTensorToMemRef(TensorType type) {
 | 
				
			||||||
 | 
					  assert(type.hasRank() && "expected only ranked shapes");
 | 
				
			||||||
 | 
					  return MemRefType::get(type.getShape(), type.getElementType());
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					/// Insert an allocation and deallocation for the given MemRefType.
 | 
				
			||||||
 | 
					static Value insertAllocAndDealloc(MemRefType type, Location loc,
 | 
				
			||||||
 | 
					                                   PatternRewriter &rewriter,
 | 
				
			||||||
 | 
					                                   bool insertDealloc,
 | 
				
			||||||
 | 
					                                   ArrayRef<Value> operands = {}) {
 | 
				
			||||||
 | 
					  // Put together alloc operands for any dynamic dimensions of the memref.
 | 
				
			||||||
 | 
					  AllocOp alloc;
 | 
				
			||||||
 | 
					  if (!operands.empty()) {
 | 
				
			||||||
 | 
					    auto memRefShape = type.getShape();
 | 
				
			||||||
 | 
					    auto rank = memRefShape.size();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    std::map<int, Value> fromOperands;
 | 
				
			||||||
 | 
					    for (int reversedIdx = 0; reversedIdx < rank; ++reversedIdx) {
 | 
				
			||||||
 | 
					      int memRefDimIdx = rank - 1 - reversedIdx;
 | 
				
			||||||
 | 
					      if (memRefShape[memRefDimIdx] < 0) { // unknown dimension
 | 
				
			||||||
 | 
					        Value maxDim = nullptr;
 | 
				
			||||||
 | 
					        for (int i = 0; i < operands.size(); i++) {
 | 
				
			||||||
 | 
					          auto operandShape =
 | 
				
			||||||
 | 
					              operands[i].getType().cast<MemRefType>().getShape();
 | 
				
			||||||
 | 
					          int operandDimIdx = operandShape.size() - 1 - reversedIdx;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					          if (operandDimIdx < 0)
 | 
				
			||||||
 | 
					            continue;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					          // In case of operations with broadcasting, the dimension of the
 | 
				
			||||||
 | 
					          // alloc result is the maximum size along each dimension of the
 | 
				
			||||||
 | 
					          // operands.
 | 
				
			||||||
 | 
					          auto operandDim =
 | 
				
			||||||
 | 
					              rewriter.create<DimOp>(loc, operands[i], operandDimIdx);
 | 
				
			||||||
 | 
					          if (maxDim) {
 | 
				
			||||||
 | 
					            auto maxCondition = rewriter.create<CmpIOp>(loc, CmpIPredicate::sgt,
 | 
				
			||||||
 | 
					                                                        operandDim, maxDim);
 | 
				
			||||||
 | 
					            maxDim = rewriter.create<SelectOp>(loc, maxCondition, operandDim,
 | 
				
			||||||
 | 
					                                               maxDim);
 | 
				
			||||||
 | 
					          } else {
 | 
				
			||||||
 | 
					            maxDim = operandDim;
 | 
				
			||||||
 | 
					          }
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					        fromOperands.insert(std::make_pair(memRefDimIdx, maxDim));
 | 
				
			||||||
 | 
					      }
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    SmallVector<Value, 4> allocOperands;
 | 
				
			||||||
 | 
					    for (int i = 0; i < rank; ++i)
 | 
				
			||||||
 | 
					      if (memRefShape[i] < 0)
 | 
				
			||||||
 | 
					        allocOperands.push_back(fromOperands[i]);
 | 
				
			||||||
 | 
					    alloc = rewriter.create<AllocOp>(loc, type, allocOperands);
 | 
				
			||||||
 | 
					  } else {
 | 
				
			||||||
 | 
					    alloc = rewriter.create<AllocOp>(loc, type);
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // Make sure to allocate at the beginning of the block if
 | 
				
			||||||
 | 
					  // all dimensions are known.
 | 
				
			||||||
 | 
					  auto *parentBlock = alloc.getOperation()->getBlock();
 | 
				
			||||||
 | 
					  if (hasAllConstantDimensions(type))
 | 
				
			||||||
 | 
					    alloc.getOperation()->moveBefore(&parentBlock->front());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  if (insertDealloc) {
 | 
				
			||||||
 | 
					    auto dealloc = rewriter.create<DeallocOp>(loc, alloc);
 | 
				
			||||||
 | 
					    dealloc.getOperation()->moveBefore(&parentBlock->back());
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  return alloc;
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Determine if current function returns the result value of the
 | 
				
			||||||
 | 
					// current op being lowered. If it does then dealloc should not be
 | 
				
			||||||
 | 
					// inserted.
 | 
				
			||||||
 | 
					static bool checkInsertDealloc(Operation *currentOp) {
 | 
				
			||||||
 | 
					  auto parentBlock = currentOp->getBlock();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  bool insertDealloc = true;
 | 
				
			||||||
 | 
					  parentBlock->walk([&insertDealloc, currentOp](ReturnOp op) {
 | 
				
			||||||
 | 
					    assert(currentOp->getNumResults() < 2 &&
 | 
				
			||||||
 | 
					           "No more than one result supported (for now).");
 | 
				
			||||||
 | 
					    // If there is at least one result to investigate.
 | 
				
			||||||
 | 
					    if (currentOp->getNumResults() > 0) {
 | 
				
			||||||
 | 
					      auto result = currentOp->getResult(0);
 | 
				
			||||||
 | 
					      for (const auto &operand : op.getOperands())
 | 
				
			||||||
 | 
					        if (operand == result)
 | 
				
			||||||
 | 
					          insertDealloc = false;
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					  });
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  return insertDealloc;
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Create a mapping from result type's dimensions to input type's dimensions,
 | 
				
			||||||
 | 
					// given that the result type is the result of a reduction op over the input
 | 
				
			||||||
 | 
					// type.
 | 
				
			||||||
 | 
					std::map<int64_t, int64_t>
 | 
				
			||||||
 | 
					getReductionMapping(MemRefType inputTy, ArrayRef<int64_t> axes, bool keepdims) {
 | 
				
			||||||
 | 
					  std::map<int64_t, int64_t> OutInDimMap;
 | 
				
			||||||
 | 
					  int64_t rank = inputTy.getRank();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // Mark reduction axes.
 | 
				
			||||||
 | 
					  std::vector<bool> isReductionAxis;
 | 
				
			||||||
 | 
					  for (decltype(rank) i = 0; i < rank; ++i) {
 | 
				
			||||||
 | 
					    if (std::find(axes.begin(), axes.end(), i) != axes.end())
 | 
				
			||||||
 | 
					      isReductionAxis.push_back(true);
 | 
				
			||||||
 | 
					    else
 | 
				
			||||||
 | 
					      isReductionAxis.push_back(false);
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  for (decltype(rank) inIndex = 0, outIndex = 0; inIndex < rank; ++inIndex) {
 | 
				
			||||||
 | 
					    // If it is a reduction axis, there is no relationship among dimensions.
 | 
				
			||||||
 | 
					    if (isReductionAxis[inIndex]) {
 | 
				
			||||||
 | 
					      if (keepdims)
 | 
				
			||||||
 | 
					        outIndex++;
 | 
				
			||||||
 | 
					    } else {
 | 
				
			||||||
 | 
					      OutInDimMap.insert(std::make_pair(outIndex, inIndex));
 | 
				
			||||||
 | 
					      outIndex++;
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  return OutInDimMap;
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Add bounds associated with the op operand to the KRNL iteration pack.
 | 
				
			||||||
 | 
					// Dynamic dimenions are supported.
 | 
				
			||||||
 | 
					static void addDimensionToPack(ConversionPatternRewriter &rewriter,
 | 
				
			||||||
 | 
					                               Location loc, KrnlIterateOperandPack &pack,
 | 
				
			||||||
 | 
					                               Value operand, int index) {
 | 
				
			||||||
 | 
					  auto shape = operand.getType().cast<MemRefType>().getShape();
 | 
				
			||||||
 | 
					  if (shape[index] < 0) {
 | 
				
			||||||
 | 
					    pack.pushConstantBound(0);
 | 
				
			||||||
 | 
					    pack.pushOperandBound(
 | 
				
			||||||
 | 
					        rewriter.create<DimOp>(loc, operand, index).getResult());
 | 
				
			||||||
 | 
					  } else {
 | 
				
			||||||
 | 
					    pack.pushConstantBound(0);
 | 
				
			||||||
 | 
					    pack.pushConstantBound(shape[index]);
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Function that defines the KRNL dialect loops and their respective
 | 
				
			||||||
 | 
					// optimized version.
 | 
				
			||||||
 | 
					static KrnlOptimizeLoopsOp
 | 
				
			||||||
 | 
					emitOptimizedLoops(ConversionPatternRewriter &rewriter, Location loc,
 | 
				
			||||||
 | 
					                   std::vector<Value> &loops,
 | 
				
			||||||
 | 
					                   std::vector<Value> &optimizedLoops, int64_t numLoops) {
 | 
				
			||||||
 | 
					  // Define loops.
 | 
				
			||||||
 | 
					  auto loopsOp = rewriter.create<KrnlDefineLoopsOp>(loc, numLoops);
 | 
				
			||||||
 | 
					  loops.reserve(numLoops);
 | 
				
			||||||
 | 
					  for (auto result : loopsOp.getResults())
 | 
				
			||||||
 | 
					    loops.push_back(result);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // Define optimized version of the loops.
 | 
				
			||||||
 | 
					  auto optimizedLoopsOp = rewriter.create<KrnlOptimizeLoopsOp>(loc, numLoops);
 | 
				
			||||||
 | 
					  optimizedLoops.reserve(numLoops);
 | 
				
			||||||
 | 
					  for (auto result : optimizedLoopsOp.getResults())
 | 
				
			||||||
 | 
					    optimizedLoops.push_back(result);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  return optimizedLoopsOp;
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Function that emits the loops and their optimized version.
 | 
				
			||||||
 | 
					// The function returns a reference to the inner optimization block.
 | 
				
			||||||
 | 
					static Block *defineLoops(ConversionPatternRewriter &rewriter, Location loc,
 | 
				
			||||||
 | 
					                          std::vector<Value> &loops,
 | 
				
			||||||
 | 
					                          std::vector<Value> &optimizedLoops,
 | 
				
			||||||
 | 
					                          int64_t numLoops) {
 | 
				
			||||||
 | 
					  KrnlOptimizeLoopsOp optimizedLoopsOp =
 | 
				
			||||||
 | 
					      emitOptimizedLoops(rewriter, loc, loops, optimizedLoops, numLoops);
 | 
				
			||||||
 | 
					  return &optimizedLoopsOp.region().front();
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Function which emits a basic set of loops and optimized loops
 | 
				
			||||||
 | 
					// for a given operation argument. A reference to the loop optimization
 | 
				
			||||||
 | 
					// block is returned in the last argument of the function.
 | 
				
			||||||
 | 
					static void emitKrnlLoopsAndIterationForOperand(
 | 
				
			||||||
 | 
					    ConversionPatternRewriter &rewriter, Location loc, Value operand,
 | 
				
			||||||
 | 
					    std::vector<Value> &originalLoops, KrnlOptimizeLoopsOp &optimizedLoopsOp,
 | 
				
			||||||
 | 
					    KrnlIterateOp &iterateOp) {
 | 
				
			||||||
 | 
					  // Operand shape.
 | 
				
			||||||
 | 
					  auto shape = operand.getType().cast<MemRefType>().getShape();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // Number of loops.
 | 
				
			||||||
 | 
					  int64_t rank = shape.size();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // Define loops and optimized loops.
 | 
				
			||||||
 | 
					  std::vector<Value> optimizedLoops;
 | 
				
			||||||
 | 
					  optimizedLoopsOp =
 | 
				
			||||||
 | 
					      emitOptimizedLoops(rewriter, loc, originalLoops, optimizedLoops, rank);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  KrnlIterateOperandPack pack(rewriter, originalLoops, optimizedLoops);
 | 
				
			||||||
 | 
					  // Iterate over the loop nest.
 | 
				
			||||||
 | 
					  for (int i = 0; i < rank; ++i)
 | 
				
			||||||
 | 
					    addDimensionToPack(rewriter, loc, pack, operand, i);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  iterateOp = rewriter.create<KrnlIterateOp>(loc, pack);
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					unsigned getMemRefEltSizeInBytes(MemRefType memRefType) {
 | 
				
			||||||
 | 
					  auto elementType = memRefType.getElementType();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  unsigned sizeInBits;
 | 
				
			||||||
 | 
					  if (elementType.isIntOrFloat()) {
 | 
				
			||||||
 | 
					    sizeInBits = elementType.getIntOrFloatBitWidth();
 | 
				
			||||||
 | 
					  } else {
 | 
				
			||||||
 | 
					    auto vectorType = elementType.cast<VectorType>();
 | 
				
			||||||
 | 
					    sizeInBits =
 | 
				
			||||||
 | 
					        vectorType.getElementTypeBitWidth() * vectorType.getNumElements();
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					  return llvm::divideCeil(sizeInBits, 8);
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Get run-time dimension information for unknown dimensions used for
 | 
				
			||||||
 | 
					// broadcasting.
 | 
				
			||||||
 | 
					std::map<int, std::map<int, Value>>
 | 
				
			||||||
 | 
					getBroadcastedDimInfo(Location loc, ConversionPatternRewriter &rewriter,
 | 
				
			||||||
 | 
					                      MemRefType memRefType, ArrayRef<Value> operands) {
 | 
				
			||||||
 | 
					  auto memRefShape = memRefType.getShape();
 | 
				
			||||||
 | 
					  int64_t rank = memRefShape.size();
 | 
				
			||||||
 | 
					  // For unknown dimensions, we need to get dimension values at runtime in
 | 
				
			||||||
 | 
					  // order to do broadcasting.
 | 
				
			||||||
 | 
					  std::map<int, std::map<int, Value>> DimInfo;
 | 
				
			||||||
 | 
					  // For each result dimension, compute the number of sharing operands.
 | 
				
			||||||
 | 
					  // Sharing operands are operands sharing the same index (counting from the
 | 
				
			||||||
 | 
					  // rightmost to the leftmost) for a given dimension.
 | 
				
			||||||
 | 
					  std::map<int, int> sharedDimCount;
 | 
				
			||||||
 | 
					  for (int reversedIdx = 0; reversedIdx < rank; ++reversedIdx) {
 | 
				
			||||||
 | 
					    int dimIdx = rank - 1 - reversedIdx;
 | 
				
			||||||
 | 
					    sharedDimCount[dimIdx] = 0;
 | 
				
			||||||
 | 
					    for (int i = 0; i < operands.size(); ++i) {
 | 
				
			||||||
 | 
					      auto shape = operands[i].getType().cast<MemRefType>().getShape();
 | 
				
			||||||
 | 
					      if (reversedIdx <= shape.size() - 1)
 | 
				
			||||||
 | 
					        sharedDimCount[dimIdx]++;
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					  // An unknown dimension can have a value of 1 or N (N > 1).
 | 
				
			||||||
 | 
					  // If its value is 1, it is broadcasted dimension.
 | 
				
			||||||
 | 
					  // Otherwise, non-broadcasted dimension.
 | 
				
			||||||
 | 
					  // We only care about unknown dimensions whose number of sharing operands is
 | 
				
			||||||
 | 
					  // more than one, since they are potentially broadcasted dimensions.
 | 
				
			||||||
 | 
					  for (int i = 0; i < operands.size(); ++i) {
 | 
				
			||||||
 | 
					    std::map<int, Value> broadcastedDims;
 | 
				
			||||||
 | 
					    auto shape = operands[i].getType().cast<MemRefType>().getShape();
 | 
				
			||||||
 | 
					    int size = shape.size();
 | 
				
			||||||
 | 
					    for (int j = 0; j < shape.size(); ++j) {
 | 
				
			||||||
 | 
					      if (shape[j] < 0 and sharedDimCount[rank - size + j] > 1) {
 | 
				
			||||||
 | 
					        auto dim = rewriter.create<DimOp>(loc, operands[i], j).getResult();
 | 
				
			||||||
 | 
					        auto one = rewriter.create<ConstantIndexOp>(loc, 1);
 | 
				
			||||||
 | 
					        auto isBroadcasted =
 | 
				
			||||||
 | 
					            rewriter.create<CmpIOp>(loc, CmpIPredicate::eq, dim, one);
 | 
				
			||||||
 | 
					        broadcastedDims.insert(std::make_pair(j, isBroadcasted));
 | 
				
			||||||
 | 
					      }
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					    DimInfo.insert(std::make_pair(i, broadcastedDims));
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					  return DimInfo;
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Extract induction variables that are used for broadcasting values of a
 | 
				
			||||||
 | 
					// given operand.
 | 
				
			||||||
 | 
					std::vector<Value>
 | 
				
			||||||
 | 
					getLoopIVsForBroadcasting(Location loc, ConversionPatternRewriter &rewriter,
 | 
				
			||||||
 | 
					                          ArrayRef<Value> loopIVs, Value operand,
 | 
				
			||||||
 | 
					                          std::map<int, Value> broadcastedDims) {
 | 
				
			||||||
 | 
					  // `operand` must has a ranked type. This should have been checked by the
 | 
				
			||||||
 | 
					  // shape inference pass.
 | 
				
			||||||
 | 
					  auto operandShape = operand.getType().cast<MemRefType>().getShape();
 | 
				
			||||||
 | 
					  auto rank = operandShape.size();
 | 
				
			||||||
 | 
					  auto loopCount = loopIVs.size();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  std::vector<Value> newLoopIVs;
 | 
				
			||||||
 | 
					  for (unsigned reversedIdx = 0; reversedIdx < rank; ++reversedIdx) {
 | 
				
			||||||
 | 
					    auto dimIdx = rank - 1 - reversedIdx;
 | 
				
			||||||
 | 
					    auto loopIdx = loopCount - 1 - reversedIdx;
 | 
				
			||||||
 | 
					    if (operandShape[dimIdx] == 1) {
 | 
				
			||||||
 | 
					      // Broadcasted dimension
 | 
				
			||||||
 | 
					      auto zero = rewriter.create<ConstantIndexOp>(loc, 0);
 | 
				
			||||||
 | 
					      newLoopIVs.insert(newLoopIVs.begin(), zero);
 | 
				
			||||||
 | 
					    } else if ((operandShape[dimIdx] == -1) &&
 | 
				
			||||||
 | 
					               (broadcastedDims.find(dimIdx) != broadcastedDims.end())) {
 | 
				
			||||||
 | 
					      // Unknown dimension, it can have a value of 1 or N (N > 1).
 | 
				
			||||||
 | 
					      // If its value is 1, it is broadcasted dimension.
 | 
				
			||||||
 | 
					      // Otherwise, non-broadcasted dimension.
 | 
				
			||||||
 | 
					      auto zero = rewriter.create<ConstantIndexOp>(loc, 0);
 | 
				
			||||||
 | 
					      auto idx = rewriter.create<SelectOp>(loc, broadcastedDims[dimIdx], zero,
 | 
				
			||||||
 | 
					                                           loopIVs[loopIdx]);
 | 
				
			||||||
 | 
					      newLoopIVs.insert(newLoopIVs.begin(), idx);
 | 
				
			||||||
 | 
					    } else {
 | 
				
			||||||
 | 
					      // Non-broadcasted dimension
 | 
				
			||||||
 | 
					      newLoopIVs.insert(newLoopIVs.begin(), loopIVs[loopIdx]);
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					  return newLoopIVs;
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					namespace {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// This is to get a scalar operation of a given type for a specific operation.
 | 
				
			||||||
 | 
					template <typename Op>
 | 
				
			||||||
 | 
					struct ScalarOp {
 | 
				
			||||||
 | 
					  using FOp = void;
 | 
				
			||||||
 | 
					  using IOp = void;
 | 
				
			||||||
 | 
					};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					template <typename FOp>
 | 
				
			||||||
 | 
					using ScalarFOp = typename ScalarOp<FOp>::FOp;
 | 
				
			||||||
 | 
					template <typename IOp>
 | 
				
			||||||
 | 
					using ScalarIOp = typename ScalarOp<IOp>::IOp;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Get the identity element of a operation.
 | 
				
			||||||
 | 
					// Return NULL if the function does not have identity.
 | 
				
			||||||
 | 
					template <typename DataType, typename Op>
 | 
				
			||||||
 | 
					DataType getIdentityValue() {
 | 
				
			||||||
 | 
					  return NULL;
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					//===----------------------------------------------------------------------===//
 | 
				
			||||||
 | 
					// This is used in the innermost loop of a KrnlIterateOp to insert computation
 | 
				
			||||||
 | 
					// composed of one or many scalar ops.
 | 
				
			||||||
 | 
					// Use template specialization for each of different ONNX operations.
 | 
				
			||||||
 | 
					//===----------------------------------------------------------------------===//
 | 
				
			||||||
 | 
					template <typename Op>
 | 
				
			||||||
 | 
					Value mapToLowerScalarOp(Operation *op, ArrayRef<Type> result_types,
 | 
				
			||||||
 | 
					                         ArrayRef<Value> operands,
 | 
				
			||||||
 | 
					                         ConversionPatternRewriter &rewriter) {
 | 
				
			||||||
 | 
					  auto loc = op->getLoc();
 | 
				
			||||||
 | 
					  Type element_type = operands.front().getType();
 | 
				
			||||||
 | 
					  if (element_type.isa<IntegerType>()) {
 | 
				
			||||||
 | 
					    return rewriter.create<ScalarIOp<Op>>(loc, result_types, operands,
 | 
				
			||||||
 | 
					                                          mlir::None);
 | 
				
			||||||
 | 
					  } else if (element_type.isa<FloatType>()) {
 | 
				
			||||||
 | 
					    return rewriter.create<ScalarFOp<Op>>(loc, result_types, operands,
 | 
				
			||||||
 | 
					                                          mlir::None);
 | 
				
			||||||
 | 
					  } else {
 | 
				
			||||||
 | 
					    emitError(loc, "unsupported element type");
 | 
				
			||||||
 | 
					    return nullptr;
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// We divide the operator lowering into different categories.
 | 
				
			||||||
 | 
					// These categories are mostly similar to the operator categories in ONNX:
 | 
				
			||||||
 | 
					// https://github.com/onnx/onnx/tree/master/onnx/defs.
 | 
				
			||||||
 | 
					// Besides, it is better to put operators with the same computation pattern into
 | 
				
			||||||
 | 
					// the same category, e.g. element-wise operators will belong to the elementwise
 | 
				
			||||||
 | 
					// category.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Math
 | 
				
			||||||
 | 
					#include "src/conversion/onnx_to_krnl/rewrite_patterns/math/elementwise.inc"
 | 
				
			||||||
 | 
					#include "src/conversion/onnx_to_krnl/rewrite_patterns/math/gemm.inc"
 | 
				
			||||||
 | 
					#include "src/conversion/onnx_to_krnl/rewrite_patterns/math/reduction.inc"
 | 
				
			||||||
 | 
					#include "src/conversion/onnx_to_krnl/rewrite_patterns/math/softmax.inc"
 | 
				
			||||||
 | 
					#include "src/conversion/onnx_to_krnl/rewrite_patterns/math/matmul.inc"
 | 
				
			||||||
 | 
					// Tensor
 | 
				
			||||||
 | 
					#include "src/conversion/onnx_to_krnl/rewrite_patterns/tensor/identity.inc"
 | 
				
			||||||
 | 
					#include "src/conversion/onnx_to_krnl/rewrite_patterns/tensor/reshape.inc"
 | 
				
			||||||
 | 
					#include "src/conversion/onnx_to_krnl/rewrite_patterns/tensor/transpose.inc"
 | 
				
			||||||
 | 
					#include "src/conversion/onnx_to_krnl/rewrite_patterns/tensor/unsqueeze.inc"
 | 
				
			||||||
 | 
					// Neural network
 | 
				
			||||||
 | 
					#include "src/conversion/onnx_to_krnl/rewrite_patterns/nn/conv.inc"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					//===----------------------------------------------------------------------===//
 | 
				
			||||||
 | 
					// EntryPoint Op lowering to Krnl Entry Point.
 | 
				
			||||||
 | 
					//===----------------------------------------------------------------------===//
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class ONNXEntryPointLowering : public OpRewritePattern<ONNXEntryPointOp> {
 | 
				
			||||||
 | 
					public:
 | 
				
			||||||
 | 
					  using OpRewritePattern<ONNXEntryPointOp>::OpRewritePattern;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  PatternMatchResult matchAndRewrite(ONNXEntryPointOp op,
 | 
				
			||||||
 | 
					                                     PatternRewriter &rewriter) const override {
 | 
				
			||||||
 | 
					    rewriter.replaceOpWithNewOp<KrnlEntryPointOp>(
 | 
				
			||||||
 | 
					        op,
 | 
				
			||||||
 | 
					        op.getAttrOfType<SymbolRefAttr>(
 | 
				
			||||||
 | 
					            ONNXEntryPointOp::getEntryPointFuncAttrName()),
 | 
				
			||||||
 | 
					        op.getAttrOfType<IntegerAttr>(ONNXEntryPointOp::getNumInputsAttrName()),
 | 
				
			||||||
 | 
					        op.getAttrOfType<IntegerAttr>(
 | 
				
			||||||
 | 
					            ONNXEntryPointOp::getNumOutputsAttrName()));
 | 
				
			||||||
 | 
					    return matchSuccess();
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					//===----------------------------------------------------------------------===//
 | 
				
			||||||
 | 
					// Conversion from Tensor type to the Standard dialect MemRef type.
 | 
				
			||||||
 | 
					//===----------------------------------------------------------------------===//
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					struct TensorTypeConverter : public TypeConverter {
 | 
				
			||||||
 | 
					  using TypeConverter::TypeConverter;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  LogicalResult convertType(Type t, SmallVectorImpl<Type> &results) override {
 | 
				
			||||||
 | 
					    if (auto tensor_type = t.dyn_cast<TensorType>()) {
 | 
				
			||||||
 | 
					      results.push_back(convertTensorToMemRef(tensor_type));
 | 
				
			||||||
 | 
					      return success();
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    results.push_back(t);
 | 
				
			||||||
 | 
					    return success();
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  /// Return true if the inputs and outputs of the given function type are
 | 
				
			||||||
 | 
					  /// legal. [Taken from MLIR and adapted to only check the legality of the
 | 
				
			||||||
 | 
					  /// inputs. Once unranked results can be handled gracefully this
 | 
				
			||||||
 | 
					  /// override needs to be removed in favour of the original MLIR one.]
 | 
				
			||||||
 | 
					  bool isSignatureLegal(FunctionType funcType) {
 | 
				
			||||||
 | 
					    return llvm::all_of(funcType.getInputs(),
 | 
				
			||||||
 | 
					                        [this](Type type) { return isLegal(type); });
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					} // end anonymous namespace.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					//===----------------------------------------------------------------------===//
 | 
				
			||||||
 | 
					// Frontend to Krnl Dialect lowering pass
 | 
				
			||||||
 | 
					//===----------------------------------------------------------------------===//
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					/// This is a partial lowering to Krnl loops of the ONNX operations.
 | 
				
			||||||
 | 
					namespace {
 | 
				
			||||||
 | 
					struct FrontendToKrnlLoweringPass
 | 
				
			||||||
 | 
					    : public ModulePass<FrontendToKrnlLoweringPass> {
 | 
				
			||||||
 | 
					  void runOnModule() final;
 | 
				
			||||||
 | 
					};
 | 
				
			||||||
 | 
					} // end anonymous namespace.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					void FrontendToKrnlLoweringPass::runOnModule() {
 | 
				
			||||||
 | 
					  auto module = getModule();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // The first thing to define is the conversion target. This will define the
 | 
				
			||||||
 | 
					  // final target for this lowering.
 | 
				
			||||||
 | 
					  ConversionTarget target(getContext());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // We define the specific operations, or dialects, that are legal targets for
 | 
				
			||||||
 | 
					  // this lowering.
 | 
				
			||||||
 | 
					  target
 | 
				
			||||||
 | 
					      .addLegalDialect<KrnlOpsDialect, AffineOpsDialect, StandardOpsDialect>();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // TODO: enable this once more ops are supported.
 | 
				
			||||||
 | 
					  // We also define the ONNX dialect as Illegal so that the conversion will fail
 | 
				
			||||||
 | 
					  // if any of these operations are *not* converted.
 | 
				
			||||||
 | 
					  // target.addIllegalDialect<mlir::ONNXOpsDialect>();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // TODO: add any other ops which are considered legal.
 | 
				
			||||||
 | 
					  // Some operations can be marked as being still legal.
 | 
				
			||||||
 | 
					  // Example: target.addLegalOp<mlir::OpName>();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // Now that the conversion target has been defined, we just need to provide
 | 
				
			||||||
 | 
					  // the set of patterns that will lower the frontend operations.
 | 
				
			||||||
 | 
					  OwningRewritePatternList patterns;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // Convert TensorType to MemRef
 | 
				
			||||||
 | 
					  TensorTypeConverter tensor_to_memref_converter;
 | 
				
			||||||
 | 
					  target.addDynamicallyLegalOp<FuncOp>([&](FuncOp op) {
 | 
				
			||||||
 | 
					    // FuncOp is legal only if types have been converted to Std types.
 | 
				
			||||||
 | 
					    return tensor_to_memref_converter.isSignatureLegal(op.getType());
 | 
				
			||||||
 | 
					  });
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // Type conversion for function signatures.
 | 
				
			||||||
 | 
					  // Call MLIR FuncOp signature conversion when result type is
 | 
				
			||||||
 | 
					  // a ranked tensor.
 | 
				
			||||||
 | 
					  populateFuncOpTypeConversionPattern(patterns, &getContext(),
 | 
				
			||||||
 | 
					                                      tensor_to_memref_converter);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // Frontend operation lowering.
 | 
				
			||||||
 | 
					  // Math
 | 
				
			||||||
 | 
					  populateLoweringONNXElementwiseOpPattern(patterns, &getContext());
 | 
				
			||||||
 | 
					  populateLoweringONNXGemmOpPattern(patterns, &getContext());
 | 
				
			||||||
 | 
					  populateLoweringONNXReductionOpPattern(patterns, &getContext());
 | 
				
			||||||
 | 
					  populateLoweringONNXSoftmaxOpPattern(patterns, &getContext());
 | 
				
			||||||
 | 
					  populateLoweringONNXMatMulOpPattern(patterns, &getContext());
 | 
				
			||||||
 | 
					  // Tensor
 | 
				
			||||||
 | 
					  populateLoweringONNXReshapeOpPattern(patterns, &getContext());
 | 
				
			||||||
 | 
					  populateLoweringONNXUnsqueezeOpPattern(patterns, &getContext());
 | 
				
			||||||
 | 
					  populateLoweringONNXTransposeOpPattern(patterns, &getContext());
 | 
				
			||||||
 | 
					  populateLoweringONNXIdentityOpPattern(patterns, &getContext());
 | 
				
			||||||
 | 
					  // Neural network
 | 
				
			||||||
 | 
					  populateLoweringONNXConvOpPattern(patterns, &getContext());
 | 
				
			||||||
 | 
					  // Entry point
 | 
				
			||||||
 | 
					  patterns.insert<ONNXEntryPointLowering>(&getContext());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // With the target and rewrite patterns defined, we can now attempt the
 | 
				
			||||||
 | 
					  // conversion. The conversion will signal failure if any of our `illegal`
 | 
				
			||||||
 | 
					  // operations were not converted successfully.
 | 
				
			||||||
 | 
					  if (failed(applyPartialConversion(module, target, patterns)))
 | 
				
			||||||
 | 
					    signalPassFailure();
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					std::unique_ptr<Pass> mlir::createLowerToKrnlPass() {
 | 
				
			||||||
 | 
					  return std::make_unique<FrontendToKrnlLoweringPass>();
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					static PassRegistration<FrontendToKrnlLoweringPass>
 | 
				
			||||||
 | 
					    pass("lower-frontend", "Lower frontend ops to Krnl dialect.");
 | 
				
			||||||
| 
						 | 
					@ -0,0 +1,646 @@
 | 
				
			||||||
 | 
					//===----- elementwise.inc - Elementwise Ops ------------------------------===//
 | 
				
			||||||
 | 
					//
 | 
				
			||||||
 | 
					// Copyright 2019 The IBM Research Authors.
 | 
				
			||||||
 | 
					//
 | 
				
			||||||
 | 
					// =============================================================================
 | 
				
			||||||
 | 
					//
 | 
				
			||||||
 | 
					// This file lowers ONNX element-wise operators to Krnl dialect.
 | 
				
			||||||
 | 
					//
 | 
				
			||||||
 | 
					//===----------------------------------------------------------------------===//
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					template <>
 | 
				
			||||||
 | 
					struct ScalarOp<ONNXAddOp> {
 | 
				
			||||||
 | 
					  using FOp = AddFOp;
 | 
				
			||||||
 | 
					  using IOp = AddIOp;
 | 
				
			||||||
 | 
					};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					template <>
 | 
				
			||||||
 | 
					struct ScalarOp<ONNXMulOp> {
 | 
				
			||||||
 | 
					  using FOp = MulFOp;
 | 
				
			||||||
 | 
					  using IOp = MulIOp;
 | 
				
			||||||
 | 
					};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					template <>
 | 
				
			||||||
 | 
					struct ScalarOp<ONNXDivOp> {
 | 
				
			||||||
 | 
					  using FOp = DivFOp;
 | 
				
			||||||
 | 
					  using IOp = SignedDivIOp;
 | 
				
			||||||
 | 
					};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					template <>
 | 
				
			||||||
 | 
					struct ScalarOp<ONNXSubOp> {
 | 
				
			||||||
 | 
					  using FOp = SubFOp;
 | 
				
			||||||
 | 
					  using IOp = SubIOp;
 | 
				
			||||||
 | 
					};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					template <>
 | 
				
			||||||
 | 
					struct ScalarOp<ONNXAndOp> {
 | 
				
			||||||
 | 
					  using FOp = AndOp; // not use
 | 
				
			||||||
 | 
					  using IOp = AndOp;
 | 
				
			||||||
 | 
					};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					template <>
 | 
				
			||||||
 | 
					struct ScalarOp<ONNXOrOp> {
 | 
				
			||||||
 | 
					  using FOp = OrOp; // not use
 | 
				
			||||||
 | 
					  using IOp = OrOp;
 | 
				
			||||||
 | 
					};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					template <>
 | 
				
			||||||
 | 
					struct ScalarOp<ONNXXorOp> {
 | 
				
			||||||
 | 
					  using FOp = XOrOp; // not use
 | 
				
			||||||
 | 
					  using IOp = XOrOp;
 | 
				
			||||||
 | 
					};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					template <>
 | 
				
			||||||
 | 
					struct ScalarOp<ONNXExpOp> {
 | 
				
			||||||
 | 
					  using FOp = ExpOp;
 | 
				
			||||||
 | 
					  using IOp = ExpOp; // not use
 | 
				
			||||||
 | 
					};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					template <>
 | 
				
			||||||
 | 
					struct ScalarOp<ONNXSumOp> {
 | 
				
			||||||
 | 
					  using FOp = AddFOp;
 | 
				
			||||||
 | 
					  using IOp = AddIOp;
 | 
				
			||||||
 | 
					};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					template <>
 | 
				
			||||||
 | 
					struct ScalarOp<ONNXTanhOp> {
 | 
				
			||||||
 | 
					  using FOp = TanhOp;
 | 
				
			||||||
 | 
					  using IOp = TanhOp; // not use
 | 
				
			||||||
 | 
					};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					template <>
 | 
				
			||||||
 | 
					struct ScalarOp<ONNXCosOp> {
 | 
				
			||||||
 | 
					  using FOp = CosOp;
 | 
				
			||||||
 | 
					  using IOp = CosOp; // not use
 | 
				
			||||||
 | 
					};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					template <>
 | 
				
			||||||
 | 
					struct ScalarOp<ONNXLogOp> {
 | 
				
			||||||
 | 
					  using FOp = LogOp;
 | 
				
			||||||
 | 
					  using IOp = LogOp; // not use
 | 
				
			||||||
 | 
					};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					template <>
 | 
				
			||||||
 | 
					struct ScalarOp<ONNXSqrtOp> {
 | 
				
			||||||
 | 
					  using FOp = KrnlSqrtOp;
 | 
				
			||||||
 | 
					  using IOp = KrnlSqrtOp; // not use
 | 
				
			||||||
 | 
					};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					//===----------------------------------------------------------------------===//
 | 
				
			||||||
 | 
					// Scalar unary ops for lowering ONNXSinhOp
 | 
				
			||||||
 | 
					//===----------------------------------------------------------------------===//
 | 
				
			||||||
 | 
					template <>
 | 
				
			||||||
 | 
					Value mapToLowerScalarOp<ONNXSinhOp>(Operation *op, ArrayRef<Type> result_types,
 | 
				
			||||||
 | 
					                                     ArrayRef<Value> operands,
 | 
				
			||||||
 | 
					                                     ConversionPatternRewriter &rewriter) {
 | 
				
			||||||
 | 
					  // ONNXSinhOp(%X) = DivFOp(SubFOp(ExpOp(%X), ExpOp(NegFOp(%X))),
 | 
				
			||||||
 | 
					  //                         ConstantOp 2)
 | 
				
			||||||
 | 
					  auto loc = op->getLoc();
 | 
				
			||||||
 | 
					  Value operand = operands[0];
 | 
				
			||||||
 | 
					  auto elementType = result_types[0];
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  auto zero = rewriter.create<ConstantOp>(loc, FloatAttr::get(elementType, 0));
 | 
				
			||||||
 | 
					  auto two = rewriter.create<ConstantOp>(loc, FloatAttr::get(elementType, 2));
 | 
				
			||||||
 | 
					  auto neg = rewriter.create<SubFOp>(loc, zero, operand);
 | 
				
			||||||
 | 
					  auto exp = rewriter.create<ExpOp>(loc, operand);
 | 
				
			||||||
 | 
					  auto negExp = rewriter.create<ExpOp>(loc, neg);
 | 
				
			||||||
 | 
					  auto result = rewriter.create<DivFOp>(
 | 
				
			||||||
 | 
					      loc, rewriter.create<SubFOp>(loc, exp, negExp), two);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  return result;
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					//===----------------------------------------------------------------------===//
 | 
				
			||||||
 | 
					// Scalar unary ops for lowering ONNXCoshOp
 | 
				
			||||||
 | 
					//===----------------------------------------------------------------------===//
 | 
				
			||||||
 | 
					template <>
 | 
				
			||||||
 | 
					Value mapToLowerScalarOp<ONNXCoshOp>(Operation *op, ArrayRef<Type> result_types,
 | 
				
			||||||
 | 
					                                     ArrayRef<Value> operands,
 | 
				
			||||||
 | 
					                                     ConversionPatternRewriter &rewriter) {
 | 
				
			||||||
 | 
					  // ONNXCoshOp(%X) = DivFOp(AddFOp(ExpOp(%X), ExpOp(NegFOp(%X))),
 | 
				
			||||||
 | 
					  //                         ConstantOp 2)
 | 
				
			||||||
 | 
					  auto loc = op->getLoc();
 | 
				
			||||||
 | 
					  Value operand = operands[0];
 | 
				
			||||||
 | 
					  auto elementType = result_types[0];
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  auto zero = rewriter.create<ConstantOp>(loc, FloatAttr::get(elementType, 0));
 | 
				
			||||||
 | 
					  auto two = rewriter.create<ConstantOp>(loc, FloatAttr::get(elementType, 2));
 | 
				
			||||||
 | 
					  auto neg = rewriter.create<SubFOp>(loc, zero, operand);
 | 
				
			||||||
 | 
					  auto exp = rewriter.create<ExpOp>(loc, operand);
 | 
				
			||||||
 | 
					  auto negExp = rewriter.create<ExpOp>(loc, neg);
 | 
				
			||||||
 | 
					  auto result = rewriter.create<DivFOp>(
 | 
				
			||||||
 | 
					      loc, rewriter.create<AddFOp>(loc, exp, negExp), two);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  return result;
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					//===----------------------------------------------------------------------===//
 | 
				
			||||||
 | 
					// Scalar unary ops for lowering ONNXSigmoidOp
 | 
				
			||||||
 | 
					//===----------------------------------------------------------------------===//
 | 
				
			||||||
 | 
					template <>
 | 
				
			||||||
 | 
					Value mapToLowerScalarOp<ONNXSigmoidOp>(Operation *op,
 | 
				
			||||||
 | 
					                                        ArrayRef<Type> result_types,
 | 
				
			||||||
 | 
					                                        ArrayRef<Value> operands,
 | 
				
			||||||
 | 
					                                        ConversionPatternRewriter &rewriter) {
 | 
				
			||||||
 | 
					  // ONNXSigmoidOp(%X) = DivFOp(ConstantOp 1,
 | 
				
			||||||
 | 
					  //                            AddFOp(ConstantOp 1, ExpOp(NegFOp(%X))))
 | 
				
			||||||
 | 
					  auto loc = op->getLoc();
 | 
				
			||||||
 | 
					  Value operand = operands[0];
 | 
				
			||||||
 | 
					  auto elementType = result_types[0];
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  auto zero = rewriter.create<ConstantOp>(loc, FloatAttr::get(elementType, 0));
 | 
				
			||||||
 | 
					  auto one = rewriter.create<ConstantOp>(loc, FloatAttr::get(elementType, 1));
 | 
				
			||||||
 | 
					  auto neg = rewriter.create<SubFOp>(loc, zero, operand);
 | 
				
			||||||
 | 
					  auto negExp = rewriter.create<ExpOp>(loc, neg);
 | 
				
			||||||
 | 
					  auto result = rewriter.create<DivFOp>(
 | 
				
			||||||
 | 
					      loc, one, rewriter.create<AddFOp>(loc, one, negExp));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  return result;
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					//===----------------------------------------------------------------------===//
 | 
				
			||||||
 | 
					// Scalar unary ops for lowering ONNXHardSigmoidOp
 | 
				
			||||||
 | 
					//===----------------------------------------------------------------------===//
 | 
				
			||||||
 | 
					template <>
 | 
				
			||||||
 | 
					Value mapToLowerScalarOp<ONNXHardSigmoidOp>(
 | 
				
			||||||
 | 
					    Operation *op, ArrayRef<Type> result_types, ArrayRef<Value> operands,
 | 
				
			||||||
 | 
					    ConversionPatternRewriter &rewriter) {
 | 
				
			||||||
 | 
					  // %Y = AddFOp(MulFOp(alpha, %X), beta)
 | 
				
			||||||
 | 
					  // %Z = SelectOp(CmpFOp(OGT, %Y, Constant 0),
 | 
				
			||||||
 | 
					  //               %Y,
 | 
				
			||||||
 | 
					  //               Constant 0)
 | 
				
			||||||
 | 
					  // ONNXHardSigmoidOp(%X) = SelectOp(CmpFOp(OLT, %Z, Constant 1),
 | 
				
			||||||
 | 
					  //                                  %Z,
 | 
				
			||||||
 | 
					  //                                  Constant 1)
 | 
				
			||||||
 | 
					  auto loc = op->getLoc();
 | 
				
			||||||
 | 
					  Value operand = operands[0];
 | 
				
			||||||
 | 
					  auto alphaAttribute = FloatAttr::get(rewriter.getF32Type(),
 | 
				
			||||||
 | 
					      llvm::dyn_cast<ONNXHardSigmoidOp>(op).alpha().convertToFloat());
 | 
				
			||||||
 | 
					  auto betaAttribute = FloatAttr::get(rewriter.getF32Type(),
 | 
				
			||||||
 | 
					      llvm::dyn_cast<ONNXHardSigmoidOp>(op).beta().convertToFloat());
 | 
				
			||||||
 | 
					  auto elementType = result_types[0];
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  auto zero = rewriter.create<ConstantOp>(loc, FloatAttr::get(elementType, 0));
 | 
				
			||||||
 | 
					  auto one = rewriter.create<ConstantOp>(loc, FloatAttr::get(elementType, 1));
 | 
				
			||||||
 | 
					  auto alpha = rewriter.create<ConstantOp>(loc, alphaAttribute);
 | 
				
			||||||
 | 
					  auto beta = rewriter.create<ConstantOp>(loc, betaAttribute);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  auto add = rewriter.create<AddFOp>(
 | 
				
			||||||
 | 
					      loc, rewriter.create<MulFOp>(loc, alpha, operand), beta);
 | 
				
			||||||
 | 
					  auto maxPredicate =
 | 
				
			||||||
 | 
					      rewriter.create<CmpFOp>(loc, CmpFPredicate::OGT, add, zero);
 | 
				
			||||||
 | 
					  auto max = rewriter.create<SelectOp>(loc, maxPredicate, add, zero);
 | 
				
			||||||
 | 
					  auto minPredicate =
 | 
				
			||||||
 | 
					      rewriter.create<CmpFOp>(loc, CmpFPredicate::OLT, max, one);
 | 
				
			||||||
 | 
					  auto result = rewriter.create<SelectOp>(loc, minPredicate, max, one);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  return result;
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					//===----------------------------------------------------------------------===//
 | 
				
			||||||
 | 
					// Scalar unary ops for lowering ONNXEluOp
 | 
				
			||||||
 | 
					//===----------------------------------------------------------------------===//
 | 
				
			||||||
 | 
					template <>
 | 
				
			||||||
 | 
					Value mapToLowerScalarOp<ONNXEluOp>(Operation *op, ArrayRef<Type> result_types,
 | 
				
			||||||
 | 
					                                    ArrayRef<Value> operands,
 | 
				
			||||||
 | 
					                                    ConversionPatternRewriter &rewriter) {
 | 
				
			||||||
 | 
					  // ONNXEluOp(%X) = SelectOp(CmpFOp(OLT, %X, ConstantOp 0),
 | 
				
			||||||
 | 
					  //                          MulFOp(alpha, SubFOp(ExpOp(%X), 1)),
 | 
				
			||||||
 | 
					  //                          %X)
 | 
				
			||||||
 | 
					  auto loc = op->getLoc();
 | 
				
			||||||
 | 
					  Value operand = operands[0];
 | 
				
			||||||
 | 
					  auto elementType = result_types[0];
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  auto alphaAttribute = FloatAttr::get(rewriter.getF32Type(),
 | 
				
			||||||
 | 
					      llvm::dyn_cast<ONNXEluOp>(op).alpha().convertToFloat());
 | 
				
			||||||
 | 
					  auto zero = rewriter.create<ConstantOp>(loc, FloatAttr::get(elementType, 0));
 | 
				
			||||||
 | 
					  auto one = rewriter.create<ConstantOp>(loc, FloatAttr::get(elementType, 1));
 | 
				
			||||||
 | 
					  auto alpha = rewriter.create<ConstantOp>(loc, alphaAttribute);
 | 
				
			||||||
 | 
					  auto exp = rewriter.create<ExpOp>(loc, operand);
 | 
				
			||||||
 | 
					  auto lessThanZero =
 | 
				
			||||||
 | 
					      rewriter.create<CmpFOp>(loc, CmpFPredicate::OLT, operand, zero);
 | 
				
			||||||
 | 
					  auto result = rewriter.create<SelectOp>(
 | 
				
			||||||
 | 
					      loc, lessThanZero,
 | 
				
			||||||
 | 
					      rewriter.create<MulFOp>(loc, alpha,
 | 
				
			||||||
 | 
					                              rewriter.create<SubFOp>(loc, exp, one)),
 | 
				
			||||||
 | 
					      operand);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  return result;
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					//===----------------------------------------------------------------------===//
 | 
				
			||||||
 | 
					// Scalar unary ops for lowering ONNXReluOp
 | 
				
			||||||
 | 
					//===----------------------------------------------------------------------===//
 | 
				
			||||||
 | 
					template <>
 | 
				
			||||||
 | 
					Value mapToLowerScalarOp<ONNXReluOp>(Operation *op, ArrayRef<Type> result_types,
 | 
				
			||||||
 | 
					                                     ArrayRef<Value> operands,
 | 
				
			||||||
 | 
					                                     ConversionPatternRewriter &rewriter) {
 | 
				
			||||||
 | 
					  // ONNXReluOp(%X) = SelectOp(CmpFOp(OLT, %X, ConstantOp 0),
 | 
				
			||||||
 | 
					  //                           ConstantOp 0,
 | 
				
			||||||
 | 
					  //                           %X)
 | 
				
			||||||
 | 
					  auto loc = op->getLoc();
 | 
				
			||||||
 | 
					  Value operand = operands[0];
 | 
				
			||||||
 | 
					  auto elementType = result_types[0];
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  auto zero = rewriter.create<ConstantOp>(loc, FloatAttr::get(elementType, 0));
 | 
				
			||||||
 | 
					  auto lessThanZero =
 | 
				
			||||||
 | 
					      rewriter.create<CmpFOp>(loc, CmpFPredicate::OLT, operand, zero);
 | 
				
			||||||
 | 
					  auto result = rewriter.create<SelectOp>(loc, lessThanZero, zero, operand);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  return result;
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					//===----------------------------------------------------------------------===//
 | 
				
			||||||
 | 
					// Scalar unary ops for lowering ONNXLeakyReluOp
 | 
				
			||||||
 | 
					//===----------------------------------------------------------------------===//
 | 
				
			||||||
 | 
					template <>
 | 
				
			||||||
 | 
					Value mapToLowerScalarOp<ONNXLeakyReluOp>(Operation *op,
 | 
				
			||||||
 | 
					                                          ArrayRef<Type> result_types,
 | 
				
			||||||
 | 
					                                          ArrayRef<Value> operands,
 | 
				
			||||||
 | 
					                                          ConversionPatternRewriter &rewriter) {
 | 
				
			||||||
 | 
					  // ONNXLeakyReluOp(%X) = SelectOp(CmpFOp(OLT, %X, ConstantOp 0),
 | 
				
			||||||
 | 
					  //                                MulFOp(alpha, %X),
 | 
				
			||||||
 | 
					  //                                %X)
 | 
				
			||||||
 | 
					  auto loc = op->getLoc();
 | 
				
			||||||
 | 
					  Value operand = operands[0];
 | 
				
			||||||
 | 
					  auto elementType = result_types[0];
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  auto alphaAttribute = FloatAttr::get(rewriter.getF32Type(),
 | 
				
			||||||
 | 
					      llvm::dyn_cast<ONNXLeakyReluOp>(op).alpha().convertToFloat());
 | 
				
			||||||
 | 
					  auto zero = rewriter.create<ConstantOp>(loc, FloatAttr::get(elementType, 0));
 | 
				
			||||||
 | 
					  auto alpha = rewriter.create<ConstantOp>(loc, alphaAttribute);
 | 
				
			||||||
 | 
					  auto lessThanZero =
 | 
				
			||||||
 | 
					      rewriter.create<CmpFOp>(loc, CmpFPredicate::OLT, operand, zero);
 | 
				
			||||||
 | 
					  auto result = rewriter.create<SelectOp>(
 | 
				
			||||||
 | 
					      loc, lessThanZero, rewriter.create<MulFOp>(loc, alpha, operand), operand);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  return result;
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					//===----------------------------------------------------------------------===//
 | 
				
			||||||
 | 
					// Scalar unary ops for lowering ONNXSeluOp
 | 
				
			||||||
 | 
					//===----------------------------------------------------------------------===//
 | 
				
			||||||
 | 
					template <>
 | 
				
			||||||
 | 
					Value mapToLowerScalarOp<ONNXSeluOp>(Operation *op, ArrayRef<Type> result_types,
 | 
				
			||||||
 | 
					                                     ArrayRef<Value> operands,
 | 
				
			||||||
 | 
					                                     ConversionPatternRewriter &rewriter) {
 | 
				
			||||||
 | 
					  // ONNXSeluOp(%X) = SelectOp(CmpFOp(OGT, %X, ConstantOp 0),
 | 
				
			||||||
 | 
					  //                           MulFOp(gamma, %X),
 | 
				
			||||||
 | 
					  //                           MulFOp(gamma,
 | 
				
			||||||
 | 
					  //                                  SubFOp(MulFOp(alpha, ExpOp(%X)),
 | 
				
			||||||
 | 
					  //                                         alpha)))
 | 
				
			||||||
 | 
					  auto loc = op->getLoc();
 | 
				
			||||||
 | 
					  Value operand = operands[0];
 | 
				
			||||||
 | 
					  auto alphaAttribute = FloatAttr::get(rewriter.getF32Type(),
 | 
				
			||||||
 | 
					      llvm::dyn_cast<ONNXSeluOp>(op).alpha().convertToFloat());
 | 
				
			||||||
 | 
					  auto gammaAttribute = FloatAttr::get(rewriter.getF32Type(),
 | 
				
			||||||
 | 
					      llvm::dyn_cast<ONNXSeluOp>(op).gamma().convertToFloat());
 | 
				
			||||||
 | 
					  auto elementType = result_types[0];
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  auto zero = rewriter.create<ConstantOp>(loc, FloatAttr::get(elementType, 0));
 | 
				
			||||||
 | 
					  auto alpha = rewriter.create<ConstantOp>(loc, alphaAttribute);
 | 
				
			||||||
 | 
					  auto gamma = rewriter.create<ConstantOp>(loc, gammaAttribute);
 | 
				
			||||||
 | 
					  auto exp = rewriter.create<ExpOp>(loc, operand);
 | 
				
			||||||
 | 
					  auto greaterThanZero =
 | 
				
			||||||
 | 
					      rewriter.create<CmpFOp>(loc, CmpFPredicate::OGT, operand, zero);
 | 
				
			||||||
 | 
					  auto select = rewriter.create<SelectOp>(
 | 
				
			||||||
 | 
					      loc, greaterThanZero, operand,
 | 
				
			||||||
 | 
					      rewriter.create<SubFOp>(loc, rewriter.create<MulFOp>(loc, alpha, exp),
 | 
				
			||||||
 | 
					                              alpha));
 | 
				
			||||||
 | 
					  auto result = rewriter.create<MulFOp>(loc, gamma, select);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  return result;
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					//===----------------------------------------------------------------------===//
 | 
				
			||||||
 | 
					// Scalar unary ops for lowering ONNXReciprocalOp
 | 
				
			||||||
 | 
					//===----------------------------------------------------------------------===//
 | 
				
			||||||
 | 
					template <>
 | 
				
			||||||
 | 
					Value mapToLowerScalarOp<ONNXReciprocalOp>(
 | 
				
			||||||
 | 
					    Operation *op, ArrayRef<Type> result_types, ArrayRef<Value> operands,
 | 
				
			||||||
 | 
					    ConversionPatternRewriter &rewriter) {
 | 
				
			||||||
 | 
					  // ONNXReciprocalOp(%X) = DivFOp(ConstantOp 1, %X)
 | 
				
			||||||
 | 
					  auto loc = op->getLoc();
 | 
				
			||||||
 | 
					  Value operand = operands[0];
 | 
				
			||||||
 | 
					  auto elementType = result_types[0];
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  auto one = rewriter.create<ConstantOp>(loc, FloatAttr::get(elementType, 1));
 | 
				
			||||||
 | 
					  auto result = rewriter.create<DivFOp>(loc, one, operand);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  return result;
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					//===----------------------------------------------------------------------===//
 | 
				
			||||||
 | 
					// Scalar unary ops for lowering ONNXSoftplusOp
 | 
				
			||||||
 | 
					//===----------------------------------------------------------------------===//
 | 
				
			||||||
 | 
					template <>
 | 
				
			||||||
 | 
					Value mapToLowerScalarOp<ONNXSoftplusOp>(
 | 
				
			||||||
 | 
					    Operation *op, ArrayRef<Type> result_types, ArrayRef<Value> operands,
 | 
				
			||||||
 | 
					    ConversionPatternRewriter &rewriter) {
 | 
				
			||||||
 | 
					  // ONNXSoftplusOp(%X) = LogOp(AddFOp(ExpOp(%X), ConstantOp 1))
 | 
				
			||||||
 | 
					  auto loc = op->getLoc();
 | 
				
			||||||
 | 
					  Value operand = operands[0];
 | 
				
			||||||
 | 
					  auto elementType = result_types[0];
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  auto exp = rewriter.create<ExpOp>(loc, operand);
 | 
				
			||||||
 | 
					  auto one = rewriter.create<ConstantOp>(loc, FloatAttr::get(elementType, 1));
 | 
				
			||||||
 | 
					  auto add = rewriter.create<AddFOp>(loc, exp, one);
 | 
				
			||||||
 | 
					  auto result = rewriter.create<LogOp>(loc, add);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  return result;
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					//===----------------------------------------------------------------------===//
 | 
				
			||||||
 | 
					// Scalar unary ops for lowering ONNXSoftsignOp
 | 
				
			||||||
 | 
					//===----------------------------------------------------------------------===//
 | 
				
			||||||
 | 
					template <>
 | 
				
			||||||
 | 
					Value mapToLowerScalarOp<ONNXSoftsignOp>(
 | 
				
			||||||
 | 
					    Operation *op, ArrayRef<Type> result_types, ArrayRef<Value> operands,
 | 
				
			||||||
 | 
					    ConversionPatternRewriter &rewriter) {
 | 
				
			||||||
 | 
					  // ONNXSoftsignOp(%X) = DivFOp(ConstantOp 1, %X)
 | 
				
			||||||
 | 
					  auto loc = op->getLoc();
 | 
				
			||||||
 | 
					  Value operand = operands[0];
 | 
				
			||||||
 | 
					  auto elementType = result_types[0];
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  auto abs = rewriter.create<AbsFOp>(loc, operand);
 | 
				
			||||||
 | 
					  auto one = rewriter.create<ConstantOp>(loc, FloatAttr::get(elementType, 1));
 | 
				
			||||||
 | 
					  auto add = rewriter.create<AddFOp>(loc, abs, one);
 | 
				
			||||||
 | 
					  auto result = rewriter.create<DivFOp>(loc, operand, add);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  return result;
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					//===----------------------------------------------------------------------===//
 | 
				
			||||||
 | 
					// Scalar unary ops for lowering ONNXSignOp
 | 
				
			||||||
 | 
					//===----------------------------------------------------------------------===//
 | 
				
			||||||
 | 
					template <>
 | 
				
			||||||
 | 
					Value mapToLowerScalarOp<ONNXSignOp>(Operation *op, ArrayRef<Type> result_types,
 | 
				
			||||||
 | 
					                                     ArrayRef<Value> operands,
 | 
				
			||||||
 | 
					                                     ConversionPatternRewriter &rewriter) {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  auto loc = op->getLoc();
 | 
				
			||||||
 | 
					  Value operand = operands[0];
 | 
				
			||||||
 | 
					  Type element_type = operands.front().getType();
 | 
				
			||||||
 | 
					  // TODO: unsigned int should be supported separately?
 | 
				
			||||||
 | 
					  if (element_type.isa<IntegerType>()) {
 | 
				
			||||||
 | 
					    // %Y = SelectOP(CmpIOp(GT, %X, ConstantOp 0),
 | 
				
			||||||
 | 
					    //               ConstantOp 1,
 | 
				
			||||||
 | 
					    //               COnstantOp -1)
 | 
				
			||||||
 | 
					    // ONNXSignOp(%X) = SelectOP(CmpIOp(EQ, %X, ConstantOp 0),
 | 
				
			||||||
 | 
					    //                           ConstantOp 0,
 | 
				
			||||||
 | 
					    //                           %Y)
 | 
				
			||||||
 | 
					    auto zero = rewriter.create<ConstantOp>(loc, rewriter.getI32IntegerAttr(0));
 | 
				
			||||||
 | 
					    auto one = rewriter.create<ConstantOp>(loc, rewriter.getI32IntegerAttr(1));
 | 
				
			||||||
 | 
					    auto minusOne =
 | 
				
			||||||
 | 
					        rewriter.create<ConstantOp>(loc, rewriter.getI32IntegerAttr(-1));
 | 
				
			||||||
 | 
					    auto plusPredicate =
 | 
				
			||||||
 | 
					        rewriter.create<CmpIOp>(loc, CmpIPredicate::sgt, operand, zero);
 | 
				
			||||||
 | 
					    auto plusSelect =
 | 
				
			||||||
 | 
					        rewriter.create<SelectOp>(loc, plusPredicate, one, minusOne);
 | 
				
			||||||
 | 
					    auto zeroPredicate =
 | 
				
			||||||
 | 
					        rewriter.create<CmpIOp>(loc, CmpIPredicate::eq, operand, zero);
 | 
				
			||||||
 | 
					    auto result =
 | 
				
			||||||
 | 
					        rewriter.create<SelectOp>(loc, zeroPredicate, zero, plusSelect);
 | 
				
			||||||
 | 
					    return result;
 | 
				
			||||||
 | 
					  } else if (element_type.isa<FloatType>()) {
 | 
				
			||||||
 | 
					    // %Y = SelectOP(CmpFOp(OGT, %X, ConstantOp 0),
 | 
				
			||||||
 | 
					    //               ConstantOp 1,
 | 
				
			||||||
 | 
					    //               ConstantOp -1)
 | 
				
			||||||
 | 
					    // ONNXSignOp(%X) = SelectOP(CmpFOp(OEQ, %X, ConstantOp 0),
 | 
				
			||||||
 | 
					    //                           ConstantOp 0,
 | 
				
			||||||
 | 
					    //                           %Y)
 | 
				
			||||||
 | 
					    auto zero =
 | 
				
			||||||
 | 
					        rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(0.0f));
 | 
				
			||||||
 | 
					    auto one = rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(1.0f));
 | 
				
			||||||
 | 
					    auto minusOne =
 | 
				
			||||||
 | 
					        rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(-1.0f));
 | 
				
			||||||
 | 
					    auto plusPredicate =
 | 
				
			||||||
 | 
					        rewriter.create<CmpFOp>(loc, CmpFPredicate::OGT, operand, zero);
 | 
				
			||||||
 | 
					    auto plusSelect =
 | 
				
			||||||
 | 
					        rewriter.create<SelectOp>(loc, plusPredicate, one, minusOne);
 | 
				
			||||||
 | 
					    auto zeroPredicate =
 | 
				
			||||||
 | 
					        rewriter.create<CmpFOp>(loc, CmpFPredicate::OEQ, operand, zero);
 | 
				
			||||||
 | 
					    auto result =
 | 
				
			||||||
 | 
					        rewriter.create<SelectOp>(loc, zeroPredicate, zero, plusSelect);
 | 
				
			||||||
 | 
					    return result;
 | 
				
			||||||
 | 
					  } else {
 | 
				
			||||||
 | 
					    emitError(loc, "unsupported element type");
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					//===----------------------------------------------------------------------===//
 | 
				
			||||||
 | 
					// Scalar unary ops for lowering ONNXMaxOp
 | 
				
			||||||
 | 
					//===----------------------------------------------------------------------===//
 | 
				
			||||||
 | 
					template <>
 | 
				
			||||||
 | 
					Value mapToLowerScalarOp<ONNXMaxOp>(Operation *op, ArrayRef<Type> result_types,
 | 
				
			||||||
 | 
					                                    ArrayRef<Value> operands,
 | 
				
			||||||
 | 
					                                    ConversionPatternRewriter &rewriter) {
 | 
				
			||||||
 | 
					  // ONNXMaxOp(%X, %Y) = SelectOp(CmpFOp(OGT, %X, %Y),
 | 
				
			||||||
 | 
					  //                              %X,
 | 
				
			||||||
 | 
					  //                              %Y)
 | 
				
			||||||
 | 
					  auto loc = op->getLoc();
 | 
				
			||||||
 | 
					  Value lhs = operands[0];
 | 
				
			||||||
 | 
					  Value rhs = operands[1];
 | 
				
			||||||
 | 
					  auto max = rewriter.create<CmpFOp>(loc, CmpFPredicate::OGT, lhs, rhs);
 | 
				
			||||||
 | 
					  auto result = rewriter.create<SelectOp>(loc, max, lhs, rhs);
 | 
				
			||||||
 | 
					  return result;
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					//===----------------------------------------------------------------------===//
 | 
				
			||||||
 | 
					// Scalar unary ops for lowering ONNXMinOp
 | 
				
			||||||
 | 
					//===----------------------------------------------------------------------===//
 | 
				
			||||||
 | 
					template <>
 | 
				
			||||||
 | 
					Value mapToLowerScalarOp<ONNXMinOp>(Operation *op, ArrayRef<Type> result_types,
 | 
				
			||||||
 | 
					                                    ArrayRef<Value> operands,
 | 
				
			||||||
 | 
					                                    ConversionPatternRewriter &rewriter) {
 | 
				
			||||||
 | 
					  // ONNXMinOp(%X, %Y) = SelectOp(CmpFOp(OLT, %X, %Y),
 | 
				
			||||||
 | 
					  //                              %X,
 | 
				
			||||||
 | 
					  //                              %Y)
 | 
				
			||||||
 | 
					  auto loc = op->getLoc();
 | 
				
			||||||
 | 
					  Value lhs = operands[0];
 | 
				
			||||||
 | 
					  Value rhs = operands[1];
 | 
				
			||||||
 | 
					  auto min = rewriter.create<CmpFOp>(loc, CmpFPredicate::OLT, lhs, rhs);
 | 
				
			||||||
 | 
					  auto result = rewriter.create<SelectOp>(loc, min, lhs, rhs);
 | 
				
			||||||
 | 
					  return result;
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Element-wise unary ops lowering to Krnl dialect.
 | 
				
			||||||
 | 
					//===----------------------------------------------------------------------===//
 | 
				
			||||||
 | 
					template <typename ElementwiseUnaryOp>
 | 
				
			||||||
 | 
					struct ONNXElementwiseUnaryOpLowering : public ConversionPattern {
 | 
				
			||||||
 | 
					  ONNXElementwiseUnaryOpLowering(MLIRContext *ctx)
 | 
				
			||||||
 | 
					      : ConversionPattern(ElementwiseUnaryOp::getOperationName(), 1, ctx) {}
 | 
				
			||||||
 | 
					  PatternMatchResult
 | 
				
			||||||
 | 
					  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
 | 
				
			||||||
 | 
					                  ConversionPatternRewriter &rewriter) const final {
 | 
				
			||||||
 | 
					    // TODO: Check that the types are valid.
 | 
				
			||||||
 | 
					    // An element-wise unary operation must have all operands and the result of
 | 
				
			||||||
 | 
					    // the same type. This should have been verified by the verifier.
 | 
				
			||||||
 | 
					    auto tensorType = (*op->result_type_begin()).cast<TensorType>();
 | 
				
			||||||
 | 
					    auto loc = op->getLoc();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    // Insert an allocation and deallocation for the result of this operation.
 | 
				
			||||||
 | 
					    auto memRefType = convertTensorToMemRef(tensorType);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    // If the output has a dynamic dimension, pass the operands required for
 | 
				
			||||||
 | 
					    // each dynamic dimension to the AllocOp. The first operand of the
 | 
				
			||||||
 | 
					    // operation is used. The operands of the op need to match in terms of
 | 
				
			||||||
 | 
					    // dimensions with the result at this pre-optimization phase.
 | 
				
			||||||
 | 
					    // TODO: verify that dimensions match.
 | 
				
			||||||
 | 
					    // TODO: can the dimension of the result differ after optimizations?
 | 
				
			||||||
 | 
					    Value alloc;
 | 
				
			||||||
 | 
					    bool insertDealloc = checkInsertDealloc(op);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if (hasAllConstantDimensions(memRefType))
 | 
				
			||||||
 | 
					      alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc);
 | 
				
			||||||
 | 
					    else
 | 
				
			||||||
 | 
					      alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc,
 | 
				
			||||||
 | 
					                                    {operands[0]});
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    std::vector<Value> originalLoops;
 | 
				
			||||||
 | 
					    KrnlOptimizeLoopsOp optimizedLoopsOp;
 | 
				
			||||||
 | 
					    KrnlIterateOp iterateOp;
 | 
				
			||||||
 | 
					    emitKrnlLoopsAndIterationForOperand(
 | 
				
			||||||
 | 
					        rewriter, loc, operands[0], originalLoops,
 | 
				
			||||||
 | 
					        optimizedLoopsOp, iterateOp);
 | 
				
			||||||
 | 
					    Block &optimizationBlock = optimizedLoopsOp.region().front();
 | 
				
			||||||
 | 
					    Block &iterationBlock = iterateOp.bodyRegion().front();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    // 1. Insert any optimizations in the KrnlOptimizeLoopsOp body.
 | 
				
			||||||
 | 
					    rewriter.setInsertionPointToEnd(&optimizationBlock);
 | 
				
			||||||
 | 
					    // Return from KrnlOptimizeLoopsOp body.
 | 
				
			||||||
 | 
					    // When no optimizations are present we just return the loops
 | 
				
			||||||
 | 
					    // unchaged.
 | 
				
			||||||
 | 
					    rewriter.create<KrnlReturnLoopsOp>(loc, originalLoops);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    // 2. Insert instructions inside the KernelIterateOp body.
 | 
				
			||||||
 | 
					    rewriter.setInsertionPointToStart(&iterationBlock);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    // Handle the operation:
 | 
				
			||||||
 | 
					    SmallVector<Value, 4> loopIVs;
 | 
				
			||||||
 | 
					    for (auto arg : iterationBlock.getArguments())
 | 
				
			||||||
 | 
					      loopIVs.push_back(arg);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    auto loadedVal = rewriter.create<LoadOp>(loc, operands[0], loopIVs);
 | 
				
			||||||
 | 
					    auto loweredOpResult = mapToLowerScalarOp<ElementwiseUnaryOp>(
 | 
				
			||||||
 | 
					        op, memRefType.getElementType(), {loadedVal}, rewriter);
 | 
				
			||||||
 | 
					    // Store result in the resulting array.
 | 
				
			||||||
 | 
					    rewriter.create<StoreOp>(loc, loweredOpResult, alloc, loopIVs);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    rewriter.replaceOp(op, alloc);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    return matchSuccess();
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Element-wise variadic ops lowering to Krnl dialect.
 | 
				
			||||||
 | 
					//===----------------------------------------------------------------------===//
 | 
				
			||||||
 | 
					template <typename ElementwiseVariadicOp>
 | 
				
			||||||
 | 
					struct ONNXElementwiseVariadicOpLowering : public ConversionPattern {
 | 
				
			||||||
 | 
					  ONNXElementwiseVariadicOpLowering(MLIRContext *ctx)
 | 
				
			||||||
 | 
					      : ConversionPattern(ElementwiseVariadicOp::getOperationName(), 1, ctx) {}
 | 
				
			||||||
 | 
					  PatternMatchResult
 | 
				
			||||||
 | 
					  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
 | 
				
			||||||
 | 
					                  ConversionPatternRewriter &rewriter) const final {
 | 
				
			||||||
 | 
					    // TODO: Check that the types are valid.
 | 
				
			||||||
 | 
					    // An element-wise variadic operation must have all operands and the result
 | 
				
			||||||
 | 
					    // of the same type. This should have been verified by the verifier.
 | 
				
			||||||
 | 
					    auto tensorType = (*op->result_type_begin()).cast<TensorType>();
 | 
				
			||||||
 | 
					    auto loc = op->getLoc();
 | 
				
			||||||
 | 
					    auto numArgs = op->getNumOperands();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    // Insert an allocation and deallocation for the result of this operation.
 | 
				
			||||||
 | 
					    auto memRefType = convertTensorToMemRef(tensorType);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Value alloc;
 | 
				
			||||||
 | 
					    bool insertDealloc = checkInsertDealloc(op);
 | 
				
			||||||
 | 
					    // If the output has a dynamic dimension, we compute its dimension at
 | 
				
			||||||
 | 
					    // runtime by using dimensions from the operands.
 | 
				
			||||||
 | 
					    // In particular, we need to know from which operand a result dimension
 | 
				
			||||||
 | 
					    // comes from.
 | 
				
			||||||
 | 
					    // TODO: can the dimension of the result differ after optimizations?
 | 
				
			||||||
 | 
					    if (hasAllConstantDimensions(memRefType))
 | 
				
			||||||
 | 
					      alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc);
 | 
				
			||||||
 | 
					    else
 | 
				
			||||||
 | 
					      alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc,
 | 
				
			||||||
 | 
					                                    operands);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    // Get run-time dimension information for unknown dimensions used for
 | 
				
			||||||
 | 
					    // broadcasting.
 | 
				
			||||||
 | 
					    std::map<int, std::map<int, Value>> broadcastedDimInfo =
 | 
				
			||||||
 | 
					        getBroadcastedDimInfo(loc, rewriter, memRefType, operands);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    std::vector<Value> originalLoops;
 | 
				
			||||||
 | 
					    KrnlOptimizeLoopsOp optimizedLoopsOp;
 | 
				
			||||||
 | 
					    KrnlIterateOp iterateOp;
 | 
				
			||||||
 | 
					    emitKrnlLoopsAndIterationForOperand(
 | 
				
			||||||
 | 
					        rewriter, loc, alloc, originalLoops,
 | 
				
			||||||
 | 
					        optimizedLoopsOp, iterateOp);
 | 
				
			||||||
 | 
					    Block &optimizationBlock = optimizedLoopsOp.region().front();
 | 
				
			||||||
 | 
					    Block &iterationBlock = iterateOp.bodyRegion().front();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    // 1. Insert any optimizations in the KrnlOptimizeLoopsOp body.
 | 
				
			||||||
 | 
					    rewriter.setInsertionPointToEnd(&optimizationBlock);
 | 
				
			||||||
 | 
					    // Return from KrnlOptimizeLoopsOp body.
 | 
				
			||||||
 | 
					    // When no optimizations are present we just return the loops unchaged.
 | 
				
			||||||
 | 
					    rewriter.create<KrnlReturnLoopsOp>(loc, originalLoops);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    // 2. Insert instructions inside the KernelIterateOp body.
 | 
				
			||||||
 | 
					    rewriter.setInsertionPointToStart(&iterationBlock);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    // Handle the operation:
 | 
				
			||||||
 | 
					    SmallVector<Value, 4> loopIVs;
 | 
				
			||||||
 | 
					    for (auto arg : iterationBlock.getArguments())
 | 
				
			||||||
 | 
					      loopIVs.push_back(arg);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    // Fold over operands for each of their scalar values
 | 
				
			||||||
 | 
					    Value accumulated, next;
 | 
				
			||||||
 | 
					    auto accumulatedLoopIVs = getLoopIVsForBroadcasting(
 | 
				
			||||||
 | 
					        loc, rewriter, loopIVs, operands[0], broadcastedDimInfo[0]);
 | 
				
			||||||
 | 
					    accumulated = rewriter.create<LoadOp>(loc, operands[0], accumulatedLoopIVs);
 | 
				
			||||||
 | 
					    for (unsigned i = 1; i < numArgs; i++) {
 | 
				
			||||||
 | 
					      auto nextLoopIVs = getLoopIVsForBroadcasting(
 | 
				
			||||||
 | 
					          loc, rewriter, loopIVs, operands[i], broadcastedDimInfo[i]);
 | 
				
			||||||
 | 
					      next = rewriter.create<LoadOp>(loc, operands[i], nextLoopIVs);
 | 
				
			||||||
 | 
					      accumulated = mapToLowerScalarOp<ElementwiseVariadicOp>(
 | 
				
			||||||
 | 
					          op, memRefType.getElementType(), {accumulated, next}, rewriter);
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					    // Store result in the resulting array.
 | 
				
			||||||
 | 
					    rewriter.create<StoreOp>(loc, accumulated, alloc, loopIVs);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    rewriter.replaceOp(op, alloc);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    return matchSuccess();
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					void populateLoweringONNXElementwiseOpPattern(
 | 
				
			||||||
 | 
					    OwningRewritePatternList &patterns, MLIRContext *ctx) {
 | 
				
			||||||
 | 
					  patterns.insert<ONNXElementwiseVariadicOpLowering<mlir::ONNXAddOp>,
 | 
				
			||||||
 | 
					                  ONNXElementwiseVariadicOpLowering<mlir::ONNXAndOp>,
 | 
				
			||||||
 | 
					                  ONNXElementwiseUnaryOpLowering<mlir::ONNXCosOp>,
 | 
				
			||||||
 | 
					                  ONNXElementwiseUnaryOpLowering<mlir::ONNXCoshOp>,
 | 
				
			||||||
 | 
					                  ONNXElementwiseVariadicOpLowering<mlir::ONNXDivOp>,
 | 
				
			||||||
 | 
					                  ONNXElementwiseUnaryOpLowering<mlir::ONNXEluOp>,
 | 
				
			||||||
 | 
					                  ONNXElementwiseUnaryOpLowering<mlir::ONNXExpOp>,
 | 
				
			||||||
 | 
					                  ONNXElementwiseUnaryOpLowering<mlir::ONNXHardSigmoidOp>,
 | 
				
			||||||
 | 
					                  ONNXElementwiseUnaryOpLowering<mlir::ONNXLeakyReluOp>,
 | 
				
			||||||
 | 
					                  ONNXElementwiseUnaryOpLowering<mlir::ONNXLogOp>,
 | 
				
			||||||
 | 
					                  ONNXElementwiseVariadicOpLowering<mlir::ONNXMaxOp>,
 | 
				
			||||||
 | 
					                  ONNXElementwiseVariadicOpLowering<mlir::ONNXMinOp>,
 | 
				
			||||||
 | 
					                  ONNXElementwiseVariadicOpLowering<mlir::ONNXMulOp>,
 | 
				
			||||||
 | 
					                  ONNXElementwiseVariadicOpLowering<mlir::ONNXOrOp>,
 | 
				
			||||||
 | 
					                  ONNXElementwiseUnaryOpLowering<mlir::ONNXReciprocalOp>,
 | 
				
			||||||
 | 
					                  ONNXElementwiseUnaryOpLowering<mlir::ONNXReluOp>,
 | 
				
			||||||
 | 
					                  ONNXElementwiseUnaryOpLowering<mlir::ONNXSeluOp>,
 | 
				
			||||||
 | 
					                  ONNXElementwiseUnaryOpLowering<mlir::ONNXSigmoidOp>,
 | 
				
			||||||
 | 
					                  ONNXElementwiseUnaryOpLowering<mlir::ONNXSignOp>,
 | 
				
			||||||
 | 
					                  ONNXElementwiseUnaryOpLowering<mlir::ONNXSinhOp>,
 | 
				
			||||||
 | 
					                  ONNXElementwiseUnaryOpLowering<mlir::ONNXSoftplusOp>,
 | 
				
			||||||
 | 
					                  ONNXElementwiseUnaryOpLowering<mlir::ONNXSoftsignOp>,
 | 
				
			||||||
 | 
					                  ONNXElementwiseUnaryOpLowering<mlir::ONNXSqrtOp>,
 | 
				
			||||||
 | 
					                  ONNXElementwiseVariadicOpLowering<mlir::ONNXSubOp>,
 | 
				
			||||||
 | 
					                  ONNXElementwiseVariadicOpLowering<mlir::ONNXSumOp>,
 | 
				
			||||||
 | 
					                  ONNXElementwiseUnaryOpLowering<mlir::ONNXTanhOp>,
 | 
				
			||||||
 | 
					                  ONNXElementwiseVariadicOpLowering<mlir::ONNXXorOp>>(ctx);
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
| 
						 | 
					@ -0,0 +1,209 @@
 | 
				
			||||||
 | 
					//===----- gemm.inc - Lowering Gemm Op ------------------------------------===//
 | 
				
			||||||
 | 
					//
 | 
				
			||||||
 | 
					// Copyright 2019 The IBM Research Authors.
 | 
				
			||||||
 | 
					//
 | 
				
			||||||
 | 
					// =============================================================================
 | 
				
			||||||
 | 
					//
 | 
				
			||||||
 | 
					// This file lowers the ONNX Gemm Operator to Krnl dialect.
 | 
				
			||||||
 | 
					//
 | 
				
			||||||
 | 
					//===----------------------------------------------------------------------===//
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					struct ONNXGemmOpLowering : public ConversionPattern {
 | 
				
			||||||
 | 
					  ONNXGemmOpLowering(MLIRContext *ctx)
 | 
				
			||||||
 | 
					      : ConversionPattern(mlir::ONNXGemmOp::getOperationName(), 1, ctx) {}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  PatternMatchResult
 | 
				
			||||||
 | 
					  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
 | 
				
			||||||
 | 
					                  ConversionPatternRewriter &rewriter) const final {
 | 
				
			||||||
 | 
					    auto tensorType = (*op->result_type_begin()).cast<TensorType>();
 | 
				
			||||||
 | 
					    auto loc = op->getLoc();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Value A, B, C;
 | 
				
			||||||
 | 
					    A = operands[0];
 | 
				
			||||||
 | 
					    B = operands[1];
 | 
				
			||||||
 | 
					    C = operands[2];
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    auto alphaAttr = FloatAttr::get(tensorType.getElementType(),
 | 
				
			||||||
 | 
					        llvm::dyn_cast<ONNXGemmOp>(op).alpha().convertToFloat());
 | 
				
			||||||
 | 
					    auto betaAttr = FloatAttr::get(tensorType.getElementType(),
 | 
				
			||||||
 | 
					        llvm::dyn_cast<ONNXGemmOp>(op).beta().convertToFloat());
 | 
				
			||||||
 | 
					    auto alpha = rewriter.create<ConstantOp>(loc, alphaAttr);
 | 
				
			||||||
 | 
					    auto beta = rewriter.create<ConstantOp>(loc, betaAttr);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    bool isTransA = (llvm::dyn_cast<ONNXGemmOp>(op).transA() != 0);
 | 
				
			||||||
 | 
					    bool isTransB = (llvm::dyn_cast<ONNXGemmOp>(op).transB() != 0);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    // Result type
 | 
				
			||||||
 | 
					    auto memRefType = convertTensorToMemRef(tensorType);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    // Insert an allocation and deallocation for the result of this operation.
 | 
				
			||||||
 | 
					    Value alloc;
 | 
				
			||||||
 | 
					    bool insertDealloc = checkInsertDealloc(op);
 | 
				
			||||||
 | 
					    if (hasAllConstantDimensions(memRefType))
 | 
				
			||||||
 | 
					      alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc);
 | 
				
			||||||
 | 
					    else {
 | 
				
			||||||
 | 
					      auto memRefShape = memRefType.getShape();
 | 
				
			||||||
 | 
					      SmallVector<Value, 2> allocOperands;
 | 
				
			||||||
 | 
					      if (memRefShape[0] < 0) {
 | 
				
			||||||
 | 
					        auto dim = rewriter.create<DimOp>(loc, A, (isTransA) ? 1 : 0);
 | 
				
			||||||
 | 
					        allocOperands.emplace_back(dim);
 | 
				
			||||||
 | 
					      }
 | 
				
			||||||
 | 
					      if (memRefShape[1] < 0) {
 | 
				
			||||||
 | 
					        auto dim = rewriter.create<DimOp>(loc, B, (isTransB) ? 0 : 1);
 | 
				
			||||||
 | 
					        allocOperands.emplace_back(dim);
 | 
				
			||||||
 | 
					      }
 | 
				
			||||||
 | 
					      alloc = rewriter.create<AllocOp>(loc, memRefType, allocOperands);
 | 
				
			||||||
 | 
					      if (insertDealloc) {
 | 
				
			||||||
 | 
					        auto *parentBlock = alloc.getDefiningOp()->getBlock();
 | 
				
			||||||
 | 
					        auto dealloc = rewriter.create<DeallocOp>(loc, alloc);
 | 
				
			||||||
 | 
					        dealloc.getOperation()->moveBefore(&parentBlock->back());
 | 
				
			||||||
 | 
					      }
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    // Number of loops
 | 
				
			||||||
 | 
					    auto memRefShape = memRefType.getShape();
 | 
				
			||||||
 | 
					    int64_t numLoops = 3;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    // Define loops.
 | 
				
			||||||
 | 
					    std::vector<Value> originalLoops;
 | 
				
			||||||
 | 
					    std::vector<Value> optimizedLoops;
 | 
				
			||||||
 | 
					    Block *optimizationBlock = defineLoops(rewriter, loc, originalLoops,
 | 
				
			||||||
 | 
					            optimizedLoops, numLoops);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    // We have two Krnl loops:
 | 
				
			||||||
 | 
					    // - Outer loop iterates over the output matrix dimensions, and
 | 
				
			||||||
 | 
					    // - Reduction loop iterates over the reduction dimension.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    // Outer loop
 | 
				
			||||||
 | 
					    std::vector<Value> outerLoops, optimizedOuterLoops;
 | 
				
			||||||
 | 
					    outerLoops.reserve(2);
 | 
				
			||||||
 | 
					    optimizedOuterLoops.reserve(2);
 | 
				
			||||||
 | 
					    for (int i = 0; i < 2; ++i) {
 | 
				
			||||||
 | 
					      outerLoops.push_back(originalLoops[i]);
 | 
				
			||||||
 | 
					      optimizedOuterLoops.push_back(optimizedLoops[i]);
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					    KrnlIterateOperandPack outerPack(rewriter, outerLoops,
 | 
				
			||||||
 | 
					                                      optimizedOuterLoops);
 | 
				
			||||||
 | 
					    // Induction variables for the outer loops
 | 
				
			||||||
 | 
					    for (int i = 0; i < 2; ++i)
 | 
				
			||||||
 | 
					      addDimensionToPack(rewriter, loc, outerPack, alloc, i);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    // Reduction loop
 | 
				
			||||||
 | 
					    std::vector<Value> reductionLoops, optimizedReductionLoops;
 | 
				
			||||||
 | 
					    reductionLoops.reserve(1);
 | 
				
			||||||
 | 
					    optimizedReductionLoops.reserve(1);
 | 
				
			||||||
 | 
					    reductionLoops.push_back(originalLoops[2]);
 | 
				
			||||||
 | 
					    optimizedReductionLoops.push_back(optimizedLoops[2]);
 | 
				
			||||||
 | 
					    KrnlIterateOperandPack reductionPack(rewriter, reductionLoops,
 | 
				
			||||||
 | 
					                                         optimizedReductionLoops);
 | 
				
			||||||
 | 
					    // Induction variable for the reduction dimension
 | 
				
			||||||
 | 
					    // Try to find and use a static value from A or B first.
 | 
				
			||||||
 | 
					    // If it failed then use a dynamic value.
 | 
				
			||||||
 | 
					    auto ATy = A.getType().cast<MemRefType>();
 | 
				
			||||||
 | 
					    auto BTy = B.getType().cast<MemRefType>();
 | 
				
			||||||
 | 
					    int64_t K_A_Idx = (isTransA) ? 0 : 1;
 | 
				
			||||||
 | 
					    int64_t K_B_Idx = (isTransB) ? 1 : 0;
 | 
				
			||||||
 | 
					    reductionPack.pushConstantBound(0);
 | 
				
			||||||
 | 
					    if (ATy.getShape()[K_A_Idx] != -1)
 | 
				
			||||||
 | 
					        reductionPack.pushConstantBound(ATy.getShape()[K_A_Idx]);
 | 
				
			||||||
 | 
					    else
 | 
				
			||||||
 | 
					      if (BTy.getShape()[K_B_Idx] != -1)
 | 
				
			||||||
 | 
					        reductionPack.pushConstantBound(BTy.getShape()[K_B_Idx]);
 | 
				
			||||||
 | 
					      else
 | 
				
			||||||
 | 
					        reductionPack.pushOperandBound(
 | 
				
			||||||
 | 
					            rewriter.create<DimOp>(loc, B, K_B_Idx).getResult());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    // Get run-time dimension information for unknown dimensions used for
 | 
				
			||||||
 | 
					    // broadcasting.
 | 
				
			||||||
 | 
					    // GemmOp supports unidirectional broadcasting from C to A*B.
 | 
				
			||||||
 | 
					    // Hence, it must be enough to get broadcasting information for C only.
 | 
				
			||||||
 | 
					    std::map<int, Value> broadcastedDimInfo;
 | 
				
			||||||
 | 
					    auto shape = C.getType().cast<MemRefType>().getShape();
 | 
				
			||||||
 | 
					    for (int i = 0; i < shape.size(); ++i) {
 | 
				
			||||||
 | 
					      if (shape[i] < 0) {
 | 
				
			||||||
 | 
					        auto dim = rewriter.create<DimOp>(loc, C, i).getResult();
 | 
				
			||||||
 | 
					        auto one = rewriter.create<ConstantIndexOp>(loc, 1);
 | 
				
			||||||
 | 
					        auto isBroadcasted =
 | 
				
			||||||
 | 
					          rewriter.create<CmpIOp>(loc, CmpIPredicate::eq, dim, one);
 | 
				
			||||||
 | 
					        broadcastedDimInfo.insert(std::make_pair(i, isBroadcasted));
 | 
				
			||||||
 | 
					      }
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    auto outerIterateOp = rewriter.create<KrnlIterateOp>(loc, outerPack);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    // Now perform the insertions into the body of the
 | 
				
			||||||
 | 
					    // just generated instructions:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    // No optimization
 | 
				
			||||||
 | 
					    rewriter.setInsertionPointToEnd(optimizationBlock);
 | 
				
			||||||
 | 
					    rewriter.create<KrnlReturnLoopsOp>(loc, originalLoops);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    // Insert instructions inside the outer loop.
 | 
				
			||||||
 | 
					    Block &outerIterationBlock = outerIterateOp.bodyRegion().front();
 | 
				
			||||||
 | 
					    rewriter.setInsertionPointToStart(&outerIterationBlock);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    // Induction variables
 | 
				
			||||||
 | 
					    SmallVector<Value, 4> loopMNIVs;
 | 
				
			||||||
 | 
					    for (auto arg : outerIterationBlock.getArguments()) {
 | 
				
			||||||
 | 
					      loopMNIVs.emplace_back(arg);
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    // Initialize the output of A*B
 | 
				
			||||||
 | 
					    auto zero = rewriter.create<ConstantOp>(
 | 
				
			||||||
 | 
					        loc, FloatAttr::get(memRefType.getElementType(), 0));
 | 
				
			||||||
 | 
					    rewriter.create<StoreOp>(loc, zero, alloc, loopMNIVs);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    // Compute A*B
 | 
				
			||||||
 | 
					    auto matmulIterateOp = rewriter.create<KrnlIterateOp>(loc, reductionPack);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    // Compute beta*C, and add up to alpha*A*B (unidirectional broadcasting)
 | 
				
			||||||
 | 
					    auto loopCIVs = getLoopIVsForBroadcasting(
 | 
				
			||||||
 | 
					        loc, rewriter, loopMNIVs, C, broadcastedDimInfo);
 | 
				
			||||||
 | 
					    auto loadedC = rewriter.create<LoadOp>(loc, C, loopCIVs);
 | 
				
			||||||
 | 
					    auto loadedAB = rewriter.create<LoadOp>(loc, alloc, loopMNIVs);
 | 
				
			||||||
 | 
					    auto alphaAB = rewriter.create<MulFOp>(loc, alpha, loadedAB);
 | 
				
			||||||
 | 
					    auto betaC = rewriter.create<MulFOp>(loc, beta, loadedC);
 | 
				
			||||||
 | 
					    auto Y = rewriter.create<AddFOp>(loc, alphaAB, betaC);
 | 
				
			||||||
 | 
					    rewriter.create<StoreOp>(loc, Y, alloc, loopMNIVs);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    // Insert instructions to do matrix multiplication: A*B
 | 
				
			||||||
 | 
					    Block &matmulIterationBlock = matmulIterateOp.bodyRegion().front();
 | 
				
			||||||
 | 
					    rewriter.setInsertionPointToStart(&matmulIterationBlock);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    // Induction variables
 | 
				
			||||||
 | 
					    SmallVector<Value, 4> loopKIVs, loopAIVs, loopBIVs;
 | 
				
			||||||
 | 
					    for (auto arg : matmulIterationBlock.getArguments())
 | 
				
			||||||
 | 
					      loopKIVs.emplace_back(arg);
 | 
				
			||||||
 | 
					    if (isTransA) {
 | 
				
			||||||
 | 
					      loopAIVs.emplace_back(loopKIVs[0]);
 | 
				
			||||||
 | 
					      loopAIVs.emplace_back(loopMNIVs[0]);
 | 
				
			||||||
 | 
					    } else {
 | 
				
			||||||
 | 
					      loopAIVs.emplace_back(loopMNIVs[0]);
 | 
				
			||||||
 | 
					      loopAIVs.emplace_back(loopKIVs[0]);
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					    if (isTransB) {
 | 
				
			||||||
 | 
					      loopBIVs.emplace_back(loopMNIVs[1]);
 | 
				
			||||||
 | 
					      loopBIVs.emplace_back(loopKIVs[0]);
 | 
				
			||||||
 | 
					    } else {
 | 
				
			||||||
 | 
					      loopBIVs.emplace_back(loopKIVs[0]);
 | 
				
			||||||
 | 
					      loopBIVs.emplace_back(loopMNIVs[1]);
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    // Matmul computation
 | 
				
			||||||
 | 
					    auto loadedA = rewriter.create<LoadOp>(loc, A, loopAIVs);
 | 
				
			||||||
 | 
					    auto loadedB = rewriter.create<LoadOp>(loc, B, loopBIVs);
 | 
				
			||||||
 | 
					    auto loadedY = rewriter.create<LoadOp>(loc, alloc, loopMNIVs);
 | 
				
			||||||
 | 
					    auto AB = rewriter.create<MulFOp>(loc, loadedA, loadedB);
 | 
				
			||||||
 | 
					    auto accumulated = rewriter.create<AddFOp>(loc, loadedY, AB);
 | 
				
			||||||
 | 
					    rewriter.create<StoreOp>(loc, accumulated, alloc, loopMNIVs);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    rewriter.replaceOp(op, alloc);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    return matchSuccess();
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					void populateLoweringONNXGemmOpPattern(
 | 
				
			||||||
 | 
					    OwningRewritePatternList &patterns, MLIRContext *ctx) {
 | 
				
			||||||
 | 
					  patterns.insert<ONNXGemmOpLowering>(ctx);
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
| 
						 | 
					@ -0,0 +1,345 @@
 | 
				
			||||||
 | 
					//===----- matmul.inc - Lowering Matmul Op --------------------------------===//
 | 
				
			||||||
 | 
					//
 | 
				
			||||||
 | 
					// Copyright 2019 The IBM Research Authors.
 | 
				
			||||||
 | 
					//
 | 
				
			||||||
 | 
					// =============================================================================
 | 
				
			||||||
 | 
					//
 | 
				
			||||||
 | 
					// This file lowers the ONNX Matmul Operator to Krnl dialect.
 | 
				
			||||||
 | 
					//
 | 
				
			||||||
 | 
					//===----------------------------------------------------------------------===//
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					struct ONNXMatMulOpLowering : public ConversionPattern {
 | 
				
			||||||
 | 
					  ONNXMatMulOpLowering(MLIRContext *ctx)
 | 
				
			||||||
 | 
					      : ConversionPattern(mlir::ONNXMatMulOp::getOperationName(), 1, ctx) {}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  PatternMatchResult
 | 
				
			||||||
 | 
					  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
 | 
				
			||||||
 | 
					                  ConversionPatternRewriter &rewriter) const final {
 | 
				
			||||||
 | 
					    auto tensorType = (*op->result_type_begin()).cast<TensorType>();
 | 
				
			||||||
 | 
					    auto loc = op->getLoc();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Value A = operands[0];
 | 
				
			||||||
 | 
					    Value B = operands[1];
 | 
				
			||||||
 | 
					    auto AShape = A.getType().cast<MemRefType>().getShape();
 | 
				
			||||||
 | 
					    auto BShape = B.getType().cast<MemRefType>().getShape();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    // There are three cases related to the shapes of the two arguments:
 | 
				
			||||||
 | 
					    // - Both arguments are N-D, N >= 2
 | 
				
			||||||
 | 
					    // - Either argument is 1-D, the other is N-D, N >= 2
 | 
				
			||||||
 | 
					    // - Both arguments are 1-D
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    // Result type
 | 
				
			||||||
 | 
					    auto memRefType = convertTensorToMemRef(tensorType);
 | 
				
			||||||
 | 
					    auto elementType = memRefType.getElementType();
 | 
				
			||||||
 | 
					    auto memRefShape = memRefType.getShape();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    // A value zero
 | 
				
			||||||
 | 
					    Value zero;
 | 
				
			||||||
 | 
					    if (elementType.isa<IntegerType>()) {
 | 
				
			||||||
 | 
					      zero = rewriter.create<ConstantOp>(
 | 
				
			||||||
 | 
					          loc, IntegerAttr::get(memRefType.getElementType(), 0));
 | 
				
			||||||
 | 
					    } else if (elementType.isa<FloatType>()) {
 | 
				
			||||||
 | 
					      zero = rewriter.create<ConstantOp>(
 | 
				
			||||||
 | 
					          loc, FloatAttr::get(memRefType.getElementType(), 0));
 | 
				
			||||||
 | 
					    } else {
 | 
				
			||||||
 | 
					      emitError(loc, "unsupported element type");
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    // Insert an allocation and deallocation for the result of this operation.
 | 
				
			||||||
 | 
					    Value alloc;
 | 
				
			||||||
 | 
					    bool insertDealloc = checkInsertDealloc(op);
 | 
				
			||||||
 | 
					    if (hasAllConstantDimensions(memRefType))
 | 
				
			||||||
 | 
					      alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc);
 | 
				
			||||||
 | 
					    else {
 | 
				
			||||||
 | 
					      SmallVector<Value, 4> allocOperands;
 | 
				
			||||||
 | 
					      if (AShape.size() >= 2 && BShape.size() >= 2) {
 | 
				
			||||||
 | 
					        // Both arguments are N-D, N >= 2
 | 
				
			||||||
 | 
					        // (s1 x s2 x... x sK x M x K) MATMUL (K x N)
 | 
				
			||||||
 | 
					        // =>
 | 
				
			||||||
 | 
					        // (s1 x s2 x... x sK x M x N)
 | 
				
			||||||
 | 
					        for (int i = 0; i < memRefShape.size() - 2; ++i) {
 | 
				
			||||||
 | 
					          if (memRefShape[i] < 0) {
 | 
				
			||||||
 | 
					            if ((AShape.size() == 2) && (BShape.size() > 2))
 | 
				
			||||||
 | 
					              allocOperands.emplace_back(rewriter.create<DimOp>(loc, B, i));
 | 
				
			||||||
 | 
					            else if ((AShape.size() > 2) && (BShape.size() == 2))
 | 
				
			||||||
 | 
					              allocOperands.emplace_back(rewriter.create<DimOp>(loc, A, i));
 | 
				
			||||||
 | 
					          }
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					        if (memRefShape[memRefShape.size() - 2] < 0) {
 | 
				
			||||||
 | 
					          auto dim = rewriter.create<DimOp>(loc, A, memRefShape.size() - 2);
 | 
				
			||||||
 | 
					          allocOperands.emplace_back(dim);
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					        if (memRefShape[memRefShape.size() - 1] < 0) {
 | 
				
			||||||
 | 
					          auto dim = rewriter.create<DimOp>(loc, B, memRefShape.size() - 1);
 | 
				
			||||||
 | 
					          allocOperands.emplace_back(dim);
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					      } else if (AShape.size() == 1 && BShape.size() >= 2) {
 | 
				
			||||||
 | 
					        // Either argument is 1-D
 | 
				
			||||||
 | 
					        // K MATMUL (s1 x s2 x... x sK x K x N)
 | 
				
			||||||
 | 
					        // =>
 | 
				
			||||||
 | 
					        // (s1 x s2 x... x sK x N)
 | 
				
			||||||
 | 
					        for (int i = 0; i < memRefShape.size() - 1; ++i) {
 | 
				
			||||||
 | 
					          if (memRefShape[i] < 0) {
 | 
				
			||||||
 | 
					            auto dim = rewriter.create<DimOp>(loc, B, i);
 | 
				
			||||||
 | 
					            allocOperands.emplace_back(dim);
 | 
				
			||||||
 | 
					          }
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					        if (memRefShape[memRefShape.size() - 1] < 0) {
 | 
				
			||||||
 | 
					          auto dim = rewriter.create<DimOp>(loc, B, BShape.size() - 1);
 | 
				
			||||||
 | 
					          allocOperands.emplace_back(dim);
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					      } else if (AShape.size() >= 2 && BShape.size() == 1) {
 | 
				
			||||||
 | 
					        // Either argument is 1-D
 | 
				
			||||||
 | 
					        // (s1 x s2 x... x sK x M x K) MATMUL K
 | 
				
			||||||
 | 
					        // =>
 | 
				
			||||||
 | 
					        // (s1 x s2 x... x sK x M)
 | 
				
			||||||
 | 
					        for (int i = 0; i < memRefShape.size() - 1; ++i) {
 | 
				
			||||||
 | 
					          if (memRefShape[i] < 0) {
 | 
				
			||||||
 | 
					            auto dim = rewriter.create<DimOp>(loc, A, i);
 | 
				
			||||||
 | 
					            allocOperands.emplace_back(dim);
 | 
				
			||||||
 | 
					          }
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					        if (memRefShape[memRefShape.size() - 1] < 0) {
 | 
				
			||||||
 | 
					          auto dim = rewriter.create<DimOp>(loc, A, AShape.size() - 2);
 | 
				
			||||||
 | 
					          allocOperands.emplace_back(dim);
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					      } else if (AShape.size() == 1 && BShape.size() == 1) {
 | 
				
			||||||
 | 
					        // Both arguments are 1-D
 | 
				
			||||||
 | 
					        if (memRefShape[0] < 0) {
 | 
				
			||||||
 | 
					          auto dim = rewriter.create<DimOp>(loc, A, 0);
 | 
				
			||||||
 | 
					          allocOperands.emplace_back(dim);
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					      } else {
 | 
				
			||||||
 | 
					        emitError(loc, "Invalid shapes");
 | 
				
			||||||
 | 
					      }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					      alloc = rewriter.create<AllocOp>(loc, memRefType, allocOperands);
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if (AShape.size() >= 2 || BShape.size() >= 2) {
 | 
				
			||||||
 | 
					      // Cases 1 and 2:
 | 
				
			||||||
 | 
					      // - Both arguments are N-D, N >= 2
 | 
				
			||||||
 | 
					      // - Either argument is 1-D, the other is N-D, N >= 2
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					      // Define loops for batch dimensions.
 | 
				
			||||||
 | 
					      std::vector<Value> originalLoops;
 | 
				
			||||||
 | 
					      std::vector<Value> optimizedLoops;
 | 
				
			||||||
 | 
					      Block *optimizationBlock = defineLoops(rewriter, loc, originalLoops,
 | 
				
			||||||
 | 
					            optimizedLoops, memRefShape.size());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					      // Outer KrnlIterateOp
 | 
				
			||||||
 | 
					      SmallVector<Value, 4> loopBatchIVs;
 | 
				
			||||||
 | 
					      bool hasBatchLoop = false;
 | 
				
			||||||
 | 
					      if (AShape.size() > 2 || BShape.size() > 2) {
 | 
				
			||||||
 | 
					        SmallVector<int, 4> batchAxes;
 | 
				
			||||||
 | 
					        int matmulResultDims =
 | 
				
			||||||
 | 
					            ((AShape.size() == 1 || BShape.size() == 1)) ? 1 : 2;
 | 
				
			||||||
 | 
					        for (int i = 0; i < memRefShape.size() - matmulResultDims; ++i)
 | 
				
			||||||
 | 
					          batchAxes.emplace_back(i);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        std::vector<Value> outerLoops, optimizedOuterLoops;
 | 
				
			||||||
 | 
					        outerLoops.reserve(batchAxes.size());
 | 
				
			||||||
 | 
					        optimizedOuterLoops.reserve(batchAxes.size());
 | 
				
			||||||
 | 
					        for (int i = 0; i < batchAxes.size(); ++i) {
 | 
				
			||||||
 | 
					          outerLoops.push_back(originalLoops[i]);
 | 
				
			||||||
 | 
					          optimizedOuterLoops.push_back(optimizedLoops[i]);
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					        KrnlIterateOperandPack outerPack(rewriter, outerLoops,
 | 
				
			||||||
 | 
					                                         optimizedOuterLoops);
 | 
				
			||||||
 | 
					        for (int i = 0; i < batchAxes.size(); ++i) {
 | 
				
			||||||
 | 
					          addDimensionToPack(rewriter, loc, outerPack, alloc, i);
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					        auto outerIterateOp = rewriter.create<KrnlIterateOp>(loc, outerPack);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        // No optimization
 | 
				
			||||||
 | 
					        rewriter.setInsertionPointToEnd(optimizationBlock);
 | 
				
			||||||
 | 
					        rewriter.create<KrnlReturnLoopsOp>(loc, originalLoops);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        // Insert instructions into the outer KrnlIterateOp.
 | 
				
			||||||
 | 
					        Block &outerIterationBlock = outerIterateOp.bodyRegion().front();
 | 
				
			||||||
 | 
					        rewriter.setInsertionPointToStart(&outerIterationBlock);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        // Induction variables: non-matrix-multiplication variables.
 | 
				
			||||||
 | 
					        for (auto arg : outerIterationBlock.getArguments()) {
 | 
				
			||||||
 | 
					          loopBatchIVs.emplace_back(arg);
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        hasBatchLoop = true;
 | 
				
			||||||
 | 
					      }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					      // Now, we define loops for matrix multiplication.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					      // Create a KrnlIterateOp for matrix multiplication.
 | 
				
			||||||
 | 
					      KrnlIterateOp matmulIterateOp;
 | 
				
			||||||
 | 
					      std::vector<Value> matmulLoops, optimizedMatmulLoops;
 | 
				
			||||||
 | 
					      if (AShape.size() >= 2 && BShape.size() >= 2) {
 | 
				
			||||||
 | 
					        // 2-D x 2-D. Result has two dimensions.
 | 
				
			||||||
 | 
					        matmulLoops.reserve(2);
 | 
				
			||||||
 | 
					        optimizedMatmulLoops.reserve(2);
 | 
				
			||||||
 | 
					        for (int i = 2; i > 0; --i) {
 | 
				
			||||||
 | 
					          matmulLoops.emplace_back(originalLoops[memRefShape.size() - i]);
 | 
				
			||||||
 | 
					          optimizedMatmulLoops.emplace_back(
 | 
				
			||||||
 | 
					              optimizedLoops[memRefShape.size() - i]);
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					        KrnlIterateOperandPack matmulPack(rewriter, matmulLoops,
 | 
				
			||||||
 | 
					                                          optimizedMatmulLoops);
 | 
				
			||||||
 | 
					        for (int i = 2; i > 0; --i) {
 | 
				
			||||||
 | 
					          addDimensionToPack(rewriter, loc, matmulPack, alloc,
 | 
				
			||||||
 | 
					                             memRefShape.size() - i);
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					        matmulIterateOp = rewriter.create<KrnlIterateOp>(loc, matmulPack);
 | 
				
			||||||
 | 
					      } else {
 | 
				
			||||||
 | 
					        // 1-D x 2-D, and vice versa. Result has one dimension.
 | 
				
			||||||
 | 
					        matmulLoops.reserve(1);
 | 
				
			||||||
 | 
					        optimizedMatmulLoops.reserve(1);
 | 
				
			||||||
 | 
					        matmulLoops.emplace_back(originalLoops[memRefShape.size() - 1]);
 | 
				
			||||||
 | 
					        optimizedMatmulLoops.emplace_back(
 | 
				
			||||||
 | 
					            optimizedLoops[memRefShape.size() - 1]);
 | 
				
			||||||
 | 
					        KrnlIterateOperandPack matmulPack(rewriter, matmulLoops,
 | 
				
			||||||
 | 
					                                          optimizedMatmulLoops);
 | 
				
			||||||
 | 
					        addDimensionToPack(rewriter, loc, matmulPack, alloc,
 | 
				
			||||||
 | 
					                           memRefShape.size() - 1);
 | 
				
			||||||
 | 
					        matmulIterateOp = rewriter.create<KrnlIterateOp>(loc, matmulPack);
 | 
				
			||||||
 | 
					      }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					      if (!hasBatchLoop) {
 | 
				
			||||||
 | 
					        // No optimization
 | 
				
			||||||
 | 
					        rewriter.setInsertionPointToEnd(optimizationBlock);
 | 
				
			||||||
 | 
					        rewriter.create<KrnlReturnLoopsOp>(loc, originalLoops);
 | 
				
			||||||
 | 
					      }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					      // Insert instructions into the matmul KrnlIterateOp.
 | 
				
			||||||
 | 
					      Block &matmulIterationBlock = matmulIterateOp.bodyRegion().front();
 | 
				
			||||||
 | 
					      rewriter.setInsertionPointToStart(&matmulIterationBlock);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					      // Induction variables: M, N
 | 
				
			||||||
 | 
					      SmallVector<Value, 4> loopMNIVs;
 | 
				
			||||||
 | 
					      for (auto arg : matmulIterationBlock.getArguments()) {
 | 
				
			||||||
 | 
					        loopMNIVs.emplace_back(arg);
 | 
				
			||||||
 | 
					      }
 | 
				
			||||||
 | 
					      // Induction variables for the final result.
 | 
				
			||||||
 | 
					      SmallVector<Value, 4> loopBatchMNIVs;
 | 
				
			||||||
 | 
					      for (auto arg : loopBatchIVs) {
 | 
				
			||||||
 | 
					        loopBatchMNIVs.emplace_back(arg);
 | 
				
			||||||
 | 
					      }
 | 
				
			||||||
 | 
					      for (auto arg : loopMNIVs) {
 | 
				
			||||||
 | 
					        loopBatchMNIVs.emplace_back(arg);
 | 
				
			||||||
 | 
					      }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					      // Fill the output with value 0.
 | 
				
			||||||
 | 
					      rewriter.create<StoreOp>(loc, zero, alloc, loopBatchMNIVs);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					      //  Iterate along the reduction dimension.
 | 
				
			||||||
 | 
					      //  Use a value from A.
 | 
				
			||||||
 | 
					      std::vector<Value> reduceLoops;
 | 
				
			||||||
 | 
					      std::vector<Value> optimizedReduceLoops;
 | 
				
			||||||
 | 
					      Block *optimizationReduceBlock =
 | 
				
			||||||
 | 
					          defineLoops(rewriter, loc, reduceLoops, optimizedReduceLoops, 1);
 | 
				
			||||||
 | 
					      KrnlIterateOperandPack reducePack(rewriter, reduceLoops,
 | 
				
			||||||
 | 
					                                        optimizedReduceLoops);
 | 
				
			||||||
 | 
					      addDimensionToPack(rewriter, loc, reducePack, A, AShape.size() - 1);
 | 
				
			||||||
 | 
					      auto reduceIterateOp = rewriter.create<KrnlIterateOp>(loc, reducePack);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					      // No optimization
 | 
				
			||||||
 | 
					      rewriter.setInsertionPointToEnd(optimizationReduceBlock);
 | 
				
			||||||
 | 
					      rewriter.create<KrnlReturnLoopsOp>(loc, reduceLoops);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					      // Insert instructions into the reduction KrnlIterateOp.
 | 
				
			||||||
 | 
					      Block &reduceIterationBlock = reduceIterateOp.bodyRegion().front();
 | 
				
			||||||
 | 
					      rewriter.setInsertionPointToStart(&reduceIterationBlock);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					      // Induction variables
 | 
				
			||||||
 | 
					      SmallVector<Value, 4> loopKIVs, loopBatchMKIVs, loopBatchKNIVs;
 | 
				
			||||||
 | 
					      // K
 | 
				
			||||||
 | 
					      loopKIVs.emplace_back(reduceIterationBlock.getArguments()[0]);
 | 
				
			||||||
 | 
					      // MK
 | 
				
			||||||
 | 
					      if (AShape.size() > 2)
 | 
				
			||||||
 | 
					        for (auto arg : loopBatchIVs)
 | 
				
			||||||
 | 
					          loopBatchMKIVs.emplace_back(arg);
 | 
				
			||||||
 | 
					      if (AShape.size() >= 2)
 | 
				
			||||||
 | 
					        loopBatchMKIVs.emplace_back(loopMNIVs[0]);
 | 
				
			||||||
 | 
					      loopBatchMKIVs.emplace_back(loopKIVs[0]);
 | 
				
			||||||
 | 
					      // KN
 | 
				
			||||||
 | 
					      if (BShape.size() > 2)
 | 
				
			||||||
 | 
					        for (auto arg : loopBatchIVs)
 | 
				
			||||||
 | 
					          loopBatchKNIVs.emplace_back(arg);
 | 
				
			||||||
 | 
					      loopBatchKNIVs.emplace_back(loopKIVs[0]);
 | 
				
			||||||
 | 
					      if (BShape.size() >= 2)
 | 
				
			||||||
 | 
					        if (AShape.size() >= 2)
 | 
				
			||||||
 | 
					          loopBatchKNIVs.emplace_back(loopMNIVs[1]);
 | 
				
			||||||
 | 
					        else
 | 
				
			||||||
 | 
					          loopBatchKNIVs.emplace_back(loopMNIVs[0]);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					      // Matmul computation
 | 
				
			||||||
 | 
					      auto loadedA = rewriter.create<LoadOp>(loc, A, loopBatchMKIVs);
 | 
				
			||||||
 | 
					      auto loadedB = rewriter.create<LoadOp>(loc, B, loopBatchKNIVs);
 | 
				
			||||||
 | 
					      auto loadedY = rewriter.create<LoadOp>(loc, alloc, loopBatchMNIVs);
 | 
				
			||||||
 | 
					      if (elementType.isa<IntegerType>()) {
 | 
				
			||||||
 | 
					        auto AB = rewriter.create<MulIOp>(loc, loadedA, loadedB);
 | 
				
			||||||
 | 
					        auto accumulated = rewriter.create<AddIOp>(loc, loadedY, AB);
 | 
				
			||||||
 | 
					        rewriter.create<StoreOp>(loc, accumulated, alloc, loopBatchMNIVs);
 | 
				
			||||||
 | 
					      } else if (elementType.isa<FloatType>()) {
 | 
				
			||||||
 | 
					        auto AB = rewriter.create<MulFOp>(loc, loadedA, loadedB);
 | 
				
			||||||
 | 
					        auto accumulated = rewriter.create<AddFOp>(loc, loadedY, AB);
 | 
				
			||||||
 | 
					        rewriter.create<StoreOp>(loc, accumulated, alloc, loopBatchMNIVs);
 | 
				
			||||||
 | 
					      }
 | 
				
			||||||
 | 
					    } else if ((AShape.size() == 1) && (BShape.size() == 1)) {
 | 
				
			||||||
 | 
					      // Case 3:
 | 
				
			||||||
 | 
					      // - Both arguments are 1-D
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					      // Fill the output with value 0.
 | 
				
			||||||
 | 
					      Value zeroIndex = rewriter.create<ConstantIndexOp>(loc, 0);
 | 
				
			||||||
 | 
					      rewriter.create<StoreOp>(loc, zero, alloc, zeroIndex);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					      //  Iterate along the reduction dimension.
 | 
				
			||||||
 | 
					      //  Use a value from A.
 | 
				
			||||||
 | 
					      std::vector<Value> reduceLoops;
 | 
				
			||||||
 | 
					      std::vector<Value> optimizedReduceLoops;
 | 
				
			||||||
 | 
					      Block *optimizationReduceBlock =
 | 
				
			||||||
 | 
					          defineLoops(rewriter, loc, reduceLoops, optimizedReduceLoops, 1);
 | 
				
			||||||
 | 
					      KrnlIterateOperandPack reducePack(rewriter, reduceLoops,
 | 
				
			||||||
 | 
					                                        optimizedReduceLoops);
 | 
				
			||||||
 | 
					      addDimensionToPack(rewriter, loc, reducePack, A, 0);
 | 
				
			||||||
 | 
					      auto reduceIterateOp = rewriter.create<KrnlIterateOp>(loc, reducePack);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					      // No optimization
 | 
				
			||||||
 | 
					      rewriter.setInsertionPointToEnd(optimizationReduceBlock);
 | 
				
			||||||
 | 
					      rewriter.create<KrnlReturnLoopsOp>(loc, reduceLoops);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					      // Insert instructions into the reduction KrnlIterateOp.
 | 
				
			||||||
 | 
					      Block &reduceIterationBlock = reduceIterateOp.bodyRegion().front();
 | 
				
			||||||
 | 
					      rewriter.setInsertionPointToStart(&reduceIterationBlock);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					      // Induction variables
 | 
				
			||||||
 | 
					      SmallVector<Value, 4> loopKIVs;
 | 
				
			||||||
 | 
					      // K
 | 
				
			||||||
 | 
					      loopKIVs.emplace_back(reduceIterationBlock.getArgument(0));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					      // Matmul computation
 | 
				
			||||||
 | 
					      auto loadedA = rewriter.create<LoadOp>(loc, A, loopKIVs);
 | 
				
			||||||
 | 
					      auto loadedB = rewriter.create<LoadOp>(loc, B, loopKIVs);
 | 
				
			||||||
 | 
					      auto loadedY = rewriter.create<LoadOp>(loc, alloc, zeroIndex);
 | 
				
			||||||
 | 
					      if (elementType.isa<IntegerType>()) {
 | 
				
			||||||
 | 
					        auto AB = rewriter.create<MulIOp>(loc, loadedA, loadedB);
 | 
				
			||||||
 | 
					        auto accumulated = rewriter.create<AddIOp>(loc, loadedY, AB);
 | 
				
			||||||
 | 
					        rewriter.create<StoreOp>(loc, accumulated, alloc, zeroIndex);
 | 
				
			||||||
 | 
					      } else if (elementType.isa<FloatType>()) {
 | 
				
			||||||
 | 
					        auto AB = rewriter.create<MulFOp>(loc, loadedA, loadedB);
 | 
				
			||||||
 | 
					        auto accumulated = rewriter.create<AddFOp>(loc, loadedY, AB);
 | 
				
			||||||
 | 
					        rewriter.create<StoreOp>(loc, accumulated, alloc, zeroIndex);
 | 
				
			||||||
 | 
					      }
 | 
				
			||||||
 | 
					    } else {
 | 
				
			||||||
 | 
					      // No scalar matrix multiplication.
 | 
				
			||||||
 | 
					      llvm_unreachable("Unsupported scalar matrix multiplication.");
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    rewriter.replaceOp(op, alloc);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    return matchSuccess();
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					void populateLoweringONNXMatMulOpPattern(
 | 
				
			||||||
 | 
					    OwningRewritePatternList &patterns, MLIRContext *ctx) {
 | 
				
			||||||
 | 
					  patterns.insert<ONNXMatMulOpLowering>(ctx);
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
| 
						 | 
					@ -0,0 +1,307 @@
 | 
				
			||||||
 | 
					//===----- reduction.inc - Lowering Reduction Ops -------------------------===//
 | 
				
			||||||
 | 
					//
 | 
				
			||||||
 | 
					// Copyright 2019 The IBM Research Authors.
 | 
				
			||||||
 | 
					//
 | 
				
			||||||
 | 
					// =============================================================================
 | 
				
			||||||
 | 
					//
 | 
				
			||||||
 | 
					// This file lowers the ONNX Reduction Operators to Krnl dialect.
 | 
				
			||||||
 | 
					//
 | 
				
			||||||
 | 
					//===----------------------------------------------------------------------===//
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Identity values
 | 
				
			||||||
 | 
					template <>
 | 
				
			||||||
 | 
					float getIdentityValue<float, ONNXReduceMaxOp>(){
 | 
				
			||||||
 | 
					  return (float)-std::numeric_limits<float>::infinity();
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					template <>
 | 
				
			||||||
 | 
					int getIdentityValue<int, ONNXReduceMaxOp>(){
 | 
				
			||||||
 | 
					  return std::numeric_limits<int>::min();
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					template <>
 | 
				
			||||||
 | 
					float getIdentityValue<float, ONNXReduceMinOp>(){
 | 
				
			||||||
 | 
					  return (float)std::numeric_limits<float>::infinity();
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					template <>
 | 
				
			||||||
 | 
					int getIdentityValue<int, ONNXReduceMinOp>(){
 | 
				
			||||||
 | 
					  return std::numeric_limits<int>::max();
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					template <>
 | 
				
			||||||
 | 
					float getIdentityValue<float, ONNXReduceProdOp>(){
 | 
				
			||||||
 | 
					  return (float)1.0;
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					template <>
 | 
				
			||||||
 | 
					int getIdentityValue<int, ONNXReduceProdOp>(){
 | 
				
			||||||
 | 
					  return 1;
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					template <>
 | 
				
			||||||
 | 
					float getIdentityValue<float, ONNXReduceSumOp>(){
 | 
				
			||||||
 | 
					  return (float)0;
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					template <>
 | 
				
			||||||
 | 
					int getIdentityValue<int, ONNXReduceSumOp>(){
 | 
				
			||||||
 | 
					  return 0;
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Scalar ops
 | 
				
			||||||
 | 
					template <>
 | 
				
			||||||
 | 
					struct ScalarOp<ONNXReduceProdOp> {
 | 
				
			||||||
 | 
					  using FOp = MulFOp;
 | 
				
			||||||
 | 
					  using IOp = MulIOp;
 | 
				
			||||||
 | 
					};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					template <>
 | 
				
			||||||
 | 
					struct ScalarOp<ONNXReduceSumOp> {
 | 
				
			||||||
 | 
					  using FOp = AddFOp;
 | 
				
			||||||
 | 
					  using IOp = AddIOp;
 | 
				
			||||||
 | 
					};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					//===----------------------------------------------------------------------===//
 | 
				
			||||||
 | 
					// Scalar unary ops for lowering ONNXReduceMaxOp
 | 
				
			||||||
 | 
					//===----------------------------------------------------------------------===//
 | 
				
			||||||
 | 
					template <>
 | 
				
			||||||
 | 
					Value mapToLowerScalarOp<ONNXReduceMaxOp>(Operation *op,
 | 
				
			||||||
 | 
					                                          ArrayRef<Type> result_types,
 | 
				
			||||||
 | 
					                                          ArrayRef<Value> operands,
 | 
				
			||||||
 | 
					                                          ConversionPatternRewriter &rewriter) {
 | 
				
			||||||
 | 
					  auto loc = op->getLoc();
 | 
				
			||||||
 | 
					  Value lhs = operands[0];
 | 
				
			||||||
 | 
					  Value rhs = operands[1];
 | 
				
			||||||
 | 
					  Type element_type = lhs.getType();
 | 
				
			||||||
 | 
					  if (element_type.isa<IntegerType>()) {
 | 
				
			||||||
 | 
					    auto max = rewriter.create<CmpIOp>(loc, CmpIPredicate::sgt, lhs, rhs);
 | 
				
			||||||
 | 
					    auto result = rewriter.create<SelectOp>(loc, max, lhs, rhs);
 | 
				
			||||||
 | 
					    return result;
 | 
				
			||||||
 | 
					  } else if (element_type.isa<FloatType>()) {
 | 
				
			||||||
 | 
					    auto max = rewriter.create<CmpFOp>(loc, CmpFPredicate::OGT, lhs, rhs);
 | 
				
			||||||
 | 
					    auto result = rewriter.create<SelectOp>(loc, max, lhs, rhs);
 | 
				
			||||||
 | 
					    return result;
 | 
				
			||||||
 | 
					  } else {
 | 
				
			||||||
 | 
					    emitError(loc, "unsupported element type");
 | 
				
			||||||
 | 
					    return nullptr;
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					//===----------------------------------------------------------------------===//
 | 
				
			||||||
 | 
					// Scalar unary ops for lowering ONNXReduceMinOp
 | 
				
			||||||
 | 
					//===----------------------------------------------------------------------===//
 | 
				
			||||||
 | 
					template <>
 | 
				
			||||||
 | 
					Value mapToLowerScalarOp<ONNXReduceMinOp>(Operation *op,
 | 
				
			||||||
 | 
					                                          ArrayRef<Type> result_types,
 | 
				
			||||||
 | 
					                                          ArrayRef<Value> operands,
 | 
				
			||||||
 | 
					                                          ConversionPatternRewriter &rewriter) {
 | 
				
			||||||
 | 
					  auto loc = op->getLoc();
 | 
				
			||||||
 | 
					  Value lhs = operands[0];
 | 
				
			||||||
 | 
					  Value rhs = operands[1];
 | 
				
			||||||
 | 
					  Type element_type = lhs.getType();
 | 
				
			||||||
 | 
					  if (element_type.isa<IntegerType>()) {
 | 
				
			||||||
 | 
					    auto min = rewriter.create<CmpIOp>(loc, CmpIPredicate::slt, lhs, rhs);
 | 
				
			||||||
 | 
					    auto result = rewriter.create<SelectOp>(loc, min, lhs, rhs);
 | 
				
			||||||
 | 
					    return result;
 | 
				
			||||||
 | 
					  } else if (element_type.isa<FloatType>()) {
 | 
				
			||||||
 | 
					    auto min = rewriter.create<CmpFOp>(loc, CmpFPredicate::OLT, lhs, rhs);
 | 
				
			||||||
 | 
					    auto result = rewriter.create<SelectOp>(loc, min, lhs, rhs);
 | 
				
			||||||
 | 
					    return result;
 | 
				
			||||||
 | 
					  } else {
 | 
				
			||||||
 | 
					    emitError(loc, "unsupported element type");
 | 
				
			||||||
 | 
					    return nullptr;
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					template <typename ONNXReductionOp>
 | 
				
			||||||
 | 
					struct ONNXReductionOpLowering : public ConversionPattern {
 | 
				
			||||||
 | 
					  ONNXReductionOpLowering(MLIRContext *ctx)
 | 
				
			||||||
 | 
					      : ConversionPattern(ONNXReductionOp::getOperationName(), 1, ctx) {}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  PatternMatchResult
 | 
				
			||||||
 | 
					  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
 | 
				
			||||||
 | 
					                  ConversionPatternRewriter &rewriter) const final {
 | 
				
			||||||
 | 
					    /*
 | 
				
			||||||
 | 
					     * Condition: reduction function must be associative and commutative.
 | 
				
			||||||
 | 
					     *
 | 
				
			||||||
 | 
					     * Example 1 (here, reduction function is `+`):
 | 
				
			||||||
 | 
					     * Induction variables: (i0, i1, i2)
 | 
				
			||||||
 | 
					     * axes = [0, 2]
 | 
				
			||||||
 | 
					     * keepdims = true
 | 
				
			||||||
 | 
					     * krnl.iterate() with (i0, i1, i2) {
 | 
				
			||||||
 | 
					     *   Y(0, i1, 0) += X(i0, i1, i2)
 | 
				
			||||||
 | 
					     * }
 | 
				
			||||||
 | 
					     *
 | 
				
			||||||
 | 
					     * Example 2 (here, reduction function is `+`):
 | 
				
			||||||
 | 
					     * Induction variables: (i0, i1, i2)
 | 
				
			||||||
 | 
					     * axes = [0, 2]
 | 
				
			||||||
 | 
					     * keepdims = false
 | 
				
			||||||
 | 
					     * krnl.iterate() with (i0, i1, i2) {
 | 
				
			||||||
 | 
					     *   Y(i1) += X(i0, i1, i2)
 | 
				
			||||||
 | 
					     * }
 | 
				
			||||||
 | 
					     *
 | 
				
			||||||
 | 
					    */
 | 
				
			||||||
 | 
					    auto loc = op->getLoc();
 | 
				
			||||||
 | 
					    auto memRefInType = operands[0].getType().cast<MemRefType>();
 | 
				
			||||||
 | 
					    auto memRefInShape = memRefInType.getShape();
 | 
				
			||||||
 | 
					    auto tensorOutType = (*op->result_type_begin()).cast<TensorType>();
 | 
				
			||||||
 | 
					    int64_t inRank = memRefInType.getRank();
 | 
				
			||||||
 | 
					    int64_t outRank = tensorOutType.getRank();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    // Get attributes
 | 
				
			||||||
 | 
					    ArrayAttr axisAttrs = llvm::dyn_cast<ONNXReductionOp>(op).axesAttr();
 | 
				
			||||||
 | 
					    std::vector<int64_t> axes;
 | 
				
			||||||
 | 
					    if (axisAttrs) {
 | 
				
			||||||
 | 
					      for (auto axisAttr : axisAttrs.getValue()) {
 | 
				
			||||||
 | 
					        int64_t axis = axisAttr.cast<IntegerAttr>().getInt();
 | 
				
			||||||
 | 
					        axis = axis >= 0 ? axis : (inRank + axis);
 | 
				
			||||||
 | 
					        assert(axis >= -inRank && axis <= inRank - 1);
 | 
				
			||||||
 | 
					        if (std::find(axes.begin(), axes.end(), axis) == axes.end())
 | 
				
			||||||
 | 
					          axes.push_back(axis);
 | 
				
			||||||
 | 
					      }
 | 
				
			||||||
 | 
					    } else {
 | 
				
			||||||
 | 
					      for (decltype(inRank) i = 0; i < inRank; ++i) {
 | 
				
			||||||
 | 
					        axes.push_back(i);
 | 
				
			||||||
 | 
					      }
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					    // KeepDims
 | 
				
			||||||
 | 
					    auto keepdims =
 | 
				
			||||||
 | 
					        llvm::dyn_cast<ONNXReductionOp>(op).keepdims();
 | 
				
			||||||
 | 
					    bool isKeepdims = (keepdims == 1) ? true : false;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    // Get type information
 | 
				
			||||||
 | 
					    auto memRefOutType = convertTensorToMemRef(tensorOutType);
 | 
				
			||||||
 | 
					    auto memRefOutShape = memRefOutType.getShape();
 | 
				
			||||||
 | 
					    auto elementOutType = memRefOutType.getElementType();
 | 
				
			||||||
 | 
					    std::map<int64_t, int64_t> outInDimMap =
 | 
				
			||||||
 | 
					        getReductionMapping(memRefInType, axes, isKeepdims);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    // Insert an allocation and deallocation for the result of this operation.
 | 
				
			||||||
 | 
					    Value alloc;
 | 
				
			||||||
 | 
					    bool insertDealloc = checkInsertDealloc(op);
 | 
				
			||||||
 | 
					    if (hasAllConstantDimensions(memRefOutType)) {
 | 
				
			||||||
 | 
					      alloc = insertAllocAndDealloc(memRefOutType, loc, rewriter, insertDealloc);
 | 
				
			||||||
 | 
					    } else {
 | 
				
			||||||
 | 
					      SmallVector<Value, 2> allocOperands;
 | 
				
			||||||
 | 
					      for (decltype(outRank) i = 0; i < outRank; ++i) {
 | 
				
			||||||
 | 
					        if (memRefOutShape[i] < 0) {
 | 
				
			||||||
 | 
					          auto dim = rewriter.create<DimOp>(loc, operands[0], outInDimMap[i]);
 | 
				
			||||||
 | 
					          allocOperands.push_back(dim);
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					      }
 | 
				
			||||||
 | 
					      alloc = rewriter.create<AllocOp>(loc, memRefOutType, allocOperands);
 | 
				
			||||||
 | 
					      if (insertDealloc) {
 | 
				
			||||||
 | 
					        auto *parentBlock = alloc.getDefiningOp()->getBlock();
 | 
				
			||||||
 | 
					        auto dealloc = rewriter.create<DeallocOp>(loc, alloc);
 | 
				
			||||||
 | 
					        dealloc.getOperation()->moveBefore(&parentBlock->back());
 | 
				
			||||||
 | 
					      }
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    // There are two Krnl loops:
 | 
				
			||||||
 | 
					    // - One to initialize the result memref, and
 | 
				
			||||||
 | 
					    // - One to do reduction
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    // Define loops to initialize the result.
 | 
				
			||||||
 | 
					    std::vector<Value> originalLoopsInit;
 | 
				
			||||||
 | 
					    std::vector<Value> optimizedLoopsInit;
 | 
				
			||||||
 | 
					    Block *optimizationBlockInit = defineLoops(rewriter, loc, originalLoopsInit,
 | 
				
			||||||
 | 
					            optimizedLoopsInit, outRank);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    // Iteration information
 | 
				
			||||||
 | 
					    KrnlIterateOperandPack packInit(rewriter, originalLoopsInit,
 | 
				
			||||||
 | 
					        optimizedLoopsInit);
 | 
				
			||||||
 | 
					    for (decltype(outRank) i = 0; i < outRank; ++i) {
 | 
				
			||||||
 | 
					      addDimensionToPack(rewriter, loc, packInit, alloc, i);
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					    auto iterateOpInit = rewriter.create<KrnlIterateOp>(loc, packInit);
 | 
				
			||||||
 | 
					    Block &iterationBlockInit = iterateOpInit.bodyRegion().front();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    // Perform the insertions into the body of the initialization loop.
 | 
				
			||||||
 | 
					    // No optimization
 | 
				
			||||||
 | 
					    rewriter.setInsertionPointToEnd(optimizationBlockInit);
 | 
				
			||||||
 | 
					    rewriter.create<KrnlReturnLoopsOp>(loc, originalLoopsInit);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    // Insert instructions inside the KernelIterateOp body.
 | 
				
			||||||
 | 
					    rewriter.setInsertionPointToStart(&iterationBlockInit);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    // Handle the operation:
 | 
				
			||||||
 | 
					    SmallVector<Value, 4> loopIVs;
 | 
				
			||||||
 | 
					    for (auto arg : iterationBlockInit.getArguments()) {
 | 
				
			||||||
 | 
					      loopIVs.push_back(arg);
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Value identity;
 | 
				
			||||||
 | 
					    if (elementOutType.isa<FloatType>()) {
 | 
				
			||||||
 | 
					      identity = rewriter.create<ConstantOp>(
 | 
				
			||||||
 | 
					          loc, FloatAttr::get(elementOutType,
 | 
				
			||||||
 | 
					                              getIdentityValue<float, ONNXReductionOp>()));
 | 
				
			||||||
 | 
					    } else if (elementOutType.isa<IntegerType>()) {
 | 
				
			||||||
 | 
					      identity = rewriter.create<ConstantOp>(
 | 
				
			||||||
 | 
					          loc, IntegerAttr::get(elementOutType,
 | 
				
			||||||
 | 
					                                getIdentityValue<int, ONNXReductionOp>()));
 | 
				
			||||||
 | 
					    } else {
 | 
				
			||||||
 | 
					      emitError(loc, "unsupported element type");
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					    rewriter.create<StoreOp>(loc, identity, alloc, loopIVs);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    // Define an Krnl loop to do reduction.
 | 
				
			||||||
 | 
					    rewriter.setInsertionPointAfter(iterateOpInit);
 | 
				
			||||||
 | 
					    std::vector<Value> originalLoops, optimizedLoops;
 | 
				
			||||||
 | 
					    Block *optimizationBlock = defineLoops(rewriter, loc, originalLoops,
 | 
				
			||||||
 | 
					            optimizedLoops, inRank);
 | 
				
			||||||
 | 
					    // Iteration information
 | 
				
			||||||
 | 
					    KrnlIterateOperandPack pack(rewriter, originalLoops, optimizedLoops);
 | 
				
			||||||
 | 
					    for (decltype(inRank) i = 0; i < inRank; ++i) {
 | 
				
			||||||
 | 
					      addDimensionToPack(rewriter, loc, pack, operands[0], i);
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					    auto iterateOp = rewriter.create<KrnlIterateOp>(loc, pack);
 | 
				
			||||||
 | 
					    Block &iterationBlock = iterateOp.bodyRegion().front();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    // Perform the insertions into the body of the reduction loop.
 | 
				
			||||||
 | 
					    // No optimization
 | 
				
			||||||
 | 
					    rewriter.setInsertionPointToEnd(optimizationBlock);
 | 
				
			||||||
 | 
					    rewriter.create<KrnlReturnLoopsOp>(loc, originalLoops);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    // Insert instructions inside the KernelIterateOp body.
 | 
				
			||||||
 | 
					    rewriter.setInsertionPointToStart(&iterationBlock);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    // Handle the operation:
 | 
				
			||||||
 | 
					    SmallVector<Value, 4> inLoopIVs, outLoopIVs;
 | 
				
			||||||
 | 
					    auto args = iterationBlock.getArguments();
 | 
				
			||||||
 | 
					    for (int i = 0; i < args.size(); ++i) {
 | 
				
			||||||
 | 
					      inLoopIVs.push_back(args[i]);
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					    Value zeroIndex = nullptr;
 | 
				
			||||||
 | 
					    for (decltype(inRank) i = 0; i < outRank; ++i) {
 | 
				
			||||||
 | 
					      if (outInDimMap.find(i) != outInDimMap.end()) {
 | 
				
			||||||
 | 
					        outLoopIVs.push_back(inLoopIVs[outInDimMap[i]]);
 | 
				
			||||||
 | 
					      } else {
 | 
				
			||||||
 | 
					        if (zeroIndex) {
 | 
				
			||||||
 | 
					          outLoopIVs.push_back(zeroIndex);
 | 
				
			||||||
 | 
					        } else {
 | 
				
			||||||
 | 
					          zeroIndex = rewriter.create<ConstantIndexOp>(loc, 0);
 | 
				
			||||||
 | 
					          outLoopIVs.push_back(zeroIndex);
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					      }
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Value next, accumulated;
 | 
				
			||||||
 | 
					    next = rewriter.create<LoadOp>(loc, operands[0], inLoopIVs);
 | 
				
			||||||
 | 
					    accumulated = rewriter.create<LoadOp>(loc, alloc, outLoopIVs);
 | 
				
			||||||
 | 
					    accumulated = mapToLowerScalarOp<ONNXReductionOp>(
 | 
				
			||||||
 | 
					        op, memRefOutType.getElementType(), {accumulated, next}, rewriter);
 | 
				
			||||||
 | 
					    rewriter.create<StoreOp>(loc, accumulated, alloc, outLoopIVs);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    rewriter.replaceOp(op, alloc);
 | 
				
			||||||
 | 
					    return matchSuccess();
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					void populateLoweringONNXReductionOpPattern(
 | 
				
			||||||
 | 
					    OwningRewritePatternList &patterns, MLIRContext *ctx) {
 | 
				
			||||||
 | 
					  patterns.insert<ONNXReductionOpLowering<mlir::ONNXReduceMaxOp>,
 | 
				
			||||||
 | 
					                  ONNXReductionOpLowering<mlir::ONNXReduceMinOp>,
 | 
				
			||||||
 | 
					                  ONNXReductionOpLowering<mlir::ONNXReduceProdOp>,
 | 
				
			||||||
 | 
					                  ONNXReductionOpLowering<mlir::ONNXReduceSumOp>>(ctx);
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
| 
						 | 
					@ -0,0 +1,205 @@
 | 
				
			||||||
 | 
					//===----- softmax.inc - Softmax Op ---------------------------------------===//
 | 
				
			||||||
 | 
					//
 | 
				
			||||||
 | 
					// Copyright 2019 The IBM Research Authors.
 | 
				
			||||||
 | 
					//
 | 
				
			||||||
 | 
					// =============================================================================
 | 
				
			||||||
 | 
					//
 | 
				
			||||||
 | 
					// This file lowers ONNX softmax operator to Krnl dialect.
 | 
				
			||||||
 | 
					//
 | 
				
			||||||
 | 
					//===----------------------------------------------------------------------===//
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					struct ONNXSoftmaxOpLowering : public ConversionPattern {
 | 
				
			||||||
 | 
					  ONNXSoftmaxOpLowering(MLIRContext *ctx)
 | 
				
			||||||
 | 
					      : ConversionPattern(mlir::ONNXSoftmaxOp::getOperationName(), 1, ctx) {}
 | 
				
			||||||
 | 
					  PatternMatchResult
 | 
				
			||||||
 | 
					  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
 | 
				
			||||||
 | 
					                  ConversionPatternRewriter &rewriter) const final {
 | 
				
			||||||
 | 
					    // softmax(x) = let max_x = max(x) in
 | 
				
			||||||
 | 
					    //                let exp_x = exp(x - max_x) in
 | 
				
			||||||
 | 
					    //                  let sum = sum(exp_x) in
 | 
				
			||||||
 | 
					    //                    exp_x / sum
 | 
				
			||||||
 | 
					    auto tensorType = (*op->result_type_begin()).cast<RankedTensorType>();
 | 
				
			||||||
 | 
					    int64_t rank = tensorType.getRank();
 | 
				
			||||||
 | 
					    int64_t axis = llvm::dyn_cast<ONNXSoftmaxOp>(op).axis().getSExtValue();
 | 
				
			||||||
 | 
					    axis = axis >= 0 ? axis : rank + axis;
 | 
				
			||||||
 | 
					    assert(axis >= -rank && axis <= rank - 1);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    auto loc = op->getLoc();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    // Insert an allocation and deallocation for the result of this operation.
 | 
				
			||||||
 | 
					    auto memRefType = convertTensorToMemRef(tensorType);
 | 
				
			||||||
 | 
					    auto elementType = memRefType.getElementType();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Value alloc;
 | 
				
			||||||
 | 
					    bool insertDealloc = checkInsertDealloc(op);
 | 
				
			||||||
 | 
					    if (hasAllConstantDimensions(memRefType))
 | 
				
			||||||
 | 
					      alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc);
 | 
				
			||||||
 | 
					    else
 | 
				
			||||||
 | 
					      alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc,
 | 
				
			||||||
 | 
					                                    operands[0]);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    // Shape of the result
 | 
				
			||||||
 | 
					    auto memRefShape = memRefType.getShape();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    // Insert allocations and deallocations for sum and max.
 | 
				
			||||||
 | 
					    MemRefType scalarMemRefType = MemRefType::get({}, elementType, {}, 0);
 | 
				
			||||||
 | 
					    Value sumOp = insertAllocAndDealloc(scalarMemRefType, loc, rewriter, true);
 | 
				
			||||||
 | 
					    Value maxOp = insertAllocAndDealloc(scalarMemRefType, loc, rewriter, true);
 | 
				
			||||||
 | 
					    Value zero =
 | 
				
			||||||
 | 
					        rewriter.create<ConstantOp>(loc, FloatAttr::get(elementType, 0));
 | 
				
			||||||
 | 
					    Value negInfinity = rewriter.create<ConstantOp>(
 | 
				
			||||||
 | 
					        loc,
 | 
				
			||||||
 | 
					        FloatAttr::get(elementType, -std::numeric_limits<float>::infinity()));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    // Define loops.
 | 
				
			||||||
 | 
					    std::vector<Value> originalLoops;
 | 
				
			||||||
 | 
					    std::vector<Value> optimizedLoops;
 | 
				
			||||||
 | 
					    Block *optimizationBlock = defineLoops(rewriter, loc, originalLoops,
 | 
				
			||||||
 | 
					            optimizedLoops, rank);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    // Coerce the input into a 2-D tensor. `axis` will be the coercing point.
 | 
				
			||||||
 | 
					    // This coercing follows the softmax definition in ONNX:
 | 
				
			||||||
 | 
					    // https://github.com/onnx/onnx/blob/master/docs/Operators.md#Softmax
 | 
				
			||||||
 | 
					    // Here, we create an outer loop and inner loop for handling the two
 | 
				
			||||||
 | 
					    // dimensions. The outer loop is only created once `axis` is not zero.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    // Define an outer loop with respect to axis.
 | 
				
			||||||
 | 
					    std::vector<Value> outerLoops, optimizedOuterLoops;
 | 
				
			||||||
 | 
					    outerLoops.reserve(axis);
 | 
				
			||||||
 | 
					    optimizedOuterLoops.reserve(axis);
 | 
				
			||||||
 | 
					    for (int i = 0; i < axis; ++i) {
 | 
				
			||||||
 | 
					      outerLoops.push_back(originalLoops[i]);
 | 
				
			||||||
 | 
					      optimizedOuterLoops.push_back(optimizedLoops[i]);
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					    KrnlIterateOperandPack outerPack(rewriter, outerLoops, optimizedOuterLoops);
 | 
				
			||||||
 | 
					    for (int i = 0; i < axis; ++i)
 | 
				
			||||||
 | 
					      addDimensionToPack(rewriter, loc, outerPack, operands[0], i);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    // Define an inner loop with respect to axis.
 | 
				
			||||||
 | 
					    std::vector<Value> innerLoops, optimizedInnerLoops;
 | 
				
			||||||
 | 
					    innerLoops.reserve(rank - axis);
 | 
				
			||||||
 | 
					    optimizedInnerLoops.reserve(rank - axis);
 | 
				
			||||||
 | 
					    for (int i = axis; i < rank; ++i) {
 | 
				
			||||||
 | 
					      innerLoops.push_back(originalLoops[i]);
 | 
				
			||||||
 | 
					      optimizedInnerLoops.push_back(optimizedLoops[i]);
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					    KrnlIterateOperandPack innerPack(rewriter, innerLoops, optimizedInnerLoops);
 | 
				
			||||||
 | 
					    for (int i = axis; i < rank; ++i)
 | 
				
			||||||
 | 
					      addDimensionToPack(rewriter, loc, innerPack, operands[0], i);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    KrnlIterateOp outerIterateOp, maxIterateOp, sumIterateOp, softmaxIterateOp;
 | 
				
			||||||
 | 
					    SmallVector<Value, 4> outerLoopIVs;
 | 
				
			||||||
 | 
					    if (axis != 0) {
 | 
				
			||||||
 | 
					      outerIterateOp = rewriter.create<KrnlIterateOp>(loc, outerPack);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					      // No optimization
 | 
				
			||||||
 | 
					      rewriter.setInsertionPointToEnd(optimizationBlock);
 | 
				
			||||||
 | 
					      rewriter.create<KrnlReturnLoopsOp>(loc, originalLoops);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					      // Insert instructions inside the outer loop.
 | 
				
			||||||
 | 
					      Block &outerIterationBlock = outerIterateOp.bodyRegion().front();
 | 
				
			||||||
 | 
					      rewriter.setInsertionPointToStart(&outerIterationBlock);
 | 
				
			||||||
 | 
					      for (auto arg : outerIterationBlock.getArguments())
 | 
				
			||||||
 | 
					        outerLoopIVs.push_back(arg);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					      // Reset accumulators.
 | 
				
			||||||
 | 
					      rewriter.create<StoreOp>(loc, zero, sumOp);
 | 
				
			||||||
 | 
					      rewriter.create<StoreOp>(loc, negInfinity, maxOp);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					      // Create an inner loop to compute max.
 | 
				
			||||||
 | 
					      maxIterateOp = rewriter.create<KrnlIterateOp>(loc, innerPack);
 | 
				
			||||||
 | 
					      // Create an inner loop to compute sum.
 | 
				
			||||||
 | 
					      sumIterateOp = rewriter.create<KrnlIterateOp>(loc, innerPack);
 | 
				
			||||||
 | 
					      // Create an inner loop to compute softmax.
 | 
				
			||||||
 | 
					      softmaxIterateOp = rewriter.create<KrnlIterateOp>(loc, innerPack);
 | 
				
			||||||
 | 
					    } else {
 | 
				
			||||||
 | 
					      // Reset accumulators.
 | 
				
			||||||
 | 
					      rewriter.create<StoreOp>(loc, zero, sumOp);
 | 
				
			||||||
 | 
					      rewriter.create<StoreOp>(loc, negInfinity, maxOp);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					      // Create an inner loop to compute max.
 | 
				
			||||||
 | 
					      maxIterateOp = rewriter.create<KrnlIterateOp>(loc, innerPack);
 | 
				
			||||||
 | 
					      // Create an inner loop to compute sum.
 | 
				
			||||||
 | 
					      sumIterateOp = rewriter.create<KrnlIterateOp>(loc, innerPack);
 | 
				
			||||||
 | 
					      // Create an inner loop to compute softmax.
 | 
				
			||||||
 | 
					      softmaxIterateOp = rewriter.create<KrnlIterateOp>(loc, innerPack);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					      // No optimization
 | 
				
			||||||
 | 
					      rewriter.setInsertionPointToEnd(optimizationBlock);
 | 
				
			||||||
 | 
					      rewriter.create<KrnlReturnLoopsOp>(loc, originalLoops);
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    // Insert instructions inside the max loop.
 | 
				
			||||||
 | 
					    Block &maxIterationBlock = maxIterateOp.bodyRegion().front();
 | 
				
			||||||
 | 
					    rewriter.setInsertionPointToStart(&maxIterationBlock);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    // Get induction variables.
 | 
				
			||||||
 | 
					    SmallVector<Value, 4> maxLoopIVs;
 | 
				
			||||||
 | 
					    for (auto arg : outerLoopIVs)
 | 
				
			||||||
 | 
					      maxLoopIVs.push_back(arg);
 | 
				
			||||||
 | 
					    for (auto arg : maxIterationBlock.getArguments())
 | 
				
			||||||
 | 
					      maxLoopIVs.push_back(arg);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    // Compute the max value.
 | 
				
			||||||
 | 
					    Value max = rewriter.create<LoadOp>(loc, maxOp);
 | 
				
			||||||
 | 
					    Value nextMax = rewriter.create<LoadOp>(loc, operands[0], maxLoopIVs);
 | 
				
			||||||
 | 
					    auto maxCond =
 | 
				
			||||||
 | 
					        rewriter.create<CmpFOp>(loc, CmpFPredicate::OGT, max, nextMax);
 | 
				
			||||||
 | 
					    max = rewriter.create<SelectOp>(loc, maxCond, max, nextMax);
 | 
				
			||||||
 | 
					    rewriter.create<StoreOp>(loc, max, maxOp);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    // Get the max.
 | 
				
			||||||
 | 
					    rewriter.setInsertionPoint(sumIterateOp);
 | 
				
			||||||
 | 
					    max = rewriter.create<LoadOp>(loc, maxOp);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    // Insert instructions inside the sum loop.
 | 
				
			||||||
 | 
					    Block &sumIterationBlock = sumIterateOp.bodyRegion().front();
 | 
				
			||||||
 | 
					    rewriter.setInsertionPointToStart(&sumIterationBlock);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    // Get induction variables.
 | 
				
			||||||
 | 
					    SmallVector<Value, 4> sumLoopIVs;
 | 
				
			||||||
 | 
					    for (auto arg : outerLoopIVs)
 | 
				
			||||||
 | 
					      sumLoopIVs.push_back(arg);
 | 
				
			||||||
 | 
					    for (auto arg : sumIterationBlock.getArguments())
 | 
				
			||||||
 | 
					      sumLoopIVs.push_back(arg);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    // Sum up values.
 | 
				
			||||||
 | 
					    Value sum = rewriter.create<LoadOp>(loc, sumOp);
 | 
				
			||||||
 | 
					    Value next = rewriter.create<LoadOp>(loc, operands[0], sumLoopIVs);
 | 
				
			||||||
 | 
					    Value sub = rewriter.create<SubFOp>(loc, next, max);
 | 
				
			||||||
 | 
					    Value exp = rewriter.create<ExpOp>(loc, sub);
 | 
				
			||||||
 | 
					    sum = rewriter.create<AddFOp>(loc, sum, exp);
 | 
				
			||||||
 | 
					    rewriter.create<StoreOp>(loc, sum, sumOp);
 | 
				
			||||||
 | 
					    // Store intermediate values in the result to avoid recomputation.
 | 
				
			||||||
 | 
					    rewriter.create<StoreOp>(loc, exp, alloc, sumLoopIVs);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    // Get the sum.
 | 
				
			||||||
 | 
					    rewriter.setInsertionPoint(softmaxIterateOp);
 | 
				
			||||||
 | 
					    sum = rewriter.create<LoadOp>(loc, sumOp);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    // Insert instructions inside the softmax loop.
 | 
				
			||||||
 | 
					    Block &softmaxIterationBlock = softmaxIterateOp.bodyRegion().front();
 | 
				
			||||||
 | 
					    rewriter.setInsertionPointToStart(&softmaxIterationBlock);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    // Get induction variables.
 | 
				
			||||||
 | 
					    SmallVector<Value, 4> softmaxLoopIVs;
 | 
				
			||||||
 | 
					    for (auto arg : outerLoopIVs)
 | 
				
			||||||
 | 
					      softmaxLoopIVs.push_back(arg);
 | 
				
			||||||
 | 
					    for (auto arg : softmaxIterationBlock.getArguments())
 | 
				
			||||||
 | 
					      softmaxLoopIVs.push_back(arg);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    // Compute softmax.
 | 
				
			||||||
 | 
					    Value expLoadedVal = rewriter.create<LoadOp>(loc, alloc, softmaxLoopIVs);
 | 
				
			||||||
 | 
					    Value result = rewriter.create<DivFOp>(loc, expLoadedVal, sum);
 | 
				
			||||||
 | 
					    rewriter.create<StoreOp>(loc, result, alloc, softmaxLoopIVs);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    rewriter.replaceOp(op, alloc);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    return matchSuccess();
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					void populateLoweringONNXSoftmaxOpPattern(
 | 
				
			||||||
 | 
					    OwningRewritePatternList &patterns, MLIRContext *ctx) {
 | 
				
			||||||
 | 
					  patterns.insert<ONNXSoftmaxOpLowering>(ctx);
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
| 
						 | 
					@ -0,0 +1,282 @@
 | 
				
			||||||
 | 
					//===----- conv.inc - Lowering Convolution Op -----------------------------===//
 | 
				
			||||||
 | 
					//
 | 
				
			||||||
 | 
					// Copyright 2019 The IBM Research Authors.
 | 
				
			||||||
 | 
					//
 | 
				
			||||||
 | 
					// =============================================================================
 | 
				
			||||||
 | 
					//
 | 
				
			||||||
 | 
					// This file lowers the ONNX Convolution Operators to Krnl dialect.
 | 
				
			||||||
 | 
					//
 | 
				
			||||||
 | 
					//===----------------------------------------------------------------------===//
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					struct ONNXConvNoBiasOpLowering : public ConversionPattern {
 | 
				
			||||||
 | 
					  ONNXConvNoBiasOpLowering(MLIRContext *ctx)
 | 
				
			||||||
 | 
					      : ConversionPattern(mlir::ONNXConvNoBiasOp::getOperationName(), 1, ctx) {}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  PatternMatchResult
 | 
				
			||||||
 | 
					  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
 | 
				
			||||||
 | 
					                  ConversionPatternRewriter &rewriter) const final {
 | 
				
			||||||
 | 
					    auto tensorType = (*op->result_type_begin()).cast<TensorType>();
 | 
				
			||||||
 | 
					    auto loc = op->getLoc();
 | 
				
			||||||
 | 
					    // Insert an allocation and deallocation for the result of this operation.
 | 
				
			||||||
 | 
					    auto memRefType = convertTensorToMemRef(tensorType);
 | 
				
			||||||
 | 
					    Value alloc;
 | 
				
			||||||
 | 
					    bool insertDealloc = checkInsertDealloc(op);
 | 
				
			||||||
 | 
					    ONNXConvNoBiasOp convOp = llvm::dyn_cast<ONNXConvNoBiasOp>(op);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if (hasAllConstantDimensions(memRefType))
 | 
				
			||||||
 | 
					      alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc);
 | 
				
			||||||
 | 
					    else
 | 
				
			||||||
 | 
					      alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc,
 | 
				
			||||||
 | 
					                                    {operands[0]});
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    auto resultShape = memRefType.getShape();
 | 
				
			||||||
 | 
					    auto inputShape = operands[0].getType().cast<MemRefType>().getShape();
 | 
				
			||||||
 | 
					    auto kernelShape = operands[1].getType().cast<MemRefType>().getShape();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    // R = ConvNoBias(D, K)
 | 
				
			||||||
 | 
					    //
 | 
				
			||||||
 | 
					    // The input/output shapes will look like this:
 | 
				
			||||||
 | 
					    //
 | 
				
			||||||
 | 
					    // D (NxCxHxW) x K (MxC/groupxKHxKW) -> R (NxMxRHxRW)
 | 
				
			||||||
 | 
					    //
 | 
				
			||||||
 | 
					    // M is a multiple of the number of groups:
 | 
				
			||||||
 | 
					    //   M = group * kernelsPerGroup
 | 
				
			||||||
 | 
					    //
 | 
				
			||||||
 | 
					    // The loop nest will look as follows:
 | 
				
			||||||
 | 
					    //
 | 
				
			||||||
 | 
					    // strides = [s1, s2]
 | 
				
			||||||
 | 
					    //
 | 
				
			||||||
 | 
					    // kernelsPerGroup = M / group;
 | 
				
			||||||
 | 
					    // for n = 0 .. N:
 | 
				
			||||||
 | 
					    //   for g = 0 .. group:
 | 
				
			||||||
 | 
					    //     for m = 0 .. kernelsPerGroup:
 | 
				
			||||||
 | 
					    //       kernel = g * kernelsPerGroup + m;
 | 
				
			||||||
 | 
					    //       for r1 = 0 .. RH:
 | 
				
			||||||
 | 
					    //         for r2 = 0 .. RW:
 | 
				
			||||||
 | 
					    //           R[n][kernel][r1][r2] = 0;
 | 
				
			||||||
 | 
					    //           for c = 0 .. C/group:
 | 
				
			||||||
 | 
					    //             for k1 = 0 .. KH:
 | 
				
			||||||
 | 
					    //               for k2 = 0 .. KW:
 | 
				
			||||||
 | 
					    //                 R[n][kernel][r1][r2] =
 | 
				
			||||||
 | 
					    //                   D[n][g * (C / group) + c][s1 * r1 + k1][s2 * r2 + k2] *
 | 
				
			||||||
 | 
					    //                   K[kernel][c][k1][k2];
 | 
				
			||||||
 | 
					    //
 | 
				
			||||||
 | 
					    // Naming:
 | 
				
			||||||
 | 
					    //   n, g, m: outer loop nest indices
 | 
				
			||||||
 | 
					    //   r1, r2: spatial loop nest indices
 | 
				
			||||||
 | 
					    //   c, k1, k2: inner loop nest indices
 | 
				
			||||||
 | 
					    //
 | 
				
			||||||
 | 
					    // TODO: handle padding.
 | 
				
			||||||
 | 
					    //
 | 
				
			||||||
 | 
					    // In the general case:
 | 
				
			||||||
 | 
					    //
 | 
				
			||||||
 | 
					    // D (NxCxD1xD2x...xDdim) x K (MxC/groupxK1xK2x...xKdim)
 | 
				
			||||||
 | 
					    //     -> R (NxMxR1xR2x...xRdim)
 | 
				
			||||||
 | 
					    //
 | 
				
			||||||
 | 
					    // The above loop nest can be adapted by increasing the number
 | 
				
			||||||
 | 
					    // of r- and k-index loop i.e. r1 r2 and k1 k2 loops.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    // Set up outermost loops: n g m r1 r2 ... rdim
 | 
				
			||||||
 | 
					    // Skip g if group is 1.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    // Before we start the iteration we need to compute the number of
 | 
				
			||||||
 | 
					    // unsplit kernels and fetch the number of groups from the attribute
 | 
				
			||||||
 | 
					    // list. Group is always a compilation constant.
 | 
				
			||||||
 | 
					    int64_t group = convOp.group().getSExtValue();
 | 
				
			||||||
 | 
					    // Compute the number of unsplit kernels. The number of kernels
 | 
				
			||||||
 | 
					    // must be a multiple of the number of groups.
 | 
				
			||||||
 | 
					    int64_t kernelsPerGroup = floor(kernelShape[0] / group);
 | 
				
			||||||
 | 
					    auto kernelsPerGroupValue =
 | 
				
			||||||
 | 
					        rewriter.create<ConstantIndexOp>(loc, kernelsPerGroup);
 | 
				
			||||||
 | 
					    auto zero = rewriter.create<ConstantOp>(
 | 
				
			||||||
 | 
					        loc, FloatAttr::get(memRefType.getElementType(), 0));
 | 
				
			||||||
 | 
					    Value subchannels;
 | 
				
			||||||
 | 
					    if (kernelShape[1] < 0) {
 | 
				
			||||||
 | 
					      subchannels =
 | 
				
			||||||
 | 
					          rewriter.create<DimOp>(loc, operands[1], 1).getResult();
 | 
				
			||||||
 | 
					    } else {
 | 
				
			||||||
 | 
					      subchannels = rewriter.create<ConstantIndexOp>(
 | 
				
			||||||
 | 
					          loc, kernelShape[1]);
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    // 1. Define outer loops and emit empty optimization block:
 | 
				
			||||||
 | 
					    int64_t nOuterLoops = (group > 1) ? 3 : 2;
 | 
				
			||||||
 | 
					    std::vector<Value> outerLoops;
 | 
				
			||||||
 | 
					    std::vector<Value> optimizedOuterLoops;
 | 
				
			||||||
 | 
					    Block *optimizationBlock = defineLoops(rewriter, loc, outerLoops,
 | 
				
			||||||
 | 
					        optimizedOuterLoops, nOuterLoops);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    // Prepare iteration arguments over outer loop nest.
 | 
				
			||||||
 | 
					    KrnlIterateOperandPack pack(
 | 
				
			||||||
 | 
					        rewriter, outerLoops, optimizedOuterLoops);
 | 
				
			||||||
 | 
					    //   for n = 0 .. N:
 | 
				
			||||||
 | 
					    pack.pushConstantBound(0);
 | 
				
			||||||
 | 
					    if (inputShape[0] < 0)
 | 
				
			||||||
 | 
					      pack.pushOperandBound(
 | 
				
			||||||
 | 
					          rewriter.create<DimOp>(loc, operands[0], 0).getResult());
 | 
				
			||||||
 | 
					    else
 | 
				
			||||||
 | 
					      pack.pushConstantBound(inputShape[0]);
 | 
				
			||||||
 | 
					    //   for g = 0 .. N:
 | 
				
			||||||
 | 
					    if (group > 1) {
 | 
				
			||||||
 | 
					      pack.pushConstantBound(0);
 | 
				
			||||||
 | 
					      pack.pushConstantBound(group);
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					    //   for m = 0 .. kernelsPerGroup:
 | 
				
			||||||
 | 
					    pack.pushConstantBound(0);
 | 
				
			||||||
 | 
					    pack.pushConstantBound(kernelsPerGroup);
 | 
				
			||||||
 | 
					    // Outer loop iteration.
 | 
				
			||||||
 | 
					    auto iterateOp = rewriter.create<KrnlIterateOp>(loc, pack);
 | 
				
			||||||
 | 
					    Block &outerIterationBlock = iterateOp.bodyRegion().front();
 | 
				
			||||||
 | 
					    // Emit optimizations for outer loops:
 | 
				
			||||||
 | 
					    rewriter.setInsertionPointToEnd(optimizationBlock);
 | 
				
			||||||
 | 
					    rewriter.create<KrnlReturnLoopsOp>(loc, outerLoops);
 | 
				
			||||||
 | 
					    rewriter.setInsertionPointToStart(&outerIterationBlock);
 | 
				
			||||||
 | 
					    {
 | 
				
			||||||
 | 
					      // 2. Emit the body of the outer loop nest.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					      // 2.1 Compute kernel order number: kernel = g * kernelsPerGroup + m;
 | 
				
			||||||
 | 
					      // If group is not set then the value of the kernel ID is
 | 
				
			||||||
 | 
					      // identical to that of the loop over kernels.
 | 
				
			||||||
 | 
					      Value kernel = outerIterationBlock.getArguments()[1];
 | 
				
			||||||
 | 
					      if (group > 1) {
 | 
				
			||||||
 | 
					        // Middle loop is over groups and third loop is over the
 | 
				
			||||||
 | 
					        // kernel identifiers in the current group.
 | 
				
			||||||
 | 
					        auto kernelsOffset = rewriter.create<MulIOp>(loc,
 | 
				
			||||||
 | 
					            outerIterationBlock.getArguments()[1],
 | 
				
			||||||
 | 
					            kernelsPerGroupValue);
 | 
				
			||||||
 | 
					        kernel = rewriter.create<AddIOp>(loc, kernelsOffset,
 | 
				
			||||||
 | 
					            outerIterationBlock.getArguments()[2]);
 | 
				
			||||||
 | 
					      }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					      // 2.2 Define spatial loops
 | 
				
			||||||
 | 
					      int64_t nSpatialLoops = resultShape.size() - 2;
 | 
				
			||||||
 | 
					      std::vector<Value> spatialLoops;
 | 
				
			||||||
 | 
					      std::vector<Value> optimizedSpatialLoops;
 | 
				
			||||||
 | 
					      Block *optSpatialLoopBlock = defineLoops(rewriter, loc, spatialLoops,
 | 
				
			||||||
 | 
					        optimizedSpatialLoops, nSpatialLoops);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					      // 2.3 Prepare iteration arguments for spatial loop nest.
 | 
				
			||||||
 | 
					      KrnlIterateOperandPack spatialPack(
 | 
				
			||||||
 | 
					        rewriter, spatialLoops, optimizedSpatialLoops);
 | 
				
			||||||
 | 
					      for (int i = 2; i < resultShape.size(); ++i)
 | 
				
			||||||
 | 
					        addDimensionToPack(rewriter, loc, spatialPack, alloc, i);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					      // 2.4 Emit loop nest over output spatial dimensions.
 | 
				
			||||||
 | 
					      //   for rX = 0 .. RX
 | 
				
			||||||
 | 
					      auto spatialIterateOp =
 | 
				
			||||||
 | 
					          rewriter.create<KrnlIterateOp>(loc, spatialPack);
 | 
				
			||||||
 | 
					      Block &spatialIterationBlock = spatialIterateOp.bodyRegion().front();
 | 
				
			||||||
 | 
					      // 2.5 Emit optimizations for outer loops:
 | 
				
			||||||
 | 
					      rewriter.setInsertionPointToEnd(optSpatialLoopBlock);
 | 
				
			||||||
 | 
					      rewriter.create<KrnlReturnLoopsOp>(loc, spatialLoops);
 | 
				
			||||||
 | 
					      rewriter.setInsertionPointToStart(&spatialIterationBlock);
 | 
				
			||||||
 | 
					      {
 | 
				
			||||||
 | 
					        // 3. Emit the body of the spatial loop nest.
 | 
				
			||||||
 | 
					        // 3.1 Emit: R[n][kernel][r1][r2] = 0;
 | 
				
			||||||
 | 
					        SmallVector<Value, 4> resultIndices;
 | 
				
			||||||
 | 
					        // n
 | 
				
			||||||
 | 
					        resultIndices.emplace_back(outerIterationBlock.getArguments()[0]);
 | 
				
			||||||
 | 
					        // kernel
 | 
				
			||||||
 | 
					        resultIndices.emplace_back(kernel);
 | 
				
			||||||
 | 
					        // rX
 | 
				
			||||||
 | 
					        for (auto arg : spatialIterationBlock.getArguments())
 | 
				
			||||||
 | 
					          resultIndices.emplace_back(arg);
 | 
				
			||||||
 | 
					        // Store initializer value into output location.
 | 
				
			||||||
 | 
					        rewriter.create<StoreOp>(loc, zero, alloc, resultIndices);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        // 3.2 Define inner loops.
 | 
				
			||||||
 | 
					        int64_t nInnerLoops = 1 + (kernelShape.size() - 2);
 | 
				
			||||||
 | 
					        std::vector<Value> innerLoops;
 | 
				
			||||||
 | 
					        std::vector<Value> optimizedInnerLoops;
 | 
				
			||||||
 | 
					        Block *optInnerLoopBlock = defineLoops(rewriter, loc, innerLoops,
 | 
				
			||||||
 | 
					            optimizedInnerLoops, nInnerLoops);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        // 3.3 Prepare iteration arguments for inner loop nest.
 | 
				
			||||||
 | 
					        KrnlIterateOperandPack innerPack(
 | 
				
			||||||
 | 
					            rewriter, innerLoops, optimizedInnerLoops);
 | 
				
			||||||
 | 
					        //   for c = 0 .. C/group
 | 
				
			||||||
 | 
					        innerPack.pushConstantBound(0);
 | 
				
			||||||
 | 
					        innerPack.pushConstantBound(kernelShape[1]);
 | 
				
			||||||
 | 
					        //   for Kx = 0 .. KX
 | 
				
			||||||
 | 
					        for (int i = 2; i < kernelShape.size(); ++i)
 | 
				
			||||||
 | 
					          addDimensionToPack(rewriter, loc, innerPack, operands[1], i);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        // 3.4 Emit inner loop nest.
 | 
				
			||||||
 | 
					        auto innerIterateOp =
 | 
				
			||||||
 | 
					            rewriter.create<KrnlIterateOp>(loc, innerPack);
 | 
				
			||||||
 | 
					        Block &innerIterationBlock = innerIterateOp.bodyRegion().front();
 | 
				
			||||||
 | 
					        // 3.5 Emit optimizations for outer loops:
 | 
				
			||||||
 | 
					        rewriter.setInsertionPointToEnd(optInnerLoopBlock);
 | 
				
			||||||
 | 
					        rewriter.create<KrnlReturnLoopsOp>(loc, innerLoops);
 | 
				
			||||||
 | 
					        rewriter.setInsertionPointToStart(&innerIterationBlock);
 | 
				
			||||||
 | 
					        {
 | 
				
			||||||
 | 
					          // 4. Emit inner loop body
 | 
				
			||||||
 | 
					          // R[n][kernel][r1][r2] =
 | 
				
			||||||
 | 
					          //   D[n][g * (C / group) + c][s1 * r1 + k1][s2 * r2 + k2] *
 | 
				
			||||||
 | 
					          //   K[kernel][c][k1][k2];
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					          // 4.1 Prepare indices for accesing the data tensor.
 | 
				
			||||||
 | 
					          SmallVector<Value, 4> dataIndices;
 | 
				
			||||||
 | 
					          // n
 | 
				
			||||||
 | 
					          dataIndices.emplace_back(outerIterationBlock.getArguments()[0]);
 | 
				
			||||||
 | 
					          // g * (C / group) + c
 | 
				
			||||||
 | 
					          Value channelDepth = innerIterationBlock.getArguments()[0];
 | 
				
			||||||
 | 
					          if (group > 1)
 | 
				
			||||||
 | 
					            channelDepth = rewriter.create<AddIOp>(loc, channelDepth,
 | 
				
			||||||
 | 
					                rewriter.create<MulIOp>(loc, subchannels,
 | 
				
			||||||
 | 
					                    outerIterationBlock.getArguments()[1]));
 | 
				
			||||||
 | 
					          dataIndices.emplace_back(channelDepth);
 | 
				
			||||||
 | 
					          // sX * rX + kX
 | 
				
			||||||
 | 
					          auto stridesAttribute = convOp.stridesAttr();
 | 
				
			||||||
 | 
					          // Read strides attribute
 | 
				
			||||||
 | 
					          SmallVector<int, 4> strides;
 | 
				
			||||||
 | 
					          if (stridesAttribute)
 | 
				
			||||||
 | 
					            for (auto stride : stridesAttribute.getValue())
 | 
				
			||||||
 | 
					              strides.emplace_back(stride.cast<IntegerAttr>().getInt());
 | 
				
			||||||
 | 
					          for (int i = 0; i < kernelShape.size() - 2; ++i) {
 | 
				
			||||||
 | 
					            Value spatialIndex = spatialIterationBlock.getArguments()[i];
 | 
				
			||||||
 | 
					            // If strides are present then emit the correct access index.
 | 
				
			||||||
 | 
					            if (stridesAttribute && strides[i] > 1)
 | 
				
			||||||
 | 
					              spatialIndex = rewriter.create<MulIOp>(loc,
 | 
				
			||||||
 | 
					                  rewriter.create<ConstantIndexOp>(loc, strides[i]),
 | 
				
			||||||
 | 
					                  spatialIterationBlock.getArguments()[i]);
 | 
				
			||||||
 | 
					            dataIndices.emplace_back(
 | 
				
			||||||
 | 
					                rewriter.create<AddIOp>(loc, spatialIndex,
 | 
				
			||||||
 | 
					                    innerIterationBlock.getArguments()[i+1]));
 | 
				
			||||||
 | 
					          }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					          // 4.2 Prepare indices for accessing the kernel tensor.
 | 
				
			||||||
 | 
					          SmallVector<Value, 4> kernelIndices;
 | 
				
			||||||
 | 
					          // kernel
 | 
				
			||||||
 | 
					          kernelIndices.emplace_back(kernel);
 | 
				
			||||||
 | 
					          // c
 | 
				
			||||||
 | 
					          kernelIndices.emplace_back(innerIterationBlock.getArguments()[0]);
 | 
				
			||||||
 | 
					          // kX
 | 
				
			||||||
 | 
					          for (int i = 0; i < kernelShape.size() - 2; ++i)
 | 
				
			||||||
 | 
					            kernelIndices.emplace_back(
 | 
				
			||||||
 | 
					                innerIterationBlock.getArguments()[i+1]);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					          // 4.3 Compute convolution.
 | 
				
			||||||
 | 
					          auto loadData =
 | 
				
			||||||
 | 
					              rewriter.create<LoadOp>(loc, operands[0], dataIndices);
 | 
				
			||||||
 | 
					          auto loadKernel =
 | 
				
			||||||
 | 
					              rewriter.create<LoadOp>(loc, operands[1], kernelIndices);
 | 
				
			||||||
 | 
					          auto loadPartialSum =
 | 
				
			||||||
 | 
					              rewriter.create<LoadOp>(loc, alloc, resultIndices);
 | 
				
			||||||
 | 
					          Value result = rewriter.create<AddFOp>(loc, loadPartialSum,
 | 
				
			||||||
 | 
					              rewriter.create<MulFOp>(loc, loadData, loadKernel));
 | 
				
			||||||
 | 
					          // 4.4 Store computed value into output location.
 | 
				
			||||||
 | 
					          rewriter.create<StoreOp>(loc, result, alloc, resultIndices);
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					      }
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					    rewriter.replaceOp(op, alloc);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    return matchSuccess();
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					void populateLoweringONNXConvOpPattern(
 | 
				
			||||||
 | 
					    OwningRewritePatternList &patterns, MLIRContext *ctx) {
 | 
				
			||||||
 | 
					  patterns.insert<ONNXConvNoBiasOpLowering>(ctx);
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
| 
						 | 
					@ -0,0 +1,26 @@
 | 
				
			||||||
 | 
					//===----- identity.inc - Lowering Identity Op ----------------------------===//
 | 
				
			||||||
 | 
					//
 | 
				
			||||||
 | 
					// Copyright 2019 The IBM Research Authors.
 | 
				
			||||||
 | 
					//
 | 
				
			||||||
 | 
					// =============================================================================
 | 
				
			||||||
 | 
					//
 | 
				
			||||||
 | 
					// This file lowers the ONNX Identity Operator to Krnl dialect.
 | 
				
			||||||
 | 
					//
 | 
				
			||||||
 | 
					//===----------------------------------------------------------------------===//
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					struct ONNXIdentityOpLowering : public ConversionPattern {
 | 
				
			||||||
 | 
					  ONNXIdentityOpLowering(MLIRContext *ctx)
 | 
				
			||||||
 | 
					      : ConversionPattern(mlir::ONNXIdentityOp::getOperationName(), 1, ctx) {}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  PatternMatchResult
 | 
				
			||||||
 | 
					  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
 | 
				
			||||||
 | 
					                  ConversionPatternRewriter &rewriter) const final {
 | 
				
			||||||
 | 
					    rewriter.replaceOp(op, operands[0]);
 | 
				
			||||||
 | 
					    return matchSuccess();
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					void populateLoweringONNXIdentityOpPattern(
 | 
				
			||||||
 | 
					    OwningRewritePatternList &patterns, MLIRContext *ctx) {
 | 
				
			||||||
 | 
					  patterns.insert<ONNXIdentityOpLowering>(ctx);
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
| 
						 | 
					@ -0,0 +1,151 @@
 | 
				
			||||||
 | 
					//===----- reshape.inc - Lowering Reshape Op ------------------------------===//
 | 
				
			||||||
 | 
					//
 | 
				
			||||||
 | 
					// Copyright 2019 The IBM Research Authors.
 | 
				
			||||||
 | 
					//
 | 
				
			||||||
 | 
					// =============================================================================
 | 
				
			||||||
 | 
					//
 | 
				
			||||||
 | 
					// This file lowers the ONNX Reshape Operator to Krnl dialect.
 | 
				
			||||||
 | 
					//
 | 
				
			||||||
 | 
					//===----------------------------------------------------------------------===//
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					struct ONNXReshapeOpLowering : public ConversionPattern {
 | 
				
			||||||
 | 
					  ONNXReshapeOpLowering(MLIRContext *ctx)
 | 
				
			||||||
 | 
					      : ConversionPattern(mlir::ONNXReshapeOp::getOperationName(), 1, ctx) {}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  PatternMatchResult
 | 
				
			||||||
 | 
					  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
 | 
				
			||||||
 | 
					                  ConversionPatternRewriter &rewriter) const final {
 | 
				
			||||||
 | 
					    auto tensorType = (*op->result_type_begin()).cast<TensorType>();
 | 
				
			||||||
 | 
					    auto inputShape = operands[0].getType().cast<MemRefType>().getShape();
 | 
				
			||||||
 | 
					    auto loc = op->getLoc();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    // Insert an allocation and deallocation for the result of this operation.
 | 
				
			||||||
 | 
					    auto memRefType = convertTensorToMemRef(tensorType);
 | 
				
			||||||
 | 
					    auto memRefShape = memRefType.getShape();
 | 
				
			||||||
 | 
					    Value alloc;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    // Compute size in bytes using the input tensor.
 | 
				
			||||||
 | 
					    Value tensorSize = rewriter.create<ConstantOp>(
 | 
				
			||||||
 | 
					        loc, rewriter.getIntegerAttr(rewriter.getIntegerType(64),
 | 
				
			||||||
 | 
					                                     getMemRefEltSizeInBytes(memRefType)));
 | 
				
			||||||
 | 
					    for (int i = 0; i < inputShape.size(); ++i) {
 | 
				
			||||||
 | 
					      Value dimVal;
 | 
				
			||||||
 | 
					      if (inputShape[i] < 0) {
 | 
				
			||||||
 | 
					        Value dim = rewriter.create<DimOp>(loc, operands[0], i);
 | 
				
			||||||
 | 
					        dimVal =
 | 
				
			||||||
 | 
					            rewriter.create<IndexCastOp>(loc, dim, rewriter.getIntegerType(64));
 | 
				
			||||||
 | 
					      } else {
 | 
				
			||||||
 | 
					        dimVal = rewriter.create<ConstantOp>(
 | 
				
			||||||
 | 
					            loc, rewriter.getIntegerAttr(rewriter.getIntegerType(64),
 | 
				
			||||||
 | 
					                                         inputShape[i]));
 | 
				
			||||||
 | 
					      }
 | 
				
			||||||
 | 
					      tensorSize = rewriter.create<MulIOp>(loc, tensorSize, dimVal);
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    bool insertDealloc = checkInsertDealloc(op);
 | 
				
			||||||
 | 
					    if (hasAllConstantDimensions(memRefType)) {
 | 
				
			||||||
 | 
					      alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc);
 | 
				
			||||||
 | 
					    } else {
 | 
				
			||||||
 | 
					      // If a dimension is zero, the actual dimension value is taken from the
 | 
				
			||||||
 | 
					      // input tensor.
 | 
				
			||||||
 | 
					      //
 | 
				
			||||||
 | 
					      // If the shape array has a negative dimension (-1), we compute its actual
 | 
				
			||||||
 | 
					      // dimension value from the other dimensions. But we don't have enough
 | 
				
			||||||
 | 
					      // information about the other dimensions at this point. So, we need to
 | 
				
			||||||
 | 
					      // scan the shape first to calculate reduction of all of the dimensions.
 | 
				
			||||||
 | 
					      // If the reduction is negative, then the shape array contains a negative
 | 
				
			||||||
 | 
					      // dimension. Otherwise, the reduction is the same as the one computed
 | 
				
			||||||
 | 
					      // from the input tensor.
 | 
				
			||||||
 | 
					      Value tensorSizeFromShape = rewriter.create<ConstantOp>(
 | 
				
			||||||
 | 
					          loc, rewriter.getIntegerAttr(rewriter.getIntegerType(64),
 | 
				
			||||||
 | 
					                                       getMemRefEltSizeInBytes(memRefType)));
 | 
				
			||||||
 | 
					      SmallVector<Value, 4> DimInfo;
 | 
				
			||||||
 | 
					      for (int i = 0; i < memRefShape.size(); ++i) {
 | 
				
			||||||
 | 
					        Value index = rewriter.create<ConstantOp>(
 | 
				
			||||||
 | 
					            loc, rewriter.getIntegerAttr(rewriter.getIndexType(), i));
 | 
				
			||||||
 | 
					        // Load index from array of indices.
 | 
				
			||||||
 | 
					        Value loadedVal = rewriter.create<LoadOp>(loc, operands[1], index);
 | 
				
			||||||
 | 
					        // If a dimension is zero, the actual dimension value is taken from the
 | 
				
			||||||
 | 
					        // input tensor.
 | 
				
			||||||
 | 
					        //
 | 
				
			||||||
 | 
					        // If a dimension is negative, it is computed from the other dimensions.
 | 
				
			||||||
 | 
					        // But we don't have enough information about the other dimensions at
 | 
				
			||||||
 | 
					        // this point. So, we let it as it is (-1), and compute it later.
 | 
				
			||||||
 | 
					        if (i < inputShape.size()) {
 | 
				
			||||||
 | 
					          Value dimVal;
 | 
				
			||||||
 | 
					          auto loadedValType = loadedVal.getType().cast<IntegerType>();
 | 
				
			||||||
 | 
					          if (inputShape[i] < 0) {
 | 
				
			||||||
 | 
					            Value dim = rewriter.create<DimOp>(loc, operands[0], i);
 | 
				
			||||||
 | 
					            dimVal = rewriter.create<IndexCastOp>(loc, dim, loadedValType);
 | 
				
			||||||
 | 
					          } else {
 | 
				
			||||||
 | 
					            dimVal = rewriter.create<ConstantOp>(
 | 
				
			||||||
 | 
					                loc, rewriter.getIntegerAttr(loadedValType, inputShape[i]));
 | 
				
			||||||
 | 
					          }
 | 
				
			||||||
 | 
					          auto zero = rewriter.create<ConstantOp>(
 | 
				
			||||||
 | 
					              loc, rewriter.getIntegerAttr(loadedValType, 0));
 | 
				
			||||||
 | 
					          auto isZero =
 | 
				
			||||||
 | 
					              rewriter.create<CmpIOp>(loc, CmpIPredicate::eq, loadedVal, zero);
 | 
				
			||||||
 | 
					          loadedVal = rewriter.create<SelectOp>(loc, isZero, dimVal, loadedVal);
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					        // Check if the loaded index is already the correct width of 64 bits.
 | 
				
			||||||
 | 
					        // Convert the value to a 64 bit integer if needed.
 | 
				
			||||||
 | 
					        Value int64LoadedVal = loadedVal;
 | 
				
			||||||
 | 
					        if (loadedVal.getType().cast<IntegerType>().getWidth() < 64)
 | 
				
			||||||
 | 
					          int64LoadedVal = rewriter.create<ZeroExtendIOp>(
 | 
				
			||||||
 | 
					              loc, loadedVal, rewriter.getIntegerType(64));
 | 
				
			||||||
 | 
					        tensorSizeFromShape =
 | 
				
			||||||
 | 
					            rewriter.create<MulIOp>(loc, tensorSizeFromShape, int64LoadedVal);
 | 
				
			||||||
 | 
					        // Store intermediate results to use later.
 | 
				
			||||||
 | 
					        DimInfo.emplace_back(int64LoadedVal);
 | 
				
			||||||
 | 
					      }
 | 
				
			||||||
 | 
					      // Reverse tensorSizeFromShape since it is negative if the shape array has
 | 
				
			||||||
 | 
					      // a negative dimension. This is safe since we only use it to compute the
 | 
				
			||||||
 | 
					      // actual value for the negative dimension.
 | 
				
			||||||
 | 
					      auto zero = rewriter.create<ConstantOp>(
 | 
				
			||||||
 | 
					          loc, rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0));
 | 
				
			||||||
 | 
					      tensorSizeFromShape =
 | 
				
			||||||
 | 
					          rewriter.create<SubIOp>(loc, zero, tensorSizeFromShape);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					      // Obtain operands for AllocOp.
 | 
				
			||||||
 | 
					      SmallVector<Value, 4> allocOperands;
 | 
				
			||||||
 | 
					      auto negOne = rewriter.create<ConstantOp>(
 | 
				
			||||||
 | 
					          loc, rewriter.getIntegerAttr(rewriter.getIntegerType(64), -1));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					      for (int i = 0; i < memRefShape.size(); ++i) {
 | 
				
			||||||
 | 
					        auto dimVal = DimInfo[i];
 | 
				
			||||||
 | 
					        auto isNegOne =
 | 
				
			||||||
 | 
					            rewriter.create<CmpIOp>(loc, CmpIPredicate::eq, dimVal, negOne);
 | 
				
			||||||
 | 
					        // If dimension is negative, compute its value from the other
 | 
				
			||||||
 | 
					        // dimensions.
 | 
				
			||||||
 | 
					        auto actualDimVal =
 | 
				
			||||||
 | 
					            rewriter.create<SignedDivIOp>(loc, tensorSize, tensorSizeFromShape);
 | 
				
			||||||
 | 
					        auto loadedVal =
 | 
				
			||||||
 | 
					            rewriter.create<SelectOp>(loc, isNegOne, actualDimVal, dimVal);
 | 
				
			||||||
 | 
					        allocOperands.push_back(rewriter.create<IndexCastOp>(
 | 
				
			||||||
 | 
					            loc, loadedVal, rewriter.getIndexType()));
 | 
				
			||||||
 | 
					      }
 | 
				
			||||||
 | 
					      AllocOp allocateMemref =
 | 
				
			||||||
 | 
					          rewriter.create<AllocOp>(loc, memRefType, allocOperands);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					      // Make sure to allocate at the beginning of the block if
 | 
				
			||||||
 | 
					      // all dimensions are known.
 | 
				
			||||||
 | 
					      auto *parentBlock = allocateMemref.getOperation()->getBlock();
 | 
				
			||||||
 | 
					      if (insertDealloc) {
 | 
				
			||||||
 | 
					        auto dealloc = rewriter.create<DeallocOp>(loc, allocateMemref);
 | 
				
			||||||
 | 
					        dealloc.getOperation()->moveBefore(&parentBlock->back());
 | 
				
			||||||
 | 
					      }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					      alloc = allocateMemref;
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    rewriter.create<KrnlMemcpyOp>(loc, alloc, operands[0], tensorSize);
 | 
				
			||||||
 | 
					    rewriter.replaceOp(op, alloc);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    return matchSuccess();
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					void populateLoweringONNXReshapeOpPattern(
 | 
				
			||||||
 | 
					    OwningRewritePatternList &patterns, MLIRContext *ctx) {
 | 
				
			||||||
 | 
					  patterns.insert<ONNXReshapeOpLowering>(ctx);
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
| 
						 | 
					@ -0,0 +1,99 @@
 | 
				
			||||||
 | 
					//===----- transpose.inc - Lowering Transpose Op --------------------------===//
 | 
				
			||||||
 | 
					//
 | 
				
			||||||
 | 
					// Copyright 2019 The IBM Research Authors.
 | 
				
			||||||
 | 
					//
 | 
				
			||||||
 | 
					// =============================================================================
 | 
				
			||||||
 | 
					//
 | 
				
			||||||
 | 
					// This file lowers the ONNX Transpose Operator to Krnl dialect.
 | 
				
			||||||
 | 
					//
 | 
				
			||||||
 | 
					//===----------------------------------------------------------------------===//
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					struct ONNXTransposeOpLowering : public ConversionPattern {
 | 
				
			||||||
 | 
					  ONNXTransposeOpLowering(MLIRContext *ctx)
 | 
				
			||||||
 | 
					      : ConversionPattern(mlir::ONNXTransposeOp::getOperationName(), 1, ctx) {}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  PatternMatchResult
 | 
				
			||||||
 | 
					  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
 | 
				
			||||||
 | 
					                  ConversionPatternRewriter &rewriter) const final {
 | 
				
			||||||
 | 
					    auto tensorType = (*op->result_type_begin()).cast<TensorType>();
 | 
				
			||||||
 | 
					    auto loc = op->getLoc();
 | 
				
			||||||
 | 
					    // Insert an allocation and deallocation for the result of this operation.
 | 
				
			||||||
 | 
					    auto memRefType = convertTensorToMemRef(tensorType);
 | 
				
			||||||
 | 
					    Value alloc;
 | 
				
			||||||
 | 
					    bool insertDealloc = checkInsertDealloc(op);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if (hasAllConstantDimensions(memRefType))
 | 
				
			||||||
 | 
					      alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc);
 | 
				
			||||||
 | 
					    else
 | 
				
			||||||
 | 
					      alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc,
 | 
				
			||||||
 | 
					                                    {operands[0]});
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    // Number of loops
 | 
				
			||||||
 | 
					    auto memRefShape = memRefType.getShape();
 | 
				
			||||||
 | 
					    int64_t rank = memRefShape.size();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    // Define loops.
 | 
				
			||||||
 | 
					    std::vector<Value> originalLoops;
 | 
				
			||||||
 | 
					    std::vector<Value> optimizedLoops;
 | 
				
			||||||
 | 
					    Block *optimizationBlock = defineLoops(rewriter, loc, originalLoops,
 | 
				
			||||||
 | 
					        optimizedLoops, rank);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    KrnlIterateOperandPack pack(rewriter, originalLoops, optimizedLoops);
 | 
				
			||||||
 | 
					    // Iterate over the loop nest using the input shape.
 | 
				
			||||||
 | 
					    for (int i = 0; i < rank; ++i)
 | 
				
			||||||
 | 
					      addDimensionToPack(rewriter, loc, pack, operands[0], i);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    auto iterateOp = rewriter.create<KrnlIterateOp>(loc, pack);
 | 
				
			||||||
 | 
					    Block &iterationBlock = iterateOp.bodyRegion().front();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    // Now perform the insertions into the body of the
 | 
				
			||||||
 | 
					    // just generated instructions:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    // 1. Insert any optimizations in the KrnlOptimizeLoopsOp body.
 | 
				
			||||||
 | 
					    rewriter.setInsertionPointToEnd(optimizationBlock);
 | 
				
			||||||
 | 
					    // Return from KrnlOptimizeLoopsOp body.
 | 
				
			||||||
 | 
					    // When no optimizations are present we just return the loops
 | 
				
			||||||
 | 
					    // unchaged.
 | 
				
			||||||
 | 
					    rewriter.create<KrnlReturnLoopsOp>(loc, originalLoops);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    // 2. Insert instructions inside the KernelIterateOp body.
 | 
				
			||||||
 | 
					    rewriter.setInsertionPointToStart(&iterationBlock);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    // Handle the operation.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    // Read perm attribute.
 | 
				
			||||||
 | 
					    SmallVector<int, 4> perm;
 | 
				
			||||||
 | 
					    auto permAttribute = llvm::dyn_cast<ONNXTransposeOp>(op).permAttr();
 | 
				
			||||||
 | 
					    if (permAttribute) {
 | 
				
			||||||
 | 
					      for (auto permVal : permAttribute.getValue())
 | 
				
			||||||
 | 
					        perm.emplace_back(permVal.cast<IntegerAttr>().getInt());
 | 
				
			||||||
 | 
					    } else {
 | 
				
			||||||
 | 
					      // TODO: Remove when perm is guaranteed to be present (even for
 | 
				
			||||||
 | 
					      // the default case). This means that perm was added by shape
 | 
				
			||||||
 | 
					      // inference or another pass to contain the values corresponding
 | 
				
			||||||
 | 
					      // to the default behavior of Transpose. 
 | 
				
			||||||
 | 
					      for (int i = iterationBlock.getArguments().size()-1; i >= 0; i--)
 | 
				
			||||||
 | 
					        perm.emplace_back(i);
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    SmallVector<Value, 4> inLoopIVs;
 | 
				
			||||||
 | 
					    for (auto arg : iterationBlock.getArguments())
 | 
				
			||||||
 | 
					      inLoopIVs.emplace_back(arg);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    SmallVector<Value, 4> outLoopIVs;
 | 
				
			||||||
 | 
					    for (int i=0; i<iterationBlock.getArguments().size(); ++i)
 | 
				
			||||||
 | 
					      outLoopIVs.emplace_back(iterationBlock.getArguments()[perm[i]]);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    auto inVal = rewriter.create<LoadOp>(loc, operands[0], inLoopIVs);
 | 
				
			||||||
 | 
					    rewriter.create<StoreOp>(loc, inVal, alloc, outLoopIVs);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    rewriter.replaceOp(op, alloc);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    return matchSuccess();
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					void populateLoweringONNXTransposeOpPattern(
 | 
				
			||||||
 | 
					    OwningRewritePatternList &patterns, MLIRContext *ctx) {
 | 
				
			||||||
 | 
					  patterns.insert<ONNXTransposeOpLowering>(ctx);
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
| 
						 | 
					@ -0,0 +1,86 @@
 | 
				
			||||||
 | 
					//===----- unsqueeze.inc - Lowering Unsqueeze Op --------------------------===//
 | 
				
			||||||
 | 
					//
 | 
				
			||||||
 | 
					// Copyright 2019 The IBM Research Authors.
 | 
				
			||||||
 | 
					//
 | 
				
			||||||
 | 
					// =============================================================================
 | 
				
			||||||
 | 
					//
 | 
				
			||||||
 | 
					// This file lowers the ONNX Unsqueeze Operator to Krnl dialect.
 | 
				
			||||||
 | 
					//
 | 
				
			||||||
 | 
					//===----------------------------------------------------------------------===//
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					struct ONNXUnsqueezeOpLowering : public ConversionPattern {
 | 
				
			||||||
 | 
					  ONNXUnsqueezeOpLowering(MLIRContext *ctx)
 | 
				
			||||||
 | 
					      : ConversionPattern(mlir::ONNXUnsqueezeOp::getOperationName(), 1, ctx) {}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  PatternMatchResult
 | 
				
			||||||
 | 
					  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
 | 
				
			||||||
 | 
					                  ConversionPatternRewriter &rewriter) const final {
 | 
				
			||||||
 | 
					    auto loc = op->getLoc();
 | 
				
			||||||
 | 
					    auto tensorType = (*op->result_type_begin()).cast<TensorType>();
 | 
				
			||||||
 | 
					    int outRank = tensorType.getRank();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    // Assume that `axes` has been validated by shape inference.
 | 
				
			||||||
 | 
					    // So, here we just get it.
 | 
				
			||||||
 | 
					    ArrayAttr axisAttrs = llvm::dyn_cast<ONNXUnsqueezeOp>(op).axesAttr();
 | 
				
			||||||
 | 
					    SmallVector<int, 4> axes;
 | 
				
			||||||
 | 
					    for (auto axisAttr : axisAttrs.getValue()) {
 | 
				
			||||||
 | 
					      int axis = axisAttr.cast<IntegerAttr>().getInt();
 | 
				
			||||||
 | 
					      axis = axis >= 0 ? axis : (outRank + axis);
 | 
				
			||||||
 | 
					      axes.emplace_back(axis);
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    // Insert an allocation and deallocation for the result of this operation.
 | 
				
			||||||
 | 
					    auto memRefType = convertTensorToMemRef(tensorType);
 | 
				
			||||||
 | 
					    Value alloc;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    // Compute size in bytes.
 | 
				
			||||||
 | 
					    Value tensorSize = rewriter.create<ConstantOp>(
 | 
				
			||||||
 | 
					        loc, rewriter.getIntegerAttr(rewriter.getIntegerType(64),
 | 
				
			||||||
 | 
					                                     getMemRefEltSizeInBytes(memRefType)));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    bool insertDealloc = checkInsertDealloc(op);
 | 
				
			||||||
 | 
					    auto memRefShape = memRefType.getShape();
 | 
				
			||||||
 | 
					    if (hasAllConstantDimensions(memRefType)) {
 | 
				
			||||||
 | 
					      alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc);
 | 
				
			||||||
 | 
					      for (int i = 0; i < memRefShape.size(); ++i) {
 | 
				
			||||||
 | 
					        Value dimVal = rewriter.create<ConstantOp>(
 | 
				
			||||||
 | 
					            loc, rewriter.getIntegerAttr(rewriter.getIntegerType(64),
 | 
				
			||||||
 | 
					                                         memRefShape[i]));
 | 
				
			||||||
 | 
					        tensorSize = rewriter.create<MulIOp>(loc, tensorSize, dimVal);
 | 
				
			||||||
 | 
					      }
 | 
				
			||||||
 | 
					    } else {
 | 
				
			||||||
 | 
					      // Unknown dimensions are always the operand's dimensions.
 | 
				
			||||||
 | 
					      SmallVector<Value, 4> allocOperands;
 | 
				
			||||||
 | 
					      for (int outIdx = 0, inIdx = 0; outIdx < memRefShape.size(); ++outIdx) {
 | 
				
			||||||
 | 
					        Value dimVal = nullptr;
 | 
				
			||||||
 | 
					        if (memRefShape[outIdx] < 0) {
 | 
				
			||||||
 | 
					          Value index = rewriter.create<DimOp>(loc, operands[0], inIdx);
 | 
				
			||||||
 | 
					          dimVal = rewriter.create<IndexCastOp>(
 | 
				
			||||||
 | 
					              loc, index, rewriter.getIntegerType(64));
 | 
				
			||||||
 | 
					          allocOperands.emplace_back(index);
 | 
				
			||||||
 | 
					        } else {
 | 
				
			||||||
 | 
					          dimVal = rewriter.create<ConstantOp>(
 | 
				
			||||||
 | 
					              loc, rewriter.getIntegerAttr(rewriter.getIntegerType(64),
 | 
				
			||||||
 | 
					                                           memRefShape[outIdx]));
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					        tensorSize = rewriter.create<MulIOp>(loc, tensorSize, dimVal);
 | 
				
			||||||
 | 
					        if (std::find(axes.begin(), axes.end(), outIdx) == axes.end())
 | 
				
			||||||
 | 
					          inIdx++;
 | 
				
			||||||
 | 
					      }
 | 
				
			||||||
 | 
					      alloc = rewriter.create<AllocOp>(loc, memRefType, allocOperands);
 | 
				
			||||||
 | 
					      auto *parentBlock = alloc.getDefiningOp()->getBlock();
 | 
				
			||||||
 | 
					      if (insertDealloc) {
 | 
				
			||||||
 | 
					        auto dealloc = rewriter.create<DeallocOp>(loc, alloc);
 | 
				
			||||||
 | 
					        dealloc.getOperation()->moveBefore(&parentBlock->back());
 | 
				
			||||||
 | 
					      }
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					    rewriter.create<KrnlMemcpyOp>(loc, alloc, operands[0], tensorSize);
 | 
				
			||||||
 | 
					    rewriter.replaceOp(op, alloc);
 | 
				
			||||||
 | 
					    return matchSuccess();
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					void populateLoweringONNXUnsqueezeOpPattern(
 | 
				
			||||||
 | 
					    OwningRewritePatternList &patterns, MLIRContext *ctx) {
 | 
				
			||||||
 | 
					  patterns.insert<ONNXUnsqueezeOpLowering>(ctx);
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							
		Loading…
	
		Reference in New Issue