Merge branch 'master' into shapeinference-pad
This commit is contained in:
		
						commit
						907104d7e8
					
				| 
						 | 
					@ -38,7 +38,7 @@ jobs:
 | 
				
			||||||
      - run:
 | 
					      - run:
 | 
				
			||||||
          name: Run End-To-End Tests
 | 
					          name: Run End-To-End Tests
 | 
				
			||||||
          command: |
 | 
					          command: |
 | 
				
			||||||
            sudo pip install -q onnx
 | 
					            sudo pip install -q -e ./ONNF/third_party/onnx
 | 
				
			||||||
            cd ONNF/build
 | 
					            cd ONNF/build
 | 
				
			||||||
            cmake --build . --target run-onnx-backend-test
 | 
					            cmake --build . --target run-onnx-backend-test
 | 
				
			||||||
      - run:
 | 
					      - run:
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -1558,33 +1558,6 @@ ONNX Gather operation
 | 
				
			||||||
 | 
					
 | 
				
			||||||
1. `output`: memref of any type values or tensor of any type values
 | 
					1. `output`: memref of any type values or tensor of any type values
 | 
				
			||||||
 | 
					
 | 
				
			||||||
### onnx.GemmNoBias (ONNXGemmNoBiasOp)
 | 
					 | 
				
			||||||
ONNX general matrix multiply operation without bias.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
#### Description:
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
The "onnx.Gemm" generic matrix multiplication without bias.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
#### Operands:
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
1. `A`: memref of any type values or tensor of any type values
 | 
					 | 
				
			||||||
1. `B`: memref of any type values or tensor of any type values
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
#### Attributes:
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
| Attribute | MLIR Type | Description |
 | 
					 | 
				
			||||||
| :-------: | :-------: | ----------- |
 | 
					 | 
				
			||||||
| `alpha` | `FloatAttr` | 32-bit float attribute attribute |
 | 
					 | 
				
			||||||
| `beta` | `FloatAttr` | 32-bit float attribute attribute |
 | 
					 | 
				
			||||||
| `transA` | `IntegerAttr` | 64-bit integer attribute attribute |
 | 
					 | 
				
			||||||
| `transB` | `IntegerAttr` | 64-bit integer attribute attribute |
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
#### Results:
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
1. `o_Y`: memref of any type values or tensor of any type values
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
### onnx.Gemm (ONNXGemmOp)
 | 
					### onnx.Gemm (ONNXGemmOp)
 | 
				
			||||||
ONNX Gemm operation
 | 
					ONNX Gemm operation
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -62,7 +62,21 @@ target_include_directories(onnf_shape_inference
 | 
				
			||||||
target_link_libraries(onnf_shape_inference ${MLIRLibs})
 | 
					target_link_libraries(onnf_shape_inference ${MLIRLibs})
 | 
				
			||||||
add_dependencies(onnf_shape_inference gen_krnl_ops)
 | 
					add_dependencies(onnf_shape_inference gen_krnl_ops)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
add_library(onnf_lower_frontend conversion/onnx_to_krnl/convert_onnx_to_krnl.cpp)
 | 
					add_library(onnf_lower_frontend
 | 
				
			||||||
 | 
					        conversion/onnx_to_krnl/onnx_to_krnl_common.cpp
 | 
				
			||||||
 | 
					        conversion/onnx_to_krnl/onnx_to_krnl_common.hpp
 | 
				
			||||||
 | 
					        conversion/onnx_to_krnl/math/elementwise.cpp
 | 
				
			||||||
 | 
					        conversion/onnx_to_krnl/math/gemm.cpp
 | 
				
			||||||
 | 
					        conversion/onnx_to_krnl/math/matmul.cpp
 | 
				
			||||||
 | 
					        conversion/onnx_to_krnl/math/reduction.cpp
 | 
				
			||||||
 | 
					        conversion/onnx_to_krnl/math/softmax.cpp
 | 
				
			||||||
 | 
					        conversion/onnx_to_krnl/nn/conv.cpp
 | 
				
			||||||
 | 
					        conversion/onnx_to_krnl/nn/normalization.cpp
 | 
				
			||||||
 | 
					        conversion/onnx_to_krnl/tensor/identity.cpp
 | 
				
			||||||
 | 
					        conversion/onnx_to_krnl/tensor/reshape.cpp
 | 
				
			||||||
 | 
					        conversion/onnx_to_krnl/tensor/transpose.cpp
 | 
				
			||||||
 | 
					        conversion/onnx_to_krnl/tensor/unsqueeze.cpp
 | 
				
			||||||
 | 
					        conversion/onnx_to_krnl/convert_onnx_to_krnl.cpp)
 | 
				
			||||||
target_include_directories(onnf_lower_frontend
 | 
					target_include_directories(onnf_lower_frontend
 | 
				
			||||||
        PRIVATE ${ONNF_SRC_ROOT} ${ONNF_BIN_ROOT}
 | 
					        PRIVATE ${ONNF_SRC_ROOT} ${ONNF_BIN_ROOT}
 | 
				
			||||||
        ${ONNF_SRC_ROOT})
 | 
					        ${ONNF_SRC_ROOT})
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -189,8 +189,9 @@ private:
 | 
				
			||||||
      }
 | 
					      }
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    mlir::Type elementType =
 | 
					    auto elementOnnxType =
 | 
				
			||||||
        convertONNXTypeToMLIRType(input.type().tensor_type().elem_type());
 | 
					        (onnx::TensorProto_DataType)input.type().tensor_type().elem_type();
 | 
				
			||||||
 | 
					    mlir::Type elementType = convertONNXTypeToMLIRType(elementOnnxType);
 | 
				
			||||||
    llvm::ArrayRef<int64_t> tensor_dims(dims.data(), dims.size());
 | 
					    llvm::ArrayRef<int64_t> tensor_dims(dims.data(), dims.size());
 | 
				
			||||||
    arg_types.emplace_back(
 | 
					    arg_types.emplace_back(
 | 
				
			||||||
        mlir::RankedTensorType::get(tensor_dims, elementType));
 | 
					        mlir::RankedTensorType::get(tensor_dims, elementType));
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -8,404 +8,11 @@
 | 
				
			||||||
// Krnl IR and standard operations.
 | 
					// Krnl IR and standard operations.
 | 
				
			||||||
//
 | 
					//
 | 
				
			||||||
//===----------------------------------------------------------------------===//
 | 
					//===----------------------------------------------------------------------===//
 | 
				
			||||||
#include <map>
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
#include "mlir/Dialect/AffineOps/AffineOps.h"
 | 
					#include "src/conversion/onnx_to_krnl/onnx_to_krnl_common.hpp"
 | 
				
			||||||
#include "mlir/Dialect/StandardOps/Ops.h"
 | 
					 | 
				
			||||||
#include "mlir/Pass/Pass.h"
 | 
					 | 
				
			||||||
#include "mlir/Transforms/DialectConversion.h"
 | 
					 | 
				
			||||||
#include "llvm/ADT/ArrayRef.h"
 | 
					 | 
				
			||||||
#include "llvm/ADT/Sequence.h"
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
#include "src/dialect/krnl/krnl_helper.hpp"
 | 
					 | 
				
			||||||
#include "src/dialect/krnl/krnl_ops.hpp"
 | 
					 | 
				
			||||||
#include "src/dialect/onnx/onnx_ops.hpp"
 | 
					 | 
				
			||||||
#include "src/pass/passes.hpp"
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
using namespace mlir;
 | 
					using namespace mlir;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
//===----------------------------------------------------------------------===//
 | 
					 | 
				
			||||||
// FrontendToAffine RewritePatterns
 | 
					 | 
				
			||||||
//===----------------------------------------------------------------------===//
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
/// Check is all dimensions are known at compile time.
 | 
					 | 
				
			||||||
