This commit is contained in:
parent
216bc4643c
commit
6827898339
|
@ -0,0 +1 @@
|
|||
from .modeling_gpt2 import GPT2LMHeadModel, GPT2Model
|
|
@ -5,11 +5,13 @@ from typing import Any, Dict, Optional, Tuple, Union
|
|||
import torch
|
||||
import torch.utils.checkpoint
|
||||
import transformers
|
||||
from torch import nn
|
||||
from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions
|
||||
from transformers.models.gpt2 import (
|
||||
from transformers.models.gpt2.modeling_gpt2 import (
|
||||
_CHECKPOINT_FOR_DOC,
|
||||
_CONFIG_FOR_DOC,
|
||||
GPT2_INPUTS_DOCSTRING,
|
||||
logger,
|
||||
)
|
||||
from transformers.utils import (
|
||||
ModelOutput,
|
||||
|
@ -260,6 +262,18 @@ class GPT2Model(transformers.models.gpt2.GPT2Model):
|
|||
|
||||
|
||||
class GPT2LMHeadModel(transformers.models.gpt2.GPT2LMHeadModel):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.transformer = GPT2Model(config)
|
||||
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
|
||||
|
||||
# Model parallel
|
||||
self.model_parallel = False
|
||||
self.device_map = None
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
def prepare_inputs_for_generation(
|
||||
self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs
|
||||
):
|
||||
|
|
Loading…
Reference in New Issue