Get MemRefType for result types (#69)
* Get memreftype for result types * Revise * Replace convertToMemRefType * Use convertToMemRefType in ONNXConvNoBiasOpLowering * Merge with the master branch * Reverse an unintentional change Co-authored-by: Gheorghe-Teodor Bercea <gt.bercea@gmail.com>
This commit is contained in:
parent
b28c6906b4
commit
a3f042220e
|
@ -37,10 +37,18 @@ static bool hasAllConstantDimensions(MemRefType type) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Convert the given TensorType into the corresponding MemRefType.
|
/// Get the corresponding MemRefType of a given TensorType/MemRefType.
|
||||||
static MemRefType convertTensorToMemRef(TensorType type) {
|
static MemRefType convertToMemRefType(Type type) {
|
||||||
assert(type.hasRank() && "expected only ranked shapes");
|
MemRefType memRefType;
|
||||||
return MemRefType::get(type.getShape(), type.getElementType());
|
auto tensorType = type.dyn_cast<TensorType>();
|
||||||
|
if (tensorType) {
|
||||||
|
assert(tensorType.hasRank() && "expected only ranked shapes");
|
||||||
|
memRefType =
|
||||||
|
MemRefType::get(tensorType.getShape(), tensorType.getElementType());
|
||||||
|
} else {
|
||||||
|
memRefType = type.dyn_cast<MemRefType>();
|
||||||
|
}
|
||||||
|
return memRefType;
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Insert an allocation and deallocation for the given MemRefType.
|
/// Insert an allocation and deallocation for the given MemRefType.
|
||||||
|
@ -430,8 +438,8 @@ struct TensorTypeConverter : public TypeConverter {
|
||||||
}
|
}
|
||||||
|
|
||||||
static LogicalResult convertType(Type t, SmallVectorImpl<Type> &results) {
|
static LogicalResult convertType(Type t, SmallVectorImpl<Type> &results) {
|
||||||
if (auto tensor_type = t.dyn_cast<TensorType>()) {
|
if (auto type = convertToMemRefType(t)) {
|
||||||
results.push_back(convertTensorToMemRef(tensor_type));
|
results.push_back(type);
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -476,11 +476,10 @@ struct ONNXElementwiseUnaryOpLowering : public ConversionPattern {
|
||||||
// TODO: Check that the types are valid.
|
// TODO: Check that the types are valid.
|
||||||
// An element-wise unary operation must have all operands and the result of
|
// An element-wise unary operation must have all operands and the result of
|
||||||
// the same type. This should have been verified by the verifier.
|
// the same type. This should have been verified by the verifier.
|
||||||
auto tensorType = (*op->result_type_begin()).cast<TensorType>();
|
|
||||||
auto loc = op->getLoc();
|
auto loc = op->getLoc();
|
||||||
|
|
||||||
// Insert an allocation and deallocation for the result of this operation.
|
// Insert an allocation and deallocation for the result of this operation.
|
||||||
auto memRefType = convertTensorToMemRef(tensorType);
|
auto memRefType = convertToMemRefType(*op->result_type_begin());
|
||||||
|
|
||||||
// If the output has a dynamic dimension, pass the operands required for
|
// If the output has a dynamic dimension, pass the operands required for
|
||||||
// each dynamic dimension to the AllocOp. The first operand of the
|
// each dynamic dimension to the AllocOp. The first operand of the
|
||||||
|
@ -545,12 +544,11 @@ struct ONNXElementwiseVariadicOpLowering : public ConversionPattern {
|
||||||
// TODO: Check that the types are valid.
|
// TODO: Check that the types are valid.
|
||||||
// An element-wise variadic operation must have all operands and the result
|
// An element-wise variadic operation must have all operands and the result
|
||||||
// of the same type. This should have been verified by the verifier.
|
// of the same type. This should have been verified by the verifier.
|
||||||
auto tensorType = (*op->result_type_begin()).cast<TensorType>();
|
|
||||||
auto loc = op->getLoc();
|
auto loc = op->getLoc();
|
||||||
auto numArgs = op->getNumOperands();
|
auto numArgs = op->getNumOperands();
|
||||||
|
|
||||||
// Insert an allocation and deallocation for the result of this operation.
|
// Insert an allocation and deallocation for the result of this operation.
|
||||||
auto memRefType = convertTensorToMemRef(tensorType);
|
auto memRefType = convertToMemRefType(*op->result_type_begin());
|
||||||
|
|
||||||
Value alloc;
|
Value alloc;
|
||||||
bool insertDealloc = checkInsertDealloc(op);
|
bool insertDealloc = checkInsertDealloc(op);
|
||||||
|
|
|
@ -15,7 +15,6 @@ struct ONNXGemmOpLowering : public ConversionPattern {
|
||||||
PatternMatchResult
|
PatternMatchResult
|
||||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||||
ConversionPatternRewriter &rewriter) const final {
|
ConversionPatternRewriter &rewriter) const final {
|
||||||
auto tensorType = (*op->result_type_begin()).cast<TensorType>();
|
|
||||||
auto loc = op->getLoc();
|
auto loc = op->getLoc();
|
||||||
|
|
||||||
Value A, B, C;
|
Value A, B, C;
|
||||||
|
@ -23,9 +22,11 @@ struct ONNXGemmOpLowering : public ConversionPattern {
|
||||||
B = operands[1];
|
B = operands[1];
|
||||||
C = operands[2];
|
C = operands[2];
|
||||||
|
|
||||||
auto alphaAttr = FloatAttr::get(tensorType.getElementType(),
|
auto memRefType = convertToMemRefType(*op->result_type_begin());
|
||||||
|
|
||||||
|
auto alphaAttr = FloatAttr::get(memRefType.getElementType(),
|
||||||
llvm::dyn_cast<ONNXGemmOp>(op).alpha().convertToFloat());
|
llvm::dyn_cast<ONNXGemmOp>(op).alpha().convertToFloat());
|
||||||
auto betaAttr = FloatAttr::get(tensorType.getElementType(),
|
auto betaAttr = FloatAttr::get(memRefType.getElementType(),
|
||||||
llvm::dyn_cast<ONNXGemmOp>(op).beta().convertToFloat());
|
llvm::dyn_cast<ONNXGemmOp>(op).beta().convertToFloat());
|
||||||
auto alpha = rewriter.create<ConstantOp>(loc, alphaAttr);
|
auto alpha = rewriter.create<ConstantOp>(loc, alphaAttr);
|
||||||
auto beta = rewriter.create<ConstantOp>(loc, betaAttr);
|
auto beta = rewriter.create<ConstantOp>(loc, betaAttr);
|
||||||
|
@ -33,9 +34,6 @@ struct ONNXGemmOpLowering : public ConversionPattern {
|
||||||
bool isTransA = (llvm::dyn_cast<ONNXGemmOp>(op).transA() != 0);
|
bool isTransA = (llvm::dyn_cast<ONNXGemmOp>(op).transA() != 0);
|
||||||
bool isTransB = (llvm::dyn_cast<ONNXGemmOp>(op).transB() != 0);
|
bool isTransB = (llvm::dyn_cast<ONNXGemmOp>(op).transB() != 0);
|
||||||
|
|
||||||
// Result type
|
|
||||||
auto memRefType = convertTensorToMemRef(tensorType);
|
|
||||||
|
|
||||||
// Insert an allocation and deallocation for the result of this operation.
|
// Insert an allocation and deallocation for the result of this operation.
|
||||||
Value alloc;
|
Value alloc;
|
||||||
bool insertDealloc = checkInsertDealloc(op);
|
bool insertDealloc = checkInsertDealloc(op);
|
||||||
|
|
|
@ -15,7 +15,6 @@ struct ONNXMatMulOpLowering : public ConversionPattern {
|
||||||
PatternMatchResult
|
PatternMatchResult
|
||||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||||
ConversionPatternRewriter &rewriter) const final {
|
ConversionPatternRewriter &rewriter) const final {
|
||||||
auto tensorType = (*op->result_type_begin()).cast<TensorType>();
|
|
||||||
auto loc = op->getLoc();
|
auto loc = op->getLoc();
|
||||||
|
|
||||||
Value A = operands[0];
|
Value A = operands[0];
|
||||||
|
@ -29,7 +28,7 @@ struct ONNXMatMulOpLowering : public ConversionPattern {
|
||||||
// - Both arguments are 1-D
|
// - Both arguments are 1-D
|
||||||
|
|
||||||
// Result type
|
// Result type
|
||||||
auto memRefType = convertTensorToMemRef(tensorType);
|
auto memRefType = convertToMemRefType(*op->result_type_begin());
|
||||||
auto elementType = memRefType.getElementType();
|
auto elementType = memRefType.getElementType();
|
||||||
auto memRefShape = memRefType.getShape();
|
auto memRefShape = memRefType.getShape();
|
||||||
|
|
||||||
|
|
|
@ -145,9 +145,9 @@ struct ONNXReductionOpLowering : public ConversionPattern {
|
||||||
auto loc = op->getLoc();
|
auto loc = op->getLoc();
|
||||||
auto memRefInType = operands[0].getType().cast<MemRefType>();
|
auto memRefInType = operands[0].getType().cast<MemRefType>();
|
||||||
auto memRefInShape = memRefInType.getShape();
|
auto memRefInShape = memRefInType.getShape();
|
||||||
auto tensorOutType = (*op->result_type_begin()).cast<TensorType>();
|
auto memRefOutType = convertToMemRefType(*op->result_type_begin());
|
||||||
int64_t inRank = memRefInType.getRank();
|
int64_t inRank = memRefInType.getRank();
|
||||||
int64_t outRank = tensorOutType.getRank();
|
int64_t outRank = memRefOutType.getRank();
|
||||||
|
|
||||||
// Get attributes
|
// Get attributes
|
||||||
ArrayAttr axisAttrs = llvm::dyn_cast<ONNXReductionOp>(op).axesAttr();
|
ArrayAttr axisAttrs = llvm::dyn_cast<ONNXReductionOp>(op).axesAttr();
|
||||||
|
@ -171,7 +171,6 @@ struct ONNXReductionOpLowering : public ConversionPattern {
|
||||||
bool isKeepdims = (keepdims == 1) ? true : false;
|
bool isKeepdims = (keepdims == 1) ? true : false;
|
||||||
|
|
||||||
// Get type information
|
// Get type information
|
||||||
auto memRefOutType = convertTensorToMemRef(tensorOutType);
|
|
||||||
auto memRefOutShape = memRefOutType.getShape();
|
auto memRefOutShape = memRefOutType.getShape();
|
||||||
auto elementOutType = memRefOutType.getElementType();
|
auto elementOutType = memRefOutType.getElementType();
|
||||||
std::map<int64_t, int64_t> outInDimMap =
|
std::map<int64_t, int64_t> outInDimMap =
|
||||||
|
|
|
@ -18,8 +18,8 @@ struct ONNXSoftmaxOpLowering : public ConversionPattern {
|
||||||
// let exp_x = exp(x - max_x) in
|
// let exp_x = exp(x - max_x) in
|
||||||
// let sum = sum(exp_x) in
|
// let sum = sum(exp_x) in
|
||||||
// exp_x / sum
|
// exp_x / sum
|
||||||
auto tensorType = (*op->result_type_begin()).cast<RankedTensorType>();
|
auto memRefType = convertToMemRefType(*op->result_type_begin());
|
||||||
int64_t rank = tensorType.getRank();
|
int64_t rank = memRefType.getRank();
|
||||||
int64_t axis = llvm::dyn_cast<ONNXSoftmaxOp>(op).axis().getSExtValue();
|
int64_t axis = llvm::dyn_cast<ONNXSoftmaxOp>(op).axis().getSExtValue();
|
||||||
axis = axis >= 0 ? axis : rank + axis;
|
axis = axis >= 0 ? axis : rank + axis;
|
||||||
assert(axis >= -rank && axis <= rank - 1);
|
assert(axis >= -rank && axis <= rank - 1);
|
||||||
|
@ -27,7 +27,6 @@ struct ONNXSoftmaxOpLowering : public ConversionPattern {
|
||||||
auto loc = op->getLoc();
|
auto loc = op->getLoc();
|
||||||
|
|
||||||
// Insert an allocation and deallocation for the result of this operation.
|
// Insert an allocation and deallocation for the result of this operation.
|
||||||
auto memRefType = convertTensorToMemRef(tensorType);
|
|
||||||
auto elementType = memRefType.getElementType();
|
auto elementType = memRefType.getElementType();
|
||||||
|
|
||||||
Value alloc;
|
Value alloc;
|
||||||
|
|
|
@ -15,10 +15,9 @@ struct ONNXConvNoBiasOpLowering : public ConversionPattern {
|
||||||
PatternMatchResult
|
PatternMatchResult
|
||||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||||
ConversionPatternRewriter &rewriter) const final {
|
ConversionPatternRewriter &rewriter) const final {
|
||||||
auto tensorType = (*op->result_type_begin()).cast<TensorType>();
|
|
||||||
auto loc = op->getLoc();
|
auto loc = op->getLoc();
|
||||||
// Insert an allocation and deallocation for the result of this operation.
|
// Insert an allocation and deallocation for the result of this operation.
|
||||||
auto memRefType = convertTensorToMemRef(tensorType);
|
auto memRefType = convertToMemRefType(*op->result_type_begin());
|
||||||
Value alloc;
|
Value alloc;
|
||||||
bool insertDealloc = checkInsertDealloc(op);
|
bool insertDealloc = checkInsertDealloc(op);
|
||||||
ONNXConvNoBiasOp convOp = llvm::dyn_cast<ONNXConvNoBiasOp>(op);
|
ONNXConvNoBiasOp convOp = llvm::dyn_cast<ONNXConvNoBiasOp>(op);
|
||||||
|
|
|
@ -15,12 +15,11 @@ struct ONNXReshapeOpLowering : public ConversionPattern {
|
||||||
PatternMatchResult
|
PatternMatchResult
|
||||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||||
ConversionPatternRewriter &rewriter) const final {
|
ConversionPatternRewriter &rewriter) const final {
|
||||||
auto tensorType = (*op->result_type_begin()).cast<TensorType>();
|
|
||||||
auto inputShape = operands[0].getType().cast<MemRefType>().getShape();
|
auto inputShape = operands[0].getType().cast<MemRefType>().getShape();
|
||||||
auto loc = op->getLoc();
|
auto loc = op->getLoc();
|
||||||
|
|
||||||
// Insert an allocation and deallocation for the result of this operation.
|
// Insert an allocation and deallocation for the result of this operation.
|
||||||
auto memRefType = convertTensorToMemRef(tensorType);
|
auto memRefType = convertToMemRefType(*op->result_type_begin());
|
||||||
auto memRefShape = memRefType.getShape();
|
auto memRefShape = memRefType.getShape();
|
||||||
Value alloc;
|
Value alloc;
|
||||||
|
|
||||||
|
|
|
@ -15,10 +15,9 @@ struct ONNXTransposeOpLowering : public ConversionPattern {
|
||||||
PatternMatchResult
|
PatternMatchResult
|
||||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||||
ConversionPatternRewriter &rewriter) const final {
|
ConversionPatternRewriter &rewriter) const final {
|
||||||
auto tensorType = (*op->result_type_begin()).cast<TensorType>();
|
|
||||||
auto loc = op->getLoc();
|
auto loc = op->getLoc();
|
||||||
// Insert an allocation and deallocation for the result of this operation.
|
// Insert an allocation and deallocation for the result of this operation.
|
||||||
auto memRefType = convertTensorToMemRef(tensorType);
|
auto memRefType = convertToMemRefType(*op->result_type_begin());
|
||||||
Value alloc;
|
Value alloc;
|
||||||
bool insertDealloc = checkInsertDealloc(op);
|
bool insertDealloc = checkInsertDealloc(op);
|
||||||
|
|
||||||
|
|
|
@ -16,8 +16,8 @@ struct ONNXUnsqueezeOpLowering : public ConversionPattern {
|
||||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||||
ConversionPatternRewriter &rewriter) const final {
|
ConversionPatternRewriter &rewriter) const final {
|
||||||
auto loc = op->getLoc();
|
auto loc = op->getLoc();
|
||||||
auto tensorType = (*op->result_type_begin()).cast<TensorType>();
|
auto memRefType = convertToMemRefType(*op->result_type_begin());
|
||||||
int outRank = tensorType.getRank();
|
int outRank = memRefType.getRank();
|
||||||
|
|
||||||
// Assume that `axes` has been validated by shape inference.
|
// Assume that `axes` has been validated by shape inference.
|
||||||
// So, here we just get it.
|
// So, here we just get it.
|
||||||
|
@ -30,7 +30,6 @@ struct ONNXUnsqueezeOpLowering : public ConversionPattern {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Insert an allocation and deallocation for the result of this operation.
|
// Insert an allocation and deallocation for the result of this operation.
|
||||||
auto memRefType = convertTensorToMemRef(tensorType);
|
|
||||||
Value alloc;
|
Value alloc;
|
||||||
|
|
||||||
// Compute size in bytes.
|
// Compute size in bytes.
|
||||||
|
|
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue