2020-03-19 16:48:09 +08:00
|
|
|
//====------ ONNXToKrnlCommon.hpp - ONNX dialects to Krnl lowering --------===//
|
2020-02-25 23:38:08 +08:00
|
|
|
//
|
|
|
|
// Copyright 2019 The IBM Research Authors.
|
|
|
|
//
|
|
|
|
// =============================================================================
|
|
|
|
//
|
|
|
|
// This file contains common code shared by the functions performing the
|
|
|
|
// lowering to the KRNL dialect.
|
|
|
|
//
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
#pragma once
|
|
|
|
|
|
|
|
#include <map>
|
|
|
|
|
2020-04-02 00:38:34 +08:00
|
|
|
#include "mlir/Dialect/Affine/IR/AffineOps.h"
|
|
|
|
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
2020-04-13 23:40:39 +08:00
|
|
|
#include "mlir/IR/PatternMatch.h"
|
2020-02-25 23:38:08 +08:00
|
|
|
#include "mlir/Pass/Pass.h"
|
|
|
|
#include "mlir/Transforms/DialectConversion.h"
|
|
|
|
#include "llvm/ADT/ArrayRef.h"
|
|
|
|
#include "llvm/ADT/Sequence.h"
|
|
|
|
|
2020-03-19 16:48:09 +08:00
|
|
|
#include "src/Dialect/Krnl/KrnlHelper.hpp"
|
|
|
|
#include "src/Dialect/Krnl/KrnlOps.hpp"
|
|
|
|
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
2020-05-14 17:31:33 +08:00
|
|
|
#include "src/Dialect/ONNX/ONNXOpsHelper.hpp"
|
2020-03-19 16:48:09 +08:00
|
|
|
#include "src/Pass/Passes.hpp"
|
2020-02-25 23:38:08 +08:00
|
|
|
|
|
|
|
using namespace mlir;
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// Common functions used when lowering the ONNX frontend dialect to KRNL.
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
/// Check is all dimensions are known at compile time.
|
|
|
|
bool hasAllConstantDimensions(MemRefType type);
|
|
|
|
|
2020-05-14 13:00:15 +08:00
|
|
|
/// Check is all operands are scalar values at compile time.
|
|
|
|
bool hasAllScalarValues(ArrayRef<Value> values);
|
|
|
|
|
2020-02-25 23:38:08 +08:00
|
|
|
/// Get the corresponding MemRefType of a given TensorType/MemRefType.
|
|
|
|
MemRefType convertToMemRefType(Type type);
|
|
|
|
|
|
|
|
/// Insert an allocation and deallocation for the given MemRefType.
|
|
|
|
Value insertAllocAndDealloc(MemRefType type, Location loc,
|
2020-04-13 23:40:39 +08:00
|
|
|
PatternRewriter &rewriter, bool insertDealloc,
|
2020-07-02 15:21:01 +08:00
|
|
|
ArrayRef<Value> operands = {}, int64_t alignment = -1);
|
2020-02-25 23:38:08 +08:00
|
|
|
|
|
|
|
// Determine if current function returns the result value of the
|
|
|
|
// current op being lowered. If it does then dealloc should not be
|
|
|
|
// inserted.
|
2020-05-13 21:08:06 +08:00
|
|
|
bool checkInsertDealloc(Operation *currentOp, int resultIndex = 0);
|
2020-02-25 23:38:08 +08:00
|
|
|
|
|
|
|
// Create a mapping from result type's dimensions to input type's dimensions,
|
|
|
|
// given that the result type is the result of a reduction op over the input
|
|
|
|
// type.
|
2020-04-13 23:40:39 +08:00
|
|
|
std::map<int64_t, int64_t> getReductionMapping(
|
|
|
|
MemRefType inputTy, ArrayRef<int64_t> axes, bool keepdims);
|
2020-02-25 23:38:08 +08:00
|
|
|
|
|
|
|
// Add bounds associated with the op operand to the KRNL iteration pack.
|
|
|
|
// Dynamic dimenions are supported.
|
2020-04-13 23:40:39 +08:00
|
|
|
void addDimensionToPack(ConversionPatternRewriter &rewriter, Location loc,
|
|
|
|
KrnlIterateOperandPack &pack, Value operand, int index);
|
2020-02-25 23:38:08 +08:00
|
|
|
|
2020-07-08 12:49:15 +08:00
|
|
|
// Function that emits the define_loop operation to define `numLoops`
|
|
|
|
// number of krnl loops, and fill `loop` with the newly defined loops.
|
|
|
|
void defineLoops(ConversionPatternRewriter &rewriter, Location loc,
|
|
|
|
std::vector<Value> &loops, int64_t numLoops);
|
2020-02-25 23:38:08 +08:00
|
|
|
|
|
|
|
unsigned getMemRefEltSizeInBytes(MemRefType memRefType);
|
|
|
|
|
|
|
|
// Get run-time dimension information for unknown dimensions used for
|
|
|
|
// broadcasting.
|
2020-04-13 23:40:39 +08:00
|
|
|
std::map<int, std::map<int, Value>> getBroadcastedDimInfo(Location loc,
|
|
|
|
ConversionPatternRewriter &rewriter, MemRefType memRefType,
|
|
|
|
ArrayRef<Value> operands);
|
2020-02-25 23:38:08 +08:00
|
|
|
|
|
|
|
// Extract induction variables that are used for broadcasting values of a
|
|
|
|
// given operand.
|
2020-04-13 23:40:39 +08:00
|
|
|
std::vector<Value> getLoopIVsForBroadcasting(Location loc,
|
|
|
|
ConversionPatternRewriter &rewriter, ArrayRef<Value> loopIVs, Value operand,
|
|
|
|
std::map<int, Value> broadcastedDims);
|
2020-02-25 23:38:08 +08:00
|
|
|
|
2020-03-06 03:21:00 +08:00
|
|
|
// Emit a constant of a specific type.
|
|
|
|
// Use this function for small values only to avoid unexpected loss in type
|
|
|
|
// casting.
|
|
|
|
Value emitConstantOp(
|
|
|
|
ConversionPatternRewriter &rewriter, Location loc, Type type, double value);
|
|
|
|
|
|
|
|
// Emit a positive infinity constant of a specific type.
|
|
|
|
// Supported types: F16, F32, F64, Int8, Int16, Int32, Int64.
|
|
|
|
// In case of Integer, emit the maximum value.
|
|
|
|
Value emitPositiveInfinityConstantOp(
|
|
|
|
ConversionPatternRewriter &rewriter, Location loc, Type type);
|
|
|
|
|
|
|
|
// Emit a negative infinity constant of a specific type.
|
|
|
|
// Supported types: F16, F32, F64, Int8, Int16, Int32, Int64.
|
|
|
|
// In case of Float, emit the negative of the positive infinity.
|
|
|
|
// In case of Integer, emit the minimum value.
|
|
|
|
Value emitNegativeInfinityConstantOp(
|
|
|
|
ConversionPatternRewriter &rewriter, Location loc, Type type);
|
|
|
|
|
2020-04-02 01:51:06 +08:00
|
|
|
int64_t ArrayAttrIntVal(ArrayAttr a, int i);
|
|
|
|
|
2020-02-25 23:38:08 +08:00
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// This is to get a scalar operation of a given type for a specific operation.
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
template <typename Op>
|
|
|
|
struct ScalarOp {
|
|
|
|
using FOp = void;
|
|
|
|
using IOp = void;
|
|
|
|
};
|
|
|
|
|
|
|
|
template <typename FOp>
|
|
|
|
using ScalarFOp = typename ScalarOp<FOp>::FOp;
|
|
|
|
template <typename IOp>
|
|
|
|
using ScalarIOp = typename ScalarOp<IOp>::IOp;
|
|
|
|
|
2020-03-06 03:21:00 +08:00
|
|
|
// Get the identity element of an operation.
|
2020-02-25 23:38:08 +08:00
|
|
|
// Return NULL if the function does not have identity.
|
2020-03-06 03:21:00 +08:00
|
|
|
// Specialize this for a new Op.
|
|
|
|
template <typename Op>
|
|
|
|
Value getIdentityValue(
|
|
|
|
ConversionPatternRewriter &rewriter, Location loc, Type type) {
|
|
|
|
return nullptr;
|
2020-02-25 23:38:08 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// This is used in the innermost loop of a KrnlIterateOp to insert computation
|
|
|
|
// composed of one or many scalar ops.
|
|
|
|
// Use template specialization for each of different ONNX operations.
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
template <typename Op>
|
2020-04-09 16:06:56 +08:00
|
|
|
Value emitScalarOpFor(ConversionPatternRewriter &rewriter, Location loc,
|
|
|
|
Operation *op, Type elementType, ArrayRef<Value> scalarOperands) {
|
|
|
|
if (elementType.isa<IntegerType>()) {
|
|
|
|
return rewriter.create<ScalarIOp<Op>>(
|
|
|
|
loc, elementType, scalarOperands, mlir::None);
|
|
|
|
} else if (elementType.isa<FloatType>()) {
|
|
|
|
return rewriter.create<ScalarFOp<Op>>(
|
|
|
|
loc, elementType, scalarOperands, mlir::None);
|
2020-02-25 23:38:08 +08:00
|
|
|
} else {
|
2020-06-02 01:55:19 +08:00
|
|
|
llvm_unreachable("unsupported element type");
|
2020-02-25 23:38:08 +08:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// Conversion from Tensor type to the Standard dialect MemRef type.
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
|
|
|
struct TensorTypeConverter : public TypeConverter {
|
|
|
|
using TypeConverter::TypeConverter;
|
|
|
|
|
2020-04-13 23:40:39 +08:00
|
|
|
TensorTypeConverter() { addConversion(convertType); }
|
2020-02-25 23:38:08 +08:00
|
|
|
|
|
|
|
static LogicalResult convertType(Type t, SmallVectorImpl<Type> &results) {
|
|
|
|
if (auto type = convertToMemRefType(t)) {
|
|
|
|
results.push_back(type);
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
|
|
|
results.push_back(t);
|
|
|
|
return success();
|
|
|
|
}
|
|
|
|
|
|
|
|
/// Return true if the inputs and outputs of the given function type are
|
|
|
|
/// legal. [Taken from MLIR and adapted to only check the legality of the
|
|
|
|
/// inputs. Once unranked results can be handled gracefully this
|
|
|
|
/// override needs to be removed in favour of the original MLIR one.]
|
|
|
|
bool isSignatureLegal(FunctionType funcType) {
|
2020-04-13 23:40:39 +08:00
|
|
|
return llvm::all_of(
|
2020-04-27 23:01:51 +08:00
|
|
|
llvm::concat<const Type>(funcType.getInputs(), funcType.getResults()),
|
|
|
|
[this](Type type) { return isLegal(type); });
|
2020-02-25 23:38:08 +08:00
|
|
|
}
|
|
|
|
};
|
|
|
|
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// Functions to add lowering patterns for frontend operations.
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
|
2020-03-19 16:48:09 +08:00
|
|
|
// `Math` directory methods:
|
2020-02-25 23:38:08 +08:00
|
|
|
|
|
|
|
void populateLoweringONNXElementwiseOpPattern(
|
|
|
|
OwningRewritePatternList &patterns, MLIRContext *ctx);
|
|
|
|
|
2020-04-13 23:40:39 +08:00
|
|
|
void populateLoweringONNXGemmOpPattern(
|
|
|
|
OwningRewritePatternList &patterns, MLIRContext *ctx);
|
2020-02-25 23:38:08 +08:00
|
|
|
|
|
|
|
void populateLoweringONNXMatMulOpPattern(
|
|
|
|
OwningRewritePatternList &patterns, MLIRContext *ctx);
|
|
|
|
|
|
|
|
void populateLoweringONNXReductionOpPattern(
|
|
|
|
OwningRewritePatternList &patterns, MLIRContext *ctx);
|
|
|
|
|
|
|
|
void populateLoweringONNXSoftmaxOpPattern(
|
|
|
|
OwningRewritePatternList &patterns, MLIRContext *ctx);
|
|
|
|
|
2020-03-19 16:48:09 +08:00
|
|
|
// `NN` directory methods:
|
2020-02-25 23:38:08 +08:00
|
|
|
|
|
|
|
void populateLoweringONNXConvOpPattern(
|
|
|
|
OwningRewritePatternList &patterns, MLIRContext *ctx);
|
|
|
|
|
|
|
|
void populateLoweringONNXNormalizationOpPattern(
|
|
|
|
OwningRewritePatternList &patterns, MLIRContext *ctx);
|
|
|
|
|
2020-03-05 03:27:21 +08:00
|
|
|
void populateLoweringONNXPoolingOpPattern(
|
|
|
|
OwningRewritePatternList &patterns, MLIRContext *ctx);
|
|
|
|
|
2020-05-13 21:08:06 +08:00
|
|
|
// `RNN` directory methods:
|
|
|
|
void populateLoweringONNXLSTMOpPattern(
|
|
|
|
OwningRewritePatternList &patterns, MLIRContext *ctx);
|
|
|
|
|
2020-03-19 16:48:09 +08:00
|
|
|
// `Tensor` directory methods:
|
2020-02-25 23:38:08 +08:00
|
|
|
|
|
|
|
void populateLoweringONNXUnsqueezeOpPattern(
|
|
|
|
OwningRewritePatternList &patterns, MLIRContext *ctx);
|
|
|
|
|
|
|
|
void populateLoweringONNXTransposeOpPattern(
|
|
|
|
OwningRewritePatternList &patterns, MLIRContext *ctx);
|
|
|
|
|
2020-03-12 04:54:07 +08:00
|
|
|
void populateLoweringONNXPadConstantValuePadOpPattern(
|
|
|
|
OwningRewritePatternList &patterns, MLIRContext *ctx);
|
|
|
|
|
2020-05-15 13:19:28 +08:00
|
|
|
void populateLoweringONNXPadOpPattern(
|
|
|
|
OwningRewritePatternList &patterns, MLIRContext *ctx);
|
|
|
|
|
2020-02-25 23:38:08 +08:00
|
|
|
void populateLoweringONNXReshapeOpPattern(
|
|
|
|
OwningRewritePatternList &patterns, MLIRContext *ctx);
|
|
|
|
|
|
|
|
void populateLoweringONNXIdentityOpPattern(
|
|
|
|
OwningRewritePatternList &patterns, MLIRContext *ctx);
|
2020-03-12 22:58:42 +08:00
|
|
|
|
|
|
|
void populateLoweringONNXConstantOpPattern(
|
|
|
|
OwningRewritePatternList &patterns, MLIRContext *ctx);
|
2020-04-13 23:40:39 +08:00
|
|
|
|
|
|
|
void populateLoweringONNXConcatOpPattern(
|
|
|
|
OwningRewritePatternList &patterns, MLIRContext *ctx);
|
2020-06-11 10:57:20 +08:00
|
|
|
|
2020-07-03 16:26:41 +08:00
|
|
|
void populateLoweringONNXSqueezeOpPattern(
|
|
|
|
OwningRewritePatternList &patterns, MLIRContext *ctx);
|
|
|
|
|
2020-06-11 10:57:20 +08:00
|
|
|
void populateLoweringONNXSplitOpPattern(
|
|
|
|
OwningRewritePatternList &patterns, MLIRContext *ctx);
|
2020-07-09 00:35:31 +08:00
|
|
|
|
|
|
|
bool checkOpResultIsUsedByGetRef(AllocOp *allocOp);
|
|
|
|
|
|
|
|
int64_t getMemRefSizeInBytes(Value val);
|