static bool hasAllConstantDimensions(MemRefType type) {
 | 
					 | 
				
			||||||
  auto memRefShape = type.getShape();
 | 
					 | 
				
			||||||
  for (int i = 0; i < memRefShape.size(); ++i)
 | 
					 | 
				
			||||||
    if (memRefShape[i] < 0)
 | 
					 | 
				
			||||||
      return false;
 | 
					 | 
				
			||||||
  return true;
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
/// Get the corresponding MemRefType of a given TensorType/MemRefType.
 | 
					 | 
				
			||||||
static MemRefType convertToMemRefType(Type type) {
 | 
					 | 
				
			||||||
  MemRefType memRefType;
 | 
					 | 
				
			||||||
  auto tensorType = type.dyn_cast<TensorType>();
 | 
					 | 
				
			||||||
  if (tensorType) {
 | 
					 | 
				
			||||||
    assert(tensorType.hasRank() && "expected only ranked shapes");
 | 
					 | 
				
			||||||
    memRefType =
 | 
					 | 
				
			||||||
        MemRefType::get(tensorType.getShape(), tensorType.getElementType());
 | 
					 | 
				
			||||||
  } else {
 | 
					 | 
				
			||||||
    memRefType = type.dyn_cast<MemRefType>();
 | 
					 | 
				
			||||||
  }
 | 
					 | 
				
			||||||
  return memRefType;
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
/// Insert an allocation and deallocation for the given MemRefType.
 | 
					 | 
				
			||||||
static Value insertAllocAndDealloc(MemRefType type, Location loc,
 | 
					 | 
				
			||||||
                                   PatternRewriter &rewriter,
 | 
					 | 
				
			||||||
                                   bool insertDealloc,
 | 
					 | 
				
			||||||
                                   ArrayRef<Value> operands = {}) {
 | 
					 | 
				
			||||||
  // Put together alloc operands for any dynamic dimensions of the memref.
 | 
					 | 
				
			||||||
  AllocOp alloc;
 | 
					 | 
				
			||||||
  if (!operands.empty()) {
 | 
					 | 
				
			||||||
    auto memRefShape = type.getShape();
 | 
					 | 
				
			||||||
    auto rank = memRefShape.size();
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    std::map<int, Value> fromOperands;
 | 
					 | 
				
			||||||
    for (int reversedIdx = 0; reversedIdx < rank; ++reversedIdx) {
 | 
					 | 
				
			||||||
      int memRefDimIdx = rank - 1 - reversedIdx;
 | 
					 | 
				
			||||||
      if (memRefShape[memRefDimIdx] < 0) { // unknown dimension
 | 
					 | 
				
			||||||
        Value maxDim = nullptr;
 | 
					 | 
				
			||||||
        for (int i = 0; i < operands.size(); i++) {
 | 
					 | 
				
			||||||
          auto operandShape =
 | 
					 | 
				
			||||||
              operands[i].getType().cast<MemRefType>().getShape();
 | 
					 | 
				
			||||||
          int operandDimIdx = operandShape.size() - 1 - reversedIdx;
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
          if (operandDimIdx < 0)
 | 
					 | 
				
			||||||
            continue;
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
          // In case of operations with broadcasting, the dimension of the
 | 
					 | 
				
			||||||
          // alloc result is the maximum size along each dimension of the
 | 
					 | 
				
			||||||
          // operands.
 | 
					 | 
				
			||||||
          auto operandDim =
 | 
					 | 
				
			||||||
              rewriter.create<DimOp>(loc, operands[i], operandDimIdx);
 | 
					 | 
				
			||||||
          if (maxDim) {
 | 
					 | 
				
			||||||
            auto maxCondition = rewriter.create<CmpIOp>(loc, CmpIPredicate::sgt,
 | 
					 | 
				
			||||||
                                                        operandDim, maxDim);
 | 
					 | 
				
			||||||
            maxDim = rewriter.create<SelectOp>(loc, maxCondition, operandDim,
 | 
					 | 
				
			||||||
                                               maxDim);
 | 
					 | 
				
			||||||
          } else {
 | 
					 | 
				
			||||||
            maxDim = operandDim;
 | 
					 | 
				
			||||||
          }
 | 
					 | 
				
			||||||
        }
 | 
					 | 
				
			||||||
        fromOperands.insert(std::make_pair(memRefDimIdx, maxDim));
 | 
					 | 
				
			||||||
      }
 | 
					 | 
				
			||||||
    }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    SmallVector<Value, 4> allocOperands;
 | 
					 | 
				
			||||||
    for (int i = 0; i < rank; ++i)
 | 
					 | 
				
			||||||
      if (memRefShape[i] < 0)
 | 
					 | 
				
			||||||
        allocOperands.push_back(fromOperands[i]);
 | 
					 | 
				
			||||||
    alloc = rewriter.create<AllocOp>(loc, type, allocOperands);
 | 
					 | 
				
			||||||
  } else {
 | 
					 | 
				
			||||||
    alloc = rewriter.create<AllocOp>(loc, type);
 | 
					 | 
				
			||||||
  }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
  // Make sure to allocate at the beginning of the block if
 | 
					 | 
				
			||||||
  // all dimensions are known.
 | 
					 | 
				
			||||||
  auto *parentBlock = alloc.getOperation()->getBlock();
 | 
					 | 
				
			||||||
  if (hasAllConstantDimensions(type))
 | 
					 | 
				
			||||||
    alloc.getOperation()->moveBefore(&parentBlock->front());
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
  if (insertDealloc) {
 | 
					 | 
				
			||||||
    auto dealloc = rewriter.create<DeallocOp>(loc, alloc);
 | 
					 | 
				
			||||||
    dealloc.getOperation()->moveBefore(&parentBlock->back());
 | 
					 | 
				
			||||||
  }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
  return alloc;
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
// Determine if current function returns the result value of the
 | 
					 | 
				
			||||||
// current op being lowered. If it does then dealloc should not be
 | 
					 | 
				
			||||||
// inserted.
 | 
					 | 
				
			||||||
static bool checkInsertDealloc(Operation *currentOp) {
 | 
					 | 
				
			||||||
  auto parentBlock = currentOp->getBlock();
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
  bool insertDealloc = true;
 | 
					 | 
				
			||||||
  parentBlock->walk([&insertDealloc, currentOp](ReturnOp op) {
 | 
					 | 
				
			||||||
    assert(currentOp->getNumResults() < 2 &&
 | 
					 | 
				
			||||||
           "No more than one result supported (for now).");
 | 
					 | 
				
			||||||
    // If there is at least one result to investigate.
 | 
					 | 
				
			||||||
    if (currentOp->getNumResults() > 0) {
 | 
					 | 
				
			||||||
      auto result = currentOp->getResult(0);
 | 
					 | 
				
			||||||
      for (const auto &operand : op.getOperands())
 | 
					 | 
				
			||||||
        if (operand == result)
 | 
					 | 
				
			||||||
          insertDealloc = false;
 | 
					 | 
				
			||||||
    }
 | 
					 | 
				
			||||||
  });
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
  return insertDealloc;
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
// Create a mapping from result type's dimensions to input type's dimensions,
 | 
					 | 
				
			||||||
// given that the result type is the result of a reduction op over the input
 | 
					 | 
				
			||||||
// type.
 | 
					 | 
				
			||||||
std::map<int64_t, int64_t>
 | 
					 | 
				
			||||||
getReductionMapping(MemRefType inputTy, ArrayRef<int64_t> axes, bool keepdims) {
 | 
					 | 
				
			||||||
  std::map<int64_t, int64_t> OutInDimMap;
 | 
					 | 
				
			||||||
  int64_t rank = inputTy.getRank();
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
  // Mark reduction axes.
 | 
					 | 
				
			||||||
  std::vector<bool> isReductionAxis;
 | 
					 | 
				
			||||||
  for (decltype(rank) i = 0; i < rank; ++i) {
 | 
					 | 
				
			||||||
    if (std::find(axes.begin(), axes.end(), i) != axes.end())
 | 
					 | 
				
			||||||
      isReductionAxis.push_back(true);
 | 
					 | 
				
			||||||
    else
 | 
					 | 
				
			||||||
      isReductionAxis.push_back(false);
 | 
					 | 
				
			||||||
  }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
  for (decltype(rank) inIndex = 0, outIndex = 0; inIndex < rank; ++inIndex) {
 | 
					 | 
				
			||||||
    // If it is a reduction axis, there is no relationship among dimensions.
 | 
					 | 
				
			||||||
    if (isReductionAxis[inIndex]) {
 | 
					 | 
				
			||||||
      if (keepdims)
 | 
					 | 
				
			||||||
        outIndex++;
 | 
					 | 
				
			||||||
    } else {
 | 
					 | 
				
			||||||
      OutInDimMap.insert(std::make_pair(outIndex, inIndex));
 | 
					 | 
				
			||||||
      outIndex++;
 | 
					 | 
				
			||||||
    }
 | 
					 | 
				
			||||||
  }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
  return OutInDimMap;
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
// Add bounds associated with the op operand to the KRNL iteration pack.
 | 
					 | 
				
			||||||
// Dynamic dimenions are supported.
 | 
					 | 
				
			||||||
static void addDimensionToPack(ConversionPatternRewriter &rewriter,
 | 
					 | 
				
			||||||
                               Location loc, KrnlIterateOperandPack &pack,
 | 
					 | 
				
			||||||
                               Value operand, int index) {
 | 
					 | 
				
			||||||
  auto shape = operand.getType().cast<MemRefType>().getShape();
 | 
					 | 
				
			||||||
  if (shape[index] < 0) {
 | 
					 | 
				
			||||||
    pack.pushConstantBound(0);
 | 
					 | 
				
			||||||
    pack.pushOperandBound(
 | 
					 | 
				
			||||||
        rewriter.create<DimOp>(loc, operand, index).getResult());
 | 
					 | 
				
			||||||
  } else {
 | 
					 | 
				
			||||||
    pack.pushConstantBound(0);
 | 
					 | 
				
			||||||
    pack.pushConstantBound(shape[index]);
 | 
					 | 
				
			||||||
  }
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
// Function that defines the KRNL dialect loops and their respective
 | 
					 | 
				
			||||||
// optimized version.
 | 
					 | 
				
			||||||
static KrnlOptimizeLoopsOp
 | 
					 | 
				
			||||||
emitOptimizedLoops(ConversionPatternRewriter &rewriter, Location loc,
 | 
					 | 
				
			||||||
                   std::vector<Value> &loops,
 | 
					 | 
				
			||||||
                   std::vector<Value> &optimizedLoops, int64_t numLoops) {
 | 
					 | 
				
			||||||
  // Define loops.
 | 
					 | 
				
			||||||
  auto loopsOp = rewriter.create<KrnlDefineLoopsOp>(loc, numLoops);
 | 
					 | 
				
			||||||
  loops.reserve(numLoops);
 | 
					 | 
				
			||||||
  for (auto result : loopsOp.getResults())
 | 
					 | 
				
			||||||
    loops.push_back(result);
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
  // Define optimized version of the loops.
 | 
					 | 
				
			||||||
  auto optimizedLoopsOp = rewriter.create<KrnlOptimizeLoopsOp>(loc, numLoops);
 | 
					 | 
				
			||||||
  optimizedLoops.reserve(numLoops);
 | 
					 | 
				
			||||||
  for (auto result : optimizedLoopsOp.getResults())
 | 
					 | 
				
			||||||
    optimizedLoops.push_back(result);
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
  return optimizedLoopsOp;
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
// Function that emits the loops and their optimized version.
 | 
					 | 
				
			||||||
// The function returns a reference to the inner optimization block.
 | 
					 | 
				
			||||||
static Block *defineLoops(ConversionPatternRewriter &rewriter, Location loc,
 | 
					 | 
				
			||||||
                          std::vector<Value> &loops,
 | 
					 | 
				
			||||||
                          std::vector<Value> &optimizedLoops,
 | 
					 | 
				
			||||||
                          int64_t numLoops) {
 | 
					 | 
				
			||||||
  KrnlOptimizeLoopsOp optimizedLoopsOp =
 | 
					 | 
				
			||||||
      emitOptimizedLoops(rewriter, loc, loops, optimizedLoops, numLoops);
 | 
					 | 
				
			||||||
  return &optimizedLoopsOp.region().front();
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
// Function which emits a basic set of loops and optimized loops
 | 
					 | 
				
			||||||
// for a given operation argument. A reference to the loop optimization
 | 
					 | 
				
			||||||
// block is returned in the last argument of the function.
 | 
					 | 
				
			||||||
static void emitKrnlLoopsAndIterationForOperand(
 | 
					 | 
				
			||||||
    ConversionPatternRewriter &rewriter, Location loc, Value operand,
 | 
					 | 
				
			||||||
    std::vector<Value> &originalLoops, KrnlOptimizeLoopsOp &optimizedLoopsOp,
 | 
					 | 
				
			||||||
    KrnlIterateOp &iterateOp) {
 | 
					 | 
				
			||||||
  // Operand shape.
 | 
					 | 
				
			||||||
  auto shape = operand.getType().cast<MemRefType>().getShape();
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
  // Number of loops.
 | 
					 | 
				
			||||||
  int64_t rank = shape.size();
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
  // Define loops and optimized loops.
 | 
					 | 
				
			||||||
  std::vector<Value> optimizedLoops;
 | 
					 | 
				
			||||||
  optimizedLoopsOp =
 | 
					 | 
				
			||||||
      emitOptimizedLoops(rewriter, loc, originalLoops, optimizedLoops, rank);
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
  KrnlIterateOperandPack pack(rewriter, originalLoops, optimizedLoops);
 | 
					 | 
				
			||||||
  // Iterate over the loop nest.
 | 
					 | 
				
			||||||
  for (int i = 0; i < rank; ++i)
 | 
					 | 
				
			||||||
    addDimensionToPack(rewriter, loc, pack, operand, i);
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
  iterateOp = rewriter.create<KrnlIterateOp>(loc, pack);
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
unsigned getMemRefEltSizeInBytes(MemRefType memRefType) {
 | 
					 | 
				
			||||||
  auto elementType = memRefType.getElementType();
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
  unsigned sizeInBits;
 | 
					 | 
				
			||||||
  if (elementType.isIntOrFloat()) {
 | 
					 | 
				
			||||||
    sizeInBits = elementType.getIntOrFloatBitWidth();
 | 
					 | 
				
			||||||
  } else {
 | 
					 | 
				
			||||||
    auto vectorType = elementType.cast<VectorType>();
 | 
					 | 
				
			||||||
    sizeInBits =
 | 
					 | 
				
			||||||
        vectorType.getElementTypeBitWidth() * vectorType.getNumElements();
 | 
					 | 
				
			||||||
  }
 | 
					 | 
				
			||||||
  return llvm::divideCeil(sizeInBits, 8);
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
// Get run-time dimension information for unknown dimensions used for
 | 
					 | 
				
			||||||
// broadcasting.
 | 
					 | 
				
			||||||
std::map<int, std::map<int, Value>>
 | 
					 | 
				
			||||||
getBroadcastedDimInfo(Location loc, ConversionPatternRewriter &rewriter,
 | 
					 | 
				
			||||||
                      MemRefType memRefType, ArrayRef<Value> operands) {
 | 
					 | 
				
			||||||
  auto memRefShape = memRefType.getShape();
 | 
					 | 
				
			||||||
  int64_t rank = memRefShape.size();
 | 
					 | 
				
			||||||
  // For unknown dimensions, we need to get dimension values at runtime in
 | 
					 | 
				
			||||||
  // order to do broadcasting.
 | 
					 | 
				
			||||||
  std::map<int, std::map<int, Value>> DimInfo;
 | 
					 | 
				
			||||||
  // For each result dimension, compute the number of sharing operands.
 | 
					 | 
				
			||||||
  // Sharing operands are operands sharing the same index (counting from the
 | 
					 | 
				
			||||||
  // rightmost to the leftmost) for a given dimension.
 | 
					 | 
				
			||||||
  std::map<int, int> sharedDimCount;
 | 
					 | 
				
			||||||
  for (int reversedIdx = 0; reversedIdx < rank; ++reversedIdx) {
 | 
					 | 
				
			||||||
    int dimIdx = rank - 1 - reversedIdx;
 | 
					 | 
				
			||||||
    sharedDimCount[dimIdx] = 0;
 | 
					 | 
				
			||||||
    for (int i = 0; i < operands.size(); ++i) {
 | 
					 | 
				
			||||||
      auto shape = operands[i].getType().cast<MemRefType>().getShape();
 | 
					 | 
				
			||||||
      if (reversedIdx <= shape.size() - 1)
 | 
					 | 
				
			||||||
        sharedDimCount[dimIdx]++;
 | 
					 | 
				
			||||||
    }
 | 
					 | 
				
			||||||
  }
 | 
					 | 
				
			||||||
  // An unknown dimension can have a value of 1 or N (N > 1).
 | 
					 | 
				
			||||||
  // If its value is 1, it is broadcasted dimension.
 | 
					 | 
				
			||||||
  // Otherwise, non-broadcasted dimension.
 | 
					 | 
				
			||||||
  // We only care about unknown dimensions whose number of sharing operands is
 | 
					 | 
				
			||||||
  // more than one, since they are potentially broadcasted dimensions.
 | 
					 | 
				
			||||||
  for (int i = 0; i < operands.size(); ++i) {
 | 
					 | 
				
			||||||
    std::map<int, Value> broadcastedDims;
 | 
					 | 
				
			||||||
    auto shape = operands[i].getType().cast<MemRefType>().getShape();
 | 
					 | 
				
			||||||
    int size = shape.size();
 | 
					 | 
				
			||||||
    for (int j = 0; j < shape.size(); ++j) {
 | 
					 | 
				
			||||||
      if (shape[j] < 0 and sharedDimCount[rank - size + j] > 1) {
 | 
					 | 
				
			||||||
        auto dim = rewriter.create<DimOp>(loc, operands[i], j).getResult();
 | 
					 | 
				
			||||||
        auto one = rewriter.create<ConstantIndexOp>(loc, 1);
 | 
					 | 
				
			||||||
        auto isBroadcasted =
 | 
					 | 
				
			||||||
            rewriter.create<CmpIOp>(loc, CmpIPredicate::eq, dim, one);
 | 
					 | 
				
			||||||
        broadcastedDims.insert(std::make_pair(j, isBroadcasted));
 | 
					 | 
				
			||||||
      }
 | 
					 | 
				
			||||||
    }
 | 
					 | 
				
			||||||
    DimInfo.insert(std::make_pair(i, broadcastedDims));
 | 
					 | 
				
			||||||
  }
 | 
					 | 
				
			||||||
  return DimInfo;
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
// Extract induction variables that are used for broadcasting values of a
 | 
					 | 
				
			||||||
// given operand.
 | 
					 | 
				
			||||||
std::vector<Value>
 | 
					 | 
				
			||||||
getLoopIVsForBroadcasting(Location loc, ConversionPatternRewriter &rewriter,
 | 
					 | 
				
			||||||
                          ArrayRef<Value> loopIVs, Value operand,
 | 
					 | 
				
			||||||
                          std::map<int, Value> broadcastedDims) {
 | 
					 | 
				
			||||||
  // `operand` must has a ranked type. This should have been checked by the
 | 
					 | 
				
			||||||
  // shape inference pass.
 | 
					 | 
				
			||||||
  auto operandShape = operand.getType().cast<MemRefType>().getShape();
 | 
					 | 
				
			||||||
  auto rank = operandShape.size();
 | 
					 | 
				
			||||||
  auto loopCount = loopIVs.size();
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
  std::vector<Value> newLoopIVs;
 | 
					 | 
				
			||||||
  for (unsigned reversedIdx = 0; reversedIdx < rank; ++reversedIdx) {
 | 
					 | 
				
			||||||
    auto dimIdx = rank - 1 - reversedIdx;
 | 
					 | 
				
			||||||
    auto loopIdx = loopCount - 1 - reversedIdx;
 | 
					 | 
				
			||||||
    if (operandShape[dimIdx] == 1) {
 | 
					 | 
				
			||||||
      // Broadcasted dimension
 | 
					 | 
				
			||||||
      auto zero = rewriter.create<ConstantIndexOp>(loc, 0);
 | 
					 | 
				
			||||||
      newLoopIVs.insert(newLoopIVs.begin(), zero);
 | 
					 | 
				
			||||||
    } else if ((operandShape[dimIdx] == -1) &&
 | 
					 | 
				
			||||||
               (broadcastedDims.find(dimIdx) != broadcastedDims.end())) {
 | 
					 | 
				
			||||||
      // Unknown dimension, it can have a value of 1 or N (N > 1).
 | 
					 | 
				
			||||||
      // If its value is 1, it is broadcasted dimension.
 | 
					 | 
				
			||||||
      // Otherwise, non-broadcasted dimension.
 | 
					 | 
				
			||||||
      auto zero = rewriter.create<ConstantIndexOp>(loc, 0);
 | 
					 | 
				
			||||||
      auto idx = rewriter.create<SelectOp>(loc, broadcastedDims[dimIdx], zero,
 | 
					 | 
				
			||||||
                                           loopIVs[loopIdx]);
 | 
					 | 
				
			||||||
      newLoopIVs.insert(newLoopIVs.begin(), idx);
 | 
					 | 
				
			||||||
    } else {
 | 
					 | 
				
			||||||
      // Non-broadcasted dimension
 | 
					 | 
				
			||||||
      newLoopIVs.insert(newLoopIVs.begin(), loopIVs[loopIdx]);
 | 
					 | 
				
			||||||
    }
 | 
					 | 
				
			||||||
  }
 | 
					 | 
				
			||||||
  return newLoopIVs;
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
namespace {
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
// This is to get a scalar operation of a given type for a specific operation.
 | 
					 | 
				
			||||||
template <typename Op>
 | 
					 | 
				
			||||||
struct ScalarOp {
 | 
					 | 
				
			||||||
  using FOp = void;
 | 
					 | 
				
			||||||
  using IOp = void;
 | 
					 | 
				
			||||||
};
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
template <typename FOp>
 | 
					 | 
				
			||||||
using ScalarFOp = typename ScalarOp<FOp>::FOp;
 | 
					 | 
				
			||||||
template <typename IOp>
 | 
					 | 
				
			||||||
using ScalarIOp = typename ScalarOp<IOp>::IOp;
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
// Get the identity element of a operation.
 | 
					 | 
				
			||||||
// Return NULL if the function does not have identity.
 | 
					 | 
				
			||||||
template <typename DataType, typename Op>
 | 
					 | 
				
			||||||
DataType getIdentityValue() {
 | 
					 | 
				
			||||||
  return NULL;
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
//===----------------------------------------------------------------------===//
 | 
					 | 
				
			||||||
// This is used in the innermost loop of a KrnlIterateOp to insert computation
 | 
					 | 
				
			||||||
// composed of one or many scalar ops.
 | 
					 | 
				
			||||||
// Use template specialization for each of different ONNX operations.
 | 
					 | 
				
			||||||
//===----------------------------------------------------------------------===//
 | 
					 | 
				
			||||||
template <typename Op>
 | 
					 | 
				
			||||||
Value mapToLowerScalarOp(Operation *op, ArrayRef<Type> result_types,
 | 
					 | 
				
			||||||
                         ArrayRef<Value> operands,
 | 
					 | 
				
			||||||
                         ConversionPatternRewriter &rewriter) {
 | 
					 | 
				
			||||||
  auto loc = op->getLoc();
 | 
					 | 
				
			||||||
  Type element_type = operands.front().getType();
 | 
					 | 
				
			||||||
  if (element_type.isa<IntegerType>()) {
 | 
					 | 
				
			||||||
    return rewriter.create<ScalarIOp<Op>>(loc, result_types, operands,
 | 
					 | 
				
			||||||
                                          mlir::None);
 | 
					 | 
				
			||||||
  } else if (element_type.isa<FloatType>()) {
 | 
					 | 
				
			||||||
    return rewriter.create<ScalarFOp<Op>>(loc, result_types, operands,
 | 
					 | 
				
			||||||
                                          mlir::None);
 | 
					 | 
				
			||||||
  } else {
 | 
					 | 
				
			||||||
    emitError(loc, "unsupported element type");
 | 
					 | 
				
			||||||
    return nullptr;
 | 
					 | 
				
			||||||
  }
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
// We divide the operator lowering into different categories.
 | 
					 | 
				
			||||||
// These categories are mostly similar to the operator categories in ONNX:
 | 
					 | 
				
			||||||
// https://github.com/onnx/onnx/tree/master/onnx/defs.
 | 
					 | 
				
			||||||
// Besides, it is better to put operators with the same computation pattern into
 | 
					 | 
				
			||||||
// the same category, e.g. element-wise operators will belong to the elementwise
 | 
					 | 
				
			||||||
// category.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
// Math
 | 
					 | 
				
			||||||
#include "src/conversion/onnx_to_krnl/rewrite_patterns/math/elementwise.inc"
 | 
					 | 
				
			||||||
#include "src/conversion/onnx_to_krnl/rewrite_patterns/math/gemm.inc"
 | 
					 | 
				
			||||||
#include "src/conversion/onnx_to_krnl/rewrite_patterns/math/reduction.inc"
 | 
					 | 
				
			||||||
#include "src/conversion/onnx_to_krnl/rewrite_patterns/math/softmax.inc"
 | 
					 | 
				
			||||||
#include "src/conversion/onnx_to_krnl/rewrite_patterns/math/matmul.inc"
 | 
					 | 
				
			||||||
// Tensor
 | 
					 | 
				
			||||||
#include "src/conversion/onnx_to_krnl/rewrite_patterns/tensor/identity.inc"
 | 
					 | 
				
			||||||
#include "src/conversion/onnx_to_krnl/rewrite_patterns/tensor/reshape.inc"
 | 
					 | 
				
			||||||
#include "src/conversion/onnx_to_krnl/rewrite_patterns/tensor/transpose.inc"
 | 
					 | 
				
			||||||
#include "src/conversion/onnx_to_krnl/rewrite_patterns/tensor/unsqueeze.inc"
 | 
					 | 
				
			||||||
// Neural network
 | 
					 | 
				
			||||||
#include "src/conversion/onnx_to_krnl/rewrite_patterns/nn/conv.inc"
 | 
					 | 
				
			||||||
#include "src/conversion/onnx_to_krnl/rewrite_patterns/nn/normalization.inc"
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
//===----------------------------------------------------------------------===//
 | 
					//===----------------------------------------------------------------------===//
 | 
				
			||||||
// EntryPoint Op lowering to Krnl Entry Point.
 | 
					// EntryPoint Op lowering to Krnl Entry Point.
 | 
				
			||||||
//===----------------------------------------------------------------------===//
 | 
					//===----------------------------------------------------------------------===//
 | 
				
			||||||
| 
						 | 
					@ -427,39 +34,6 @@ public:
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
};
 | 
					};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
//===----------------------------------------------------------------------===//
 | 
					 | 
				
			||||||
// Conversion from Tensor type to the Standard dialect MemRef type.
 | 
					 | 
				
			||||||
//===----------------------------------------------------------------------===//
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
struct TensorTypeConverter : public TypeConverter {
 | 
					 | 
				
			||||||
  using TypeConverter::TypeConverter;
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
  TensorTypeConverter() {
 | 
					 | 
				
			||||||
    addConversion(convertType);
 | 
					 | 
				
			||||||
  }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
  static LogicalResult convertType(Type t, SmallVectorImpl<Type> &results) {
 | 
					 | 
				
			||||||
    if (auto type = convertToMemRefType(t)) {
 | 
					 | 
				
			||||||
      results.push_back(type);
 | 
					 | 
				
			||||||
      return success();
 | 
					 | 
				
			||||||
    }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    results.push_back(t);
 | 
					 | 
				
			||||||
    return success();
 | 
					 | 
				
			||||||
  }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
  /// Return true if the inputs and outputs of the given function type are
 | 
					 | 
				
			||||||
  /// legal. [Taken from MLIR and adapted to only check the legality of the
 | 
					 | 
				
			||||||
  /// inputs. Once unranked results can be handled gracefully this
 | 
					 | 
				
			||||||
  /// override needs to be removed in favour of the original MLIR one.]
 | 
					 | 
				
			||||||
  bool isSignatureLegal(FunctionType funcType) {
 | 
					 | 
				
			||||||
    return llvm::all_of(funcType.getInputs(),
 | 
					 | 
				
			||||||
                        [this](Type type) { return isLegal(type); });
 | 
					 | 
				
			||||||
  }
 | 
					 | 
				
			||||||
};
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
} // end anonymous namespace.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
//===----------------------------------------------------------------------===//
 | 
					//===----------------------------------------------------------------------===//
 | 
				
			||||||
// Frontend to Krnl Dialect lowering pass
 | 
					// Frontend to Krnl Dialect lowering pass
 | 
				
			||||||
//===----------------------------------------------------------------------===//
 | 
					//===----------------------------------------------------------------------===//
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -1,4 +1,4 @@
 | 
				
			||||||
//===----- elementwise.inc - Elementwise Ops ------------------------------===//
 | 
					//===----- elementwise.cpp - Elementwise Ops ------------------------------===//
 | 
				
			||||||
//
 | 
					//
 | 
				
			||||||
// Copyright 2019 The IBM Research Authors.
 | 
					// Copyright 2019 The IBM Research Authors.
 | 
				
			||||||
//
 | 
					//
 | 
				
			||||||
| 
						 | 
					@ -8,6 +8,10 @@
 | 
				
			||||||
//
 | 
					//
 | 
				
			||||||
//===----------------------------------------------------------------------===//
 | 
					//===----------------------------------------------------------------------===//
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#include "src/conversion/onnx_to_krnl/onnx_to_krnl_common.hpp"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					using namespace mlir;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
template <>
 | 
					template <>
 | 
				
			||||||
struct ScalarOp<ONNXAddOp> {
 | 
					struct ScalarOp<ONNXAddOp> {
 | 
				
			||||||
  using FOp = AddFOp;
 | 
					  using FOp = AddFOp;
 | 
				
			||||||
| 
						 | 
					@ -1,4 +1,4 @@
 | 
				
			||||||
//===----- gemm.inc - Lowering Gemm Op ------------------------------------===//
 | 
					//===----- gemm.cpp - Lowering Gemm Op ------------------------------------===//
 | 
				
			||||||
//
 | 
					//
 | 
				
			||||||
// Copyright 2019 The IBM Research Authors.
 | 
					// Copyright 2019 The IBM Research Authors.
 | 
				
			||||||
//
 | 
					//
 | 
				
			||||||
| 
						 | 
					@ -8,6 +8,10 @@
 | 
				
			||||||
//
 | 
					//
 | 
				
			||||||
//===----------------------------------------------------------------------===//
 | 
					//===----------------------------------------------------------------------===//
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#include "src/conversion/onnx_to_krnl/onnx_to_krnl_common.hpp"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					using namespace mlir;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
template <typename GemmOp>
 | 
					template <typename GemmOp>
 | 
				
			||||||
struct ONNXGemmOpLowering : public ConversionPattern {
 | 
					struct ONNXGemmOpLowering : public ConversionPattern {
 | 
				
			||||||
  ONNXGemmOpLowering(MLIRContext *ctx)
 | 
					  ONNXGemmOpLowering(MLIRContext *ctx)
 | 
				
			||||||
| 
						 | 
					@ -17,9 +21,7 @@ struct ONNXGemmOpLowering : public ConversionPattern {
 | 
				
			||||||
  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
 | 
					  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
 | 
				
			||||||
                  ConversionPatternRewriter &rewriter) const final {
 | 
					                  ConversionPatternRewriter &rewriter) const final {
 | 
				
			||||||
    auto loc = op->getLoc();
 | 
					    auto loc = op->getLoc();
 | 
				
			||||||
    // The first predicate is unnecessary when we remove ONXGemmNoBiasOp.
 | 
					    bool hasBias = !op->getOperand(2).getType().isa<NoneType>();
 | 
				
			||||||
    bool hasBias = (operands.size() == 3) &&
 | 
					 | 
				
			||||||
                   (!op->getOperand(2).getType().isa<NoneType>());
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    Value A, B, C;
 | 
					    Value A, B, C;
 | 
				
			||||||
    A = operands[0];
 | 
					    A = operands[0];
 | 
				
			||||||
| 
						 | 
					@ -215,5 +217,4 @@ struct ONNXGemmOpLowering : public ConversionPattern {
 | 
				
			||||||
void populateLoweringONNXGemmOpPattern(OwningRewritePatternList &patterns,
 | 
					void populateLoweringONNXGemmOpPattern(OwningRewritePatternList &patterns,
 | 
				
			||||||
                                       MLIRContext *ctx) {
 | 
					                                       MLIRContext *ctx) {
 | 
				
			||||||
  patterns.insert<ONNXGemmOpLowering<ONNXGemmOp>>(ctx);
 | 
					  patterns.insert<ONNXGemmOpLowering<ONNXGemmOp>>(ctx);
 | 
				
			||||||
  patterns.insert<ONNXGemmOpLowering<ONNXGemmNoBiasOp>>(ctx);
 | 
					 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
| 
						 | 
					@ -1,4 +1,4 @@
 | 
				
			||||||
//===----- matmul.inc - Lowering Matmul Op --------------------------------===//
 | 
					//===----- matmul.cpp - Lowering Matmul Op --------------------------------===//
 | 
				
			||||||
//
 | 
					//
 | 
				
			||||||
// Copyright 2019 The IBM Research Authors.
 | 
					// Copyright 2019 The IBM Research Authors.
 | 
				
			||||||
//
 | 
					//
 | 
				
			||||||
| 
						 | 
					@ -8,6 +8,10 @@
 | 
				
			||||||
//
 | 
					//
 | 
				
			||||||
//===----------------------------------------------------------------------===//
 | 
					//===----------------------------------------------------------------------===//
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#include "src/conversion/onnx_to_krnl/onnx_to_krnl_common.hpp"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					using namespace mlir;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
struct ONNXMatMulOpLowering : public ConversionPattern {
 | 
					struct ONNXMatMulOpLowering : public ConversionPattern {
 | 
				
			||||||
  ONNXMatMulOpLowering(MLIRContext *ctx)
 | 
					  ONNXMatMulOpLowering(MLIRContext *ctx)
 | 
				
			||||||
      : ConversionPattern(mlir::ONNXMatMulOp::getOperationName(), 1, ctx) {}
 | 
					      : ConversionPattern(mlir::ONNXMatMulOp::getOperationName(), 1, ctx) {}
 | 
				
			||||||
| 
						 | 
					@ -1,4 +1,4 @@
 | 
				
			||||||
//===----- reduction.inc - Lowering Reduction Ops -------------------------===//
 | 
					//===----- reduction.cpp - Lowering Reduction Ops -------------------------===//
 | 
				
			||||||
//
 | 
					//
 | 
				
			||||||
// Copyright 2019 The IBM Research Authors.
 | 
					// Copyright 2019 The IBM Research Authors.
 | 
				
			||||||
//
 | 
					//
 | 
				
			||||||
| 
						 | 
					@ -8,6 +8,10 @@
 | 
				
			||||||
//
 | 
					//
 | 
				
			||||||
//===----------------------------------------------------------------------===//
 | 
					//===----------------------------------------------------------------------===//
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#include "src/conversion/onnx_to_krnl/onnx_to_krnl_common.hpp"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					using namespace mlir;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// Identity values
 | 
					// Identity values
 | 
				
			||||||
template <>
 | 
					template <>
 | 
				
			||||||
float getIdentityValue<float, ONNXReduceMaxOp>(){
 | 
					float getIdentityValue<float, ONNXReduceMaxOp>(){
 | 
				
			||||||
| 
						 | 
					@ -1,4 +1,4 @@
 | 
				
			||||||
//===----- softmax.inc - Softmax Op ---------------------------------------===//
 | 
					//===----- softmax.cpp - Softmax Op ---------------------------------------===//
 | 
				
			||||||
//
 | 
					//
 | 
				
			||||||
// Copyright 2019 The IBM Research Authors.
 | 
					// Copyright 2019 The IBM Research Authors.
 | 
				
			||||||
//
 | 
					//
 | 
				
			||||||
| 
						 | 
					@ -8,6 +8,10 @@
 | 
				
			||||||
//
 | 
					//
 | 
				
			||||||
//===----------------------------------------------------------------------===//
 | 
					//===----------------------------------------------------------------------===//
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#include "src/conversion/onnx_to_krnl/onnx_to_krnl_common.hpp"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					using namespace mlir;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
struct ONNXSoftmaxOpLowering : public ConversionPattern {
 | 
					struct ONNXSoftmaxOpLowering : public ConversionPattern {
 | 
				
			||||||
  ONNXSoftmaxOpLowering(MLIRContext *ctx)
 | 
					  ONNXSoftmaxOpLowering(MLIRContext *ctx)
 | 
				
			||||||
      : ConversionPattern(mlir::ONNXSoftmaxOp::getOperationName(), 1, ctx) {}
 | 
					      : ConversionPattern(mlir::ONNXSoftmaxOp::getOperationName(), 1, ctx) {}
 | 
				
			||||||
| 
						 | 
					@ -1,4 +1,4 @@
 | 
				
			||||||
//===----- conv.inc - Lowering Convolution Op -----------------------------===//
 | 
					//===----- conv.cpp - Lowering Convolution Op -----------------------------===//
 | 
				
			||||||
//
 | 
					//
 | 
				
			||||||
// Copyright 2019 The IBM Research Authors.
 | 
					// Copyright 2019 The IBM Research Authors.
 | 
				
			||||||
//
 | 
					//
 | 
				
			||||||
| 
						 | 
					@ -8,6 +8,10 @@
 | 
				
			||||||
//
 | 
					//
 | 
				
			||||||
//===----------------------------------------------------------------------===//
 | 
					//===----------------------------------------------------------------------===//
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#include "src/conversion/onnx_to_krnl/onnx_to_krnl_common.hpp"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					using namespace mlir;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
struct ONNXConvNoBiasOpLowering : public ConversionPattern {
 | 
					struct ONNXConvNoBiasOpLowering : public ConversionPattern {
 | 
				
			||||||
  ONNXConvNoBiasOpLowering(MLIRContext *ctx)
 | 
					  ONNXConvNoBiasOpLowering(MLIRContext *ctx)
 | 
				
			||||||
      : ConversionPattern(mlir::ONNXConvNoBiasOp::getOperationName(), 1, ctx) {}
 | 
					      : ConversionPattern(mlir::ONNXConvNoBiasOp::getOperationName(), 1, ctx) {}
 | 
				
			||||||
| 
						 | 
					@ -1,4 +1,4 @@
 | 
				
			||||||
//===----- normalization.inc - Lowering Normalization Ops -----------------===//
 | 
					//===----- normalization.cpp - Lowering Normalization Ops -----------------===//
 | 
				
			||||||
//
 | 
					//
 | 
				
			||||||
// Copyright 2019 The IBM Research Authors.
 | 
					// Copyright 2019 The IBM Research Authors.
 | 
				
			||||||
//
 | 
					//
 | 
				
			||||||
| 
						 | 
					@ -8,6 +8,10 @@
 | 
				
			||||||
//
 | 
					//
 | 
				
			||||||
//===----------------------------------------------------------------------===//
 | 
					//===----------------------------------------------------------------------===//
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#include "src/conversion/onnx_to_krnl/onnx_to_krnl_common.hpp"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					using namespace mlir;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
struct ONNXBatchNormalizationTestModeOpLowering : public ConversionPattern {
 | 
					struct ONNXBatchNormalizationTestModeOpLowering : public ConversionPattern {
 | 
				
			||||||
  ONNXBatchNormalizationTestModeOpLowering(MLIRContext *ctx)
 | 
					  ONNXBatchNormalizationTestModeOpLowering(MLIRContext *ctx)
 | 
				
			||||||
      : ConversionPattern(
 | 
					      : ConversionPattern(
 | 
				
			||||||
| 
						 | 
					@ -0,0 +1,324 @@
 | 
				
			||||||
 | 
					//====-- onnx_to_krnl_common.cpp - ONNX dialects to Krnl lowering ---------===//
 | 
				
			||||||
 | 
					//
 | 
				
			||||||
 | 
					// Copyright 2019 The IBM Research Authors.
 | 
				
			||||||
 | 
					//
 | 
				
			||||||
 | 
					// =============================================================================
 | 
				
			||||||
 | 
					//
 | 
				
			||||||
 | 
					// This file contains common code shared by the functions performing the
 | 
				
			||||||
 | 
					// lowering to the KRNL dialect.
 | 
				
			||||||
 | 
					//
 | 
				
			||||||
 | 
					//===----------------------------------------------------------------------===//
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#include "src/conversion/onnx_to_krnl/onnx_to_krnl_common.hpp"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					/// Check is all dimensions are known at compile time.
 | 
				
			||||||
 | 
					bool hasAllConstantDimensions(MemRefType type) {
 | 
				
			||||||
 | 
					  auto memRefShape = type.getShape();
 | 
				
			||||||
 | 
					  for (int i = 0; i < memRefShape.size(); ++i)
 | 
				
			||||||
 | 
					    if (memRefShape[i] < 0)
 | 
				
			||||||
 | 
					      return false;
 | 
				
			||||||
 | 
					  return true;
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					/// Get the corresponding MemRefType of a given TensorType/MemRefType.
 | 
				
			||||||
 | 
					MemRefType convertToMemRefType(Type type) {
 | 
				
			||||||
 | 
					  MemRefType memRefType;
 | 
				
			||||||
 | 
					  auto tensorType = type.dyn_cast<TensorType>();
 | 
				
			||||||
 | 
					  if (tensorType) {
 | 
				
			||||||
 | 
					    assert(tensorType.hasRank() && "expected only ranked shapes");
 | 
				
			||||||
 | 
					    memRefType =
 | 
				
			||||||
 | 
					        MemRefType::get(tensorType.getShape(), tensorType.getElementType());
 | 
				
			||||||
 | 
					  } else {
 | 
				
			||||||
 | 
					    memRefType = type.dyn_cast<MemRefType>();
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					  return memRefType;
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					/// Insert an allocation and deallocation for the given MemRefType.
 | 
				
			||||||
 | 
					Value insertAllocAndDealloc(MemRefType type, Location loc,
 | 
				
			||||||
 | 
					                                   PatternRewriter &rewriter,
 | 
				
			||||||
 | 
					                                   bool insertDealloc,
 | 
				
			||||||
 | 
					                                   ArrayRef<Value> operands) {
 | 
				
			||||||
 | 
					  // Put together alloc operands for any dynamic dimensions of the memref.
 | 
				
			||||||
 | 
					  AllocOp alloc;
 | 
				
			||||||
 | 
					  if (!operands.empty()) {
 | 
				
			||||||
 | 
					    auto memRefShape = type.getShape();
 | 
				
			||||||
 | 
					    auto rank = memRefShape.size();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    std::map<int, Value> fromOperands;
 | 
				
			||||||
 | 
					    for (int reversedIdx = 0; reversedIdx < rank; ++reversedIdx) {
 | 
				
			||||||
 | 
					      int memRefDimIdx = rank - 1 - reversedIdx;
 | 
				
			||||||
 | 
					      if (memRefShape[memRefDimIdx] < 0) { // unknown dimension
 | 
				
			||||||
 | 
					        Value maxDim = nullptr;
 | 
				
			||||||
 | 
					        for (int i = 0; i < operands.size(); i++) {
 | 
				
			||||||
 | 
					          auto operandShape =
 | 
				
			||||||
 | 
					              operands[i].getType().cast<MemRefType>().getShape();
 | 
				
			||||||
 | 
					          int operandDimIdx = operandShape.size() - 1 - reversedIdx;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					          if (operandDimIdx < 0)
 | 
				
			||||||
 | 
					            continue;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					          // In case of operations with broadcasting, the dimension of the
 | 
				
			||||||
 | 
					          // alloc result is the maximum size along each dimension of the
 | 
				
			||||||
 | 
					          // operands.
 | 
				
			||||||
 | 
					          auto operandDim =
 | 
				
			||||||
 | 
					              rewriter.create<DimOp>(loc, operands[i], operandDimIdx);
 | 
				
			||||||
 | 
					          if (maxDim) {
 | 
				
			||||||
 | 
					            auto maxCondition = rewriter.create<CmpIOp>(loc, CmpIPredicate::sgt,
 | 
				
			||||||
 | 
					                                                        operandDim, maxDim);
 | 
				
			||||||
 | 
					            maxDim = rewriter.create<SelectOp>(loc, maxCondition, operandDim,
 | 
				
			||||||
 | 
					                                               maxDim);
 | 
				
			||||||
 | 
					          } else {
 | 
				
			||||||
 | 
					            maxDim = operandDim;
 | 
				
			||||||
 | 
					          }
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					        fromOperands.insert(std::make_pair(memRefDimIdx, maxDim));
 | 
				
			||||||
 | 
					      }
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    SmallVector<Value, 4> allocOperands;
 | 
				
			||||||
 | 
					    for (int i = 0; i < rank; ++i)
 | 
				
			||||||
 | 
					      if (memRefShape[i] < 0)
 | 
				
			||||||
 | 
					        allocOperands.push_back(fromOperands[i]);
 | 
				
			||||||
 | 
					    alloc = rewriter.create<AllocOp>(loc, type, allocOperands);
 | 
				
			||||||
 | 
					  } else {
 | 
				
			||||||
 | 
					    alloc = rewriter.create<AllocOp>(loc, type);
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // Make sure to allocate at the beginning of the block if
 | 
				
			||||||
 | 
					  // all dimensions are known.
 | 
				
			||||||
 | 
					  auto *parentBlock = alloc.getOperation()->getBlock();
 | 
				
			||||||
 | 
					  if (hasAllConstantDimensions(type))
 | 
				
			||||||
 | 
					    alloc.getOperation()->moveBefore(&parentBlock->front());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  if (insertDealloc) {
 | 
				
			||||||
 | 
					    auto dealloc = rewriter.create<DeallocOp>(loc, alloc);
 | 
				
			||||||
 | 
					    dealloc.getOperation()->moveBefore(&parentBlock->back());
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  return alloc;
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Determine if current function returns the result value of the
 | 
				
			||||||
 | 
					// current op being lowered. If it does then dealloc should not be
 | 
				
			||||||
 | 
					// inserted.
 | 
				
			||||||
 | 
					bool checkInsertDealloc(Operation *currentOp) {
 | 
				
			||||||
 | 
					  auto parentBlock = currentOp->getBlock();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  bool insertDealloc = true;
 | 
				
			||||||
 | 
					  parentBlock->walk([&insertDealloc, currentOp](ReturnOp op) {
 | 
				
			||||||
 | 
					    assert(currentOp->getNumResults() < 2 &&
 | 
				
			||||||
 | 
					           "No more than one result supported (for now).");
 | 
				
			||||||
 | 
					    // If there is at least one result to investigate.
 | 
				
			||||||
 | 
					    if (currentOp->getNumResults() > 0) {
 | 
				
			||||||
 | 
					      auto result = currentOp->getResult(0);
 | 
				
			||||||
 | 
					      for (const auto &operand : op.getOperands())
 | 
				
			||||||
 | 
					        if (operand == result)
 | 
				
			||||||
 | 
					          insertDealloc = false;
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					  });
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  return insertDealloc;
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Create a mapping from result type's dimensions to input type's dimensions,
 | 
				
			||||||
 | 
					// given that the result type is the result of a reduction op over the input
 | 
				
			||||||
 | 
					// type.
 | 
				
			||||||
 | 
					std::map<int64_t, int64_t>
 | 
				
			||||||
 | 
					getReductionMapping(MemRefType inputTy, ArrayRef<int64_t> axes, bool keepdims) {
 | 
				
			||||||
 | 
					  std::map<int64_t, int64_t> OutInDimMap;
 | 
				
			||||||
 | 
					  int64_t rank = inputTy.getRank();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // Mark reduction axes.
 | 
				
			||||||
 | 
					  std::vector<bool> isReductionAxis;
 | 
				
			||||||
 | 
					  for (decltype(rank) i = 0; i < rank; ++i) {
 | 
				
			||||||
 | 
					    if (std::find(axes.begin(), axes.end(), i) != axes.end())
 | 
				
			||||||
 | 
					      isReductionAxis.push_back(true);
 | 
				
			||||||
 | 
					    else
 | 
				
			||||||
 | 
					      isReductionAxis.push_back(false);
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  for (decltype(rank) inIndex = 0, outIndex = 0; inIndex < rank; ++inIndex) {
 | 
				
			||||||
 | 
					    // If it is a reduction axis, there is no relationship among dimensions.
 | 
				
			||||||
 | 
					    if (isReductionAxis[inIndex]) {
 | 
				
			||||||
 | 
					      if (keepdims)
 | 
				
			||||||
 | 
					        outIndex++;
 | 
				
			||||||
 | 
					    } else {
 | 
				
			||||||
 | 
					      OutInDimMap.insert(std::make_pair(outIndex, inIndex));
 | 
				
			||||||
 | 
					      outIndex++;
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  return OutInDimMap;
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Add bounds associated with the op operand to the KRNL iteration pack.
 | 
				
			||||||
 | 
					// Dynamic dimenions are supported.
 | 
				
			||||||
 | 
					void addDimensionToPack(ConversionPatternRewriter &rewriter,
 | 
				
			||||||
 | 
					                               Location loc, KrnlIterateOperandPack &pack,
 | 
				
			||||||
 | 
					                               Value operand, int index) {
 | 
				
			||||||
 | 
					  auto shape = operand.getType().cast<MemRefType>().getShape();
 | 
				
			||||||
 | 
					  if (shape[index] < 0) {
 | 
				
			||||||
 | 
					    pack.pushConstantBound(0);
 | 
				
			||||||
 | 
					    pack.pushOperandBound(
 | 
				
			||||||
 | 
					        rewriter.create<DimOp>(loc, operand, index).getResult());
 | 
				
			||||||
 | 
					  } else {
 | 
				
			||||||
 | 
					    pack.pushConstantBound(0);
 | 
				
			||||||
 | 
					    pack.pushConstantBound(shape[index]);
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Function that defines the KRNL dialect loops and their respective
 | 
				
			||||||
 | 
					// optimized version.
 | 
				
			||||||
 | 
					KrnlOptimizeLoopsOp
 | 
				
			||||||
 | 
					emitOptimizedLoops(ConversionPatternRewriter &rewriter, Location loc,
 | 
				
			||||||
 | 
					                   std::vector<Value> &loops,
 | 
				
			||||||
 | 
					                   std::vector<Value> &optimizedLoops, int64_t numLoops) {
 | 
				
			||||||
 | 
					  // Define loops.
 | 
				
			||||||
 | 
					  auto loopsOp = rewriter.create<KrnlDefineLoopsOp>(loc, numLoops);
 | 
				
			||||||
 | 
					  loops.reserve(numLoops);
 | 
				
			||||||
 | 
					  for (auto result : loopsOp.getResults())
 | 
				
			||||||
 | 
					    loops.push_back(result);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // Define optimized version of the loops.
 | 
				
			||||||
 | 
					  auto optimizedLoopsOp = rewriter.create<KrnlOptimizeLoopsOp>(loc, numLoops);
 | 
				
			||||||
 | 
					  optimizedLoops.reserve(numLoops);
 | 
				
			||||||
 | 
					  for (auto result : optimizedLoopsOp.getResults())
 | 
				
			||||||
 | 
					    optimizedLoops.push_back(result);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  return optimizedLoopsOp;
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Function that emits the loops and their optimized version.
 | 
				
			||||||
 | 
					// The function returns a reference to the inner optimization block.
 | 
				
			||||||
 | 
					Block *defineLoops(ConversionPatternRewriter &rewriter, Location loc,
 | 
				
			||||||
 | 
					                          std::vector<Value> &loops,
 | 
				
			||||||
 | 
					                          std::vector<Value> &optimizedLoops,
 | 
				
			||||||
 | 
					                          int64_t numLoops) {
 | 
				
			||||||
 | 
					  KrnlOptimizeLoopsOp optimizedLoopsOp =
 | 
				
			||||||
 | 
					      emitOptimizedLoops(rewriter, loc, loops, optimizedLoops, numLoops);
 | 
				
			||||||
 | 
					  return &optimizedLoopsOp.region().front();
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Function which emits a basic set of loops and optimized loops
 | 
				
			||||||
 | 
					// for a given operation argument. A reference to the loop optimization
 | 
				
			||||||
 | 
					// block is returned in the last argument of the function.
 | 
				
			||||||
 | 
					void emitKrnlLoopsAndIterationForOperand(
 | 
				
			||||||
 | 
					    ConversionPatternRewriter &rewriter, Location loc, Value operand,
 | 
				
			||||||
 | 
					    std::vector<Value> &originalLoops, KrnlOptimizeLoopsOp &optimizedLoopsOp,
 | 
				
			||||||
 | 
					    KrnlIterateOp &iterateOp) {
 | 
				
			||||||
 | 
					  // Operand shape.
 | 
				
			||||||
 | 
					  auto shape = operand.getType().cast<MemRefType>().getShape();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // Number of loops.
 | 
				
			||||||
 | 
					  int64_t rank = shape.size();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // Define loops and optimized loops.
 | 
				
			||||||
 | 
					  std::vector<Value> optimizedLoops;
 | 
				
			||||||
 | 
					  optimizedLoopsOp =
 | 
				
			||||||
 | 
					      emitOptimizedLoops(rewriter, loc, originalLoops, optimizedLoops, rank);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  KrnlIterateOperandPack pack(rewriter, originalLoops, optimizedLoops);
 | 
				
			||||||
 | 
					  // Iterate over the loop nest.
 | 
				
			||||||
 | 
					  for (int i = 0; i < rank; ++i)
 | 
				
			||||||
 | 
					    addDimensionToPack(rewriter, loc, pack, operand, i);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  iterateOp = rewriter.create<KrnlIterateOp>(loc, pack);
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					unsigned getMemRefEltSizeInBytes(MemRefType memRefType) {
 | 
				
			||||||
 | 
					  auto elementType = memRefType.getElementType();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  unsigned sizeInBits;
 | 
				
			||||||
 | 
					  if (elementType.isIntOrFloat()) {
 | 
				
			||||||
 | 
					    sizeInBits = elementType.getIntOrFloatBitWidth();
 | 
				
			||||||
 | 
					  } else {
 | 
				
			||||||
 | 
					    auto vectorType = elementType.cast<VectorType>();
 | 
				
			||||||
 | 
					    sizeInBits =
 | 
				
			||||||
 | 
					        vectorType.getElementTypeBitWidth() * vectorType.getNumElements();
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					  return llvm::divideCeil(sizeInBits, 8);
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Get run-time dimension information for unknown dimensions used for
 | 
				
			||||||
 | 
					// broadcasting.
 | 
				
			||||||
 | 
					std::map<int, std::map<int, Value>>
 | 
				
			||||||
 | 
					getBroadcastedDimInfo(Location loc, ConversionPatternRewriter &rewriter,
 | 
				
			||||||
 | 
					                      MemRefType memRefType, ArrayRef<Value> operands) {
 | 
				
			||||||
 | 
					  auto memRefShape = memRefType.getShape();
 | 
				
			||||||
 | 
					  int64_t rank = memRefShape.size();
 | 
				
			||||||
 | 
					  // For unknown dimensions, we need to get dimension values at runtime in
 | 
				
			||||||
 | 
					  // order to do broadcasting.
 | 
				
			||||||
 | 
					  std::map<int, std::map<int, Value>> DimInfo;
 | 
				
			||||||
 | 
					  // For each result dimension, compute the number of sharing operands.
 | 
				
			||||||
 | 
					  // Sharing operands are operands sharing the same index (counting from the
 | 
				
			||||||
 | 
					  // rightmost to the leftmost) for a given dimension.
 | 
				
			||||||
 | 
					  std::map<int, int> sharedDimCount;
 | 
				
			||||||
 | 
					  for (int reversedIdx = 0; reversedIdx < rank; ++reversedIdx) {
 | 
				
			||||||
 | 
					    int dimIdx = rank - 1 - reversedIdx;
 | 
				
			||||||
 | 
					    sharedDimCount[dimIdx] = 0;
 | 
				
			||||||
 | 
					    for (int i = 0; i < operands.size(); ++i) {
 | 
				
			||||||
 | 
					      auto shape = operands[i].getType().cast<MemRefType>().getShape();
 | 
				
			||||||
 | 
					      if (reversedIdx <= shape.size() - 1)
 | 
				
			||||||
 | 
					        sharedDimCount[dimIdx]++;
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					  // An unknown dimension can have a value of 1 or N (N > 1).
 | 
				
			||||||
 | 
					  // If its value is 1, it is broadcasted dimension.
 | 
				
			||||||
 | 
					  // Otherwise, non-broadcasted dimension.
 | 
				
			||||||
 | 
					  // We only care about unknown dimensions whose number of sharing operands is
 | 
				
			||||||
 | 
					  // more than one, since they are potentially broadcasted dimensions.
 | 
				
			||||||
 | 
					  for (int i = 0; i < operands.size(); ++i) {
 | 
				
			||||||
 | 
					    std::map<int, Value> broadcastedDims;
 | 
				
			||||||
 | 
					    auto shape = operands[i].getType().cast<MemRefType>().getShape();
 | 
				
			||||||
 | 
					    int size = shape.size();
 | 
				
			||||||
 | 
					    for (int j = 0; j < shape.size(); ++j) {
 | 
				
			||||||
 | 
					      if (shape[j] < 0 and sharedDimCount[rank - size + j] > 1) {
 | 
				
			||||||
 | 
					        auto dim = rewriter.create<DimOp>(loc, operands[i], j).getResult();
 | 
				
			||||||
 | 
					        auto one = rewriter.create<ConstantIndexOp>(loc, 1);
 | 
				
			||||||
 | 
					        auto isBroadcasted =
 | 
				
			||||||
 | 
					            rewriter.create<CmpIOp>(loc, CmpIPredicate::eq, dim, one);
 | 
				
			||||||
 | 
					        broadcastedDims.insert(std::make_pair(j, isBroadcasted));
 | 
				
			||||||
 | 
					      }
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					    DimInfo.insert(std::make_pair(i, broadcastedDims));
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					  return DimInfo;
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Extract induction variables that are used for broadcasting values of a
 | 
				
			||||||
 | 
					// given operand.
 | 
				
			||||||
 | 
					std::vector<Value>
 | 
				
			||||||
 | 
					getLoopIVsForBroadcasting(Location loc, ConversionPatternRewriter &rewriter,
 | 
				
			||||||
 | 
					                          ArrayRef<Value> loopIVs, Value operand,
 | 
				
			||||||
 | 
					                          std::map<int, Value> broadcastedDims) {
 | 
				
			||||||
 | 
					  // `operand` must has a ranked type. This should have been checked by the
 | 
				
			||||||
 | 
					  // shape inference pass.
 | 
				
			||||||
 | 
					  auto operandShape = operand.getType().cast<MemRefType>().getShape();
 | 
				
			||||||
 | 
					  auto rank = operandShape.size();
 | 
				
			||||||
 | 
					  auto loopCount = loopIVs.size();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  std::vector<Value> newLoopIVs;
 | 
				
			||||||
 | 
					  for (unsigned reversedIdx = 0; reversedIdx < rank; ++reversedIdx) {
 | 
				
			||||||
 | 
					    auto dimIdx = rank - 1 - reversedIdx;
 | 
				
			||||||
 | 
					    auto loopIdx = loopCount - 1 - reversedIdx;
 | 
				
			||||||
 | 
					    if (operandShape[dimIdx] == 1) {
 | 
				
			||||||
 | 
					      // Broadcasted dimension
 | 
				
			||||||
 | 
					      auto zero = rewriter.create<ConstantIndexOp>(loc, 0);
 | 
				
			||||||
 | 
					      newLoopIVs.insert(newLoopIVs.begin(), zero);
 | 
				
			||||||
 | 
					    } else if ((operandShape[dimIdx] == -1) &&
 | 
				
			||||||
 | 
					               (broadcastedDims.find(dimIdx) != broadcastedDims.end())) {
 | 
				
			||||||
 | 
					      // Unknown dimension, it can have a value of 1 or N (N > 1).
 | 
				
			||||||
 | 
					      // If its value is 1, it is broadcasted dimension.
 | 
				
			||||||
 | 
					      // Otherwise, non-broadcasted dimension.
 | 
				
			||||||
 | 
					      auto zero = rewriter.create<ConstantIndexOp>(loc, 0);
 | 
				
			||||||
 | 
					      auto idx = rewriter.create<SelectOp>(loc, broadcastedDims[dimIdx], zero,
 | 
				
			||||||
 | 
					                                           loopIVs[loopIdx]);
 | 
				
			||||||
 | 
					      newLoopIVs.insert(newLoopIVs.begin(), idx);
 | 
				
			||||||
 | 
					    } else {
 | 
				
			||||||
 | 
					      // Non-broadcasted dimension
 | 
				
			||||||
 | 
					      newLoopIVs.insert(newLoopIVs.begin(), loopIVs[loopIdx]);
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					  return newLoopIVs;
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
| 
						 | 
					@ -0,0 +1,217 @@
 | 
				
			||||||
 | 
					//====-- onnx_to_krnl_common.hpp - ONNX dialects to Krnl lowering ---------===//
 | 
				
			||||||
 | 
					//
 | 
				
			||||||
 | 
					// Copyright 2019 The IBM Research Authors.
 | 
				
			||||||
 | 
					//
 | 
				
			||||||
 | 
					// =============================================================================
 | 
				
			||||||
 | 
					//
 | 
				
			||||||
 | 
					// This file contains common code shared by the functions performing the
 | 
				
			||||||
 | 
					// lowering to the KRNL dialect.
 | 
				
			||||||
 | 
					//
 | 
				
			||||||
 | 
					//===----------------------------------------------------------------------===//
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#pragma once
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#include <map>
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#include "mlir/Dialect/AffineOps/AffineOps.h"
 | 
				
			||||||
 | 
					#include "mlir/Dialect/StandardOps/Ops.h"
 | 
				
			||||||
 | 
					#include "mlir/Pass/Pass.h"
 | 
				
			||||||
 | 
					#include "mlir/Transforms/DialectConversion.h"
 | 
				
			||||||
 | 
					#include "llvm/ADT/ArrayRef.h"
 | 
				
			||||||
 | 
					#include "llvm/ADT/Sequence.h"
 | 
				
			||||||
 | 
					#include "mlir/IR/PatternMatch.h"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#include "src/dialect/krnl/krnl_helper.hpp"
 | 
				
			||||||
 | 
					#include "src/dialect/krnl/krnl_ops.hpp"
 | 
				
			||||||
 | 
					#include "src/dialect/onnx/onnx_ops.hpp"
 | 
				
			||||||
 | 
					#include "src/pass/passes.hpp"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					using namespace mlir;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					//===----------------------------------------------------------------------===//
 | 
				
			||||||
 | 
					// Common functions used when lowering the ONNX frontend dialect to KRNL.
 | 
				
			||||||
 | 
					//===----------------------------------------------------------------------===//
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					/// Check is all dimensions are known at compile time.
 | 
				
			||||||
 | 
					bool hasAllConstantDimensions(MemRefType type);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					/// Get the corresponding MemRefType of a given TensorType/MemRefType.
 | 
				
			||||||
 | 
					MemRefType convertToMemRefType(Type type);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					/// Insert an allocation and deallocation for the given MemRefType.
 | 
				
			||||||
 | 
					Value insertAllocAndDealloc(MemRefType type, Location loc,
 | 
				
			||||||
 | 
					                                   PatternRewriter &rewriter,
 | 
				
			||||||
 | 
					                                   bool insertDealloc,
 | 
				
			||||||
 | 
					                                   ArrayRef<Value> operands = {});
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Determine if current function returns the result value of the
 | 
				
			||||||
 | 
					// current op being lowered. If it does then dealloc should not be
 | 
				
			||||||
 | 
					// inserted.
 | 
				
			||||||
 | 
					bool checkInsertDealloc(Operation *currentOp);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Create a mapping from result type's dimensions to input type's dimensions,
 | 
				
			||||||
 | 
					// given that the result type is the result of a reduction op over the input
 | 
				
			||||||
 | 
					// type.
 | 
				
			||||||
 | 
					std::map<int64_t, int64_t>
 | 
				
			||||||
 | 
					getReductionMapping(MemRefType inputTy, ArrayRef<int64_t> axes, bool keepdims);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Add bounds associated with the op operand to the KRNL iteration pack.
 | 
				
			||||||
 | 
					// Dynamic dimenions are supported.
 | 
				
			||||||
 | 
					void addDimensionToPack(ConversionPatternRewriter &rewriter,
 | 
				
			||||||
 | 
					                               Location loc, KrnlIterateOperandPack &pack,
 | 
				
			||||||
 | 
					                               Value operand, int index);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Function that defines the KRNL dialect loops and their respective
 | 
				
			||||||
 | 
					// optimized version.
 | 
				
			||||||
 | 
					KrnlOptimizeLoopsOp
 | 
				
			||||||
 | 
					emitOptimizedLoops(ConversionPatternRewriter &rewriter, Location loc,
 | 
				
			||||||
 | 
					                   std::vector<Value> &loops,
 | 
				
			||||||
 | 
					                   std::vector<Value> &optimizedLoops, int64_t numLoops);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Function that emits the loops and their optimized version.
 | 
				
			||||||
 | 
					// The function returns a reference to the inner optimization block.
 | 
				
			||||||
 | 
					Block *defineLoops(ConversionPatternRewriter &rewriter, Location loc,
 | 
				
			||||||
 | 
					                          std::vector<Value> &loops,
 | 
				
			||||||
 | 
					                          std::vector<Value> &optimizedLoops,
 | 
				
			||||||
 | 
					                          int64_t numLoops);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Function which emits a basic set of loops and optimized loops
 | 
				
			||||||
 | 
					// for a given operation argument. A reference to the loop optimization
 | 
				
			||||||
 | 
					// block is returned in the last argument of the function.
 | 
				
			||||||
 | 
					void emitKrnlLoopsAndIterationForOperand(
 | 
				
			||||||
 | 
					    ConversionPatternRewriter &rewriter, Location loc, Value operand,
 | 
				
			||||||
 | 
					    std::vector<Value> &originalLoops, KrnlOptimizeLoopsOp &optimizedLoopsOp,
 | 
				
			||||||
 | 
					    KrnlIterateOp &iterateOp);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					unsigned getMemRefEltSizeInBytes(MemRefType memRefType);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Get run-time dimension information for unknown dimensions used for
 | 
				
			||||||
 | 
					// broadcasting.
 | 
				
			||||||
 | 
					std::map<int, std::map<int, Value>>
 | 
				
			||||||
 | 
					getBroadcastedDimInfo(Location loc, ConversionPatternRewriter &rewriter,
 | 
				
			||||||
 | 
					                      MemRefType memRefType, ArrayRef<Value> operands);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Extract induction variables that are used for broadcasting values of a
 | 
				
			||||||
 | 
					// given operand.
 | 
				
			||||||
 | 
					std::vector<Value>
 | 
				
			||||||
 | 
					getLoopIVsForBroadcasting(Location loc, ConversionPatternRewriter &rewriter,
 | 
				
			||||||
 | 
					                          ArrayRef<Value> loopIVs, Value operand,
 | 
				
			||||||
 | 
					                          std::map<int, Value> broadcastedDims);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					//===----------------------------------------------------------------------===//
 | 
				
			||||||
 | 
					// This is to get a scalar operation of a given type for a specific operation.
 | 
				
			||||||
 | 
					//===----------------------------------------------------------------------===//
 | 
				
			||||||
 | 
					template <typename Op>
 | 
				
			||||||
 | 
					struct ScalarOp {
 | 
				
			||||||
 | 
					  using FOp = void;
 | 
				
			||||||
 | 
					  using IOp = void;
 | 
				
			||||||
 | 
					};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					template <typename FOp>
 | 
				
			||||||
 | 
					using ScalarFOp = typename ScalarOp<FOp>::FOp;
 | 
				
			||||||
 | 
					template <typename IOp>
 | 
				
			||||||
 | 
					using ScalarIOp = typename ScalarOp<IOp>::IOp;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Get the identity element of a operation.
 | 
				
			||||||
 | 
					// Return NULL if the function does not have identity.
 | 
				
			||||||
 | 
					template <typename DataType, typename Op>
 | 
				
			||||||
 | 
					DataType getIdentityValue() {
 | 
				
			||||||
 | 
					  return NULL;
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					//===----------------------------------------------------------------------===//
 | 
				
			||||||
 | 
					// This is used in the innermost loop of a KrnlIterateOp to insert computation
 | 
				
			||||||
 | 
					// composed of one or many scalar ops.
 | 
				
			||||||
 | 
					// Use template specialization for each of different ONNX operations.
 | 
				
			||||||
 | 
					//===----------------------------------------------------------------------===//
 | 
				
			||||||
 | 
					template <typename Op>
 | 
				
			||||||
 | 
					Value mapToLowerScalarOp(Operation *op, ArrayRef<Type> result_types,
 | 
				
			||||||
 | 
					                         ArrayRef<Value> operands,
 | 
				
			||||||
 | 
					                         ConversionPatternRewriter &rewriter) {
 | 
				
			||||||
 | 
					  auto loc = op->getLoc();
 | 
				
			||||||
 | 
					  Type element_type = operands.front().getType();
 | 
				
			||||||
 | 
					  if (element_type.isa<IntegerType>()) {
 | 
				
			||||||
 | 
					    return rewriter.create<ScalarIOp<Op>>(loc, result_types, operands,
 | 
				
			||||||
 | 
					                                          mlir::None);
 | 
				
			||||||
 | 
					  } else if (element_type.isa<FloatType>()) {
 | 
				
			||||||
 | 
					    return rewriter.create<ScalarFOp<Op>>(loc, result_types, operands,
 | 
				
			||||||
 | 
					                                          mlir::None);
 | 
				
			||||||
 | 
					  } else {
 | 
				
			||||||
 | 
					    emitError(loc, "unsupported element type");
 | 
				
			||||||
 | 
					    return nullptr;
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					//===----------------------------------------------------------------------===//
 | 
				
			||||||
 | 
					// Conversion from Tensor type to the Standard dialect MemRef type.
 | 
				
			||||||
 | 
					//===----------------------------------------------------------------------===//
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					struct TensorTypeConverter : public TypeConverter {
 | 
				
			||||||
 | 
					  using TypeConverter::TypeConverter;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  TensorTypeConverter() {
 | 
				
			||||||
 | 
					    addConversion(convertType);
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  static LogicalResult convertType(Type t, SmallVectorImpl<Type> &results) {
 | 
				
			||||||
 | 
					    if (auto type = convertToMemRefType(t)) {
 | 
				
			||||||
 | 
					      results.push_back(type);
 | 
				
			||||||
 | 
					      return success();
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    results.push_back(t);
 | 
				
			||||||
 | 
					    return success();
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  /// Return true if the inputs and outputs of the given function type are
 | 
				
			||||||
 | 
					  /// legal. [Taken from MLIR and adapted to only check the legality of the
 | 
				
			||||||
 | 
					  /// inputs. Once unranked results can be handled gracefully this
 | 
				
			||||||
 | 
					  /// override needs to be removed in favour of the original MLIR one.]
 | 
				
			||||||
 | 
					  bool isSignatureLegal(FunctionType funcType) {
 | 
				
			||||||
 | 
					    return llvm::all_of(funcType.getInputs(),
 | 
				
			||||||
 | 
					                        [this](Type type) { return isLegal(type); });
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					//===----------------------------------------------------------------------===//
 | 
				
			||||||
 | 
					// Functions to add lowering patterns for frontend operations.
 | 
				
			||||||
 | 
					//===----------------------------------------------------------------------===//
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// `math` directory methods:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					void populateLoweringONNXElementwiseOpPattern(
 | 
				
			||||||
 | 
					    OwningRewritePatternList &patterns, MLIRContext *ctx);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					void populateLoweringONNXGemmOpPattern(OwningRewritePatternList &patterns,
 | 
				
			||||||
 | 
					                                       MLIRContext *ctx);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					void populateLoweringONNXMatMulOpPattern(
 | 
				
			||||||
 | 
					    OwningRewritePatternList &patterns, MLIRContext *ctx);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					void populateLoweringONNXReductionOpPattern(
 | 
				
			||||||
 | 
					    OwningRewritePatternList &patterns, MLIRContext *ctx);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					void populateLoweringONNXSoftmaxOpPattern(
 | 
				
			||||||
 | 
					    OwningRewritePatternList &patterns, MLIRContext *ctx);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// `nn` directory methods:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					void populateLoweringONNXConvOpPattern(
 | 
				
			||||||
 | 
					    OwningRewritePatternList &patterns, MLIRContext *ctx);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					void populateLoweringONNXNormalizationOpPattern(
 | 
				
			||||||
 | 
					    OwningRewritePatternList &patterns, MLIRContext *ctx);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// `tensor` directory methods:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					void populateLoweringONNXUnsqueezeOpPattern(
 | 
				
			||||||
 | 
					    OwningRewritePatternList &patterns, MLIRContext *ctx);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					void populateLoweringONNXTransposeOpPattern(
 | 
				
			||||||
 | 
					    OwningRewritePatternList &patterns, MLIRContext *ctx);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					void populateLoweringONNXReshapeOpPattern(
 | 
				
			||||||
 | 
					    OwningRewritePatternList &patterns, MLIRContext *ctx);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					void populateLoweringONNXIdentityOpPattern(
 | 
				
			||||||
 | 
					    OwningRewritePatternList &patterns, MLIRContext *ctx);
 | 
				
			||||||
| 
						 | 
					@ -1,4 +1,4 @@
 | 
				
			||||||
//===----- identity.inc - Lowering Identity Op ----------------------------===//
 | 
					//===----- identity.cpp - Lowering Identity Op ----------------------------===//
 | 
				
			||||||
//
 | 
					//
 | 
				
			||||||
// Copyright 2019 The IBM Research Authors.
 | 
					// Copyright 2019 The IBM Research Authors.
 | 
				
			||||||
//
 | 
					//
 | 
				
			||||||
| 
						 | 
					@ -8,6 +8,10 @@
 | 
				
			||||||
//
 | 
					//
 | 
				
			||||||
//===----------------------------------------------------------------------===//
 | 
					//===----------------------------------------------------------------------===//
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#include "src/conversion/onnx_to_krnl/onnx_to_krnl_common.hpp"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					using namespace mlir;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
struct ONNXIdentityOpLowering : public ConversionPattern {
 | 
					struct ONNXIdentityOpLowering : public ConversionPattern {
 | 
				
			||||||
  ONNXIdentityOpLowering(MLIRContext *ctx)
 | 
					  ONNXIdentityOpLowering(MLIRContext *ctx)
 | 
				
			||||||
      : ConversionPattern(mlir::ONNXIdentityOp::getOperationName(), 1, ctx) {}
 | 
					      : ConversionPattern(mlir::ONNXIdentityOp::getOperationName(), 1, ctx) {}
 | 
				
			||||||
| 
						 | 
					@ -1,4 +1,4 @@
 | 
				
			||||||
//===----- reshape.inc - Lowering Reshape Op ------------------------------===//
 | 
					//===----- reshape.cpp - Lowering Reshape Op ------------------------------===//
 | 
				
			||||||
//
 | 
					//
 | 
				
			||||||
// Copyright 2019 The IBM Research Authors.
 | 
					// Copyright 2019 The IBM Research Authors.
 | 
				
			||||||
//
 | 
					//
 | 
				
			||||||
| 
						 | 
					@ -8,6 +8,10 @@
 | 
				
			||||||
//
 | 
					//
 | 
				
			||||||
//===----------------------------------------------------------------------===//
 | 
					//===----------------------------------------------------------------------===//
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#include "src/conversion/onnx_to_krnl/onnx_to_krnl_common.hpp"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					using namespace mlir;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
struct ONNXReshapeOpLowering : public ConversionPattern {
 | 
					struct ONNXReshapeOpLowering : public ConversionPattern {
 | 
				
			||||||
  ONNXReshapeOpLowering(MLIRContext *ctx)
 | 
					  ONNXReshapeOpLowering(MLIRContext *ctx)
 | 
				
			||||||
      : ConversionPattern(mlir::ONNXReshapeOp::getOperationName(), 1, ctx) {}
 | 
					      : ConversionPattern(mlir::ONNXReshapeOp::getOperationName(), 1, ctx) {}
 | 
				
			||||||
| 
						 | 
					@ -1,4 +1,4 @@
 | 
				
			||||||
//===----- transpose.inc - Lowering Transpose Op --------------------------===//
 | 
					//===----- transpose.cpp - Lowering Transpose Op --------------------------===//
 | 
				
			||||||
//
 | 
					//
 | 
				
			||||||
// Copyright 2019 The IBM Research Authors.
 | 
					// Copyright 2019 The IBM Research Authors.
 | 
				
			||||||
//
 | 
					//
 | 
				
			||||||
| 
						 | 
					@ -8,6 +8,10 @@
 | 
				
			||||||
//
 | 
					//
 | 
				
			||||||
//===----------------------------------------------------------------------===//
 | 
					//===----------------------------------------------------------------------===//
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#include "src/conversion/onnx_to_krnl/onnx_to_krnl_common.hpp"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					using namespace mlir;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
struct ONNXTransposeOpLowering : public ConversionPattern {
 | 
					struct ONNXTransposeOpLowering : public ConversionPattern {
 | 
				
			||||||
  ONNXTransposeOpLowering(MLIRContext *ctx)
 | 
					  ONNXTransposeOpLowering(MLIRContext *ctx)
 | 
				
			||||||
      : ConversionPattern(mlir::ONNXTransposeOp::getOperationName(), 1, ctx) {}
 | 
					      : ConversionPattern(mlir::ONNXTransposeOp::getOperationName(), 1, ctx) {}
 | 
				
			||||||
| 
						 | 
					@ -1,4 +1,4 @@
 | 
				
			||||||
//===----- unsqueeze.inc - Lowering Unsqueeze Op --------------------------===//
 | 
					//===----- unsqueeze.cpp - Lowering Unsqueeze Op --------------------------===//
 | 
				
			||||||
//
 | 
					//
 | 
				
			||||||
// Copyright 2019 The IBM Research Authors.
 | 
					// Copyright 2019 The IBM Research Authors.
 | 
				
			||||||
//
 | 
					//
 | 
				
			||||||
| 
						 | 
					@ -8,6 +8,10 @@
 | 
				
			||||||
//
 | 
					//
 | 
				
			||||||
//===----------------------------------------------------------------------===//
 | 
					//===----------------------------------------------------------------------===//
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#include "src/conversion/onnx_to_krnl/onnx_to_krnl_common.hpp"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					using namespace mlir;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
struct ONNXUnsqueezeOpLowering : public ConversionPattern {
 | 
					struct ONNXUnsqueezeOpLowering : public ConversionPattern {
 | 
				
			||||||
  ONNXUnsqueezeOpLowering(MLIRContext *ctx)
 | 
					  ONNXUnsqueezeOpLowering(MLIRContext *ctx)
 | 
				
			||||||
      : ConversionPattern(mlir::ONNXUnsqueezeOp::getOperationName(), 1, ctx) {}
 | 
					      : ConversionPattern(mlir::ONNXUnsqueezeOp::getOperationName(), 1, ctx) {}
 | 
				
			||||||
| 
						 | 
					@ -131,7 +131,7 @@ void KrnlIterateOperandPack::pushConstantBound(int64_t bound) {
 | 
				
			||||||
  boundMaps.emplace_back(AffineMapAttr::get(map));
 | 
					  boundMaps.emplace_back(AffineMapAttr::get(map));
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
void KrnlIterateOperandPack::pushOperandBound(mlir::Value operand) {
 | 
					void KrnlIterateOperandPack::pushOperandBound(Value operand) {
 | 
				
			||||||
  if (boundMaps.size() % 2 == 0)
 | 
					  if (boundMaps.size() % 2 == 0)
 | 
				
			||||||
    _operands.emplace_back(inputLoops[boundMaps.size() / 2]);
 | 
					    _operands.emplace_back(inputLoops[boundMaps.size() / 2]);
 | 
				
			||||||
  AffineMap map = builder.getSymbolIdentityMap();
 | 
					  AffineMap map = builder.getSymbolIdentityMap();
 | 
				
			||||||
| 
						 | 
					@ -145,7 +145,7 @@ BuildKrnlLoop::BuildKrnlLoop(
 | 
				
			||||||
      pushCount(0), createdDefineOp(false), createdOptimizeOp(false),
 | 
					      pushCount(0), createdDefineOp(false), createdOptimizeOp(false),
 | 
				
			||||||
      createdIterateOp(false) {
 | 
					      createdIterateOp(false) {
 | 
				
			||||||
  if (originalLoopNum <= 0)
 | 
					  if (originalLoopNum <= 0)
 | 
				
			||||||
    emitError(loc, "expected positive number of original loops");
 | 
					    emitError(loc, "Expected positive number of original loops.");
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
BuildKrnlLoop::BuildKrnlLoop(
 | 
					BuildKrnlLoop::BuildKrnlLoop(
 | 
				
			||||||
| 
						 | 
					@ -154,25 +154,24 @@ BuildKrnlLoop::BuildKrnlLoop(
 | 
				
			||||||
          memRefOperand.getType().cast<MemRefType>().getShape().size()) {}
 | 
					          memRefOperand.getType().cast<MemRefType>().getShape().size()) {}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
BuildKrnlLoop::~BuildKrnlLoop() {
 | 
					BuildKrnlLoop::~BuildKrnlLoop() {
 | 
				
			||||||
  if (!createdDefineOp)
 | 
					 | 
				
			||||||
    emitError(loc, "expected to create define op");
 | 
					 | 
				
			||||||
  if (!createdIterateOp)
 | 
					 | 
				
			||||||
    emitError(loc, "expected to create iteration op");
 | 
					 | 
				
			||||||
  if (pack)
 | 
					  if (pack)
 | 
				
			||||||
    free(pack);
 | 
					    free(pack);
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
void BuildKrnlLoop::createDefineAndOptimizeOp(bool withEmptyOptimization) {
 | 
					void BuildKrnlLoop::createDefineAndOptimizeOp(bool withEmptyOptimization) {
 | 
				
			||||||
  // insert define loop op
 | 
					  // Insert define loop operation.
 | 
				
			||||||
  auto loopsOp = rewriter.create<KrnlDefineLoopsOp>(loc, originalLoopNum);
 | 
					  auto loopsOp = rewriter.create<KrnlDefineLoopsOp>(loc, originalLoopNum);
 | 
				
			||||||
  originalLoops.reserve(originalLoopNum);
 | 
					  originalLoops.reserve(originalLoopNum);
 | 
				
			||||||
  for (auto result : loopsOp.getResults())
 | 
					  for (auto result : loopsOp.getResults())
 | 
				
			||||||
    originalLoops.push_back(result);
 | 
					    originalLoops.push_back(result);
 | 
				
			||||||
  // inserte optimize loop op.
 | 
					  createdDefineOp = true;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // Insert optimize loop operation.
 | 
				
			||||||
  auto optimizedLoopsOp =
 | 
					  auto optimizedLoopsOp =
 | 
				
			||||||
      rewriter.create<KrnlOptimizeLoopsOp>(loc, originalLoopNum);
 | 
					      rewriter.create<KrnlOptimizeLoopsOp>(loc, originalLoopNum);
 | 
				
			||||||
  optLoops.reserve(originalLoopNum);
 | 
					  optLoops.reserve(originalLoopNum);
 | 
				
			||||||
  // Emit empty optimizations
 | 
					
 | 
				
			||||||
 | 
					  // Emit empty optimizations if flag is set.
 | 
				
			||||||
  if (withEmptyOptimization) {
 | 
					  if (withEmptyOptimization) {
 | 
				
			||||||
    for (auto result : optimizedLoopsOp.getResults())
 | 
					    for (auto result : optimizedLoopsOp.getResults())
 | 
				
			||||||
      optLoops.push_back(result);
 | 
					      optLoops.push_back(result);
 | 
				
			||||||
| 
						 | 
					@ -182,12 +181,12 @@ void BuildKrnlLoop::createDefineAndOptimizeOp(bool withEmptyOptimization) {
 | 
				
			||||||
    rewriter.create<KrnlReturnLoopsOp>(loc, originalLoops);
 | 
					    rewriter.create<KrnlReturnLoopsOp>(loc, originalLoops);
 | 
				
			||||||
    rewriter.restoreInsertionPoint(ip);
 | 
					    rewriter.restoreInsertionPoint(ip);
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
 | 
					  createdOptimizeOp = true;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  // prepare data structure to push bounds
 | 
					  // prepare data structure to push bounds
 | 
				
			||||||
  pack = new KrnlIterateOperandPack(rewriter, originalLoops, optLoops);
 | 
					  pack = new KrnlIterateOperandPack(rewriter, originalLoops, optLoops);
 | 
				
			||||||
  createdOptimizeOp = true;
 | 
					 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// push bounds (lower and upper) and return index for loop info
 | 
					 | 
				
			||||||
int BuildKrnlLoop::pushBounds(int64_t lowerBound, int64_t upperBound) {
 | 
					int BuildKrnlLoop::pushBounds(int64_t lowerBound, int64_t upperBound) {
 | 
				
			||||||
  pack->pushConstantBound(lowerBound);
 | 
					  pack->pushConstantBound(lowerBound);
 | 
				
			||||||
  pack->pushConstantBound(upperBound);
 | 
					  pack->pushConstantBound(upperBound);
 | 
				
			||||||
| 
						 | 
					@ -203,17 +202,20 @@ int BuildKrnlLoop::pushBounds(int64_t lowerBound, Value upperBound) {
 | 
				
			||||||
int BuildKrnlLoop::pushBounds(int64_t lowerBound, Value upperBoundMemRefOperand,
 | 
					int BuildKrnlLoop::pushBounds(int64_t lowerBound, Value upperBoundMemRefOperand,
 | 
				
			||||||
    int upperBoundMemRefIndex, bool upperBoundMustBeConstant) {
 | 
					    int upperBoundMemRefIndex, bool upperBoundMustBeConstant) {
 | 
				
			||||||
  pack->pushConstantBound(lowerBound);
 | 
					  pack->pushConstantBound(lowerBound);
 | 
				
			||||||
  // process upperBound as a dimension of mem ref, possibly non-constant
 | 
					
 | 
				
			||||||
 | 
					  // Process upperBound as a dimension of the MemRef. Non-constant dimensions
 | 
				
			||||||
 | 
					  // are supported.
 | 
				
			||||||
  auto shape = upperBoundMemRefOperand.getType().cast<MemRefType>().getShape();
 | 
					  auto shape = upperBoundMemRefOperand.getType().cast<MemRefType>().getShape();
 | 
				
			||||||
  if (shape[upperBoundMemRefIndex] < 0) {
 | 
					  if (shape[upperBoundMemRefIndex] < 0) {
 | 
				
			||||||
    if (upperBoundMustBeConstant)
 | 
					    if (upperBoundMustBeConstant)
 | 
				
			||||||
      emitError(loc, "bound expected to be constant");
 | 
					      emitError(loc, "Bound expected to be constant.");
 | 
				
			||||||
    pack->pushOperandBound(
 | 
					    pack->pushOperandBound(
 | 
				
			||||||
        rewriter
 | 
					        rewriter
 | 
				
			||||||
            .create<DimOp>(loc, upperBoundMemRefOperand, upperBoundMemRefIndex)
 | 
					            .create<DimOp>(loc, upperBoundMemRefOperand, upperBoundMemRefIndex)
 | 
				
			||||||
            .getResult());
 | 
					            .getResult());
 | 
				
			||||||
  } else
 | 
					  } else
 | 
				
			||||||
    pack->pushConstantBound(shape[upperBoundMemRefIndex]);
 | 
					    pack->pushConstantBound(shape[upperBoundMemRefIndex]);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  return pushCount++;
 | 
					  return pushCount++;
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -223,19 +225,20 @@ int BuildKrnlLoop::pushBounds(Value lowerBound, Value upperBound) {
 | 
				
			||||||
  return pushCount++;
 | 
					  return pushCount++;
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// create iter
 | 
					 | 
				
			||||||
void BuildKrnlLoop::createIterateOp() {
 | 
					void BuildKrnlLoop::createIterateOp() {
 | 
				
			||||||
 | 
					  // Loop definition operation is mandatory.
 | 
				
			||||||
  if (!createdDefineOp)
 | 
					  if (!createdDefineOp)
 | 
				
			||||||
    emitError(loc, "must create define op before iterate op");
 | 
					    emitError(loc, "Must create define op before iterate op.");
 | 
				
			||||||
  // Tight now, optimize (possibly empty) is mandatory. This may change
 | 
					
 | 
				
			||||||
 | 
					  // Loop optimization operation is mandatory (for now).
 | 
				
			||||||
  if (!createdOptimizeOp)
 | 
					  if (!createdOptimizeOp)
 | 
				
			||||||
    emitError(loc, "must create optimize op before iterate op");
 | 
					    emitError(loc, "Must create optimize op before iterate op.");
 | 
				
			||||||
  // have to have defined all bounds
 | 
					
 | 
				
			||||||
  if (pushCount != originalLoopNum) {
 | 
					  // Check if all bounds have been defined.
 | 
				
			||||||
    printf(" push count %d, original loop %d\n", pushCount, originalLoopNum);
 | 
					  if (pushCount != originalLoopNum)
 | 
				
			||||||
    emitError(loc, "must push bounds for all original loops");
 | 
					    emitError(loc, "Must push bounds for all original loops.");
 | 
				
			||||||
  }
 | 
					
 | 
				
			||||||
  // create iterate op
 | 
					  // Emit iteration operation.
 | 
				
			||||||
  auto iterateOp = rewriter.create<KrnlIterateOp>(loc, *pack);
 | 
					  auto iterateOp = rewriter.create<KrnlIterateOp>(loc, *pack);
 | 
				
			||||||
  iterBlock = &iterateOp.bodyRegion().front();
 | 
					  iterBlock = &iterateOp.bodyRegion().front();
 | 
				
			||||||
  createdIterateOp = true;
 | 
					  createdIterateOp = true;
 | 
				
			||||||
| 
						 | 
					@ -243,19 +246,27 @@ void BuildKrnlLoop::createIterateOp() {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
void BuildKrnlLoop::createDefineOptimizeAndIterateOp(
 | 
					void BuildKrnlLoop::createDefineOptimizeAndIterateOp(
 | 
				
			||||||
    Value memRefOperand, bool withEmptyOptimization) {
 | 
					    Value memRefOperand, bool withEmptyOptimization) {
 | 
				
			||||||
 | 
					  // Rank of the MemRef operand. We will emit a loop for each dimension.
 | 
				
			||||||
  int loopNum = memRefOperand.getType().cast<MemRefType>().getShape().size();
 | 
					  int loopNum = memRefOperand.getType().cast<MemRefType>().getShape().size();
 | 
				
			||||||
  if (originalLoopNum != loopNum)
 | 
					  if (originalLoopNum != loopNum)
 | 
				
			||||||
    emitError(loc, "mismatch in loop numbers from constructor and define");
 | 
					    emitError(loc, "Mismatch in loop numbers from constructor and define.");
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // Emit the definition and the optimization operations for the loop nest.
 | 
				
			||||||
  createDefineAndOptimizeOp(withEmptyOptimization);
 | 
					  createDefineAndOptimizeOp(withEmptyOptimization);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // Push a lower-upper bound pair for each dimension of the MemRef operand.
 | 
				
			||||||
 | 
					  // The lower bound in this case is always zero.
 | 
				
			||||||
  for (int i = 0; i < originalLoopNum; ++i)
 | 
					  for (int i = 0; i < originalLoopNum; ++i)
 | 
				
			||||||
    pushBounds(0, memRefOperand, i);
 | 
					    pushBounds(0, memRefOperand, i);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // Emit the iteration operation over the current loop nest.
 | 
				
			||||||
  createIterateOp();
 | 
					  createIterateOp();
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// get induction variable to be use within iter
 | 
					 | 
				
			||||||
BlockArgument &BuildKrnlLoop::getInductionVar(int originalLoopIndex) {
 | 
					BlockArgument &BuildKrnlLoop::getInductionVar(int originalLoopIndex) {
 | 
				
			||||||
 | 
					  // Check if loop iteration variable is within bounds.
 | 
				
			||||||
  if (originalLoopIndex < 0 || originalLoopIndex >= originalLoopNum)
 | 
					  if (originalLoopIndex < 0 || originalLoopIndex >= originalLoopNum)
 | 
				
			||||||
    emitError(loc, "original loop index is out of bound");
 | 
					    emitError(loc, "Original loop index is out of bounds.");
 | 
				
			||||||
  return iterBlock->getArguments()[originalLoopIndex];
 | 
					  return iterBlock->getArguments()[originalLoopIndex];
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -106,19 +106,21 @@ private:
 | 
				
			||||||
//
 | 
					//
 | 
				
			||||||
// The sequence is as follow:
 | 
					// The sequence is as follow:
 | 
				
			||||||
//
 | 
					//
 | 
				
			||||||
//   1) Create a object giving the rewriter, location, and number of loop in the
 | 
					//   1) Create an object giving the rewriter, location, and number of loop in
 | 
				
			||||||
//   original (non optimized) loop.
 | 
					//   the original (non optimized) loop.
 | 
				
			||||||
//
 | 
					//
 | 
				
			||||||
//   2) Create define & optimize ops (currently paired). Optimizations can then
 | 
					//   2) Create define & optimize ops (currently paired). Optimizations can then
 | 
				
			||||||
//   be added to the inner block of the optimize operation. Make sure to set the
 | 
					//   be added to the inner block of the optimize operation. Make sure to set
 | 
				
			||||||
//   insertion point to that block for optimizations to go in the right place.
 | 
					//   the insertion point to that block for optimizations to go in the right
 | 
				
			||||||
 | 
					//   place.
 | 
				
			||||||
//
 | 
					//
 | 
				
			||||||
//   3) Push the bounds for each of the original loops. Bounds are pushed in
 | 
					//   3) Push the bounds for each of the original loops. Bounds are pushed in
 | 
				
			||||||
//   pairs (lower & upper bounds). THere are a few methods to do it depending on
 | 
					//   pairs (lower & upper bounds). There are a few methods to do it depending
 | 
				
			||||||
//   the type of the bounds. When pushing bounds, the method returns a number
 | 
					//   on the type of the bounds. When pushing bounds, the method returns a
 | 
				
			||||||
//   that represent the index associated with that iteration (induction variable
 | 
					//   number that represent the index associated with that iteration (induction
 | 
				
			||||||
//   and bounds). That index can be used later to extract the induction variable
 | 
					//   variable and bounds). That index can be used later to extract the
 | 
				
			||||||
//   for reference in computation and/or index calculations of mem refs.
 | 
					//   induction variable for reference in computation and/or index calculations
 | 
				
			||||||
 | 
					//   of mem refs.
 | 
				
			||||||
//
 | 
					//
 | 
				
			||||||
//   4) Once all the bounds are pushed, create the iterate operation. Once this
 | 
					//   4) Once all the bounds are pushed, create the iterate operation. Once this
 | 
				
			||||||
//   is done, we can add operations within the iterate blocks by setting the
 | 
					//   is done, we can add operations within the iterate blocks by setting the
 | 
				
			||||||
| 
						 | 
					@ -127,67 +129,90 @@ private:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class BuildKrnlLoop {
 | 
					class BuildKrnlLoop {
 | 
				
			||||||
public:
 | 
					public:
 | 
				
			||||||
  // Create a build kernel loop for the given location and loop number.
 | 
					  // Create kernel loop builder for a loop nest of depth loopNum.
 | 
				
			||||||
  BuildKrnlLoop(ConversionPatternRewriter &rewriter, Location loc, int loopNum);
 | 
					  BuildKrnlLoop(ConversionPatternRewriter &rewriter, Location loc, int loopNum);
 | 
				
			||||||
  // Do the same, but where the loop number corresponds to the dimensionality of
 | 
					
 | 
				
			||||||
  // the mem ref operand.
 | 
					  // Create kernel loop builder for a loop nest of depth equal to the
 | 
				
			||||||
 | 
					  // dimensionality of the operand. An operand of MemRef type is requied.
 | 
				
			||||||
  BuildKrnlLoop(
 | 
					  BuildKrnlLoop(
 | 
				
			||||||
      ConversionPatternRewriter &rewriter, Location loc, Value memRefOperand);
 | 
					      ConversionPatternRewriter &rewriter, Location loc, Value memRefOperand);
 | 
				
			||||||
  ~BuildKrnlLoop();
 | 
					  ~BuildKrnlLoop();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  // Create define and optimize loop with loopNum original loops. If
 | 
					  // Create define and optimize loop with loopNum original loops. If
 | 
				
			||||||
  // withEmptyOptimization, the optimization is simply the identity function (no
 | 
					  // withEmptyOptimization is true, the optimization is simply the identity
 | 
				
			||||||
  // optimizations).
 | 
					  // function (no optimizations).
 | 
				
			||||||
  void createDefineAndOptimizeOp(bool withEmptyOptimization = true);
 | 
					  void createDefineAndOptimizeOp(bool withEmptyOptimization = true);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  // Push bounds (lower and upper) for each of the loops, in order. It returns
 | 
					  // Push bounds (lower and upper) for each of the loops (order matters).
 | 
				
			||||||
  // the index associated with the loop iteration. This index is in the range
 | 
					  // The function returns the order number associated with the loop iteration.
 | 
				
			||||||
  // from zero to original loop number -1, and is monotonally increasing from
 | 
					  // This index is used by the getInductionVar call. Non-constant operands
 | 
				
			||||||
  // call to call. This index is later used in the getInductionVar call.
 | 
					  // must be of MemRef type.
 | 
				
			||||||
  int pushBounds(int64_t lowerBound, int64_t upperBound);
 | 
					  int pushBounds(int64_t lowerBound, int64_t upperBound);
 | 
				
			||||||
  int pushBounds(int64_t lowerBound, Value upperBound);
 | 
					  int pushBounds(int64_t lowerBound, Value upperBound);
 | 
				
			||||||
  int pushBounds(Value lowerBound, Value upperBound);
 | 
					  int pushBounds(Value lowerBound, Value upperBound);
 | 
				
			||||||
  // same, where the lower bound is an integer, and the uppoer bound is given by
 | 
					 | 
				
			||||||
  // the size of the mem ref operand along the upperBoundMemRefIndex dimension.
 | 
					 | 
				
			||||||
  int pushBounds(int64_t lowerBound, Value upperBoundMemRefOperand,
 | 
					  int pushBounds(int64_t lowerBound, Value upperBoundMemRefOperand,
 | 
				
			||||||
      int upperBoundMemRefIndex, bool upperBoundMustBeConstant = false);
 | 
					      int upperBoundMemRefIndex, bool upperBoundMustBeConstant = false);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  // Create an iterate op.
 | 
					  // Create the KrnlIterateOp assiciated with this loop nest. The loops
 | 
				
			||||||
 | 
					  // iteration will be created if the definition and the optimization
 | 
				
			||||||
 | 
					  // operations associated with this loop nest have been emitted already.
 | 
				
			||||||
  void createIterateOp();
 | 
					  void createIterateOp();
 | 
				
			||||||
  // Create an define, optimize and iterate op, with the same loop nummber as
 | 
					
 | 
				
			||||||
  // the rank of the memRefOperand. The lower bound of each loops is zero, and
 | 
					  // Create the loop nest definition, optimization and iteration operations
 | 
				
			||||||
  // the upper bound of each loops is the dimension given by the mem refs
 | 
					  // for a given operand of MemRef type. The loop nest has a depth equal to the
 | 
				
			||||||
 | 
					  // rank of the MemRef operand. The lower bound of each loop is zero. The
 | 
				
			||||||
 | 
					  // upper bound of each loop is given by the corresponding dimension of the
 | 
				
			||||||
 | 
					  // MemRef operand.
 | 
				
			||||||
  void createDefineOptimizeAndIterateOp(
 | 
					  void createDefineOptimizeAndIterateOp(
 | 
				
			||||||
      Value memRefOperand, bool withEmptyOptimization = true);
 | 
					      Value memRefOperand, bool withEmptyOptimization = true);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  // Get the (original loop) induction variable associated with the given index.
 | 
					  // Get the (original loop) induction variable associated with the given
 | 
				
			||||||
  // Use the index returned when pushing the bounds.
 | 
					  // index. Use the index returned when pushing the bounds.
 | 
				
			||||||
  BlockArgument &getInductionVar(int originalLoopIndex);
 | 
					  BlockArgument &getInductionVar(int originalLoopIndex);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  // Get blocks. This allow us to set the insertion point to the inner block of
 | 
					  // Get a reference to the code region of the optimization operation.
 | 
				
			||||||
  // the optimize and the iterate Operation
 | 
					  // This allows us to set the insertion point to the inner block of the
 | 
				
			||||||
 | 
					  // loop nest optimization operation.
 | 
				
			||||||
  Block *getOptimizationBlock() { return optBlock; }
 | 
					  Block *getOptimizationBlock() { return optBlock; }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // Get a reference to the code region of the iteration operation.
 | 
				
			||||||
 | 
					  // This allows us to set the insertion point to the inner block of the
 | 
				
			||||||
 | 
					  // loop nest iteration operation.
 | 
				
			||||||
  Block *getIterateBlock() { return iterBlock; }
 | 
					  Block *getIterateBlock() { return iterBlock; }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  // get original or optimized loops
 | 
					  // Get original loop nest.
 | 
				
			||||||
  std::vector<Value> &getOriginalLoops() { return originalLoops; }
 | 
					  std::vector<Value> &getOriginalLoops() { return originalLoops; }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // Get optimized loop nest.
 | 
				
			||||||
  std::vector<Value> &getOptimizedLoops() { return optLoops; }
 | 
					  std::vector<Value> &getOptimizedLoops() { return optLoops; }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
private:
 | 
					private:
 | 
				
			||||||
  // inputs
 | 
					  // Required for emitting operations.
 | 
				
			||||||
  ConversionPatternRewriter &rewriter;
 | 
					  ConversionPatternRewriter &rewriter;
 | 
				
			||||||
  Location loc;
 | 
					  Location loc;
 | 
				
			||||||
  int originalLoopNum;
 | 
					  int originalLoopNum;
 | 
				
			||||||
  // track loops and bounds
 | 
					
 | 
				
			||||||
 | 
					  // List of original, un-optimized loops.
 | 
				
			||||||
  std::vector<Value> originalLoops;
 | 
					  std::vector<Value> originalLoops;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // List of optimized loops.
 | 
				
			||||||
  std::vector<Value> optLoops;
 | 
					  std::vector<Value> optLoops;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // List of lower-upper bound pairs needed by the KrnlIterateOp.
 | 
				
			||||||
  KrnlIterateOperandPack *pack;
 | 
					  KrnlIterateOperandPack *pack;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // Number of lower-upper bound pairs pushed.
 | 
				
			||||||
  int pushCount;
 | 
					  int pushCount;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // Flags that keep track of emitted operations.
 | 
				
			||||||
  bool createdDefineOp;
 | 
					  bool createdDefineOp;
 | 
				
			||||||
  bool createdOptimizeOp;
 | 
					  bool createdOptimizeOp;
 | 
				
			||||||
  bool createdIterateOp;
 | 
					  bool createdIterateOp;
 | 
				
			||||||
  // insertion points (opt block, iterate)
 | 
					
 | 
				
			||||||
 | 
					  // Saved insertion point in the code region of the KrnlOptimizeLoopsOp.
 | 
				
			||||||
  Block *optBlock;
 | 
					  Block *optBlock;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // Saved insertion point in the code region of the KrnlIterateOp.
 | 
				
			||||||
  Block *iterBlock;
 | 
					  Block *iterBlock;
 | 
				
			||||||
};
 | 
					};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -90,25 +90,6 @@ def ONNXEntryPointOp: ONNX_Op<"EntryPoint"> {
 | 
				
			||||||
// or outputs. This decision affects only ONNX operations with optional
 | 
					// or outputs. This decision affects only ONNX operations with optional
 | 
				
			||||||
// arguments not ONNX operations with variadic operands.
 | 
					// arguments not ONNX operations with variadic operands.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def ONNXGemmNoBiasOp: ONNX_Op<"GemmNoBias",
 | 
					 | 
				
			||||||
    [NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
 | 
					 | 
				
			||||||
  let summary = "ONNX general matrix multiply operation without bias.";
 | 
					 | 
				
			||||||
  let description = [{
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    The "onnx.Gemm" generic matrix multiplication without bias.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
  }];
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
  let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$A,
 | 
					 | 
				
			||||||
           AnyTypeOf<[AnyMemRef, AnyTensor]>:$B,
 | 
					 | 
				
			||||||
           DefaultValuedAttr<F32Attr, "1.0">:$alpha,
 | 
					 | 
				
			||||||
           DefaultValuedAttr<F32Attr, "1.0">:$beta,
 | 
					 | 
				
			||||||
           DefaultValuedAttr<I64Attr, "0">:$transA,
 | 
					 | 
				
			||||||
           DefaultValuedAttr<I64Attr, "0">:$transB);
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
  let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>:$o_Y);
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def ONNXConvNoBiasOp:ONNX_Op<"ConvNoBias",
 | 
					def ONNXConvNoBiasOp:ONNX_Op<"ConvNoBias",
 | 
				
			||||||
    [NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
 | 
					    [NoSideEffect, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
 | 
				
			||||||
  let hasCanonicalizer = 1;
 | 
					  let hasCanonicalizer = 1;
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -565,32 +565,6 @@ void ONNXGemmOp::inferShapes() {
 | 
				
			||||||
  getResult().setType(RankedTensorType::get(dims, lhsTy.getElementType()));
 | 
					  getResult().setType(RankedTensorType::get(dims, lhsTy.getElementType()));
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// GemmNoBias
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
void ONNXGemmNoBiasOp::inferShapes() {
 | 
					 | 
				
			||||||
  // Cannot infer shape if no shape exists.
 | 
					 | 
				
			||||||
  if (!getOperand(0).getType().isa<RankedTensorType>() ||
 | 
					 | 
				
			||||||
      !getOperand(1).getType().isa<RankedTensorType>())
 | 
					 | 
				
			||||||
    return;
 | 
					 | 
				
			||||||
  auto lhsTy = getOperand(0).getType().cast<RankedTensorType>();
 | 
					 | 
				
			||||||
  auto rhsTy = getOperand(1).getType().cast<RankedTensorType>();
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
  int64_t M, N, K_A, K_B;
 | 
					 | 
				
			||||||
  M = (transA() == 0) ? lhsTy.getShape()[0] : lhsTy.getShape()[1];
 | 
					 | 
				
			||||||
  K_A = (transA() == 0) ? lhsTy.getShape()[1] : lhsTy.getShape()[0];
 | 
					 | 
				
			||||||
  N = (transB() == 0) ? rhsTy.getShape()[1] : rhsTy.getShape()[0];
 | 
					 | 
				
			||||||
  K_B = (transB() == 0) ? rhsTy.getShape()[0] : rhsTy.getShape()[1];
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
  if ((K_A != -1) and (K_B != -1) and (K_A != K_B)) {
 | 
					 | 
				
			||||||
    emitError("Tensor shapes mismatched.");
 | 
					 | 
				
			||||||
  }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
  SmallVector<int64_t, 2> dims;
 | 
					 | 
				
			||||||
  dims.emplace_back(M);
 | 
					 | 
				
			||||||
  dims.emplace_back(N);
 | 
					 | 
				
			||||||
  getResult().setType(RankedTensorType::get(dims, lhsTy.getElementType()));
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
/// BatchNormalizationTestMode
 | 
					/// BatchNormalizationTestMode
 | 
				
			||||||
void ONNXBatchNormalizationTestModeOp::inferShapes() {
 | 
					void ONNXBatchNormalizationTestModeOp::inferShapes() {
 | 
				
			||||||
  // Cannot infer shape if no shape exists.
 | 
					  // Cannot infer shape if no shape exists.
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -118,7 +118,6 @@ public:
 | 
				
			||||||
        op->getName().getStringRef() != "onnx.Identity" &&
 | 
					        op->getName().getStringRef() != "onnx.Identity" &&
 | 
				
			||||||
        op->getName().getStringRef() != "onnx.MatMul" &&
 | 
					        op->getName().getStringRef() != "onnx.MatMul" &&
 | 
				
			||||||
        op->getName().getStringRef() != "onnx.Gemm" &&
 | 
					        op->getName().getStringRef() != "onnx.Gemm" &&
 | 
				
			||||||
        op->getName().getStringRef() != "onnx.GemmNoBias" &&
 | 
					 | 
				
			||||||
        op->getName().getStringRef() != "onnx.Reshape" &&
 | 
					        op->getName().getStringRef() != "onnx.Reshape" &&
 | 
				
			||||||
        op->getName().getStringRef() != "onnx.Transpose" &&
 | 
					        op->getName().getStringRef() != "onnx.Transpose" &&
 | 
				
			||||||
        op->getName().getStringRef() != "onnx.ReduceMax" &&
 | 
					        op->getName().getStringRef() != "onnx.ReduceMax" &&
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -806,35 +806,6 @@ func @test_gemm(%arg0 : tensor<5x10xf32>, %arg1 : tensor<5x10xf32>, %arg2: tenso
 | 
				
			||||||
  // CHECK: }
 | 
					  // CHECK: }
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func @test_gemm_no_bias(%arg0 : tensor<5x10xf32>, %arg1 : tensor<5x10xf32>) -> tensor<*xf32> {
 | 
					 | 
				
			||||||
  %0 ="onnx.GemmNoBias"(%arg0, %arg1) {alpha = 1.0 : f32, beta = 5.0 : f32, transA = 1, transB = 0} : (tensor<5x10xf32>, tensor<5x10xf32>) -> tensor<*xf32>
 | 
					 | 
				
			||||||
  "std.return"(%0) : (tensor<*xf32>) -> ()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
  // CHECK-LABEL: test_gemm_no_bias
 | 
					 | 
				
			||||||
  // CHECK: [[RES:%.+]] = alloc() : memref<10x10xf32>
 | 
					 | 
				
			||||||
  // CHECK: [[ALPHA:%.+]] = constant 1.000000e+00 : f32
 | 
					 | 
				
			||||||
  // CHECK: [[BETA:%.+]] = constant 5.000000e+00 : f32
 | 
					 | 
				
			||||||
  // CHECK: [[DEF_LOOPS:%.+]]:3 = krnl.define_loops 3
 | 
					 | 
				
			||||||
  // CHECK: [[OPT_LOOPS:%.+]]:3 = krnl.optimize_loops  {
 | 
					 | 
				
			||||||
  // CHECK: krnl.return_loops [[DEF_LOOPS]]#0, [[DEF_LOOPS]]#1, [[DEF_LOOPS]]#2
 | 
					 | 
				
			||||||
  // CHECK: } : () -> (!krnl.loop, !krnl.loop, !krnl.loop)
 | 
					 | 
				
			||||||
  // CHECK: krnl.iterate([[OPT_LOOPS]]#0, [[OPT_LOOPS]]#1) with ([[DEF_LOOPS]]#0 -> %arg2 = 0 to 10, [[DEF_LOOPS]]#1 -> %arg3 = 0 to 10) {
 | 
					 | 
				
			||||||
  // CHECK: krnl.iterate([[OPT_LOOPS]]#2) with ([[DEF_LOOPS]]#2 -> %arg4 = 0 to 5) {
 | 
					 | 
				
			||||||
  // CHECK: [[A:%.+]] = load %arg0[%arg4, %arg2] : memref<5x10xf32>
 | 
					 | 
				
			||||||
  // CHECK: [[B:%.+]] = load %arg1[%arg4, %arg3] : memref<5x10xf32>
 | 
					 | 
				
			||||||
  // CHECK: [[Y:%.+]] = load [[RES]][%arg2, %arg3] : memref<10x10xf32>
 | 
					 | 
				
			||||||
  // CHECK: [[AB:%.+]] = mulf [[A]], [[B]] : f32
 | 
					 | 
				
			||||||
  // CHECK: [[SUM:%.+]] = addf [[Y]], [[AB]] : f32
 | 
					 | 
				
			||||||
  // CHECK: store [[SUM]], [[RES]][%arg2, %arg3] : memref<10x10xf32>
 | 
					 | 
				
			||||||
  // CHECK: }
 | 
					 | 
				
			||||||
  // CHECK: [[LOAD_Y:%.+]] = load [[RES]][%arg2, %arg3] : memref<10x10xf32>
 | 
					 | 
				
			||||||
  // CHECK: [[ALPHA_AB:%.+]] = mulf [[ALPHA]], [[LOAD_Y]] : f32
 | 
					 | 
				
			||||||
  // CHECK: store [[ALPHA_AB]], [[RES]][%arg2, %arg3] : memref<10x10xf32>
 | 
					 | 
				
			||||||
  // CHECK: }
 | 
					 | 
				
			||||||
  // CHECK: return [[RES]] : memref<10x10xf32>
 | 
					 | 
				
			||||||
  // CHECK: }
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func @test_sqrt(%arg0 : tensor<?x10xf32>) -> tensor<*xf32> {
 | 
					func @test_sqrt(%arg0 : tensor<?x10xf32>) -> tensor<*xf32> {
 | 
				
			||||||
  %0 = "onnx.Sqrt"(%arg0) : (tensor<?x10xf32>) -> tensor<*xf32>
 | 
					  %0 = "onnx.Sqrt"(%arg0) : (tensor<?x10xf32>) -> tensor<*xf32>
 | 
				
			||||||
  "std.return"(%0) : (tensor<*xf32>) -> ()
 | 
					  "std.return"(%0) : (tensor<*xf32>) -> ()
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -1 +1 @@
 | 
				
			||||||
Subproject commit 1439eab5542c625bb3da49860f0cd68c3eafdc18
 | 
					Subproject commit 553df22c67bee5f0fe6599cff60f1afc6748c635
 | 
				
			||||||
		Loading…
	
		Reference in New Issue