Witllm/wit/configuration.py

92 lines
2.7 KiB
Python

class ModelConfig:
def __init__(self):
self.vocab_size = 4096
self.hidden_size = 1024
self.num_hidden_layers = 24
self.num_attention_heads = 16
self.emb_dropout_prob = 0.0
self.attn_dropout_prob = 0.0
self.layer_norm_epsilon = 1e-6
self.initializer_range = 0.02
self.max_position_embeddings = 8192
self.scale_attn_weights = True
self.use_cache = True
self.bf16 = False
self.fp16 = False
self.fp32 = False
self.rotary_pct = 1.0
self.rotary_emb_base = 10000
self.use_dynamic_ntk = True
self.use_logn_attn = True
self.use_flash_attn = "auto"
self.intermediate_size = 5504 # 5504 11008
self.no_bias = True
self.tie_word_embeddings = False
self.use_cache_quantization = False
self.use_cache_kernel = False
self.softmax_in_fp32 = False
self.chat_format = "chatml"
self.max_window_size = 6144
self.max_new_tokens = 512
self.do_sample = True
self.top_k = 0
self.top_p = 0.8
self.repetition_penalty = 1.1
self.model_max_length = 8192
class MeaningDatasetConfig:
def __init__(self):
self.level_ratio = 5
self.level = 5
self.dataset_level = 3
self.min_subitem = 2
self.mask_level = None
self.mask_idx = None
class DatasetConfig:
def __init__(self):
self.name = "meaning"
self.meaning = MeaningDatasetConfig()
class TrainConfig:
def __init__(self):
self.name = "bigger" # current train process name
self.pretrain_model_name = None # "qwen/Qwen-1_8B-Chat"
self.learning_rate = 0.0001
self.use_tril_attention_mask = None
self.precision = "16-mixed" # "precision:bf16-mixed,16-mixed,32-true"
self.train_batch_size = 4
self.val_batch_size = 4
self.num_proc = 8
self.max_epochs = 1000
self.strategy = "auto"
self.resume_from_ckpt_path = None
self.seed = 42
self.dataloader_works = 2
self.model_config = ModelConfig()
self.dataset = DatasetConfig()
def class_to_dict(obj):
if isinstance(obj, (int, float, str, bool, type(None))):
return obj
elif isinstance(obj, dict):
return {k: class_to_dict(v) for k, v in obj.items()}
elif isinstance(obj, list):
return {str(index): value for index, value in enumerate(obj)}
elif hasattr(obj, "__dict__"):
return {k: class_to_dict(v) for k, v in obj.__dict__.items()}
else:
return obj
# train_config = TrainConfig()
# train_config_dict = class_to_dict(train_config)
# import pprint
# pprint.pprint(train_config_dict)