Support dimension zero in reshape (#55)
Co-authored-by: Gheorghe-Teodor Bercea <gt.bercea@gmail.com>
This commit is contained in:
		
							parent
							
								
									f3047943a1
								
							
						
					
					
						commit
						5b44169aaa
					
				|  | @ -1118,6 +1118,7 @@ struct ONNXReshapeOpLowering : public ConversionPattern { | |||
|       alloc = insertAllocAndDealloc(memRefType, loc, rewriter, insertDealloc); | ||||
|     } else { | ||||
|       auto memRefShape = memRefType.getShape(); | ||||
|       auto inputShape = operands[0].getType().cast<MemRefType>().getShape(); | ||||
|       SmallVector<Value, 4> allocOperands; | ||||
|       for (int i = 0; i < memRefShape.size(); ++i) { | ||||
|         // The shape array can always be used to construct shape information of
 | ||||
|  | @ -1126,6 +1127,24 @@ struct ONNXReshapeOpLowering : public ConversionPattern { | |||
|             loc, rewriter.getIntegerAttr(rewriter.getIndexType(), i)); | ||||
|         // Load index from array of indices.
 | ||||
|         Value loadedVal = rewriter.create<LoadOp>(loc, operands[1], index); | ||||
|         // If a dimension is zero, the actual dimension value is taken from the
 | ||||
|         // input tensor.
 | ||||
|         if (i < inputShape.size()) { | ||||
|           Value dimVal; | ||||
|           auto dimTy = loadedVal.getType().cast<IntegerType>(); | ||||
|           if (inputShape[i] < 0) { | ||||
|             Value dim = rewriter.create<DimOp>(loc, operands[0], i); | ||||
|             dimVal = rewriter.create<IndexCastOp>(loc, dim, dimTy); | ||||
|           } else { | ||||
|             dimVal = rewriter.create<ConstantOp>( | ||||
|                 loc, rewriter.getIntegerAttr(dimTy, inputShape[i])); | ||||
|           } | ||||
|           auto zero = rewriter.create<ConstantOp>( | ||||
|               loc, rewriter.getIntegerAttr(dimTy, 0)); | ||||
|           auto isZero = | ||||
|               rewriter.create<CmpIOp>(loc, CmpIPredicate::eq, loadedVal, zero); | ||||
|           loadedVal = rewriter.create<SelectOp>(loc, isZero, dimVal, loadedVal); | ||||
|         } | ||||
|         // Check if the loaded index is already the correct width of 64 bits.
 | ||||
|         // Convert the value to a 64 bit integer if needed.
 | ||||
|         Value int64LoadedVal = loadedVal; | ||||
|  |  | |||
|  | @ -160,7 +160,15 @@ test_to_enable = [ | |||
|     "test_softsign_example_cpu", | ||||
| 
 | ||||
|     # ReshapeOp: | ||||
|     "test_reshape_extended_dims_cpu", | ||||
|     #"test_reshape_negative_dim_cpu", <- handle nagative dim | ||||
|     #"test_reshape_negative_extended_dims_cpu", <- handle nagative dim | ||||
|     "test_reshape_one_dim_cpu", | ||||
|     "test_reshape_reduced_dims_cpu", | ||||
|     "test_reshape_reordered_all_dims_cpu", | ||||
|     "test_reshape_reordered_last_dims_cpu", | ||||
|     #"test_reshape_zero_and_negative_dim_cpu", <- handle nagative dim | ||||
|     "test_reshape_zero_dim_cpu", | ||||
| ] | ||||
| 
 | ||||
| # Extract name of all test cases. | ||||
|  |  | |||
|  | @ -305,14 +305,23 @@ func @test_reshape(%arg0 : tensor<?x10xf32>, %arg1 : tensor<4xi32>) -> tensor<*x | |||
|   // CHECK: [[TYPE_IN_BYTES:%.+]] = constant 4 : i64 | ||||
|   // CHECK: %[[INDEX_0:.+]] = constant 0 : index | ||||
|   // CHECK: [[LOAD_0:%.+]] = load %arg1[%[[INDEX_0]]] : memref<4xi32> | ||||
|   // CHECK: [[EXT_0:%.+]] = zexti [[LOAD_0]] : i32 to i64 | ||||
|   // CHECK: [[DIM_0:%.+]] = dim %arg0, 0 : memref<?x10xf32> | ||||
|   // CHECK: [[DIM_0_CAST:%.+]] = index_cast [[DIM_0]] : index to i32 | ||||
|   // CHECK: [[CONSTANT_0:%.+]] = constant 0 : i32 | ||||
|   // CHECK: [[CMP:%.+]] = cmpi "eq", [[LOAD_0]], [[CONSTANT_0]] : i32 | ||||
|   // CHECK: [[SELECT_0:%.+]] = select [[CMP]], [[DIM_0_CAST]], [[LOAD_0]] : i32 | ||||
|   // CHECK: [[EXT_0:%.+]] = zexti [[SELECT_0]] : i32 to i64 | ||||
|   // CHECK: [[MUL_0:%.+]] = muli [[TYPE_IN_BYTES]], [[EXT_0]] : i64 | ||||
|   // CHECK: [[CAST_0:%.+]] = index_cast [[LOAD_0]] : i32 to index | ||||
|   // CHECK: [[CAST_0:%.+]] = index_cast [[SELECT_0]] : i32 to index | ||||
|   // CHECK: %[[INDEX_1:.+]] = constant 1 : index | ||||
|   // CHECK: [[LOAD_1:%.+]] = load %arg1[%[[INDEX_1]]] : memref<4xi32> | ||||
|   // CHECK: [[EXT_1:%.+]] = zexti [[LOAD_1]] : i32 to i64 | ||||
|   // CHECK: [[CONSTANT_1:%.+]] = constant 10 : i32 | ||||
|   // CHECK: [[CONSTANT_2:%.+]] = constant 0 : i32 | ||||
|   // CHECK: [[CMP_1:%.+]] = cmpi "eq", [[LOAD_1]], [[CONSTANT_2]] : i32 | ||||
|   // CHECK: [[SELECT_1:%.+]] = select [[CMP_1]], [[CONSTANT_1]], [[LOAD_1]] : i32 | ||||
|   // CHECK: [[EXT_1:%.+]] = zexti [[SELECT_1]] : i32 to i64 | ||||
|   // CHECK: [[MUL_1:%.+]] = muli [[MUL_0]], [[EXT_1]] : i64 | ||||
|   // CHECK: [[CAST_1:%.+]] = index_cast [[LOAD_1]] : i32 to index | ||||
|   // CHECK: [[CAST_1:%.+]] = index_cast [[SELECT_1]] : i32 to index | ||||
|   // CHECK: %[[INDEX_2:.+]] = constant 2 : index | ||||
|   // CHECK: [[LOAD_2:%.+]] = load %arg1[%[[INDEX_2]]] : memref<4xi32> | ||||
|   // CHECK: [[EXT_2:%.+]] = zexti [[LOAD_2]] : i32 to i64 | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue