This commit is contained in:
Yiqing-Zhou 2023-05-14 22:53:28 +08:00
parent 216bc4643c
commit 6827898339
2 changed files with 16 additions and 1 deletions

View File

@ -0,0 +1 @@
from .modeling_gpt2 import GPT2LMHeadModel, GPT2Model

View File

@ -5,11 +5,13 @@ from typing import Any, Dict, Optional, Tuple, Union
import torch import torch
import torch.utils.checkpoint import torch.utils.checkpoint
import transformers import transformers
from torch import nn
from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions
from transformers.models.gpt2 import ( from transformers.models.gpt2.modeling_gpt2 import (
_CHECKPOINT_FOR_DOC, _CHECKPOINT_FOR_DOC,
_CONFIG_FOR_DOC, _CONFIG_FOR_DOC,
GPT2_INPUTS_DOCSTRING, GPT2_INPUTS_DOCSTRING,
logger,
) )
from transformers.utils import ( from transformers.utils import (
ModelOutput, ModelOutput,
@ -260,6 +262,18 @@ class GPT2Model(transformers.models.gpt2.GPT2Model):
class GPT2LMHeadModel(transformers.models.gpt2.GPT2LMHeadModel): class GPT2LMHeadModel(transformers.models.gpt2.GPT2LMHeadModel):
def __init__(self, config):
super().__init__(config)
self.transformer = GPT2Model(config)
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
# Model parallel
self.model_parallel = False
self.device_map = None
# Initialize weights and apply final processing
self.post_init()
def prepare_inputs_for_generation( def prepare_inputs_for_generation(
self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs
): ):