Support dimension zero in reshape (#55)
Co-authored-by: Gheorghe-Teodor Bercea <gt.bercea@gmail.com>
This commit is contained in:
parent
f3047943a1
commit
5b44169aaa
|
@ -1118,6 +1118,7 @@ struct ONNXReshapeOpLowering : public ConversionPattern {
|
||||||
alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc);
|
alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc);
|
||||||
} else {
|
} else {
|
||||||
auto memRefShape = memRefType.getShape();
|
auto memRefShape = memRefType.getShape();
|
||||||
|
auto inputShape = operands[0].getType().cast<MemRefType>().getShape();
|
||||||
SmallVector<Value, 4> allocOperands;
|
SmallVector<Value, 4> allocOperands;
|
||||||
for (int i = 0; i < memRefShape.size(); ++i) {
|
for (int i = 0; i < memRefShape.size(); ++i) {
|
||||||
// The shape array can always be used to construct shape information of
|
// The shape array can always be used to construct shape information of
|
||||||
|
@ -1126,6 +1127,24 @@ struct ONNXReshapeOpLowering : public ConversionPattern {
|
||||||
loc, rewriter.getIntegerAttr(rewriter.getIndexType(), i));
|
loc, rewriter.getIntegerAttr(rewriter.getIndexType(), i));
|
||||||
// Load index from array of indices.
|
// Load index from array of indices.
|
||||||
Value loadedVal = rewriter.create<LoadOp>(loc, operands[1], index);
|
Value loadedVal = rewriter.create<LoadOp>(loc, operands[1], index);
|
||||||
|
// If a dimension is zero, the actual dimension value is taken from the
|
||||||
|
// input tensor.
|
||||||
|
if (i < inputShape.size()) {
|
||||||
|
Value dimVal;
|
||||||
|
auto dimTy = loadedVal.getType().cast<IntegerType>();
|
||||||
|
if (inputShape[i] < 0) {
|
||||||
|
Value dim = rewriter.create<DimOp>(loc, operands[0], i);
|
||||||
|
dimVal = rewriter.create<IndexCastOp>(loc, dim, dimTy);
|
||||||
|
} else {
|
||||||
|
dimVal = rewriter.create<ConstantOp>(
|
||||||
|
loc, rewriter.getIntegerAttr(dimTy, inputShape[i]));
|
||||||
|
}
|
||||||
|
auto zero = rewriter.create<ConstantOp>(
|
||||||
|
loc, rewriter.getIntegerAttr(dimTy, 0));
|
||||||
|
auto isZero =
|
||||||
|
rewriter.create<CmpIOp>(loc, CmpIPredicate::eq, loadedVal, zero);
|
||||||
|
loadedVal = rewriter.create<SelectOp>(loc, isZero, dimVal, loadedVal);
|
||||||
|
}
|
||||||
// Check if the loaded index is already the correct width of 64 bits.
|
// Check if the loaded index is already the correct width of 64 bits.
|
||||||
// Convert the value to a 64 bit integer if needed.
|
// Convert the value to a 64 bit integer if needed.
|
||||||
Value int64LoadedVal = loadedVal;
|
Value int64LoadedVal = loadedVal;
|
||||||
|
|
|
@ -160,7 +160,15 @@ test_to_enable = [
|
||||||
"test_softsign_example_cpu",
|
"test_softsign_example_cpu",
|
||||||
|
|
||||||
# ReshapeOp:
|
# ReshapeOp:
|
||||||
|
"test_reshape_extended_dims_cpu",
|
||||||
|
#"test_reshape_negative_dim_cpu", <- handle nagative dim
|
||||||
|
#"test_reshape_negative_extended_dims_cpu", <- handle nagative dim
|
||||||
|
"test_reshape_one_dim_cpu",
|
||||||
|
"test_reshape_reduced_dims_cpu",
|
||||||
"test_reshape_reordered_all_dims_cpu",
|
"test_reshape_reordered_all_dims_cpu",
|
||||||
|
"test_reshape_reordered_last_dims_cpu",
|
||||||
|
#"test_reshape_zero_and_negative_dim_cpu", <- handle nagative dim
|
||||||
|
"test_reshape_zero_dim_cpu",
|
||||||
]
|
]
|
||||||
|
|
||||||
# Extract name of all test cases.
|
# Extract name of all test cases.
|
||||||
|
|
|
@ -305,14 +305,23 @@ func @test_reshape(%arg0 : tensor<?x10xf32>, %arg1 : tensor<4xi32>) -> tensor<*x
|
||||||
// CHECK: [[TYPE_IN_BYTES:%.+]] = constant 4 : i64
|
// CHECK: [[TYPE_IN_BYTES:%.+]] = constant 4 : i64
|
||||||
// CHECK: %[[INDEX_0:.+]] = constant 0 : index
|
// CHECK: %[[INDEX_0:.+]] = constant 0 : index
|
||||||
// CHECK: [[LOAD_0:%.+]] = load %arg1[%[[INDEX_0]]] : memref<4xi32>
|
// CHECK: [[LOAD_0:%.+]] = load %arg1[%[[INDEX_0]]] : memref<4xi32>
|
||||||
// CHECK: [[EXT_0:%.+]] = zexti [[LOAD_0]] : i32 to i64
|
// CHECK: [[DIM_0:%.+]] = dim %arg0, 0 : memref<?x10xf32>
|
||||||
|
// CHECK: [[DIM_0_CAST:%.+]] = index_cast [[DIM_0]] : index to i32
|
||||||
|
// CHECK: [[CONSTANT_0:%.+]] = constant 0 : i32
|
||||||
|
// CHECK: [[CMP:%.+]] = cmpi "eq", [[LOAD_0]], [[CONSTANT_0]] : i32
|
||||||
|
// CHECK: [[SELECT_0:%.+]] = select [[CMP]], [[DIM_0_CAST]], [[LOAD_0]] : i32
|
||||||
|
// CHECK: [[EXT_0:%.+]] = zexti [[SELECT_0]] : i32 to i64
|
||||||
// CHECK: [[MUL_0:%.+]] = muli [[TYPE_IN_BYTES]], [[EXT_0]] : i64
|
// CHECK: [[MUL_0:%.+]] = muli [[TYPE_IN_BYTES]], [[EXT_0]] : i64
|
||||||
// CHECK: [[CAST_0:%.+]] = index_cast [[LOAD_0]] : i32 to index
|
// CHECK: [[CAST_0:%.+]] = index_cast [[SELECT_0]] : i32 to index
|
||||||
// CHECK: %[[INDEX_1:.+]] = constant 1 : index
|
// CHECK: %[[INDEX_1:.+]] = constant 1 : index
|
||||||
// CHECK: [[LOAD_1:%.+]] = load %arg1[%[[INDEX_1]]] : memref<4xi32>
|
// CHECK: [[LOAD_1:%.+]] = load %arg1[%[[INDEX_1]]] : memref<4xi32>
|
||||||
// CHECK: [[EXT_1:%.+]] = zexti [[LOAD_1]] : i32 to i64
|
// CHECK: [[CONSTANT_1:%.+]] = constant 10 : i32
|
||||||
|
// CHECK: [[CONSTANT_2:%.+]] = constant 0 : i32
|
||||||
|
// CHECK: [[CMP_1:%.+]] = cmpi "eq", [[LOAD_1]], [[CONSTANT_2]] : i32
|
||||||
|
// CHECK: [[SELECT_1:%.+]] = select [[CMP_1]], [[CONSTANT_1]], [[LOAD_1]] : i32
|
||||||
|
// CHECK: [[EXT_1:%.+]] = zexti [[SELECT_1]] : i32 to i64
|
||||||
// CHECK: [[MUL_1:%.+]] = muli [[MUL_0]], [[EXT_1]] : i64
|
// CHECK: [[MUL_1:%.+]] = muli [[MUL_0]], [[EXT_1]] : i64
|
||||||
// CHECK: [[CAST_1:%.+]] = index_cast [[LOAD_1]] : i32 to index
|
// CHECK: [[CAST_1:%.+]] = index_cast [[SELECT_1]] : i32 to index
|
||||||
// CHECK: %[[INDEX_2:.+]] = constant 2 : index
|
// CHECK: %[[INDEX_2:.+]] = constant 2 : index
|
||||||
// CHECK: [[LOAD_2:%.+]] = load %arg1[%[[INDEX_2]]] : memref<4xi32>
|
// CHECK: [[LOAD_2:%.+]] = load %arg1[%[[INDEX_2]]] : memref<4xi32>
|
||||||
// CHECK: [[EXT_2:%.+]] = zexti [[LOAD_2]] : i32 to i64
|
// CHECK: [[EXT_2:%.+]] = zexti [[LOAD_2]] : i32 to i64
|
||||||
|
|
Loading…
Reference in New Issue