Refine meaning dataset memory cost when building.
This commit is contained in:
		
							parent
							
								
									c907210fc1
								
							
						
					
					
						commit
						ef08359a94
					
				| 
						 | 
					@ -1,14 +1,11 @@
 | 
				
			||||||
import os
 | 
					import os
 | 
				
			||||||
import datasets
 | 
					import torch, datasets
 | 
				
			||||||
import torch
 | 
					import math, gc, time, random, copy
 | 
				
			||||||
import math
 | 
					 | 
				
			||||||
import random
 | 
					 | 
				
			||||||
from itertools import chain
 | 
					from itertools import chain
 | 
				
			||||||
from typing import Dict, Tuple
 | 
					from typing import Dict, Tuple
 | 
				
			||||||
from torch.utils.data import ConcatDataset, DataLoader, Dataset, random_split
 | 
					from torch.utils.data import ConcatDataset, DataLoader, Dataset, random_split
 | 
				
			||||||
import numpy as np
 | 
					import numpy as np
 | 
				
			||||||
from torch.utils.data import BatchSampler
 | 
					from torch.utils.data import BatchSampler
 | 
				
			||||||
import copy
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class MeaningMap:
 | 
					class MeaningMap:
 | 
				
			||||||
| 
						 | 
					@ -25,6 +22,7 @@ class MeaningMap:
 | 
				
			||||||
        file += "_" + str(max_subitem) + "_" + str(min_subitem)
 | 
					        file += "_" + str(max_subitem) + "_" + str(min_subitem)
 | 
				
			||||||
        file = path + file + ".npz"
 | 
					        file = path + file + ".npz"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        start_time = time.time()
 | 
				
			||||||
        if not os.path.exists(path):
 | 
					        if not os.path.exists(path):
 | 
				
			||||||
            os.mkdir(path)
 | 
					            os.mkdir(path)
 | 
				
			||||||
        if os.path.exists(file) and use_cache:
 | 
					        if os.path.exists(file) and use_cache:
 | 
				
			||||||
| 
						 | 
					@ -41,7 +39,7 @@ class MeaningMap:
 | 
				
			||||||
            self.ms_rank_all = dlra[:, 3].astype(np.uint32)
 | 
					            self.ms_rank_all = dlra[:, 3].astype(np.uint32)
 | 
				
			||||||
            self.ms_height = slhwm[:, 2]
 | 
					            self.ms_height = slhwm[:, 2]
 | 
				
			||||||
            self.ms_weight = slhwm[:, 3]
 | 
					            self.ms_weight = slhwm[:, 3]
 | 
				
			||||||
            print("Load end")
 | 
					            print("Load end, elapsed:" + str(time.time() - start_time) + "s")
 | 
				
			||||||
        else:
 | 
					        else:
 | 
				
			||||||
            print("Disk cache miss, build new one.")
 | 
					            print("Disk cache miss, build new one.")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -75,10 +73,10 @@ class MeaningMap:
 | 
				
			||||||
            ms_weight = []  # meaning tree weight
 | 
					            ms_weight = []  # meaning tree weight
 | 
				
			||||||
            index = 0
 | 
					            index = 0
 | 
				
			||||||
            for i in range(self.vocab_size):
 | 
					            for i in range(self.vocab_size):
 | 
				
			||||||
                ms_data.append(np.array([i]))
 | 
					                ms_data.append(np.array([i], dtype=np.int32))
 | 
				
			||||||
                ms_level.append(np.array([0]))
 | 
					                ms_level.append(np.array([0], dtype=np.int32))
 | 
				
			||||||
                ms_rank_idx.append(np.array([0]))
 | 
					                ms_rank_idx.append(np.array([0], dtype=np.uint32))
 | 
				
			||||||
                ms_rank_all.append(np.array([0]))
 | 
					                ms_rank_all.append(np.array([0], dtype=np.uint32))
 | 
				
			||||||
                ms_start.append(index)
 | 
					                ms_start.append(index)
 | 
				
			||||||
                ms_len.append(1)
 | 
					                ms_len.append(1)
 | 
				
			||||||
                ms_height.append(0)
 | 
					                ms_height.append(0)
 | 
				
			||||||
| 
						 | 
					@ -100,13 +98,13 @@ class MeaningMap:
 | 
				
			||||||
                        ([0xFFFFFFF0 + i] if newm < self.vocab_size else ms_rank_idx[newm] * 16 + i)
 | 
					                        ([0xFFFFFFF0 + i] if newm < self.vocab_size else ms_rank_idx[newm] * 16 + i)
 | 
				
			||||||
                        for i, newm in enumerate(m_list)
 | 
					                        for i, newm in enumerate(m_list)
 | 
				
			||||||
                    ]
 | 
					                    ]
 | 
				
			||||||
                )
 | 
					                ).astype(np.uint32)
 | 
				
			||||||
                mrl = np.concatenate(
 | 
					                mrl = np.concatenate(
 | 
				
			||||||
                    [
 | 
					                    [
 | 
				
			||||||
                        ([0xFFFFFFF0 + m_len] if newm < self.vocab_size else ms_rank_all[newm] * 16 + m_len)
 | 
					                        ([0xFFFFFFF0 + m_len] if newm < self.vocab_size else ms_rank_all[newm] * 16 + m_len)
 | 
				
			||||||
                        for i, newm in enumerate(m_list)
 | 
					                        for i, newm in enumerate(m_list)
 | 
				
			||||||
                    ]
 | 
					                    ]
 | 
				
			||||||
                )
 | 
					                ).astype(np.uint32)
 | 
				
			||||||
                ms_data.append(ma)
 | 
					                ms_data.append(ma)
 | 
				
			||||||
                ms_level.append(ml)
 | 
					                ms_level.append(ml)
 | 
				
			||||||
                ms_rank_idx.append(mr)
 | 
					                ms_rank_idx.append(mr)
 | 
				
			||||||
| 
						 | 
					@ -117,6 +115,7 @@ class MeaningMap:
 | 
				
			||||||
                ms_weight.append(sum(ms_weight[sub_m] for sub_m in m_list))
 | 
					                ms_weight.append(sum(ms_weight[sub_m] for sub_m in m_list))
 | 
				
			||||||
                index = index + len(ma)
 | 
					                index = index + len(ma)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            print("Mapping end, elapsed:" + str(time.time() - start_time) + "s")
 | 
				
			||||||
            ms_data = np.array(list(chain(*ms_data))).astype(np.int32)
 | 
					            ms_data = np.array(list(chain(*ms_data))).astype(np.int32)
 | 
				
			||||||
            ms_level = np.array(list(chain(*ms_level))).astype(np.int32)
 | 
					            ms_level = np.array(list(chain(*ms_level))).astype(np.int32)
 | 
				
			||||||
            ms_rank_idx = np.array(list(chain(*ms_rank_idx))).astype(np.uint32)
 | 
					            ms_rank_idx = np.array(list(chain(*ms_rank_idx))).astype(np.uint32)
 | 
				
			||||||
| 
						 | 
					@ -175,7 +174,7 @@ class MeaningMap:
 | 
				
			||||||
            self.ms_len = ms_len
 | 
					            self.ms_len = ms_len
 | 
				
			||||||
            self.ms_height = ms_height
 | 
					            self.ms_height = ms_height
 | 
				
			||||||
            self.ms_weight = ms_weight
 | 
					            self.ms_weight = ms_weight
 | 
				
			||||||
            print("Disk cache build end.")
 | 
					            print("Disk cache build end, elapsed:" + str(time.time() - start_time) + "s")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def get_sequence(self, meaning):  # return sequence[meaning]
 | 
					    def get_sequence(self, meaning):  # return sequence[meaning]
 | 
				
			||||||
        start = self.ms_start[meaning]
 | 
					        start = self.ms_start[meaning]
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue