Support dimension zero in reshape (#55)

Co-authored-by: Gheorghe-Teodor Bercea <gt.bercea@gmail.com>
This commit is contained in:
Tung D. Le 2020-01-30 00:41:09 +09:00 committed by GitHub
parent f3047943a1
commit 5b44169aaa
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 40 additions and 4 deletions

View File

@ -1118,6 +1118,7 @@ struct ONNXReshapeOpLowering : public ConversionPattern {
alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc); alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc);
} else { } else {
auto memRefShape = memRefType.getShape(); auto memRefShape = memRefType.getShape();
auto inputShape = operands[0].getType().cast<MemRefType>().getShape();
SmallVector<Value, 4> allocOperands; SmallVector<Value, 4> allocOperands;
for (int i = 0; i < memRefShape.size(); ++i) { for (int i = 0; i < memRefShape.size(); ++i) {
// The shape array can always be used to construct shape information of // The shape array can always be used to construct shape information of
@ -1126,6 +1127,24 @@ struct ONNXReshapeOpLowering : public ConversionPattern {
loc, rewriter.getIntegerAttr(rewriter.getIndexType(), i)); loc, rewriter.getIntegerAttr(rewriter.getIndexType(), i));
// Load index from array of indices. // Load index from array of indices.
Value loadedVal = rewriter.create<LoadOp>(loc, operands[1], index); Value loadedVal = rewriter.create<LoadOp>(loc, operands[1], index);
// If a dimension is zero, the actual dimension value is taken from the
// input tensor.
if (i < inputShape.size()) {
Value dimVal;
auto dimTy = loadedVal.getType().cast<IntegerType>();
if (inputShape[i] < 0) {
Value dim = rewriter.create<DimOp>(loc, operands[0], i);
dimVal = rewriter.create<IndexCastOp>(loc, dim, dimTy);
} else {
dimVal = rewriter.create<ConstantOp>(
loc, rewriter.getIntegerAttr(dimTy, inputShape[i]));
}
auto zero = rewriter.create<ConstantOp>(
loc, rewriter.getIntegerAttr(dimTy, 0));
auto isZero =
rewriter.create<CmpIOp>(loc, CmpIPredicate::eq, loadedVal, zero);
loadedVal = rewriter.create<SelectOp>(loc, isZero, dimVal, loadedVal);
}
// Check if the loaded index is already the correct width of 64 bits. // Check if the loaded index is already the correct width of 64 bits.
// Convert the value to a 64 bit integer if needed. // Convert the value to a 64 bit integer if needed.
Value int64LoadedVal = loadedVal; Value int64LoadedVal = loadedVal;

View File

@ -160,7 +160,15 @@ test_to_enable = [
"test_softsign_example_cpu", "test_softsign_example_cpu",
# ReshapeOp: # ReshapeOp:
"test_reshape_extended_dims_cpu",
#"test_reshape_negative_dim_cpu", <- handle nagative dim
#"test_reshape_negative_extended_dims_cpu", <- handle nagative dim
"test_reshape_one_dim_cpu",
"test_reshape_reduced_dims_cpu",
"test_reshape_reordered_all_dims_cpu", "test_reshape_reordered_all_dims_cpu",
"test_reshape_reordered_last_dims_cpu",
#"test_reshape_zero_and_negative_dim_cpu", <- handle nagative dim
"test_reshape_zero_dim_cpu",
] ]
# Extract name of all test cases. # Extract name of all test cases.

View File

@ -305,14 +305,23 @@ func @test_reshape(%arg0 : tensor<?x10xf32>, %arg1 : tensor<4xi32>) -> tensor<*x
// CHECK: [[TYPE_IN_BYTES:%.+]] = constant 4 : i64 // CHECK: [[TYPE_IN_BYTES:%.+]] = constant 4 : i64
// CHECK: %[[INDEX_0:.+]] = constant 0 : index // CHECK: %[[INDEX_0:.+]] = constant 0 : index
// CHECK: [[LOAD_0:%.+]] = load %arg1[%[[INDEX_0]]] : memref<4xi32> // CHECK: [[LOAD_0:%.+]] = load %arg1[%[[INDEX_0]]] : memref<4xi32>
// CHECK: [[EXT_0:%.+]] = zexti [[LOAD_0]] : i32 to i64 // CHECK: [[DIM_0:%.+]] = dim %arg0, 0 : memref<?x10xf32>
// CHECK: [[DIM_0_CAST:%.+]] = index_cast [[DIM_0]] : index to i32
// CHECK: [[CONSTANT_0:%.+]] = constant 0 : i32
// CHECK: [[CMP:%.+]] = cmpi "eq", [[LOAD_0]], [[CONSTANT_0]] : i32
// CHECK: [[SELECT_0:%.+]] = select [[CMP]], [[DIM_0_CAST]], [[LOAD_0]] : i32
// CHECK: [[EXT_0:%.+]] = zexti [[SELECT_0]] : i32 to i64
// CHECK: [[MUL_0:%.+]] = muli [[TYPE_IN_BYTES]], [[EXT_0]] : i64 // CHECK: [[MUL_0:%.+]] = muli [[TYPE_IN_BYTES]], [[EXT_0]] : i64
// CHECK: [[CAST_0:%.+]] = index_cast [[LOAD_0]] : i32 to index // CHECK: [[CAST_0:%.+]] = index_cast [[SELECT_0]] : i32 to index
// CHECK: %[[INDEX_1:.+]] = constant 1 : index // CHECK: %[[INDEX_1:.+]] = constant 1 : index
// CHECK: [[LOAD_1:%.+]] = load %arg1[%[[INDEX_1]]] : memref<4xi32> // CHECK: [[LOAD_1:%.+]] = load %arg1[%[[INDEX_1]]] : memref<4xi32>
// CHECK: [[EXT_1:%.+]] = zexti [[LOAD_1]] : i32 to i64 // CHECK: [[CONSTANT_1:%.+]] = constant 10 : i32
// CHECK: [[CONSTANT_2:%.+]] = constant 0 : i32
// CHECK: [[CMP_1:%.+]] = cmpi "eq", [[LOAD_1]], [[CONSTANT_2]] : i32
// CHECK: [[SELECT_1:%.+]] = select [[CMP_1]], [[CONSTANT_1]], [[LOAD_1]] : i32
// CHECK: [[EXT_1:%.+]] = zexti [[SELECT_1]] : i32 to i64
// CHECK: [[MUL_1:%.+]] = muli [[MUL_0]], [[EXT_1]] : i64 // CHECK: [[MUL_1:%.+]] = muli [[MUL_0]], [[EXT_1]] : i64
// CHECK: [[CAST_1:%.+]] = index_cast [[LOAD_1]] : i32 to index // CHECK: [[CAST_1:%.+]] = index_cast [[SELECT_1]] : i32 to index
// CHECK: %[[INDEX_2:.+]] = constant 2 : index // CHECK: %[[INDEX_2:.+]] = constant 2 : index
// CHECK: [[LOAD_2:%.+]] = load %arg1[%[[INDEX_2]]] : memref<4xi32> // CHECK: [[LOAD_2:%.+]] = load %arg1[%[[INDEX_2]]] : memref<4xi32>
// CHECK: [[EXT_2:%.+]] = zexti [[LOAD_2]] : i32 to i64 // CHECK: [[EXT_2:%.+]] = zexti [[LOAD_2]] : i32 to i64