92 lines
2.7 KiB
Python
92 lines
2.7 KiB
Python
class ModelConfig:
|
|
def __init__(self):
|
|
self.vocab_size = 4096
|
|
self.hidden_size = 1024
|
|
self.num_hidden_layers = 24
|
|
self.num_attention_heads = 16
|
|
self.emb_dropout_prob = 0.0
|
|
self.attn_dropout_prob = 0.0
|
|
self.layer_norm_epsilon = 1e-6
|
|
self.initializer_range = 0.02
|
|
self.max_position_embeddings = 8192
|
|
self.scale_attn_weights = True
|
|
self.use_cache = True
|
|
self.bf16 = False
|
|
self.fp16 = False
|
|
self.fp32 = False
|
|
self.rotary_pct = 1.0
|
|
self.rotary_emb_base = 10000
|
|
self.use_dynamic_ntk = True
|
|
self.use_logn_attn = True
|
|
self.use_flash_attn = "auto"
|
|
self.intermediate_size = 5504 # 5504 11008
|
|
self.no_bias = True
|
|
self.tie_word_embeddings = False
|
|
self.use_cache_quantization = False
|
|
self.use_cache_kernel = False
|
|
self.softmax_in_fp32 = False
|
|
|
|
self.chat_format = "chatml"
|
|
self.max_window_size = 6144
|
|
self.max_new_tokens = 512
|
|
self.do_sample = True
|
|
self.top_k = 0
|
|
self.top_p = 0.8
|
|
self.repetition_penalty = 1.1
|
|
self.model_max_length = 8192
|
|
|
|
|
|
class MeaningDatasetConfig:
|
|
def __init__(self):
|
|
self.level_ratio = 5
|
|
self.level = 5
|
|
self.dataset_level = 3
|
|
self.min_subitem = 2
|
|
self.mask_level = None
|
|
self.mask_idx = None
|
|
|
|
|
|
class DatasetConfig:
|
|
def __init__(self):
|
|
self.name = "meaning"
|
|
self.meaning = MeaningDatasetConfig()
|
|
|
|
|
|
class TrainConfig:
|
|
def __init__(self):
|
|
self.name = "bigger" # current train process name
|
|
self.pretrain_model_name = None # "qwen/Qwen-1_8B-Chat"
|
|
self.learning_rate = 0.0001
|
|
self.use_tril_attention_mask = None
|
|
self.precision = "16-mixed" # "precision:bf16-mixed,16-mixed,32-true"
|
|
self.train_batch_size = 4
|
|
self.val_batch_size = 4
|
|
self.num_proc = 8
|
|
self.max_epochs = 1000
|
|
self.strategy = "auto"
|
|
self.resume_from_ckpt_path = None
|
|
self.seed = 42
|
|
self.dataloader_works = 2
|
|
|
|
self.model_config = ModelConfig()
|
|
self.dataset = DatasetConfig()
|
|
|
|
|
|
def class_to_dict(obj):
|
|
if isinstance(obj, (int, float, str, bool, type(None))):
|
|
return obj
|
|
elif isinstance(obj, dict):
|
|
return {k: class_to_dict(v) for k, v in obj.items()}
|
|
elif isinstance(obj, list):
|
|
return {str(index): value for index, value in enumerate(obj)}
|
|
elif hasattr(obj, "__dict__"):
|
|
return {k: class_to_dict(v) for k, v in obj.__dict__.items()}
|
|
else:
|
|
return obj
|
|
|
|
|
|
# train_config = TrainConfig()
|
|
# train_config_dict = class_to_dict(train_config)
|
|
# import pprint
|
|
# pprint.pprint(train_config_dict)
|