diff --git a/custom_models/gpt2/__init__.py b/custom_models/gpt2/__init__.py index e69de29..3a0208c 100644 --- a/custom_models/gpt2/__init__.py +++ b/custom_models/gpt2/__init__.py @@ -0,0 +1 @@ +from .modeling_gpt2 import GPT2LMHeadModel, GPT2Model diff --git a/custom_models/gpt2/modeling_gpt2.py b/custom_models/gpt2/modeling_gpt2.py index f3d7d00..98a5d4e 100644 --- a/custom_models/gpt2/modeling_gpt2.py +++ b/custom_models/gpt2/modeling_gpt2.py @@ -5,11 +5,13 @@ from typing import Any, Dict, Optional, Tuple, Union import torch import torch.utils.checkpoint import transformers +from torch import nn from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions -from transformers.models.gpt2 import ( +from transformers.models.gpt2.modeling_gpt2 import ( _CHECKPOINT_FOR_DOC, _CONFIG_FOR_DOC, GPT2_INPUTS_DOCSTRING, + logger, ) from transformers.utils import ( ModelOutput, @@ -260,6 +262,18 @@ class GPT2Model(transformers.models.gpt2.GPT2Model): class GPT2LMHeadModel(transformers.models.gpt2.GPT2LMHeadModel): + def __init__(self, config): + super().__init__(config) + self.transformer = GPT2Model(config) + self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) + + # Model parallel + self.model_parallel = False + self.device_map = None + + # Initialize weights and apply final processing + self.post_init() + def prepare_inputs_for_generation( self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs ):