Get MemRefType for result types (#69)
* Get memreftype for result types * Revise * Replace convertToMemRefType * Use convertToMemRefType in ONNXConvNoBiasOpLowering * Merge with the master branch * Reverse an unintentional change Co-authored-by: Gheorghe-Teodor Bercea <gt.bercea@gmail.com>
This commit is contained in:
parent
b28c6906b4
commit
a3f042220e
|
@ -37,10 +37,18 @@ static bool hasAllConstantDimensions(MemRefType type) {
|
|||
return true;
|
||||
}
|
||||
|
||||
/// Convert the given TensorType into the corresponding MemRefType.
|
||||
static MemRefType convertTensorToMemRef(TensorType type) {
|
||||
assert(type.hasRank() && "expected only ranked shapes");
|
||||
return MemRefType::get(type.getShape(), type.getElementType());
|
||||
/// Get the corresponding MemRefType of a given TensorType/MemRefType.
|
||||
static MemRefType convertToMemRefType(Type type) {
|
||||
MemRefType memRefType;
|
||||
auto tensorType = type.dyn_cast<TensorType>();
|
||||
if (tensorType) {
|
||||
assert(tensorType.hasRank() && "expected only ranked shapes");
|
||||
memRefType =
|
||||
MemRefType::get(tensorType.getShape(), tensorType.getElementType());
|
||||
} else {
|
||||
memRefType = type.dyn_cast<MemRefType>();
|
||||
}
|
||||
return memRefType;
|
||||
}
|
||||
|
||||
/// Insert an allocation and deallocation for the given MemRefType.
|
||||
|
@ -430,8 +438,8 @@ struct TensorTypeConverter : public TypeConverter {
|
|||
}
|
||||
|
||||
static LogicalResult convertType(Type t, SmallVectorImpl<Type> &results) {
|
||||
if (auto tensor_type = t.dyn_cast<TensorType>()) {
|
||||
results.push_back(convertTensorToMemRef(tensor_type));
|
||||
if (auto type = convertToMemRefType(t)) {
|
||||
results.push_back(type);
|
||||
return success();
|
||||
}
|
||||
|
||||
|
|
|
@ -476,11 +476,10 @@ struct ONNXElementwiseUnaryOpLowering : public ConversionPattern {
|
|||
// TODO: Check that the types are valid.
|
||||
// An element-wise unary operation must have all operands and the result of
|
||||
// the same type. This should have been verified by the verifier.
|
||||
auto tensorType = (*op->result_type_begin()).cast<TensorType>();
|
||||
auto loc = op->getLoc();
|
||||
|
||||
// Insert an allocation and deallocation for the result of this operation.
|
||||
auto memRefType = convertTensorToMemRef(tensorType);
|
||||
auto memRefType = convertToMemRefType(*op->result_type_begin());
|
||||
|
||||
// If the output has a dynamic dimension, pass the operands required for
|
||||
// each dynamic dimension to the AllocOp. The first operand of the
|
||||
|
@ -545,12 +544,11 @@ struct ONNXElementwiseVariadicOpLowering : public ConversionPattern {
|
|||
// TODO: Check that the types are valid.
|
||||
// An element-wise variadic operation must have all operands and the result
|
||||
// of the same type. This should have been verified by the verifier.
|
||||
auto tensorType = (*op->result_type_begin()).cast<TensorType>();
|
||||
auto loc = op->getLoc();
|
||||
auto numArgs = op->getNumOperands();
|
||||
|
||||
// Insert an allocation and deallocation for the result of this operation.
|
||||
auto memRefType = convertTensorToMemRef(tensorType);
|
||||
auto memRefType = convertToMemRefType(*op->result_type_begin());
|
||||
|
||||
Value alloc;
|
||||
bool insertDealloc = checkInsertDealloc(op);
|
||||
|
|
|
@ -15,7 +15,6 @@ struct ONNXGemmOpLowering : public ConversionPattern {
|
|||
PatternMatchResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const final {
|
||||
auto tensorType = (*op->result_type_begin()).cast<TensorType>();
|
||||
auto loc = op->getLoc();
|
||||
|
||||
Value A, B, C;
|
||||
|
@ -23,9 +22,11 @@ struct ONNXGemmOpLowering : public ConversionPattern {
|
|||
B = operands[1];
|
||||
C = operands[2];
|
||||
|
||||
auto alphaAttr = FloatAttr::get(tensorType.getElementType(),
|
||||
auto memRefType = convertToMemRefType(*op->result_type_begin());
|
||||
|
||||
auto alphaAttr = FloatAttr::get(memRefType.getElementType(),
|
||||
llvm::dyn_cast<ONNXGemmOp>(op).alpha().convertToFloat());
|
||||
auto betaAttr = FloatAttr::get(tensorType.getElementType(),
|
||||
auto betaAttr = FloatAttr::get(memRefType.getElementType(),
|
||||
llvm::dyn_cast<ONNXGemmOp>(op).beta().convertToFloat());
|
||||
auto alpha = rewriter.create<ConstantOp>(loc, alphaAttr);
|
||||
auto beta = rewriter.create<ConstantOp>(loc, betaAttr);
|
||||
|
@ -33,9 +34,6 @@ struct ONNXGemmOpLowering : public ConversionPattern {
|
|||
bool isTransA = (llvm::dyn_cast<ONNXGemmOp>(op).transA() != 0);
|
||||
bool isTransB = (llvm::dyn_cast<ONNXGemmOp>(op).transB() != 0);
|
||||
|
||||
// Result type
|
||||
auto memRefType = convertTensorToMemRef(tensorType);
|
||||
|
||||
// Insert an allocation and deallocation for the result of this operation.
|
||||
Value alloc;
|
||||
bool insertDealloc = checkInsertDealloc(op);
|
||||
|
|
|
@ -15,7 +15,6 @@ struct ONNXMatMulOpLowering : public ConversionPattern {
|
|||
PatternMatchResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const final {
|
||||
auto tensorType = (*op->result_type_begin()).cast<TensorType>();
|
||||
auto loc = op->getLoc();
|
||||
|
||||
Value A = operands[0];
|
||||
|
@ -29,7 +28,7 @@ struct ONNXMatMulOpLowering : public ConversionPattern {
|
|||
// - Both arguments are 1-D
|
||||
|
||||
// Result type
|
||||
auto memRefType = convertTensorToMemRef(tensorType);
|
||||
auto memRefType = convertToMemRefType(*op->result_type_begin());
|
||||
auto elementType = memRefType.getElementType();
|
||||
auto memRefShape = memRefType.getShape();
|
||||
|
||||
|
|
|
@ -145,9 +145,9 @@ struct ONNXReductionOpLowering : public ConversionPattern {
|
|||
auto loc = op->getLoc();
|
||||
auto memRefInType = operands[0].getType().cast<MemRefType>();
|
||||
auto memRefInShape = memRefInType.getShape();
|
||||
auto tensorOutType = (*op->result_type_begin()).cast<TensorType>();
|
||||
auto memRefOutType = convertToMemRefType(*op->result_type_begin());
|
||||
int64_t inRank = memRefInType.getRank();
|
||||
int64_t outRank = tensorOutType.getRank();
|
||||
int64_t outRank = memRefOutType.getRank();
|
||||
|
||||
// Get attributes
|
||||
ArrayAttr axisAttrs = llvm::dyn_cast<ONNXReductionOp>(op).axesAttr();
|
||||
|
@ -171,7 +171,6 @@ struct ONNXReductionOpLowering : public ConversionPattern {
|
|||
bool isKeepdims = (keepdims == 1) ? true : false;
|
||||
|
||||
// Get type information
|
||||
auto memRefOutType = convertTensorToMemRef(tensorOutType);
|
||||
auto memRefOutShape = memRefOutType.getShape();
|
||||
auto elementOutType = memRefOutType.getElementType();
|
||||
std::map<int64_t, int64_t> outInDimMap =
|
||||
|
|
|
@ -18,8 +18,8 @@ struct ONNXSoftmaxOpLowering : public ConversionPattern {
|
|||
// let exp_x = exp(x - max_x) in
|
||||
// let sum = sum(exp_x) in
|
||||
// exp_x / sum
|
||||
auto tensorType = (*op->result_type_begin()).cast<RankedTensorType>();
|
||||
int64_t rank = tensorType.getRank();
|
||||
auto memRefType = convertToMemRefType(*op->result_type_begin());
|
||||
int64_t rank = memRefType.getRank();
|
||||
int64_t axis = llvm::dyn_cast<ONNXSoftmaxOp>(op).axis().getSExtValue();
|
||||
axis = axis >= 0 ? axis : rank + axis;
|
||||
assert(axis >= -rank && axis <= rank - 1);
|
||||
|
@ -27,7 +27,6 @@ struct ONNXSoftmaxOpLowering : public ConversionPattern {
|
|||
auto loc = op->getLoc();
|
||||
|
||||
// Insert an allocation and deallocation for the result of this operation.
|
||||
auto memRefType = convertTensorToMemRef(tensorType);
|
||||
auto elementType = memRefType.getElementType();
|
||||
|
||||
Value alloc;
|
||||
|
|
|
@ -15,10 +15,9 @@ struct ONNXConvNoBiasOpLowering : public ConversionPattern {
|
|||
PatternMatchResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const final {
|
||||
auto tensorType = (*op->result_type_begin()).cast<TensorType>();
|
||||
auto loc = op->getLoc();
|
||||
// Insert an allocation and deallocation for the result of this operation.
|
||||
auto memRefType = convertTensorToMemRef(tensorType);
|
||||
auto memRefType = convertToMemRefType(*op->result_type_begin());
|
||||
Value alloc;
|
||||
bool insertDealloc = checkInsertDealloc(op);
|
||||
ONNXConvNoBiasOp convOp = llvm::dyn_cast<ONNXConvNoBiasOp>(op);
|
||||
|
|
|
@ -15,12 +15,11 @@ struct ONNXReshapeOpLowering : public ConversionPattern {
|
|||
PatternMatchResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const final {
|
||||
auto tensorType = (*op->result_type_begin()).cast<TensorType>();
|
||||
auto inputShape = operands[0].getType().cast<MemRefType>().getShape();
|
||||
auto loc = op->getLoc();
|
||||
|
||||
// Insert an allocation and deallocation for the result of this operation.
|
||||
auto memRefType = convertTensorToMemRef(tensorType);
|
||||
auto memRefType = convertToMemRefType(*op->result_type_begin());
|
||||
auto memRefShape = memRefType.getShape();
|
||||
Value alloc;
|
||||
|
||||
|
|
|
@ -15,10 +15,9 @@ struct ONNXTransposeOpLowering : public ConversionPattern {
|
|||
PatternMatchResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const final {
|
||||
auto tensorType = (*op->result_type_begin()).cast<TensorType>();
|
||||
auto loc = op->getLoc();
|
||||
// Insert an allocation and deallocation for the result of this operation.
|
||||
auto memRefType = convertTensorToMemRef(tensorType);
|
||||
auto memRefType = convertToMemRefType(*op->result_type_begin());
|
||||
Value alloc;
|
||||
bool insertDealloc = checkInsertDealloc(op);
|
||||
|
||||
|
|
|
@ -16,8 +16,8 @@ struct ONNXUnsqueezeOpLowering : public ConversionPattern {
|
|||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const final {
|
||||
auto loc = op->getLoc();
|
||||
auto tensorType = (*op->result_type_begin()).cast<TensorType>();
|
||||
int outRank = tensorType.getRank();
|
||||
auto memRefType = convertToMemRefType(*op->result_type_begin());
|
||||
int outRank = memRefType.getRank();
|
||||
|
||||
// Assume that `axes` has been validated by shape inference.
|
||||
// So, here we just get it.
|
||||
|
@ -30,7 +30,6 @@ struct ONNXUnsqueezeOpLowering : public ConversionPattern {
|
|||
}
|
||||
|
||||
// Insert an allocation and deallocation for the result of this operation.
|
||||
auto memRefType = convertTensorToMemRef(tensorType);
|
||||
Value alloc;
|
||||
|
||||
// Compute size in bytes.
|
||||
|
|
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue