Add support of negative dimensions (#66)

Co-authored-by: Gheorghe-Teodor Bercea <gt.bercea@gmail.com>
This commit is contained in:
Tung D. Le 2020-02-12 00:37:47 +09:00 committed by GitHub
parent 181803ebf4
commit adad9e24bd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 131 additions and 44 deletions

View File

@ -1279,44 +1279,73 @@ struct ONNXReshapeOpLowering : public ConversionPattern {
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
auto tensorType = (*op->result_type_begin()).cast<TensorType>();
auto inputShape = operands[0].getType().cast<MemRefType>().getShape();
auto loc = op->getLoc();
// Insert an allocation and deallocation for the result of this operation.
auto memRefType = convertTensorToMemRef(tensorType);
auto memRefShape = memRefType.getShape();
Value alloc;
// Compute size in bytes.
// Compute size in bytes using the input tensor.
Value tensorSize = rewriter.create<ConstantOp>(
loc, rewriter.getIntegerAttr(rewriter.getIntegerType(64),
getMemRefEltSizeInBytes(memRefType)));
for (int i = 0; i < inputShape.size(); ++i) {
Value dimVal;
if (inputShape[i] < 0) {
Value dim = rewriter.create<DimOp>(loc, operands[0], i);
dimVal =
rewriter.create<IndexCastOp>(loc, dim, rewriter.getIntegerType(64));
} else {
dimVal = rewriter.create<ConstantOp>(
loc, rewriter.getIntegerAttr(rewriter.getIntegerType(64),
inputShape[i]));
}
tensorSize = rewriter.create<MulIOp>(loc, tensorSize, dimVal);
}
bool insertDealloc = checkInsertDealloc(op);
if (hasAllConstantDimensions(memRefType)) {
alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc);
} else {
auto memRefShape = memRefType.getShape();
auto inputShape = operands[0].getType().cast<MemRefType>().getShape();
SmallVector<Value, 4> allocOperands;
// If a dimension is zero, the actual dimension value is taken from the
// input tensor.
//
// If the shape array has a negative dimension (-1), we compute its actual
// dimension value from the other dimensions. But we don't have enough
// information about the other dimensions at this point. So, we need to
// scan the shape first to calculate reduction of all of the dimensions.
// If the reduction is negative, then the shape array contains a negative
// dimension. Otherwise, the reduction is the same as the one computed
// from the input tensor.
Value tensorSizeFromShape = rewriter.create<ConstantOp>(
loc, rewriter.getIntegerAttr(rewriter.getIntegerType(64),
getMemRefEltSizeInBytes(memRefType)));
SmallVector<Value, 4> DimInfo;
for (int i = 0; i < memRefShape.size(); ++i) {
// The shape array can always be used to construct shape information of
// the result.
Value index = rewriter.create<ConstantOp>(
loc, rewriter.getIntegerAttr(rewriter.getIndexType(), i));
// Load index from array of indices.
Value loadedVal = rewriter.create<LoadOp>(loc, operands[1], index);
// If a dimension is zero, the actual dimension value is taken from the
// input tensor.
//
// If a dimension is negative, it is computed from the other dimensions.
// But we don't have enough information about the other dimensions at
// this point. So, we let it as it is (-1), and compute it later.
if (i < inputShape.size()) {
Value dimVal;
auto dimTy = loadedVal.getType().cast<IntegerType>();
auto loadedValType = loadedVal.getType().cast<IntegerType>();
if (inputShape[i] < 0) {
Value dim = rewriter.create<DimOp>(loc, operands[0], i);
dimVal = rewriter.create<IndexCastOp>(loc, dim, dimTy);
dimVal = rewriter.create<IndexCastOp>(loc, dim, loadedValType);
} else {
dimVal = rewriter.create<ConstantOp>(
loc, rewriter.getIntegerAttr(dimTy, inputShape[i]));
loc, rewriter.getIntegerAttr(loadedValType, inputShape[i]));
}
auto zero = rewriter.create<ConstantOp>(
loc, rewriter.getIntegerAttr(dimTy, 0));
loc, rewriter.getIntegerAttr(loadedValType, 0));
auto isZero =
rewriter.create<CmpIOp>(loc, CmpIPredicate::eq, loadedVal, zero);
loadedVal = rewriter.create<SelectOp>(loc, isZero, dimVal, loadedVal);
@ -1327,7 +1356,34 @@ struct ONNXReshapeOpLowering : public ConversionPattern {
if (loadedVal.getType().cast<IntegerType>().getWidth() < 64)
int64LoadedVal = rewriter.create<ZeroExtendIOp>(
loc, loadedVal, rewriter.getIntegerType(64));
tensorSize = rewriter.create<MulIOp>(loc, tensorSize, int64LoadedVal);
tensorSizeFromShape =
rewriter.create<MulIOp>(loc, tensorSizeFromShape, int64LoadedVal);
// Store intermediate results to use later.
DimInfo.emplace_back(int64LoadedVal);
}
// Reverse tensorSizeFromShape since it is negative if the shape array has
// a negative dimension. This is safe since we only use it to compute the
// actual value for the negative dimension.
auto zero = rewriter.create<ConstantOp>(
loc, rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0));
tensorSizeFromShape =
rewriter.create<SubIOp>(loc, zero, tensorSizeFromShape);
// Obtain operands for AllocOp.
SmallVector<Value, 4> allocOperands;
auto negOne = rewriter.create<ConstantOp>(
loc, rewriter.getIntegerAttr(rewriter.getIntegerType(64), -1));
for (int i = 0; i < memRefShape.size(); ++i) {
auto dimVal = DimInfo[i];
auto isNegOne =
rewriter.create<CmpIOp>(loc, CmpIPredicate::eq, dimVal, negOne);
// If dimension is negative, compute its value from the other
// dimensions.
auto actualDimVal =
rewriter.create<SignedDivIOp>(loc, tensorSize, tensorSizeFromShape);
auto loadedVal =
rewriter.create<SelectOp>(loc, isNegOne, actualDimVal, dimVal);
allocOperands.push_back(rewriter.create<IndexCastOp>(
loc, loadedVal, rewriter.getIndexType()));
}

View File

@ -224,13 +224,13 @@ test_to_enable = [
# ReshapeOp:
"test_reshape_extended_dims_cpu",
#"test_reshape_negative_dim_cpu", <- handle nagative dim
#"test_reshape_negative_extended_dims_cpu", <- handle nagative dim
"test_reshape_negative_dim_cpu",
"test_reshape_negative_extended_dims_cpu",
"test_reshape_one_dim_cpu",
"test_reshape_reduced_dims_cpu",
"test_reshape_reordered_all_dims_cpu",
"test_reshape_reordered_last_dims_cpu",
#"test_reshape_zero_and_negative_dim_cpu", <- handle nagative dim
"test_reshape_zero_and_negative_dim_cpu",
"test_reshape_zero_dim_cpu",
# Transpose

View File

@ -302,38 +302,69 @@ func @test_reshape(%arg0 : tensor<?x10xf32>, %arg1 : tensor<4xi32>) -> tensor<*x
"std.return"(%0) : (tensor<*xf32>) -> ()
// CHECK-LABEL: test_reshape
// CHECK: [[TYPE_IN_BYTES:%.+]] = constant 4 : i64
// CHECK: %[[INDEX_0:.+]] = constant 0 : index
// CHECK: [[LOAD_0:%.+]] = load %arg1[%[[INDEX_0]]] : memref<4xi32>
// CHECK: [[TYPE_IN_BYTES_0:%.+]] = constant 4 : i64
// CHECK: [[DIM_0:%.+]] = dim %arg0, 0 : memref<?x10xf32>
// CHECK: [[DIM_0_CAST:%.+]] = index_cast [[DIM_0]] : index to i32
// CHECK: [[CONSTANT_0:%.+]] = constant 0 : i32
// CHECK: [[CMP:%.+]] = cmpi "eq", [[LOAD_0]], [[CONSTANT_0]] : i32
// CHECK: [[SELECT_0:%.+]] = select [[CMP]], [[DIM_0_CAST]], [[LOAD_0]] : i32
// CHECK: [[EXT_0:%.+]] = zexti [[SELECT_0]] : i32 to i64
// CHECK: [[MUL_0:%.+]] = muli [[TYPE_IN_BYTES]], [[EXT_0]] : i64
// CHECK: [[CAST_0:%.+]] = index_cast [[SELECT_0]] : i32 to index
// CHECK: %[[INDEX_1:.+]] = constant 1 : index
// CHECK: [[LOAD_1:%.+]] = load %arg1[%[[INDEX_1]]] : memref<4xi32>
// CHECK: [[CONSTANT_1:%.+]] = constant 10 : i32
// CHECK: [[DIM_0_CAST:%.+]] = index_cast [[DIM_0]] : index to i64
// CHECK: [[MUL_0:%.+]] = muli [[TYPE_IN_BYTES_0]], [[DIM_0_CAST]] : i64
// CHECK: [[CONSTANT_0:%.+]] = constant 10 : i64
// CHECK: [[TENSOR_SIZE:%.+]] = muli [[MUL_0]], [[CONSTANT_0]] : i64
// CHECK: [[TYPE_IN_BYTES_1:%.+]] = constant 4 : i64
// CHECK: %[[CONSTANT_1:.+]] = constant 0 : index
// CHECK: [[LOAD_0:%.+]] = load %arg1[%[[CONSTANT_1]]] : memref<4xi32>
// CHECK: [[DIM_1:%.+]] = dim %arg0, 0 : memref<?x10xf32>
// CHECK: [[DIM_1_CAST:%.+]] = index_cast [[DIM_1]] : index to i32
// CHECK: [[CONSTANT_2:%.+]] = constant 0 : i32
// CHECK: [[CMP_1:%.+]] = cmpi "eq", [[LOAD_1]], [[CONSTANT_2]] : i32
// CHECK: [[SELECT_1:%.+]] = select [[CMP_1]], [[CONSTANT_1]], [[LOAD_1]] : i32
// CHECK: [[EXT_1:%.+]] = zexti [[SELECT_1]] : i32 to i64
// CHECK: [[MUL_1:%.+]] = muli [[MUL_0]], [[EXT_1]] : i64
// CHECK: [[CAST_1:%.+]] = index_cast [[SELECT_1]] : i32 to index
// CHECK: %[[INDEX_2:.+]] = constant 2 : index
// CHECK: [[LOAD_2:%.+]] = load %arg1[%[[INDEX_2]]] : memref<4xi32>
// CHECK: [[EXT_2:%.+]] = zexti [[LOAD_2]] : i32 to i64
// CHECK: [[MUL_2:%.+]] = muli [[MUL_1]], [[EXT_2]] : i64
// CHECK: [[CAST_2:%.+]] = index_cast [[LOAD_2]] : i32 to index
// CHECK: %[[INDEX_3:.+]] = constant 3 : index
// CHECK: [[LOAD_3:%.+]] = load %arg1[%[[INDEX_3]]] : memref<4xi32>
// CHECK: [[EXT_3:%.+]] = zexti [[LOAD_3]] : i32 to i64
// CHECK: [[MUL_3:%.+]] = muli [[MUL_2]], [[EXT_3]] : i64
// CHECK: [[CAST_3:%.+]] = index_cast [[LOAD_3]] : i32 to index
// CHECK: [[CMP_0:%.+]] = cmpi "eq", [[LOAD_0]], [[CONSTANT_2]] : i32
// CHECK: [[SELECT_0:%.+]] = select [[CMP_0]], [[DIM_1_CAST]], [[LOAD_0]] : i32
// CHECK: [[ZEXTI_0:%.+]] = zexti [[SELECT_0]] : i32 to i64
// CHECK: [[MUL_1:%.+]] = muli [[TYPE_IN_BYTES_1]], [[ZEXTI_0]] : i64
// CHECK: %[[CONSTANT_3:.+]] = constant 1 : index
// CHECK: [[LOAD_1:%.+]] = load %arg1[%[[CONSTANT_3]]] : memref<4xi32>
// CHECK: [[CONSTANT_3:%.+]] = constant 10 : i32
// CHECK: [[CONSTANT_4:%.+]] = constant 0 : i32
// CHECK: [[CMP_1:%.+]] = cmpi "eq", [[LOAD_1]], [[CONSTANT_4]] : i32
// CHECK: [[SELECT_1:%.+]] = select [[CMP_1]], [[CONSTANT_3]], [[LOAD_1]] : i32
// CHECK: [[ZEXTI_1:%.+]] = zexti [[SELECT_1]] : i32 to i64
// CHECK: [[MUL_2:%.+]] = muli [[MUL_1]], [[ZEXTI_1]] : i64
// CHECK: %[[CONSTANT_5:.+]] = constant 2 : index
// CHECK: [[LOAD_2:%.+]] = load %arg1[%[[CONSTANT_5]]] : memref<4xi32>
// CHECK: [[ZEXTI_2:%.+]] = zexti [[LOAD_2]] : i32 to i64
// CHECK: [[MUL_3:%.+]] = muli [[MUL_2]], [[ZEXTI_2]] : i64
// CHECK: %[[CONSTANT_6:.+]] = constant 3 : index
// CHECK: [[LOAD_3:%.+]] = load %arg1[%[[CONSTANT_6]]] : memref<4xi32>
// CHECK: [[ZEXTI_3:%.+]] = zexti [[LOAD_3]] : i32 to i64
// CHECK: [[MUL_4:%.+]] = muli [[MUL_3]], [[ZEXTI_3]] : i64
// CHECK: [[CONSTANT_7:%.+]] = constant 0 : i64
// CHECK: [[SUB_0:%.+]] = subi [[CONSTANT_7]], [[MUL_4]] : i64
// CHECK: [[CONSTANT_8:%.+]] = constant -1 : i64
// CHECK: [[CMP_2:%.+]] = cmpi "eq", [[ZEXTI_0]], [[CONSTANT_8]] : i64
// CHECK: [[DIVISIGNED_0:%.+]] = divi_signed [[TENSOR_SIZE]], [[SUB_0]] : i64
// CHECK: [[SELECT_2:%.+]] = select [[CMP_2]], [[DIVISIGNED_0]], [[ZEXTI_0]] : i64
// CHECK: [[CAST_0:%.+]] = index_cast [[SELECT_2]] : i64 to index
// CHECK: [[CMP_3:%.+]] = cmpi "eq", [[ZEXTI_1]], [[CONSTANT_8]] : i64
// CHECK: [[DIVISIGNED_1:%.+]] = divi_signed [[TENSOR_SIZE]], [[SUB_0]] : i64
// CHECK: [[SELECT_3:%.+]] = select [[CMP_3]], [[DIVISIGNED_1]], [[ZEXTI_1]] : i64
// CHECK: [[CAST_1:%.+]] = index_cast [[SELECT_3]] : i64 to index
// CHECK: [[CMP_4:%.+]] = cmpi "eq", [[ZEXTI_2]], [[CONSTANT_8]] : i64
// CHECK: [[DIVISIGNED_2:%.+]] = divi_signed [[TENSOR_SIZE]], [[SUB_0]] : i64
// CHECK: [[SELECT_4:%.+]] = select [[CMP_4]], [[DIVISIGNED_2]], [[ZEXTI_2]] : i64
// CHECK: [[CAST_2:%.+]] = index_cast [[SELECT_4]] : i64 to index
// CHECK: [[CMP_5:%.+]] = cmpi "eq", [[ZEXTI_3]], [[CONSTANT_8]] : i64
// CHECK: [[DIVISIGNED_3:%.+]] = divi_signed [[TENSOR_SIZE]], [[SUB_0]] : i64
// CHECK: [[SELECT_5:%.+]] = select [[CMP_5]], [[DIVISIGNED_3]], [[ZEXTI_3]] : i64
// CHECK: [[CAST_3:%.+]] = index_cast [[SELECT_5]] : i64 to index
// CHECK: [[ALLOC:%.+]] = alloc([[CAST_0]], [[CAST_1]], [[CAST_2]], [[CAST_3]]) : memref<?x?x?x?xf32>
// CHECK: "krnl.memcpy"([[ALLOC]], %arg0, [[MUL_3]]) : (memref<?x?x?x?xf32>, memref<?x10xf32>, i64) -> ()
// CHECK: "krnl.memcpy"([[ALLOC]], %arg0, [[TENSOR_SIZE]]) : (memref<?x?x?x?xf32>, memref<?x10xf32>, i64) -> ()
// CHECK: return [[ALLOC]] : memref<?x?x?x?xf32>
}