Get MemRefType for result types (#69)

* Get memreftype for result types

* Revise

* Replace convertToMemRefType

* Use convertToMemRefType in ONNXConvNoBiasOpLowering

* Merge with the master branch

* Reverse an unintentional change

Co-authored-by: Gheorghe-Teodor Bercea <gt.bercea@gmail.com>
This commit is contained in:
Tung D. Le 2020-02-20 22:44:02 +09:00 committed by GitHub
parent b28c6906b4
commit a3f042220e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 2410 additions and 33 deletions

View File

@ -37,10 +37,18 @@ static bool hasAllConstantDimensions(MemRefType type) {
return true; return true;
} }
/// Convert the given TensorType into the corresponding MemRefType. /// Get the corresponding MemRefType of a given TensorType/MemRefType.
static MemRefType convertTensorToMemRef(TensorType type) { static MemRefType convertToMemRefType(Type type) {
assert(type.hasRank() && "expected only ranked shapes"); MemRefType memRefType;
return MemRefType::get(type.getShape(), type.getElementType()); auto tensorType = type.dyn_cast<TensorType>();
if (tensorType) {
assert(tensorType.hasRank() && "expected only ranked shapes");
memRefType =
MemRefType::get(tensorType.getShape(), tensorType.getElementType());
} else {
memRefType = type.dyn_cast<MemRefType>();
}
return memRefType;
} }
/// Insert an allocation and deallocation for the given MemRefType. /// Insert an allocation and deallocation for the given MemRefType.
@ -430,8 +438,8 @@ struct TensorTypeConverter : public TypeConverter {
} }
static LogicalResult convertType(Type t, SmallVectorImpl<Type> &results) { static LogicalResult convertType(Type t, SmallVectorImpl<Type> &results) {
if (auto tensor_type = t.dyn_cast<TensorType>()) { if (auto type = convertToMemRefType(t)) {
results.push_back(convertTensorToMemRef(tensor_type)); results.push_back(type);
return success(); return success();
} }

View File

@ -476,11 +476,10 @@ struct ONNXElementwiseUnaryOpLowering : public ConversionPattern {
// TODO: Check that the types are valid. // TODO: Check that the types are valid.
// An element-wise unary operation must have all operands and the result of // An element-wise unary operation must have all operands and the result of
// the same type. This should have been verified by the verifier. // the same type. This should have been verified by the verifier.
auto tensorType = (*op->result_type_begin()).cast<TensorType>();
auto loc = op->getLoc(); auto loc = op->getLoc();
// Insert an allocation and deallocation for the result of this operation. // Insert an allocation and deallocation for the result of this operation.
auto memRefType = convertTensorToMemRef(tensorType); auto memRefType = convertToMemRefType(*op->result_type_begin());
// If the output has a dynamic dimension, pass the operands required for // If the output has a dynamic dimension, pass the operands required for
// each dynamic dimension to the AllocOp. The first operand of the // each dynamic dimension to the AllocOp. The first operand of the
@ -545,12 +544,11 @@ struct ONNXElementwiseVariadicOpLowering : public ConversionPattern {
// TODO: Check that the types are valid. // TODO: Check that the types are valid.
// An element-wise variadic operation must have all operands and the result // An element-wise variadic operation must have all operands and the result
// of the same type. This should have been verified by the verifier. // of the same type. This should have been verified by the verifier.
auto tensorType = (*op->result_type_begin()).cast<TensorType>();
auto loc = op->getLoc(); auto loc = op->getLoc();
auto numArgs = op->getNumOperands(); auto numArgs = op->getNumOperands();
// Insert an allocation and deallocation for the result of this operation. // Insert an allocation and deallocation for the result of this operation.
auto memRefType = convertTensorToMemRef(tensorType); auto memRefType = convertToMemRefType(*op->result_type_begin());
Value alloc; Value alloc;
bool insertDealloc = checkInsertDealloc(op); bool insertDealloc = checkInsertDealloc(op);

View File

@ -15,7 +15,6 @@ struct ONNXGemmOpLowering : public ConversionPattern {
PatternMatchResult PatternMatchResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands, matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final { ConversionPatternRewriter &rewriter) const final {
auto tensorType = (*op->result_type_begin()).cast<TensorType>();
auto loc = op->getLoc(); auto loc = op->getLoc();
Value A, B, C; Value A, B, C;
@ -23,9 +22,11 @@ struct ONNXGemmOpLowering : public ConversionPattern {
B = operands[1]; B = operands[1];
C = operands[2]; C = operands[2];
auto alphaAttr = FloatAttr::get(tensorType.getElementType(), auto memRefType = convertToMemRefType(*op->result_type_begin());
auto alphaAttr = FloatAttr::get(memRefType.getElementType(),
llvm::dyn_cast<ONNXGemmOp>(op).alpha().convertToFloat()); llvm::dyn_cast<ONNXGemmOp>(op).alpha().convertToFloat());
auto betaAttr = FloatAttr::get(tensorType.getElementType(), auto betaAttr = FloatAttr::get(memRefType.getElementType(),
llvm::dyn_cast<ONNXGemmOp>(op).beta().convertToFloat()); llvm::dyn_cast<ONNXGemmOp>(op).beta().convertToFloat());
auto alpha = rewriter.create<ConstantOp>(loc, alphaAttr); auto alpha = rewriter.create<ConstantOp>(loc, alphaAttr);
auto beta = rewriter.create<ConstantOp>(loc, betaAttr); auto beta = rewriter.create<ConstantOp>(loc, betaAttr);
@ -33,9 +34,6 @@ struct ONNXGemmOpLowering : public ConversionPattern {
bool isTransA = (llvm::dyn_cast<ONNXGemmOp>(op).transA() != 0); bool isTransA = (llvm::dyn_cast<ONNXGemmOp>(op).transA() != 0);
bool isTransB = (llvm::dyn_cast<ONNXGemmOp>(op).transB() != 0); bool isTransB = (llvm::dyn_cast<ONNXGemmOp>(op).transB() != 0);
// Result type
auto memRefType = convertTensorToMemRef(tensorType);
// Insert an allocation and deallocation for the result of this operation. // Insert an allocation and deallocation for the result of this operation.
Value alloc; Value alloc;
bool insertDealloc = checkInsertDealloc(op); bool insertDealloc = checkInsertDealloc(op);

View File

@ -15,7 +15,6 @@ struct ONNXMatMulOpLowering : public ConversionPattern {
PatternMatchResult PatternMatchResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands, matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final { ConversionPatternRewriter &rewriter) const final {
auto tensorType = (*op->result_type_begin()).cast<TensorType>();
auto loc = op->getLoc(); auto loc = op->getLoc();
Value A = operands[0]; Value A = operands[0];
@ -29,7 +28,7 @@ struct ONNXMatMulOpLowering : public ConversionPattern {
// - Both arguments are 1-D // - Both arguments are 1-D
// Result type // Result type
auto memRefType = convertTensorToMemRef(tensorType); auto memRefType = convertToMemRefType(*op->result_type_begin());
auto elementType = memRefType.getElementType(); auto elementType = memRefType.getElementType();
auto memRefShape = memRefType.getShape(); auto memRefShape = memRefType.getShape();

View File

@ -145,9 +145,9 @@ struct ONNXReductionOpLowering : public ConversionPattern {
auto loc = op->getLoc(); auto loc = op->getLoc();
auto memRefInType = operands[0].getType().cast<MemRefType>(); auto memRefInType = operands[0].getType().cast<MemRefType>();
auto memRefInShape = memRefInType.getShape(); auto memRefInShape = memRefInType.getShape();
auto tensorOutType = (*op->result_type_begin()).cast<TensorType>(); auto memRefOutType = convertToMemRefType(*op->result_type_begin());
int64_t inRank = memRefInType.getRank(); int64_t inRank = memRefInType.getRank();
int64_t outRank = tensorOutType.getRank(); int64_t outRank = memRefOutType.getRank();
// Get attributes // Get attributes
ArrayAttr axisAttrs = llvm::dyn_cast<ONNXReductionOp>(op).axesAttr(); ArrayAttr axisAttrs = llvm::dyn_cast<ONNXReductionOp>(op).axesAttr();
@ -171,7 +171,6 @@ struct ONNXReductionOpLowering : public ConversionPattern {
bool isKeepdims = (keepdims == 1) ? true : false; bool isKeepdims = (keepdims == 1) ? true : false;
// Get type information // Get type information
auto memRefOutType = convertTensorToMemRef(tensorOutType);
auto memRefOutShape = memRefOutType.getShape(); auto memRefOutShape = memRefOutType.getShape();
auto elementOutType = memRefOutType.getElementType(); auto elementOutType = memRefOutType.getElementType();
std::map<int64_t, int64_t> outInDimMap = std::map<int64_t, int64_t> outInDimMap =

View File

@ -18,8 +18,8 @@ struct ONNXSoftmaxOpLowering : public ConversionPattern {
// let exp_x = exp(x - max_x) in // let exp_x = exp(x - max_x) in
// let sum = sum(exp_x) in // let sum = sum(exp_x) in
// exp_x / sum // exp_x / sum
auto tensorType = (*op->result_type_begin()).cast<RankedTensorType>(); auto memRefType = convertToMemRefType(*op->result_type_begin());
int64_t rank = tensorType.getRank(); int64_t rank = memRefType.getRank();
int64_t axis = llvm::dyn_cast<ONNXSoftmaxOp>(op).axis().getSExtValue(); int64_t axis = llvm::dyn_cast<ONNXSoftmaxOp>(op).axis().getSExtValue();
axis = axis >= 0 ? axis : rank + axis; axis = axis >= 0 ? axis : rank + axis;
assert(axis >= -rank && axis <= rank - 1); assert(axis >= -rank && axis <= rank - 1);
@ -27,7 +27,6 @@ struct ONNXSoftmaxOpLowering : public ConversionPattern {
auto loc = op->getLoc(); auto loc = op->getLoc();
// Insert an allocation and deallocation for the result of this operation. // Insert an allocation and deallocation for the result of this operation.
auto memRefType = convertTensorToMemRef(tensorType);
auto elementType = memRefType.getElementType(); auto elementType = memRefType.getElementType();
Value alloc; Value alloc;

View File

@ -15,10 +15,9 @@ struct ONNXConvNoBiasOpLowering : public ConversionPattern {
PatternMatchResult PatternMatchResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands, matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final { ConversionPatternRewriter &rewriter) const final {
auto tensorType = (*op->result_type_begin()).cast<TensorType>();
auto loc = op->getLoc(); auto loc = op->getLoc();
// Insert an allocation and deallocation for the result of this operation. // Insert an allocation and deallocation for the result of this operation.
auto memRefType = convertTensorToMemRef(tensorType); auto memRefType = convertToMemRefType(*op->result_type_begin());
Value alloc; Value alloc;
bool insertDealloc = checkInsertDealloc(op); bool insertDealloc = checkInsertDealloc(op);
ONNXConvNoBiasOp convOp = llvm::dyn_cast<ONNXConvNoBiasOp>(op); ONNXConvNoBiasOp convOp = llvm::dyn_cast<ONNXConvNoBiasOp>(op);

View File

@ -15,12 +15,11 @@ struct ONNXReshapeOpLowering : public ConversionPattern {
PatternMatchResult PatternMatchResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands, matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final { ConversionPatternRewriter &rewriter) const final {
auto tensorType = (*op->result_type_begin()).cast<TensorType>();
auto inputShape = operands[0].getType().cast<MemRefType>().getShape(); auto inputShape = operands[0].getType().cast<MemRefType>().getShape();
auto loc = op->getLoc(); auto loc = op->getLoc();
// Insert an allocation and deallocation for the result of this operation. // Insert an allocation and deallocation for the result of this operation.
auto memRefType = convertTensorToMemRef(tensorType); auto memRefType = convertToMemRefType(*op->result_type_begin());
auto memRefShape = memRefType.getShape(); auto memRefShape = memRefType.getShape();
Value alloc; Value alloc;

View File

@ -15,10 +15,9 @@ struct ONNXTransposeOpLowering : public ConversionPattern {
PatternMatchResult PatternMatchResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands, matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final { ConversionPatternRewriter &rewriter) const final {
auto tensorType = (*op->result_type_begin()).cast<TensorType>();
auto loc = op->getLoc(); auto loc = op->getLoc();
// Insert an allocation and deallocation for the result of this operation. // Insert an allocation and deallocation for the result of this operation.
auto memRefType = convertTensorToMemRef(tensorType); auto memRefType = convertToMemRefType(*op->result_type_begin());
Value alloc; Value alloc;
bool insertDealloc = checkInsertDealloc(op); bool insertDealloc = checkInsertDealloc(op);

View File

@ -16,8 +16,8 @@ struct ONNXUnsqueezeOpLowering : public ConversionPattern {
matchAndRewrite(Operation *op, ArrayRef<Value> operands, matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final { ConversionPatternRewriter &rewriter) const final {
auto loc = op->getLoc(); auto loc = op->getLoc();
auto tensorType = (*op->result_type_begin()).cast<TensorType>(); auto memRefType = convertToMemRefType(*op->result_type_begin());
int outRank = tensorType.getRank(); int outRank = memRefType.getRank();
// Assume that `axes` has been validated by shape inference. // Assume that `axes` has been validated by shape inference.
// So, here we just get it. // So, here we just get it.
@ -30,7 +30,6 @@ struct ONNXUnsqueezeOpLowering : public ConversionPattern {
} }
// Insert an allocation and deallocation for the result of this operation. // Insert an allocation and deallocation for the result of this operation.
auto memRefType = convertTensorToMemRef(tensorType);
Value alloc; Value alloc;
// Compute size in bytes. // Compute size in bytes.

File diff suppressed because it is too large Load Diff