Add support of negative dimensions (#66)
Co-authored-by: Gheorghe-Teodor Bercea <gt.bercea@gmail.com>
This commit is contained in:
parent
181803ebf4
commit
adad9e24bd
|
@ -1279,44 +1279,73 @@ struct ONNXReshapeOpLowering : public ConversionPattern {
|
|||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const final {
|
||||
auto tensorType = (*op->result_type_begin()).cast<TensorType>();
|
||||
auto inputShape = operands[0].getType().cast<MemRefType>().getShape();
|
||||
auto loc = op->getLoc();
|
||||
|
||||
// Insert an allocation and deallocation for the result of this operation.
|
||||
auto memRefType = convertTensorToMemRef(tensorType);
|
||||
auto memRefShape = memRefType.getShape();
|
||||
Value alloc;
|
||||
|
||||
// Compute size in bytes.
|
||||
// Compute size in bytes using the input tensor.
|
||||
Value tensorSize = rewriter.create<ConstantOp>(
|
||||
loc, rewriter.getIntegerAttr(rewriter.getIntegerType(64),
|
||||
getMemRefEltSizeInBytes(memRefType)));
|
||||
for (int i = 0; i < inputShape.size(); ++i) {
|
||||
Value dimVal;
|
||||
if (inputShape[i] < 0) {
|
||||
Value dim = rewriter.create<DimOp>(loc, operands[0], i);
|
||||
dimVal =
|
||||
rewriter.create<IndexCastOp>(loc, dim, rewriter.getIntegerType(64));
|
||||
} else {
|
||||
dimVal = rewriter.create<ConstantOp>(
|
||||
loc, rewriter.getIntegerAttr(rewriter.getIntegerType(64),
|
||||
inputShape[i]));
|
||||
}
|
||||
tensorSize = rewriter.create<MulIOp>(loc, tensorSize, dimVal);
|
||||
}
|
||||
|
||||
bool insertDealloc = checkInsertDealloc(op);
|
||||
if (hasAllConstantDimensions(memRefType)) {
|
||||
alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc);
|
||||
} else {
|
||||
auto memRefShape = memRefType.getShape();
|
||||
auto inputShape = operands[0].getType().cast<MemRefType>().getShape();
|
||||
SmallVector<Value, 4> allocOperands;
|
||||
// If a dimension is zero, the actual dimension value is taken from the
|
||||
// input tensor.
|
||||
//
|
||||
// If the shape array has a negative dimension (-1), we compute its actual
|
||||
// dimension value from the other dimensions. But we don't have enough
|
||||
// information about the other dimensions at this point. So, we need to
|
||||
// scan the shape first to calculate reduction of all of the dimensions.
|
||||
// If the reduction is negative, then the shape array contains a negative
|
||||
// dimension. Otherwise, the reduction is the same as the one computed
|
||||
// from the input tensor.
|
||||
Value tensorSizeFromShape = rewriter.create<ConstantOp>(
|
||||
loc, rewriter.getIntegerAttr(rewriter.getIntegerType(64),
|
||||
getMemRefEltSizeInBytes(memRefType)));
|
||||
SmallVector<Value, 4> DimInfo;
|
||||
for (int i = 0; i < memRefShape.size(); ++i) {
|
||||
// The shape array can always be used to construct shape information of
|
||||
// the result.
|
||||
Value index = rewriter.create<ConstantOp>(
|
||||
loc, rewriter.getIntegerAttr(rewriter.getIndexType(), i));
|
||||
// Load index from array of indices.
|
||||
Value loadedVal = rewriter.create<LoadOp>(loc, operands[1], index);
|
||||
// If a dimension is zero, the actual dimension value is taken from the
|
||||
// input tensor.
|
||||
//
|
||||
// If a dimension is negative, it is computed from the other dimensions.
|
||||
// But we don't have enough information about the other dimensions at
|
||||
// this point. So, we let it as it is (-1), and compute it later.
|
||||
if (i < inputShape.size()) {
|
||||
Value dimVal;
|
||||
auto dimTy = loadedVal.getType().cast<IntegerType>();
|
||||
auto loadedValType = loadedVal.getType().cast<IntegerType>();
|
||||
if (inputShape[i] < 0) {
|
||||
Value dim = rewriter.create<DimOp>(loc, operands[0], i);
|
||||
dimVal = rewriter.create<IndexCastOp>(loc, dim, dimTy);
|
||||
dimVal = rewriter.create<IndexCastOp>(loc, dim, loadedValType);
|
||||
} else {
|
||||
dimVal = rewriter.create<ConstantOp>(
|
||||
loc, rewriter.getIntegerAttr(dimTy, inputShape[i]));
|
||||
loc, rewriter.getIntegerAttr(loadedValType, inputShape[i]));
|
||||
}
|
||||
auto zero = rewriter.create<ConstantOp>(
|
||||
loc, rewriter.getIntegerAttr(dimTy, 0));
|
||||
loc, rewriter.getIntegerAttr(loadedValType, 0));
|
||||
auto isZero =
|
||||
rewriter.create<CmpIOp>(loc, CmpIPredicate::eq, loadedVal, zero);
|
||||
loadedVal = rewriter.create<SelectOp>(loc, isZero, dimVal, loadedVal);
|
||||
|
@ -1327,9 +1356,36 @@ struct ONNXReshapeOpLowering : public ConversionPattern {
|
|||
if (loadedVal.getType().cast<IntegerType>().getWidth() < 64)
|
||||
int64LoadedVal = rewriter.create<ZeroExtendIOp>(
|
||||
loc, loadedVal, rewriter.getIntegerType(64));
|
||||
tensorSize = rewriter.create<MulIOp>(loc, tensorSize, int64LoadedVal);
|
||||
tensorSizeFromShape =
|
||||
rewriter.create<MulIOp>(loc, tensorSizeFromShape, int64LoadedVal);
|
||||
// Store intermediate results to use later.
|
||||
DimInfo.emplace_back(int64LoadedVal);
|
||||
}
|
||||
// Reverse tensorSizeFromShape since it is negative if the shape array has
|
||||
// a negative dimension. This is safe since we only use it to compute the
|
||||
// actual value for the negative dimension.
|
||||
auto zero = rewriter.create<ConstantOp>(
|
||||
loc, rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0));
|
||||
tensorSizeFromShape =
|
||||
rewriter.create<SubIOp>(loc, zero, tensorSizeFromShape);
|
||||
|
||||
// Obtain operands for AllocOp.
|
||||
SmallVector<Value, 4> allocOperands;
|
||||
auto negOne = rewriter.create<ConstantOp>(
|
||||
loc, rewriter.getIntegerAttr(rewriter.getIntegerType(64), -1));
|
||||
|
||||
for (int i = 0; i < memRefShape.size(); ++i) {
|
||||
auto dimVal = DimInfo[i];
|
||||
auto isNegOne =
|
||||
rewriter.create<CmpIOp>(loc, CmpIPredicate::eq, dimVal, negOne);
|
||||
// If dimension is negative, compute its value from the other
|
||||
// dimensions.
|
||||
auto actualDimVal =
|
||||
rewriter.create<SignedDivIOp>(loc, tensorSize, tensorSizeFromShape);
|
||||
auto loadedVal =
|
||||
rewriter.create<SelectOp>(loc, isNegOne, actualDimVal, dimVal);
|
||||
allocOperands.push_back(rewriter.create<IndexCastOp>(
|
||||
loc, loadedVal, rewriter.getIndexType()));
|
||||
loc, loadedVal, rewriter.getIndexType()));
|
||||
}
|
||||
AllocOp allocateMemref =
|
||||
rewriter.create<AllocOp>(loc, memRefType, allocOperands);
|
||||
|
|
|
@ -224,13 +224,13 @@ test_to_enable = [
|
|||
|
||||
# ReshapeOp:
|
||||
"test_reshape_extended_dims_cpu",
|
||||
#"test_reshape_negative_dim_cpu", <- handle nagative dim
|
||||
#"test_reshape_negative_extended_dims_cpu", <- handle nagative dim
|
||||
"test_reshape_negative_dim_cpu",
|
||||
"test_reshape_negative_extended_dims_cpu",
|
||||
"test_reshape_one_dim_cpu",
|
||||
"test_reshape_reduced_dims_cpu",
|
||||
"test_reshape_reordered_all_dims_cpu",
|
||||
"test_reshape_reordered_last_dims_cpu",
|
||||
#"test_reshape_zero_and_negative_dim_cpu", <- handle nagative dim
|
||||
"test_reshape_zero_and_negative_dim_cpu",
|
||||
"test_reshape_zero_dim_cpu",
|
||||
|
||||
# Transpose
|
||||
|
|
|
@ -302,38 +302,69 @@ func @test_reshape(%arg0 : tensor<?x10xf32>, %arg1 : tensor<4xi32>) -> tensor<*x
|
|||
"std.return"(%0) : (tensor<*xf32>) -> ()
|
||||
|
||||
// CHECK-LABEL: test_reshape
|
||||
// CHECK: [[TYPE_IN_BYTES:%.+]] = constant 4 : i64
|
||||
// CHECK: %[[INDEX_0:.+]] = constant 0 : index
|
||||
// CHECK: [[LOAD_0:%.+]] = load %arg1[%[[INDEX_0]]] : memref<4xi32>
|
||||
// CHECK: [[TYPE_IN_BYTES_0:%.+]] = constant 4 : i64
|
||||
// CHECK: [[DIM_0:%.+]] = dim %arg0, 0 : memref<?x10xf32>
|
||||
// CHECK: [[DIM_0_CAST:%.+]] = index_cast [[DIM_0]] : index to i32
|
||||
// CHECK: [[CONSTANT_0:%.+]] = constant 0 : i32
|
||||
// CHECK: [[CMP:%.+]] = cmpi "eq", [[LOAD_0]], [[CONSTANT_0]] : i32
|
||||
// CHECK: [[SELECT_0:%.+]] = select [[CMP]], [[DIM_0_CAST]], [[LOAD_0]] : i32
|
||||
// CHECK: [[EXT_0:%.+]] = zexti [[SELECT_0]] : i32 to i64
|
||||
// CHECK: [[MUL_0:%.+]] = muli [[TYPE_IN_BYTES]], [[EXT_0]] : i64
|
||||
// CHECK: [[CAST_0:%.+]] = index_cast [[SELECT_0]] : i32 to index
|
||||
// CHECK: %[[INDEX_1:.+]] = constant 1 : index
|
||||
// CHECK: [[LOAD_1:%.+]] = load %arg1[%[[INDEX_1]]] : memref<4xi32>
|
||||
// CHECK: [[CONSTANT_1:%.+]] = constant 10 : i32
|
||||
// CHECK: [[DIM_0_CAST:%.+]] = index_cast [[DIM_0]] : index to i64
|
||||
// CHECK: [[MUL_0:%.+]] = muli [[TYPE_IN_BYTES_0]], [[DIM_0_CAST]] : i64
|
||||
// CHECK: [[CONSTANT_0:%.+]] = constant 10 : i64
|
||||
// CHECK: [[TENSOR_SIZE:%.+]] = muli [[MUL_0]], [[CONSTANT_0]] : i64
|
||||
|
||||
// CHECK: [[TYPE_IN_BYTES_1:%.+]] = constant 4 : i64
|
||||
// CHECK: %[[CONSTANT_1:.+]] = constant 0 : index
|
||||
// CHECK: [[LOAD_0:%.+]] = load %arg1[%[[CONSTANT_1]]] : memref<4xi32>
|
||||
// CHECK: [[DIM_1:%.+]] = dim %arg0, 0 : memref<?x10xf32>
|
||||
// CHECK: [[DIM_1_CAST:%.+]] = index_cast [[DIM_1]] : index to i32
|
||||
// CHECK: [[CONSTANT_2:%.+]] = constant 0 : i32
|
||||
// CHECK: [[CMP_1:%.+]] = cmpi "eq", [[LOAD_1]], [[CONSTANT_2]] : i32
|
||||
// CHECK: [[SELECT_1:%.+]] = select [[CMP_1]], [[CONSTANT_1]], [[LOAD_1]] : i32
|
||||
// CHECK: [[EXT_1:%.+]] = zexti [[SELECT_1]] : i32 to i64
|
||||
// CHECK: [[MUL_1:%.+]] = muli [[MUL_0]], [[EXT_1]] : i64
|
||||
// CHECK: [[CAST_1:%.+]] = index_cast [[SELECT_1]] : i32 to index
|
||||
// CHECK: %[[INDEX_2:.+]] = constant 2 : index
|
||||
// CHECK: [[LOAD_2:%.+]] = load %arg1[%[[INDEX_2]]] : memref<4xi32>
|
||||
// CHECK: [[EXT_2:%.+]] = zexti [[LOAD_2]] : i32 to i64
|
||||
// CHECK: [[MUL_2:%.+]] = muli [[MUL_1]], [[EXT_2]] : i64
|
||||
// CHECK: [[CAST_2:%.+]] = index_cast [[LOAD_2]] : i32 to index
|
||||
// CHECK: %[[INDEX_3:.+]] = constant 3 : index
|
||||
// CHECK: [[LOAD_3:%.+]] = load %arg1[%[[INDEX_3]]] : memref<4xi32>
|
||||
// CHECK: [[EXT_3:%.+]] = zexti [[LOAD_3]] : i32 to i64
|
||||
// CHECK: [[MUL_3:%.+]] = muli [[MUL_2]], [[EXT_3]] : i64
|
||||
// CHECK: [[CAST_3:%.+]] = index_cast [[LOAD_3]] : i32 to index
|
||||
// CHECK: [[CMP_0:%.+]] = cmpi "eq", [[LOAD_0]], [[CONSTANT_2]] : i32
|
||||
// CHECK: [[SELECT_0:%.+]] = select [[CMP_0]], [[DIM_1_CAST]], [[LOAD_0]] : i32
|
||||
// CHECK: [[ZEXTI_0:%.+]] = zexti [[SELECT_0]] : i32 to i64
|
||||
// CHECK: [[MUL_1:%.+]] = muli [[TYPE_IN_BYTES_1]], [[ZEXTI_0]] : i64
|
||||
|
||||
// CHECK: %[[CONSTANT_3:.+]] = constant 1 : index
|
||||
// CHECK: [[LOAD_1:%.+]] = load %arg1[%[[CONSTANT_3]]] : memref<4xi32>
|
||||
// CHECK: [[CONSTANT_3:%.+]] = constant 10 : i32
|
||||
// CHECK: [[CONSTANT_4:%.+]] = constant 0 : i32
|
||||
// CHECK: [[CMP_1:%.+]] = cmpi "eq", [[LOAD_1]], [[CONSTANT_4]] : i32
|
||||
// CHECK: [[SELECT_1:%.+]] = select [[CMP_1]], [[CONSTANT_3]], [[LOAD_1]] : i32
|
||||
// CHECK: [[ZEXTI_1:%.+]] = zexti [[SELECT_1]] : i32 to i64
|
||||
// CHECK: [[MUL_2:%.+]] = muli [[MUL_1]], [[ZEXTI_1]] : i64
|
||||
|
||||
// CHECK: %[[CONSTANT_5:.+]] = constant 2 : index
|
||||
// CHECK: [[LOAD_2:%.+]] = load %arg1[%[[CONSTANT_5]]] : memref<4xi32>
|
||||
// CHECK: [[ZEXTI_2:%.+]] = zexti [[LOAD_2]] : i32 to i64
|
||||
// CHECK: [[MUL_3:%.+]] = muli [[MUL_2]], [[ZEXTI_2]] : i64
|
||||
|
||||
// CHECK: %[[CONSTANT_6:.+]] = constant 3 : index
|
||||
// CHECK: [[LOAD_3:%.+]] = load %arg1[%[[CONSTANT_6]]] : memref<4xi32>
|
||||
// CHECK: [[ZEXTI_3:%.+]] = zexti [[LOAD_3]] : i32 to i64
|
||||
// CHECK: [[MUL_4:%.+]] = muli [[MUL_3]], [[ZEXTI_3]] : i64
|
||||
|
||||
// CHECK: [[CONSTANT_7:%.+]] = constant 0 : i64
|
||||
// CHECK: [[SUB_0:%.+]] = subi [[CONSTANT_7]], [[MUL_4]] : i64
|
||||
|
||||
// CHECK: [[CONSTANT_8:%.+]] = constant -1 : i64
|
||||
// CHECK: [[CMP_2:%.+]] = cmpi "eq", [[ZEXTI_0]], [[CONSTANT_8]] : i64
|
||||
// CHECK: [[DIVISIGNED_0:%.+]] = divi_signed [[TENSOR_SIZE]], [[SUB_0]] : i64
|
||||
// CHECK: [[SELECT_2:%.+]] = select [[CMP_2]], [[DIVISIGNED_0]], [[ZEXTI_0]] : i64
|
||||
// CHECK: [[CAST_0:%.+]] = index_cast [[SELECT_2]] : i64 to index
|
||||
|
||||
// CHECK: [[CMP_3:%.+]] = cmpi "eq", [[ZEXTI_1]], [[CONSTANT_8]] : i64
|
||||
// CHECK: [[DIVISIGNED_1:%.+]] = divi_signed [[TENSOR_SIZE]], [[SUB_0]] : i64
|
||||
// CHECK: [[SELECT_3:%.+]] = select [[CMP_3]], [[DIVISIGNED_1]], [[ZEXTI_1]] : i64
|
||||
// CHECK: [[CAST_1:%.+]] = index_cast [[SELECT_3]] : i64 to index
|
||||
|
||||
// CHECK: [[CMP_4:%.+]] = cmpi "eq", [[ZEXTI_2]], [[CONSTANT_8]] : i64
|
||||
// CHECK: [[DIVISIGNED_2:%.+]] = divi_signed [[TENSOR_SIZE]], [[SUB_0]] : i64
|
||||
// CHECK: [[SELECT_4:%.+]] = select [[CMP_4]], [[DIVISIGNED_2]], [[ZEXTI_2]] : i64
|
||||
// CHECK: [[CAST_2:%.+]] = index_cast [[SELECT_4]] : i64 to index
|
||||
|
||||
// CHECK: [[CMP_5:%.+]] = cmpi "eq", [[ZEXTI_3]], [[CONSTANT_8]] : i64
|
||||
// CHECK: [[DIVISIGNED_3:%.+]] = divi_signed [[TENSOR_SIZE]], [[SUB_0]] : i64
|
||||
// CHECK: [[SELECT_5:%.+]] = select [[CMP_5]], [[DIVISIGNED_3]], [[ZEXTI_3]] : i64
|
||||
// CHECK: [[CAST_3:%.+]] = index_cast [[SELECT_5]] : i64 to index
|
||||
|
||||
// CHECK: [[ALLOC:%.+]] = alloc([[CAST_0]], [[CAST_1]], [[CAST_2]], [[CAST_3]]) : memref<?x?x?x?xf32>
|
||||
// CHECK: "krnl.memcpy"([[ALLOC]], %arg0, [[MUL_3]]) : (memref<?x?x?x?xf32>, memref<?x10xf32>, i64) -> ()
|
||||
// CHECK: "krnl.memcpy"([[ALLOC]], %arg0, [[TENSOR_SIZE]]) : (memref<?x?x?x?xf32>, memref<?x10xf32>, i64) -> ()
|
||||
// CHECK: return [[ALLOC]] : memref<?x?x?x?xf32>
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue