This commit is contained in:
parent
216bc4643c
commit
6827898339
|
@ -0,0 +1 @@
|
||||||
|
from .modeling_gpt2 import GPT2LMHeadModel, GPT2Model
|
|
@ -5,11 +5,13 @@ from typing import Any, Dict, Optional, Tuple, Union
|
||||||
import torch
|
import torch
|
||||||
import torch.utils.checkpoint
|
import torch.utils.checkpoint
|
||||||
import transformers
|
import transformers
|
||||||
|
from torch import nn
|
||||||
from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions
|
from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions
|
||||||
from transformers.models.gpt2 import (
|
from transformers.models.gpt2.modeling_gpt2 import (
|
||||||
_CHECKPOINT_FOR_DOC,
|
_CHECKPOINT_FOR_DOC,
|
||||||
_CONFIG_FOR_DOC,
|
_CONFIG_FOR_DOC,
|
||||||
GPT2_INPUTS_DOCSTRING,
|
GPT2_INPUTS_DOCSTRING,
|
||||||
|
logger,
|
||||||
)
|
)
|
||||||
from transformers.utils import (
|
from transformers.utils import (
|
||||||
ModelOutput,
|
ModelOutput,
|
||||||
|
@ -260,6 +262,18 @@ class GPT2Model(transformers.models.gpt2.GPT2Model):
|
||||||
|
|
||||||
|
|
||||||
class GPT2LMHeadModel(transformers.models.gpt2.GPT2LMHeadModel):
|
class GPT2LMHeadModel(transformers.models.gpt2.GPT2LMHeadModel):
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__(config)
|
||||||
|
self.transformer = GPT2Model(config)
|
||||||
|
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
|
||||||
|
|
||||||
|
# Model parallel
|
||||||
|
self.model_parallel = False
|
||||||
|
self.device_map = None
|
||||||
|
|
||||||
|
# Initialize weights and apply final processing
|
||||||
|
self.post_init()
|
||||||
|
|
||||||
def prepare_inputs_for_generation(
|
def prepare_inputs_for_generation(
|
||||||
self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs
|
self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs
|
||||||
):
|
):
|
||||||
|
|
Loading…
Reference in New Issue