Get MemRefType for result types (#69)

* Get memreftype for result types

* Revise

* Replace convertToMemRefType

* Use convertToMemRefType in ONNXConvNoBiasOpLowering

* Merge with the master branch

* Reverse an unintentional change

Co-authored-by: Gheorghe-Teodor Bercea <gt.bercea@gmail.com>
This commit is contained in:
Tung D. Le 2020-02-20 22:44:02 +09:00 committed by GitHub
parent b28c6906b4
commit a3f042220e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 2410 additions and 33 deletions

View File

@ -37,10 +37,18 @@ static bool hasAllConstantDimensions(MemRefType type) {
return true;
}
/// Convert the given TensorType into the corresponding MemRefType.
static MemRefType convertTensorToMemRef(TensorType type) {
assert(type.hasRank() && "expected only ranked shapes");
return MemRefType::get(type.getShape(), type.getElementType());
/// Get the corresponding MemRefType of a given TensorType/MemRefType.
static MemRefType convertToMemRefType(Type type) {
MemRefType memRefType;
auto tensorType = type.dyn_cast<TensorType>();
if (tensorType) {
assert(tensorType.hasRank() && "expected only ranked shapes");
memRefType =
MemRefType::get(tensorType.getShape(), tensorType.getElementType());
} else {
memRefType = type.dyn_cast<MemRefType>();
}
return memRefType;
}
/// Insert an allocation and deallocation for the given MemRefType.
@ -430,8 +438,8 @@ struct TensorTypeConverter : public TypeConverter {
}
static LogicalResult convertType(Type t, SmallVectorImpl<Type> &results) {
if (auto tensor_type = t.dyn_cast<TensorType>()) {
results.push_back(convertTensorToMemRef(tensor_type));
if (auto type = convertToMemRefType(t)) {
results.push_back(type);
return success();
}

View File

@ -476,11 +476,10 @@ struct ONNXElementwiseUnaryOpLowering : public ConversionPattern {
// TODO: Check that the types are valid.
// An element-wise unary operation must have all operands and the result of
// the same type. This should have been verified by the verifier.
auto tensorType = (*op->result_type_begin()).cast<TensorType>();
auto loc = op->getLoc();
// Insert an allocation and deallocation for the result of this operation.
auto memRefType = convertTensorToMemRef(tensorType);
auto memRefType = convertToMemRefType(*op->result_type_begin());
// If the output has a dynamic dimension, pass the operands required for
// each dynamic dimension to the AllocOp. The first operand of the
@ -545,12 +544,11 @@ struct ONNXElementwiseVariadicOpLowering : public ConversionPattern {
// TODO: Check that the types are valid.
// An element-wise variadic operation must have all operands and the result
// of the same type. This should have been verified by the verifier.
auto tensorType = (*op->result_type_begin()).cast<TensorType>();
auto loc = op->getLoc();
auto numArgs = op->getNumOperands();
// Insert an allocation and deallocation for the result of this operation.
auto memRefType = convertTensorToMemRef(tensorType);
auto memRefType = convertToMemRefType(*op->result_type_begin());
Value alloc;
bool insertDealloc = checkInsertDealloc(op);

View File

@ -15,7 +15,6 @@ struct ONNXGemmOpLowering : public ConversionPattern {
PatternMatchResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
auto tensorType = (*op->result_type_begin()).cast<TensorType>();
auto loc = op->getLoc();
Value A, B, C;
@ -23,9 +22,11 @@ struct ONNXGemmOpLowering : public ConversionPattern {
B = operands[1];
C = operands[2];
auto alphaAttr = FloatAttr::get(tensorType.getElementType(),
auto memRefType = convertToMemRefType(*op->result_type_begin());
auto alphaAttr = FloatAttr::get(memRefType.getElementType(),
llvm::dyn_cast<ONNXGemmOp>(op).alpha().convertToFloat());
auto betaAttr = FloatAttr::get(tensorType.getElementType(),
auto betaAttr = FloatAttr::get(memRefType.getElementType(),
llvm::dyn_cast<ONNXGemmOp>(op).beta().convertToFloat());
auto alpha = rewriter.create<ConstantOp>(loc, alphaAttr);
auto beta = rewriter.create<ConstantOp>(loc, betaAttr);
@ -33,9 +34,6 @@ struct ONNXGemmOpLowering : public ConversionPattern {
bool isTransA = (llvm::dyn_cast<ONNXGemmOp>(op).transA() != 0);
bool isTransB = (llvm::dyn_cast<ONNXGemmOp>(op).transB() != 0);
// Result type
auto memRefType = convertTensorToMemRef(tensorType);
// Insert an allocation and deallocation for the result of this operation.
Value alloc;
bool insertDealloc = checkInsertDealloc(op);

View File

@ -15,7 +15,6 @@ struct ONNXMatMulOpLowering : public ConversionPattern {
PatternMatchResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
auto tensorType = (*op->result_type_begin()).cast<TensorType>();
auto loc = op->getLoc();
Value A = operands[0];
@ -29,7 +28,7 @@ struct ONNXMatMulOpLowering : public ConversionPattern {
// - Both arguments are 1-D
// Result type
auto memRefType = convertTensorToMemRef(tensorType);
auto memRefType = convertToMemRefType(*op->result_type_begin());
auto elementType = memRefType.getElementType();
auto memRefShape = memRefType.getShape();

View File

@ -145,9 +145,9 @@ struct ONNXReductionOpLowering : public ConversionPattern {
auto loc = op->getLoc();
auto memRefInType = operands[0].getType().cast<MemRefType>();
auto memRefInShape = memRefInType.getShape();
auto tensorOutType = (*op->result_type_begin()).cast<TensorType>();
auto memRefOutType = convertToMemRefType(*op->result_type_begin());
int64_t inRank = memRefInType.getRank();
int64_t outRank = tensorOutType.getRank();
int64_t outRank = memRefOutType.getRank();
// Get attributes
ArrayAttr axisAttrs = llvm::dyn_cast<ONNXReductionOp>(op).axesAttr();
@ -171,7 +171,6 @@ struct ONNXReductionOpLowering : public ConversionPattern {
bool isKeepdims = (keepdims == 1) ? true : false;
// Get type information
auto memRefOutType = convertTensorToMemRef(tensorOutType);
auto memRefOutShape = memRefOutType.getShape();
auto elementOutType = memRefOutType.getElementType();
std::map<int64_t, int64_t> outInDimMap =

View File

@ -18,8 +18,8 @@ struct ONNXSoftmaxOpLowering : public ConversionPattern {
// let exp_x = exp(x - max_x) in
// let sum = sum(exp_x) in
// exp_x / sum
auto tensorType = (*op->result_type_begin()).cast<RankedTensorType>();
int64_t rank = tensorType.getRank();
auto memRefType = convertToMemRefType(*op->result_type_begin());
int64_t rank = memRefType.getRank();
int64_t axis = llvm::dyn_cast<ONNXSoftmaxOp>(op).axis().getSExtValue();
axis = axis >= 0 ? axis : rank + axis;
assert(axis >= -rank && axis <= rank - 1);
@ -27,7 +27,6 @@ struct ONNXSoftmaxOpLowering : public ConversionPattern {
auto loc = op->getLoc();
// Insert an allocation and deallocation for the result of this operation.
auto memRefType = convertTensorToMemRef(tensorType);
auto elementType = memRefType.getElementType();
Value alloc;

View File

@ -15,10 +15,9 @@ struct ONNXConvNoBiasOpLowering : public ConversionPattern {
PatternMatchResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
auto tensorType = (*op->result_type_begin()).cast<TensorType>();
auto loc = op->getLoc();
// Insert an allocation and deallocation for the result of this operation.
auto memRefType = convertTensorToMemRef(tensorType);
auto memRefType = convertToMemRefType(*op->result_type_begin());
Value alloc;
bool insertDealloc = checkInsertDealloc(op);
ONNXConvNoBiasOp convOp = llvm::dyn_cast<ONNXConvNoBiasOp>(op);

View File

@ -15,12 +15,11 @@ struct ONNXReshapeOpLowering : public ConversionPattern {
PatternMatchResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
auto tensorType = (*op->result_type_begin()).cast<TensorType>();
auto inputShape = operands[0].getType().cast<MemRefType>().getShape();
auto loc = op->getLoc();
// Insert an allocation and deallocation for the result of this operation.
auto memRefType = convertTensorToMemRef(tensorType);
auto memRefType = convertToMemRefType(*op->result_type_begin());
auto memRefShape = memRefType.getShape();
Value alloc;

View File

@ -15,10 +15,9 @@ struct ONNXTransposeOpLowering : public ConversionPattern {
PatternMatchResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
auto tensorType = (*op->result_type_begin()).cast<TensorType>();
auto loc = op->getLoc();
// Insert an allocation and deallocation for the result of this operation.
auto memRefType = convertTensorToMemRef(tensorType);
auto memRefType = convertToMemRefType(*op->result_type_begin());
Value alloc;
bool insertDealloc = checkInsertDealloc(op);

View File

@ -16,8 +16,8 @@ struct ONNXUnsqueezeOpLowering : public ConversionPattern {
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
auto loc = op->getLoc();
auto tensorType = (*op->result_type_begin()).cast<TensorType>();
int outRank = tensorType.getRank();
auto memRefType = convertToMemRefType(*op->result_type_begin());
int outRank = memRefType.getRank();
// Assume that `axes` has been validated by shape inference.
// So, here we just get it.
@ -30,7 +30,6 @@ struct ONNXUnsqueezeOpLowering : public ConversionPattern {
}
// Insert an allocation and deallocation for the result of this operation.
auto memRefType = convertTensorToMemRef(tensorType);
Value alloc;
// Compute size in bytes.

File diff suppressed because it is too large Load Diff