profile
viewpoint
If you are wondering where the data of this site comes from, please visit https://api.github.com/users/stefan-it/events. GitMemory does not store any data, but only uses NGINX to cache data for a period of time. The idea behind GitMemory is simply to give users a better reading experience.
Stefan Schweter stefan-it Munich, Germany https://schweter.ml Developer at @dbmdz, M.Sc Computational Linguistics, Researcher, former student @ The Center for Information and Language Processing (CIS), LMU Munich

huggingface/transformers 46052

🤗Transformers: State-of-the-art Natural Language Processing for Pytorch and TensorFlow 2.0.

flairNLP/flair 10348

A very simple framework for state-of-the-art Natural Language Processing (NLP)

dbmdz/berts 87

DBMDZ BERT, DistilBERT, ELECTRA, GPT-2 and ConvBERT models

stefan-it/capsnet-nlp 66

CapsNet for NLP

dbmdz/deep-eos 56

General-Purpose Neural Networks for Sentence Boundary Detection

flairNLP/flair-lms 47

Language Models for Zalando's flair library

stefan-it/nmt-en-vi 47

Neural Machine Translation system for English to Vietnamese (IWSLT'15 English-Vietnamese data)

stefan-it/fine-tuned-berts-seq 36

Fine-tuned Transformers compatible BERT models for Sequence Tagging

stefan-it/flair-experiments 33

Experiments with Zalando's flair library

stefan-it/german-gpt2 24

German GPT-2 model

Pull request review commenthuggingface/transformers

Add VisualBERT

+# coding=utf-8+# Copyright 2021 The HuggingFace Inc. team. All rights reserved.+#+# Licensed under the Apache License, Version 2.0 (the "License");+# you may not use this file except in compliance with the License.+# You may obtain a copy of the License at+#+#     http://www.apache.org/licenses/LICENSE-2.0+#+# Unless required by applicable law or agreed to in writing, software+# distributed under the License is distributed on an "AS IS" BASIS,+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.+# See the License for the specific language governing permissions and+# limitations under the License.+""" PyTorch VisualBERT model. """+++import math+from copy import deepcopy+from dataclasses import dataclass+from typing import Optional, Tuple++import torch+import torch.utils.checkpoint+from torch import nn+from torch.nn import CrossEntropyLoss, KLDivLoss, LogSoftmax++from ...activations import ACT2FN+from ...file_utils import (+    ModelOutput,+    _prepare_output_docstrings,+    add_start_docstrings,+    add_start_docstrings_to_model_forward,+    replace_return_docstrings,+)+from ...modeling_outputs import (+    BaseModelOutput,+    BaseModelOutputWithPooling,+    MaskedLMOutput,+    MultipleChoiceModelOutput,+    SequenceClassifierOutput,+)+from ...modeling_utils import (+    PreTrainedModel,+    apply_chunking_to_forward,+    find_pruneable_heads_and_indices,+    prune_linear_layer,+)+from ...utils import logging+from .configuration_visual_bert import VisualBertConfig+++logger = logging.get_logger(__name__)++_CONFIG_FOR_DOC = "VisualBertConfig"+_TOKENIZER_FOR_DOC = "BertTokenizer"+_TOKENIZER_CHECKPOINT = "bert-base-uncased"++VISUAL_BERT_PRETRAINED_MODEL_ARCHIVE_LIST = [+    "gchhablani/visualbert-vqa",+    "gchhablani/visualbert-vqa-pre",+    "gchhablani/visualbert-vqa-coco-pre",+    "gchhablani/visualbert-vcr",+    "gchhablani/visualbert-vcr-pre",+    "gchhablani/visualbert-vcr-coco-pre",+    "gchhablani/visualbert-nlvr2",+    "gchhablani/visualbert-nlvr2-pre",+    "gchhablani/visualbert-nlvr2-coco-pre"+    # See all VisualBERT models at https://huggingface.co/models?filter=visual_bert+]+++def mish(x):+    return x * torch.tanh(nn.functional.softplus(x))+++def add_code_sample_docstrings(+    *docstr,+    tokenizer_class=None,+    tokenizer_checkpoint=None,+    checkpoint=None,+    output_type=None,+    config_class=None,+    mask=None,+    model_cls=None,+    code_sample=None+):+    def docstring_decorator(fn):+        # model_class defaults to function's class if not specified otherwise+        model_class = fn.__qualname__.split(".")[0] if model_cls is None else model_cls++        doc_kwargs = dict(+            model_class=model_class,+            tokenizer_class=tokenizer_class,+            checkpoint=checkpoint,+            mask=mask,+            tokenizer_checkpoint=tokenizer_checkpoint,+        )++        output_doc = _prepare_output_docstrings(output_type, config_class) if output_type is not None else ""++        built_doc = code_sample.format(**doc_kwargs)+        fn.__doc__ = (fn.__doc__ or "") + "".join(docstr) + output_doc + built_doc+        return fn++    return docstring_decorator+++class VisualBertEmbeddings(nn.Module):+    """Construct the embeddings from word, position and token_type embeddings and visual embeddings."""++    def __init__(self, config):+        super().__init__()+        self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)+        self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)+        self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)++        # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load+        # any TensorFlow checkpoint file++        self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)+        self.dropout = nn.Dropout(config.hidden_dropout_prob)++        # position_ids (1, len position emb) is contiguous in memory and exported when serialized+        self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))+        self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")++        # For Visual Features+        # Token type and position embedding for image features+        self.visual_token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)+        self.visual_position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)++        if config.special_visual_initialize:+            self.visual_token_type_embeddings.weight = torch.nn.Parameter(+                deepcopy(self.token_type_embeddings.weight.data), requires_grad=True+            )+            self.visual_position_embeddings.weight = torch.nn.Parameter(+                deepcopy(self.position_embeddings.weight.data), requires_grad=True+            )++        self.visual_projection = nn.Linear(config.visual_embedding_dim, config.hidden_size)++    def forward(+        self,+        input_ids=None,+        token_type_ids=None,+        position_ids=None,+        inputs_embeds=None,+        visual_embeds=None,+        visual_token_type_ids=None,+        image_text_alignment=None,+    ):+        if input_ids is not None:+            input_shape = input_ids.size()+        else:+            input_shape = inputs_embeds.size()[:-1]++        seq_length = input_shape[1]++        if position_ids is None:+            position_ids = self.position_ids[:, :seq_length]++        # TO-CHECK: FROM ORIGINAL CODE+        # if input_ids is not None:+        #     position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device)+        #     position_ids = position_ids.unsqueeze(0).expand_as(input_ids)+        # else:+        #     position_ids = torch.arange(seq_length, dtype=torch.long, device=inputs_embeds.device)+        #     position_ids = position_ids.unsqueeze(0).expand(input_shape)++        if inputs_embeds is None:+            inputs_embeds = self.word_embeddings(input_ids)++        if token_type_ids is None:+            token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.input_embeds.device)++        token_type_embeddings = self.token_type_embeddings(token_type_ids)++        embeddings = inputs_embeds + token_type_embeddings+        if self.position_embedding_type == "absolute":+            position_embeddings = self.position_embeddings(position_ids)+            embeddings += position_embeddings++        if visual_embeds is not None:+            if visual_token_type_ids is None:+                visual_token_type_ids = torch.ones(+                    visual_embeds.size()[:-1], dtype=torch.long, device=self.position_ids.device+                )++            visual_embeds = self.visual_projection(visual_embeds)+            visual_token_type_embeddings = self.visual_token_type_embeddings(visual_token_type_ids)++            if image_text_alignment is not None:++                # TO-DO: Find a way to handle this in a better way.+                # image_text_alignment = Batch x image_length x alignment_number.+                # Each element denotes the position of the word corresponding to the image feature. -1 is the padding value.+                image_text_alignment_mask = (image_text_alignment != -1).long()+                # Get rid of the -1.+                image_text_alignment = image_text_alignment_mask * image_text_alignment++                # Batch x image_length x alignment length x dim+                visual_position_embeddings = self.position_embeddings(+                    image_text_alignment+                ) * image_text_alignment_mask.to(dtype=next(self.parameters()).dtype).unsqueeze(-1)+                visual_position_embeddings = visual_position_embeddings.sum(2)++                # We want to averge along the alignment_number dimension.+                image_text_alignment_mask = image_text_alignment_mask.to(dtype=next(self.parameters()).dtype).sum(2)+                image_text_alignment_mask[image_text_alignment_mask == 0] = 1  # Avoid devide by zero error+                visual_position_embeddings = visual_position_embeddings / image_text_alignment_mask.unsqueeze(-1)++                visual_position_ids = torch.zeros(+                    *visual_embeds.size()[:-1], dtype=torch.long, device=visual_embeds.device+                )++                # When fine-tuning the detector , the image_text_alignment is sometimes padded too long.+                if visual_position_embeddings.size(1) != visual_embeds.size(1):+                    assert visual_position_embeddings.size(1) >= visual_embeds.size(1)+                    visual_position_embeddings = visual_position_embeddings[:, : visual_embeds.size(1), :]++                visual_position_embeddings = visual_position_embeddings + self.visual_position_embeddings(+                    visual_position_ids+                )+            else:+                visual_position_ids = torch.zeros(+                    *visual_embeds.size()[:-1], dtype=torch.long, device=visual_embeds.device+                )+                visual_position_embeddings = self.visual_position_embeddings(visual_position_ids)++            visual_embeddings = visual_embeds + visual_position_embeddings + visual_token_type_embeddings++            embeddings = torch.cat((embeddings, visual_embeddings), dim=1)++        embeddings = self.LayerNorm(embeddings)+        embeddings = self.dropout(embeddings)+        return embeddings+++class VisualBertSelfAttention(nn.Module):+    def __init__(self, config):+        super().__init__()+        if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):+            raise ValueError(+                "The hidden size (%d) is not a multiple of the number of attention "+                "heads (%d)" % (config.hidden_size, config.num_attention_heads)+            )++        self.num_attention_heads = config.num_attention_heads+        self.attention_head_size = int(config.hidden_size / config.num_attention_heads)+        self.all_head_size = self.num_attention_heads * self.attention_head_size++        self.query = nn.Linear(config.hidden_size, self.all_head_size)+        self.key = nn.Linear(config.hidden_size, self.all_head_size)+        self.value = nn.Linear(config.hidden_size, self.all_head_size)++        self.dropout = nn.Dropout(config.attention_probs_dropout_prob)++        # TO-CHECK: Config doesn't have this, is this needed? Is it in PreTrainedConfig?+        self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")+        if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":+            self.max_position_embeddings = config.max_position_embeddings+            self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)++    def transpose_for_scores(self, x):+        new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)+        x = x.view(*new_x_shape)+        return x.permute(0, 2, 1, 3)++    def forward(+        self,+        hidden_states,+        attention_mask=None,+        head_mask=None,+        output_attentions=False,+    ):+        mixed_query_layer = self.query(hidden_states)++        key_layer = self.transpose_for_scores(self.key(hidden_states))+        value_layer = self.transpose_for_scores(self.value(hidden_states))++        query_layer = self.transpose_for_scores(mixed_query_layer)++        # Take the dot product between "query" and "key" to get the raw attention scores.+        attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))++        if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":+            seq_length = hidden_states.size()[1]+            position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)+            position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1)+            distance = position_ids_l - position_ids_r+            positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)+            positional_embedding = positional_embedding.to(dtype=query_layer.dtype)  # fp16 compatibility++            if self.position_embedding_type == "relative_key":+                relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)+                attention_scores = attention_scores + relative_position_scores+            elif self.position_embedding_type == "relative_key_query":+                relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)+                relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)+                attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key++        attention_scores = attention_scores / math.sqrt(self.attention_head_size)+        if attention_mask is not None:+            # Apply the attention mask is (precomputed for all layers in VisualBertModel forward() function)+            attention_scores = attention_scores + attention_mask++        # Normalize the attention scores to probabilities.+        attention_probs = nn.Softmax(dim=-1)(attention_scores)++        # This is actually dropping out entire tokens to attend to, which might+        # seem a bit unusual, but is taken from the original Transformer paper.+        attention_probs = self.dropout(attention_probs)++        # Mask heads if we want to+        if head_mask is not None:+            attention_probs = attention_probs * head_mask++        context_layer = torch.matmul(attention_probs, value_layer)++        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()+        new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)+        context_layer = context_layer.view(*new_context_layer_shape)++        outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)++        return outputs+++class VisualBertSelfOutput(nn.Module):+    def __init__(self, config):+        super().__init__()+        self.dense = nn.Linear(config.hidden_size, config.hidden_size)+        self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)+        self.dropout = nn.Dropout(config.hidden_dropout_prob)++    def forward(self, hidden_states, input_tensor):+        hidden_states = self.dense(hidden_states)+        hidden_states = self.dropout(hidden_states)+        hidden_states = self.LayerNorm(hidden_states + input_tensor)+        return hidden_states+++class VisualBertAttention(nn.Module):+    def __init__(self, config):+        super().__init__()+        self.self = VisualBertSelfAttention(config)+        self.output = VisualBertSelfOutput(config)+        self.pruned_heads = set()++    def prune_heads(self, heads):+        if len(heads) == 0:+            return+        heads, index = find_pruneable_heads_and_indices(+            heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads+        )++        # Prune linear layers+        self.self.query = prune_linear_layer(self.self.query, index)+        self.self.key = prune_linear_layer(self.self.key, index)+        self.self.value = prune_linear_layer(self.self.value, index)+        self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)++        # Update hyper params and store pruned heads+        self.self.num_attention_heads = self.self.num_attention_heads - len(heads)+        self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads+        self.pruned_heads = self.pruned_heads.union(heads)++    def forward(+        self,+        hidden_states,+        attention_mask=None,+        head_mask=None,+        output_attentions=False,+    ):+        self_outputs = self.self(+            hidden_states,+            attention_mask,+            head_mask,+            output_attentions,+        )+        attention_output = self.output(self_outputs[0], hidden_states)+        outputs = (attention_output,) + self_outputs[1:]  # add attentions if we output them+        return outputs+++class VisualBertIntermediate(nn.Module):+    def __init__(self, config):+        super().__init__()+        self.dense = nn.Linear(config.hidden_size, config.intermediate_size)+        if isinstance(config.hidden_act, str):+            self.intermediate_act_fn = ACT2FN[config.hidden_act]+        else:+            self.intermediate_act_fn = config.hidden_act++    def forward(self, hidden_states):+        hidden_states = self.dense(hidden_states)+        hidden_states = self.intermediate_act_fn(hidden_states)+        return hidden_states+++class VisualBertOutput(nn.Module):+    def __init__(self, config):+        super().__init__()+        self.dense = nn.Linear(config.intermediate_size, config.hidden_size)+        self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)+        self.dropout = nn.Dropout(config.hidden_dropout_prob)++    def forward(self, hidden_states, input_tensor):+        hidden_states = self.dense(hidden_states)+        hidden_states = self.dropout(hidden_states)+        hidden_states = self.LayerNorm(hidden_states + input_tensor)+        return hidden_states+++class VisualBertLayer(nn.Module):+    def __init__(self, config):+        super().__init__()+        self.chunk_size_feed_forward = config.chunk_size_feed_forward+        self.seq_len_dim = 1+        self.attention = VisualBertAttention(config)+        self.intermediate = VisualBertIntermediate(config)+        self.output = VisualBertOutput(config)++    def forward(+        self,+        hidden_states,+        attention_mask=None,+        head_mask=None,+        output_attentions=False,+    ):+        self_attention_outputs = self.attention(+            hidden_states,+            attention_mask,+            head_mask,+            output_attentions=output_attentions,+        )+        attention_output = self_attention_outputs[0]++        outputs = self_attention_outputs[1:]  # add self attentions if we output attention weights++        layer_output = apply_chunking_to_forward(+            self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output+        )+        outputs = (layer_output,) + outputs++        return outputs++    def feed_forward_chunk(self, attention_output):+        intermediate_output = self.intermediate(attention_output)+        layer_output = self.output(intermediate_output, attention_output)+        return layer_output+++class VisualBertEncoder(nn.Module):+    def __init__(self, config):+        super().__init__()+        self.config = config+        self.layer = nn.ModuleList([VisualBertLayer(config) for _ in range(config.num_hidden_layers)])++    def forward(+        self,+        hidden_states,+        attention_mask=None,+        head_mask=None,+        output_attentions=False,+        output_hidden_states=False,+        return_dict=True,+    ):+        all_hidden_states = () if output_hidden_states else None+        all_self_attentions = () if output_attentions else None++        for i, layer_module in enumerate(self.layer):+            if output_hidden_states:+                all_hidden_states = all_hidden_states + (hidden_states,)++            layer_head_mask = head_mask[i] if head_mask is not None else None++            if getattr(self.config, "gradient_checkpointing", False) and self.training:++                def create_custom_forward(module):+                    def custom_forward(*inputs):+                        return module(*inputs, output_attentions)++                    return custom_forward++                layer_outputs = torch.utils.checkpoint.checkpoint(+                    create_custom_forward(layer_module),+                    hidden_states,+                    attention_mask,+                    layer_head_mask,+                )+            else:+                layer_outputs = layer_module(+                    hidden_states,+                    attention_mask,+                    layer_head_mask,+                    output_attentions,+                )++            hidden_states = layer_outputs[0]+            if output_attentions:+                all_self_attentions = all_self_attentions + (layer_outputs[1],)++        if output_hidden_states:+            all_hidden_states = all_hidden_states + (hidden_states,)++        if not return_dict:+            return tuple(+                v+                for v in [+                    hidden_states,+                    all_hidden_states,+                    all_self_attentions,+                ]+                if v is not None+            )+        return BaseModelOutput(+            last_hidden_state=hidden_states,+            hidden_states=all_hidden_states,+            attentions=all_self_attentions,+        )+++class VisualBertPooler(nn.Module):+    def __init__(self, config):+        super().__init__()+        self.dense = nn.Linear(config.hidden_size, config.hidden_size)+        self.activation = nn.Tanh()++    def forward(self, hidden_states):+        # We "pool" the model by simply taking the hidden state corresponding+        # to the first token.+        first_token_tensor = hidden_states[:, 0]+        pooled_output = self.dense(first_token_tensor)+        pooled_output = self.activation(pooled_output)+        return pooled_output+++class VisualBertPredictionHeadTransform(nn.Module):+    def __init__(self, config):+        super().__init__()+        self.dense = nn.Linear(config.hidden_size, config.hidden_size)+        if isinstance(config.hidden_act, str):+            self.transform_act_fn = ACT2FN[config.hidden_act]+        else:+            self.transform_act_fn = config.hidden_act++        self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)++    def forward(self, hidden_states):+        hidden_states = self.dense(hidden_states)+        hidden_states = self.transform_act_fn(hidden_states)+        hidden_states = self.LayerNorm(hidden_states)+        return hidden_states+++class VisualBertLMPredictionHead(nn.Module):+    def __init__(self, config):+        super().__init__()+        self.transform = VisualBertPredictionHeadTransform(config)++        # The output weights are the same as the input embeddings, but there is+        # an output-only bias for each token.+        self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)++        self.bias = nn.Parameter(torch.zeros(config.vocab_size))++        # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`+        self.decoder.bias = self.bias++    def forward(self, hidden_states):+        hidden_states = self.transform(hidden_states)+        hidden_states = self.decoder(hidden_states)+        return hidden_states+++class VisualBertOnlyMLMHead(nn.Module):+    def __init__(self, config):+        super().__init__()+        self.predictions = VisualBertLMPredictionHead(config)++    def forward(self, sequence_output):+        prediction_scores = self.predictions(sequence_output)+        return prediction_scores+++class VisualBertOnlySIPHead(nn.Module):  # Sentence-Image Prediction+    def __init__(self, config):+        super().__init__()+        self.seq_relationship = nn.Linear(config.hidden_size, 2)++    def forward(self, pooled_output):+        seq_relationship_score = self.seq_relationship(pooled_output)+        return seq_relationship_score+++class VisualBertPreTrainingHeads(nn.Module):+    def __init__(self, config):+        super().__init__()+        self.predictions = VisualBertLMPredictionHead(config)+        self.seq_relationship = nn.Linear(config.hidden_size, 2)++    def forward(self, sequence_output, pooled_output):+        prediction_scores = self.predictions(sequence_output)+        seq_relationship_score = self.seq_relationship(pooled_output)+        return prediction_scores, seq_relationship_score+++class VisualBertPreTrainedModel(PreTrainedModel):+    """+    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained+    models.+    """++    config_class = VisualBertConfig+    base_model_prefix = "visual_bert"+    _keys_to_ignore_on_load_missing = [r"position_ids"]++    def _init_weights(self, module):+        """Initialize the weights"""+        if isinstance(module, (nn.Linear, nn.Embedding)):+            # Slightly different from the TF version which uses truncated_normal for initialization+            # cf https://github.com/pytorch/pytorch/pull/5617+            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)++        elif isinstance(module, nn.LayerNorm):+            module.bias.data.zero_()+            module.weight.data.fill_(1.0)+        if isinstance(module, nn.Linear) and module.bias is not None:+            module.bias.data.zero_()+++@dataclass+class VisualBertForPreTrainingOutput(ModelOutput):+    """+    Output type of :class:`~transformers.VisualBertForPreTraining`.++    Args:+        loss (`optional`, returned when ``labels`` is provided, ``torch.FloatTensor`` of shape :obj:`(1,)`):+            Total loss as the sum of the masked language modeling loss and the sentence-image prediction+            (classification) loss.+        prediction_logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`):+            Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).+        seq_relationship_logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, 2)`):+            Prediction scores of the sentence-image prediction (classification) head (scores of True/False continuation+            before SoftMax).+        hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):+            Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)+            of shape :obj:`(batch_size, sequence_length, hidden_size)`.++            Hidden-states of the model at the output of each layer plus the initial embedding outputs.+        attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):+            Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads,+            sequence_length, sequence_length)`.++            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention+            heads.+    """++    loss: Optional[torch.FloatTensor] = None+    prediction_logits: torch.FloatTensor = None+    seq_relationship_logits: torch.FloatTensor = None+    hidden_states: Optional[Tuple[torch.FloatTensor]] = None+    attentions: Optional[Tuple[torch.FloatTensor]] = None+++VISUAL_BERT_START_DOCSTRING = r"""+    This model inherits from :class:`~transformers.PreTrainedModel`. Check the superclass documentation for the generic+    methods the library implements for all its model (such as downloading or saving, resizing the input embeddings,+    pruning heads etc.)++    This model is also a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`__+    subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to+    general usage and behavior.++    Parameters:+        config (:class:`~transformers.VisualBertConfig`): Model configuration class with all the parameters of the model.+            Initializing with a config file does not load the weights associated with the model, only the+            configuration. Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model+            weights.+"""++VISUAL_BERT_INPUTS_DOCSTRING = r"""+    Args:+        input_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`):+            Indices of input sequence tokens in the vocabulary.++            Indices can be obtained using :class:`~transformers.BertTokenizer`. See+            :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for+            details.++            `What are input IDs? <../glossary.html#input-ids>`__+        attention_mask (:obj:`torch.FloatTensor` of shape :obj:`({0})`, `optional`):+            Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``:++            - 1 for tokens that are **not masked**,+            - 0 for tokens that are **masked**.++            `What are attention masks? <../glossary.html#attention-mask>`__+        token_type_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`, `optional`):+            Segment token indices to indicate first and second portions of the inputs. Indices are selected in ``[0,+            1]``:++            - 0 corresponds to a `sentence A` token,+            - 1 corresponds to a `sentence B` token.++            `What are token type IDs? <../glossary.html#token-type-ids>`_+        position_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`, `optional`):+            Indices of positions of each input sequence tokens in the position embeddings. Selected in the range ``[0,+            config.max_position_embeddings - 1]``.++            `What are position IDs? <../glossary.html#position-ids>`_+        head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`):+            Mask to nullify selected heads of the self-attention modules. Mask values selected in ``[0, 1]``:++            - 1 indicates the head is **not masked**,+            - 0 indicates the head is **masked**.++        inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`({0}, hidden_size)`, `optional`):+            Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation.+            This is useful if you want more control over how to convert :obj:`input_ids` indices into associated+            vectors than the model's internal embedding lookup matrix.++        visual_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, visual_seq_length, visual_embedding_dim)`, `optional`):+            The embedded representation of the visual inputs, generally derived using using an object detector.++        visual_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, visual_seq_length)`, `optional`):+            Mask to avoid performing attention on visual embeddings. Mask values selected in ``[0, 1]``:++            - 1 for tokens that are **not masked**,+            - 0 for tokens that are **masked**.++            `What are attention masks? <../glossary.html#attention-mask>`__+        visual_token_type_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, visual_seq_length)`, `optional`):+            Segment token indices to indicate different portions of the visual embeds.++            `What are token type IDs? <../glossary.html#token-type-ids>`_ The authors of VisualBERT set the+            `visual_token_type_ids` to `1` for all tokens.++        image_text_alignment (:obj:`torch.LongTensor` of shape :obj:`(batch_size, visual_seq_length, alignment_number)`, `optional`):+            Image-Text alignment uses to decide the position IDs of the visual embeddings.++        output_attentions (:obj:`bool`, `optional`):+            Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned+            tensors for more detail.+        output_hidden_states (:obj:`bool`, `optional`):+            Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for+            more detail.+        return_dict (:obj:`bool`, `optional`):+            Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple.+"""++VISUAL_BERT_VQA_SAMPLE = r"""+    Example::+        >>> # Assumption: `get_visual_embeddings(image)` gets the visual embeddings of the image in the batch.+        >>> from transformers import {tokenizer_class}, {model_class}+        >>> import torch++        >>> tokenizer = {tokenizer_class}.from_pretrained('{tokenizer_checkpoint}')+        >>> model = {model_class}.from_pretrained('{checkpoint}')++        >>> text = "Who is eating the apple?"+        >>> inputs = tokenizer(text, return_tensors='pt')+        >>> visual_embeds = get_visual_embeddings(image).unsqueeze(0)+        >>> visual_token_type_ids = torch.ones(visual_embeds.shape[:-1], dtype=torch.long) #example+        >>> visual_attention_mask = torch.ones(visual_embeds.shape[:-1], dtype=torch.float)++        >>> inputs.update({{+            "visual_embeds": visual_embeds,+            "visual_token_type_ids": visual_token_type_ids,+            "visual_attention_mask": visual_attention_mask+        }})++        >>> labels = torch.tensor([[0.0,1.0]]).unsqueeze(0)  # Batch size 1, Num labels 2++        >>> outputs = model(**inputs, labels=labels)+        >>> loss = outputs.loss+        >>> scores = outputs.logits+"""++VISUAL_BERT_NLVR_SAMPLE = r"""+    Example::+        >>> # Assumption: `get_visual_embeddings(image)` gets the visual embeddings of the image in the batch.+        >>> from transformers import {tokenizer_class}, {model_class}+        >>> import torch++        >>> tokenizer = {tokenizer_class}.from_pretrained('{tokenizer_checkpoint}')+        >>> model = {model_class}.from_pretrained('{checkpoint}')++        >>> text = "Who is eating the apple?"+        >>> inputs = tokenizer(text, return_tensors='pt')+        >>> visual_embeds = get_visual_embeddings(image).unsqueeze(0)+        >>> visual_token_type_ids = torch.ones(visual_embeds.shape[:-1], dtype=torch.long) #example+        >>> visual_attention_mask = torch.ones(visual_embeds.shape[:-1], dtype=torch.float)++        >>> inputs.update({{+            "visual_embeds": visual_embeds,+            "visual_token_type_ids": visual_token_type_ids,+            "visual_attention_mask": visual_attention_mask+        }})++        >>> labels = torch.tensor(1).unsqueeze(0)  # Batch size 1, Num choices 2++        >>> outputs = model(**inputs, labels=labels)+        >>> loss = outputs.loss+        >>> scores = outputs.logits+"""+++VISUAL_BERT_VQA_ADVANCED_SAMPLE = r"""+    Example::+        >>> # Assumption: `get_visual_embeddings(image)` gets the visual embeddings of the image in the batch.+        >>> from transformers import {tokenizer_class}, {model_class}+        >>> import torch++        >>> tokenizer = {tokenizer_class}.from_pretrained('{tokenizer_checkpoint}')+        >>> model = {model_class}.from_pretrained('{checkpoint}')++        >>> text = "Who is eating the apple?"+        >>> inputs = tokenizer(text, return_tensors='pt')+        >>> visual_embeds = get_visual_embeddings(image).unsqueeze(0)+        >>> visual_token_type_ids = torch.ones(visual_embeds.shape[:-1], dtype=torch.long)+        >>> visual_attention_mask = torch.ones(visual_embeds.shape[:-1], dtype=torch.float)++        >>> inputs.update({{+            "visual_embeds": visual_embeds,+            "visual_token_type_ids": visual_token_type_ids,+            "visual_attention_mask": visual_attention_mask+        }})++        >>> labels = torch.ones((1, inputs["input_ids"].shape[-1]+visual_embeds.shape[-2])) # Batch size 1++        >>> outputs = model(**inputs, labels=labels)+        >>> loss = outputs.loss+        >>> scores = outputs.logits+"""++VISUAL_BERT_REGION_TO_PHRASE_SAMPLE = r"""+    Example::+        >>> # Assumption: `get_visual_embeddings(image)` gets the visual embeddings of the image in the batch.+        >>> from transformers import {tokenizer_class}, {model_class}+        >>> import torch++        >>> tokenizer = {tokenizer_class}.from_pretrained('{tokenizer_checkpoint}')+        >>> model = {model_class}.from_pretrained('{checkpoint}')++        >>> text = "Who is eating the apple?"+        >>> inputs = tokenizer(text, return_tensors='pt')+        >>> visual_embeds = get_visual_embeddings(image).unsqueeze(0)+        >>> visual_token_type_ids = torch.ones(visual_embeds.shape[:-1], dtype=torch.long) #example+        >>> visual_attention_mask = torch.ones(visual_embeds.shape[:-1], dtype=torch.float)+        >>> region_to_phrase_position = torch.ones((1, inputs["input_ids"].shape[-1]+visual_embeds.shape[-2]))++        >>> inputs.update({{+            "region_to_phrase_position": region_to_phrase_position,+            "visual_embeds": visual_embeds,+            "visual_token_type_ids": visual_token_type_ids,+            "visual_attention_mask": visual_attention_mask+        }})++        >>> labels = torch.ones((1, inputs["input_ids"].shape[-1]+visual_embeds.shape[-2], visual_embeds.shape[-2])) # Batch size 1++        >>> outputs = model(**inputs, labels=labels)+        >>> loss = outputs.loss+        >>> scores = outputs.logits+"""++VISUAL_BERT_PRE_TRAINING_SAMPLE = r"""+    Example::+        >>> # Assumption: `get_visual_embeddings(image)` gets the visual embeddings of the image in the batch.+        >>> from transformers import {tokenizer_class}, {model_class}+        >>> import torch++        >>> tokenizer = {tokenizer_class}.from_pretrained('{tokenizer_checkpoint}')+        >>> model = {model_class}.from_pretrained('{checkpoint}')++        >>> inputs = tokenizer("The capital of France is {mask}.", return_tensors="pt")+        >>> visual_embeds = get_visual_embeddings(image).unsqueeze(0)+        >>> visual_token_type_ids = torch.ones(visual_embeds.shape[:-1], dtype=torch.long) #example+        >>> visual_attention_mask = torch.ones(visual_embeds.shape[:-1], dtype=torch.float)++        >>> inputs.update({{+            "visual_embeds": visual_embeds,+            "visual_token_type_ids": visual_token_type_ids,+            "visual_attention_mask": visual_attention_mask+        }})++        >>> labels = tokenizer("The capital of France is Paris.", return_tensors="pt", padding="max_length", max_length=inputs["input_ids"].shape[-1]+visual_embeds.shape[-2])["input_ids"]+        >>> sentence_image_labels = torch.tensor(1).unsqueeze(0) # Batch_size+++        >>> outputs = model(**inputs, labels=labels, sentence_image_labels=sentence_image_labels)+        >>> loss = outputs.loss+        >>> prediction_logits = outputs.prediction_logits+        >>> seq_relationship_logits = outputs.seq_relationship_logits+"""++VISUAL_BERT_MODEL_SAMPLE = r"""+    Example::+        >>> # Assumption: `get_visual_embeddings(image)` gets the visual embeddings of the image.+        >>> from transformers import {tokenizer_class}, {model_class}+        >>> import torch++        >>> tokenizer = {tokenizer_class}.from_pretrained('{tokenizer_checkpoint}')+        >>> model = {model_class}.from_pretrained('{checkpoint}')++        >>> inputs = tokenizer("The capital of France is Paris.", return_tensors="pt")+        >>> visual_embeds = get_visual_embeddings(image).unsqueeze(0)+        >>> visual_token_type_ids = torch.ones(visual_embeds.shape[:-1], dtype=torch.long) #example+        >>> visual_attention_mask = torch.ones(visual_embeds.shape[:-1], dtype=torch.float)++        >>> inputs.update({{+            "visual_embeds": visual_embeds,+            "visual_token_type_ids": visual_token_type_ids,+            "visual_attention_mask": visual_attention_mask+        }})++        >>> outputs = model(**inputs)++        >>> last_hidden_states = outputs.last_hidden_state+"""++VISUAL_BERT_MULTIPLE_CHOICE_SAMPLE = r"""+    Example::++        >>> from transformers import {tokenizer_class}, {model_class}+        >>> import torch++        >>> tokenizer = {tokenizer_class}.from_pretrained('{tokenizer_checkpoint}')+        >>> model = {model_class}.from_pretrained('{checkpoint}')++        >>> prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced."+        >>> choice0 = "It is eaten with a fork and a knife."+        >>> choice1 = "It is eaten while held in the hand."++        >>> visual_embeds = get_visual_embeddings(image)+        >>> visual_embeds = visual_embeds.expand(1, 2, *visual_embeds.shape) # (batch_size, num_choices, visual_seq_length, visual_embedding_dim)+        >>> visual_token_type_ids = torch.ones(visual_embeds.shape[:-1], dtype=torch.long)+        >>> visual_attention_mask = torch.ones(visual_embeds.shape[:-1], dtype=torch.float)++        >>> labels = torch.tensor(0).unsqueeze(0)  # choice0 is correct (according to Wikipedia ;)), batch size 1++        >>> encoding = tokenizer([[prompt, prompt], [choice0, choice1]], return_tensors='pt', padding=True)+        >>> outputs = model(**{{k: v.unsqueeze(0) for k,v in encoding.items()}}, visual_embeds=visual_embeds, visual_attention_mask=visual_attention_mask, visual_token_type_ids=visual_token_type_ids, labels=labels)  # batch size is 1++        >>> loss = outputs.loss+        >>> logits = outputs.logits+"""+++@add_start_docstrings(+    "The bare VisualBert Model transformer outputting raw hidden-states without any specific head on top.",+    VISUAL_BERT_START_DOCSTRING,+)+class VisualBertModel(VisualBertPreTrainedModel):+    """++    The model can behave as an encoder (with only self-attention) following the architecture described in `Attention is+    all you need <https://arxiv.org/abs/1706.03762>`__ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,+    Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.+    """++    def __init__(self, config, add_pooling_layer=True):+        super().__init__(config)+        self.config = config++        self.embeddings = VisualBertEmbeddings(config)+        self.encoder = VisualBertEncoder(config)++        self.pooler = (+            VisualBertPooler(config) if add_pooling_layer else None+        )  # TO-DO: Check if pooler is needed necessarily or optionally.++        self.bypass_transformer = config.bypass_transformer++        if self.bypass_transformer:+            self.additional_layer = VisualBertLayer(config)++        self.init_weights()++    def get_input_embeddings(self):+        return self.embeddings.word_embeddings++    def set_input_embeddings(self, value):+        self.embeddings.word_embeddings = value++    # TO-CHECK+    def _prune_heads(self, heads_to_prune):+        """+        Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base+        class PreTrainedModel+        """+        for layer, heads in heads_to_prune.items():+            self.encoder.layer[layer].attention.prune_heads(heads)++    @add_start_docstrings_to_model_forward(VISUAL_BERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))+    @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=_CONFIG_FOR_DOC)+    @add_code_sample_docstrings(+        tokenizer_class=_TOKENIZER_FOR_DOC,+        tokenizer_checkpoint=_TOKENIZER_CHECKPOINT,+        checkpoint="gchhablani/visualbert-vqa-coco-pre",+        output_type=BaseModelOutputWithPooling,+        config_class="gchhablani/visualbert-vqa-coco-pre",+        code_sample=VISUAL_BERT_MODEL_SAMPLE,+    )+    def forward(+        self,+        input_ids=None,+        attention_mask=None,+        token_type_ids=None,+        position_ids=None,+        head_mask=None,+        inputs_embeds=None,+        visual_embeds=None,+        visual_attention_mask=None,+        visual_token_type_ids=None,+        image_text_alignment=None,+        output_attentions=None,+        output_hidden_states=None,+        return_dict=None,+    ):++        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions+        output_hidden_states = (+            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states+        )+        return_dict = return_dict if return_dict is not None else self.config.use_return_dict++        if input_ids is not None and inputs_embeds is not None:+            raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")+        elif input_ids is not None:+            input_shape = input_ids.size()+            batch_size, seq_length = input_shape+        elif inputs_embeds is not None:+            input_shape = inputs_embeds.size()[:-1]+            batch_size, seq_length = input_shape+        else:+            raise ValueError("You have to specify either input_ids or inputs_embeds")++        assert (+            visual_embeds is not None+        ), f"`visual_embeds` can not be of type {type(visual_embeds)} when using a VisualBert Model."++        device = input_ids.device if input_ids is not None else inputs_embeds.device++        visual_input_shape = visual_embeds.size()[:-1]++        if attention_mask is None:+            attention_mask = torch.ones(input_shape, device=device)++        if visual_attention_mask is None:+            visual_attention_mask = torch.ones(visual_input_shape, device=device)++        # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]+        # ourselves in which case we just need to make it broadcastable to all heads.++        combined_attention_mask = torch.cat((attention_mask, visual_attention_mask), dim=-1)+        extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(+            combined_attention_mask, [batch_size, input_shape + visual_input_shape], device+        )++        # If a 2D or 3D attention mask is provided for the cross-attention+        # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]+        # if self.config.is_decoder and encoder_hidden_states is not None:+        #     encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()+        #     encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)+        #     if encoder_attention_mask is None:+        #         encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)+        #     encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)+        # else:+        #     encoder_extended_attention_mask = None++        # Prepare head mask if needed+        # 1.0 in head_mask indicate we keep the head+        # attention_probs has shape bsz x n_heads x N x N+        # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]+        # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]+        head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)++        embedding_output = self.embeddings(+            input_ids=input_ids,+            position_ids=position_ids,+            token_type_ids=token_type_ids,+            inputs_embeds=inputs_embeds,+            visual_embeds=visual_embeds,+            visual_token_type_ids=visual_token_type_ids,+            image_text_alignment=image_text_alignment,+        )++        if self.bypass_transformer and visual_embeds is not None:+            text_length = input_ids.size(1)+            text_embedding_output = embedding_output[:, :text_length, :]+            visual_embedding_output = embedding_output[:, text_length:, :]++            text_extended_attention_mask = extended_attention_mask[:, :, text_length, :text_length]++            encoded_outputs = self.encoder(+                text_embedding_output,+                attention_mask=text_extended_attention_mask,+                output_attentions=output_attentions,+                output_hidden_states=output_hidden_states,+                return_dict=return_dict,+            )+            sequence_output = encoded_outputs[0]+            concatenated_input = torch.cat((sequence_output, visual_embedding_output), dim=1)+            sequence_output = self.additional_layer(concatenated_input, extended_attention_mask)+            pooled_output = self.pooler(sequence_output) if self.pooler is not None else None++        else:+            encoder_outputs = self.encoder(+                embedding_output,+                attention_mask=extended_attention_mask,+                head_mask=head_mask,+                output_attentions=output_attentions,+                output_hidden_states=output_hidden_states,+                return_dict=return_dict,+            )+            sequence_output = encoder_outputs[0]++            pooled_output = self.pooler(sequence_output) if self.pooler is not None else None++        if not return_dict:+            return (sequence_output, pooled_output) + encoder_outputs[1:]++        return BaseModelOutputWithPooling(+            last_hidden_state=sequence_output,+            pooler_output=pooled_output,+            hidden_states=encoder_outputs.hidden_states,+            attentions=encoder_outputs.attentions,+        )+++@add_start_docstrings(+    """+    VisualBert Model with two heads on top as done during the pretraining: a `masked language modeling` head and a+    `sentence-image prediction (classification)` head.+    """,+    VISUAL_BERT_START_DOCSTRING,+)+class VisualBertForPreTraining(VisualBertPreTrainedModel):+    def __init__(self, config):+        super().__init__(config)++        self.visual_bert = VisualBertModel(config)+        self.cls = VisualBertPreTrainingHeads(config)++        self.init_weights()++    def get_output_embeddings(self):+        return self.cls.predictions.decoder++    def set_output_embeddings(self, new_embeddings):+        self.cls.predictions.decoder = new_embeddings++    @add_start_docstrings_to_model_forward(VISUAL_BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))+    @replace_return_docstrings(output_type=VisualBertForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)+    @add_code_sample_docstrings(+        tokenizer_class=_TOKENIZER_FOR_DOC,+        tokenizer_checkpoint=_TOKENIZER_CHECKPOINT,+        checkpoint="gchhablani/visualbert-vqa-coco-pre",+        mask="[MASK]",+        output_type=VisualBertForPreTrainingOutput,+        config_class="gchhablani/visualbert-vqa-coco-pre",+        code_sample=VISUAL_BERT_PRE_TRAINING_SAMPLE,+    )+    def forward(+        self,+        input_ids=None,+        attention_mask=None,+        token_type_ids=None,+        position_ids=None,+        head_mask=None,+        inputs_embeds=None,+        visual_embeds=None,+        visual_attention_mask=None,+        visual_token_type_ids=None,+        image_text_alignment=None,+        output_attentions=None,+        output_hidden_states=None,+        return_dict=None,+        labels=None,+        sentence_image_labels=None,+    ):+        r"""+        labels (:obj:`torch.LongTensor` of shape ``(batch_size, total_sequence_length)``, `optional`):+            Labels for computing the masked language modeling loss. Indices should be in ``[-100, 0, ...,+            config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored+            (masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]``+        sentence_image_labels (``torch.LongTensor`` of shape ``(batch_size,)``, `optional`):+            Labels for computing the sentence-image prediction (classification) loss. Input should be a sequence pair+            (see :obj:`input_ids` docstring) Indices should be in ``[0, 1]``:++            - 0 indicates sequence B is a matching pair of sequence A for the given image,+            - 1 indicates sequence B is a random sequence w.r.t A for the given image.++        Returns:+        """+        return_dict = return_dict if return_dict is not None else self.config.use_return_dict++        outputs = self.visual_bert(+            input_ids,+            attention_mask=attention_mask,+            token_type_ids=token_type_ids,+            position_ids=position_ids,+            head_mask=head_mask,+            inputs_embeds=inputs_embeds,+            visual_embeds=visual_embeds,+            visual_attention_mask=visual_attention_mask,+            visual_token_type_ids=visual_token_type_ids,+            image_text_alignment=image_text_alignment,+            output_attentions=output_attentions,+            output_hidden_states=output_hidden_states,+            return_dict=return_dict,+        )++        sequence_output, pooled_output = outputs[:2]+        prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output)++        total_loss = None+        if labels is not None and sentence_image_labels is not None:+            assert labels.size(-1) == attention_mask.size(-1) + visual_attention_mask.size(+                -1+            ), f"The labels provided should have same sequence length as total attention mask. Found labels with sequence length {labels.size(-1)}, expected {attention_mask.size(-1)+ visual_attention_mask.size(-1)}."++            loss_fct = CrossEntropyLoss()+            masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))+            sentence_image_loss = loss_fct(seq_relationship_score.view(-1, 2), sentence_image_labels.view(-1))+            total_loss = masked_lm_loss + sentence_image_loss++        if labels is not None and sentence_image_labels is None:+            assert labels.size(-1) == attention_mask.size(-1) + visual_attention_mask.size(+                -1+            ), f"The labels provided should have same sequence length as total attention mask. Found labels with sequence length {labels.size(-1)}, expected {attention_mask.size(-1)+ visual_attention_mask.size(-1)}."+            loss_fct = CrossEntropyLoss()+            masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))+            total_loss = masked_lm_loss++        if not return_dict:+            output = (prediction_scores, seq_relationship_score) + outputs[2:]+            return ((total_loss,) + output) if total_loss is not None else output++        return VisualBertForPreTrainingOutput(+            loss=total_loss,+            prediction_logits=prediction_scores,+            seq_relationship_logits=seq_relationship_score,+            hidden_states=outputs.hidden_states,+            attentions=outputs.attentions,+        )+++class VisualBertClassificationHead(nn.Module):+    """Head for sentence-level classification tasks."""++    def __init__(self, config):+        super().__init__()+        self.dense = nn.Linear(config.hidden_size, config.hidden_size)+        self.dropout = nn.Dropout(config.hidden_dropout_prob)+        self.out_proj = nn.Linear(config.hidden_size, config.num_labels)++        self.config = config++    def forward(self, features, **kwargs):+        x = features[:, 0, :]  # take <s> token (equiv. to [CLS])+        x = self.dropout(x)+        x = self.dense(x)+        x = ACT2FN[self.config.hidden_act](x)+        x = self.dropout(x)+        x = self.out_proj(x)+        return x+++@add_start_docstrings(+    """+    VisualBert Model with a multiple choice classification head on top (a linear layer on top of the pooled output and+    a softmax) e.g. for VCR tasks.+    """,+    VISUAL_BERT_START_DOCSTRING,+)+class VisualBertForMultipleChoice(VisualBertPreTrainedModel):+    def __init__(self, config):+        super().__init__(config)++        self.visual_bert = VisualBertModel(config)+        # TO-CHECK+        # self.sequence_summary = SequenceSummary(config)+        self.dropout = nn.Dropout(config.hidden_dropout_prob)+        self.cls = nn.Linear(config.hidden_size, 1)++        self.init_weights()++    @add_start_docstrings_to_model_forward(+        VISUAL_BERT_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length")+    )+    @add_code_sample_docstrings(+        tokenizer_class=_TOKENIZER_FOR_DOC,+        tokenizer_checkpoint=_TOKENIZER_CHECKPOINT,+        checkpoint="gchhablani/visualbert-vcr",+        output_type=MultipleChoiceModelOutput,+        config_class="gchhablani/visualbert-vcr",+        code_sample=VISUAL_BERT_MULTIPLE_CHOICE_SAMPLE,+    )+    def forward(+        self,+        input_ids=None,+        attention_mask=None,+        token_type_ids=None,+        position_ids=None,+        head_mask=None,+        inputs_embeds=None,+        visual_embeds=None,+        visual_attention_mask=None,+        visual_token_type_ids=None,+        image_text_alignment=None,+        output_attentions=None,+        output_hidden_states=None,+        return_dict=None,+        labels=None,+    ):+        r"""+        labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):+            Labels for computing the multiple choice classification loss. Indices should be in ``[0, ...,+            num_choices-1]`` where :obj:`num_choices` is the size of the second dimension of the input tensors. (See+            :obj:`input_ids` above)+        """+        return_dict = return_dict if return_dict is not None else self.config.use_return_dict+        num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]++        input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None+        attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None+        token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None+        position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None+        inputs_embeds = (+            inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))+            if inputs_embeds is not None+            else None+        )++        visual_embeds = (+            visual_embeds.view(-1, visual_embeds.size(-2), visual_embeds.size(-1))+            if visual_embeds is not None+            else None+        )+        visual_attention_mask = (+            visual_attention_mask.view(-1, visual_attention_mask.size(-1))+            if visual_attention_mask is not None+            else None+        )+        visual_token_type_ids = (+            visual_token_type_ids.view(-1, visual_token_type_ids.size(-1))+            if visual_token_type_ids is not None+            else None+        )++        outputs = self.visual_bert(+            input_ids,+            attention_mask=attention_mask,+            token_type_ids=token_type_ids,+            position_ids=position_ids,+            head_mask=head_mask,+            inputs_embeds=inputs_embeds,+            visual_embeds=visual_embeds,+            visual_attention_mask=visual_attention_mask,+            visual_token_type_ids=visual_token_type_ids,+            image_text_alignment=image_text_alignment,+            output_attentions=output_attentions,+            output_hidden_states=output_hidden_states,+            return_dict=return_dict,+        )++        _, pooled_output = outputs[0], outputs[1]++        # pooled_output = self.sequence_summary(sequence_output)+        pooled_output = self.dropout(pooled_output)+        logits = self.cls(pooled_output)+        reshaped_logits = logits.view(-1, num_choices)++        loss = None+        if labels is not None:+            loss_fct = CrossEntropyLoss()+            loss = loss_fct(reshaped_logits, labels)++        if not return_dict:+            output = (reshaped_logits,) + outputs[2:]+            return ((loss,) + output) if loss is not None else output++        return MultipleChoiceModelOutput(+            loss=loss,+            logits=reshaped_logits,+            hidden_states=outputs.hidden_states,+            attentions=outputs.attentions,+        )+++@add_start_docstrings(+    """+    VisualBert Model with a classification/regression head on top (a dropout and a linear layer on top of the pooled+    output) for VQA.+    """,+    VISUAL_BERT_START_DOCSTRING,+)+class VisualBertForQuestionAnswering(VisualBertPreTrainedModel):+    def __init__(self, config):+        super().__init__(config)+        self.num_labels = config.num_labels++        self.visual_bert = VisualBertModel(config)+        # TO-CHECK: Can this be done with a `SequenceSummary` layer?+        self.dropout = nn.Dropout(config.hidden_dropout_prob)+        self.cls = nn.Linear(config.hidden_size, config.num_labels)++        self.init_weights()++    @add_start_docstrings_to_model_forward(VISUAL_BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))+    @add_code_sample_docstrings(+        tokenizer_class=_TOKENIZER_FOR_DOC,+        tokenizer_checkpoint=_TOKENIZER_CHECKPOINT,+        checkpoint="gchhablani/visualbert-vqa",+        output_type=SequenceClassifierOutput,+        config_class="gchhablani/visualbert-vqa",+        code_sample=VISUAL_BERT_VQA_SAMPLE,+    )+    def forward(+        self,+        input_ids=None,+        attention_mask=None,+        token_type_ids=None,+        position_ids=None,+        head_mask=None,+        inputs_embeds=None,+        visual_embeds=None,+        visual_attention_mask=None,+        visual_token_type_ids=None,+        image_text_alignment=None,+        output_attentions=None,+        output_hidden_states=None,+        return_dict=None,+        labels=None,+    ):+        r"""+        labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, total_sequence_length)`, `optional`):+            Labels for computing the sequence classification/regression loss. Indices should be in :obj:`[0, ...,+            config.num_labels - 1]`. A KLDLoss is computed between the labels and the returned logits.++        """+        return_dict = return_dict if return_dict is not None else self.config.use_return_dict++        # TO-CHECK+        # Get the index of the last text token+        index_to_gather = attention_mask.sum(1) - 2  # as in original code++        outputs = self.visual_bert(+            input_ids,+            attention_mask=attention_mask,+            token_type_ids=token_type_ids,+            position_ids=position_ids,+            head_mask=head_mask,+            inputs_embeds=inputs_embeds,+            visual_embeds=visual_embeds,+            visual_attention_mask=visual_attention_mask,+            visual_token_type_ids=visual_token_type_ids,+            image_text_alignment=image_text_alignment,+            output_attentions=output_attentions,+            output_hidden_states=output_hidden_states,+            return_dict=return_dict,+        )++        sequence_output = outputs[0]++        # TO-CHECK: From the original code+        pooled_output = torch.gather(+            sequence_output,+            1,+            index_to_gather.unsqueeze(-1).unsqueeze(-1).expand(index_to_gather.size(0), 1, sequence_output.size(-1)),+        )++        pooled_output = self.dropout(pooled_output)+        logits = self.cls(pooled_output)+        reshaped_logits = logits.view(-1, self.num_labels)++        loss = None+        if labels is not None:+            loss_fct = torch.nn.KLDivLoss(reduction="batchmean")+            log_softmax = torch.nn.LogSoftmax(dim=-1)+            reshaped_logits = log_softmax(reshaped_logits)+            loss = loss_fct(reshaped_logits, labels.contiguous())+        if not return_dict:+            output = (reshaped_logits,) + outputs[2:]+            return ((loss,) + output) if loss is not None else output++        return SequenceClassifierOutput(+            loss=loss,+            logits=reshaped_logits,+            hidden_states=outputs.hidden_states,+            attentions=outputs.attentions,+        )+++@add_start_docstrings(+    """+    VisualBert Model with a MLM head on top for VQA tasks.+    """,+    VISUAL_BERT_START_DOCSTRING,+)+class VisualBertForQuestionAnsweringAdvanced(VisualBertPreTrainedModel):+    def __init__(self, config):+        super().__init__(config)++        self.visual_bert = VisualBertModel(config)+        self.cls = VisualBertPreTrainingHeads(config)++        self.init_weights()++    @add_start_docstrings_to_model_forward(VISUAL_BERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))+    @add_code_sample_docstrings(+        tokenizer_class=_TOKENIZER_FOR_DOC,+        tokenizer_checkpoint=_TOKENIZER_CHECKPOINT,+        checkpoint="gchhablani/visualbert-vqa-pre",+        output_type=MaskedLMOutput,+        config_class="gchhablani/visualbert-vqa-pre",+        code_sample=VISUAL_BERT_VQA_ADVANCED_SAMPLE,+    )+    def forward(+        self,+        input_ids=None,+        attention_mask=None,+        token_type_ids=None,+        position_ids=None,+        head_mask=None,+        inputs_embeds=None,+        visual_embeds=None,+        visual_attention_mask=None,+        visual_token_type_ids=None,+        image_text_alignment=None,+        output_attentions=None,+        output_hidden_states=None,+        return_dict=None,+        labels=None,+    ):+        r"""+        labels (:obj:`torch.LongTensor` of shape ``(batch_size, total_sequence_length)``, `optional`):+            Labels for computing the masked language modeling loss. Indices should be in ``[-100, 0, ...,+            config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored+            (masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]``++        """+        return_dict = return_dict if return_dict is not None else self.config.use_return_dict++        outputs = self.visual_bert(+            input_ids,+            attention_mask=attention_mask,+            token_type_ids=token_type_ids,+            position_ids=position_ids,+            head_mask=head_mask,+            inputs_embeds=inputs_embeds,+            visual_embeds=visual_embeds,+            visual_attention_mask=visual_attention_mask,+            visual_token_type_ids=visual_token_type_ids,+            image_text_alignment=image_text_alignment,+            output_attentions=output_attentions,+            output_hidden_states=output_hidden_states,+            return_dict=return_dict,+        )++        sequence_output = outputs[0]+        pooled_output = outputs[1]++        prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output)+        loss = None++        if labels is not None:+            loss_fct = CrossEntropyLoss()+            masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))+            loss = masked_lm_loss++        if not return_dict:+            output = (prediction_scores,) + outputs[2:]+            return ((loss,) + output) if loss is not None else output++        return MaskedLMOutput(+            loss=loss,+            logits=prediction_scores,+            hidden_states=outputs.hidden_states,+            attentions=outputs.attentions,+        )+++@add_start_docstrings(+    """+    VisualBert Model with a sequence classification head on top (a dropout and a linear layer on top of the pooled+    output) for Visual Reasoning e.g. for NLVR task.+    """,+    VISUAL_BERT_START_DOCSTRING,+)+class VisualBertForVisualReasoning(VisualBertPreTrainedModel):+    def __init__(self, config):+        super().__init__(config)+        self.num_labels = config.num_labels++        self.visual_bert = VisualBertModel(config)++        # TO-CHECK: Can this be done with a `SequenceSummary` layer?+        self.dropout = nn.Dropout(config.hidden_dropout_prob)+        self.cls = nn.Linear(config.hidden_size, config.num_labels)  # 2++        self.init_weights()++    @add_start_docstrings_to_model_forward(VISUAL_BERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))+    @add_code_sample_docstrings(+        tokenizer_class=_TOKENIZER_FOR_DOC,+        tokenizer_checkpoint=_TOKENIZER_CHECKPOINT,+        checkpoint="gchhablani/visualbert-nlvr2",+        output_type=SequenceClassifierOutput,+        config_class="gchhablani/visualbert-nlvr2",+        code_sample=VISUAL_BERT_NLVR_SAMPLE,+    )+    def forward(+        self,+        input_ids=None,+        attention_mask=None,+        token_type_ids=None,+        position_ids=None,+        head_mask=None,+        inputs_embeds=None,+        visual_embeds=None,+        visual_attention_mask=None,+        visual_token_type_ids=None,+        image_text_alignment=None,+        output_attentions=None,+        output_hidden_states=None,+        return_dict=None,+        labels=None,+    ):+        r"""+        labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):+            Labels for computing the sequence classification/regression loss. Indices should be in :obj:`[0, ...,+            config.num_labels - 1]`. A classification loss is computed (Cross-Entropy) against these labels.+        """+        return_dict = return_dict if return_dict is not None else self.config.use_return_dict++        outputs = self.visual_bert(+            input_ids,+            attention_mask=attention_mask,+            token_type_ids=token_type_ids,+            position_ids=position_ids,+            head_mask=head_mask,+            inputs_embeds=inputs_embeds,+            visual_embeds=visual_embeds,+            visual_attention_mask=visual_attention_mask,+            visual_token_type_ids=visual_token_type_ids,+            image_text_alignment=image_text_alignment,+            output_attentions=output_attentions,+            output_hidden_states=output_hidden_states,+            return_dict=return_dict,+        )++        # sequence_output = outputs[0]+        pooled_output = outputs[1]+        pooled_output = self.dropout(pooled_output)+        logits = self.cls(pooled_output)+        reshaped_logits = logits.contiguous()++        loss = None+        if labels is not None:+            loss_fct = CrossEntropyLoss()+            masked_lm_loss = loss_fct(reshaped_logits, labels.view(-1))+            loss = masked_lm_loss++        if not return_dict:+            output = (logits,) + outputs[2:]+            return ((loss,) + output) if loss is not None else output++        return SequenceClassifierOutput(+            loss=loss,+            logits=reshaped_logits,+            hidden_states=outputs.hidden_states,+            attentions=outputs.attentions,+        )+++class RegionToPhraseAttention(nn.Module):+    def __init__(self, config):+        super().__init__()+        if config.hidden_size % config.num_attention_heads != 0:+            raise ValueError(+                "The hidden size (%d) is not a multiple of the number of attention "+                "heads (%d)" % (config.hidden_size, config.num_attention_heads)+            )+        self.num_attention_heads = 1  # config.num_attention_heads+        self.attention_head_size = int(config.hidden_size / config.num_attention_heads)+        self.all_head_size = self.num_attention_heads * self.attention_head_size++        self.query = nn.Linear(config.hidden_size, self.all_head_size)+        self.key = nn.Linear(config.hidden_size, self.all_head_size)+        self.value = nn.Linear(config.hidden_size, self.all_head_size)++        self.dropout = nn.Dropout(config.attention_probs_dropout_prob)++    def transpose_for_scores(self, x):+        new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)+        x = x.view(*new_x_shape)+        return x.permute(0, 2, 1, 3)++    def forward(self, query, key, attention_mask):+        attention_mask = attention_mask.to(dtype=next(self.parameters()).dtype)+        attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)+        attention_mask = (1.0 - attention_mask) * -10000.0++        mixed_query_layer = self.query(query)+        mixed_key_layer = self.key(key)+        # We don't need value layers+        # mixed_value_layer = self.value(hidden_states)++        query_layer = self.transpose_for_scores(mixed_query_layer)+        key_layer = self.transpose_for_scores(mixed_key_layer)+        # value_layer = self.transpose_for_scores(mixed_value_layer)++        # Take the dot product between "query" and "key" to get the raw attention scores.+        attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))++        attention_scores = attention_scores / math.sqrt(self.attention_head_size)++        attention_scores = attention_scores + attention_mask++        attention_scores = attention_scores.squeeze(1)+        return attention_scores+++@add_start_docstrings(+    """+    VisualBert Model with a Masked Language Modeling head and an attention layer on top for Region-to-Phrase Alignment+    e.g. for Flickr30 Entities task.+    """,+    VISUAL_BERT_START_DOCSTRING,+)+class VisualBertForRegionToPhraseAlignment(VisualBertPreTrainedModel):+    def __init__(self, config):+        super().__init__(config)++        self.visual_bert = VisualBertModel(config)+        self.dropout = nn.Dropout(config.hidden_dropout_prob)+        self.cls = VisualBertPreTrainingHeads(config)+        self.attention = RegionToPhraseAttention(config)++        self.init_weights()++    @add_start_docstrings_to_model_forward(VISUAL_BERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))+    @add_code_sample_docstrings(+        tokenizer_class=_TOKENIZER_FOR_DOC,+        tokenizer_checkpoint=_TOKENIZER_CHECKPOINT,+        checkpoint="gchhablani/visualbert-vqa-coco-pre",+        output_type=SequenceClassifierOutput,+        config_class="gchhablani/visualbert-vqa-coco-pre",+        code_sample=VISUAL_BERT_REGION_TO_PHRASE_SAMPLE,+    )+    def forward(+        self,+        input_ids=None,+        attention_mask=None,+        token_type_ids=None,+        position_ids=None,+        head_mask=None,+        inputs_embeds=None,+        visual_embeds=None,+        visual_attention_mask=None,+        visual_token_type_ids=None,+        image_text_alignment=None,+        output_attentions=None,+        output_hidden_states=None,+        return_dict=None,+        region_to_phrase_position=None,+        labels=None,+    ):+        r"""+        region_to_phrase_position (:obj:`torch.LongTensor` of shape ``(batch_size, total_sequence_length)``, `optional`):+            The positions depicting the position of the image embedding corresponding to the textual tokens.++        labels (:obj:`torch.LongTensor` of shape ``(batch_size, total_sequence_length, visual_sequence_length)``, `optional`):+            Labels for computing the masked language modeling loss. KLDLoss is computed against these labels and the+            outputs from the attention layer.++        """+        assert (+            region_to_phrase_position is not None+        ), "`region_to_phrase_position` should not be None when using Flickr Model."++        return_dict = return_dict if return_dict is not None else self.config.use_return_dict++        outputs = self.visual_bert(+            input_ids,+            attention_mask=attention_mask,+            token_type_ids=token_type_ids,+            position_ids=position_ids,+            head_mask=head_mask,+            inputs_embeds=inputs_embeds,+            visual_embeds=visual_embeds,+            visual_attention_mask=visual_attention_mask,+            visual_token_type_ids=visual_token_type_ids,+            image_text_alignment=image_text_alignment,+            output_attentions=output_attentions,+            output_hidden_states=output_hidden_states,+            return_dict=return_dict,+        )++        sequence_output = outputs[0]++        region_to_phrase_position_mask = (region_to_phrase_position != -1).long()++        # Make the -1 become 0+        region_to_phrase_position = region_to_phrase_position * region_to_phrase_position_mask++        # Selected_positions = batch x selected position x dim+        expanded_region_to_phrase_positions = region_to_phrase_position.unsqueeze(2).expand(+            region_to_phrase_position.size(0), region_to_phrase_position.size(1), sequence_output.size(2)+        )+        selected_positions = sequence_output.gather(1, expanded_region_to_phrase_positions)++        # Visual Features = batch x visual_feature_length x dim+        visual_features = sequence_output[+            :, attention_mask.size(1) :, :+        ]  # This will need separate image and visual masks.+        assert visual_features.size(1) == visual_attention_mask.size(1)

Fixed

gchhablani

comment created time in 5 minutes

Pull request review commenthuggingface/transformers

Add VisualBERT

+# coding=utf-8+# Copyright 2021 The HuggingFace Inc. team. All rights reserved.+#+# Licensed under the Apache License, Version 2.0 (the "License");+# you may not use this file except in compliance with the License.+# You may obtain a copy of the License at+#+#     http://www.apache.org/licenses/LICENSE-2.0+#+# Unless required by applicable law or agreed to in writing, software+# distributed under the License is distributed on an "AS IS" BASIS,+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.+# See the License for the specific language governing permissions and+# limitations under the License.+""" PyTorch VisualBERT model. """+++import math+from copy import deepcopy+from dataclasses import dataclass+from typing import Optional, Tuple++import torch+import torch.utils.checkpoint+from torch import nn+from torch.nn import CrossEntropyLoss, KLDivLoss, LogSoftmax++from ...activations import ACT2FN+from ...file_utils import (+    ModelOutput,+    _prepare_output_docstrings,+    add_start_docstrings,+    add_start_docstrings_to_model_forward,+    replace_return_docstrings,+)+from ...modeling_outputs import (+    BaseModelOutput,+    BaseModelOutputWithPooling,+    MaskedLMOutput,+    MultipleChoiceModelOutput,+    SequenceClassifierOutput,+)+from ...modeling_utils import (+    PreTrainedModel,+    apply_chunking_to_forward,+    find_pruneable_heads_and_indices,+    prune_linear_layer,+)+from ...utils import logging+from .configuration_visual_bert import VisualBertConfig+++logger = logging.get_logger(__name__)++_CONFIG_FOR_DOC = "VisualBertConfig"+_TOKENIZER_FOR_DOC = "BertTokenizer"+_TOKENIZER_CHECKPOINT = "bert-base-uncased"++VISUAL_BERT_PRETRAINED_MODEL_ARCHIVE_LIST = [+    "gchhablani/visualbert-vqa",+    "gchhablani/visualbert-vqa-pre",+    "gchhablani/visualbert-vqa-coco-pre",+    "gchhablani/visualbert-vcr",+    "gchhablani/visualbert-vcr-pre",+    "gchhablani/visualbert-vcr-coco-pre",+    "gchhablani/visualbert-nlvr2",+    "gchhablani/visualbert-nlvr2-pre",+    "gchhablani/visualbert-nlvr2-coco-pre"+    # See all VisualBERT models at https://huggingface.co/models?filter=visual_bert+]+++def mish(x):+    return x * torch.tanh(nn.functional.softplus(x))+++def add_code_sample_docstrings(+    *docstr,+    tokenizer_class=None,+    tokenizer_checkpoint=None,+    checkpoint=None,+    output_type=None,+    config_class=None,+    mask=None,+    model_cls=None,+    code_sample=None+):+    def docstring_decorator(fn):+        # model_class defaults to function's class if not specified otherwise+        model_class = fn.__qualname__.split(".")[0] if model_cls is None else model_cls++        doc_kwargs = dict(+            model_class=model_class,+            tokenizer_class=tokenizer_class,+            checkpoint=checkpoint,+            mask=mask,+            tokenizer_checkpoint=tokenizer_checkpoint,+        )++        output_doc = _prepare_output_docstrings(output_type, config_class) if output_type is not None else ""++        built_doc = code_sample.format(**doc_kwargs)+        fn.__doc__ = (fn.__doc__ or "") + "".join(docstr) + output_doc + built_doc+        return fn++    return docstring_decorator+++class VisualBertEmbeddings(nn.Module):+    """Construct the embeddings from word, position and token_type embeddings and visual embeddings."""++    def __init__(self, config):+        super().__init__()+        self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)+        self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)+        self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)++        # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load+        # any TensorFlow checkpoint file++        self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)+        self.dropout = nn.Dropout(config.hidden_dropout_prob)++        # position_ids (1, len position emb) is contiguous in memory and exported when serialized+        self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))+        self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")++        # For Visual Features+        # Token type and position embedding for image features+        self.visual_token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)+        self.visual_position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)++        if config.special_visual_initialize:+            self.visual_token_type_embeddings.weight = torch.nn.Parameter(+                deepcopy(self.token_type_embeddings.weight.data), requires_grad=True+            )+            self.visual_position_embeddings.weight = torch.nn.Parameter(+                deepcopy(self.position_embeddings.weight.data), requires_grad=True+            )++        self.visual_projection = nn.Linear(config.visual_embedding_dim, config.hidden_size)++    def forward(+        self,+        input_ids=None,+        token_type_ids=None,+        position_ids=None,+        inputs_embeds=None,+        visual_embeds=None,+        visual_token_type_ids=None,+        image_text_alignment=None,+    ):+        if input_ids is not None:+            input_shape = input_ids.size()+        else:+            input_shape = inputs_embeds.size()[:-1]++        seq_length = input_shape[1]++        if position_ids is None:+            position_ids = self.position_ids[:, :seq_length]++        # TO-CHECK: FROM ORIGINAL CODE+        # if input_ids is not None:+        #     position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device)+        #     position_ids = position_ids.unsqueeze(0).expand_as(input_ids)+        # else:+        #     position_ids = torch.arange(seq_length, dtype=torch.long, device=inputs_embeds.device)+        #     position_ids = position_ids.unsqueeze(0).expand(input_shape)++        if inputs_embeds is None:+            inputs_embeds = self.word_embeddings(input_ids)++        if token_type_ids is None:+            token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.input_embeds.device)++        token_type_embeddings = self.token_type_embeddings(token_type_ids)++        embeddings = inputs_embeds + token_type_embeddings+        if self.position_embedding_type == "absolute":+            position_embeddings = self.position_embeddings(position_ids)+            embeddings += position_embeddings++        if visual_embeds is not None:+            if visual_token_type_ids is None:+                visual_token_type_ids = torch.ones(+                    visual_embeds.size()[:-1], dtype=torch.long, device=self.position_ids.device+                )++            visual_embeds = self.visual_projection(visual_embeds)+            visual_token_type_embeddings = self.visual_token_type_embeddings(visual_token_type_ids)++            if image_text_alignment is not None:++                # TO-DO: Find a way to handle this in a better way.+                # image_text_alignment = Batch x image_length x alignment_number.+                # Each element denotes the position of the word corresponding to the image feature. -1 is the padding value.+                image_text_alignment_mask = (image_text_alignment != -1).long()+                # Get rid of the -1.+                image_text_alignment = image_text_alignment_mask * image_text_alignment++                # Batch x image_length x alignment length x dim+                visual_position_embeddings = self.position_embeddings(+                    image_text_alignment+                ) * image_text_alignment_mask.to(dtype=next(self.parameters()).dtype).unsqueeze(-1)+                visual_position_embeddings = visual_position_embeddings.sum(2)++                # We want to averge along the alignment_number dimension.+                image_text_alignment_mask = image_text_alignment_mask.to(dtype=next(self.parameters()).dtype).sum(2)+                image_text_alignment_mask[image_text_alignment_mask == 0] = 1  # Avoid devide by zero error+                visual_position_embeddings = visual_position_embeddings / image_text_alignment_mask.unsqueeze(-1)++                visual_position_ids = torch.zeros(+                    *visual_embeds.size()[:-1], dtype=torch.long, device=visual_embeds.device+                )++                # When fine-tuning the detector , the image_text_alignment is sometimes padded too long.+                if visual_position_embeddings.size(1) != visual_embeds.size(1):+                    assert visual_position_embeddings.size(1) >= visual_embeds.size(1)+                    visual_position_embeddings = visual_position_embeddings[:, : visual_embeds.size(1), :]++                visual_position_embeddings = visual_position_embeddings + self.visual_position_embeddings(+                    visual_position_ids+                )+            else:+                visual_position_ids = torch.zeros(+                    *visual_embeds.size()[:-1], dtype=torch.long, device=visual_embeds.device+                )+                visual_position_embeddings = self.visual_position_embeddings(visual_position_ids)++            visual_embeddings = visual_embeds + visual_position_embeddings + visual_token_type_embeddings++            embeddings = torch.cat((embeddings, visual_embeddings), dim=1)++        embeddings = self.LayerNorm(embeddings)+        embeddings = self.dropout(embeddings)+        return embeddings+++class VisualBertSelfAttention(nn.Module):+    def __init__(self, config):+        super().__init__()+        if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):+            raise ValueError(+                "The hidden size (%d) is not a multiple of the number of attention "+                "heads (%d)" % (config.hidden_size, config.num_attention_heads)+            )++        self.num_attention_heads = config.num_attention_heads+        self.attention_head_size = int(config.hidden_size / config.num_attention_heads)+        self.all_head_size = self.num_attention_heads * self.attention_head_size++        self.query = nn.Linear(config.hidden_size, self.all_head_size)+        self.key = nn.Linear(config.hidden_size, self.all_head_size)+        self.value = nn.Linear(config.hidden_size, self.all_head_size)++        self.dropout = nn.Dropout(config.attention_probs_dropout_prob)++        # TO-CHECK: Config doesn't have this, is this needed? Is it in PreTrainedConfig?+        self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")+        if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":+            self.max_position_embeddings = config.max_position_embeddings+            self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)++    def transpose_for_scores(self, x):+        new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)+        x = x.view(*new_x_shape)+        return x.permute(0, 2, 1, 3)++    def forward(+        self,+        hidden_states,+        attention_mask=None,+        head_mask=None,+        output_attentions=False,+    ):+        mixed_query_layer = self.query(hidden_states)++        key_layer = self.transpose_for_scores(self.key(hidden_states))+        value_layer = self.transpose_for_scores(self.value(hidden_states))++        query_layer = self.transpose_for_scores(mixed_query_layer)++        # Take the dot product between "query" and "key" to get the raw attention scores.+        attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))++        if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":+            seq_length = hidden_states.size()[1]+            position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)+            position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1)+            distance = position_ids_l - position_ids_r+            positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)+            positional_embedding = positional_embedding.to(dtype=query_layer.dtype)  # fp16 compatibility++            if self.position_embedding_type == "relative_key":+                relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)+                attention_scores = attention_scores + relative_position_scores+            elif self.position_embedding_type == "relative_key_query":+                relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)+                relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)+                attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key++        attention_scores = attention_scores / math.sqrt(self.attention_head_size)+        if attention_mask is not None:+            # Apply the attention mask is (precomputed for all layers in VisualBertModel forward() function)+            attention_scores = attention_scores + attention_mask++        # Normalize the attention scores to probabilities.+        attention_probs = nn.Softmax(dim=-1)(attention_scores)++        # This is actually dropping out entire tokens to attend to, which might+        # seem a bit unusual, but is taken from the original Transformer paper.+        attention_probs = self.dropout(attention_probs)++        # Mask heads if we want to+        if head_mask is not None:+            attention_probs = attention_probs * head_mask++        context_layer = torch.matmul(attention_probs, value_layer)++        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()+        new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)+        context_layer = context_layer.view(*new_context_layer_shape)++        outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)++        return outputs+++class VisualBertSelfOutput(nn.Module):+    def __init__(self, config):+        super().__init__()+        self.dense = nn.Linear(config.hidden_size, config.hidden_size)+        self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)+        self.dropout = nn.Dropout(config.hidden_dropout_prob)++    def forward(self, hidden_states, input_tensor):+        hidden_states = self.dense(hidden_states)+        hidden_states = self.dropout(hidden_states)+        hidden_states = self.LayerNorm(hidden_states + input_tensor)+        return hidden_states+++class VisualBertAttention(nn.Module):+    def __init__(self, config):+        super().__init__()+        self.self = VisualBertSelfAttention(config)+        self.output = VisualBertSelfOutput(config)+        self.pruned_heads = set()++    def prune_heads(self, heads):+        if len(heads) == 0:+            return+        heads, index = find_pruneable_heads_and_indices(+            heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads+        )++        # Prune linear layers+        self.self.query = prune_linear_layer(self.self.query, index)+        self.self.key = prune_linear_layer(self.self.key, index)+        self.self.value = prune_linear_layer(self.self.value, index)+        self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)++        # Update hyper params and store pruned heads+        self.self.num_attention_heads = self.self.num_attention_heads - len(heads)+        self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads+        self.pruned_heads = self.pruned_heads.union(heads)++    def forward(+        self,+        hidden_states,+        attention_mask=None,+        head_mask=None,+        output_attentions=False,+    ):+        self_outputs = self.self(+            hidden_states,+            attention_mask,+            head_mask,+            output_attentions,+        )+        attention_output = self.output(self_outputs[0], hidden_states)+        outputs = (attention_output,) + self_outputs[1:]  # add attentions if we output them+        return outputs+++class VisualBertIntermediate(nn.Module):+    def __init__(self, config):+        super().__init__()+        self.dense = nn.Linear(config.hidden_size, config.intermediate_size)+        if isinstance(config.hidden_act, str):+            self.intermediate_act_fn = ACT2FN[config.hidden_act]+        else:+            self.intermediate_act_fn = config.hidden_act++    def forward(self, hidden_states):+        hidden_states = self.dense(hidden_states)+        hidden_states = self.intermediate_act_fn(hidden_states)+        return hidden_states+++class VisualBertOutput(nn.Module):+    def __init__(self, config):+        super().__init__()+        self.dense = nn.Linear(config.intermediate_size, config.hidden_size)+        self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)+        self.dropout = nn.Dropout(config.hidden_dropout_prob)++    def forward(self, hidden_states, input_tensor):+        hidden_states = self.dense(hidden_states)+        hidden_states = self.dropout(hidden_states)+        hidden_states = self.LayerNorm(hidden_states + input_tensor)+        return hidden_states+++class VisualBertLayer(nn.Module):+    def __init__(self, config):+        super().__init__()+        self.chunk_size_feed_forward = config.chunk_size_feed_forward+        self.seq_len_dim = 1+        self.attention = VisualBertAttention(config)+        self.intermediate = VisualBertIntermediate(config)+        self.output = VisualBertOutput(config)++    def forward(+        self,+        hidden_states,+        attention_mask=None,+        head_mask=None,+        output_attentions=False,+    ):+        self_attention_outputs = self.attention(+            hidden_states,+            attention_mask,+            head_mask,+            output_attentions=output_attentions,+        )+        attention_output = self_attention_outputs[0]++        outputs = self_attention_outputs[1:]  # add self attentions if we output attention weights++        layer_output = apply_chunking_to_forward(+            self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output+        )+        outputs = (layer_output,) + outputs++        return outputs++    def feed_forward_chunk(self, attention_output):+        intermediate_output = self.intermediate(attention_output)+        layer_output = self.output(intermediate_output, attention_output)+        return layer_output+++class VisualBertEncoder(nn.Module):+    def __init__(self, config):+        super().__init__()+        self.config = config+        self.layer = nn.ModuleList([VisualBertLayer(config) for _ in range(config.num_hidden_layers)])++    def forward(+        self,+        hidden_states,+        attention_mask=None,+        head_mask=None,+        output_attentions=False,+        output_hidden_states=False,+        return_dict=True,+    ):+        all_hidden_states = () if output_hidden_states else None+        all_self_attentions = () if output_attentions else None++        for i, layer_module in enumerate(self.layer):+            if output_hidden_states:+                all_hidden_states = all_hidden_states + (hidden_states,)++            layer_head_mask = head_mask[i] if head_mask is not None else None++            if getattr(self.config, "gradient_checkpointing", False) and self.training:++                def create_custom_forward(module):+                    def custom_forward(*inputs):+                        return module(*inputs, output_attentions)++                    return custom_forward++                layer_outputs = torch.utils.checkpoint.checkpoint(+                    create_custom_forward(layer_module),+                    hidden_states,+                    attention_mask,+                    layer_head_mask,+                )+            else:+                layer_outputs = layer_module(+                    hidden_states,+                    attention_mask,+                    layer_head_mask,+                    output_attentions,+                )

Fixed

gchhablani

comment created time in 5 minutes

Pull request review commenthuggingface/transformers

Add VisualBERT

+# coding=utf-8+# Copyright 2021 The HuggingFace Inc. team. All rights reserved.+#+# Licensed under the Apache License, Version 2.0 (the "License");+# you may not use this file except in compliance with the License.+# You may obtain a copy of the License at+#+#     http://www.apache.org/licenses/LICENSE-2.0+#+# Unless required by applicable law or agreed to in writing, software+# distributed under the License is distributed on an "AS IS" BASIS,+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.+# See the License for the specific language governing permissions and+# limitations under the License.+""" PyTorch VisualBERT model. """+++import math+from copy import deepcopy+from dataclasses import dataclass+from typing import Optional, Tuple++import torch+import torch.utils.checkpoint+from torch import nn+from torch.nn import CrossEntropyLoss, KLDivLoss, LogSoftmax++from ...activations import ACT2FN+from ...file_utils import (+    ModelOutput,+    _prepare_output_docstrings,+    add_start_docstrings,+    add_start_docstrings_to_model_forward,+    replace_return_docstrings,+)+from ...modeling_outputs import (+    BaseModelOutput,+    BaseModelOutputWithPooling,+    MaskedLMOutput,+    MultipleChoiceModelOutput,+    SequenceClassifierOutput,+)+from ...modeling_utils import (+    PreTrainedModel,+    apply_chunking_to_forward,+    find_pruneable_heads_and_indices,+    prune_linear_layer,+)+from ...utils import logging+from .configuration_visual_bert import VisualBertConfig+++logger = logging.get_logger(__name__)++_CONFIG_FOR_DOC = "VisualBertConfig"+_TOKENIZER_FOR_DOC = "BertTokenizer"+_TOKENIZER_CHECKPOINT = "bert-base-uncased"++VISUAL_BERT_PRETRAINED_MODEL_ARCHIVE_LIST = [+    "gchhablani/visualbert-vqa",+    "gchhablani/visualbert-vqa-pre",+    "gchhablani/visualbert-vqa-coco-pre",+    "gchhablani/visualbert-vcr",+    "gchhablani/visualbert-vcr-pre",+    "gchhablani/visualbert-vcr-coco-pre",+    "gchhablani/visualbert-nlvr2",+    "gchhablani/visualbert-nlvr2-pre",+    "gchhablani/visualbert-nlvr2-coco-pre"+    # See all VisualBERT models at https://huggingface.co/models?filter=visual_bert+]+++def mish(x):+    return x * torch.tanh(nn.functional.softplus(x))+++def add_code_sample_docstrings(+    *docstr,+    tokenizer_class=None,+    tokenizer_checkpoint=None,+    checkpoint=None,+    output_type=None,+    config_class=None,+    mask=None,+    model_cls=None,+    code_sample=None+):+    def docstring_decorator(fn):+        # model_class defaults to function's class if not specified otherwise+        model_class = fn.__qualname__.split(".")[0] if model_cls is None else model_cls++        doc_kwargs = dict(+            model_class=model_class,+            tokenizer_class=tokenizer_class,+            checkpoint=checkpoint,+            mask=mask,+            tokenizer_checkpoint=tokenizer_checkpoint,+        )++        output_doc = _prepare_output_docstrings(output_type, config_class) if output_type is not None else ""++        built_doc = code_sample.format(**doc_kwargs)+        fn.__doc__ = (fn.__doc__ or "") + "".join(docstr) + output_doc + built_doc+        return fn++    return docstring_decorator+++class VisualBertEmbeddings(nn.Module):+    """Construct the embeddings from word, position and token_type embeddings and visual embeddings."""++    def __init__(self, config):+        super().__init__()+        self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)+        self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)+        self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)++        # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load+        # any TensorFlow checkpoint file++        self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)+        self.dropout = nn.Dropout(config.hidden_dropout_prob)++        # position_ids (1, len position emb) is contiguous in memory and exported when serialized+        self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))+        self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")++        # For Visual Features+        # Token type and position embedding for image features+        self.visual_token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)+        self.visual_position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)++        if config.special_visual_initialize:+            self.visual_token_type_embeddings.weight = torch.nn.Parameter(+                deepcopy(self.token_type_embeddings.weight.data), requires_grad=True+            )+            self.visual_position_embeddings.weight = torch.nn.Parameter(+                deepcopy(self.position_embeddings.weight.data), requires_grad=True+            )++        self.visual_projection = nn.Linear(config.visual_embedding_dim, config.hidden_size)++    def forward(+        self,+        input_ids=None,+        token_type_ids=None,+        position_ids=None,+        inputs_embeds=None,+        visual_embeds=None,+        visual_token_type_ids=None,+        image_text_alignment=None,+    ):+        if input_ids is not None:+            input_shape = input_ids.size()+        else:+            input_shape = inputs_embeds.size()[:-1]++        seq_length = input_shape[1]++        if position_ids is None:+            position_ids = self.position_ids[:, :seq_length]++        # TO-CHECK: FROM ORIGINAL CODE+        # if input_ids is not None:+        #     position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device)+        #     position_ids = position_ids.unsqueeze(0).expand_as(input_ids)+        # else:+        #     position_ids = torch.arange(seq_length, dtype=torch.long, device=inputs_embeds.device)+        #     position_ids = position_ids.unsqueeze(0).expand(input_shape)++        if inputs_embeds is None:+            inputs_embeds = self.word_embeddings(input_ids)++        if token_type_ids is None:+            token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.input_embeds.device)++        token_type_embeddings = self.token_type_embeddings(token_type_ids)++        embeddings = inputs_embeds + token_type_embeddings+        if self.position_embedding_type == "absolute":+            position_embeddings = self.position_embeddings(position_ids)+            embeddings += position_embeddings++        if visual_embeds is not None:+            if visual_token_type_ids is None:+                visual_token_type_ids = torch.ones(+                    visual_embeds.size()[:-1], dtype=torch.long, device=self.position_ids.device+                )++            visual_embeds = self.visual_projection(visual_embeds)+            visual_token_type_embeddings = self.visual_token_type_embeddings(visual_token_type_ids)++            if image_text_alignment is not None:++                # TO-DO: Find a way to handle this in a better way.+                # image_text_alignment = Batch x image_length x alignment_number.+                # Each element denotes the position of the word corresponding to the image feature. -1 is the padding value.+                image_text_alignment_mask = (image_text_alignment != -1).long()+                # Get rid of the -1.+                image_text_alignment = image_text_alignment_mask * image_text_alignment++                # Batch x image_length x alignment length x dim+                visual_position_embeddings = self.position_embeddings(+                    image_text_alignment+                ) * image_text_alignment_mask.to(dtype=next(self.parameters()).dtype).unsqueeze(-1)+                visual_position_embeddings = visual_position_embeddings.sum(2)++                # We want to averge along the alignment_number dimension.+                image_text_alignment_mask = image_text_alignment_mask.to(dtype=next(self.parameters()).dtype).sum(2)+                image_text_alignment_mask[image_text_alignment_mask == 0] = 1  # Avoid devide by zero error+                visual_position_embeddings = visual_position_embeddings / image_text_alignment_mask.unsqueeze(-1)++                visual_position_ids = torch.zeros(+                    *visual_embeds.size()[:-1], dtype=torch.long, device=visual_embeds.device+                )++                # When fine-tuning the detector , the image_text_alignment is sometimes padded too long.+                if visual_position_embeddings.size(1) != visual_embeds.size(1):+                    assert visual_position_embeddings.size(1) >= visual_embeds.size(1)+                    visual_position_embeddings = visual_position_embeddings[:, : visual_embeds.size(1), :]++                visual_position_embeddings = visual_position_embeddings + self.visual_position_embeddings(+                    visual_position_ids+                )+            else:+                visual_position_ids = torch.zeros(+                    *visual_embeds.size()[:-1], dtype=torch.long, device=visual_embeds.device+                )+                visual_position_embeddings = self.visual_position_embeddings(visual_position_ids)++            visual_embeddings = visual_embeds + visual_position_embeddings + visual_token_type_embeddings++            embeddings = torch.cat((embeddings, visual_embeddings), dim=1)++        embeddings = self.LayerNorm(embeddings)+        embeddings = self.dropout(embeddings)+        return embeddings+++class VisualBertSelfAttention(nn.Module):+    def __init__(self, config):+        super().__init__()+        if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):+            raise ValueError(+                "The hidden size (%d) is not a multiple of the number of attention "+                "heads (%d)" % (config.hidden_size, config.num_attention_heads)

Fixed

gchhablani

comment created time in 5 minutes

Pull request review commenthuggingface/transformers

Add VisualBERT

+# coding=utf-8+# Copyright 2021 The HuggingFace Inc. team. All rights reserved.+#+# Licensed under the Apache License, Version 2.0 (the "License");+# you may not use this file except in compliance with the License.+# You may obtain a copy of the License at+#+#     http://www.apache.org/licenses/LICENSE-2.0+#+# Unless required by applicable law or agreed to in writing, software+# distributed under the License is distributed on an "AS IS" BASIS,+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.+# See the License for the specific language governing permissions and+# limitations under the License.+""" PyTorch VisualBERT model. """+++import math+from copy import deepcopy+from dataclasses import dataclass+from typing import Optional, Tuple++import torch+import torch.utils.checkpoint+from torch import nn+from torch.nn import CrossEntropyLoss, KLDivLoss, LogSoftmax++from ...activations import ACT2FN+from ...file_utils import (+    ModelOutput,+    _prepare_output_docstrings,+    add_start_docstrings,+    add_start_docstrings_to_model_forward,+    replace_return_docstrings,+)+from ...modeling_outputs import (+    BaseModelOutput,+    BaseModelOutputWithPooling,+    MaskedLMOutput,+    MultipleChoiceModelOutput,+    SequenceClassifierOutput,+)+from ...modeling_utils import (+    PreTrainedModel,+    apply_chunking_to_forward,+    find_pruneable_heads_and_indices,+    prune_linear_layer,+)+from ...utils import logging+from .configuration_visual_bert import VisualBertConfig+++logger = logging.get_logger(__name__)++_CONFIG_FOR_DOC = "VisualBertConfig"+_TOKENIZER_FOR_DOC = "BertTokenizer"+_TOKENIZER_CHECKPOINT = "bert-base-uncased"++VISUAL_BERT_PRETRAINED_MODEL_ARCHIVE_LIST = [+    "gchhablani/visualbert-vqa",+    "gchhablani/visualbert-vqa-pre",+    "gchhablani/visualbert-vqa-coco-pre",+    "gchhablani/visualbert-vcr",+    "gchhablani/visualbert-vcr-pre",+    "gchhablani/visualbert-vcr-coco-pre",+    "gchhablani/visualbert-nlvr2",+    "gchhablani/visualbert-nlvr2-pre",+    "gchhablani/visualbert-nlvr2-coco-pre"+    # See all VisualBERT models at https://huggingface.co/models?filter=visual_bert+]+++def mish(x):+    return x * torch.tanh(nn.functional.softplus(x))+++def add_code_sample_docstrings(+    *docstr,+    tokenizer_class=None,+    tokenizer_checkpoint=None,+    checkpoint=None,+    output_type=None,+    config_class=None,+    mask=None,+    model_cls=None,+    code_sample=None+):+    def docstring_decorator(fn):+        # model_class defaults to function's class if not specified otherwise+        model_class = fn.__qualname__.split(".")[0] if model_cls is None else model_cls++        doc_kwargs = dict(+            model_class=model_class,+            tokenizer_class=tokenizer_class,+            checkpoint=checkpoint,+            mask=mask,+            tokenizer_checkpoint=tokenizer_checkpoint,+        )++        output_doc = _prepare_output_docstrings(output_type, config_class) if output_type is not None else ""++        built_doc = code_sample.format(**doc_kwargs)+        fn.__doc__ = (fn.__doc__ or "") + "".join(docstr) + output_doc + built_doc+        return fn++    return docstring_decorator+++class VisualBertEmbeddings(nn.Module):+    """Construct the embeddings from word, position and token_type embeddings and visual embeddings."""++    def __init__(self, config):+        super().__init__()+        self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)+        self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)+        self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)++        # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load+        # any TensorFlow checkpoint file++        self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)+        self.dropout = nn.Dropout(config.hidden_dropout_prob)++        # position_ids (1, len position emb) is contiguous in memory and exported when serialized+        self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))+        self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")++        # For Visual Features+        # Token type and position embedding for image features+        self.visual_token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)+        self.visual_position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)++        if config.special_visual_initialize:+            self.visual_token_type_embeddings.weight = torch.nn.Parameter(+                deepcopy(self.token_type_embeddings.weight.data), requires_grad=True+            )+            self.visual_position_embeddings.weight = torch.nn.Parameter(+                deepcopy(self.position_embeddings.weight.data), requires_grad=True+            )++        self.visual_projection = nn.Linear(config.visual_embedding_dim, config.hidden_size)++    def forward(+        self,+        input_ids=None,+        token_type_ids=None,+        position_ids=None,+        inputs_embeds=None,+        visual_embeds=None,+        visual_token_type_ids=None,+        image_text_alignment=None,+    ):+        if input_ids is not None:+            input_shape = input_ids.size()+        else:+            input_shape = inputs_embeds.size()[:-1]++        seq_length = input_shape[1]++        if position_ids is None:+            position_ids = self.position_ids[:, :seq_length]++        # TO-CHECK: FROM ORIGINAL CODE+        # if input_ids is not None:+        #     position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device)+        #     position_ids = position_ids.unsqueeze(0).expand_as(input_ids)+        # else:+        #     position_ids = torch.arange(seq_length, dtype=torch.long, device=inputs_embeds.device)+        #     position_ids = position_ids.unsqueeze(0).expand(input_shape)++        if inputs_embeds is None:+            inputs_embeds = self.word_embeddings(input_ids)++        if token_type_ids is None:+            token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.input_embeds.device)++        token_type_embeddings = self.token_type_embeddings(token_type_ids)++        embeddings = inputs_embeds + token_type_embeddings+        if self.position_embedding_type == "absolute":+            position_embeddings = self.position_embeddings(position_ids)+            embeddings += position_embeddings++        if visual_embeds is not None:+            if visual_token_type_ids is None:+                visual_token_type_ids = torch.ones(+                    visual_embeds.size()[:-1], dtype=torch.long, device=self.position_ids.device+                )++            visual_embeds = self.visual_projection(visual_embeds)+            visual_token_type_embeddings = self.visual_token_type_embeddings(visual_token_type_ids)++            if image_text_alignment is not None:++                # TO-DO: Find a way to handle this in a better way.+                # image_text_alignment = Batch x image_length x alignment_number.+                # Each element denotes the position of the word corresponding to the image feature. -1 is the padding value.+                image_text_alignment_mask = (image_text_alignment != -1).long()+                # Get rid of the -1.+                image_text_alignment = image_text_alignment_mask * image_text_alignment++                # Batch x image_length x alignment length x dim+                visual_position_embeddings = self.position_embeddings(+                    image_text_alignment+                ) * image_text_alignment_mask.to(dtype=next(self.parameters()).dtype).unsqueeze(-1)+                visual_position_embeddings = visual_position_embeddings.sum(2)++                # We want to averge along the alignment_number dimension.+                image_text_alignment_mask = image_text_alignment_mask.to(dtype=next(self.parameters()).dtype).sum(2)+                image_text_alignment_mask[image_text_alignment_mask == 0] = 1  # Avoid devide by zero error+                visual_position_embeddings = visual_position_embeddings / image_text_alignment_mask.unsqueeze(-1)++                visual_position_ids = torch.zeros(+                    *visual_embeds.size()[:-1], dtype=torch.long, device=visual_embeds.device+                )++                # When fine-tuning the detector , the image_text_alignment is sometimes padded too long.+                if visual_position_embeddings.size(1) != visual_embeds.size(1):+                    assert visual_position_embeddings.size(1) >= visual_embeds.size(1)

Fixed

gchhablani

comment created time in 5 minutes

Pull request review commenthuggingface/transformers

Add VisualBERT

+# coding=utf-8+# Copyright 2021 The HuggingFace Inc. team. All rights reserved.+#+# Licensed under the Apache License, Version 2.0 (the "License");+# you may not use this file except in compliance with the License.+# You may obtain a copy of the License at+#+#     http://www.apache.org/licenses/LICENSE-2.0+#+# Unless required by applicable law or agreed to in writing, software+# distributed under the License is distributed on an "AS IS" BASIS,+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.+# See the License for the specific language governing permissions and+# limitations under the License.+""" PyTorch VisualBERT model. """+++import math+from copy import deepcopy+from dataclasses import dataclass+from typing import Optional, Tuple++import torch+import torch.utils.checkpoint+from torch import nn+from torch.nn import CrossEntropyLoss, KLDivLoss, LogSoftmax++from ...activations import ACT2FN+from ...file_utils import (+    ModelOutput,+    _prepare_output_docstrings,+    add_start_docstrings,+    add_start_docstrings_to_model_forward,+    replace_return_docstrings,+)+from ...modeling_outputs import (+    BaseModelOutput,+    BaseModelOutputWithPooling,+    MaskedLMOutput,+    MultipleChoiceModelOutput,+    SequenceClassifierOutput,+)+from ...modeling_utils import (+    PreTrainedModel,+    apply_chunking_to_forward,+    find_pruneable_heads_and_indices,+    prune_linear_layer,+)+from ...utils import logging+from .configuration_visual_bert import VisualBertConfig+++logger = logging.get_logger(__name__)++_CONFIG_FOR_DOC = "VisualBertConfig"+_TOKENIZER_FOR_DOC = "BertTokenizer"+_TOKENIZER_CHECKPOINT = "bert-base-uncased"++VISUAL_BERT_PRETRAINED_MODEL_ARCHIVE_LIST = [+    "gchhablani/visualbert-vqa",+    "gchhablani/visualbert-vqa-pre",+    "gchhablani/visualbert-vqa-coco-pre",+    "gchhablani/visualbert-vcr",+    "gchhablani/visualbert-vcr-pre",+    "gchhablani/visualbert-vcr-coco-pre",+    "gchhablani/visualbert-nlvr2",+    "gchhablani/visualbert-nlvr2-pre",+    "gchhablani/visualbert-nlvr2-coco-pre"+    # See all VisualBERT models at https://huggingface.co/models?filter=visual_bert+]+++def mish(x):+    return x * torch.tanh(nn.functional.softplus(x))+++def add_code_sample_docstrings(+    *docstr,+    tokenizer_class=None,+    tokenizer_checkpoint=None,+    checkpoint=None,+    output_type=None,+    config_class=None,+    mask=None,+    model_cls=None,+    code_sample=None+):+    def docstring_decorator(fn):+        # model_class defaults to function's class if not specified otherwise+        model_class = fn.__qualname__.split(".")[0] if model_cls is None else model_cls++        doc_kwargs = dict(+            model_class=model_class,+            tokenizer_class=tokenizer_class,+            checkpoint=checkpoint,+            mask=mask,+            tokenizer_checkpoint=tokenizer_checkpoint,+        )++        output_doc = _prepare_output_docstrings(output_type, config_class) if output_type is not None else ""++        built_doc = code_sample.format(**doc_kwargs)+        fn.__doc__ = (fn.__doc__ or "") + "".join(docstr) + output_doc + built_doc+        return fn++    return docstring_decorator+++class VisualBertEmbeddings(nn.Module):+    """Construct the embeddings from word, position and token_type embeddings and visual embeddings."""++    def __init__(self, config):+        super().__init__()+        self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)+        self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)+        self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)++        # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load+        # any TensorFlow checkpoint file++        self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)+        self.dropout = nn.Dropout(config.hidden_dropout_prob)++        # position_ids (1, len position emb) is contiguous in memory and exported when serialized+        self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))+        self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")++        # For Visual Features+        # Token type and position embedding for image features+        self.visual_token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)+        self.visual_position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)++        if config.special_visual_initialize:+            self.visual_token_type_embeddings.weight = torch.nn.Parameter(+                deepcopy(self.token_type_embeddings.weight.data), requires_grad=True+            )+            self.visual_position_embeddings.weight = torch.nn.Parameter(+                deepcopy(self.position_embeddings.weight.data), requires_grad=True+            )++        self.visual_projection = nn.Linear(config.visual_embedding_dim, config.hidden_size)++    def forward(+        self,+        input_ids=None,+        token_type_ids=None,+        position_ids=None,+        inputs_embeds=None,+        visual_embeds=None,+        visual_token_type_ids=None,+        image_text_alignment=None,+    ):+        if input_ids is not None:+            input_shape = input_ids.size()+        else:+            input_shape = inputs_embeds.size()[:-1]++        seq_length = input_shape[1]++        if position_ids is None:+            position_ids = self.position_ids[:, :seq_length]++        # TO-CHECK: FROM ORIGINAL CODE+        # if input_ids is not None:+        #     position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device)+        #     position_ids = position_ids.unsqueeze(0).expand_as(input_ids)+        # else:+        #     position_ids = torch.arange(seq_length, dtype=torch.long, device=inputs_embeds.device)+        #     position_ids = position_ids.unsqueeze(0).expand(input_shape)++        if inputs_embeds is None:+            inputs_embeds = self.word_embeddings(input_ids)++        if token_type_ids is None:+            token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.input_embeds.device)++        token_type_embeddings = self.token_type_embeddings(token_type_ids)++        embeddings = inputs_embeds + token_type_embeddings+        if self.position_embedding_type == "absolute":+            position_embeddings = self.position_embeddings(position_ids)+            embeddings += position_embeddings++        if visual_embeds is not None:+            if visual_token_type_ids is None:+                visual_token_type_ids = torch.ones(+                    visual_embeds.size()[:-1], dtype=torch.long, device=self.position_ids.device+                )++            visual_embeds = self.visual_projection(visual_embeds)+            visual_token_type_embeddings = self.visual_token_type_embeddings(visual_token_type_ids)++            if image_text_alignment is not None:++                # TO-DO: Find a way to handle this in a better way.+                # image_text_alignment = Batch x image_length x alignment_number.+                # Each element denotes the position of the word corresponding to the image feature. -1 is the padding value.+                image_text_alignment_mask = (image_text_alignment != -1).long()+                # Get rid of the -1.+                image_text_alignment = image_text_alignment_mask * image_text_alignment++                # Batch x image_length x alignment length x dim+                visual_position_embeddings = self.position_embeddings(+                    image_text_alignment+                ) * image_text_alignment_mask.to(dtype=next(self.parameters()).dtype).unsqueeze(-1)+                visual_position_embeddings = visual_position_embeddings.sum(2)++                # We want to averge along the alignment_number dimension.+                image_text_alignment_mask = image_text_alignment_mask.to(dtype=next(self.parameters()).dtype).sum(2)

Done

gchhablani

comment created time in 5 minutes

issue openedjeniyat/StackOverflowNER

Can not download utils_fine_tune.tar.gz from Google driver link

Hi, @jeniyat , thanks for your generous sharing. From #3 you give us a driver link to download resources, but when I open this link I found that I can not download the utils_fine_tune.tar.gz for The download file will exceed the limit, so it cannot be downloaded at this time this error. Could you please give me another share link, thx. My email address is site.rvli@nuaa.edu.cn. Thank you very much.

created time in 10 minutes

issue closedguolinke/TUPE

Discrepancy between the paper and the implementation?

Hi,

First of all, thank you for your amazing work; the paper was quite an interesting read.

I just noticed an inconsistency between the paper and the implementation (specifically, in the reset part), and I was hoping you could clear it up.

image

According to the paper, the positional attention (or positional correlation, in your terms) is "reset" to θ1 or θ2, if the token attending or being attended is the [CLS] token.

image

In the appendix, you provide the definition of the θ values; you essentially have pθ1 and pθ2, which are treated like absolute positional embeddings (as in, they are multiplied by the same query/key matrices as absolute positional embeddings) attending to themselves.

However, in your implementation, you seem to be "resetting" the absolute and the relative correlations separately:

https://github.com/guolinke/TUPE/blob/10ecb61675cd5866f47d2f8e5cedf944c66de5be/fairseq/modules/transformer_sentence_encoder.py#L181-L184 https://github.com/guolinke/TUPE/blob/10ecb61675cd5866f47d2f8e5cedf944c66de5be/fairseq/modules/transformer_sentence_encoder.py#L239-L240

which would mean that there is another bias term added to the equation above.

Could you explain the reasoning behind this change, and how significant (or insignificant) of a change this is to the overall performance?

closed time in 10 minutes

tonyswoo

issue commentguolinke/TUPE

Discrepancy between the paper and the implementation?

That makes sense.

Thank you very much for your time.

tonyswoo

comment created time in 10 minutes

Pull request review commenthuggingface/transformers

Add VisualBERT

+# coding=utf-8+# Copyright 2021 The HuggingFace Inc. team. All rights reserved.+#+# Licensed under the Apache License, Version 2.0 (the "License");+# you may not use this file except in compliance with the License.+# You may obtain a copy of the License at+#+#     http://www.apache.org/licenses/LICENSE-2.0+#+# Unless required by applicable law or agreed to in writing, software+# distributed under the License is distributed on an "AS IS" BASIS,+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.+# See the License for the specific language governing permissions and+# limitations under the License.+""" PyTorch VisualBERT model. """+++import math+import os+from dataclasses import dataclass+from typing import Optional, Tuple++import torch+import torch.utils.checkpoint+from torch import nn+from torch.nn import CrossEntropyLoss, KLDivLoss, LogSoftmax++from ...activations import ACT2FN+from ...file_utils import (+    ModelOutput,+    _prepare_output_docstrings,+    add_start_docstrings,+    add_start_docstrings_to_model_forward,+    replace_return_docstrings,+)+from ...modeling_outputs import (+    BaseModelOutputWithPastAndCrossAttentions,+    BaseModelOutputWithPoolingAndCrossAttentions,+    MaskedLMOutput,+    MultipleChoiceModelOutput,+    SequenceClassifierOutput,+)+from ...modeling_utils import (+    PreTrainedModel,+    apply_chunking_to_forward,+    find_pruneable_heads_and_indices,+    prune_linear_layer,+)+from ...utils import logging+from .configuration_visual_bert import VisualBertConfig+++logger = logging.get_logger(__name__)++_CONFIG_FOR_DOC = "VisualBertConfig"+_TOKENIZER_FOR_DOC = "BertTokenizer"+_TOKENIZER_CHECKPOINT = "bert-base-uncased"++VISUAL_BERT_PRETRAINED_MODEL_ARCHIVE_LIST = [+    "gchhablani/visualbert-vqa",+    "gchhablani/visualbert-vqa-pre",+    "gchhablani/visualbert-vqa-coco-pre",+    "gchhablani/visualbert-vcr",+    "gchhablani/visualbert-vcr-pre",+    "gchhablani/visualbert-vcr-coco-pre",+    "gchhablani/visualbert-nlvr2",+    "gchhablani/visualbert-nlvr2-pre",+    "gchhablani/visualbert-nlvr2-coco-pre"+    # See all VisualBERT models at https://huggingface.co/models?filter=visual_bert+]+++# TO-CHECK+def load_tf_weights_in_visual_bert(model, config, tf_checkpoint_path):+    """Load tf checkpoints in a pytorch model."""+    try:+        import re++        import numpy as np+        import tensorflow as tf+    except ImportError:+        logger.error(+            "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see "+            "https://www.tensorflow.org/install/ for installation instructions."+        )+        raise+    tf_path = os.path.abspath(tf_checkpoint_path)+    logger.info("Converting TensorFlow checkpoint from {}".format(tf_path))+    # Load weights from TF model+    init_vars = tf.train.list_variables(tf_path)+    names = []+    arrays = []+    for name, shape in init_vars:+        logger.info("Loading TF weight {} with shape {}".format(name, shape))+        array = tf.train.load_variable(tf_path, name)+        names.append(name)+        arrays.append(array)++    for name, array in zip(names, arrays):+        name = name.split("/")+        # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v+        # which are not required for using pretrained model+        if any(+            n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"]+            for n in name+        ):+            logger.info("Skipping {}".format("/".join(name)))+            continue+        pointer = model+        for m_name in name:+            if re.fullmatch(r"[A-Za-z]+_\d+", m_name):+                scope_names = re.split(r"_(\d+)", m_name)+            else:+                scope_names = [m_name]+            if scope_names[0] == "kernel" or scope_names[0] == "gamma":+                pointer = getattr(pointer, "weight")+            elif scope_names[0] == "output_bias" or scope_names[0] == "beta":+                pointer = getattr(pointer, "bias")+            elif scope_names[0] == "output_weights":+                pointer = getattr(pointer, "weight")+            elif scope_names[0] == "squad":+                pointer = getattr(pointer, "classifier")+            else:+                try:+                    pointer = getattr(pointer, scope_names[0])+                except AttributeError:+                    logger.info("Skipping {}".format("/".join(name)))+                    continue+            if len(scope_names) >= 2:+                num = int(scope_names[1])+                pointer = pointer[num]+        if m_name[-11:] == "_embeddings":+            pointer = getattr(pointer, "weight")+        elif m_name == "kernel":+            array = np.transpose(array)+        try:+            assert (+                pointer.shape == array.shape+            ), f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched"+        except AssertionError as e:+            e.args += (pointer.shape, array.shape)+            raise+        logger.info("Initialize PyTorch weight {}".format(name))+        pointer.data = torch.from_numpy(array)+    return model+++def mish(x):+    return x * torch.tanh(nn.functional.softplus(x))+++def add_code_sample_docstrings(+    *docstr,+    tokenizer_class=None,+    tokenizer_checkpoint=None,+    checkpoint=None,+    output_type=None,+    config_class=None,+    mask=None,+    model_cls=None,+    code_sample=None+):+    def docstring_decorator(fn):+        # model_class defaults to function's class if not specified otherwise+        model_class = fn.__qualname__.split(".")[0] if model_cls is None else model_cls++        doc_kwargs = dict(+            model_class=model_class,+            tokenizer_class=tokenizer_class,+            checkpoint=checkpoint,+            mask=mask,+            tokenizer_checkpoint=tokenizer_checkpoint,+        )++        output_doc = _prepare_output_docstrings(output_type, config_class) if output_type is not None else ""++        built_doc = code_sample.format(**doc_kwargs)+        fn.__doc__ = (fn.__doc__ or "") + "".join(docstr) + output_doc + built_doc+        return fn++    return docstring_decorator+++# TO-CHECK: Vestige of the original code+++class BertLayerNorm(nn.Module):+    def __init__(self, hidden_size, eps=1e-12):+        """Construct a layernorm module in the TF style (epsilon inside the square root)."""+        super(BertLayerNorm, self).__init__()+        self.weight = nn.Parameter(torch.ones(hidden_size))+        self.bias = nn.Parameter(torch.zeros(hidden_size))+        self.variance_epsilon = eps++    def forward(self, x):+        u = x.mean(-1, keepdim=True)+        s = (x - u).pow(2).mean(-1, keepdim=True)+        x = (x - u) / torch.sqrt(s + self.variance_epsilon)+        return self.weight * x + self.bias+++class VisualBertEmbeddings(nn.Module):+    """Construct the embeddings from word, position and token_type embeddings and visual embeddings."""++    def __init__(self, config):+        super().__init__()+        self.word_embeddings = nn.Embedding(+            config.vocab_size, config.hidden_size+        )  # TO-CHECK: , padding_idx=config.pad_token_id+        self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)+        self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)++        # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load+        # any TensorFlow checkpoint file++        # TO-CHECK+        # self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)+        self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps)  # original eps=1e-12+        self.dropout = nn.Dropout(config.hidden_dropout_prob)++        # TO-CHECK+        # position_ids (1, len position emb) is contiguous in memory and exported when serialized+        # self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))+        self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")++        # For Visual Features+        # Segment and position embedding for image features+        self.visual_token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)+        self.visual_position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)++        # TO-CHECK: Check if register buffer is needed for Visual features+        self.visual_projection = nn.Linear(config.visual_embedding_dim, config.hidden_size)++    # TO-CHECK: Check how to incorporate this. This is being called outside the classes.+    # def special_intialize(self):+    #     ### This is a bit unorthodox. The better way might be to add an inititilizer to AllenNLP.+    #     # This function is used to initialize the token_type_embeddings_visual and positiona_embedding_visual, just incase.+    #     self.token_type_embeddings_visual.weight = torch.nn.Parameter(deepcopy(self.token_type_embeddings.weight.data), requires_grad = True)+    #     self.position_embeddings_visual.weight = torch.nn.Parameter(deepcopy(self.position_embeddings.weight.data), requires_grad = True)+    #     return

Added an option for special_visual_initialize

gchhablani

comment created time in 19 minutes

Pull request review commenthuggingface/transformers

Add VisualBERT

+# coding=utf-8+# Copyright 2021 The HuggingFace Inc. team. All rights reserved.+#+# Licensed under the Apache License, Version 2.0 (the "License");+# you may not use this file except in compliance with the License.+# You may obtain a copy of the License at+#+#     http://www.apache.org/licenses/LICENSE-2.0+#+# Unless required by applicable law or agreed to in writing, software+# distributed under the License is distributed on an "AS IS" BASIS,+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.+# See the License for the specific language governing permissions and+# limitations under the License.+""" PyTorch VisualBERT model. """+++import math+import os+from dataclasses import dataclass+from typing import Optional, Tuple++import torch+import torch.utils.checkpoint+from torch import nn+from torch.nn import CrossEntropyLoss, KLDivLoss, LogSoftmax++from ...activations import ACT2FN+from ...file_utils import (+    ModelOutput,+    _prepare_output_docstrings,+    add_start_docstrings,+    add_start_docstrings_to_model_forward,+    replace_return_docstrings,+)+from ...modeling_outputs import (+    BaseModelOutputWithPastAndCrossAttentions,+    BaseModelOutputWithPoolingAndCrossAttentions,+    MaskedLMOutput,+    MultipleChoiceModelOutput,+    SequenceClassifierOutput,+)+from ...modeling_utils import (+    PreTrainedModel,+    apply_chunking_to_forward,+    find_pruneable_heads_and_indices,+    prune_linear_layer,+)+from ...utils import logging+from .configuration_visual_bert import VisualBertConfig+++logger = logging.get_logger(__name__)++_CONFIG_FOR_DOC = "VisualBertConfig"+_TOKENIZER_FOR_DOC = "BertTokenizer"+_TOKENIZER_CHECKPOINT = "bert-base-uncased"++VISUAL_BERT_PRETRAINED_MODEL_ARCHIVE_LIST = [+    "gchhablani/visualbert-vqa",+    "gchhablani/visualbert-vqa-pre",+    "gchhablani/visualbert-vqa-coco-pre",+    "gchhablani/visualbert-vcr",+    "gchhablani/visualbert-vcr-pre",+    "gchhablani/visualbert-vcr-coco-pre",+    "gchhablani/visualbert-nlvr2",+    "gchhablani/visualbert-nlvr2-pre",+    "gchhablani/visualbert-nlvr2-coco-pre"+    # See all VisualBERT models at https://huggingface.co/models?filter=visual_bert+]+++# TO-CHECK+def load_tf_weights_in_visual_bert(model, config, tf_checkpoint_path):+    """Load tf checkpoints in a pytorch model."""+    try:+        import re++        import numpy as np+        import tensorflow as tf+    except ImportError:+        logger.error(+            "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see "+            "https://www.tensorflow.org/install/ for installation instructions."+        )+        raise+    tf_path = os.path.abspath(tf_checkpoint_path)+    logger.info("Converting TensorFlow checkpoint from {}".format(tf_path))+    # Load weights from TF model+    init_vars = tf.train.list_variables(tf_path)+    names = []+    arrays = []+    for name, shape in init_vars:+        logger.info("Loading TF weight {} with shape {}".format(name, shape))+        array = tf.train.load_variable(tf_path, name)+        names.append(name)+        arrays.append(array)++    for name, array in zip(names, arrays):+        name = name.split("/")+        # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v+        # which are not required for using pretrained model+        if any(+            n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"]+            for n in name+        ):+            logger.info("Skipping {}".format("/".join(name)))+            continue+        pointer = model+        for m_name in name:+            if re.fullmatch(r"[A-Za-z]+_\d+", m_name):+                scope_names = re.split(r"_(\d+)", m_name)+            else:+                scope_names = [m_name]+            if scope_names[0] == "kernel" or scope_names[0] == "gamma":+                pointer = getattr(pointer, "weight")+            elif scope_names[0] == "output_bias" or scope_names[0] == "beta":+                pointer = getattr(pointer, "bias")+            elif scope_names[0] == "output_weights":+                pointer = getattr(pointer, "weight")+            elif scope_names[0] == "squad":+                pointer = getattr(pointer, "classifier")+            else:+                try:+                    pointer = getattr(pointer, scope_names[0])+                except AttributeError:+                    logger.info("Skipping {}".format("/".join(name)))+                    continue+            if len(scope_names) >= 2:+                num = int(scope_names[1])+                pointer = pointer[num]+        if m_name[-11:] == "_embeddings":+            pointer = getattr(pointer, "weight")+        elif m_name == "kernel":+            array = np.transpose(array)+        try:+            assert (+                pointer.shape == array.shape+            ), f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched"+        except AssertionError as e:+            e.args += (pointer.shape, array.shape)+            raise+        logger.info("Initialize PyTorch weight {}".format(name))+        pointer.data = torch.from_numpy(array)+    return model+++def mish(x):+    return x * torch.tanh(nn.functional.softplus(x))+++def add_code_sample_docstrings(+    *docstr,+    tokenizer_class=None,+    tokenizer_checkpoint=None,+    checkpoint=None,+    output_type=None,

Done.

gchhablani

comment created time in 19 minutes

Pull request review commenthuggingface/transformers

Add VisualBERT

+# coding=utf-8+# Copyright 2021 The HuggingFace Inc. team. All rights reserved.+#+# Licensed under the Apache License, Version 2.0 (the "License");+# you may not use this file except in compliance with the License.+# You may obtain a copy of the License at+#+#     http://www.apache.org/licenses/LICENSE-2.0+#+# Unless required by applicable law or agreed to in writing, software+# distributed under the License is distributed on an "AS IS" BASIS,+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.+# See the License for the specific language governing permissions and+# limitations under the License.+""" PyTorch VisualBERT model. """+++import math+from copy import deepcopy+from dataclasses import dataclass+from typing import Optional, Tuple++import torch+import torch.utils.checkpoint+from torch import nn+from torch.nn import CrossEntropyLoss, KLDivLoss, LogSoftmax++from ...activations import ACT2FN+from ...file_utils import (+    ModelOutput,+    _prepare_output_docstrings,+    add_start_docstrings,+    add_start_docstrings_to_model_forward,+    replace_return_docstrings,+)+from ...modeling_outputs import (+    BaseModelOutput,+    BaseModelOutputWithPooling,+    MaskedLMOutput,+    MultipleChoiceModelOutput,+    SequenceClassifierOutput,+)+from ...modeling_utils import (+    PreTrainedModel,+    apply_chunking_to_forward,+    find_pruneable_heads_and_indices,+    prune_linear_layer,+)+from ...utils import logging+from .configuration_visual_bert import VisualBertConfig+++logger = logging.get_logger(__name__)++_CONFIG_FOR_DOC = "VisualBertConfig"+_TOKENIZER_FOR_DOC = "BertTokenizer"+_TOKENIZER_CHECKPOINT = "bert-base-uncased"++VISUAL_BERT_PRETRAINED_MODEL_ARCHIVE_LIST = [+    "gchhablani/visualbert-vqa",+    "gchhablani/visualbert-vqa-pre",+    "gchhablani/visualbert-vqa-coco-pre",+    "gchhablani/visualbert-vcr",+    "gchhablani/visualbert-vcr-pre",+    "gchhablani/visualbert-vcr-coco-pre",+    "gchhablani/visualbert-nlvr2",+    "gchhablani/visualbert-nlvr2-pre",+    "gchhablani/visualbert-nlvr2-coco-pre"+    # See all VisualBERT models at https://huggingface.co/models?filter=visual_bert+]+++def mish(x):+    return x * torch.tanh(nn.functional.softplus(x))+++def add_code_sample_docstrings(+    *docstr,+    tokenizer_class=None,+    tokenizer_checkpoint=None,+    checkpoint=None,+    output_type=None,+    config_class=None,+    mask=None,+    model_cls=None,+    code_sample=None+):+    def docstring_decorator(fn):+        # model_class defaults to function's class if not specified otherwise+        model_class = fn.__qualname__.split(".")[0] if model_cls is None else model_cls++        doc_kwargs = dict(+            model_class=model_class,+            tokenizer_class=tokenizer_class,+            checkpoint=checkpoint,+            mask=mask,+            tokenizer_checkpoint=tokenizer_checkpoint,+        )++        output_doc = _prepare_output_docstrings(output_type, config_class) if output_type is not None else ""++        built_doc = code_sample.format(**doc_kwargs)+        fn.__doc__ = (fn.__doc__ or "") + "".join(docstr) + output_doc + built_doc+        return fn++    return docstring_decorator+++class VisualBertEmbeddings(nn.Module):+    """Construct the embeddings from word, position and token_type embeddings and visual embeddings."""++    def __init__(self, config):+        super().__init__()+        self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)+        self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)+        self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)++        # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load+        # any TensorFlow checkpoint file++        self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)+        self.dropout = nn.Dropout(config.hidden_dropout_prob)++        # position_ids (1, len position emb) is contiguous in memory and exported when serialized+        self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))+        self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")++        # For Visual Features+        # Token type and position embedding for image features+        self.visual_token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)+        self.visual_position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)++        if config.special_visual_initialize:+            self.visual_token_type_embeddings.weight = torch.nn.Parameter(+                deepcopy(self.token_type_embeddings.weight.data), requires_grad=True+            )+            self.visual_position_embeddings.weight = torch.nn.Parameter(+                deepcopy(self.position_embeddings.weight.data), requires_grad=True+            )++        self.visual_projection = nn.Linear(config.visual_embedding_dim, config.hidden_size)++    def forward(+        self,+        input_ids=None,+        token_type_ids=None,+        position_ids=None,+        inputs_embeds=None,+        visual_embeds=None,+        visual_token_type_ids=None,+        image_text_alignment=None,+    ):+        if input_ids is not None:+            input_shape = input_ids.size()+        else:+            input_shape = inputs_embeds.size()[:-1]++        seq_length = input_shape[1]++        if position_ids is None:+            position_ids = self.position_ids[:, :seq_length]++        # TO-CHECK: FROM ORIGINAL CODE+        # if input_ids is not None:+        #     position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device)+        #     position_ids = position_ids.unsqueeze(0).expand_as(input_ids)+        # else:+        #     position_ids = torch.arange(seq_length, dtype=torch.long, device=inputs_embeds.device)+        #     position_ids = position_ids.unsqueeze(0).expand(input_shape)++        if inputs_embeds is None:+            inputs_embeds = self.word_embeddings(input_ids)++        if token_type_ids is None:+            token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.input_embeds.device)++        token_type_embeddings = self.token_type_embeddings(token_type_ids)++        embeddings = inputs_embeds + token_type_embeddings+        if self.position_embedding_type == "absolute":+            position_embeddings = self.position_embeddings(position_ids)+            embeddings += position_embeddings++        if visual_embeds is not None:+            if visual_token_type_ids is None:+                visual_token_type_ids = torch.ones(+                    visual_embeds.size()[:-1], dtype=torch.long, device=self.position_ids.device+                )++            visual_embeds = self.visual_projection(visual_embeds)+            visual_token_type_embeddings = self.visual_token_type_embeddings(visual_token_type_ids)++            if image_text_alignment is not None:++                # TO-DO: Find a way to handle this in a better way.+                # image_text_alignment = Batch x image_length x alignment_number.+                # Each element denotes the position of the word corresponding to the image feature. -1 is the padding value.+                image_text_alignment_mask = (image_text_alignment != -1).long()+                # Get rid of the -1.+                image_text_alignment = image_text_alignment_mask * image_text_alignment++                # Batch x image_length x alignment length x dim+                visual_position_embeddings = self.position_embeddings(+                    image_text_alignment+                ) * image_text_alignment_mask.to(dtype=next(self.parameters()).dtype).unsqueeze(-1)+                visual_position_embeddings = visual_position_embeddings.sum(2)++                # We want to averge along the alignment_number dimension.+                image_text_alignment_mask = image_text_alignment_mask.to(dtype=next(self.parameters()).dtype).sum(2)+                image_text_alignment_mask[image_text_alignment_mask == 0] = 1  # Avoid devide by zero error+                visual_position_embeddings = visual_position_embeddings / image_text_alignment_mask.unsqueeze(-1)++                visual_position_ids = torch.zeros(+                    *visual_embeds.size()[:-1], dtype=torch.long, device=visual_embeds.device+                )++                # When fine-tuning the detector , the image_text_alignment is sometimes padded too long.+                if visual_position_embeddings.size(1) != visual_embeds.size(1):+                    assert visual_position_embeddings.size(1) >= visual_embeds.size(1)+                    visual_position_embeddings = visual_position_embeddings[:, : visual_embeds.size(1), :]++                visual_position_embeddings = visual_position_embeddings + self.visual_position_embeddings(+                    visual_position_ids+                )+            else:+                visual_position_ids = torch.zeros(+                    *visual_embeds.size()[:-1], dtype=torch.long, device=visual_embeds.device+                )+                visual_position_embeddings = self.visual_position_embeddings(visual_position_ids)++            visual_embeddings = visual_embeds + visual_position_embeddings + visual_token_type_embeddings++            embeddings = torch.cat((embeddings, visual_embeddings), dim=1)++        embeddings = self.LayerNorm(embeddings)+        embeddings = self.dropout(embeddings)+        return embeddings+++class VisualBertSelfAttention(nn.Module):+    def __init__(self, config):+        super().__init__()+        if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):+            raise ValueError(+                "The hidden size (%d) is not a multiple of the number of attention "+                "heads (%d)" % (config.hidden_size, config.num_attention_heads)+            )++        self.num_attention_heads = config.num_attention_heads+        self.attention_head_size = int(config.hidden_size / config.num_attention_heads)+        self.all_head_size = self.num_attention_heads * self.attention_head_size++        self.query = nn.Linear(config.hidden_size, self.all_head_size)+        self.key = nn.Linear(config.hidden_size, self.all_head_size)+        self.value = nn.Linear(config.hidden_size, self.all_head_size)++        self.dropout = nn.Dropout(config.attention_probs_dropout_prob)++        # TO-CHECK: Config doesn't have this, is this needed? Is it in PreTrainedConfig?+        self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")+        if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":+            self.max_position_embeddings = config.max_position_embeddings+            self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)++    def transpose_for_scores(self, x):+        new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)+        x = x.view(*new_x_shape)+        return x.permute(0, 2, 1, 3)++    def forward(+        self,+        hidden_states,+        attention_mask=None,+        head_mask=None,+        output_attentions=False,+    ):+        mixed_query_layer = self.query(hidden_states)++        key_layer = self.transpose_for_scores(self.key(hidden_states))+        value_layer = self.transpose_for_scores(self.value(hidden_states))++        query_layer = self.transpose_for_scores(mixed_query_layer)++        # Take the dot product between "query" and "key" to get the raw attention scores.+        attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))++        if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":+            seq_length = hidden_states.size()[1]+            position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)+            position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1)+            distance = position_ids_l - position_ids_r+            positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)+            positional_embedding = positional_embedding.to(dtype=query_layer.dtype)  # fp16 compatibility++            if self.position_embedding_type == "relative_key":+                relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)+                attention_scores = attention_scores + relative_position_scores+            elif self.position_embedding_type == "relative_key_query":+                relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)+                relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)+                attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key++        attention_scores = attention_scores / math.sqrt(self.attention_head_size)+        if attention_mask is not None:+            # Apply the attention mask is (precomputed for all layers in VisualBertModel forward() function)+            attention_scores = attention_scores + attention_mask++        # Normalize the attention scores to probabilities.+        attention_probs = nn.Softmax(dim=-1)(attention_scores)++        # This is actually dropping out entire tokens to attend to, which might+        # seem a bit unusual, but is taken from the original Transformer paper.+        attention_probs = self.dropout(attention_probs)++        # Mask heads if we want to+        if head_mask is not None:+            attention_probs = attention_probs * head_mask++        context_layer = torch.matmul(attention_probs, value_layer)++        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()+        new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)+        context_layer = context_layer.view(*new_context_layer_shape)++        outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)++        return outputs+++class VisualBertSelfOutput(nn.Module):+    def __init__(self, config):+        super().__init__()+        self.dense = nn.Linear(config.hidden_size, config.hidden_size)+        self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)+        self.dropout = nn.Dropout(config.hidden_dropout_prob)++    def forward(self, hidden_states, input_tensor):+        hidden_states = self.dense(hidden_states)+        hidden_states = self.dropout(hidden_states)+        hidden_states = self.LayerNorm(hidden_states + input_tensor)+        return hidden_states+++class VisualBertAttention(nn.Module):+    def __init__(self, config):+        super().__init__()+        self.self = VisualBertSelfAttention(config)+        self.output = VisualBertSelfOutput(config)+        self.pruned_heads = set()++    def prune_heads(self, heads):+        if len(heads) == 0:+            return+        heads, index = find_pruneable_heads_and_indices(+            heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads+        )++        # Prune linear layers+        self.self.query = prune_linear_layer(self.self.query, index)+        self.self.key = prune_linear_layer(self.self.key, index)+        self.self.value = prune_linear_layer(self.self.value, index)+        self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)++        # Update hyper params and store pruned heads+        self.self.num_attention_heads = self.self.num_attention_heads - len(heads)+        self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads+        self.pruned_heads = self.pruned_heads.union(heads)++    def forward(+        self,+        hidden_states,+        attention_mask=None,+        head_mask=None,+        output_attentions=False,+    ):+        self_outputs = self.self(+            hidden_states,+            attention_mask,+            head_mask,+            output_attentions,+        )+        attention_output = self.output(self_outputs[0], hidden_states)+        outputs = (attention_output,) + self_outputs[1:]  # add attentions if we output them+        return outputs+++class VisualBertIntermediate(nn.Module):+    def __init__(self, config):+        super().__init__()+        self.dense = nn.Linear(config.hidden_size, config.intermediate_size)+        if isinstance(config.hidden_act, str):+            self.intermediate_act_fn = ACT2FN[config.hidden_act]+        else:+            self.intermediate_act_fn = config.hidden_act++    def forward(self, hidden_states):+        hidden_states = self.dense(hidden_states)+        hidden_states = self.intermediate_act_fn(hidden_states)+        return hidden_states+++class VisualBertOutput(nn.Module):+    def __init__(self, config):+        super().__init__()+        self.dense = nn.Linear(config.intermediate_size, config.hidden_size)+        self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)+        self.dropout = nn.Dropout(config.hidden_dropout_prob)++    def forward(self, hidden_states, input_tensor):+        hidden_states = self.dense(hidden_states)+        hidden_states = self.dropout(hidden_states)+        hidden_states = self.LayerNorm(hidden_states + input_tensor)+        return hidden_states+++class VisualBertLayer(nn.Module):+    def __init__(self, config):+        super().__init__()+        self.chunk_size_feed_forward = config.chunk_size_feed_forward+        self.seq_len_dim = 1+        self.attention = VisualBertAttention(config)+        self.intermediate = VisualBertIntermediate(config)+        self.output = VisualBertOutput(config)++    def forward(+        self,+        hidden_states,+        attention_mask=None,+        head_mask=None,+        output_attentions=False,+    ):+        self_attention_outputs = self.attention(+            hidden_states,+            attention_mask,+            head_mask,+            output_attentions=output_attentions,+        )+        attention_output = self_attention_outputs[0]++        outputs = self_attention_outputs[1:]  # add self attentions if we output attention weights++        layer_output = apply_chunking_to_forward(+            self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output+        )+        outputs = (layer_output,) + outputs++        return outputs++    def feed_forward_chunk(self, attention_output):+        intermediate_output = self.intermediate(attention_output)+        layer_output = self.output(intermediate_output, attention_output)+        return layer_output+++class VisualBertEncoder(nn.Module):+    def __init__(self, config):+        super().__init__()+        self.config = config+        self.layer = nn.ModuleList([VisualBertLayer(config) for _ in range(config.num_hidden_layers)])++    def forward(+        self,+        hidden_states,+        attention_mask=None,+        head_mask=None,+        output_attentions=False,+        output_hidden_states=False,+        return_dict=True,+    ):+        all_hidden_states = () if output_hidden_states else None+        all_self_attentions = () if output_attentions else None++        for i, layer_module in enumerate(self.layer):+            if output_hidden_states:+                all_hidden_states = all_hidden_states + (hidden_states,)++            layer_head_mask = head_mask[i] if head_mask is not None else None++            if getattr(self.config, "gradient_checkpointing", False) and self.training:++                def create_custom_forward(module):+                    def custom_forward(*inputs):+                        return module(*inputs, output_attentions)++                    return custom_forward++                layer_outputs = torch.utils.checkpoint.checkpoint(+                    create_custom_forward(layer_module),+                    hidden_states,+                    attention_mask,+                    layer_head_mask,+                )+            else:+                layer_outputs = layer_module(+                    hidden_states,+                    attention_mask,+                    layer_head_mask,+                    output_attentions,+                )++            hidden_states = layer_outputs[0]+            if output_attentions:+                all_self_attentions = all_self_attentions + (layer_outputs[1],)++        if output_hidden_states:+            all_hidden_states = all_hidden_states + (hidden_states,)++        if not return_dict:+            return tuple(+                v+                for v in [+                    hidden_states,+                    all_hidden_states,+                    all_self_attentions,+                ]+                if v is not None+            )+        return BaseModelOutput(+            last_hidden_state=hidden_states,+            hidden_states=all_hidden_states,+            attentions=all_self_attentions,+        )+++class VisualBertPooler(nn.Module):+    def __init__(self, config):+        super().__init__()+        self.dense = nn.Linear(config.hidden_size, config.hidden_size)+        self.activation = nn.Tanh()++    def forward(self, hidden_states):+        # We "pool" the model by simply taking the hidden state corresponding+        # to the first token.+        first_token_tensor = hidden_states[:, 0]+        pooled_output = self.dense(first_token_tensor)+        pooled_output = self.activation(pooled_output)+        return pooled_output+++class VisualBertPredictionHeadTransform(nn.Module):+    def __init__(self, config):+        super().__init__()+        self.dense = nn.Linear(config.hidden_size, config.hidden_size)+        if isinstance(config.hidden_act, str):+            self.transform_act_fn = ACT2FN[config.hidden_act]+        else:+            self.transform_act_fn = config.hidden_act++        self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)++    def forward(self, hidden_states):+        hidden_states = self.dense(hidden_states)+        hidden_states = self.transform_act_fn(hidden_states)+        hidden_states = self.LayerNorm(hidden_states)+        return hidden_states+++class VisualBertLMPredictionHead(nn.Module):+    def __init__(self, config):+        super().__init__()+        self.transform = VisualBertPredictionHeadTransform(config)++        # The output weights are the same as the input embeddings, but there is+        # an output-only bias for each token.+        self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)++        self.bias = nn.Parameter(torch.zeros(config.vocab_size))++        # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`+        self.decoder.bias = self.bias++    def forward(self, hidden_states):+        hidden_states = self.transform(hidden_states)+        hidden_states = self.decoder(hidden_states)+        return hidden_states+++class VisualBertOnlyMLMHead(nn.Module):+    def __init__(self, config):+        super().__init__()+        self.predictions = VisualBertLMPredictionHead(config)++    def forward(self, sequence_output):+        prediction_scores = self.predictions(sequence_output)+        return prediction_scores+++class VisualBertOnlySIPHead(nn.Module):  # Sentence-Image Prediction+    def __init__(self, config):+        super().__init__()+        self.seq_relationship = nn.Linear(config.hidden_size, 2)++    def forward(self, pooled_output):+        seq_relationship_score = self.seq_relationship(pooled_output)+        return seq_relationship_score+++class VisualBertPreTrainingHeads(nn.Module):+    def __init__(self, config):+        super().__init__()+        self.predictions = VisualBertLMPredictionHead(config)+        self.seq_relationship = nn.Linear(config.hidden_size, 2)++    def forward(self, sequence_output, pooled_output):+        prediction_scores = self.predictions(sequence_output)+        seq_relationship_score = self.seq_relationship(pooled_output)+        return prediction_scores, seq_relationship_score+++class VisualBertPreTrainedModel(PreTrainedModel):+    """+    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained+    models.+    """++    config_class = VisualBertConfig+    base_model_prefix = "visual_bert"+    _keys_to_ignore_on_load_missing = [r"position_ids"]++    def _init_weights(self, module):+        """Initialize the weights"""+        if isinstance(module, (nn.Linear, nn.Embedding)):+            # Slightly different from the TF version which uses truncated_normal for initialization+            # cf https://github.com/pytorch/pytorch/pull/5617+            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)++        elif isinstance(module, nn.LayerNorm):+            module.bias.data.zero_()+            module.weight.data.fill_(1.0)+        if isinstance(module, nn.Linear) and module.bias is not None:+            module.bias.data.zero_()+++@dataclass+class VisualBertForPreTrainingOutput(ModelOutput):+    """+    Output type of :class:`~transformers.VisualBertForPreTraining`.++    Args:+        loss (`optional`, returned when ``labels`` is provided, ``torch.FloatTensor`` of shape :obj:`(1,)`):+            Total loss as the sum of the masked language modeling loss and the sentence-image prediction+            (classification) loss.+        prediction_logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`):+            Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).+        seq_relationship_logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, 2)`):+            Prediction scores of the sentence-image prediction (classification) head (scores of True/False continuation+            before SoftMax).+        hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):+            Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)+            of shape :obj:`(batch_size, sequence_length, hidden_size)`.++            Hidden-states of the model at the output of each layer plus the initial embedding outputs.+        attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):+            Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads,+            sequence_length, sequence_length)`.++            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention+            heads.+    """++    loss: Optional[torch.FloatTensor] = None+    prediction_logits: torch.FloatTensor = None+    seq_relationship_logits: torch.FloatTensor = None+    hidden_states: Optional[Tuple[torch.FloatTensor]] = None+    attentions: Optional[Tuple[torch.FloatTensor]] = None+++VISUAL_BERT_START_DOCSTRING = r"""+    This model inherits from :class:`~transformers.PreTrainedModel`. Check the superclass documentation for the generic+    methods the library implements for all its model (such as downloading or saving, resizing the input embeddings,+    pruning heads etc.)++    This model is also a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`__+    subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to+    general usage and behavior.++    Parameters:+        config (:class:`~transformers.VisualBertConfig`): Model configuration class with all the parameters of the model.+            Initializing with a config file does not load the weights associated with the model, only the+            configuration. Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model+            weights.+"""++VISUAL_BERT_INPUTS_DOCSTRING = r"""+    Args:+        input_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`):+            Indices of input sequence tokens in the vocabulary.++            Indices can be obtained using :class:`~transformers.BertTokenizer`. See+            :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for+            details.++            `What are input IDs? <../glossary.html#input-ids>`__+        attention_mask (:obj:`torch.FloatTensor` of shape :obj:`({0})`, `optional`):+            Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``:++            - 1 for tokens that are **not masked**,+            - 0 for tokens that are **masked**.++            `What are attention masks? <../glossary.html#attention-mask>`__+        token_type_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`, `optional`):+            Segment token indices to indicate first and second portions of the inputs. Indices are selected in ``[0,+            1]``:++            - 0 corresponds to a `sentence A` token,+            - 1 corresponds to a `sentence B` token.++            `What are token type IDs? <../glossary.html#token-type-ids>`_+        position_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`, `optional`):+            Indices of positions of each input sequence tokens in the position embeddings. Selected in the range ``[0,+            config.max_position_embeddings - 1]``.++            `What are position IDs? <../glossary.html#position-ids>`_+        head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`):+            Mask to nullify selected heads of the self-attention modules. Mask values selected in ``[0, 1]``:++            - 1 indicates the head is **not masked**,+            - 0 indicates the head is **masked**.++        inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`({0}, hidden_size)`, `optional`):+            Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation.+            This is useful if you want more control over how to convert :obj:`input_ids` indices into associated+            vectors than the model's internal embedding lookup matrix.++        visual_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, visual_seq_length, visual_embedding_dim)`, `optional`):+            The embedded representation of the visual inputs, generally derived using using an object detector.++        visual_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, visual_seq_length)`, `optional`):+            Mask to avoid performing attention on visual embeddings. Mask values selected in ``[0, 1]``:++            - 1 for tokens that are **not masked**,+            - 0 for tokens that are **masked**.++            `What are attention masks? <../glossary.html#attention-mask>`__+        visual_token_type_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, visual_seq_length)`, `optional`):+            Segment token indices to indicate different portions of the visual embeds.++            `What are token type IDs? <../glossary.html#token-type-ids>`_ The authors of VisualBERT set the+            `visual_token_type_ids` to `1` for all tokens.++        image_text_alignment (:obj:`torch.LongTensor` of shape :obj:`(batch_size, visual_seq_length, alignment_number)`, `optional`):+            Image-Text alignment uses to decide the position IDs of the visual embeddings.++        output_attentions (:obj:`bool`, `optional`):+            Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned+            tensors for more detail.+        output_hidden_states (:obj:`bool`, `optional`):+            Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for+            more detail.+        return_dict (:obj:`bool`, `optional`):+            Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple.+"""++VISUAL_BERT_VQA_SAMPLE = r"""+    Example::+        >>> # Assumption: `get_visual_embeddings(image)` gets the visual embeddings of the image in the batch.+        >>> from transformers import {tokenizer_class}, {model_class}+        >>> import torch++        >>> tokenizer = {tokenizer_class}.from_pretrained('{tokenizer_checkpoint}')+        >>> model = {model_class}.from_pretrained('{checkpoint}')++        >>> text = "Who is eating the apple?"+        >>> inputs = tokenizer(text, return_tensors='pt')+        >>> visual_embeds = get_visual_embeddings(image).unsqueeze(0)+        >>> visual_token_type_ids = torch.ones(visual_embeds.shape[:-1], dtype=torch.long) #example+        >>> visual_attention_mask = torch.ones(visual_embeds.shape[:-1], dtype=torch.float)++        >>> inputs.update({{+            "visual_embeds": visual_embeds,+            "visual_token_type_ids": visual_token_type_ids,+            "visual_attention_mask": visual_attention_mask+        }})++        >>> labels = torch.tensor([[0.0,1.0]]).unsqueeze(0)  # Batch size 1, Num labels 2++        >>> outputs = model(**inputs, labels=labels)+        >>> loss = outputs.loss+        >>> scores = outputs.logits+"""++VISUAL_BERT_NLVR_SAMPLE = r"""+    Example::+        >>> # Assumption: `get_visual_embeddings(image)` gets the visual embeddings of the image in the batch.+        >>> from transformers import {tokenizer_class}, {model_class}+        >>> import torch++        >>> tokenizer = {tokenizer_class}.from_pretrained('{tokenizer_checkpoint}')+        >>> model = {model_class}.from_pretrained('{checkpoint}')++        >>> text = "Who is eating the apple?"+        >>> inputs = tokenizer(text, return_tensors='pt')+        >>> visual_embeds = get_visual_embeddings(image).unsqueeze(0)+        >>> visual_token_type_ids = torch.ones(visual_embeds.shape[:-1], dtype=torch.long) #example+        >>> visual_attention_mask = torch.ones(visual_embeds.shape[:-1], dtype=torch.float)++        >>> inputs.update({{+            "visual_embeds": visual_embeds,+            "visual_token_type_ids": visual_token_type_ids,+            "visual_attention_mask": visual_attention_mask+        }})++        >>> labels = torch.tensor(1).unsqueeze(0)  # Batch size 1, Num choices 2++        >>> outputs = model(**inputs, labels=labels)+        >>> loss = outputs.loss+        >>> scores = outputs.logits+"""+++VISUAL_BERT_VQA_ADVANCED_SAMPLE = r"""+    Example::+        >>> # Assumption: `get_visual_embeddings(image)` gets the visual embeddings of the image in the batch.+        >>> from transformers import {tokenizer_class}, {model_class}+        >>> import torch++        >>> tokenizer = {tokenizer_class}.from_pretrained('{tokenizer_checkpoint}')+        >>> model = {model_class}.from_pretrained('{checkpoint}')++        >>> text = "Who is eating the apple?"+        >>> inputs = tokenizer(text, return_tensors='pt')+        >>> visual_embeds = get_visual_embeddings(image).unsqueeze(0)+        >>> visual_token_type_ids = torch.ones(visual_embeds.shape[:-1], dtype=torch.long)+        >>> visual_attention_mask = torch.ones(visual_embeds.shape[:-1], dtype=torch.float)++        >>> inputs.update({{+            "visual_embeds": visual_embeds,+            "visual_token_type_ids": visual_token_type_ids,+            "visual_attention_mask": visual_attention_mask+        }})++        >>> labels = torch.ones((1, inputs["input_ids"].shape[-1]+visual_embeds.shape[-2])) # Batch size 1++        >>> outputs = model(**inputs, labels=labels)+        >>> loss = outputs.loss+        >>> scores = outputs.logits+"""++VISUAL_BERT_REGION_TO_PHRASE_SAMPLE = r"""+    Example::+        >>> # Assumption: `get_visual_embeddings(image)` gets the visual embeddings of the image in the batch.+        >>> from transformers import {tokenizer_class}, {model_class}+        >>> import torch++        >>> tokenizer = {tokenizer_class}.from_pretrained('{tokenizer_checkpoint}')+        >>> model = {model_class}.from_pretrained('{checkpoint}')++        >>> text = "Who is eating the apple?"+        >>> inputs = tokenizer(text, return_tensors='pt')+        >>> visual_embeds = get_visual_embeddings(image).unsqueeze(0)+        >>> visual_token_type_ids = torch.ones(visual_embeds.shape[:-1], dtype=torch.long) #example+        >>> visual_attention_mask = torch.ones(visual_embeds.shape[:-1], dtype=torch.float)+        >>> region_to_phrase_position = torch.ones((1, inputs["input_ids"].shape[-1]+visual_embeds.shape[-2]))++        >>> inputs.update({{+            "region_to_phrase_position": region_to_phrase_position,+            "visual_embeds": visual_embeds,+            "visual_token_type_ids": visual_token_type_ids,+            "visual_attention_mask": visual_attention_mask+        }})++        >>> labels = torch.ones((1, inputs["input_ids"].shape[-1]+visual_embeds.shape[-2], visual_embeds.shape[-2])) # Batch size 1++        >>> outputs = model(**inputs, labels=labels)+        >>> loss = outputs.loss+        >>> scores = outputs.logits+"""++VISUAL_BERT_PRE_TRAINING_SAMPLE = r"""+    Example::+        >>> # Assumption: `get_visual_embeddings(image)` gets the visual embeddings of the image in the batch.+        >>> from transformers import {tokenizer_class}, {model_class}+        >>> import torch++        >>> tokenizer = {tokenizer_class}.from_pretrained('{tokenizer_checkpoint}')+        >>> model = {model_class}.from_pretrained('{checkpoint}')++        >>> inputs = tokenizer("The capital of France is {mask}.", return_tensors="pt")+        >>> visual_embeds = get_visual_embeddings(image).unsqueeze(0)+        >>> visual_token_type_ids = torch.ones(visual_embeds.shape[:-1], dtype=torch.long) #example+        >>> visual_attention_mask = torch.ones(visual_embeds.shape[:-1], dtype=torch.float)++        >>> inputs.update({{+            "visual_embeds": visual_embeds,+            "visual_token_type_ids": visual_token_type_ids,+            "visual_attention_mask": visual_attention_mask+        }})++        >>> labels = tokenizer("The capital of France is Paris.", return_tensors="pt", padding="max_length", max_length=inputs["input_ids"].shape[-1]+visual_embeds.shape[-2])["input_ids"]+        >>> sentence_image_labels = torch.tensor(1).unsqueeze(0) # Batch_size+++        >>> outputs = model(**inputs, labels=labels, sentence_image_labels=sentence_image_labels)+        >>> loss = outputs.loss+        >>> prediction_logits = outputs.prediction_logits+        >>> seq_relationship_logits = outputs.seq_relationship_logits+"""++VISUAL_BERT_MODEL_SAMPLE = r"""+    Example::+        >>> # Assumption: `get_visual_embeddings(image)` gets the visual embeddings of the image.+        >>> from transformers import {tokenizer_class}, {model_class}+        >>> import torch++        >>> tokenizer = {tokenizer_class}.from_pretrained('{tokenizer_checkpoint}')+        >>> model = {model_class}.from_pretrained('{checkpoint}')++        >>> inputs = tokenizer("The capital of France is Paris.", return_tensors="pt")+        >>> visual_embeds = get_visual_embeddings(image).unsqueeze(0)+        >>> visual_token_type_ids = torch.ones(visual_embeds.shape[:-1], dtype=torch.long) #example+        >>> visual_attention_mask = torch.ones(visual_embeds.shape[:-1], dtype=torch.float)++        >>> inputs.update({{+            "visual_embeds": visual_embeds,+            "visual_token_type_ids": visual_token_type_ids,+            "visual_attention_mask": visual_attention_mask+        }})++        >>> outputs = model(**inputs)++        >>> last_hidden_states = outputs.last_hidden_state+"""++VISUAL_BERT_MULTIPLE_CHOICE_SAMPLE = r"""+    Example::++        >>> from transformers import {tokenizer_class}, {model_class}+        >>> import torch++        >>> tokenizer = {tokenizer_class}.from_pretrained('{tokenizer_checkpoint}')+        >>> model = {model_class}.from_pretrained('{checkpoint}')++        >>> prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced."+        >>> choice0 = "It is eaten with a fork and a knife."+        >>> choice1 = "It is eaten while held in the hand."++        >>> visual_embeds = get_visual_embeddings(image)+        >>> visual_embeds = visual_embeds.expand(1, 2, *visual_embeds.shape) # (batch_size, num_choices, visual_seq_length, visual_embedding_dim)+        >>> visual_token_type_ids = torch.ones(visual_embeds.shape[:-1], dtype=torch.long)+        >>> visual_attention_mask = torch.ones(visual_embeds.shape[:-1], dtype=torch.float)++        >>> labels = torch.tensor(0).unsqueeze(0)  # choice0 is correct (according to Wikipedia ;)), batch size 1++        >>> encoding = tokenizer([[prompt, prompt], [choice0, choice1]], return_tensors='pt', padding=True)+        >>> outputs = model(**{{k: v.unsqueeze(0) for k,v in encoding.items()}}, visual_embeds=visual_embeds, visual_attention_mask=visual_attention_mask, visual_token_type_ids=visual_token_type_ids, labels=labels)  # batch size is 1++        >>> loss = outputs.loss+        >>> logits = outputs.logits+"""+++@add_start_docstrings(+    "The bare VisualBert Model transformer outputting raw hidden-states without any specific head on top.",+    VISUAL_BERT_START_DOCSTRING,+)+class VisualBertModel(VisualBertPreTrainedModel):+    """++    The model can behave as an encoder (with only self-attention) following the architecture described in `Attention is+    all you need <https://arxiv.org/abs/1706.03762>`__ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,+    Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.+    """++    def __init__(self, config, add_pooling_layer=True):+        super().__init__(config)+        self.config = config++        self.embeddings = VisualBertEmbeddings(config)+        self.encoder = VisualBertEncoder(config)++        self.pooler = (+            VisualBertPooler(config) if add_pooling_layer else None+        )  # TO-DO: Check if pooler is needed necessarily or optionally.++        self.bypass_transformer = config.bypass_transformer++        if self.bypass_transformer:+            self.additional_layer = VisualBertLayer(config)++        self.init_weights()++    def get_input_embeddings(self):+        return self.embeddings.word_embeddings++    def set_input_embeddings(self, value):+        self.embeddings.word_embeddings = value++    # TO-CHECK+    def _prune_heads(self, heads_to_prune):+        """+        Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base+        class PreTrainedModel+        """+        for layer, heads in heads_to_prune.items():+            self.encoder.layer[layer].attention.prune_heads(heads)++    @add_start_docstrings_to_model_forward(VISUAL_BERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))

Done.

gchhablani

comment created time in 28 minutes

Pull request review commenthuggingface/transformers

Add VisualBERT

+# coding=utf-8+# Copyright 2021 The HuggingFace Inc. team. All rights reserved.+#+# Licensed under the Apache License, Version 2.0 (the "License");+# you may not use this file except in compliance with the License.+# You may obtain a copy of the License at+#+#     http://www.apache.org/licenses/LICENSE-2.0+#+# Unless required by applicable law or agreed to in writing, software+# distributed under the License is distributed on an "AS IS" BASIS,+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.+# See the License for the specific language governing permissions and+# limitations under the License.+""" PyTorch VisualBERT model. """+++import math+from copy import deepcopy+from dataclasses import dataclass+from typing import Optional, Tuple++import torch+import torch.utils.checkpoint+from torch import nn+from torch.nn import CrossEntropyLoss, KLDivLoss, LogSoftmax++from ...activations import ACT2FN+from ...file_utils import (+    ModelOutput,+    _prepare_output_docstrings,+    add_start_docstrings,+    add_start_docstrings_to_model_forward,+    replace_return_docstrings,+)+from ...modeling_outputs import (+    BaseModelOutput,+    BaseModelOutputWithPooling,+    MaskedLMOutput,+    MultipleChoiceModelOutput,+    SequenceClassifierOutput,+)+from ...modeling_utils import (+    PreTrainedModel,+    apply_chunking_to_forward,+    find_pruneable_heads_and_indices,+    prune_linear_layer,+)+from ...utils import logging+from .configuration_visual_bert import VisualBertConfig+++logger = logging.get_logger(__name__)++_CONFIG_FOR_DOC = "VisualBertConfig"+_TOKENIZER_FOR_DOC = "BertTokenizer"+_TOKENIZER_CHECKPOINT = "bert-base-uncased"++VISUAL_BERT_PRETRAINED_MODEL_ARCHIVE_LIST = [+    "gchhablani/visualbert-vqa",+    "gchhablani/visualbert-vqa-pre",+    "gchhablani/visualbert-vqa-coco-pre",+    "gchhablani/visualbert-vcr",+    "gchhablani/visualbert-vcr-pre",+    "gchhablani/visualbert-vcr-coco-pre",+    "gchhablani/visualbert-nlvr2",+    "gchhablani/visualbert-nlvr2-pre",+    "gchhablani/visualbert-nlvr2-coco-pre"+    # See all VisualBERT models at https://huggingface.co/models?filter=visual_bert+]+++def mish(x):+    return x * torch.tanh(nn.functional.softplus(x))+++def add_code_sample_docstrings(+    *docstr,+    tokenizer_class=None,+    tokenizer_checkpoint=None,+    checkpoint=None,+    output_type=None,+    config_class=None,+    mask=None,+    model_cls=None,+    code_sample=None+):+    def docstring_decorator(fn):+        # model_class defaults to function's class if not specified otherwise+        model_class = fn.__qualname__.split(".")[0] if model_cls is None else model_cls++        doc_kwargs = dict(+            model_class=model_class,+            tokenizer_class=tokenizer_class,+            checkpoint=checkpoint,+            mask=mask,+            tokenizer_checkpoint=tokenizer_checkpoint,+        )++        output_doc = _prepare_output_docstrings(output_type, config_class) if output_type is not None else ""++        built_doc = code_sample.format(**doc_kwargs)+        fn.__doc__ = (fn.__doc__ or "") + "".join(docstr) + output_doc + built_doc+        return fn++    return docstring_decorator+++class VisualBertEmbeddings(nn.Module):+    """Construct the embeddings from word, position and token_type embeddings and visual embeddings."""++    def __init__(self, config):+        super().__init__()+        self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)+        self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)+        self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)++        # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load+        # any TensorFlow checkpoint file++        self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)+        self.dropout = nn.Dropout(config.hidden_dropout_prob)++        # position_ids (1, len position emb) is contiguous in memory and exported when serialized+        self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))+        self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")++        # For Visual Features+        # Token type and position embedding for image features+        self.visual_token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)+        self.visual_position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)++        if config.special_visual_initialize:+            self.visual_token_type_embeddings.weight = torch.nn.Parameter(+                deepcopy(self.token_type_embeddings.weight.data), requires_grad=True+            )+            self.visual_position_embeddings.weight = torch.nn.Parameter(+                deepcopy(self.position_embeddings.weight.data), requires_grad=True+            )++        self.visual_projection = nn.Linear(config.visual_embedding_dim, config.hidden_size)++    def forward(+        self,+        input_ids=None,+        token_type_ids=None,+        position_ids=None,+        inputs_embeds=None,+        visual_embeds=None,+        visual_token_type_ids=None,+        image_text_alignment=None,+    ):+        if input_ids is not None:+            input_shape = input_ids.size()+        else:+            input_shape = inputs_embeds.size()[:-1]++        seq_length = input_shape[1]++        if position_ids is None:+            position_ids = self.position_ids[:, :seq_length]++        # TO-CHECK: FROM ORIGINAL CODE+        # if input_ids is not None:+        #     position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device)+        #     position_ids = position_ids.unsqueeze(0).expand_as(input_ids)+        # else:+        #     position_ids = torch.arange(seq_length, dtype=torch.long, device=inputs_embeds.device)+        #     position_ids = position_ids.unsqueeze(0).expand(input_shape)++        if inputs_embeds is None:+            inputs_embeds = self.word_embeddings(input_ids)++        if token_type_ids is None:+            token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.input_embeds.device)++        token_type_embeddings = self.token_type_embeddings(token_type_ids)++        embeddings = inputs_embeds + token_type_embeddings+        if self.position_embedding_type == "absolute":+            position_embeddings = self.position_embeddings(position_ids)+            embeddings += position_embeddings++        if visual_embeds is not None:+            if visual_token_type_ids is None:+                visual_token_type_ids = torch.ones(+                    visual_embeds.size()[:-1], dtype=torch.long, device=self.position_ids.device+                )++            visual_embeds = self.visual_projection(visual_embeds)+            visual_token_type_embeddings = self.visual_token_type_embeddings(visual_token_type_ids)++            if image_text_alignment is not None:++                # TO-DO: Find a way to handle this in a better way.+                # image_text_alignment = Batch x image_length x alignment_number.+                # Each element denotes the position of the word corresponding to the image feature. -1 is the padding value.+                image_text_alignment_mask = (image_text_alignment != -1).long()+                # Get rid of the -1.+                image_text_alignment = image_text_alignment_mask * image_text_alignment++                # Batch x image_length x alignment length x dim+                visual_position_embeddings = self.position_embeddings(+                    image_text_alignment+                ) * image_text_alignment_mask.to(dtype=next(self.parameters()).dtype).unsqueeze(-1)+                visual_position_embeddings = visual_position_embeddings.sum(2)++                # We want to averge along the alignment_number dimension.+                image_text_alignment_mask = image_text_alignment_mask.to(dtype=next(self.parameters()).dtype).sum(2)+                image_text_alignment_mask[image_text_alignment_mask == 0] = 1  # Avoid devide by zero error+                visual_position_embeddings = visual_position_embeddings / image_text_alignment_mask.unsqueeze(-1)++                visual_position_ids = torch.zeros(+                    *visual_embeds.size()[:-1], dtype=torch.long, device=visual_embeds.device+                )++                # When fine-tuning the detector , the image_text_alignment is sometimes padded too long.+                if visual_position_embeddings.size(1) != visual_embeds.size(1):+                    assert visual_position_embeddings.size(1) >= visual_embeds.size(1)+                    visual_position_embeddings = visual_position_embeddings[:, : visual_embeds.size(1), :]++                visual_position_embeddings = visual_position_embeddings + self.visual_position_embeddings(+                    visual_position_ids+                )+            else:+                visual_position_ids = torch.zeros(+                    *visual_embeds.size()[:-1], dtype=torch.long, device=visual_embeds.device+                )+                visual_position_embeddings = self.visual_position_embeddings(visual_position_ids)++            visual_embeddings = visual_embeds + visual_position_embeddings + visual_token_type_embeddings++            embeddings = torch.cat((embeddings, visual_embeddings), dim=1)++        embeddings = self.LayerNorm(embeddings)+        embeddings = self.dropout(embeddings)+        return embeddings+++class VisualBertSelfAttention(nn.Module):+    def __init__(self, config):+        super().__init__()+        if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):+            raise ValueError(+                "The hidden size (%d) is not a multiple of the number of attention "+                "heads (%d)" % (config.hidden_size, config.num_attention_heads)+            )++        self.num_attention_heads = config.num_attention_heads+        self.attention_head_size = int(config.hidden_size / config.num_attention_heads)+        self.all_head_size = self.num_attention_heads * self.attention_head_size++        self.query = nn.Linear(config.hidden_size, self.all_head_size)+        self.key = nn.Linear(config.hidden_size, self.all_head_size)+        self.value = nn.Linear(config.hidden_size, self.all_head_size)++        self.dropout = nn.Dropout(config.attention_probs_dropout_prob)++        # TO-CHECK: Config doesn't have this, is this needed? Is it in PreTrainedConfig?+        self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")+        if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":+            self.max_position_embeddings = config.max_position_embeddings+            self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)++    def transpose_for_scores(self, x):+        new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)+        x = x.view(*new_x_shape)+        return x.permute(0, 2, 1, 3)++    def forward(+        self,+        hidden_states,+        attention_mask=None,+        head_mask=None,+        output_attentions=False,+    ):+        mixed_query_layer = self.query(hidden_states)++        key_layer = self.transpose_for_scores(self.key(hidden_states))+        value_layer = self.transpose_for_scores(self.value(hidden_states))++        query_layer = self.transpose_for_scores(mixed_query_layer)++        # Take the dot product between "query" and "key" to get the raw attention scores.+        attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))++        if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":+            seq_length = hidden_states.size()[1]+            position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)+            position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1)+            distance = position_ids_l - position_ids_r+            positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)+            positional_embedding = positional_embedding.to(dtype=query_layer.dtype)  # fp16 compatibility++            if self.position_embedding_type == "relative_key":+                relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)+                attention_scores = attention_scores + relative_position_scores+            elif self.position_embedding_type == "relative_key_query":+                relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)+                relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)+                attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key++        attention_scores = attention_scores / math.sqrt(self.attention_head_size)+        if attention_mask is not None:+            # Apply the attention mask is (precomputed for all layers in VisualBertModel forward() function)+            attention_scores = attention_scores + attention_mask++        # Normalize the attention scores to probabilities.+        attention_probs = nn.Softmax(dim=-1)(attention_scores)++        # This is actually dropping out entire tokens to attend to, which might+        # seem a bit unusual, but is taken from the original Transformer paper.+        attention_probs = self.dropout(attention_probs)++        # Mask heads if we want to+        if head_mask is not None:+            attention_probs = attention_probs * head_mask++        context_layer = torch.matmul(attention_probs, value_layer)++        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()+        new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)+        context_layer = context_layer.view(*new_context_layer_shape)++        outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)++        return outputs+++class VisualBertSelfOutput(nn.Module):+    def __init__(self, config):+        super().__init__()+        self.dense = nn.Linear(config.hidden_size, config.hidden_size)+        self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)+        self.dropout = nn.Dropout(config.hidden_dropout_prob)++    def forward(self, hidden_states, input_tensor):+        hidden_states = self.dense(hidden_states)+        hidden_states = self.dropout(hidden_states)+        hidden_states = self.LayerNorm(hidden_states + input_tensor)+        return hidden_states+++class VisualBertAttention(nn.Module):+    def __init__(self, config):+        super().__init__()+        self.self = VisualBertSelfAttention(config)+        self.output = VisualBertSelfOutput(config)+        self.pruned_heads = set()++    def prune_heads(self, heads):+        if len(heads) == 0:+            return+        heads, index = find_pruneable_heads_and_indices(+            heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads+        )++        # Prune linear layers+        self.self.query = prune_linear_layer(self.self.query, index)+        self.self.key = prune_linear_layer(self.self.key, index)+        self.self.value = prune_linear_layer(self.self.value, index)+        self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)++        # Update hyper params and store pruned heads+        self.self.num_attention_heads = self.self.num_attention_heads - len(heads)+        self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads+        self.pruned_heads = self.pruned_heads.union(heads)++    def forward(+        self,+        hidden_states,+        attention_mask=None,+        head_mask=None,+        output_attentions=False,+    ):+        self_outputs = self.self(+            hidden_states,+            attention_mask,+            head_mask,+            output_attentions,+        )+        attention_output = self.output(self_outputs[0], hidden_states)+        outputs = (attention_output,) + self_outputs[1:]  # add attentions if we output them+        return outputs+++class VisualBertIntermediate(nn.Module):+    def __init__(self, config):+        super().__init__()+        self.dense = nn.Linear(config.hidden_size, config.intermediate_size)+        if isinstance(config.hidden_act, str):+            self.intermediate_act_fn = ACT2FN[config.hidden_act]+        else:+            self.intermediate_act_fn = config.hidden_act++    def forward(self, hidden_states):+        hidden_states = self.dense(hidden_states)+        hidden_states = self.intermediate_act_fn(hidden_states)+        return hidden_states+++class VisualBertOutput(nn.Module):+    def __init__(self, config):+        super().__init__()+        self.dense = nn.Linear(config.intermediate_size, config.hidden_size)+        self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)+        self.dropout = nn.Dropout(config.hidden_dropout_prob)++    def forward(self, hidden_states, input_tensor):+        hidden_states = self.dense(hidden_states)+        hidden_states = self.dropout(hidden_states)+        hidden_states = self.LayerNorm(hidden_states + input_tensor)+        return hidden_states+++class VisualBertLayer(nn.Module):+    def __init__(self, config):+        super().__init__()+        self.chunk_size_feed_forward = config.chunk_size_feed_forward+        self.seq_len_dim = 1+        self.attention = VisualBertAttention(config)+        self.intermediate = VisualBertIntermediate(config)+        self.output = VisualBertOutput(config)++    def forward(+        self,+        hidden_states,+        attention_mask=None,+        head_mask=None,+        output_attentions=False,+    ):+        self_attention_outputs = self.attention(+            hidden_states,+            attention_mask,+            head_mask,+            output_attentions=output_attentions,+        )+        attention_output = self_attention_outputs[0]++        outputs = self_attention_outputs[1:]  # add self attentions if we output attention weights++        layer_output = apply_chunking_to_forward(+            self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output+        )+        outputs = (layer_output,) + outputs++        return outputs++    def feed_forward_chunk(self, attention_output):+        intermediate_output = self.intermediate(attention_output)+        layer_output = self.output(intermediate_output, attention_output)+        return layer_output+++class VisualBertEncoder(nn.Module):+    def __init__(self, config):+        super().__init__()+        self.config = config+        self.layer = nn.ModuleList([VisualBertLayer(config) for _ in range(config.num_hidden_layers)])++    def forward(+        self,+        hidden_states,+        attention_mask=None,+        head_mask=None,+        output_attentions=False,+        output_hidden_states=False,+        return_dict=True,+    ):+        all_hidden_states = () if output_hidden_states else None+        all_self_attentions = () if output_attentions else None++        for i, layer_module in enumerate(self.layer):+            if output_hidden_states:+                all_hidden_states = all_hidden_states + (hidden_states,)++            layer_head_mask = head_mask[i] if head_mask is not None else None++            if getattr(self.config, "gradient_checkpointing", False) and self.training:++                def create_custom_forward(module):+                    def custom_forward(*inputs):+                        return module(*inputs, output_attentions)++                    return custom_forward++                layer_outputs = torch.utils.checkpoint.checkpoint(+                    create_custom_forward(layer_module),+                    hidden_states,+                    attention_mask,+                    layer_head_mask,+                )+            else:+                layer_outputs = layer_module(+                    hidden_states,+                    attention_mask,+                    layer_head_mask,+                    output_attentions,+                )++            hidden_states = layer_outputs[0]+            if output_attentions:+                all_self_attentions = all_self_attentions + (layer_outputs[1],)++        if output_hidden_states:+            all_hidden_states = all_hidden_states + (hidden_states,)++        if not return_dict:+            return tuple(+                v+                for v in [+                    hidden_states,+                    all_hidden_states,+                    all_self_attentions,+                ]+                if v is not None+            )+        return BaseModelOutput(+            last_hidden_state=hidden_states,+            hidden_states=all_hidden_states,+            attentions=all_self_attentions,+        )+++class VisualBertPooler(nn.Module):+    def __init__(self, config):+        super().__init__()+        self.dense = nn.Linear(config.hidden_size, config.hidden_size)+        self.activation = nn.Tanh()++    def forward(self, hidden_states):+        # We "pool" the model by simply taking the hidden state corresponding+        # to the first token.+        first_token_tensor = hidden_states[:, 0]+        pooled_output = self.dense(first_token_tensor)+        pooled_output = self.activation(pooled_output)+        return pooled_output+++class VisualBertPredictionHeadTransform(nn.Module):+    def __init__(self, config):+        super().__init__()+        self.dense = nn.Linear(config.hidden_size, config.hidden_size)+        if isinstance(config.hidden_act, str):+            self.transform_act_fn = ACT2FN[config.hidden_act]+        else:+            self.transform_act_fn = config.hidden_act++        self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)++    def forward(self, hidden_states):+        hidden_states = self.dense(hidden_states)+        hidden_states = self.transform_act_fn(hidden_states)+        hidden_states = self.LayerNorm(hidden_states)+        return hidden_states+++class VisualBertLMPredictionHead(nn.Module):+    def __init__(self, config):+        super().__init__()+        self.transform = VisualBertPredictionHeadTransform(config)++        # The output weights are the same as the input embeddings, but there is+        # an output-only bias for each token.+        self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)++        self.bias = nn.Parameter(torch.zeros(config.vocab_size))++        # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`+        self.decoder.bias = self.bias++    def forward(self, hidden_states):+        hidden_states = self.transform(hidden_states)+        hidden_states = self.decoder(hidden_states)+        return hidden_states+++class VisualBertOnlyMLMHead(nn.Module):+    def __init__(self, config):+        super().__init__()+        self.predictions = VisualBertLMPredictionHead(config)++    def forward(self, sequence_output):+        prediction_scores = self.predictions(sequence_output)+        return prediction_scores+++class VisualBertOnlySIPHead(nn.Module):  # Sentence-Image Prediction+    def __init__(self, config):+        super().__init__()+        self.seq_relationship = nn.Linear(config.hidden_size, 2)++    def forward(self, pooled_output):+        seq_relationship_score = self.seq_relationship(pooled_output)+        return seq_relationship_score+++class VisualBertPreTrainingHeads(nn.Module):+    def __init__(self, config):+        super().__init__()+        self.predictions = VisualBertLMPredictionHead(config)+        self.seq_relationship = nn.Linear(config.hidden_size, 2)++    def forward(self, sequence_output, pooled_output):+        prediction_scores = self.predictions(sequence_output)+        seq_relationship_score = self.seq_relationship(pooled_output)+        return prediction_scores, seq_relationship_score+++class VisualBertPreTrainedModel(PreTrainedModel):+    """+    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained+    models.+    """++    config_class = VisualBertConfig+    base_model_prefix = "visual_bert"+    _keys_to_ignore_on_load_missing = [r"position_ids"]++    def _init_weights(self, module):+        """Initialize the weights"""+        if isinstance(module, (nn.Linear, nn.Embedding)):+            # Slightly different from the TF version which uses truncated_normal for initialization+            # cf https://github.com/pytorch/pytorch/pull/5617+            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)++        elif isinstance(module, nn.LayerNorm):+            module.bias.data.zero_()+            module.weight.data.fill_(1.0)+        if isinstance(module, nn.Linear) and module.bias is not None:+            module.bias.data.zero_()+++@dataclass+class VisualBertForPreTrainingOutput(ModelOutput):+    """+    Output type of :class:`~transformers.VisualBertForPreTraining`.++    Args:+        loss (`optional`, returned when ``labels`` is provided, ``torch.FloatTensor`` of shape :obj:`(1,)`):+            Total loss as the sum of the masked language modeling loss and the sentence-image prediction+            (classification) loss.+        prediction_logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`):+            Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).+        seq_relationship_logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, 2)`):+            Prediction scores of the sentence-image prediction (classification) head (scores of True/False continuation+            before SoftMax).+        hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):+            Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)+            of shape :obj:`(batch_size, sequence_length, hidden_size)`.++            Hidden-states of the model at the output of each layer plus the initial embedding outputs.+        attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):+            Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads,+            sequence_length, sequence_length)`.++            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention+            heads.+    """++    loss: Optional[torch.FloatTensor] = None+    prediction_logits: torch.FloatTensor = None+    seq_relationship_logits: torch.FloatTensor = None+    hidden_states: Optional[Tuple[torch.FloatTensor]] = None+    attentions: Optional[Tuple[torch.FloatTensor]] = None+++VISUAL_BERT_START_DOCSTRING = r"""+    This model inherits from :class:`~transformers.PreTrainedModel`. Check the superclass documentation for the generic+    methods the library implements for all its model (such as downloading or saving, resizing the input embeddings,+    pruning heads etc.)++    This model is also a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`__+    subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to+    general usage and behavior.++    Parameters:+        config (:class:`~transformers.VisualBertConfig`): Model configuration class with all the parameters of the model.+            Initializing with a config file does not load the weights associated with the model, only the+            configuration. Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model+            weights.+"""++VISUAL_BERT_INPUTS_DOCSTRING = r"""+    Args:+        input_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`):+            Indices of input sequence tokens in the vocabulary.++            Indices can be obtained using :class:`~transformers.BertTokenizer`. See+            :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for+            details.++            `What are input IDs? <../glossary.html#input-ids>`__+        attention_mask (:obj:`torch.FloatTensor` of shape :obj:`({0})`, `optional`):+            Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``:++            - 1 for tokens that are **not masked**,+            - 0 for tokens that are **masked**.++            `What are attention masks? <../glossary.html#attention-mask>`__+        token_type_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`, `optional`):+            Segment token indices to indicate first and second portions of the inputs. Indices are selected in ``[0,+            1]``:++            - 0 corresponds to a `sentence A` token,+            - 1 corresponds to a `sentence B` token.++            `What are token type IDs? <../glossary.html#token-type-ids>`_+        position_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`, `optional`):+            Indices of positions of each input sequence tokens in the position embeddings. Selected in the range ``[0,+            config.max_position_embeddings - 1]``.++            `What are position IDs? <../glossary.html#position-ids>`_+        head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`):+            Mask to nullify selected heads of the self-attention modules. Mask values selected in ``[0, 1]``:++            - 1 indicates the head is **not masked**,+            - 0 indicates the head is **masked**.++        inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`({0}, hidden_size)`, `optional`):+            Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation.+            This is useful if you want more control over how to convert :obj:`input_ids` indices into associated+            vectors than the model's internal embedding lookup matrix.++        visual_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, visual_seq_length, visual_embedding_dim)`, `optional`):+            The embedded representation of the visual inputs, generally derived using using an object detector.++        visual_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, visual_seq_length)`, `optional`):+            Mask to avoid performing attention on visual embeddings. Mask values selected in ``[0, 1]``:++            - 1 for tokens that are **not masked**,+            - 0 for tokens that are **masked**.++            `What are attention masks? <../glossary.html#attention-mask>`__+        visual_token_type_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, visual_seq_length)`, `optional`):+            Segment token indices to indicate different portions of the visual embeds.++            `What are token type IDs? <../glossary.html#token-type-ids>`_ The authors of VisualBERT set the+            `visual_token_type_ids` to `1` for all tokens.++        image_text_alignment (:obj:`torch.LongTensor` of shape :obj:`(batch_size, visual_seq_length, alignment_number)`, `optional`):+            Image-Text alignment uses to decide the position IDs of the visual embeddings.++        output_attentions (:obj:`bool`, `optional`):+            Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned+            tensors for more detail.+        output_hidden_states (:obj:`bool`, `optional`):+            Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for+            more detail.+        return_dict (:obj:`bool`, `optional`):+            Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple.+"""++VISUAL_BERT_VQA_SAMPLE = r"""+    Example::+        >>> # Assumption: `get_visual_embeddings(image)` gets the visual embeddings of the image in the batch.+        >>> from transformers import {tokenizer_class}, {model_class}+        >>> import torch++        >>> tokenizer = {tokenizer_class}.from_pretrained('{tokenizer_checkpoint}')+        >>> model = {model_class}.from_pretrained('{checkpoint}')++        >>> text = "Who is eating the apple?"+        >>> inputs = tokenizer(text, return_tensors='pt')+        >>> visual_embeds = get_visual_embeddings(image).unsqueeze(0)+        >>> visual_token_type_ids = torch.ones(visual_embeds.shape[:-1], dtype=torch.long) #example+        >>> visual_attention_mask = torch.ones(visual_embeds.shape[:-1], dtype=torch.float)++        >>> inputs.update({{+            "visual_embeds": visual_embeds,+            "visual_token_type_ids": visual_token_type_ids,+            "visual_attention_mask": visual_attention_mask+        }})++        >>> labels = torch.tensor([[0.0,1.0]]).unsqueeze(0)  # Batch size 1, Num labels 2++        >>> outputs = model(**inputs, labels=labels)+        >>> loss = outputs.loss+        >>> scores = outputs.logits+"""++VISUAL_BERT_NLVR_SAMPLE = r"""+    Example::+        >>> # Assumption: `get_visual_embeddings(image)` gets the visual embeddings of the image in the batch.+        >>> from transformers import {tokenizer_class}, {model_class}+        >>> import torch++        >>> tokenizer = {tokenizer_class}.from_pretrained('{tokenizer_checkpoint}')+        >>> model = {model_class}.from_pretrained('{checkpoint}')++        >>> text = "Who is eating the apple?"+        >>> inputs = tokenizer(text, return_tensors='pt')+        >>> visual_embeds = get_visual_embeddings(image).unsqueeze(0)+        >>> visual_token_type_ids = torch.ones(visual_embeds.shape[:-1], dtype=torch.long) #example+        >>> visual_attention_mask = torch.ones(visual_embeds.shape[:-1], dtype=torch.float)++        >>> inputs.update({{+            "visual_embeds": visual_embeds,+            "visual_token_type_ids": visual_token_type_ids,+            "visual_attention_mask": visual_attention_mask+        }})++        >>> labels = torch.tensor(1).unsqueeze(0)  # Batch size 1, Num choices 2++        >>> outputs = model(**inputs, labels=labels)+        >>> loss = outputs.loss+        >>> scores = outputs.logits+"""+++VISUAL_BERT_VQA_ADVANCED_SAMPLE = r"""+    Example::+        >>> # Assumption: `get_visual_embeddings(image)` gets the visual embeddings of the image in the batch.+        >>> from transformers import {tokenizer_class}, {model_class}+        >>> import torch++        >>> tokenizer = {tokenizer_class}.from_pretrained('{tokenizer_checkpoint}')+        >>> model = {model_class}.from_pretrained('{checkpoint}')++        >>> text = "Who is eating the apple?"+        >>> inputs = tokenizer(text, return_tensors='pt')+        >>> visual_embeds = get_visual_embeddings(image).unsqueeze(0)+        >>> visual_token_type_ids = torch.ones(visual_embeds.shape[:-1], dtype=torch.long)+        >>> visual_attention_mask = torch.ones(visual_embeds.shape[:-1], dtype=torch.float)++        >>> inputs.update({{+            "visual_embeds": visual_embeds,+            "visual_token_type_ids": visual_token_type_ids,+            "visual_attention_mask": visual_attention_mask+        }})++        >>> labels = torch.ones((1, inputs["input_ids"].shape[-1]+visual_embeds.shape[-2])) # Batch size 1++        >>> outputs = model(**inputs, labels=labels)+        >>> loss = outputs.loss+        >>> scores = outputs.logits+"""++VISUAL_BERT_REGION_TO_PHRASE_SAMPLE = r"""+    Example::+        >>> # Assumption: `get_visual_embeddings(image)` gets the visual embeddings of the image in the batch.+        >>> from transformers import {tokenizer_class}, {model_class}+        >>> import torch++        >>> tokenizer = {tokenizer_class}.from_pretrained('{tokenizer_checkpoint}')+        >>> model = {model_class}.from_pretrained('{checkpoint}')++        >>> text = "Who is eating the apple?"+        >>> inputs = tokenizer(text, return_tensors='pt')+        >>> visual_embeds = get_visual_embeddings(image).unsqueeze(0)+        >>> visual_token_type_ids = torch.ones(visual_embeds.shape[:-1], dtype=torch.long) #example+        >>> visual_attention_mask = torch.ones(visual_embeds.shape[:-1], dtype=torch.float)+        >>> region_to_phrase_position = torch.ones((1, inputs["input_ids"].shape[-1]+visual_embeds.shape[-2]))++        >>> inputs.update({{+            "region_to_phrase_position": region_to_phrase_position,+            "visual_embeds": visual_embeds,+            "visual_token_type_ids": visual_token_type_ids,+            "visual_attention_mask": visual_attention_mask+        }})++        >>> labels = torch.ones((1, inputs["input_ids"].shape[-1]+visual_embeds.shape[-2], visual_embeds.shape[-2])) # Batch size 1++        >>> outputs = model(**inputs, labels=labels)+        >>> loss = outputs.loss+        >>> scores = outputs.logits+"""++VISUAL_BERT_PRE_TRAINING_SAMPLE = r"""+    Example::+        >>> # Assumption: `get_visual_embeddings(image)` gets the visual embeddings of the image in the batch.+        >>> from transformers import {tokenizer_class}, {model_class}+        >>> import torch++        >>> tokenizer = {tokenizer_class}.from_pretrained('{tokenizer_checkpoint}')+        >>> model = {model_class}.from_pretrained('{checkpoint}')++        >>> inputs = tokenizer("The capital of France is {mask}.", return_tensors="pt")+        >>> visual_embeds = get_visual_embeddings(image).unsqueeze(0)+        >>> visual_token_type_ids = torch.ones(visual_embeds.shape[:-1], dtype=torch.long) #example+        >>> visual_attention_mask = torch.ones(visual_embeds.shape[:-1], dtype=torch.float)++        >>> inputs.update({{+            "visual_embeds": visual_embeds,+            "visual_token_type_ids": visual_token_type_ids,+            "visual_attention_mask": visual_attention_mask+        }})++        >>> labels = tokenizer("The capital of France is Paris.", return_tensors="pt", padding="max_length", max_length=inputs["input_ids"].shape[-1]+visual_embeds.shape[-2])["input_ids"]+        >>> sentence_image_labels = torch.tensor(1).unsqueeze(0) # Batch_size+++        >>> outputs = model(**inputs, labels=labels, sentence_image_labels=sentence_image_labels)+        >>> loss = outputs.loss+        >>> prediction_logits = outputs.prediction_logits+        >>> seq_relationship_logits = outputs.seq_relationship_logits+"""++VISUAL_BERT_MODEL_SAMPLE = r"""+    Example::+        >>> # Assumption: `get_visual_embeddings(image)` gets the visual embeddings of the image.+        >>> from transformers import {tokenizer_class}, {model_class}+        >>> import torch++        >>> tokenizer = {tokenizer_class}.from_pretrained('{tokenizer_checkpoint}')+        >>> model = {model_class}.from_pretrained('{checkpoint}')++        >>> inputs = tokenizer("The capital of France is Paris.", return_tensors="pt")+        >>> visual_embeds = get_visual_embeddings(image).unsqueeze(0)+        >>> visual_token_type_ids = torch.ones(visual_embeds.shape[:-1], dtype=torch.long) #example+        >>> visual_attention_mask = torch.ones(visual_embeds.shape[:-1], dtype=torch.float)++        >>> inputs.update({{+            "visual_embeds": visual_embeds,+            "visual_token_type_ids": visual_token_type_ids,+            "visual_attention_mask": visual_attention_mask+        }})++        >>> outputs = model(**inputs)++        >>> last_hidden_states = outputs.last_hidden_state+"""++VISUAL_BERT_MULTIPLE_CHOICE_SAMPLE = r"""+    Example::++        >>> from transformers import {tokenizer_class}, {model_class}+        >>> import torch++        >>> tokenizer = {tokenizer_class}.from_pretrained('{tokenizer_checkpoint}')+        >>> model = {model_class}.from_pretrained('{checkpoint}')++        >>> prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced."+        >>> choice0 = "It is eaten with a fork and a knife."+        >>> choice1 = "It is eaten while held in the hand."++        >>> visual_embeds = get_visual_embeddings(image)+        >>> visual_embeds = visual_embeds.expand(1, 2, *visual_embeds.shape) # (batch_size, num_choices, visual_seq_length, visual_embedding_dim)+        >>> visual_token_type_ids = torch.ones(visual_embeds.shape[:-1], dtype=torch.long)+        >>> visual_attention_mask = torch.ones(visual_embeds.shape[:-1], dtype=torch.float)++        >>> labels = torch.tensor(0).unsqueeze(0)  # choice0 is correct (according to Wikipedia ;)), batch size 1++        >>> encoding = tokenizer([[prompt, prompt], [choice0, choice1]], return_tensors='pt', padding=True)+        >>> outputs = model(**{{k: v.unsqueeze(0) for k,v in encoding.items()}}, visual_embeds=visual_embeds, visual_attention_mask=visual_attention_mask, visual_token_type_ids=visual_token_type_ids, labels=labels)  # batch size is 1++        >>> loss = outputs.loss+        >>> logits = outputs.logits+"""+++@add_start_docstrings(+    "The bare VisualBert Model transformer outputting raw hidden-states without any specific head on top.",+    VISUAL_BERT_START_DOCSTRING,+)+class VisualBertModel(VisualBertPreTrainedModel):+    """++    The model can behave as an encoder (with only self-attention) following the architecture described in `Attention is+    all you need <https://arxiv.org/abs/1706.03762>`__ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,+    Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.+    """++    def __init__(self, config, add_pooling_layer=True):+        super().__init__(config)+        self.config = config++        self.embeddings = VisualBertEmbeddings(config)+        self.encoder = VisualBertEncoder(config)++        self.pooler = (+            VisualBertPooler(config) if add_pooling_layer else None+        )  # TO-DO: Check if pooler is needed necessarily or optionally.

Removed the comment.

gchhablani

comment created time in 31 minutes

Pull request review commenthuggingface/transformers

Add VisualBERT

+# coding=utf-8+# Copyright 2021 The HuggingFace Inc. team. All rights reserved.+#+# Licensed under the Apache License, Version 2.0 (the "License");+# you may not use this file except in compliance with the License.+# You may obtain a copy of the License at+#+#     http://www.apache.org/licenses/LICENSE-2.0+#+# Unless required by applicable law or agreed to in writing, software+# distributed under the License is distributed on an "AS IS" BASIS,+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.+# See the License for the specific language governing permissions and+# limitations under the License.+""" PyTorch VisualBERT model. """+++import math+from copy import deepcopy+from dataclasses import dataclass+from typing import Optional, Tuple++import torch+import torch.utils.checkpoint+from torch import nn+from torch.nn import CrossEntropyLoss, KLDivLoss, LogSoftmax++from ...activations import ACT2FN+from ...file_utils import (+    ModelOutput,+    _prepare_output_docstrings,+    add_start_docstrings,+    add_start_docstrings_to_model_forward,+    replace_return_docstrings,+)+from ...modeling_outputs import (+    BaseModelOutput,+    BaseModelOutputWithPooling,+    MaskedLMOutput,+    MultipleChoiceModelOutput,+    SequenceClassifierOutput,+)+from ...modeling_utils import (+    PreTrainedModel,+    apply_chunking_to_forward,+    find_pruneable_heads_and_indices,+    prune_linear_layer,+)+from ...utils import logging+from .configuration_visual_bert import VisualBertConfig+++logger = logging.get_logger(__name__)++_CONFIG_FOR_DOC = "VisualBertConfig"+_TOKENIZER_FOR_DOC = "BertTokenizer"+_TOKENIZER_CHECKPOINT = "bert-base-uncased"++VISUAL_BERT_PRETRAINED_MODEL_ARCHIVE_LIST = [+    "gchhablani/visualbert-vqa",+    "gchhablani/visualbert-vqa-pre",+    "gchhablani/visualbert-vqa-coco-pre",+    "gchhablani/visualbert-vcr",+    "gchhablani/visualbert-vcr-pre",+    "gchhablani/visualbert-vcr-coco-pre",+    "gchhablani/visualbert-nlvr2",+    "gchhablani/visualbert-nlvr2-pre",+    "gchhablani/visualbert-nlvr2-coco-pre"+    # See all VisualBERT models at https://huggingface.co/models?filter=visual_bert+]+++def mish(x):+    return x * torch.tanh(nn.functional.softplus(x))+++def add_code_sample_docstrings(+    *docstr,+    tokenizer_class=None,+    tokenizer_checkpoint=None,+    checkpoint=None,+    output_type=None,+    config_class=None,+    mask=None,+    model_cls=None,+    code_sample=None+):+    def docstring_decorator(fn):+        # model_class defaults to function's class if not specified otherwise+        model_class = fn.__qualname__.split(".")[0] if model_cls is None else model_cls++        doc_kwargs = dict(+            model_class=model_class,+            tokenizer_class=tokenizer_class,+            checkpoint=checkpoint,+            mask=mask,+            tokenizer_checkpoint=tokenizer_checkpoint,+        )++        output_doc = _prepare_output_docstrings(output_type, config_class) if output_type is not None else ""++        built_doc = code_sample.format(**doc_kwargs)+        fn.__doc__ = (fn.__doc__ or "") + "".join(docstr) + output_doc + built_doc+        return fn++    return docstring_decorator+++class VisualBertEmbeddings(nn.Module):+    """Construct the embeddings from word, position and token_type embeddings and visual embeddings."""++    def __init__(self, config):+        super().__init__()+        self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)+        self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)+        self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)++        # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load+        # any TensorFlow checkpoint file++        self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)+        self.dropout = nn.Dropout(config.hidden_dropout_prob)++        # position_ids (1, len position emb) is contiguous in memory and exported when serialized+        self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))+        self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")++        # For Visual Features+        # Token type and position embedding for image features+        self.visual_token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)+        self.visual_position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)++        if config.special_visual_initialize:+            self.visual_token_type_embeddings.weight = torch.nn.Parameter(+                deepcopy(self.token_type_embeddings.weight.data), requires_grad=True+            )+            self.visual_position_embeddings.weight = torch.nn.Parameter(+                deepcopy(self.position_embeddings.weight.data), requires_grad=True+            )++        self.visual_projection = nn.Linear(config.visual_embedding_dim, config.hidden_size)++    def forward(+        self,+        input_ids=None,+        token_type_ids=None,+        position_ids=None,+        inputs_embeds=None,+        visual_embeds=None,+        visual_token_type_ids=None,+        image_text_alignment=None,+    ):+        if input_ids is not None:+            input_shape = input_ids.size()+        else:+            input_shape = inputs_embeds.size()[:-1]++        seq_length = input_shape[1]++        if position_ids is None:+            position_ids = self.position_ids[:, :seq_length]++        # TO-CHECK: FROM ORIGINAL CODE+        # if input_ids is not None:+        #     position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device)+        #     position_ids = position_ids.unsqueeze(0).expand_as(input_ids)+        # else:+        #     position_ids = torch.arange(seq_length, dtype=torch.long, device=inputs_embeds.device)+        #     position_ids = position_ids.unsqueeze(0).expand(input_shape)++        if inputs_embeds is None:+            inputs_embeds = self.word_embeddings(input_ids)++        if token_type_ids is None:+            token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.input_embeds.device)++        token_type_embeddings = self.token_type_embeddings(token_type_ids)++        embeddings = inputs_embeds + token_type_embeddings+        if self.position_embedding_type == "absolute":+            position_embeddings = self.position_embeddings(position_ids)+            embeddings += position_embeddings++        if visual_embeds is not None:+            if visual_token_type_ids is None:+                visual_token_type_ids = torch.ones(+                    visual_embeds.size()[:-1], dtype=torch.long, device=self.position_ids.device+                )++            visual_embeds = self.visual_projection(visual_embeds)+            visual_token_type_embeddings = self.visual_token_type_embeddings(visual_token_type_ids)++            if image_text_alignment is not None:++                # TO-DO: Find a way to handle this in a better way.+                # image_text_alignment = Batch x image_length x alignment_number.+                # Each element denotes the position of the word corresponding to the image feature. -1 is the padding value.+                image_text_alignment_mask = (image_text_alignment != -1).long()+                # Get rid of the -1.+                image_text_alignment = image_text_alignment_mask * image_text_alignment++                # Batch x image_length x alignment length x dim+                visual_position_embeddings = self.position_embeddings(+                    image_text_alignment+                ) * image_text_alignment_mask.to(dtype=next(self.parameters()).dtype).unsqueeze(-1)+                visual_position_embeddings = visual_position_embeddings.sum(2)++                # We want to averge along the alignment_number dimension.+                image_text_alignment_mask = image_text_alignment_mask.to(dtype=next(self.parameters()).dtype).sum(2)+                image_text_alignment_mask[image_text_alignment_mask == 0] = 1  # Avoid devide by zero error+                visual_position_embeddings = visual_position_embeddings / image_text_alignment_mask.unsqueeze(-1)++                visual_position_ids = torch.zeros(+                    *visual_embeds.size()[:-1], dtype=torch.long, device=visual_embeds.device+                )++                # When fine-tuning the detector , the image_text_alignment is sometimes padded too long.+                if visual_position_embeddings.size(1) != visual_embeds.size(1):+                    assert visual_position_embeddings.size(1) >= visual_embeds.size(1)+                    visual_position_embeddings = visual_position_embeddings[:, : visual_embeds.size(1), :]++                visual_position_embeddings = visual_position_embeddings + self.visual_position_embeddings(+                    visual_position_ids+                )+            else:+                visual_position_ids = torch.zeros(+                    *visual_embeds.size()[:-1], dtype=torch.long, device=visual_embeds.device+                )+                visual_position_embeddings = self.visual_position_embeddings(visual_position_ids)++            visual_embeddings = visual_embeds + visual_position_embeddings + visual_token_type_embeddings++            embeddings = torch.cat((embeddings, visual_embeddings), dim=1)++        embeddings = self.LayerNorm(embeddings)+        embeddings = self.dropout(embeddings)+        return embeddings+++class VisualBertSelfAttention(nn.Module):+    def __init__(self, config):+        super().__init__()+        if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):+            raise ValueError(+                "The hidden size (%d) is not a multiple of the number of attention "+                "heads (%d)" % (config.hidden_size, config.num_attention_heads)+            )++        self.num_attention_heads = config.num_attention_heads+        self.attention_head_size = int(config.hidden_size / config.num_attention_heads)+        self.all_head_size = self.num_attention_heads * self.attention_head_size++        self.query = nn.Linear(config.hidden_size, self.all_head_size)+        self.key = nn.Linear(config.hidden_size, self.all_head_size)+        self.value = nn.Linear(config.hidden_size, self.all_head_size)++        self.dropout = nn.Dropout(config.attention_probs_dropout_prob)++        # TO-CHECK: Config doesn't have this, is this needed? Is it in PreTrainedConfig?+        self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")+        if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":+            self.max_position_embeddings = config.max_position_embeddings+            self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)++    def transpose_for_scores(self, x):+        new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)+        x = x.view(*new_x_shape)+        return x.permute(0, 2, 1, 3)++    def forward(+        self,+        hidden_states,+        attention_mask=None,+        head_mask=None,+        output_attentions=False,+    ):+        mixed_query_layer = self.query(hidden_states)++        key_layer = self.transpose_for_scores(self.key(hidden_states))+        value_layer = self.transpose_for_scores(self.value(hidden_states))++        query_layer = self.transpose_for_scores(mixed_query_layer)++        # Take the dot product between "query" and "key" to get the raw attention scores.+        attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))++        if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":+            seq_length = hidden_states.size()[1]+            position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)+            position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1)+            distance = position_ids_l - position_ids_r+            positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)+            positional_embedding = positional_embedding.to(dtype=query_layer.dtype)  # fp16 compatibility++            if self.position_embedding_type == "relative_key":+                relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)+                attention_scores = attention_scores + relative_position_scores+            elif self.position_embedding_type == "relative_key_query":+                relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)+                relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)+                attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key++        attention_scores = attention_scores / math.sqrt(self.attention_head_size)+        if attention_mask is not None:+            # Apply the attention mask is (precomputed for all layers in VisualBertModel forward() function)+            attention_scores = attention_scores + attention_mask++        # Normalize the attention scores to probabilities.+        attention_probs = nn.Softmax(dim=-1)(attention_scores)++        # This is actually dropping out entire tokens to attend to, which might+        # seem a bit unusual, but is taken from the original Transformer paper.+        attention_probs = self.dropout(attention_probs)++        # Mask heads if we want to+        if head_mask is not None:+            attention_probs = attention_probs * head_mask++        context_layer = torch.matmul(attention_probs, value_layer)++        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()+        new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)+        context_layer = context_layer.view(*new_context_layer_shape)++        outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)++        return outputs+++class VisualBertSelfOutput(nn.Module):+    def __init__(self, config):+        super().__init__()+        self.dense = nn.Linear(config.hidden_size, config.hidden_size)+        self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)+        self.dropout = nn.Dropout(config.hidden_dropout_prob)++    def forward(self, hidden_states, input_tensor):+        hidden_states = self.dense(hidden_states)+        hidden_states = self.dropout(hidden_states)+        hidden_states = self.LayerNorm(hidden_states + input_tensor)+        return hidden_states+++class VisualBertAttention(nn.Module):+    def __init__(self, config):+        super().__init__()+        self.self = VisualBertSelfAttention(config)+        self.output = VisualBertSelfOutput(config)+        self.pruned_heads = set()++    def prune_heads(self, heads):+        if len(heads) == 0:+            return+        heads, index = find_pruneable_heads_and_indices(+            heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads+        )++        # Prune linear layers+        self.self.query = prune_linear_layer(self.self.query, index)+        self.self.key = prune_linear_layer(self.self.key, index)+        self.self.value = prune_linear_layer(self.self.value, index)+        self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)++        # Update hyper params and store pruned heads+        self.self.num_attention_heads = self.self.num_attention_heads - len(heads)+        self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads+        self.pruned_heads = self.pruned_heads.union(heads)++    def forward(+        self,+        hidden_states,+        attention_mask=None,+        head_mask=None,+        output_attentions=False,+    ):+        self_outputs = self.self(+            hidden_states,+            attention_mask,+            head_mask,+            output_attentions,+        )+        attention_output = self.output(self_outputs[0], hidden_states)+        outputs = (attention_output,) + self_outputs[1:]  # add attentions if we output them+        return outputs+++class VisualBertIntermediate(nn.Module):+    def __init__(self, config):+        super().__init__()+        self.dense = nn.Linear(config.hidden_size, config.intermediate_size)+        if isinstance(config.hidden_act, str):+            self.intermediate_act_fn = ACT2FN[config.hidden_act]+        else:+            self.intermediate_act_fn = config.hidden_act++    def forward(self, hidden_states):+        hidden_states = self.dense(hidden_states)+        hidden_states = self.intermediate_act_fn(hidden_states)+        return hidden_states+++class VisualBertOutput(nn.Module):+    def __init__(self, config):+        super().__init__()+        self.dense = nn.Linear(config.intermediate_size, config.hidden_size)+        self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)+        self.dropout = nn.Dropout(config.hidden_dropout_prob)++    def forward(self, hidden_states, input_tensor):+        hidden_states = self.dense(hidden_states)+        hidden_states = self.dropout(hidden_states)+        hidden_states = self.LayerNorm(hidden_states + input_tensor)+        return hidden_states+++class VisualBertLayer(nn.Module):+    def __init__(self, config):+        super().__init__()+        self.chunk_size_feed_forward = config.chunk_size_feed_forward+        self.seq_len_dim = 1+        self.attention = VisualBertAttention(config)+        self.intermediate = VisualBertIntermediate(config)+        self.output = VisualBertOutput(config)++    def forward(+        self,+        hidden_states,+        attention_mask=None,+        head_mask=None,+        output_attentions=False,+    ):+        self_attention_outputs = self.attention(+            hidden_states,+            attention_mask,+            head_mask,+            output_attentions=output_attentions,+        )+        attention_output = self_attention_outputs[0]++        outputs = self_attention_outputs[1:]  # add self attentions if we output attention weights++        layer_output = apply_chunking_to_forward(+            self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output+        )+        outputs = (layer_output,) + outputs++        return outputs++    def feed_forward_chunk(self, attention_output):+        intermediate_output = self.intermediate(attention_output)+        layer_output = self.output(intermediate_output, attention_output)+        return layer_output+++class VisualBertEncoder(nn.Module):+    def __init__(self, config):+        super().__init__()+        self.config = config+        self.layer = nn.ModuleList([VisualBertLayer(config) for _ in range(config.num_hidden_layers)])++    def forward(+        self,+        hidden_states,+        attention_mask=None,+        head_mask=None,+        output_attentions=False,+        output_hidden_states=False,+        return_dict=True,+    ):+        all_hidden_states = () if output_hidden_states else None+        all_self_attentions = () if output_attentions else None++        for i, layer_module in enumerate(self.layer):+            if output_hidden_states:+                all_hidden_states = all_hidden_states + (hidden_states,)++            layer_head_mask = head_mask[i] if head_mask is not None else None++            if getattr(self.config, "gradient_checkpointing", False) and self.training:++                def create_custom_forward(module):+                    def custom_forward(*inputs):+                        return module(*inputs, output_attentions)++                    return custom_forward++                layer_outputs = torch.utils.checkpoint.checkpoint(+                    create_custom_forward(layer_module),+                    hidden_states,+                    attention_mask,+                    layer_head_mask,+                )+            else:+                layer_outputs = layer_module(+                    hidden_states,+                    attention_mask,+                    layer_head_mask,+                    output_attentions,+                )++            hidden_states = layer_outputs[0]+            if output_attentions:+                all_self_attentions = all_self_attentions + (layer_outputs[1],)++        if output_hidden_states:+            all_hidden_states = all_hidden_states + (hidden_states,)++        if not return_dict:+            return tuple(+                v+                for v in [+                    hidden_states,+                    all_hidden_states,+                    all_self_attentions,+                ]+                if v is not None+            )+        return BaseModelOutput(+            last_hidden_state=hidden_states,+            hidden_states=all_hidden_states,+            attentions=all_self_attentions,+        )+++class VisualBertPooler(nn.Module):+    def __init__(self, config):+        super().__init__()+        self.dense = nn.Linear(config.hidden_size, config.hidden_size)+        self.activation = nn.Tanh()++    def forward(self, hidden_states):+        # We "pool" the model by simply taking the hidden state corresponding+        # to the first token.+        first_token_tensor = hidden_states[:, 0]+        pooled_output = self.dense(first_token_tensor)+        pooled_output = self.activation(pooled_output)+        return pooled_output+++class VisualBertPredictionHeadTransform(nn.Module):+    def __init__(self, config):+        super().__init__()+        self.dense = nn.Linear(config.hidden_size, config.hidden_size)+        if isinstance(config.hidden_act, str):+            self.transform_act_fn = ACT2FN[config.hidden_act]+        else:+            self.transform_act_fn = config.hidden_act++        self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)++    def forward(self, hidden_states):+        hidden_states = self.dense(hidden_states)+        hidden_states = self.transform_act_fn(hidden_states)+        hidden_states = self.LayerNorm(hidden_states)+        return hidden_states+++class VisualBertLMPredictionHead(nn.Module):+    def __init__(self, config):+        super().__init__()+        self.transform = VisualBertPredictionHeadTransform(config)++        # The output weights are the same as the input embeddings, but there is+        # an output-only bias for each token.+        self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)++        self.bias = nn.Parameter(torch.zeros(config.vocab_size))++        # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`+        self.decoder.bias = self.bias++    def forward(self, hidden_states):+        hidden_states = self.transform(hidden_states)+        hidden_states = self.decoder(hidden_states)+        return hidden_states+++class VisualBertOnlyMLMHead(nn.Module):+    def __init__(self, config):+        super().__init__()+        self.predictions = VisualBertLMPredictionHead(config)++    def forward(self, sequence_output):+        prediction_scores = self.predictions(sequence_output)+        return prediction_scores+++class VisualBertOnlySIPHead(nn.Module):  # Sentence-Image Prediction+    def __init__(self, config):+        super().__init__()+        self.seq_relationship = nn.Linear(config.hidden_size, 2)++    def forward(self, pooled_output):+        seq_relationship_score = self.seq_relationship(pooled_output)+        return seq_relationship_score+++class VisualBertPreTrainingHeads(nn.Module):+    def __init__(self, config):+        super().__init__()+        self.predictions = VisualBertLMPredictionHead(config)+        self.seq_relationship = nn.Linear(config.hidden_size, 2)++    def forward(self, sequence_output, pooled_output):+        prediction_scores = self.predictions(sequence_output)+        seq_relationship_score = self.seq_relationship(pooled_output)+        return prediction_scores, seq_relationship_score+++class VisualBertPreTrainedModel(PreTrainedModel):+    """+    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained+    models.+    """++    config_class = VisualBertConfig+    base_model_prefix = "visual_bert"+    _keys_to_ignore_on_load_missing = [r"position_ids"]++    def _init_weights(self, module):+        """Initialize the weights"""+        if isinstance(module, (nn.Linear, nn.Embedding)):+            # Slightly different from the TF version which uses truncated_normal for initialization+            # cf https://github.com/pytorch/pytorch/pull/5617+            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)++        elif isinstance(module, nn.LayerNorm):+            module.bias.data.zero_()+            module.weight.data.fill_(1.0)+        if isinstance(module, nn.Linear) and module.bias is not None:+            module.bias.data.zero_()+++@dataclass+class VisualBertForPreTrainingOutput(ModelOutput):+    """+    Output type of :class:`~transformers.VisualBertForPreTraining`.++    Args:+        loss (`optional`, returned when ``labels`` is provided, ``torch.FloatTensor`` of shape :obj:`(1,)`):+            Total loss as the sum of the masked language modeling loss and the sentence-image prediction+            (classification) loss.+        prediction_logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`):+            Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).+        seq_relationship_logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, 2)`):+            Prediction scores of the sentence-image prediction (classification) head (scores of True/False continuation+            before SoftMax).+        hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):+            Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)+            of shape :obj:`(batch_size, sequence_length, hidden_size)`.++            Hidden-states of the model at the output of each layer plus the initial embedding outputs.+        attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):+            Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads,+            sequence_length, sequence_length)`.++            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention+            heads.+    """++    loss: Optional[torch.FloatTensor] = None+    prediction_logits: torch.FloatTensor = None+    seq_relationship_logits: torch.FloatTensor = None+    hidden_states: Optional[Tuple[torch.FloatTensor]] = None+    attentions: Optional[Tuple[torch.FloatTensor]] = None+++VISUAL_BERT_START_DOCSTRING = r"""+    This model inherits from :class:`~transformers.PreTrainedModel`. Check the superclass documentation for the generic+    methods the library implements for all its model (such as downloading or saving, resizing the input embeddings,+    pruning heads etc.)++    This model is also a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`__+    subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to+    general usage and behavior.++    Parameters:+        config (:class:`~transformers.VisualBertConfig`): Model configuration class with all the parameters of the model.+            Initializing with a config file does not load the weights associated with the model, only the+            configuration. Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model+            weights.+"""++VISUAL_BERT_INPUTS_DOCSTRING = r"""+    Args:+        input_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`):+            Indices of input sequence tokens in the vocabulary.++            Indices can be obtained using :class:`~transformers.BertTokenizer`. See+            :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for+            details.++            `What are input IDs? <../glossary.html#input-ids>`__+        attention_mask (:obj:`torch.FloatTensor` of shape :obj:`({0})`, `optional`):+            Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``:++            - 1 for tokens that are **not masked**,+            - 0 for tokens that are **masked**.++            `What are attention masks? <../glossary.html#attention-mask>`__+        token_type_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`, `optional`):+            Segment token indices to indicate first and second portions of the inputs. Indices are selected in ``[0,+            1]``:++            - 0 corresponds to a `sentence A` token,+            - 1 corresponds to a `sentence B` token.++            `What are token type IDs? <../glossary.html#token-type-ids>`_+        position_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`, `optional`):+            Indices of positions of each input sequence tokens in the position embeddings. Selected in the range ``[0,+            config.max_position_embeddings - 1]``.++            `What are position IDs? <../glossary.html#position-ids>`_+        head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`):+            Mask to nullify selected heads of the self-attention modules. Mask values selected in ``[0, 1]``:++            - 1 indicates the head is **not masked**,+            - 0 indicates the head is **masked**.++        inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`({0}, hidden_size)`, `optional`):+            Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation.+            This is useful if you want more control over how to convert :obj:`input_ids` indices into associated+            vectors than the model's internal embedding lookup matrix.++        visual_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, visual_seq_length, visual_embedding_dim)`, `optional`):+            The embedded representation of the visual inputs, generally derived using using an object detector.++        visual_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, visual_seq_length)`, `optional`):+            Mask to avoid performing attention on visual embeddings. Mask values selected in ``[0, 1]``:++            - 1 for tokens that are **not masked**,+            - 0 for tokens that are **masked**.++            `What are attention masks? <../glossary.html#attention-mask>`__+        visual_token_type_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, visual_seq_length)`, `optional`):+            Segment token indices to indicate different portions of the visual embeds.++            `What are token type IDs? <../glossary.html#token-type-ids>`_ The authors of VisualBERT set the+            `visual_token_type_ids` to `1` for all tokens.++        image_text_alignment (:obj:`torch.LongTensor` of shape :obj:`(batch_size, visual_seq_length, alignment_number)`, `optional`):+            Image-Text alignment uses to decide the position IDs of the visual embeddings.++        output_attentions (:obj:`bool`, `optional`):+            Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned+            tensors for more detail.+        output_hidden_states (:obj:`bool`, `optional`):+            Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for+            more detail.+        return_dict (:obj:`bool`, `optional`):+            Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple.+"""++VISUAL_BERT_VQA_SAMPLE = r"""

Done.

gchhablani

comment created time in 32 minutes

Pull request review commenthuggingface/transformers

Add VisualBERT

+# coding=utf-8+# Copyright 2021 The HuggingFace Inc. team. All rights reserved.+#+# Licensed under the Apache License, Version 2.0 (the "License");+# you may not use this file except in compliance with the License.+# You may obtain a copy of the License at+#+#     http://www.apache.org/licenses/LICENSE-2.0+#+# Unless required by applicable law or agreed to in writing, software+# distributed under the License is distributed on an "AS IS" BASIS,+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.+# See the License for the specific language governing permissions and+# limitations under the License.+""" PyTorch VisualBERT model. """+++import math+from copy import deepcopy+from dataclasses import dataclass+from typing import Optional, Tuple++import torch+import torch.utils.checkpoint+from torch import nn+from torch.nn import CrossEntropyLoss, KLDivLoss, LogSoftmax++from ...activations import ACT2FN+from ...file_utils import (+    ModelOutput,+    _prepare_output_docstrings,+    add_start_docstrings,+    add_start_docstrings_to_model_forward,+    replace_return_docstrings,+)+from ...modeling_outputs import (+    BaseModelOutput,+    BaseModelOutputWithPooling,+    MaskedLMOutput,+    MultipleChoiceModelOutput,+    SequenceClassifierOutput,+)+from ...modeling_utils import (+    PreTrainedModel,+    apply_chunking_to_forward,+    find_pruneable_heads_and_indices,+    prune_linear_layer,+)+from ...utils import logging+from .configuration_visual_bert import VisualBertConfig+++logger = logging.get_logger(__name__)++_CONFIG_FOR_DOC = "VisualBertConfig"+_TOKENIZER_FOR_DOC = "BertTokenizer"+_TOKENIZER_CHECKPOINT = "bert-base-uncased"++VISUAL_BERT_PRETRAINED_MODEL_ARCHIVE_LIST = [+    "gchhablani/visualbert-vqa",+    "gchhablani/visualbert-vqa-pre",+    "gchhablani/visualbert-vqa-coco-pre",+    "gchhablani/visualbert-vcr",+    "gchhablani/visualbert-vcr-pre",+    "gchhablani/visualbert-vcr-coco-pre",+    "gchhablani/visualbert-nlvr2",+    "gchhablani/visualbert-nlvr2-pre",+    "gchhablani/visualbert-nlvr2-coco-pre"+    # See all VisualBERT models at https://huggingface.co/models?filter=visual_bert+]+++def mish(x):+    return x * torch.tanh(nn.functional.softplus(x))+++def add_code_sample_docstrings(+    *docstr,+    tokenizer_class=None,+    tokenizer_checkpoint=None,+    checkpoint=None,+    output_type=None,+    config_class=None,+    mask=None,+    model_cls=None,+    code_sample=None+):+    def docstring_decorator(fn):+        # model_class defaults to function's class if not specified otherwise+        model_class = fn.__qualname__.split(".")[0] if model_cls is None else model_cls++        doc_kwargs = dict(+            model_class=model_class,+            tokenizer_class=tokenizer_class,+            checkpoint=checkpoint,+            mask=mask,+            tokenizer_checkpoint=tokenizer_checkpoint,+        )++        output_doc = _prepare_output_docstrings(output_type, config_class) if output_type is not None else ""++        built_doc = code_sample.format(**doc_kwargs)+        fn.__doc__ = (fn.__doc__ or "") + "".join(docstr) + output_doc + built_doc+        return fn++    return docstring_decorator+++class VisualBertEmbeddings(nn.Module):+    """Construct the embeddings from word, position and token_type embeddings and visual embeddings."""++    def __init__(self, config):+        super().__init__()+        self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)+        self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)+        self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)++        # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load+        # any TensorFlow checkpoint file++        self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)+        self.dropout = nn.Dropout(config.hidden_dropout_prob)++        # position_ids (1, len position emb) is contiguous in memory and exported when serialized+        self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))+        self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")++        # For Visual Features+        # Token type and position embedding for image features+        self.visual_token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)+        self.visual_position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)++        if config.special_visual_initialize:+            self.visual_token_type_embeddings.weight = torch.nn.Parameter(+                deepcopy(self.token_type_embeddings.weight.data), requires_grad=True+            )+            self.visual_position_embeddings.weight = torch.nn.Parameter(+                deepcopy(self.position_embeddings.weight.data), requires_grad=True+            )++        self.visual_projection = nn.Linear(config.visual_embedding_dim, config.hidden_size)++    def forward(+        self,+        input_ids=None,+        token_type_ids=None,+        position_ids=None,+        inputs_embeds=None,+        visual_embeds=None,+        visual_token_type_ids=None,+        image_text_alignment=None,+    ):+        if input_ids is not None:+            input_shape = input_ids.size()+        else:+            input_shape = inputs_embeds.size()[:-1]++        seq_length = input_shape[1]++        if position_ids is None:+            position_ids = self.position_ids[:, :seq_length]++        # TO-CHECK: FROM ORIGINAL CODE+        # if input_ids is not None:+        #     position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device)+        #     position_ids = position_ids.unsqueeze(0).expand_as(input_ids)+        # else:+        #     position_ids = torch.arange(seq_length, dtype=torch.long, device=inputs_embeds.device)+        #     position_ids = position_ids.unsqueeze(0).expand(input_shape)++        if inputs_embeds is None:+            inputs_embeds = self.word_embeddings(input_ids)++        if token_type_ids is None:+            token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.input_embeds.device)++        token_type_embeddings = self.token_type_embeddings(token_type_ids)++        embeddings = inputs_embeds + token_type_embeddings+        if self.position_embedding_type == "absolute":+            position_embeddings = self.position_embeddings(position_ids)+            embeddings += position_embeddings++        if visual_embeds is not None:+            if visual_token_type_ids is None:+                visual_token_type_ids = torch.ones(+                    visual_embeds.size()[:-1], dtype=torch.long, device=self.position_ids.device+                )++            visual_embeds = self.visual_projection(visual_embeds)+            visual_token_type_embeddings = self.visual_token_type_embeddings(visual_token_type_ids)++            if image_text_alignment is not None:++                # TO-DO: Find a way to handle this in a better way.+                # image_text_alignment = Batch x image_length x alignment_number.+                # Each element denotes the position of the word corresponding to the image feature. -1 is the padding value.+                image_text_alignment_mask = (image_text_alignment != -1).long()+                # Get rid of the -1.+                image_text_alignment = image_text_alignment_mask * image_text_alignment++                # Batch x image_length x alignment length x dim+                visual_position_embeddings = self.position_embeddings(+                    image_text_alignment+                ) * image_text_alignment_mask.to(dtype=next(self.parameters()).dtype).unsqueeze(-1)+                visual_position_embeddings = visual_position_embeddings.sum(2)++                # We want to averge along the alignment_number dimension.+                image_text_alignment_mask = image_text_alignment_mask.to(dtype=next(self.parameters()).dtype).sum(2)+                image_text_alignment_mask[image_text_alignment_mask == 0] = 1  # Avoid devide by zero error+                visual_position_embeddings = visual_position_embeddings / image_text_alignment_mask.unsqueeze(-1)++                visual_position_ids = torch.zeros(+                    *visual_embeds.size()[:-1], dtype=torch.long, device=visual_embeds.device+                )++                # When fine-tuning the detector , the image_text_alignment is sometimes padded too long.+                if visual_position_embeddings.size(1) != visual_embeds.size(1):+                    assert visual_position_embeddings.size(1) >= visual_embeds.size(1)+                    visual_position_embeddings = visual_position_embeddings[:, : visual_embeds.size(1), :]++                visual_position_embeddings = visual_position_embeddings + self.visual_position_embeddings(+                    visual_position_ids+                )+            else:+                visual_position_ids = torch.zeros(+                    *visual_embeds.size()[:-1], dtype=torch.long, device=visual_embeds.device+                )+                visual_position_embeddings = self.visual_position_embeddings(visual_position_ids)++            visual_embeddings = visual_embeds + visual_position_embeddings + visual_token_type_embeddings++            embeddings = torch.cat((embeddings, visual_embeddings), dim=1)++        embeddings = self.LayerNorm(embeddings)+        embeddings = self.dropout(embeddings)+        return embeddings+++class VisualBertSelfAttention(nn.Module):+    def __init__(self, config):+        super().__init__()+        if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):+            raise ValueError(+                "The hidden size (%d) is not a multiple of the number of attention "+                "heads (%d)" % (config.hidden_size, config.num_attention_heads)+            )++        self.num_attention_heads = config.num_attention_heads+        self.attention_head_size = int(config.hidden_size / config.num_attention_heads)+        self.all_head_size = self.num_attention_heads * self.attention_head_size++        self.query = nn.Linear(config.hidden_size, self.all_head_size)+        self.key = nn.Linear(config.hidden_size, self.all_head_size)+        self.value = nn.Linear(config.hidden_size, self.all_head_size)++        self.dropout = nn.Dropout(config.attention_probs_dropout_prob)++        # TO-CHECK: Config doesn't have this, is this needed? Is it in PreTrainedConfig?+        self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")+        if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":+            self.max_position_embeddings = config.max_position_embeddings+            self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)++    def transpose_for_scores(self, x):+        new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)+        x = x.view(*new_x_shape)+        return x.permute(0, 2, 1, 3)++    def forward(+        self,+        hidden_states,+        attention_mask=None,+        head_mask=None,+        output_attentions=False,+    ):+        mixed_query_layer = self.query(hidden_states)++        key_layer = self.transpose_for_scores(self.key(hidden_states))+        value_layer = self.transpose_for_scores(self.value(hidden_states))++        query_layer = self.transpose_for_scores(mixed_query_layer)++        # Take the dot product between "query" and "key" to get the raw attention scores.+        attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))++        if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":+            seq_length = hidden_states.size()[1]+            position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)+            position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1)+            distance = position_ids_l - position_ids_r+            positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)+            positional_embedding = positional_embedding.to(dtype=query_layer.dtype)  # fp16 compatibility++            if self.position_embedding_type == "relative_key":+                relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)+                attention_scores = attention_scores + relative_position_scores+            elif self.position_embedding_type == "relative_key_query":+                relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)+                relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)+                attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key++        attention_scores = attention_scores / math.sqrt(self.attention_head_size)+        if attention_mask is not None:+            # Apply the attention mask is (precomputed for all layers in VisualBertModel forward() function)+            attention_scores = attention_scores + attention_mask++        # Normalize the attention scores to probabilities.+        attention_probs = nn.Softmax(dim=-1)(attention_scores)++        # This is actually dropping out entire tokens to attend to, which might+        # seem a bit unusual, but is taken from the original Transformer paper.+        attention_probs = self.dropout(attention_probs)++        # Mask heads if we want to+        if head_mask is not None:+            attention_probs = attention_probs * head_mask++        context_layer = torch.matmul(attention_probs, value_layer)++        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()+        new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)+        context_layer = context_layer.view(*new_context_layer_shape)++        outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)++        return outputs+++class VisualBertSelfOutput(nn.Module):+    def __init__(self, config):+        super().__init__()+        self.dense = nn.Linear(config.hidden_size, config.hidden_size)+        self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)+        self.dropout = nn.Dropout(config.hidden_dropout_prob)++    def forward(self, hidden_states, input_tensor):+        hidden_states = self.dense(hidden_states)+        hidden_states = self.dropout(hidden_states)+        hidden_states = self.LayerNorm(hidden_states + input_tensor)+        return hidden_states+++class VisualBertAttention(nn.Module):+    def __init__(self, config):+        super().__init__()+        self.self = VisualBertSelfAttention(config)+        self.output = VisualBertSelfOutput(config)+        self.pruned_heads = set()++    def prune_heads(self, heads):+        if len(heads) == 0:+            return+        heads, index = find_pruneable_heads_and_indices(+            heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads+        )++        # Prune linear layers+        self.self.query = prune_linear_layer(self.self.query, index)+        self.self.key = prune_linear_layer(self.self.key, index)+        self.self.value = prune_linear_layer(self.self.value, index)+        self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)++        # Update hyper params and store pruned heads+        self.self.num_attention_heads = self.self.num_attention_heads - len(heads)+        self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads+        self.pruned_heads = self.pruned_heads.union(heads)++    def forward(+        self,+        hidden_states,+        attention_mask=None,+        head_mask=None,+        output_attentions=False,+    ):+        self_outputs = self.self(+            hidden_states,+            attention_mask,+            head_mask,+            output_attentions,+        )+        attention_output = self.output(self_outputs[0], hidden_states)+        outputs = (attention_output,) + self_outputs[1:]  # add attentions if we output them+        return outputs+++class VisualBertIntermediate(nn.Module):+    def __init__(self, config):+        super().__init__()+        self.dense = nn.Linear(config.hidden_size, config.intermediate_size)+        if isinstance(config.hidden_act, str):+            self.intermediate_act_fn = ACT2FN[config.hidden_act]+        else:+            self.intermediate_act_fn = config.hidden_act++    def forward(self, hidden_states):+        hidden_states = self.dense(hidden_states)+        hidden_states = self.intermediate_act_fn(hidden_states)+        return hidden_states+++class VisualBertOutput(nn.Module):+    def __init__(self, config):+        super().__init__()+        self.dense = nn.Linear(config.intermediate_size, config.hidden_size)+        self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)+        self.dropout = nn.Dropout(config.hidden_dropout_prob)++    def forward(self, hidden_states, input_tensor):+        hidden_states = self.dense(hidden_states)+        hidden_states = self.dropout(hidden_states)+        hidden_states = self.LayerNorm(hidden_states + input_tensor)+        return hidden_states+++class VisualBertLayer(nn.Module):+    def __init__(self, config):+        super().__init__()+        self.chunk_size_feed_forward = config.chunk_size_feed_forward+        self.seq_len_dim = 1+        self.attention = VisualBertAttention(config)+        self.intermediate = VisualBertIntermediate(config)+        self.output = VisualBertOutput(config)++    def forward(+        self,+        hidden_states,+        attention_mask=None,+        head_mask=None,+        output_attentions=False,+    ):+        self_attention_outputs = self.attention(+            hidden_states,+            attention_mask,+            head_mask,+            output_attentions=output_attentions,+        )+        attention_output = self_attention_outputs[0]++        outputs = self_attention_outputs[1:]  # add self attentions if we output attention weights++        layer_output = apply_chunking_to_forward(+            self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output+        )+        outputs = (layer_output,) + outputs++        return outputs++    def feed_forward_chunk(self, attention_output):+        intermediate_output = self.intermediate(attention_output)+        layer_output = self.output(intermediate_output, attention_output)+        return layer_output+++class VisualBertEncoder(nn.Module):+    def __init__(self, config):+        super().__init__()+        self.config = config+        self.layer = nn.ModuleList([VisualBertLayer(config) for _ in range(config.num_hidden_layers)])++    def forward(+        self,+        hidden_states,+        attention_mask=None,+        head_mask=None,+        output_attentions=False,+        output_hidden_states=False,+        return_dict=True,+    ):+        all_hidden_states = () if output_hidden_states else None+        all_self_attentions = () if output_attentions else None++        for i, layer_module in enumerate(self.layer):+            if output_hidden_states:+                all_hidden_states = all_hidden_states + (hidden_states,)++            layer_head_mask = head_mask[i] if head_mask is not None else None++            if getattr(self.config, "gradient_checkpointing", False) and self.training:++                def create_custom_forward(module):+                    def custom_forward(*inputs):+                        return module(*inputs, output_attentions)++                    return custom_forward++                layer_outputs = torch.utils.checkpoint.checkpoint(+                    create_custom_forward(layer_module),+                    hidden_states,+                    attention_mask,+                    layer_head_mask,+                )+            else:+                layer_outputs = layer_module(+                    hidden_states,+                    attention_mask,+                    layer_head_mask,+                    output_attentions,+                )++            hidden_states = layer_outputs[0]+            if output_attentions:+                all_self_attentions = all_self_attentions + (layer_outputs[1],)++        if output_hidden_states:+            all_hidden_states = all_hidden_states + (hidden_states,)++        if not return_dict:+            return tuple(+                v+                for v in [+                    hidden_states,+                    all_hidden_states,+                    all_self_attentions,+                ]+                if v is not None

Changed

gchhablani

comment created time in 32 minutes

Pull request review commenthuggingface/transformers

Add VisualBERT

+# coding=utf-8+# Copyright 2021 The HuggingFace Inc. team. All rights reserved.+#+# Licensed under the Apache License, Version 2.0 (the "License");+# you may not use this file except in compliance with the License.+# You may obtain a copy of the License at+#+#     http://www.apache.org/licenses/LICENSE-2.0+#+# Unless required by applicable law or agreed to in writing, software+# distributed under the License is distributed on an "AS IS" BASIS,+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.+# See the License for the specific language governing permissions and+# limitations under the License.+""" PyTorch VisualBERT model. """+++import math+from copy import deepcopy+from dataclasses import dataclass+from typing import Optional, Tuple++import torch+import torch.utils.checkpoint+from torch import nn+from torch.nn import CrossEntropyLoss, KLDivLoss, LogSoftmax++from ...activations import ACT2FN+from ...file_utils import (+    ModelOutput,+    _prepare_output_docstrings,+    add_start_docstrings,+    add_start_docstrings_to_model_forward,+    replace_return_docstrings,+)+from ...modeling_outputs import (+    BaseModelOutput,+    BaseModelOutputWithPooling,+    MaskedLMOutput,+    MultipleChoiceModelOutput,+    SequenceClassifierOutput,+)+from ...modeling_utils import (+    PreTrainedModel,+    apply_chunking_to_forward,+    find_pruneable_heads_and_indices,+    prune_linear_layer,+)+from ...utils import logging+from .configuration_visual_bert import VisualBertConfig+++logger = logging.get_logger(__name__)++_CONFIG_FOR_DOC = "VisualBertConfig"+_TOKENIZER_FOR_DOC = "BertTokenizer"+_TOKENIZER_CHECKPOINT = "bert-base-uncased"++VISUAL_BERT_PRETRAINED_MODEL_ARCHIVE_LIST = [+    "gchhablani/visualbert-vqa",+    "gchhablani/visualbert-vqa-pre",+    "gchhablani/visualbert-vqa-coco-pre",+    "gchhablani/visualbert-vcr",+    "gchhablani/visualbert-vcr-pre",+    "gchhablani/visualbert-vcr-coco-pre",+    "gchhablani/visualbert-nlvr2",+    "gchhablani/visualbert-nlvr2-pre",+    "gchhablani/visualbert-nlvr2-coco-pre"+    # See all VisualBERT models at https://huggingface.co/models?filter=visual_bert+]+++def mish(x):+    return x * torch.tanh(nn.functional.softplus(x))+++def add_code_sample_docstrings(+    *docstr,+    tokenizer_class=None,+    tokenizer_checkpoint=None,+    checkpoint=None,+    output_type=None,+    config_class=None,+    mask=None,+    model_cls=None,+    code_sample=None+):+    def docstring_decorator(fn):+        # model_class defaults to function's class if not specified otherwise+        model_class = fn.__qualname__.split(".")[0] if model_cls is None else model_cls++        doc_kwargs = dict(+            model_class=model_class,+            tokenizer_class=tokenizer_class,+            checkpoint=checkpoint,+            mask=mask,+            tokenizer_checkpoint=tokenizer_checkpoint,+        )++        output_doc = _prepare_output_docstrings(output_type, config_class) if output_type is not None else ""++        built_doc = code_sample.format(**doc_kwargs)+        fn.__doc__ = (fn.__doc__ or "") + "".join(docstr) + output_doc + built_doc+        return fn++    return docstring_decorator+++class VisualBertEmbeddings(nn.Module):+    """Construct the embeddings from word, position and token_type embeddings and visual embeddings."""++    def __init__(self, config):+        super().__init__()+        self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)+        self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)+        self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)++        # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load+        # any TensorFlow checkpoint file++        self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)+        self.dropout = nn.Dropout(config.hidden_dropout_prob)++        # position_ids (1, len position emb) is contiguous in memory and exported when serialized+        self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))+        self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")++        # For Visual Features+        # Token type and position embedding for image features+        self.visual_token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)+        self.visual_position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)++        if config.special_visual_initialize:+            self.visual_token_type_embeddings.weight = torch.nn.Parameter(+                deepcopy(self.token_type_embeddings.weight.data), requires_grad=True+            )+            self.visual_position_embeddings.weight = torch.nn.Parameter(+                deepcopy(self.position_embeddings.weight.data), requires_grad=True+            )++        self.visual_projection = nn.Linear(config.visual_embedding_dim, config.hidden_size)++    def forward(+        self,+        input_ids=None,+        token_type_ids=None,+        position_ids=None,+        inputs_embeds=None,+        visual_embeds=None,+        visual_token_type_ids=None,+        image_text_alignment=None,+    ):+        if input_ids is not None:+            input_shape = input_ids.size()+        else:+            input_shape = inputs_embeds.size()[:-1]++        seq_length = input_shape[1]++        if position_ids is None:+            position_ids = self.position_ids[:, :seq_length]++        # TO-CHECK: FROM ORIGINAL CODE+        # if input_ids is not None:+        #     position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device)+        #     position_ids = position_ids.unsqueeze(0).expand_as(input_ids)+        # else:+        #     position_ids = torch.arange(seq_length, dtype=torch.long, device=inputs_embeds.device)+        #     position_ids = position_ids.unsqueeze(0).expand(input_shape)++        if inputs_embeds is None:+            inputs_embeds = self.word_embeddings(input_ids)++        if token_type_ids is None:+            token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.input_embeds.device)++        token_type_embeddings = self.token_type_embeddings(token_type_ids)++        embeddings = inputs_embeds + token_type_embeddings+        if self.position_embedding_type == "absolute":+            position_embeddings = self.position_embeddings(position_ids)+            embeddings += position_embeddings++        if visual_embeds is not None:+            if visual_token_type_ids is None:+                visual_token_type_ids = torch.ones(+                    visual_embeds.size()[:-1], dtype=torch.long, device=self.position_ids.device+                )++            visual_embeds = self.visual_projection(visual_embeds)+            visual_token_type_embeddings = self.visual_token_type_embeddings(visual_token_type_ids)++            if image_text_alignment is not None:++                # TO-DO: Find a way to handle this in a better way.+                # image_text_alignment = Batch x image_length x alignment_number.+                # Each element denotes the position of the word corresponding to the image feature. -1 is the padding value.+                image_text_alignment_mask = (image_text_alignment != -1).long()+                # Get rid of the -1.+                image_text_alignment = image_text_alignment_mask * image_text_alignment++                # Batch x image_length x alignment length x dim+                visual_position_embeddings = self.position_embeddings(+                    image_text_alignment+                ) * image_text_alignment_mask.to(dtype=next(self.parameters()).dtype).unsqueeze(-1)+                visual_position_embeddings = visual_position_embeddings.sum(2)++                # We want to averge along the alignment_number dimension.+                image_text_alignment_mask = image_text_alignment_mask.to(dtype=next(self.parameters()).dtype).sum(2)+                image_text_alignment_mask[image_text_alignment_mask == 0] = 1  # Avoid devide by zero error+                visual_position_embeddings = visual_position_embeddings / image_text_alignment_mask.unsqueeze(-1)++                visual_position_ids = torch.zeros(+                    *visual_embeds.size()[:-1], dtype=torch.long, device=visual_embeds.device+                )++                # When fine-tuning the detector , the image_text_alignment is sometimes padded too long.+                if visual_position_embeddings.size(1) != visual_embeds.size(1):+                    assert visual_position_embeddings.size(1) >= visual_embeds.size(1)+                    visual_position_embeddings = visual_position_embeddings[:, : visual_embeds.size(1), :]++                visual_position_embeddings = visual_position_embeddings + self.visual_position_embeddings(+                    visual_position_ids+                )+            else:+                visual_position_ids = torch.zeros(+                    *visual_embeds.size()[:-1], dtype=torch.long, device=visual_embeds.device+                )+                visual_position_embeddings = self.visual_position_embeddings(visual_position_ids)++            visual_embeddings = visual_embeds + visual_position_embeddings + visual_token_type_embeddings++            embeddings = torch.cat((embeddings, visual_embeddings), dim=1)++        embeddings = self.LayerNorm(embeddings)+        embeddings = self.dropout(embeddings)+        return embeddings+++class VisualBertSelfAttention(nn.Module):+    def __init__(self, config):+        super().__init__()+        if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):+            raise ValueError(+                "The hidden size (%d) is not a multiple of the number of attention "+                "heads (%d)" % (config.hidden_size, config.num_attention_heads)+            )++        self.num_attention_heads = config.num_attention_heads+        self.attention_head_size = int(config.hidden_size / config.num_attention_heads)+        self.all_head_size = self.num_attention_heads * self.attention_head_size++        self.query = nn.Linear(config.hidden_size, self.all_head_size)+        self.key = nn.Linear(config.hidden_size, self.all_head_size)+        self.value = nn.Linear(config.hidden_size, self.all_head_size)++        self.dropout = nn.Dropout(config.attention_probs_dropout_prob)++        # TO-CHECK: Config doesn't have this, is this needed? Is it in PreTrainedConfig?+        self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")+        if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":+            self.max_position_embeddings = config.max_position_embeddings+            self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)++    def transpose_for_scores(self, x):+        new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)+        x = x.view(*new_x_shape)+        return x.permute(0, 2, 1, 3)++    def forward(+        self,+        hidden_states,+        attention_mask=None,+        head_mask=None,+        output_attentions=False,+    ):+        mixed_query_layer = self.query(hidden_states)++        key_layer = self.transpose_for_scores(self.key(hidden_states))+        value_layer = self.transpose_for_scores(self.value(hidden_states))++        query_layer = self.transpose_for_scores(mixed_query_layer)++        # Take the dot product between "query" and "key" to get the raw attention scores.+        attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))++        if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":+            seq_length = hidden_states.size()[1]+            position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)+            position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1)+            distance = position_ids_l - position_ids_r+            positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)+            positional_embedding = positional_embedding.to(dtype=query_layer.dtype)  # fp16 compatibility++            if self.position_embedding_type == "relative_key":+                relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)+                attention_scores = attention_scores + relative_position_scores+            elif self.position_embedding_type == "relative_key_query":+                relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)+                relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)+                attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key++        attention_scores = attention_scores / math.sqrt(self.attention_head_size)+        if attention_mask is not None:+            # Apply the attention mask is (precomputed for all layers in VisualBertModel forward() function)+            attention_scores = attention_scores + attention_mask++        # Normalize the attention scores to probabilities.+        attention_probs = nn.Softmax(dim=-1)(attention_scores)++        # This is actually dropping out entire tokens to attend to, which might+        # seem a bit unusual, but is taken from the original Transformer paper.+        attention_probs = self.dropout(attention_probs)++        # Mask heads if we want to+        if head_mask is not None:+            attention_probs = attention_probs * head_mask++        context_layer = torch.matmul(attention_probs, value_layer)++        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()+        new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)+        context_layer = context_layer.view(*new_context_layer_shape)++        outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)++        return outputs+++class VisualBertSelfOutput(nn.Module):+    def __init__(self, config):+        super().__init__()+        self.dense = nn.Linear(config.hidden_size, config.hidden_size)+        self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)+        self.dropout = nn.Dropout(config.hidden_dropout_prob)++    def forward(self, hidden_states, input_tensor):+        hidden_states = self.dense(hidden_states)+        hidden_states = self.dropout(hidden_states)+        hidden_states = self.LayerNorm(hidden_states + input_tensor)+        return hidden_states+++class VisualBertAttention(nn.Module):+    def __init__(self, config):+        super().__init__()+        self.self = VisualBertSelfAttention(config)+        self.output = VisualBertSelfOutput(config)+        self.pruned_heads = set()++    def prune_heads(self, heads):+        if len(heads) == 0:+            return+        heads, index = find_pruneable_heads_and_indices(+            heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads+        )++        # Prune linear layers+        self.self.query = prune_linear_layer(self.self.query, index)+        self.self.key = prune_linear_layer(self.self.key, index)+        self.self.value = prune_linear_layer(self.self.value, index)+        self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)++        # Update hyper params and store pruned heads+        self.self.num_attention_heads = self.self.num_attention_heads - len(heads)+        self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads+        self.pruned_heads = self.pruned_heads.union(heads)++    def forward(+        self,+        hidden_states,+        attention_mask=None,+        head_mask=None,+        output_attentions=False,+    ):+        self_outputs = self.self(+            hidden_states,+            attention_mask,+            head_mask,+            output_attentions,+        )+        attention_output = self.output(self_outputs[0], hidden_states)+        outputs = (attention_output,) + self_outputs[1:]  # add attentions if we output them+        return outputs+++class VisualBertIntermediate(nn.Module):+    def __init__(self, config):+        super().__init__()+        self.dense = nn.Linear(config.hidden_size, config.intermediate_size)+        if isinstance(config.hidden_act, str):+            self.intermediate_act_fn = ACT2FN[config.hidden_act]+        else:+            self.intermediate_act_fn = config.hidden_act++    def forward(self, hidden_states):+        hidden_states = self.dense(hidden_states)+        hidden_states = self.intermediate_act_fn(hidden_states)+        return hidden_states+++class VisualBertOutput(nn.Module):+    def __init__(self, config):+        super().__init__()+        self.dense = nn.Linear(config.intermediate_size, config.hidden_size)+        self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)+        self.dropout = nn.Dropout(config.hidden_dropout_prob)++    def forward(self, hidden_states, input_tensor):+        hidden_states = self.dense(hidden_states)+        hidden_states = self.dropout(hidden_states)+        hidden_states = self.LayerNorm(hidden_states + input_tensor)+        return hidden_states+++class VisualBertLayer(nn.Module):+    def __init__(self, config):+        super().__init__()+        self.chunk_size_feed_forward = config.chunk_size_feed_forward+        self.seq_len_dim = 1+        self.attention = VisualBertAttention(config)+        self.intermediate = VisualBertIntermediate(config)+        self.output = VisualBertOutput(config)++    def forward(+        self,+        hidden_states,+        attention_mask=None,+        head_mask=None,+        output_attentions=False,+    ):

Changed

gchhablani

comment created time in 33 minutes

Pull request review commenthuggingface/transformers

Add VisualBERT

+# coding=utf-8+# Copyright 2021 The HuggingFace Inc. team. All rights reserved.+#+# Licensed under the Apache License, Version 2.0 (the "License");+# you may not use this file except in compliance with the License.+# You may obtain a copy of the License at+#+#     http://www.apache.org/licenses/LICENSE-2.0+#+# Unless required by applicable law or agreed to in writing, software+# distributed under the License is distributed on an "AS IS" BASIS,+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.+# See the License for the specific language governing permissions and+# limitations under the License.+""" PyTorch VisualBERT model. """+++import math+from copy import deepcopy+from dataclasses import dataclass+from typing import Optional, Tuple++import torch+import torch.utils.checkpoint+from torch import nn+from torch.nn import CrossEntropyLoss, KLDivLoss, LogSoftmax++from ...activations import ACT2FN+from ...file_utils import (+    ModelOutput,+    _prepare_output_docstrings,+    add_start_docstrings,+    add_start_docstrings_to_model_forward,+    replace_return_docstrings,+)+from ...modeling_outputs import (+    BaseModelOutput,+    BaseModelOutputWithPooling,+    MaskedLMOutput,+    MultipleChoiceModelOutput,+    SequenceClassifierOutput,+)+from ...modeling_utils import (+    PreTrainedModel,+    apply_chunking_to_forward,+    find_pruneable_heads_and_indices,+    prune_linear_layer,+)+from ...utils import logging+from .configuration_visual_bert import VisualBertConfig+++logger = logging.get_logger(__name__)++_CONFIG_FOR_DOC = "VisualBertConfig"+_TOKENIZER_FOR_DOC = "BertTokenizer"+_TOKENIZER_CHECKPOINT = "bert-base-uncased"++VISUAL_BERT_PRETRAINED_MODEL_ARCHIVE_LIST = [+    "gchhablani/visualbert-vqa",+    "gchhablani/visualbert-vqa-pre",+    "gchhablani/visualbert-vqa-coco-pre",+    "gchhablani/visualbert-vcr",+    "gchhablani/visualbert-vcr-pre",+    "gchhablani/visualbert-vcr-coco-pre",+    "gchhablani/visualbert-nlvr2",+    "gchhablani/visualbert-nlvr2-pre",+    "gchhablani/visualbert-nlvr2-coco-pre"+    # See all VisualBERT models at https://huggingface.co/models?filter=visual_bert+]+++def mish(x):+    return x * torch.tanh(nn.functional.softplus(x))+++def add_code_sample_docstrings(+    *docstr,+    tokenizer_class=None,+    tokenizer_checkpoint=None,+    checkpoint=None,+    output_type=None,+    config_class=None,+    mask=None,+    model_cls=None,+    code_sample=None+):+    def docstring_decorator(fn):+        # model_class defaults to function's class if not specified otherwise+        model_class = fn.__qualname__.split(".")[0] if model_cls is None else model_cls++        doc_kwargs = dict(+            model_class=model_class,+            tokenizer_class=tokenizer_class,+            checkpoint=checkpoint,+            mask=mask,+            tokenizer_checkpoint=tokenizer_checkpoint,+        )++        output_doc = _prepare_output_docstrings(output_type, config_class) if output_type is not None else ""++        built_doc = code_sample.format(**doc_kwargs)+        fn.__doc__ = (fn.__doc__ or "") + "".join(docstr) + output_doc + built_doc+        return fn++    return docstring_decorator+++class VisualBertEmbeddings(nn.Module):+    """Construct the embeddings from word, position and token_type embeddings and visual embeddings."""++    def __init__(self, config):+        super().__init__()+        self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)+        self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)+        self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)++        # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load+        # any TensorFlow checkpoint file++        self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)+        self.dropout = nn.Dropout(config.hidden_dropout_prob)++        # position_ids (1, len position emb) is contiguous in memory and exported when serialized+        self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))+        self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")++        # For Visual Features+        # Token type and position embedding for image features+        self.visual_token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)+        self.visual_position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)++        if config.special_visual_initialize:+            self.visual_token_type_embeddings.weight = torch.nn.Parameter(+                deepcopy(self.token_type_embeddings.weight.data), requires_grad=True+            )+            self.visual_position_embeddings.weight = torch.nn.Parameter(+                deepcopy(self.position_embeddings.weight.data), requires_grad=True+            )++        self.visual_projection = nn.Linear(config.visual_embedding_dim, config.hidden_size)++    def forward(+        self,+        input_ids=None,+        token_type_ids=None,+        position_ids=None,+        inputs_embeds=None,+        visual_embeds=None,+        visual_token_type_ids=None,+        image_text_alignment=None,+    ):+        if input_ids is not None:+            input_shape = input_ids.size()+        else:+            input_shape = inputs_embeds.size()[:-1]++        seq_length = input_shape[1]++        if position_ids is None:+            position_ids = self.position_ids[:, :seq_length]++        # TO-CHECK: FROM ORIGINAL CODE+        # if input_ids is not None:+        #     position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device)+        #     position_ids = position_ids.unsqueeze(0).expand_as(input_ids)+        # else:+        #     position_ids = torch.arange(seq_length, dtype=torch.long, device=inputs_embeds.device)+        #     position_ids = position_ids.unsqueeze(0).expand(input_shape)

Removed.

gchhablani

comment created time in an hour

Pull request review commenthuggingface/transformers

Add VisualBERT

+# coding=utf-8+# Copyright 2021 The HuggingFace Inc. team. All rights reserved.+#+# Licensed under the Apache License, Version 2.0 (the "License");+# you may not use this file except in compliance with the License.+# You may obtain a copy of the License at+#+#     http://www.apache.org/licenses/LICENSE-2.0+#+# Unless required by applicable law or agreed to in writing, software+# distributed under the License is distributed on an "AS IS" BASIS,+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.+# See the License for the specific language governing permissions and+# limitations under the License.+""" PyTorch VisualBERT model. """+++import math+from copy import deepcopy+from dataclasses import dataclass+from typing import Optional, Tuple++import torch+import torch.utils.checkpoint+from torch import nn+from torch.nn import CrossEntropyLoss, KLDivLoss, LogSoftmax++from ...activations import ACT2FN+from ...file_utils import (+    ModelOutput,+    _prepare_output_docstrings,+    add_start_docstrings,+    add_start_docstrings_to_model_forward,+    replace_return_docstrings,+)+from ...modeling_outputs import (+    BaseModelOutput,+    BaseModelOutputWithPooling,+    MaskedLMOutput,+    MultipleChoiceModelOutput,+    SequenceClassifierOutput,+)+from ...modeling_utils import (+    PreTrainedModel,+    apply_chunking_to_forward,+    find_pruneable_heads_and_indices,+    prune_linear_layer,+)+from ...utils import logging+from .configuration_visual_bert import VisualBertConfig+++logger = logging.get_logger(__name__)++_CONFIG_FOR_DOC = "VisualBertConfig"+_TOKENIZER_FOR_DOC = "BertTokenizer"+_TOKENIZER_CHECKPOINT = "bert-base-uncased"++VISUAL_BERT_PRETRAINED_MODEL_ARCHIVE_LIST = [+    "gchhablani/visualbert-vqa",+    "gchhablani/visualbert-vqa-pre",+    "gchhablani/visualbert-vqa-coco-pre",+    "gchhablani/visualbert-vcr",+    "gchhablani/visualbert-vcr-pre",+    "gchhablani/visualbert-vcr-coco-pre",+    "gchhablani/visualbert-nlvr2",+    "gchhablani/visualbert-nlvr2-pre",+    "gchhablani/visualbert-nlvr2-coco-pre"+    # See all VisualBERT models at https://huggingface.co/models?filter=visual_bert+]+++def mish(x):+    return x * torch.tanh(nn.functional.softplus(x))+++def add_code_sample_docstrings(+    *docstr,+    tokenizer_class=None,+    tokenizer_checkpoint=None,+    checkpoint=None,+    output_type=None,+    config_class=None,+    mask=None,+    model_cls=None,+    code_sample=None+):+    def docstring_decorator(fn):+        # model_class defaults to function's class if not specified otherwise+        model_class = fn.__qualname__.split(".")[0] if model_cls is None else model_cls++        doc_kwargs = dict(+            model_class=model_class,+            tokenizer_class=tokenizer_class,+            checkpoint=checkpoint,+            mask=mask,+            tokenizer_checkpoint=tokenizer_checkpoint,+        )++        output_doc = _prepare_output_docstrings(output_type, config_class) if output_type is not None else ""++        built_doc = code_sample.format(**doc_kwargs)+        fn.__doc__ = (fn.__doc__ or "") + "".join(docstr) + output_doc + built_doc+        return fn++    return docstring_decorator+++class VisualBertEmbeddings(nn.Module):+    """Construct the embeddings from word, position and token_type embeddings and visual embeddings."""++    def __init__(self, config):+        super().__init__()+        self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)+        self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)+        self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)++        # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load+        # any TensorFlow checkpoint file++        self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)+        self.dropout = nn.Dropout(config.hidden_dropout_prob)++        # position_ids (1, len position emb) is contiguous in memory and exported when serialized+        self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))+        self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")++        # For Visual Features+        # Token type and position embedding for image features+        self.visual_token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)+        self.visual_position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)++        if config.special_visual_initialize:+            self.visual_token_type_embeddings.weight = torch.nn.Parameter(+                deepcopy(self.token_type_embeddings.weight.data), requires_grad=True+            )+            self.visual_position_embeddings.weight = torch.nn.Parameter(+                deepcopy(self.position_embeddings.weight.data), requires_grad=True

Done

gchhablani

comment created time in an hour

Pull request review commenthuggingface/transformers

Add VisualBERT

+# coding=utf-8+# Copyright 2021 The HuggingFace Inc. team. All rights reserved.+#+# Licensed under the Apache License, Version 2.0 (the "License");+# you may not use this file except in compliance with the License.+# You may obtain a copy of the License at+#+#     http://www.apache.org/licenses/LICENSE-2.0+#+# Unless required by applicable law or agreed to in writing, software+# distributed under the License is distributed on an "AS IS" BASIS,+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.+# See the License for the specific language governing permissions and+# limitations under the License.+""" Testing suite for the PyTorch VisualBERT model. """+++import copy+import unittest++from tests.test_modeling_common import floats_tensor+from transformers import is_torch_available++# from transformers.models.auto import get_values+from transformers.testing_utils import require_torch, slow, torch_device++from .test_configuration_common import ConfigTester+from .test_modeling_common import ModelTesterMixin, ids_tensor  # , random_attention_mask+++if is_torch_available():+    import torch++    from transformers import (+        VisualBertConfig,+        VisualBertForMultipleChoice,+        VisualBertForPreTraining,+        VisualBertForQuestionAnswering,+        VisualBertForQuestionAnsweringAdvanced,+        VisualBertForRegionToPhraseAlignment,+        VisualBertForVisualReasoning,+        VisualBertModel,+    )+    from transformers.models.visual_bert.modeling_visual_bert import VISUAL_BERT_PRETRAINED_MODEL_ARCHIVE_LIST+++class VisualBertModelTester:+    def __init__(+        self,+        parent,+        batch_size=13,+        seq_length=7,+        visual_seq_length=5,+        is_training=True,+        use_attention_mask=True,+        use_visual_attention_mask=True,+        use_token_type_ids=True,+        use_visual_token_type_ids=True,+        use_labels=True,+        vocab_size=99,+        hidden_size=32,+        num_hidden_layers=5,+        num_attention_heads=4,+        intermediate_size=37,+        hidden_act="gelu",+        hidden_dropout_prob=0.1,+        attention_probs_dropout_prob=0.1,+        max_position_embeddings=512,+        visual_embedding_dim=20,+        type_vocab_size=16,+        type_sequence_label_size=2,+        initializer_range=0.02,+        num_labels=3,+        num_choices=4,+        scope=None,+    ):+        self.parent = parent+        self.batch_size = batch_size+        self.seq_length = seq_length+        self.visual_seq_length = visual_seq_length+        self.is_training = is_training+        self.use_attention_mask = use_attention_mask+        self.use_visual_attention_mask = use_visual_attention_mask+        self.use_token_type_ids = use_token_type_ids+        self.use_visual_token_type_ids = use_visual_token_type_ids+        self.use_labels = use_labels+        self.vocab_size = vocab_size+        self.hidden_size = hidden_size+        self.num_hidden_layers = num_hidden_layers+        self.num_attention_heads = num_attention_heads+        self.intermediate_size = intermediate_size+        self.hidden_act = hidden_act+        self.hidden_dropout_prob = hidden_dropout_prob+        self.attention_probs_dropout_prob = attention_probs_dropout_prob+        self.max_position_embeddings = max_position_embeddings+        self.visual_embedding_dim = visual_embedding_dim+        self.type_vocab_size = type_vocab_size+        self.type_sequence_label_size = type_sequence_label_size+        self.initializer_range = initializer_range+        self.num_labels = num_labels+        self.num_choices = num_choices+        self.scope = scope++    def prepare_config(self):+        return VisualBertConfig(+            vocab_size=self.vocab_size,+            hidden_size=self.hidden_size,+            num_hidden_layers=self.num_hidden_layers,+            num_attention_heads=self.num_attention_heads,+            intermediate_size=self.intermediate_size,+            hidden_act=self.hidden_act,+            hidden_dropout_prob=self.hidden_dropout_prob,+            attention_probs_dropout_prob=self.attention_probs_dropout_prob,+            max_position_embeddings=self.max_position_embeddings,+            type_vocab_size=self.type_vocab_size,+            visual_embedding_dim=self.visual_embedding_dim,+            num_labels=self.num_labels,+            is_decoder=False,+            initializer_range=self.initializer_range,+        )++    def prepare_config_and_inputs_for_common(self):+        input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)+        visual_embeds = floats_tensor([self.batch_size, self.visual_seq_length, self.visual_embedding_dim])++        attention_mask = None+        if self.use_attention_mask:+            attention_mask = torch.ones((self.batch_size, self.seq_length), dtype=torch.long, device=torch_device)++        visual_attention_mask = None+        if self.use_visual_attention_mask:+            visual_attention_mask = torch.ones(+                (self.batch_size, self.visual_seq_length), dtype=torch.long, device=torch_device+            )++        token_type_ids = None+        if self.use_token_type_ids:+            token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size)++        visual_token_type_ids = None+        if self.use_visual_token_type_ids:+            visual_token_type_ids = ids_tensor([self.batch_size, self.visual_seq_length], self.type_vocab_size)++        config = self.prepare_config()+        return config, {+            "input_ids": input_ids,+            "token_type_ids": token_type_ids,+            "attention_mask": attention_mask,+            "visual_embeds": visual_embeds,+            "visual_token_type_ids": visual_token_type_ids,+            "visual_attention_mask": visual_attention_mask,+        }++    def prepare_config_and_inputs_for_pretraining(self):+        masked_lm_labels = None+        sentence_image_labels = None++        if self.use_labels:+            masked_lm_labels = ids_tensor([self.batch_size, self.seq_length + self.visual_seq_length], self.vocab_size)+            sentence_image_labels = ids_tensor(+                [+                    self.batch_size,+                ],+                self.type_sequence_label_size,+            )++        config, input_dict = self.prepare_config_and_inputs_for_common()++        input_dict.update({"labels": masked_lm_labels, "sentence_image_labels": sentence_image_labels})++        return config, input_dict++    def prepare_config_and_inputs_for_multiple_choice(self):+        input_ids = ids_tensor([self.batch_size, self.num_choices, self.seq_length], self.vocab_size)+        visual_embeds = floats_tensor(+            [self.batch_size, self.num_choices, self.visual_seq_length, self.visual_embedding_dim]+        )++        attention_mask = None+        if self.use_attention_mask:+            attention_mask = torch.ones(+                (self.batch_size, self.num_choices, self.seq_length), dtype=torch.long, device=torch_device+            )++        visual_attention_mask = None+        if self.use_visual_attention_mask:+            visual_attention_mask = torch.ones(+                (self.batch_size, self.num_choices, self.visual_seq_length), dtype=torch.long, device=torch_device+            )++        token_type_ids = None+        if self.use_token_type_ids:+            token_type_ids = ids_tensor([self.batch_size, self.num_choices, self.seq_length], self.type_vocab_size)++        visual_token_type_ids = None+        if self.use_visual_token_type_ids:+            visual_token_type_ids = ids_tensor(+                [self.batch_size, self.num_choices, self.visual_seq_length], self.type_vocab_size+            )++        labels = None++        if self.use_labels:+            labels = ids_tensor([self.batch_size], self.num_choices)++        config = self.prepare_config()+        return config, {+            "input_ids": input_ids,+            "token_type_ids": token_type_ids,+            "attention_mask": attention_mask,+            "visual_embeds": visual_embeds,+            "visual_token_type_ids": visual_token_type_ids,+            "visual_attention_mask": visual_attention_mask,+            "labels": labels,+        }++    def prepare_config_and_inputs_for_vqa(self):+        vqa_labels = None++        if self.use_labels:+            vqa_labels = ids_tensor([self.batch_size, self.num_labels], self.num_labels)++        config, input_dict = self.prepare_config_and_inputs_for_common()++        input_dict.update({"labels": vqa_labels})+        return config, input_dict++    def prepare_config_and_inputs_for_vqa_advanced(self):+        vqa_labels = None++        if self.use_labels:+            vqa_labels = ids_tensor([self.batch_size, self.seq_length + self.visual_seq_length], self.vocab_size)++        config, input_dict = self.prepare_config_and_inputs_for_common()++        input_dict.update({"labels": vqa_labels})+        return config, input_dict++    def prepare_config_and_inputs_for_nlvr(self):+        nlvr_labels = None++        if self.use_labels:+            nlvr_labels = ids_tensor([self.batch_size], self.num_labels)++        config, input_dict = self.prepare_config_and_inputs_for_common()++        input_dict.update({"labels": nlvr_labels})+        return config, input_dict++    def prepare_config_and_inputs_for_flickr(self):+        region_to_phrase_position = torch.cat(+            (+                ids_tensor([self.batch_size, self.seq_length], self.visual_seq_length),+                torch.ones(self.batch_size, self.visual_seq_length, dtype=torch.long, device=torch_device) * -1,+            ),+            dim=-1,+        )+        flickr_labels = None+        if self.use_labels:+            flickr_labels = ids_tensor(+                [self.batch_size, self.seq_length + self.visual_seq_length, self.visual_seq_length], 2+            )++        config, input_dict = self.prepare_config_and_inputs_for_common()++        input_dict.update({"region_to_phrase_position": region_to_phrase_position, "labels": flickr_labels})+        return config, input_dict++    def create_and_check_model(self, config, input_dict):+        model = VisualBertModel(config=config)+        model.to(torch_device)+        model.eval()+        result = model(**input_dict)+        self.parent.assertEqual(+            result.last_hidden_state.shape,+            (self.batch_size, self.seq_length + self.visual_seq_length, self.hidden_size),+        )++    def create_and_check_for_pretraining(self, config, input_dict):+        model = VisualBertForPreTraining(config=config)+        model.to(torch_device)+        model.eval()+        result = model(**input_dict)+        self.parent.assertEqual(+            result.prediction_logits.shape,+            (self.batch_size, self.seq_length + self.visual_seq_length, self.vocab_size),+        )++    def create_and_check_for_vqa(self, config, input_dict):+        model = VisualBertForQuestionAnswering(config=config)+        model.to(torch_device)+        model.eval()+        result = model(**input_dict)+        self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels))++    def create_and_check_for_vqa_advanced(self, config, input_dict):+        model = VisualBertForQuestionAnsweringAdvanced(config=config)+        model.to(torch_device)+        model.eval()+        result = model(**input_dict)+        self.parent.assertEqual(+            result.logits.shape, (self.batch_size, self.seq_length + self.visual_seq_length, self.vocab_size)+        )++    def create_and_check_for_multiple_choice(self, config, input_dict):+        model = VisualBertForMultipleChoice(config=config)+        model.to(torch_device)+        model.eval()+        result = model(**input_dict)+        self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_choices))++    def create_and_check_for_nlvr(self, config, input_dict):+        model = VisualBertForVisualReasoning(config=config)+        model.to(torch_device)+        model.eval()+        result = model(**input_dict)+        self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels))++    def create_and_check_for_flickr(self, config, input_dict):+        model = VisualBertForRegionToPhraseAlignment(config=config)+        model.to(torch_device)+        model.eval()+        result = model(**input_dict)+        self.parent.assertEqual(+            result.logits.shape, (self.batch_size, self.seq_length + self.visual_seq_length, self.visual_seq_length)+        )+++@require_torch+class VisualBertModelTest(ModelTesterMixin, unittest.TestCase):++    all_model_classes = (+        (+            VisualBertModel,+            VisualBertForMultipleChoice,+            VisualBertForVisualReasoning,+            VisualBertForRegionToPhraseAlignment,+            VisualBertForQuestionAnswering,+            VisualBertForQuestionAnsweringAdvanced,+            VisualBertForPreTraining,+        )+        if is_torch_available()+        else ()+    )+    test_torchscript = False+    test_pruning = False++    def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):+        inputs_dict = copy.deepcopy(inputs_dict)+        if model_class == VisualBertForMultipleChoice:+            for key in inputs_dict.keys():+                value = inputs_dict[key]+                if isinstance(value, torch.Tensor) and value.ndim > 1:+                    if key != "visual_embeds":+                        inputs_dict[key] = (+                            inputs_dict[key].unsqueeze(1).expand(-1, self.model_tester.num_choices, -1).contiguous()+                        )+                    else:+                        inputs_dict[key] = (+                            inputs_dict[key]+                            .unsqueeze(1)+                            .expand(-1, self.model_tester.num_choices, -1, self.model_tester.visual_embedding_dim)+                            .contiguous()

Done.

gchhablani

comment created time in an hour

Pull request review commenthuggingface/transformers

Add VisualBERT

+# coding=utf-8+# Copyright 2021 The HuggingFace Inc. team. All rights reserved.+#+# Licensed under the Apache License, Version 2.0 (the "License");+# you may not use this file except in compliance with the License.+# You may obtain a copy of the License at+#+#     http://www.apache.org/licenses/LICENSE-2.0+#+# Unless required by applicable law or agreed to in writing, software+# distributed under the License is distributed on an "AS IS" BASIS,+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.+# See the License for the specific language governing permissions and+# limitations under the License.+""" PyTorch VisualBERT model. """+++import math+from copy import deepcopy+from dataclasses import dataclass+from typing import Optional, Tuple++import torch+import torch.utils.checkpoint+from torch import nn+from torch.nn import CrossEntropyLoss, KLDivLoss, LogSoftmax++from ...activations import ACT2FN+from ...file_utils import (+    ModelOutput,+    _prepare_output_docstrings,+    add_start_docstrings,+    add_start_docstrings_to_model_forward,+    replace_return_docstrings,+)+from ...modeling_outputs import (+    BaseModelOutput,+    BaseModelOutputWithPooling,+    MaskedLMOutput,+    MultipleChoiceModelOutput,+    SequenceClassifierOutput,+)+from ...modeling_utils import (+    PreTrainedModel,+    apply_chunking_to_forward,+    find_pruneable_heads_and_indices,+    prune_linear_layer,+)+from ...utils import logging+from .configuration_visual_bert import VisualBertConfig+++logger = logging.get_logger(__name__)++_CONFIG_FOR_DOC = "VisualBertConfig"+_TOKENIZER_FOR_DOC = "BertTokenizer"+_TOKENIZER_CHECKPOINT = "bert-base-uncased"++VISUAL_BERT_PRETRAINED_MODEL_ARCHIVE_LIST = [+    "gchhablani/visualbert-vqa",+    "gchhablani/visualbert-vqa-pre",+    "gchhablani/visualbert-vqa-coco-pre",+    "gchhablani/visualbert-vcr",+    "gchhablani/visualbert-vcr-pre",+    "gchhablani/visualbert-vcr-coco-pre",+    "gchhablani/visualbert-nlvr2",+    "gchhablani/visualbert-nlvr2-pre",+    "gchhablani/visualbert-nlvr2-coco-pre"+    # See all VisualBERT models at https://huggingface.co/models?filter=visual_bert+]+++def mish(x):+    return x * torch.tanh(nn.functional.softplus(x))+++def add_code_sample_docstrings(+    *docstr,+    tokenizer_class=None,+    tokenizer_checkpoint=None,+    checkpoint=None,+    output_type=None,+    config_class=None,+    mask=None,+    model_cls=None,+    code_sample=None+):+    def docstring_decorator(fn):+        # model_class defaults to function's class if not specified otherwise+        model_class = fn.__qualname__.split(".")[0] if model_cls is None else model_cls++        doc_kwargs = dict(+            model_class=model_class,+            tokenizer_class=tokenizer_class,+            checkpoint=checkpoint,+            mask=mask,+            tokenizer_checkpoint=tokenizer_checkpoint,+        )++        output_doc = _prepare_output_docstrings(output_type, config_class) if output_type is not None else ""++        built_doc = code_sample.format(**doc_kwargs)+        fn.__doc__ = (fn.__doc__ or "") + "".join(docstr) + output_doc + built_doc+        return fn++    return docstring_decorator+++class VisualBertEmbeddings(nn.Module):+    """Construct the embeddings from word, position and token_type embeddings and visual embeddings."""++    def __init__(self, config):+        super().__init__()+        self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)+        self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)+        self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)++        # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load+        # any TensorFlow checkpoint file++        self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)+        self.dropout = nn.Dropout(config.hidden_dropout_prob)++        # position_ids (1, len position emb) is contiguous in memory and exported when serialized+        self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))+        self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")++        # For Visual Features+        # Token type and position embedding for image features+        self.visual_token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)+        self.visual_position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)++        if config.special_visual_initialize:+            self.visual_token_type_embeddings.weight = torch.nn.Parameter(+                deepcopy(self.token_type_embeddings.weight.data), requires_grad=True

I'm cloning the embedding weights instead of creating a new parameter. Hope that is okay.

gchhablani

comment created time in an hour

Pull request review commenthuggingface/transformers

Add VisualBERT

+# coding=utf-8+# Copyright 2021 The HuggingFace Inc. team. All rights reserved.+#+# Licensed under the Apache License, Version 2.0 (the "License");+# you may not use this file except in compliance with the License.+# You may obtain a copy of the License at+#+#     http://www.apache.org/licenses/LICENSE-2.0+#+# Unless required by applicable law or agreed to in writing, software+# distributed under the License is distributed on an "AS IS" BASIS,+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.+# See the License for the specific language governing permissions and+# limitations under the License.+""" PyTorch VisualBERT model. """+++import math+import os+from dataclasses import dataclass+from typing import Optional, Tuple++import torch+import torch.utils.checkpoint+from torch import nn+from torch.nn import CrossEntropyLoss, KLDivLoss, LogSoftmax++from ...activations import ACT2FN+from ...file_utils import (+    ModelOutput,+    _prepare_output_docstrings,+    add_start_docstrings,+    add_start_docstrings_to_model_forward,+    replace_return_docstrings,+)+from ...modeling_outputs import (+    BaseModelOutputWithPastAndCrossAttentions,+    BaseModelOutputWithPoolingAndCrossAttentions,+    MaskedLMOutput,+    MultipleChoiceModelOutput,+    SequenceClassifierOutput,+)+from ...modeling_utils import (+    PreTrainedModel,+    apply_chunking_to_forward,+    find_pruneable_heads_and_indices,+    prune_linear_layer,+)+from ...utils import logging+from .configuration_visual_bert import VisualBertConfig+++logger = logging.get_logger(__name__)++_CONFIG_FOR_DOC = "VisualBertConfig"+_TOKENIZER_FOR_DOC = "BertTokenizer"+_TOKENIZER_CHECKPOINT = "bert-base-uncased"++VISUAL_BERT_PRETRAINED_MODEL_ARCHIVE_LIST = [+    "gchhablani/visualbert-vqa",+    "gchhablani/visualbert-vqa-pre",+    "gchhablani/visualbert-vqa-coco-pre",+    "gchhablani/visualbert-vcr",+    "gchhablani/visualbert-vcr-pre",+    "gchhablani/visualbert-vcr-coco-pre",+    "gchhablani/visualbert-nlvr2",+    "gchhablani/visualbert-nlvr2-pre",+    "gchhablani/visualbert-nlvr2-coco-pre"+    # See all VisualBERT models at https://huggingface.co/models?filter=visual_bert+]+++# TO-CHECK+def load_tf_weights_in_visual_bert(model, config, tf_checkpoint_path):+    """Load tf checkpoints in a pytorch model."""+    try:+        import re++        import numpy as np+        import tensorflow as tf+    except ImportError:+        logger.error(+            "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see "+            "https://www.tensorflow.org/install/ for installation instructions."+        )+        raise+    tf_path = os.path.abspath(tf_checkpoint_path)+    logger.info("Converting TensorFlow checkpoint from {}".format(tf_path))+    # Load weights from TF model+    init_vars = tf.train.list_variables(tf_path)+    names = []+    arrays = []+    for name, shape in init_vars:+        logger.info("Loading TF weight {} with shape {}".format(name, shape))+        array = tf.train.load_variable(tf_path, name)+        names.append(name)+        arrays.append(array)++    for name, array in zip(names, arrays):+        name = name.split("/")+        # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v+        # which are not required for using pretrained model+        if any(+            n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"]+            for n in name+        ):+            logger.info("Skipping {}".format("/".join(name)))+            continue+        pointer = model+        for m_name in name:+            if re.fullmatch(r"[A-Za-z]+_\d+", m_name):+                scope_names = re.split(r"_(\d+)", m_name)+            else:+                scope_names = [m_name]+            if scope_names[0] == "kernel" or scope_names[0] == "gamma":+                pointer = getattr(pointer, "weight")+            elif scope_names[0] == "output_bias" or scope_names[0] == "beta":+                pointer = getattr(pointer, "bias")+            elif scope_names[0] == "output_weights":+                pointer = getattr(pointer, "weight")+            elif scope_names[0] == "squad":+                pointer = getattr(pointer, "classifier")+            else:+                try:+                    pointer = getattr(pointer, scope_names[0])+                except AttributeError:+                    logger.info("Skipping {}".format("/".join(name)))+                    continue+            if len(scope_names) >= 2:+                num = int(scope_names[1])+                pointer = pointer[num]+        if m_name[-11:] == "_embeddings":+            pointer = getattr(pointer, "weight")+        elif m_name == "kernel":+            array = np.transpose(array)+        try:+            assert (+                pointer.shape == array.shape+            ), f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched"+        except AssertionError as e:+            e.args += (pointer.shape, array.shape)+            raise+        logger.info("Initialize PyTorch weight {}".format(name))+        pointer.data = torch.from_numpy(array)+    return model+++def mish(x):+    return x * torch.tanh(nn.functional.softplus(x))+++def add_code_sample_docstrings(+    *docstr,+    tokenizer_class=None,+    tokenizer_checkpoint=None,+    checkpoint=None,+    output_type=None,+    config_class=None,+    mask=None,+    model_cls=None,+    code_sample=None+):+    def docstring_decorator(fn):+        # model_class defaults to function's class if not specified otherwise+        model_class = fn.__qualname__.split(".")[0] if model_cls is None else model_cls++        doc_kwargs = dict(+            model_class=model_class,+            tokenizer_class=tokenizer_class,+            checkpoint=checkpoint,+            mask=mask,+            tokenizer_checkpoint=tokenizer_checkpoint,+        )++        output_doc = _prepare_output_docstrings(output_type, config_class) if output_type is not None else ""++        built_doc = code_sample.format(**doc_kwargs)+        fn.__doc__ = (fn.__doc__ or "") + "".join(docstr) + output_doc + built_doc+        return fn++    return docstring_decorator+++# TO-CHECK: Vestige of the original code+++class BertLayerNorm(nn.Module):+    def __init__(self, hidden_size, eps=1e-12):+        """Construct a layernorm module in the TF style (epsilon inside the square root)."""+        super(BertLayerNorm, self).__init__()+        self.weight = nn.Parameter(torch.ones(hidden_size))+        self.bias = nn.Parameter(torch.zeros(hidden_size))+        self.variance_epsilon = eps++    def forward(self, x):+        u = x.mean(-1, keepdim=True)+        s = (x - u).pow(2).mean(-1, keepdim=True)+        x = (x - u) / torch.sqrt(s + self.variance_epsilon)+        return self.weight * x + self.bias+++class VisualBertEmbeddings(nn.Module):+    """Construct the embeddings from word, position and token_type embeddings and visual embeddings."""++    def __init__(self, config):+        super().__init__()+        self.word_embeddings = nn.Embedding(+            config.vocab_size, config.hidden_size+        )  # TO-CHECK: , padding_idx=config.pad_token_id+        self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)+        self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)++        # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load+        # any TensorFlow checkpoint file++        # TO-CHECK+        # self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)+        self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps)  # original eps=1e-12+        self.dropout = nn.Dropout(config.hidden_dropout_prob)++        # TO-CHECK+        # position_ids (1, len position emb) is contiguous in memory and exported when serialized+        # self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))+        self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")++        # For Visual Features+        # Segment and position embedding for image features+        self.visual_token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)+        self.visual_position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)++        # TO-CHECK: Check if register buffer is needed for Visual features+        self.visual_projection = nn.Linear(config.visual_embedding_dim, config.hidden_size)++    # TO-CHECK: Check how to incorporate this. This is being called outside the classes.+    # def special_intialize(self):+    #     ### This is a bit unorthodox. The better way might be to add an inititilizer to AllenNLP.+    #     # This function is used to initialize the token_type_embeddings_visual and positiona_embedding_visual, just incase.+    #     self.token_type_embeddings_visual.weight = torch.nn.Parameter(deepcopy(self.token_type_embeddings.weight.data), requires_grad = True)+    #     self.position_embeddings_visual.weight = torch.nn.Parameter(deepcopy(self.position_embeddings.weight.data), requires_grad = True)+    #     return++    def forward(+        self,+        input_ids=None,  # TO-CHECK+        token_type_ids=None,+        position_ids=None,+        inputs_embeds=None,+        visual_embeds=None,+        visual_token_type_ids=None,+        image_text_alignment=None,+        # past_key_values_length=0, # TO-CHECK+    ):+        # TO-CHECK: Check if `confidence=None` and `visual_position_embeds=None` (or id) is needed.+        # `position_embeddings_visual` is not used in the original code.++        if input_ids is not None:+            input_shape = input_ids.size()+        else:+            input_shape = inputs_embeds.size()[:-1]++        seq_length = input_shape[1]++        # TO-CHECK+        # if position_ids is None:+        #     position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]+        if input_ids is not None:+            position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device)+            position_ids = position_ids.unsqueeze(0).expand_as(input_ids)+        else:+            position_ids = torch.arange(seq_length, dtype=torch.long, device=inputs_embeds.device)+            position_ids = position_ids.unsqueeze(0).expand(input_shape)++        if token_type_ids is None:+            token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)++        if inputs_embeds is None:+            inputs_embeds = self.word_embeddings(input_ids)++        token_type_embeddings = self.token_type_embeddings(token_type_ids)++        embeddings = inputs_embeds + token_type_embeddings+        if self.position_embedding_type == "absolute":+            position_embeddings = self.position_embeddings(position_ids)+            embeddings += position_embeddings++        if visual_embeds is not None:+            visual_embeds = self.visual_projection(visual_embeds)+            visual_token_type_embeddings = self.visual_token_type_embeddings(visual_token_type_ids)++            if image_text_alignment is not None:++                # TO-DO: Find a way to handle this in a better way.+                # image_text_alignment = Batch x image_length x alignment_number. Each element denotes the position of the word corresponding to the image feature. -1 is the padding value.+                image_text_alignment_mask = (image_text_alignment != -1).long()+                # Get rid of the -1.+                image_text_alignment = image_text_alignment_mask * image_text_alignment++                # Batch x image_length x alignment length x dim+                visual_position_embeddings = self.position_embeddings(+                    image_text_alignment+                ) * image_text_alignment_mask.to(dtype=next(self.parameters()).dtype).unsqueeze(-1)+                visual_position_embeddings = visual_position_embeddings.sum(2)++                # We want to averge along the alignment_number dimension.+                image_text_alignment_mask = image_text_alignment_mask.to(dtype=next(self.parameters()).dtype).sum(2)+                image_text_alignment_mask[image_text_alignment_mask == 0] = 1  # Avoid devide by zero error+                visual_position_embeddings = visual_position_embeddings / image_text_alignment_mask.unsqueeze(-1)++                visual_position_ids = torch.zeros(+                    *visual_embeds.size()[:-1], dtype=torch.long, device=visual_embeds.device+                )  # They use .cuda() but I believe this will be same as visual_embeds device.++                # When fine-tuning the detector , the image_text_alignment is sometimes padded too long.+                if visual_position_embeddings.size(1) != visual_embeds.size(1):+                    assert visual_position_embeddings.size(1) >= visual_embeds.size(1)+                    visual_position_embeddings = visual_position_embeddings[:, : visual_embeds.size(1), :]++                visual_position_embeddings = visual_position_embeddings + self.visual_position_embeddings(+                    visual_position_ids+                )+            else:+                visual_position_ids = torch.zeros(+                    *visual_embeds.size()[:-1], dtype=torch.long, device=visual_embeds.device+                )  # They use .cuda() but I believe this will be same as visual_embeds device.+                visual_position_embeddings = self.visual_position_embeddings(visual_position_ids)++            visual_embeddings = visual_embeds + visual_position_embeddings + visual_token_type_embeddings++            # Concate the two:+            embeddings = torch.cat((embeddings, visual_embeddings), dim=1)  # concat the visual embeddings++        embeddings = self.LayerNorm(embeddings)+        embeddings = self.dropout(embeddings)+        return embeddings+++class VisualBertSelfAttention(nn.Module):+    def __init__(self, config):+        super().__init__()+        if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):+            raise ValueError(+                "The hidden size (%d) is not a multiple of the number of attention "+                "heads (%d)" % (config.hidden_size, config.num_attention_heads)+            )++        self.num_attention_heads = config.num_attention_heads+        self.attention_head_size = int(config.hidden_size / config.num_attention_heads)+        self.all_head_size = self.num_attention_heads * self.attention_head_size++        self.query = nn.Linear(config.hidden_size, self.all_head_size)+        self.key = nn.Linear(config.hidden_size, self.all_head_size)+        self.value = nn.Linear(config.hidden_size, self.all_head_size)++        self.dropout = nn.Dropout(config.attention_probs_dropout_prob)+        self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")+        if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":+            self.max_position_embeddings = config.max_position_embeddings+            self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)++        self.is_decoder = config.is_decoder++    def transpose_for_scores(self, x):+        new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)+        x = x.view(*new_x_shape)+        return x.permute(0, 2, 1, 3)++    def forward(+        self,+        hidden_states,+        attention_mask=None,+        head_mask=None,+        encoder_hidden_states=None,+        encoder_attention_mask=None,+        past_key_value=None,+        output_attentions=False,+    ):+        mixed_query_layer = self.query(hidden_states)++        # If this is instantiated as a cross-attention module, the keys+        # and values come from an encoder; the attention mask needs to be+        # such that the encoder's padding tokens are not attended to.+        is_cross_attention = encoder_hidden_states is not None++        if is_cross_attention and past_key_value is not None:+            # reuse k,v, cross_attentions+            key_layer = past_key_value[0]+            value_layer = past_key_value[1]+            attention_mask = encoder_attention_mask+        elif is_cross_attention:+            key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))+            value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))+            attention_mask = encoder_attention_mask+        elif past_key_value is not None:+            key_layer = self.transpose_for_scores(self.key(hidden_states))+            value_layer = self.transpose_for_scores(self.value(hidden_states))+            key_layer = torch.cat([past_key_value[0], key_layer], dim=2)+            value_layer = torch.cat([past_key_value[1], value_layer], dim=2)+        else:+            key_layer = self.transpose_for_scores(self.key(hidden_states))+            value_layer = self.transpose_for_scores(self.value(hidden_states))++        query_layer = self.transpose_for_scores(mixed_query_layer)++        if self.is_decoder:+            # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.+            # Further calls to cross_attention layer can then reuse all cross-attention+            # key/value_states (first "if" case)+            # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of+            # all previous decoder key/value_states. Further calls to uni-directional self-attention+            # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)+            # if encoder bi-directional self-attention `past_key_value` is always `None`+            past_key_value = (key_layer, value_layer)++        # Take the dot product between "query" and "key" to get the raw attention scores.+        attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))++        if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":+            seq_length = hidden_states.size()[1]+            position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)+            position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1)+            distance = position_ids_l - position_ids_r+            positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)+            positional_embedding = positional_embedding.to(dtype=query_layer.dtype)  # fp16 compatibility++            if self.position_embedding_type == "relative_key":+                relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)+                attention_scores = attention_scores + relative_position_scores+            elif self.position_embedding_type == "relative_key_query":+                relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)+                relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)+                attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key++        attention_scores = attention_scores / math.sqrt(self.attention_head_size)+        if attention_mask is not None:+            # Apply the attention mask is (precomputed for all layers in VisualBertModel forward() function)+            attention_scores = attention_scores + attention_mask++        # Normalize the attention scores to probabilities.+        attention_probs = nn.Softmax(dim=-1)(attention_scores)++        # This is actually dropping out entire tokens to attend to, which might+        # seem a bit unusual, but is taken from the original Transformer paper.+        attention_probs = self.dropout(attention_probs)++        # Mask heads if we want to+        if head_mask is not None:+            attention_probs = attention_probs * head_mask++        context_layer = torch.matmul(attention_probs, value_layer)++        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()+        new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)+        context_layer = context_layer.view(*new_context_layer_shape)++        outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)++        if self.is_decoder:+            outputs = outputs + (past_key_value,)+        return outputs+++class VisualBertSelfOutput(nn.Module):+    def __init__(self, config):+        super().__init__()+        self.dense = nn.Linear(config.hidden_size, config.hidden_size)++        # TO-CHECK+        # self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)+        self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps)  # original eps=1e-12+        self.dropout = nn.Dropout(config.hidden_dropout_prob)++    def forward(self, hidden_states, input_tensor):+        hidden_states = self.dense(hidden_states)+        hidden_states = self.dropout(hidden_states)+        hidden_states = self.LayerNorm(hidden_states + input_tensor)+        return hidden_states+++class VisualBertAttention(nn.Module):+    def __init__(self, config):+        super().__init__()+        self.self = VisualBertSelfAttention(config)+        self.output = VisualBertSelfOutput(config)+        self.pruned_heads = set()++    def prune_heads(self, heads):+        if len(heads) == 0:+            return+        heads, index = find_pruneable_heads_and_indices(+            heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads+        )++        # Prune linear layers+        self.self.query = prune_linear_layer(self.self.query, index)+        self.self.key = prune_linear_layer(self.self.key, index)+        self.self.value = prune_linear_layer(self.self.value, index)+        self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)++        # Update hyper params and store pruned heads+        self.self.num_attention_heads = self.self.num_attention_heads - len(heads)+        self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads+        self.pruned_heads = self.pruned_heads.union(heads)++    def forward(+        self,+        hidden_states,+        attention_mask=None,+        head_mask=None,+        encoder_hidden_states=None,+        encoder_attention_mask=None,+        past_key_value=None,+        output_attentions=False,+    ):+        self_outputs = self.self(+            hidden_states,+            attention_mask,+            head_mask,+            encoder_hidden_states,+            encoder_attention_mask,+            past_key_value,+            output_attentions,+        )+        attention_output = self.output(self_outputs[0], hidden_states)+        outputs = (attention_output,) + self_outputs[1:]  # add attentions if we output them+        return outputs+++class VisualBertIntermediate(nn.Module):+    def __init__(self, config):+        super().__init__()+        self.dense = nn.Linear(config.hidden_size, config.intermediate_size)+        if isinstance(config.hidden_act, str):+            self.intermediate_act_fn = ACT2FN[config.hidden_act]+        else:+            self.intermediate_act_fn = config.hidden_act++    def forward(self, hidden_states):+        hidden_states = self.dense(hidden_states)+        hidden_states = self.intermediate_act_fn(hidden_states)+        return hidden_states+++class VisualBertOutput(nn.Module):+    def __init__(self, config):+        super().__init__()+        self.dense = nn.Linear(config.intermediate_size, config.hidden_size)+        # TO-CHECK+        # self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)+        self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps)  # original eps=1e-12+        self.dropout = nn.Dropout(config.hidden_dropout_prob)++    def forward(self, hidden_states, input_tensor):+        hidden_states = self.dense(hidden_states)+        hidden_states = self.dropout(hidden_states)+        hidden_states = self.LayerNorm(hidden_states + input_tensor)+        return hidden_states+++class VisualBertLayer(nn.Module):+    def __init__(self, config):+        super().__init__()+        self.chunk_size_feed_forward = config.chunk_size_feed_forward+        self.seq_len_dim = 1+        self.attention = VisualBertAttention(config)+        self.is_decoder = config.is_decoder+        self.add_cross_attention = config.add_cross_attention+        if self.add_cross_attention:+            assert self.is_decoder, f"{self} should be used as a decoder model if cross attention is added"+            self.crossattention = VisualBertAttention(config)+        self.intermediate = VisualBertIntermediate(config)+        self.output = VisualBertOutput(config)++    def forward(+        self,+        hidden_states,+        attention_mask=None,+        head_mask=None,+        encoder_hidden_states=None,+        encoder_attention_mask=None,+        past_key_value=None,+        output_attentions=False,+    ):+        # decoder uni-directional self-attention cached key/values tuple is at positions 1,2+        self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None+        self_attention_outputs = self.attention(+            hidden_states,+            attention_mask,+            head_mask,+            output_attentions=output_attentions,+            past_key_value=self_attn_past_key_value,+        )+        attention_output = self_attention_outputs[0]++        # if decoder, the last output is tuple of self-attn cache+        if self.is_decoder:+            outputs = self_attention_outputs[1:-1]+            present_key_value = self_attention_outputs[-1]+        else:+            outputs = self_attention_outputs[1:]  # add self attentions if we output attention weights++        cross_attn_present_key_value = None+        if self.is_decoder and encoder_hidden_states is not None:+            assert hasattr(+                self, "crossattention"+            ), f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers by setting `config.add_cross_attention=True`"++            # cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple+            cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None+            cross_attention_outputs = self.crossattention(+                attention_output,+                attention_mask,+                head_mask,+                encoder_hidden_states,+                encoder_attention_mask,+                cross_attn_past_key_value,+                output_attentions,+            )+            attention_output = cross_attention_outputs[0]+            outputs = outputs + cross_attention_outputs[1:-1]  # add cross attentions if we output attention weights++            # add cross-attn cache to positions 3,4 of present_key_value tuple+            cross_attn_present_key_value = cross_attention_outputs[-1]+            present_key_value = present_key_value + cross_attn_present_key_value++        layer_output = apply_chunking_to_forward(+            self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output+        )+        outputs = (layer_output,) + outputs++        # if decoder, return the attn key/values as the last output+        if self.is_decoder:+            outputs = outputs + (present_key_value,)++        return outputs++    def feed_forward_chunk(self, attention_output):+        intermediate_output = self.intermediate(attention_output)+        layer_output = self.output(intermediate_output, attention_output)+        return layer_output+++class VisualBertEncoder(nn.Module):+    def __init__(self, config):+        super().__init__()+        self.config = config+        self.layer = nn.ModuleList([VisualBertLayer(config) for _ in range(config.num_hidden_layers)])++    def forward(+        self,+        hidden_states,+        attention_mask=None,+        head_mask=None,+        encoder_hidden_states=None,+        encoder_attention_mask=None,+        past_key_values=None,+        use_cache=None,+        output_attentions=False,+        output_hidden_states=False,+        return_dict=True,+    ):+        all_hidden_states = () if output_hidden_states else None+        all_self_attentions = () if output_attentions else None+        all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None++        next_decoder_cache = () if use_cache else None+        for i, layer_module in enumerate(self.layer):+            if output_hidden_states:+                all_hidden_states = all_hidden_states + (hidden_states,)++            layer_head_mask = head_mask[i] if head_mask is not None else None+            past_key_value = past_key_values[i] if past_key_values is not None else None++            if getattr(self.config, "gradient_checkpointing", False) and self.training:++                if use_cache:+                    logger.warn(+                        "`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting "+                        "`use_cache=False`..."+                    )+                    use_cache = False++                def create_custom_forward(module):+                    def custom_forward(*inputs):+                        return module(*inputs, past_key_value, output_attentions)++                    return custom_forward++                layer_outputs = torch.utils.checkpoint.checkpoint(+                    create_custom_forward(layer_module),+                    hidden_states,+                    attention_mask,+                    layer_head_mask,+                    encoder_hidden_states,+                    encoder_attention_mask,+                )+            else:+                layer_outputs = layer_module(+                    hidden_states,+                    attention_mask,+                    layer_head_mask,+                    encoder_hidden_states,+                    encoder_attention_mask,+                    past_key_value,+                    output_attentions,+                )++            hidden_states = layer_outputs[0]+            if use_cache:+                next_decoder_cache += (layer_outputs[-1],)+            if output_attentions:+                all_self_attentions = all_self_attentions + (layer_outputs[1],)+                if self.config.add_cross_attention:+                    all_cross_attentions = all_cross_attentions + (layer_outputs[2],)++        if output_hidden_states:+            all_hidden_states = all_hidden_states + (hidden_states,)++        if not return_dict:+            return tuple(+                v+                for v in [+                    hidden_states,+                    next_decoder_cache,+                    all_hidden_states,+                    all_self_attentions,+                    all_cross_attentions,+                ]+                if v is not None+            )+        return BaseModelOutputWithPastAndCrossAttentions(+            last_hidden_state=hidden_states,+            past_key_values=next_decoder_cache,+            hidden_states=all_hidden_states,+            attentions=all_self_attentions,+            cross_attentions=all_cross_attentions,+        )+++class VisualBertPooler(nn.Module):+    def __init__(self, config):+        super().__init__()+        self.dense = nn.Linear(config.hidden_size, config.hidden_size)+        self.activation = nn.Tanh()++    def forward(self, hidden_states):+        # We "pool" the model by simply taking the hidden state corresponding+        # to the first token.+        first_token_tensor = hidden_states[:, 0]+        pooled_output = self.dense(first_token_tensor)+        pooled_output = self.activation(pooled_output)+        return pooled_output+++class VisualBertPredictionHeadTransform(nn.Module):+    def __init__(self, config):+        super().__init__()+        self.dense = nn.Linear(config.hidden_size, config.hidden_size)+        if isinstance(config.hidden_act, str):+            self.transform_act_fn = ACT2FN[config.hidden_act]+        else:+            self.transform_act_fn = config.hidden_act++        # TO-CHECK+        # self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)+        self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps)++    def forward(self, hidden_states):+        hidden_states = self.dense(hidden_states)+        hidden_states = self.transform_act_fn(hidden_states)+        hidden_states = self.LayerNorm(hidden_states)+        return hidden_states+++class VisualBertLMPredictionHead(nn.Module):+    def __init__(self, config):+        super().__init__()+        self.transform = VisualBertPredictionHeadTransform(config)++        # The output weights are the same as the input embeddings, but there is+        # an output-only bias for each token.+        self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)++        self.bias = nn.Parameter(torch.zeros(config.vocab_size))++        # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`+        self.decoder.bias = self.bias++    def forward(self, hidden_states):+        hidden_states = self.transform(hidden_states)+        hidden_states = self.decoder(hidden_states)+        return hidden_states+++class VisualBertOnlyMLMHead(nn.Module):+    def __init__(self, config):+        super().__init__()+        self.predictions = VisualBertLMPredictionHead(config)++    def forward(self, sequence_output):+        prediction_scores = self.predictions(sequence_output)+        return prediction_scores+++class VisualBertOnlySIPHead(nn.Module):  # Sentence-Image Prediction+    def __init__(self, config):+        super().__init__()+        self.seq_relationship = nn.Linear(config.hidden_size, 2)++    def forward(self, pooled_output):+        seq_relationship_score = self.seq_relationship(pooled_output)+        return seq_relationship_score+++class VisualBertPreTrainingHeads(nn.Module):+    def __init__(self, config):+        super().__init__()+        self.predictions = VisualBertLMPredictionHead(config)+        self.seq_relationship = nn.Linear(config.hidden_size, 2)++    def forward(self, sequence_output, pooled_output):+        prediction_scores = self.predictions(sequence_output)+        seq_relationship_score = self.seq_relationship(pooled_output)+        return prediction_scores, seq_relationship_score+++class VisualBertPreTrainedModel(PreTrainedModel):+    """+    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained+    models.+    """++    config_class = VisualBertConfig+    load_tf_weights = load_tf_weights_in_visual_bert+    base_model_prefix = "visual_bert"+    _keys_to_ignore_on_load_missing = [r"position_ids"]++    def _init_weights(self, module):+        """Initialize the weights"""+        if isinstance(module, (nn.Linear, nn.Embedding)):+            # Slightly different from the TF version which uses truncated_normal for initialization+            # cf https://github.com/pytorch/pytorch/pull/5617+            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)++        # TO-CHECK+        # elif isinstance(module, nn.LayerNorm):+        elif isinstance(module, BertLayerNorm):+            module.bias.data.zero_()+            module.weight.data.fill_(1.0)+        if isinstance(module, nn.Linear) and module.bias is not None:+            module.bias.data.zero_()+++@dataclass+class VisualBertForPreTrainingOutput(ModelOutput):+    """+    Output type of :class:`~transformers.VisualBertForPreTraining`.++    Args:+        loss (`optional`, returned when ``labels`` is provided, ``torch.FloatTensor`` of shape :obj:`(1,)`):+            Total loss as the sum of the masked language modeling loss and the sentence-image prediction+            (classification) loss.+        prediction_logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`):+            Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).+        seq_relationship_logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, 2)`):+            Prediction scores of the sentence-image prediction (classification) head (scores of True/False continuation+            before SoftMax).+        hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):+            Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)+            of shape :obj:`(batch_size, sequence_length, hidden_size)`.++            Hidden-states of the model at the output of each layer plus the initial embedding outputs.+        attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):+            Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads,+            sequence_length, sequence_length)`.++            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention+            heads.+    """++    loss: Optional[torch.FloatTensor] = None+    prediction_logits: torch.FloatTensor = None+    seq_relationship_logits: torch.FloatTensor = None+    hidden_states: Optional[Tuple[torch.FloatTensor]] = None+    attentions: Optional[Tuple[torch.FloatTensor]] = None+++VISUAL_BERT_START_DOCSTRING = r"""+    This model inherits from :class:`~transformers.PreTrainedModel`. Check the superclass documentation for the generic+    methods the library implements for all its model (such as downloading or saving, resizing the input embeddings,+    pruning heads etc.)++    This model is also a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`__+    subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to+    general usage and behavior.++    Parameters:+        config (:class:`~transformers.VisualBertConfig`): Model configuration class with all the parameters of the model.+            Initializing with a config file does not load the weights associated with the model, only the+            configuration. Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model+            weights.+"""++VISUAL_BERT_INPUTS_DOCSTRING = r"""+    Args:+        input_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`):+            Indices of input sequence tokens in the vocabulary.++            Indices can be obtained using :class:`~transformers.BertTokenizer`. See+            :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for+            details.++            `What are input IDs? <../glossary.html#input-ids>`__+        attention_mask (:obj:`torch.FloatTensor` of shape :obj:`({0})`, `optional`):+            Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``:++            - 1 for tokens that are **not masked**,+            - 0 for tokens that are **masked**.++            `What are attention masks? <../glossary.html#attention-mask>`__+        token_type_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`, `optional`):+            Segment token indices to indicate first and second portions of the inputs. Indices are selected in ``[0,+            1]``:++            - 0 corresponds to a `sentence A` token,+            - 1 corresponds to a `sentence B` token.++            `What are token type IDs? <../glossary.html#token-type-ids>`_+        position_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`, `optional`):+            Indices of positions of each input sequence tokens in the position embeddings. Selected in the range ``[0,+            config.max_position_embeddings - 1]``.++            `What are position IDs? <../glossary.html#position-ids>`_+        head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`):+            Mask to nullify selected heads of the self-attention modules. Mask values selected in ``[0, 1]``:++            - 1 indicates the head is **not masked**,+            - 0 indicates the head is **masked**.++        inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`({0}, hidden_size)`, `optional`):+            Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation.+            This is useful if you want more control over how to convert :obj:`input_ids` indices into associated+            vectors than the model's internal embedding lookup matrix.++        visual_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, visual_seq_length, visual_embedding_dim)`, `optional`):+            The embedded representation of the visual inputs, generally derived using using an object detector.++        visual_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, visual_seq_length)`, `optional`):+            Mask to avoid performing attention on visual embeddings. Mask values selected in ``[0, 1]``:++            - 1 for tokens that are **not masked**,+            - 0 for tokens that are **masked**.++            `What are attention masks? <../glossary.html#attention-mask>`__+        visual_token_type_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, visual_seq_length)`, `optional`):+            Segment token indices to indicate different portions of the visual embeds. Indices are selected in ``[0,+            1]``:++            - 0 corresponds to a `sentence A` token,+            - 1 corresponds to a `sentence B` token.++            `What are token type IDs? <../glossary.html#token-type-ids>`_++        image_text_alignment (:obj:`torch.LongTensor` of shape :obj:`(batch_size, visual_seq_length, alignment_number)`, `optional`):+            Image-Text alignment uses to decide the position IDs of the visual embeddings.++        output_attentions (:obj:`bool`, `optional`):+            Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned+            tensors for more detail.+        output_hidden_states (:obj:`bool`, `optional`):+            Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for+            more detail.+        return_dict (:obj:`bool`, `optional`):+            Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple.+"""++VISUAL_BERT_VQA_SAMPLE = r"""+    Example::+        >>> # Assumption: `get_visual_embeddings(image)` gets the visual embeddings of the image in the batch.+        >>> from transformers import {tokenizer_class}, {model_class}+        >>> import torch++        >>> tokenizer = {tokenizer_class}.from_pretrained('{tokenizer_checkpoint}')+        >>> model = {model_class}.from_pretrained('{checkpoint}')++        >>> text = "Who is eating the apple?"+        >>> inputs = tokenizer(text, return_tensors='pt')+        >>> visual_embeds = get_visual_embeddings(image).unsqueeze(0)+        >>> visual_token_type_ids = torch.ones(visual_embeds.shape[:-1], dtype=torch.long) #example+        >>> visual_attention_mask = torch.ones(visual_embeds.shape[:-1], dtype=torch.float)++        >>> inputs.update({{+            "visual_embeds": visual_embeds,+            "visual_token_type_ids": visual_token_type_ids,+            "visual_attention_mask": visual_attention_mask+        }})++        >>> labels = torch.tensor([[0.0,1.0]]).unsqueeze(0)  # Batch size 1, Num labels 2++        >>> outputs = model(**inputs, labels=labels)+        >>> loss = outputs.loss+        >>> scores = outputs.logits+"""++VISUAL_BERT_NLVR_SAMPLE = r"""+    Example::+        >>> # Assumption: `get_visual_embeddings(image)` gets the visual embeddings of the image in the batch.+        >>> from transformers import {tokenizer_class}, {model_class}+        >>> import torch++        >>> tokenizer = {tokenizer_class}.from_pretrained('{tokenizer_checkpoint}')+        >>> model = {model_class}.from_pretrained('{checkpoint}')++        >>> text = "Who is eating the apple?"+        >>> inputs = tokenizer(text, return_tensors='pt')+        >>> visual_embeds = get_visual_embeddings(image).unsqueeze(0)+        >>> visual_token_type_ids = torch.ones(visual_embeds.shape[:-1], dtype=torch.long) #example+        >>> visual_attention_mask = torch.ones(visual_embeds.shape[:-1], dtype=torch.float)++        >>> inputs.update({{+            "visual_embeds": visual_embeds,+            "visual_token_type_ids": visual_token_type_ids,+            "visual_attention_mask": visual_attention_mask+        }})++        >>> labels = torch.tensor(1).unsqueeze(0)  # Batch size 1, Num choices 2++        >>> outputs = model(**inputs, labels=labels)+        >>> loss = outputs.loss+        >>> scores = outputs.logits+"""+++VISUAL_BERT_VQA_ADVANCED_SAMPLE = r"""+    Example::+        >>> # Assumption: `get_visual_embeddings(image)` gets the visual embeddings of the image in the batch.+        >>> from transformers import {tokenizer_class}, {model_class}+        >>> import torch++        >>> tokenizer = {tokenizer_class}.from_pretrained('{tokenizer_checkpoint}')+        >>> model = {model_class}.from_pretrained('{checkpoint}')++        >>> text = "Who is eating the apple?"+        >>> inputs = tokenizer(text, return_tensors='pt')+        >>> visual_embeds = get_visual_embeddings(image).unsqueeze(0)+        >>> visual_token_type_ids = torch.ones(visual_embeds.shape[:-1], dtype=torch.long)+        >>> visual_attention_mask = torch.ones(visual_embeds.shape[:-1], dtype=torch.float)++        >>> inputs.update({{+            "visual_embeds": visual_embeds,+            "visual_token_type_ids": visual_token_type_ids,+            "visual_attention_mask": visual_attention_mask+        }})++        >>> labels = torch.ones((1, inputs["input_ids"].shape[-1]+visual_embeds.shape[-2])) # Batch size 1++        >>> outputs = model(**inputs, labels=labels)+        >>> loss = outputs.loss+        >>> scores = outputs.logits+"""++VISUAL_BERT_FLICKR_SAMPLE = r"""+    Example::+        >>> # Assumption: `get_visual_embeddings(image)` gets the visual embeddings of the image in the batch.+        >>> from transformers import {tokenizer_class}, {model_class}+        >>> import torch++        >>> tokenizer = {tokenizer_class}.from_pretrained('{tokenizer_checkpoint}')+        >>> model = {model_class}.from_pretrained('{checkpoint}')++        >>> text = "Who is eating the apple?"+        >>> inputs = tokenizer(text, return_tensors='pt')+        >>> visual_embeds = get_visual_embeddings(image).unsqueeze(0)+        >>> visual_token_type_ids = torch.ones(visual_embeds.shape[:-1], dtype=torch.long) #example+        >>> visual_attention_mask = torch.ones(visual_embeds.shape[:-1], dtype=torch.float)+        >>> flickr_position = torch.ones((1, inputs["input_ids"].shape[-1]+visual_embeds.shape[-2]))++        >>> inputs.update({{+            "flickr_position": flickr_position,+            "visual_embeds": visual_embeds,+            "visual_token_type_ids": visual_token_type_ids,+            "visual_attention_mask": visual_attention_mask+        }})++        >>> labels = torch.ones((1, inputs["input_ids"].shape[-1]+visual_embeds.shape[-2], visual_embeds.shape[-2])) # Batch size 1++        >>> outputs = model(**inputs, labels=labels)+        >>> loss = outputs.loss+        >>> scores = outputs.logits+"""++VISUAL_BERT_PRE_TRAINING_SAMPLE = r"""+    Example::+        >>> # Assumption: `get_visual_embeddings(image)` gets the visual embeddings of the image in the batch.+        >>> from transformers import {tokenizer_class}, {model_class}+        >>> import torch++        >>> tokenizer = {tokenizer_class}.from_pretrained('{tokenizer_checkpoint}')+        >>> model = {model_class}.from_pretrained('{checkpoint}')++        >>> inputs = tokenizer("The capital of France is {mask}.", return_tensors="pt")+        >>> visual_embeds = get_visual_embeddings(image).unsqueeze(0)+        >>> visual_token_type_ids = torch.ones(visual_embeds.shape[:-1], dtype=torch.long) #example+        >>> visual_attention_mask = torch.ones(visual_embeds.shape[:-1], dtype=torch.float)++        >>> inputs.update({{+            "visual_embeds": visual_embeds,+            "visual_token_type_ids": visual_token_type_ids,+            "visual_attention_mask": visual_attention_mask+        }})++        >>> labels = tokenizer("The capital of France is Paris.", return_tensors="pt", padding="max_length", max_length=inputs["input_ids"].shape[-1]+visual_embeds.shape[-2])["input_ids"]+        >>> sentence_image_labels = torch.tensor(1).unsqueeze(0) # Batch_size+++        >>> outputs = model(**inputs, labels=labels, sentence_image_labels=sentence_image_labels)+        >>> loss = outputs.loss+        >>> prediction_logits = outputs.prediction_logits+        >>> seq_relationship_logits = outputs.seq_relationship_logits+"""++VISUAL_BERT_MODEL_SAMPLE = r"""+    Example::+        >>> # Assumption: `get_visual_embeddings(image)` gets the visual embeddings of the image.+        >>> from transformers import {tokenizer_class}, {model_class}+        >>> import torch++        >>> tokenizer = {tokenizer_class}.from_pretrained('{tokenizer_checkpoint}')+        >>> model = {model_class}.from_pretrained('{checkpoint}')++        >>> inputs = tokenizer("The capital of France is Paris.", return_tensors="pt")+        >>> visual_embeds = get_visual_embeddings(image).unsqueeze(0)+        >>> visual_token_type_ids = torch.ones(visual_embeds.shape[:-1], dtype=torch.long) #example+        >>> visual_attention_mask = torch.ones(visual_embeds.shape[:-1], dtype=torch.float)++        >>> inputs.update({{+            "visual_embeds": visual_embeds,+            "visual_token_type_ids": visual_token_type_ids,+            "visual_attention_mask": visual_attention_mask+        }})++        >>> outputs = model(**inputs)++        >>> last_hidden_states = outputs.last_hidden_state+"""++VISUAL_BERT_MULTIPLE_CHOICE_SAMPLE = r"""+    Example::++        >>> from transformers import {tokenizer_class}, {model_class}+        >>> import torch++        >>> tokenizer = {tokenizer_class}.from_pretrained('{tokenizer_checkpoint}')+        >>> model = {model_class}.from_pretrained('{checkpoint}')++        >>> prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced."+        >>> choice0 = "It is eaten with a fork and a knife."+        >>> choice1 = "It is eaten while held in the hand."++        >>> visual_embeds = get_visual_embeddings(image)+        >>> visual_embeds = visual_embeds.expand(1, 2, *visual_embeds.shape) # (batch_size, num_choices, visual_seq_length, visual_embedding_dim)+        >>> visual_token_type_ids = torch.ones(visual_embeds.shape[:-1], dtype=torch.long)+        >>> visual_attention_mask = torch.ones(visual_embeds.shape[:-1], dtype=torch.float)++        >>> labels = torch.tensor(0).unsqueeze(0)  # choice0 is correct (according to Wikipedia ;)), batch size 1++        >>> encoding = tokenizer([[prompt, prompt], [choice0, choice1]], return_tensors='pt', padding=True)+        >>> outputs = model(**{{k: v.unsqueeze(0) for k,v in encoding.items()}}, visual_embeds=visual_embeds, visual_attention_mask=visual_attention_mask, visual_token_type_ids=visual_token_type_ids, labels=labels)  # batch size is 1++        >>> loss = outputs.loss+        >>> logits = outputs.logits+"""+++@add_start_docstrings(+    "The bare VisualBert Model transformer outputting raw hidden-states without any specific head on top.",+    VISUAL_BERT_START_DOCSTRING,+)+class VisualBertModel(VisualBertPreTrainedModel):+    """++    The model can behave as an encoder (with only self-attention) following the architecture described in `Attention is+    all you need <https://arxiv.org/abs/1706.03762>`__ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,+    Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.+    """++    def __init__(self, config, add_pooling_layer=True):+        super().__init__(config)+        self.config = config++        self.embeddings = VisualBertEmbeddings(config)+        self.encoder = VisualBertEncoder(config)++        self.pooler = (+            VisualBertPooler(config) if add_pooling_layer else None+        )  # TO-DO: Check if pooler is needed necessarily.++        self.bypass_transformer = config.bypass_transformer++        if self.bypass_transformer:+            self.additional_layer = VisualBertLayer(config)++        # TO-CHECK: This next line is from old BERT code, which is not used anymore.+        # self.output_attention_weights = config.output_attention_weights++        self.init_weights()  # self.apply(self.init_bert_weights) #Vestiges of old code++    # TO-CHECK+    def get_input_embeddings(self):+        return self.embeddings.word_embeddings++    # TO-CHECK+    def set_input_embeddings(self, value):+        self.embeddings.word_embeddings = value++    # TO-CHECK+    def _prune_heads(self, heads_to_prune):+        """+        Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base+        class PreTrainedModel+        """+        for layer, heads in heads_to_prune.items():+            self.encoder.layer[layer].attention.prune_heads(heads)++    @add_start_docstrings_to_model_forward(VISUAL_BERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))+    @replace_return_docstrings(output_type=BaseModelOutputWithPoolingAndCrossAttentions, config_class=_CONFIG_FOR_DOC)+    @add_code_sample_docstrings(+        tokenizer_class=_TOKENIZER_FOR_DOC,+        tokenizer_checkpoint=_TOKENIZER_CHECKPOINT,+        checkpoint="gchhablani/visualbert-vqa-coco-pre",+        output_type=BaseModelOutputWithPoolingAndCrossAttentions,+        config_class="gchhablani/visualbert-vqa-coco-pre",+        code_sample=VISUAL_BERT_MODEL_SAMPLE,+    )+    def forward(+        self,+        input_ids=None,+        attention_mask=None,+        token_type_ids=None,+        position_ids=None,+        head_mask=None,+        inputs_embeds=None,+        visual_embeds=None,+        visual_attention_mask=None,+        visual_token_type_ids=None,+        image_text_alignment=None,+        # encoder_hidden_states=None,+        # encoder_attention_mask=None,+        # past_key_values=None,+        # use_cache=None,+        output_attentions=None,+        output_hidden_states=None,+        return_dict=None,+    ):++        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions+        output_hidden_states = (+            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states+        )+        return_dict = return_dict if return_dict is not None else self.config.use_return_dict++        # if self.config.is_decoder:+        #     use_cache = use_cache if use_cache is not None else self.config.use_cache+        # else:+        use_cache = False++        if input_ids is not None and inputs_embeds is not None:+            raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")+        elif input_ids is not None:+            input_shape = input_ids.size()+            batch_size, seq_length = input_shape+        elif inputs_embeds is not None:+            input_shape = inputs_embeds.size()[:-1]+            batch_size, seq_length = input_shape+        else:+            raise ValueError("You have to specify either input_ids or inputs_embeds")++        device = input_ids.device if input_ids is not None else inputs_embeds.device++        # past_key_values_length+        # past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0++        if attention_mask is None:+            attention_mask = torch.ones(((batch_size, seq_length)), device=device)+            # attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)+        if token_type_ids is None:+            token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)++        if visual_embeds is not None:+            visual_input_shape = visual_embeds.size()[:-1]+            _, visual_seq_length = visual_input_shape+            if visual_token_type_ids is None:+                visual_token_type_ids = torch.zeros(visual_input_shape, dtype=torch.long, device=device)++            if visual_attention_mask is None:+                visual_attention_mask = torch.ones(visual_input_shape, device=device)++        # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]+        # ourselves in which case we just need to make it broadcastable to all heads.++        # TO-CHECK : Whether input_shape+visual_input_shape is correct.+        if visual_embeds is not None:+            combined_attention_mask = torch.cat((attention_mask, visual_attention_mask), dim=-1)+            extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(+                combined_attention_mask, [batch_size, input_shape + visual_input_shape], device+            )+        else:+            extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(+                attention_mask, [batch_size, input_shape], device+            )++        # If a 2D or 3D attention mask is provided for the cross-attention+        # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]+        # if self.config.is_decoder and encoder_hidden_states is not None:+        #     encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()+        #     encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)+        #     if encoder_attention_mask is None:+        #         encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)+        #     encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)+        # else:+        #     encoder_extended_attention_mask = None++        # Prepare head mask if needed+        # 1.0 in head_mask indicate we keep the head+        # attention_probs has shape bsz x n_heads x N x N+        # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]+        # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]+        head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)++        embedding_output = self.embeddings(+            input_ids=input_ids,+            position_ids=position_ids,+            token_type_ids=token_type_ids,+            inputs_embeds=inputs_embeds,+            visual_embeds=visual_embeds,+            visual_token_type_ids=visual_token_type_ids,+            image_text_alignment=image_text_alignment,+            # past_key_values_length=past_key_values_length,+        )++        if self.bypass_transformer and visual_embeds is not None:+            assert output_hidden_states is None  # TO-DO: Need to check if this is correct.+            text_length = input_ids.size(1)+            text_embedding_output = embedding_output[:, :text_length, :]+            visual_embedding_output = embedding_output[:, text_length:, :]++            text_extended_attention_mask = extended_attention_mask[:, :, text_length, :text_length]+            # text_encoder_hidden_states = encoder_hidden_states[:, :text_length, :]+            # text_encoder_attention_mask = encoder_extended_attention_mask[:, :, :text_length, :text_length]++            # TO-DO: Check how past-key values work and whether they are required to be modified and added here++            encoded_outputs = self.encoder(+                text_embedding_output,+                attention_mask=text_extended_attention_mask,+                # encoder_hidden_states=text_encoder_hidden_states,+                # encoder_attention_mask=text_encoder_attention_mask,+                use_cache=use_cache,+                output_attentions=output_attentions,+                output_hidden_states=output_hidden_states,+                return_dict=return_dict,+            )+            sequence_output = encoded_outputs[0]+            concatenated_input = torch.cat((sequence_output, visual_embedding_output), dim=1)+            sequence_output = self.additional_layer(concatenated_input, extended_attention_mask)+            pooled_output = self.pooler(sequence_output) if self.pooler is not None else None++        else:+            encoder_outputs = self.encoder(+                embedding_output,+                attention_mask=extended_attention_mask,+                head_mask=head_mask,+                # encoder_hidden_states=encoder_hidden_states,+                # encoder_attention_mask=encoder_extended_attention_mask,+                # past_key_values=past_key_values,+                use_cache=use_cache,+                output_attentions=output_attentions,+                output_hidden_states=output_hidden_states,+                return_dict=return_dict,+            )+            sequence_output = encoder_outputs[0]++            pooled_output = self.pooler(sequence_output) if self.pooler is not None else None++        if not return_dict:+            return (sequence_output, pooled_output) + encoder_outputs[1:]++        return BaseModelOutputWithPoolingAndCrossAttentions(  # Changed+            last_hidden_state=sequence_output,+            pooler_output=pooled_output,+            past_key_values=encoder_outputs.past_key_values,+            hidden_states=encoder_outputs.hidden_states,+            attentions=encoder_outputs.attentions,+            cross_attentions=encoder_outputs.cross_attentions,+        )+++# TO-DO: Check if the case where we don't want to calculate is_random_next loss.+# The is a case where during pre-training, in the original code, it is checked if the is_random_next is None+# In this case, the next_sentence_loss is not calculated.+++@add_start_docstrings(+    """+    VisualBert Model with two heads on top as done during the pretraining: a `masked language modeling` head and a+    `sentence-image prediction (classification)` head.+    """,+    VISUAL_BERT_START_DOCSTRING,+)+class VisualBertForPreTraining(VisualBertPreTrainedModel):+    def __init__(self, config):+        super().__init__(config)++        self.visual_bert = VisualBertModel(config)+        self.cls = VisualBertPreTrainingHeads(config)++        # UNUSED+        # self.cut_first = cut_first+        # self.hard_cap_seq_len = hard_cap_seq_len++        self.init_weights()++    def get_output_embeddings(self):+        return self.cls.predictions.decoder++    def set_output_embeddings(self, new_embeddings):+        self.cls.predictions.decoder = new_embeddings++    @add_start_docstrings_to_model_forward(VISUAL_BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))+    @replace_return_docstrings(output_type=VisualBertForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)+    @add_code_sample_docstrings(+        tokenizer_class=_TOKENIZER_FOR_DOC,+        tokenizer_checkpoint=_TOKENIZER_CHECKPOINT,+        checkpoint="gchhablani/visualbert-vqa-coco-pre",+        mask="[MASK]",+        output_type=VisualBertForPreTrainingOutput,+        config_class="gchhablani/visualbert-vqa-coco-pre",+        code_sample=VISUAL_BERT_PRE_TRAINING_SAMPLE,+    )+    def forward(+        self,+        input_ids=None,+        attention_mask=None,+        token_type_ids=None,+        position_ids=None,+        head_mask=None,+        inputs_embeds=None,+        visual_embeds=None,+        visual_attention_mask=None,+        visual_token_type_ids=None,+        image_text_alignment=None,+        output_attentions=None,+        output_hidden_states=None,+        return_dict=None,+        labels=None,+        sentence_image_labels=None,+    ):+        r"""+        labels (:obj:`torch.LongTensor` of shape ``(batch_size, total_sequence_length)``, `optional`):+            Labels for computing the masked language modeling loss. Indices should be in ``[-100, 0, ...,+            config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored+            (masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]``+        sentence_image_labels (``torch.LongTensor`` of shape ``(batch_size,)``, `optional`):+            Labels for computing the sentence-image prediction (classification) loss. Input should be a sequence pair+            (see :obj:`input_ids` docstring) Indices should be in ``[0, 1]``:++            - 0 indicates sequence B is a matching pair of sequence A for the given image,+            - 1 indicates sequence B is a random sequence w.r.t A for the given image.++        Returns:+        """+        return_dict = return_dict if return_dict is not None else self.config.use_return_dict++        outputs = self.visual_bert(+            input_ids,+            attention_mask=attention_mask,+            token_type_ids=token_type_ids,+            position_ids=position_ids,+            head_mask=head_mask,+            inputs_embeds=inputs_embeds,+            visual_embeds=visual_embeds,+            visual_attention_mask=visual_attention_mask,+            visual_token_type_ids=visual_token_type_ids,+            image_text_alignment=image_text_alignment,+            output_attentions=output_attentions,+            output_hidden_states=output_hidden_states,+            return_dict=return_dict,+        )++        sequence_output, pooled_output = outputs[:2]+        prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output)++        total_loss = None+        if labels is not None and sentence_image_labels is not None:+            assert labels.size(-1) == attention_mask.size(-1) + visual_attention_mask.size(+                -1+            ), f"The labels provided should have same sequence length as total attention mask. Found labels with sequence length {labels.size(-1)}, expected {attention_mask.size(-1)+ visual_attention_mask.size(-1)}."++            loss_fct = CrossEntropyLoss()+            masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))+            sentence_image_loss = loss_fct(seq_relationship_score.view(-1, 2), sentence_image_labels.view(-1))+            total_loss = masked_lm_loss + sentence_image_loss++        # TO-CHECK+        if labels is not None and sentence_image_labels is None:+            assert labels.size(-1) == attention_mask.size(-1) + visual_attention_mask.size(+                -1+            ), f"The labels provided should have same sequence length as total attention mask. Found labels with sequence length {labels.size(-1)}, expected {attention_mask.size(-1)+ visual_attention_mask.size(-1)}."+            loss_fct = CrossEntropyLoss()+            masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))+            total_loss = masked_lm_loss++        if not return_dict:+            output = (prediction_scores, seq_relationship_score) + outputs[2:]+            return ((total_loss,) + output) if total_loss is not None else output++        return VisualBertForPreTrainingOutput(+            loss=total_loss,+            prediction_logits=prediction_scores,+            seq_relationship_logits=seq_relationship_score,+            hidden_states=outputs.hidden_states,+            attentions=outputs.attentions,+        )+++class VisualBertClassificationHead(nn.Module):+    """Head for sentence-level classification tasks."""++    def __init__(self, config):+        super().__init__()+        self.dense = nn.Linear(config.hidden_size, config.hidden_size)+        self.dropout = nn.Dropout(config.hidden_dropout_prob)+        self.out_proj = nn.Linear(config.hidden_size, config.num_labels)++        self.config = config++    def forward(self, features, **kwargs):+        x = features[:, 0, :]  # take <s> token (equiv. to [CLS])+        x = self.dropout(x)+        x = self.dense(x)+        x = ACT2FN[self.config.hidden_act](x)+        x = self.dropout(x)+        x = self.out_proj(x)+        return x+++@add_start_docstrings(+    """+    VisualBert Model with a multiple choice classification head on top (a linear layer on top of the pooled output and+    a softmax) e.g. for VCR tasks.+    """,+    VISUAL_BERT_START_DOCSTRING,+)+class VisualBertForMultipleChoice(VisualBertPreTrainedModel):+    def __init__(self, config):+        super().__init__(config)++        self.visual_bert = VisualBertModel(config)+        # TO-CHECK+        # self.sequence_summary = SequenceSummary(config)+        self.cls = nn.Linear(config.hidden_size, 1)++        self.init_weights()++    @add_start_docstrings_to_model_forward(+        VISUAL_BERT_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length")+    )+    @add_code_sample_docstrings(+        tokenizer_class=_TOKENIZER_FOR_DOC,+        tokenizer_checkpoint=_TOKENIZER_CHECKPOINT,+        checkpoint="gchhablani/visualbert-vcr",+        output_type=MultipleChoiceModelOutput,+        config_class="gchhablani/visualbert-vcr",+        code_sample=VISUAL_BERT_MULTIPLE_CHOICE_SAMPLE,+    )+    def forward(+        self,+        input_ids=None,+        attention_mask=None,+        token_type_ids=None,+        position_ids=None,+        head_mask=None,+        inputs_embeds=None,+        visual_embeds=None,+        visual_attention_mask=None,+        visual_token_type_ids=None,+        image_text_alignment=None,+        output_attentions=None,+        output_hidden_states=None,+        return_dict=None,+        labels=None,+    ):+        r"""+        labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):+            Labels for computing the multiple choice classification loss. Indices should be in ``[0, ...,+            num_choices-1]`` where :obj:`num_choices` is the size of the second dimension of the input tensors. (See+            :obj:`input_ids` above)+        """+        return_dict = return_dict if return_dict is not None else self.config.use_return_dict+        num_choices = (+            input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]+        )  # TO-CHECK: original code uses 4 directly.++        input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None+        attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None+        token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None+        position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None+        inputs_embeds = (+            inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))+            if inputs_embeds is not None+            else None+        )++        visual_embeds = (+            visual_embeds.view(-1, visual_embeds.size(-2), visual_embeds.size(-1))+            if visual_embeds is not None+            else None+        )+        visual_attention_mask = (+            visual_attention_mask.view(-1, visual_attention_mask.size(-1))+            if visual_attention_mask is not None+            else None+        )+        visual_token_type_ids = (+            visual_token_type_ids.view(-1, visual_token_type_ids.size(-1))+            if visual_token_type_ids is not None+            else None+        )++        outputs = self.visual_bert(+            input_ids,+            attention_mask=attention_mask,+            token_type_ids=token_type_ids,+            position_ids=position_ids,+            head_mask=head_mask,+            inputs_embeds=inputs_embeds,+            visual_embeds=visual_embeds,+            visual_attention_mask=visual_attention_mask,+            visual_token_type_ids=visual_token_type_ids,+            image_text_alignment=image_text_alignment,+            output_attentions=output_attentions,+            output_hidden_states=output_hidden_states,+            return_dict=return_dict,+        )++        _, pooled_output = outputs[0], outputs[1]++        # pooled_output = self.sequence_summary(sequence_output)+        logits = self.cls(pooled_output)+        reshaped_logits = logits.view(-1, num_choices)++        loss = None+        if labels is not None:+            loss_fct = CrossEntropyLoss()+            loss = loss_fct(reshaped_logits, labels)++        if not return_dict:+            output = (reshaped_logits,) + outputs[2:]+            return ((loss,) + output) if loss is not None else output++        return MultipleChoiceModelOutput(+            loss=loss,+            logits=reshaped_logits,+            hidden_states=outputs.hidden_states,+            attentions=outputs.attentions,+        )+++@add_start_docstrings(+    """+    VisualBert Model with a classification/regression head on top (a dropout and a linear layer on top of the pooled+    output) for VQA.+    """,+    VISUAL_BERT_START_DOCSTRING,+)+class VisualBertForVQA(VisualBertPreTrainedModel):+    def __init__(self, config):+        super().__init__(config)+        self.num_labels = config.num_labels++        self.visual_bert = VisualBertModel(config)+        self.dropout = nn.Dropout(config.hidden_dropout_prob)+        self.cls = nn.Linear(config.hidden_size, config.num_labels)++        self.init_weights()++    @add_start_docstrings_to_model_forward(VISUAL_BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))+    @add_code_sample_docstrings(+        tokenizer_class=_TOKENIZER_FOR_DOC,+        tokenizer_checkpoint=_TOKENIZER_CHECKPOINT,+        checkpoint="gchhablani/visualbert-vqa",+        output_type=SequenceClassifierOutput,+        config_class="gchhablani/visualbert-vqa",+        code_sample=VISUAL_BERT_VQA_SAMPLE,+    )+    def forward(+        self,+        input_ids=None,+        attention_mask=None,+        token_type_ids=None,+        position_ids=None,+        head_mask=None,+        inputs_embeds=None,+        visual_embeds=None,+        visual_attention_mask=None,+        visual_token_type_ids=None,+        image_text_alignment=None,+        output_attentions=None,+        output_hidden_states=None,+        return_dict=None,+        labels=None,+    ):+        r"""+        labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, total_sequence_length)`, `optional`):+            Labels for computing the sequence classification/regression loss. Indices should be in :obj:`[0, ...,+            config.num_labels - 1]`. A KLDLoss is computed between the labels and the returned logits.++        """+        return_dict = return_dict if return_dict is not None else self.config.use_return_dict++        # TO-CHECK : Raises an error if sum <2+        index_to_gather = attention_mask.sum(1) - 2  # as in original code # need before concat++        outputs = self.visual_bert(+            input_ids,+            attention_mask=attention_mask,+            token_type_ids=token_type_ids,+            position_ids=position_ids,+            head_mask=head_mask,+            inputs_embeds=inputs_embeds,+            visual_embeds=visual_embeds,+            visual_attention_mask=visual_attention_mask,+            visual_token_type_ids=visual_token_type_ids,+            image_text_alignment=image_text_alignment,+            output_attentions=output_attentions,+            output_hidden_states=output_hidden_states,+            return_dict=return_dict,+        )++        sequence_output = outputs[0]+        # pooled_output = outputs[1]  # TO-DO: Convert this to the gather code used by the original model.++        pooled_output = torch.gather(+            sequence_output,+            1,+            index_to_gather.unsqueeze(-1).unsqueeze(-1).expand(index_to_gather.size(0), 1, sequence_output.size(-1)),+        )++        # UNUSED+        # input_ids = torch.gather(input_ids, 1, index_to_gather.unsqueeze(-1).expand(index_to_gather.size(0), 1))++        pooled_output = self.dropout(pooled_output)+        logits = self.cls(pooled_output)+        reshaped_logits = logits.view(-1, self.num_labels)++        loss = None+        if labels is not None:+            loss_fct = torch.nn.KLDivLoss(reduction="batchmean")+            log_softmax = torch.nn.LogSoftmax(dim=-1)+            reshaped_logits = log_softmax(reshaped_logits)+            loss = loss_fct(reshaped_logits, labels.contiguous())+        if not return_dict:+            output = (reshaped_logits,) + outputs[2:]+            return ((loss,) + output) if loss is not None else output++        return SequenceClassifierOutput(  # TO-DO: Need to replace this with VQA Model Output, maybe+            loss=loss,+            logits=reshaped_logits,+            hidden_states=outputs.hidden_states,+            attentions=outputs.attentions,+        )+++@add_start_docstrings(+    """+    VisualBert Model with a MLM head on top for VQA tasks.+    """,+    VISUAL_BERT_START_DOCSTRING,+)+class VisualBertForVQAAdvanced(VisualBertPreTrainedModel):

Removing VQA Advanced as the authors confirmed it isn't necessary.

gchhablani

comment created time in an hour

issue commentguolinke/TUPE

Discrepancy between the paper and the implementation?

Yeah, it indeed is slightly different. In the implementation, we reset the abs-pos and rel-pos separately, by their own forms. I think it is almost the same as the reset them unified, as the difference is a bias term (from rel-pos part). This is for easier implementation, and make codes cleaner.

tonyswoo

comment created time in an hour

Pull request review commenthuggingface/transformers

Add VisualBERT

+# coding=utf-8+# Copyright 2021 The HuggingFace Inc. team. All rights reserved.+#+# Licensed under the Apache License, Version 2.0 (the "License");+# you may not use this file except in compliance with the License.+# You may obtain a copy of the License at+#+#     http://www.apache.org/licenses/LICENSE-2.0+#+# Unless required by applicable law or agreed to in writing, software+# distributed under the License is distributed on an "AS IS" BASIS,+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.+# See the License for the specific language governing permissions and+# limitations under the License.+""" PyTorch VisualBERT model. """+++import math+import os+from dataclasses import dataclass+from typing import Optional, Tuple++import torch+import torch.utils.checkpoint+from torch import nn+from torch.nn import CrossEntropyLoss, KLDivLoss, LogSoftmax++from ...activations import ACT2FN+from ...file_utils import (+    ModelOutput,+    _prepare_output_docstrings,+    add_start_docstrings,+    add_start_docstrings_to_model_forward,+    replace_return_docstrings,+)+from ...modeling_outputs import (+    BaseModelOutputWithPastAndCrossAttentions,+    BaseModelOutputWithPoolingAndCrossAttentions,+    MaskedLMOutput,+    MultipleChoiceModelOutput,+    SequenceClassifierOutput,+)+from ...modeling_utils import (+    PreTrainedModel,+    apply_chunking_to_forward,+    find_pruneable_heads_and_indices,+    prune_linear_layer,+)+from ...utils import logging+from .configuration_visual_bert import VisualBertConfig+++logger = logging.get_logger(__name__)++_CONFIG_FOR_DOC = "VisualBertConfig"+_TOKENIZER_FOR_DOC = "BertTokenizer"+_TOKENIZER_CHECKPOINT = "bert-base-uncased"++VISUAL_BERT_PRETRAINED_MODEL_ARCHIVE_LIST = [+    "gchhablani/visualbert-vqa",+    "gchhablani/visualbert-vqa-pre",+    "gchhablani/visualbert-vqa-coco-pre",+    "gchhablani/visualbert-vcr",+    "gchhablani/visualbert-vcr-pre",+    "gchhablani/visualbert-vcr-coco-pre",+    "gchhablani/visualbert-nlvr2",+    "gchhablani/visualbert-nlvr2-pre",+    "gchhablani/visualbert-nlvr2-coco-pre"+    # See all VisualBERT models at https://huggingface.co/models?filter=visual_bert+]+++# TO-CHECK+def load_tf_weights_in_visual_bert(model, config, tf_checkpoint_path):+    """Load tf checkpoints in a pytorch model."""+    try:+        import re++        import numpy as np+        import tensorflow as tf+    except ImportError:+        logger.error(+            "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see "+            "https://www.tensorflow.org/install/ for installation instructions."+        )+        raise+    tf_path = os.path.abspath(tf_checkpoint_path)+    logger.info("Converting TensorFlow checkpoint from {}".format(tf_path))+    # Load weights from TF model+    init_vars = tf.train.list_variables(tf_path)+    names = []+    arrays = []+    for name, shape in init_vars:+        logger.info("Loading TF weight {} with shape {}".format(name, shape))+        array = tf.train.load_variable(tf_path, name)+        names.append(name)+        arrays.append(array)++    for name, array in zip(names, arrays):+        name = name.split("/")+        # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v+        # which are not required for using pretrained model+        if any(+            n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"]+            for n in name+        ):+            logger.info("Skipping {}".format("/".join(name)))+            continue+        pointer = model+        for m_name in name:+            if re.fullmatch(r"[A-Za-z]+_\d+", m_name):+                scope_names = re.split(r"_(\d+)", m_name)+            else:+                scope_names = [m_name]+            if scope_names[0] == "kernel" or scope_names[0] == "gamma":+                pointer = getattr(pointer, "weight")+            elif scope_names[0] == "output_bias" or scope_names[0] == "beta":+                pointer = getattr(pointer, "bias")+            elif scope_names[0] == "output_weights":+                pointer = getattr(pointer, "weight")+            elif scope_names[0] == "squad":+                pointer = getattr(pointer, "classifier")+            else:+                try:+                    pointer = getattr(pointer, scope_names[0])+                except AttributeError:+                    logger.info("Skipping {}".format("/".join(name)))+                    continue+            if len(scope_names) >= 2:+                num = int(scope_names[1])+                pointer = pointer[num]+        if m_name[-11:] == "_embeddings":+            pointer = getattr(pointer, "weight")+        elif m_name == "kernel":+            array = np.transpose(array)+        try:+            assert (+                pointer.shape == array.shape+            ), f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched"+        except AssertionError as e:+            e.args += (pointer.shape, array.shape)+            raise+        logger.info("Initialize PyTorch weight {}".format(name))+        pointer.data = torch.from_numpy(array)+    return model+++def mish(x):+    return x * torch.tanh(nn.functional.softplus(x))+++def add_code_sample_docstrings(+    *docstr,+    tokenizer_class=None,+    tokenizer_checkpoint=None,+    checkpoint=None,+    output_type=None,+    config_class=None,+    mask=None,+    model_cls=None,+    code_sample=None+):+    def docstring_decorator(fn):+        # model_class defaults to function's class if not specified otherwise+        model_class = fn.__qualname__.split(".")[0] if model_cls is None else model_cls++        doc_kwargs = dict(+            model_class=model_class,+            tokenizer_class=tokenizer_class,+            checkpoint=checkpoint,+            mask=mask,+            tokenizer_checkpoint=tokenizer_checkpoint,+        )++        output_doc = _prepare_output_docstrings(output_type, config_class) if output_type is not None else ""++        built_doc = code_sample.format(**doc_kwargs)+        fn.__doc__ = (fn.__doc__ or "") + "".join(docstr) + output_doc + built_doc+        return fn++    return docstring_decorator+++# TO-CHECK: Vestige of the original code+++class BertLayerNorm(nn.Module):+    def __init__(self, hidden_size, eps=1e-12):+        """Construct a layernorm module in the TF style (epsilon inside the square root)."""+        super(BertLayerNorm, self).__init__()+        self.weight = nn.Parameter(torch.ones(hidden_size))+        self.bias = nn.Parameter(torch.zeros(hidden_size))+        self.variance_epsilon = eps++    def forward(self, x):+        u = x.mean(-1, keepdim=True)+        s = (x - u).pow(2).mean(-1, keepdim=True)+        x = (x - u) / torch.sqrt(s + self.variance_epsilon)+        return self.weight * x + self.bias+++class VisualBertEmbeddings(nn.Module):+    """Construct the embeddings from word, position and token_type embeddings and visual embeddings."""++    def __init__(self, config):+        super().__init__()+        self.word_embeddings = nn.Embedding(+            config.vocab_size, config.hidden_size+        )  # TO-CHECK: , padding_idx=config.pad_token_id+        self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)+        self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)++        # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load+        # any TensorFlow checkpoint file++        # TO-CHECK+        # self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)+        self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps)  # original eps=1e-12+        self.dropout = nn.Dropout(config.hidden_dropout_prob)++        # TO-CHECK+        # position_ids (1, len position emb) is contiguous in memory and exported when serialized+        # self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))+        self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")++        # For Visual Features+        # Segment and position embedding for image features+        self.visual_token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)+        self.visual_position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)++        # TO-CHECK: Check if register buffer is needed for Visual features+        self.visual_projection = nn.Linear(config.visual_embedding_dim, config.hidden_size)++    # TO-CHECK: Check how to incorporate this. This is being called outside the classes.+    # def special_intialize(self):+    #     ### This is a bit unorthodox. The better way might be to add an inititilizer to AllenNLP.+    #     # This function is used to initialize the token_type_embeddings_visual and positiona_embedding_visual, just incase.+    #     self.token_type_embeddings_visual.weight = torch.nn.Parameter(deepcopy(self.token_type_embeddings.weight.data), requires_grad = True)+    #     self.position_embeddings_visual.weight = torch.nn.Parameter(deepcopy(self.position_embeddings.weight.data), requires_grad = True)+    #     return++    def forward(+        self,+        input_ids=None,  # TO-CHECK+        token_type_ids=None,+        position_ids=None,+        inputs_embeds=None,+        visual_embeds=None,+        visual_token_type_ids=None,+        image_text_alignment=None,+        # past_key_values_length=0, # TO-CHECK+    ):+        # TO-CHECK: Check if `confidence=None` and `visual_position_embeds=None` (or id) is needed.+        # `position_embeddings_visual` is not used in the original code.++        if input_ids is not None:+            input_shape = input_ids.size()+        else:+            input_shape = inputs_embeds.size()[:-1]++        seq_length = input_shape[1]++        # TO-CHECK+        # if position_ids is None:+        #     position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]+        if input_ids is not None:+            position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device)+            position_ids = position_ids.unsqueeze(0).expand_as(input_ids)+        else:+            position_ids = torch.arange(seq_length, dtype=torch.long, device=inputs_embeds.device)+            position_ids = position_ids.unsqueeze(0).expand(input_shape)++        if token_type_ids is None:+            token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)++        if inputs_embeds is None:+            inputs_embeds = self.word_embeddings(input_ids)++        token_type_embeddings = self.token_type_embeddings(token_type_ids)++        embeddings = inputs_embeds + token_type_embeddings+        if self.position_embedding_type == "absolute":+            position_embeddings = self.position_embeddings(position_ids)+            embeddings += position_embeddings++        if visual_embeds is not None:+            visual_embeds = self.visual_projection(visual_embeds)+            visual_token_type_embeddings = self.visual_token_type_embeddings(visual_token_type_ids)++            if image_text_alignment is not None:++                # TO-DO: Find a way to handle this in a better way.+                # image_text_alignment = Batch x image_length x alignment_number. Each element denotes the position of the word corresponding to the image feature. -1 is the padding value.+                image_text_alignment_mask = (image_text_alignment != -1).long()+                # Get rid of the -1.+                image_text_alignment = image_text_alignment_mask * image_text_alignment++                # Batch x image_length x alignment length x dim+                visual_position_embeddings = self.position_embeddings(+                    image_text_alignment+                ) * image_text_alignment_mask.to(dtype=next(self.parameters()).dtype).unsqueeze(-1)+                visual_position_embeddings = visual_position_embeddings.sum(2)++                # We want to averge along the alignment_number dimension.+                image_text_alignment_mask = image_text_alignment_mask.to(dtype=next(self.parameters()).dtype).sum(2)+                image_text_alignment_mask[image_text_alignment_mask == 0] = 1  # Avoid devide by zero error+                visual_position_embeddings = visual_position_embeddings / image_text_alignment_mask.unsqueeze(-1)++                visual_position_ids = torch.zeros(+                    *visual_embeds.size()[:-1], dtype=torch.long, device=visual_embeds.device+                )  # They use .cuda() but I believe this will be same as visual_embeds device.++                # When fine-tuning the detector , the image_text_alignment is sometimes padded too long.+                if visual_position_embeddings.size(1) != visual_embeds.size(1):+                    assert visual_position_embeddings.size(1) >= visual_embeds.size(1)+                    visual_position_embeddings = visual_position_embeddings[:, : visual_embeds.size(1), :]++                visual_position_embeddings = visual_position_embeddings + self.visual_position_embeddings(+                    visual_position_ids+                )+            else:+                visual_position_ids = torch.zeros(+                    *visual_embeds.size()[:-1], dtype=torch.long, device=visual_embeds.device+                )  # They use .cuda() but I believe this will be same as visual_embeds device.+                visual_position_embeddings = self.visual_position_embeddings(visual_position_ids)++            visual_embeddings = visual_embeds + visual_position_embeddings + visual_token_type_embeddings++            # Concate the two:+            embeddings = torch.cat((embeddings, visual_embeddings), dim=1)  # concat the visual embeddings++        embeddings = self.LayerNorm(embeddings)+        embeddings = self.dropout(embeddings)+        return embeddings+++class VisualBertSelfAttention(nn.Module):+    def __init__(self, config):+        super().__init__()+        if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):+            raise ValueError(+                "The hidden size (%d) is not a multiple of the number of attention "+                "heads (%d)" % (config.hidden_size, config.num_attention_heads)+            )++        self.num_attention_heads = config.num_attention_heads+        self.attention_head_size = int(config.hidden_size / config.num_attention_heads)+        self.all_head_size = self.num_attention_heads * self.attention_head_size++        self.query = nn.Linear(config.hidden_size, self.all_head_size)+        self.key = nn.Linear(config.hidden_size, self.all_head_size)+        self.value = nn.Linear(config.hidden_size, self.all_head_size)++        self.dropout = nn.Dropout(config.attention_probs_dropout_prob)+        self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")+        if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":+            self.max_position_embeddings = config.max_position_embeddings+            self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)++        self.is_decoder = config.is_decoder++    def transpose_for_scores(self, x):+        new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)+        x = x.view(*new_x_shape)+        return x.permute(0, 2, 1, 3)++    def forward(+        self,+        hidden_states,+        attention_mask=None,+        head_mask=None,+        encoder_hidden_states=None,+        encoder_attention_mask=None,+        past_key_value=None,+        output_attentions=False,+    ):+        mixed_query_layer = self.query(hidden_states)++        # If this is instantiated as a cross-attention module, the keys+        # and values come from an encoder; the attention mask needs to be+        # such that the encoder's padding tokens are not attended to.+        is_cross_attention = encoder_hidden_states is not None++        if is_cross_attention and past_key_value is not None:+            # reuse k,v, cross_attentions+            key_layer = past_key_value[0]+            value_layer = past_key_value[1]+            attention_mask = encoder_attention_mask+        elif is_cross_attention:+            key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))+            value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))+            attention_mask = encoder_attention_mask+        elif past_key_value is not None:+            key_layer = self.transpose_for_scores(self.key(hidden_states))+            value_layer = self.transpose_for_scores(self.value(hidden_states))+            key_layer = torch.cat([past_key_value[0], key_layer], dim=2)+            value_layer = torch.cat([past_key_value[1], value_layer], dim=2)+        else:+            key_layer = self.transpose_for_scores(self.key(hidden_states))+            value_layer = self.transpose_for_scores(self.value(hidden_states))++        query_layer = self.transpose_for_scores(mixed_query_layer)++        if self.is_decoder:+            # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.+            # Further calls to cross_attention layer can then reuse all cross-attention+            # key/value_states (first "if" case)+            # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of+            # all previous decoder key/value_states. Further calls to uni-directional self-attention+            # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)+            # if encoder bi-directional self-attention `past_key_value` is always `None`+            past_key_value = (key_layer, value_layer)++        # Take the dot product between "query" and "key" to get the raw attention scores.+        attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))++        if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":+            seq_length = hidden_states.size()[1]+            position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)+            position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1)+            distance = position_ids_l - position_ids_r+            positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)+            positional_embedding = positional_embedding.to(dtype=query_layer.dtype)  # fp16 compatibility++            if self.position_embedding_type == "relative_key":+                relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)+                attention_scores = attention_scores + relative_position_scores+            elif self.position_embedding_type == "relative_key_query":+                relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)+                relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)+                attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key++        attention_scores = attention_scores / math.sqrt(self.attention_head_size)+        if attention_mask is not None:+            # Apply the attention mask is (precomputed for all layers in VisualBertModel forward() function)+            attention_scores = attention_scores + attention_mask++        # Normalize the attention scores to probabilities.+        attention_probs = nn.Softmax(dim=-1)(attention_scores)++        # This is actually dropping out entire tokens to attend to, which might+        # seem a bit unusual, but is taken from the original Transformer paper.+        attention_probs = self.dropout(attention_probs)++        # Mask heads if we want to+        if head_mask is not None:+            attention_probs = attention_probs * head_mask++        context_layer = torch.matmul(attention_probs, value_layer)++        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()+        new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)+        context_layer = context_layer.view(*new_context_layer_shape)++        outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)++        if self.is_decoder:+            outputs = outputs + (past_key_value,)+        return outputs+++class VisualBertSelfOutput(nn.Module):+    def __init__(self, config):+        super().__init__()+        self.dense = nn.Linear(config.hidden_size, config.hidden_size)++        # TO-CHECK+        # self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)+        self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps)  # original eps=1e-12+        self.dropout = nn.Dropout(config.hidden_dropout_prob)++    def forward(self, hidden_states, input_tensor):+        hidden_states = self.dense(hidden_states)+        hidden_states = self.dropout(hidden_states)+        hidden_states = self.LayerNorm(hidden_states + input_tensor)+        return hidden_states+++class VisualBertAttention(nn.Module):+    def __init__(self, config):+        super().__init__()+        self.self = VisualBertSelfAttention(config)+        self.output = VisualBertSelfOutput(config)+        self.pruned_heads = set()++    def prune_heads(self, heads):+        if len(heads) == 0:+            return+        heads, index = find_pruneable_heads_and_indices(+            heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads+        )++        # Prune linear layers+        self.self.query = prune_linear_layer(self.self.query, index)+        self.self.key = prune_linear_layer(self.self.key, index)+        self.self.value = prune_linear_layer(self.self.value, index)+        self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)++        # Update hyper params and store pruned heads+        self.self.num_attention_heads = self.self.num_attention_heads - len(heads)+        self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads+        self.pruned_heads = self.pruned_heads.union(heads)++    def forward(+        self,+        hidden_states,+        attention_mask=None,+        head_mask=None,+        encoder_hidden_states=None,+        encoder_attention_mask=None,+        past_key_value=None,+        output_attentions=False,+    ):+        self_outputs = self.self(+            hidden_states,+            attention_mask,+            head_mask,+            encoder_hidden_states,+            encoder_attention_mask,+            past_key_value,+            output_attentions,+        )+        attention_output = self.output(self_outputs[0], hidden_states)+        outputs = (attention_output,) + self_outputs[1:]  # add attentions if we output them+        return outputs+++class VisualBertIntermediate(nn.Module):+    def __init__(self, config):+        super().__init__()+        self.dense = nn.Linear(config.hidden_size, config.intermediate_size)+        if isinstance(config.hidden_act, str):+            self.intermediate_act_fn = ACT2FN[config.hidden_act]+        else:+            self.intermediate_act_fn = config.hidden_act++    def forward(self, hidden_states):+        hidden_states = self.dense(hidden_states)+        hidden_states = self.intermediate_act_fn(hidden_states)+        return hidden_states+++class VisualBertOutput(nn.Module):+    def __init__(self, config):+        super().__init__()+        self.dense = nn.Linear(config.intermediate_size, config.hidden_size)+        # TO-CHECK+        # self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)+        self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps)  # original eps=1e-12+        self.dropout = nn.Dropout(config.hidden_dropout_prob)++    def forward(self, hidden_states, input_tensor):+        hidden_states = self.dense(hidden_states)+        hidden_states = self.dropout(hidden_states)+        hidden_states = self.LayerNorm(hidden_states + input_tensor)+        return hidden_states+++class VisualBertLayer(nn.Module):+    def __init__(self, config):+        super().__init__()+        self.chunk_size_feed_forward = config.chunk_size_feed_forward+        self.seq_len_dim = 1+        self.attention = VisualBertAttention(config)+        self.is_decoder = config.is_decoder+        self.add_cross_attention = config.add_cross_attention+        if self.add_cross_attention:+            assert self.is_decoder, f"{self} should be used as a decoder model if cross attention is added"+            self.crossattention = VisualBertAttention(config)+        self.intermediate = VisualBertIntermediate(config)+        self.output = VisualBertOutput(config)++    def forward(+        self,+        hidden_states,+        attention_mask=None,+        head_mask=None,+        encoder_hidden_states=None,+        encoder_attention_mask=None,+        past_key_value=None,+        output_attentions=False,+    ):+        # decoder uni-directional self-attention cached key/values tuple is at positions 1,2+        self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None+        self_attention_outputs = self.attention(+            hidden_states,+            attention_mask,+            head_mask,+            output_attentions=output_attentions,+            past_key_value=self_attn_past_key_value,+        )+        attention_output = self_attention_outputs[0]++        # if decoder, the last output is tuple of self-attn cache+        if self.is_decoder:+            outputs = self_attention_outputs[1:-1]+            present_key_value = self_attention_outputs[-1]+        else:+            outputs = self_attention_outputs[1:]  # add self attentions if we output attention weights++        cross_attn_present_key_value = None+        if self.is_decoder and encoder_hidden_states is not None:+            assert hasattr(+                self, "crossattention"+            ), f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers by setting `config.add_cross_attention=True`"++            # cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple+            cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None+            cross_attention_outputs = self.crossattention(+                attention_output,+                attention_mask,+                head_mask,+                encoder_hidden_states,+                encoder_attention_mask,+                cross_attn_past_key_value,+                output_attentions,+            )+            attention_output = cross_attention_outputs[0]+            outputs = outputs + cross_attention_outputs[1:-1]  # add cross attentions if we output attention weights++            # add cross-attn cache to positions 3,4 of present_key_value tuple+            cross_attn_present_key_value = cross_attention_outputs[-1]+            present_key_value = present_key_value + cross_attn_present_key_value++        layer_output = apply_chunking_to_forward(+            self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output+        )+        outputs = (layer_output,) + outputs++        # if decoder, return the attn key/values as the last output+        if self.is_decoder:+            outputs = outputs + (present_key_value,)++        return outputs++    def feed_forward_chunk(self, attention_output):+        intermediate_output = self.intermediate(attention_output)+        layer_output = self.output(intermediate_output, attention_output)+        return layer_output+++class VisualBertEncoder(nn.Module):+    def __init__(self, config):+        super().__init__()+        self.config = config+        self.layer = nn.ModuleList([VisualBertLayer(config) for _ in range(config.num_hidden_layers)])++    def forward(+        self,+        hidden_states,+        attention_mask=None,+        head_mask=None,+        encoder_hidden_states=None,+        encoder_attention_mask=None,+        past_key_values=None,+        use_cache=None,+        output_attentions=False,+        output_hidden_states=False,+        return_dict=True,+    ):+        all_hidden_states = () if output_hidden_states else None+        all_self_attentions = () if output_attentions else None+        all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None++        next_decoder_cache = () if use_cache else None+        for i, layer_module in enumerate(self.layer):+            if output_hidden_states:+                all_hidden_states = all_hidden_states + (hidden_states,)++            layer_head_mask = head_mask[i] if head_mask is not None else None+            past_key_value = past_key_values[i] if past_key_values is not None else None++            if getattr(self.config, "gradient_checkpointing", False) and self.training:++                if use_cache:+                    logger.warn(+                        "`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting "+                        "`use_cache=False`..."+                    )+                    use_cache = False++                def create_custom_forward(module):+                    def custom_forward(*inputs):+                        return module(*inputs, past_key_value, output_attentions)++                    return custom_forward++                layer_outputs = torch.utils.checkpoint.checkpoint(+                    create_custom_forward(layer_module),+                    hidden_states,+                    attention_mask,+                    layer_head_mask,+                    encoder_hidden_states,+                    encoder_attention_mask,+                )+            else:+                layer_outputs = layer_module(+                    hidden_states,+                    attention_mask,+                    layer_head_mask,+                    encoder_hidden_states,+                    encoder_attention_mask,+                    past_key_value,+                    output_attentions,+                )++            hidden_states = layer_outputs[0]+            if use_cache:+                next_decoder_cache += (layer_outputs[-1],)+            if output_attentions:+                all_self_attentions = all_self_attentions + (layer_outputs[1],)+                if self.config.add_cross_attention:+                    all_cross_attentions = all_cross_attentions + (layer_outputs[2],)++        if output_hidden_states:+            all_hidden_states = all_hidden_states + (hidden_states,)++        if not return_dict:+            return tuple(+                v+                for v in [+                    hidden_states,+                    next_decoder_cache,+                    all_hidden_states,+                    all_self_attentions,+                    all_cross_attentions,+                ]+                if v is not None+            )+        return BaseModelOutputWithPastAndCrossAttentions(+            last_hidden_state=hidden_states,+            past_key_values=next_decoder_cache,+            hidden_states=all_hidden_states,+            attentions=all_self_attentions,+            cross_attentions=all_cross_attentions,+        )+++class VisualBertPooler(nn.Module):+    def __init__(self, config):+        super().__init__()+        self.dense = nn.Linear(config.hidden_size, config.hidden_size)+        self.activation = nn.Tanh()++    def forward(self, hidden_states):+        # We "pool" the model by simply taking the hidden state corresponding+        # to the first token.+        first_token_tensor = hidden_states[:, 0]+        pooled_output = self.dense(first_token_tensor)+        pooled_output = self.activation(pooled_output)+        return pooled_output+++class VisualBertPredictionHeadTransform(nn.Module):+    def __init__(self, config):+        super().__init__()+        self.dense = nn.Linear(config.hidden_size, config.hidden_size)+        if isinstance(config.hidden_act, str):+            self.transform_act_fn = ACT2FN[config.hidden_act]+        else:+            self.transform_act_fn = config.hidden_act++        # TO-CHECK+        # self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)+        self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps)++    def forward(self, hidden_states):+        hidden_states = self.dense(hidden_states)+        hidden_states = self.transform_act_fn(hidden_states)+        hidden_states = self.LayerNorm(hidden_states)+        return hidden_states+++class VisualBertLMPredictionHead(nn.Module):+    def __init__(self, config):+        super().__init__()+        self.transform = VisualBertPredictionHeadTransform(config)++        # The output weights are the same as the input embeddings, but there is+        # an output-only bias for each token.+        self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)++        self.bias = nn.Parameter(torch.zeros(config.vocab_size))++        # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`+        self.decoder.bias = self.bias++    def forward(self, hidden_states):+        hidden_states = self.transform(hidden_states)+        hidden_states = self.decoder(hidden_states)+        return hidden_states+++class VisualBertOnlyMLMHead(nn.Module):+    def __init__(self, config):+        super().__init__()+        self.predictions = VisualBertLMPredictionHead(config)++    def forward(self, sequence_output):+        prediction_scores = self.predictions(sequence_output)+        return prediction_scores+++class VisualBertOnlySIPHead(nn.Module):  # Sentence-Image Prediction+    def __init__(self, config):+        super().__init__()+        self.seq_relationship = nn.Linear(config.hidden_size, 2)++    def forward(self, pooled_output):+        seq_relationship_score = self.seq_relationship(pooled_output)+        return seq_relationship_score+++class VisualBertPreTrainingHeads(nn.Module):+    def __init__(self, config):+        super().__init__()+        self.predictions = VisualBertLMPredictionHead(config)+        self.seq_relationship = nn.Linear(config.hidden_size, 2)++    def forward(self, sequence_output, pooled_output):+        prediction_scores = self.predictions(sequence_output)+        seq_relationship_score = self.seq_relationship(pooled_output)+        return prediction_scores, seq_relationship_score+++class VisualBertPreTrainedModel(PreTrainedModel):+    """+    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained+    models.+    """++    config_class = VisualBertConfig+    load_tf_weights = load_tf_weights_in_visual_bert+    base_model_prefix = "visual_bert"+    _keys_to_ignore_on_load_missing = [r"position_ids"]++    def _init_weights(self, module):+        """Initialize the weights"""+        if isinstance(module, (nn.Linear, nn.Embedding)):+            # Slightly different from the TF version which uses truncated_normal for initialization+            # cf https://github.com/pytorch/pytorch/pull/5617+            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)++        # TO-CHECK+        # elif isinstance(module, nn.LayerNorm):+        elif isinstance(module, BertLayerNorm):+            module.bias.data.zero_()+            module.weight.data.fill_(1.0)+        if isinstance(module, nn.Linear) and module.bias is not None:+            module.bias.data.zero_()+++@dataclass+class VisualBertForPreTrainingOutput(ModelOutput):+    """+    Output type of :class:`~transformers.VisualBertForPreTraining`.++    Args:+        loss (`optional`, returned when ``labels`` is provided, ``torch.FloatTensor`` of shape :obj:`(1,)`):+            Total loss as the sum of the masked language modeling loss and the sentence-image prediction+            (classification) loss.+        prediction_logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`):+            Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).+        seq_relationship_logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, 2)`):+            Prediction scores of the sentence-image prediction (classification) head (scores of True/False continuation+            before SoftMax).+        hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):+            Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)+            of shape :obj:`(batch_size, sequence_length, hidden_size)`.++            Hidden-states of the model at the output of each layer plus the initial embedding outputs.+        attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):+            Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads,+            sequence_length, sequence_length)`.++            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention+            heads.+    """++    loss: Optional[torch.FloatTensor] = None+    prediction_logits: torch.FloatTensor = None+    seq_relationship_logits: torch.FloatTensor = None+    hidden_states: Optional[Tuple[torch.FloatTensor]] = None+    attentions: Optional[Tuple[torch.FloatTensor]] = None+++VISUAL_BERT_START_DOCSTRING = r"""+    This model inherits from :class:`~transformers.PreTrainedModel`. Check the superclass documentation for the generic+    methods the library implements for all its model (such as downloading or saving, resizing the input embeddings,+    pruning heads etc.)++    This model is also a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`__+    subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to+    general usage and behavior.++    Parameters:+        config (:class:`~transformers.VisualBertConfig`): Model configuration class with all the parameters of the model.+            Initializing with a config file does not load the weights associated with the model, only the+            configuration. Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model+            weights.+"""++VISUAL_BERT_INPUTS_DOCSTRING = r"""+    Args:+        input_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`):+            Indices of input sequence tokens in the vocabulary.++            Indices can be obtained using :class:`~transformers.BertTokenizer`. See+            :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for+            details.++            `What are input IDs? <../glossary.html#input-ids>`__+        attention_mask (:obj:`torch.FloatTensor` of shape :obj:`({0})`, `optional`):+            Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``:++            - 1 for tokens that are **not masked**,+            - 0 for tokens that are **masked**.++            `What are attention masks? <../glossary.html#attention-mask>`__+        token_type_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`, `optional`):+            Segment token indices to indicate first and second portions of the inputs. Indices are selected in ``[0,+            1]``:++            - 0 corresponds to a `sentence A` token,+            - 1 corresponds to a `sentence B` token.++            `What are token type IDs? <../glossary.html#token-type-ids>`_+        position_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`, `optional`):+            Indices of positions of each input sequence tokens in the position embeddings. Selected in the range ``[0,+            config.max_position_embeddings - 1]``.++            `What are position IDs? <../glossary.html#position-ids>`_+        head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`):+            Mask to nullify selected heads of the self-attention modules. Mask values selected in ``[0, 1]``:++            - 1 indicates the head is **not masked**,+            - 0 indicates the head is **masked**.++        inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`({0}, hidden_size)`, `optional`):+            Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation.+            This is useful if you want more control over how to convert :obj:`input_ids` indices into associated+            vectors than the model's internal embedding lookup matrix.++        visual_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, visual_seq_length, visual_embedding_dim)`, `optional`):+            The embedded representation of the visual inputs, generally derived using using an object detector.++        visual_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, visual_seq_length)`, `optional`):+            Mask to avoid performing attention on visual embeddings. Mask values selected in ``[0, 1]``:++            - 1 for tokens that are **not masked**,+            - 0 for tokens that are **masked**.++            `What are attention masks? <../glossary.html#attention-mask>`__+        visual_token_type_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, visual_seq_length)`, `optional`):+            Segment token indices to indicate different portions of the visual embeds. Indices are selected in ``[0,+            1]``:++            - 0 corresponds to a `sentence A` token,+            - 1 corresponds to a `sentence B` token.++            `What are token type IDs? <../glossary.html#token-type-ids>`_++        image_text_alignment (:obj:`torch.LongTensor` of shape :obj:`(batch_size, visual_seq_length, alignment_number)`, `optional`):+            Image-Text alignment uses to decide the position IDs of the visual embeddings.++        output_attentions (:obj:`bool`, `optional`):+            Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned+            tensors for more detail.+        output_hidden_states (:obj:`bool`, `optional`):+            Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for+            more detail.+        return_dict (:obj:`bool`, `optional`):+            Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple.+"""++VISUAL_BERT_VQA_SAMPLE = r"""+    Example::+        >>> # Assumption: `get_visual_embeddings(image)` gets the visual embeddings of the image in the batch.+        >>> from transformers import {tokenizer_class}, {model_class}+        >>> import torch++        >>> tokenizer = {tokenizer_class}.from_pretrained('{tokenizer_checkpoint}')+        >>> model = {model_class}.from_pretrained('{checkpoint}')++        >>> text = "Who is eating the apple?"+        >>> inputs = tokenizer(text, return_tensors='pt')+        >>> visual_embeds = get_visual_embeddings(image).unsqueeze(0)+        >>> visual_token_type_ids = torch.ones(visual_embeds.shape[:-1], dtype=torch.long) #example+        >>> visual_attention_mask = torch.ones(visual_embeds.shape[:-1], dtype=torch.float)++        >>> inputs.update({{+            "visual_embeds": visual_embeds,+            "visual_token_type_ids": visual_token_type_ids,+            "visual_attention_mask": visual_attention_mask+        }})++        >>> labels = torch.tensor([[0.0,1.0]]).unsqueeze(0)  # Batch size 1, Num labels 2++        >>> outputs = model(**inputs, labels=labels)+        >>> loss = outputs.loss+        >>> scores = outputs.logits+"""++VISUAL_BERT_NLVR_SAMPLE = r"""+    Example::+        >>> # Assumption: `get_visual_embeddings(image)` gets the visual embeddings of the image in the batch.+        >>> from transformers import {tokenizer_class}, {model_class}+        >>> import torch++        >>> tokenizer = {tokenizer_class}.from_pretrained('{tokenizer_checkpoint}')+        >>> model = {model_class}.from_pretrained('{checkpoint}')++        >>> text = "Who is eating the apple?"+        >>> inputs = tokenizer(text, return_tensors='pt')+        >>> visual_embeds = get_visual_embeddings(image).unsqueeze(0)+        >>> visual_token_type_ids = torch.ones(visual_embeds.shape[:-1], dtype=torch.long) #example+        >>> visual_attention_mask = torch.ones(visual_embeds.shape[:-1], dtype=torch.float)++        >>> inputs.update({{+            "visual_embeds": visual_embeds,+            "visual_token_type_ids": visual_token_type_ids,+            "visual_attention_mask": visual_attention_mask+        }})++        >>> labels = torch.tensor(1).unsqueeze(0)  # Batch size 1, Num choices 2++        >>> outputs = model(**inputs, labels=labels)+        >>> loss = outputs.loss+        >>> scores = outputs.logits+"""+++VISUAL_BERT_VQA_ADVANCED_SAMPLE = r"""+    Example::+        >>> # Assumption: `get_visual_embeddings(image)` gets the visual embeddings of the image in the batch.+        >>> from transformers import {tokenizer_class}, {model_class}+        >>> import torch++        >>> tokenizer = {tokenizer_class}.from_pretrained('{tokenizer_checkpoint}')+        >>> model = {model_class}.from_pretrained('{checkpoint}')++        >>> text = "Who is eating the apple?"+        >>> inputs = tokenizer(text, return_tensors='pt')+        >>> visual_embeds = get_visual_embeddings(image).unsqueeze(0)+        >>> visual_token_type_ids = torch.ones(visual_embeds.shape[:-1], dtype=torch.long)+        >>> visual_attention_mask = torch.ones(visual_embeds.shape[:-1], dtype=torch.float)++        >>> inputs.update({{+            "visual_embeds": visual_embeds,+            "visual_token_type_ids": visual_token_type_ids,+            "visual_attention_mask": visual_attention_mask+        }})++        >>> labels = torch.ones((1, inputs["input_ids"].shape[-1]+visual_embeds.shape[-2])) # Batch size 1++        >>> outputs = model(**inputs, labels=labels)+        >>> loss = outputs.loss+        >>> scores = outputs.logits+"""++VISUAL_BERT_FLICKR_SAMPLE = r"""+    Example::+        >>> # Assumption: `get_visual_embeddings(image)` gets the visual embeddings of the image in the batch.+        >>> from transformers import {tokenizer_class}, {model_class}+        >>> import torch++        >>> tokenizer = {tokenizer_class}.from_pretrained('{tokenizer_checkpoint}')+        >>> model = {model_class}.from_pretrained('{checkpoint}')++        >>> text = "Who is eating the apple?"+        >>> inputs = tokenizer(text, return_tensors='pt')+        >>> visual_embeds = get_visual_embeddings(image).unsqueeze(0)+        >>> visual_token_type_ids = torch.ones(visual_embeds.shape[:-1], dtype=torch.long) #example+        >>> visual_attention_mask = torch.ones(visual_embeds.shape[:-1], dtype=torch.float)+        >>> flickr_position = torch.ones((1, inputs["input_ids"].shape[-1]+visual_embeds.shape[-2]))++        >>> inputs.update({{+            "flickr_position": flickr_position,+            "visual_embeds": visual_embeds,+            "visual_token_type_ids": visual_token_type_ids,+            "visual_attention_mask": visual_attention_mask+        }})++        >>> labels = torch.ones((1, inputs["input_ids"].shape[-1]+visual_embeds.shape[-2], visual_embeds.shape[-2])) # Batch size 1++        >>> outputs = model(**inputs, labels=labels)+        >>> loss = outputs.loss+        >>> scores = outputs.logits+"""++VISUAL_BERT_PRE_TRAINING_SAMPLE = r"""+    Example::+        >>> # Assumption: `get_visual_embeddings(image)` gets the visual embeddings of the image in the batch.+        >>> from transformers import {tokenizer_class}, {model_class}+        >>> import torch++        >>> tokenizer = {tokenizer_class}.from_pretrained('{tokenizer_checkpoint}')+        >>> model = {model_class}.from_pretrained('{checkpoint}')++        >>> inputs = tokenizer("The capital of France is {mask}.", return_tensors="pt")+        >>> visual_embeds = get_visual_embeddings(image).unsqueeze(0)+        >>> visual_token_type_ids = torch.ones(visual_embeds.shape[:-1], dtype=torch.long) #example+        >>> visual_attention_mask = torch.ones(visual_embeds.shape[:-1], dtype=torch.float)++        >>> inputs.update({{+            "visual_embeds": visual_embeds,+            "visual_token_type_ids": visual_token_type_ids,+            "visual_attention_mask": visual_attention_mask+        }})++        >>> labels = tokenizer("The capital of France is Paris.", return_tensors="pt", padding="max_length", max_length=inputs["input_ids"].shape[-1]+visual_embeds.shape[-2])["input_ids"]+        >>> sentence_image_labels = torch.tensor(1).unsqueeze(0) # Batch_size+++        >>> outputs = model(**inputs, labels=labels, sentence_image_labels=sentence_image_labels)+        >>> loss = outputs.loss+        >>> prediction_logits = outputs.prediction_logits+        >>> seq_relationship_logits = outputs.seq_relationship_logits+"""++VISUAL_BERT_MODEL_SAMPLE = r"""+    Example::+        >>> # Assumption: `get_visual_embeddings(image)` gets the visual embeddings of the image.+        >>> from transformers import {tokenizer_class}, {model_class}+        >>> import torch++        >>> tokenizer = {tokenizer_class}.from_pretrained('{tokenizer_checkpoint}')+        >>> model = {model_class}.from_pretrained('{checkpoint}')++        >>> inputs = tokenizer("The capital of France is Paris.", return_tensors="pt")+        >>> visual_embeds = get_visual_embeddings(image).unsqueeze(0)+        >>> visual_token_type_ids = torch.ones(visual_embeds.shape[:-1], dtype=torch.long) #example+        >>> visual_attention_mask = torch.ones(visual_embeds.shape[:-1], dtype=torch.float)++        >>> inputs.update({{+            "visual_embeds": visual_embeds,+            "visual_token_type_ids": visual_token_type_ids,+            "visual_attention_mask": visual_attention_mask+        }})++        >>> outputs = model(**inputs)++        >>> last_hidden_states = outputs.last_hidden_state+"""++VISUAL_BERT_MULTIPLE_CHOICE_SAMPLE = r"""+    Example::++        >>> from transformers import {tokenizer_class}, {model_class}+        >>> import torch++        >>> tokenizer = {tokenizer_class}.from_pretrained('{tokenizer_checkpoint}')+        >>> model = {model_class}.from_pretrained('{checkpoint}')++        >>> prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced."+        >>> choice0 = "It is eaten with a fork and a knife."+        >>> choice1 = "It is eaten while held in the hand."++        >>> visual_embeds = get_visual_embeddings(image)+        >>> visual_embeds = visual_embeds.expand(1, 2, *visual_embeds.shape) # (batch_size, num_choices, visual_seq_length, visual_embedding_dim)+        >>> visual_token_type_ids = torch.ones(visual_embeds.shape[:-1], dtype=torch.long)+        >>> visual_attention_mask = torch.ones(visual_embeds.shape[:-1], dtype=torch.float)++        >>> labels = torch.tensor(0).unsqueeze(0)  # choice0 is correct (according to Wikipedia ;)), batch size 1++        >>> encoding = tokenizer([[prompt, prompt], [choice0, choice1]], return_tensors='pt', padding=True)+        >>> outputs = model(**{{k: v.unsqueeze(0) for k,v in encoding.items()}}, visual_embeds=visual_embeds, visual_attention_mask=visual_attention_mask, visual_token_type_ids=visual_token_type_ids, labels=labels)  # batch size is 1++        >>> loss = outputs.loss+        >>> logits = outputs.logits+"""+++@add_start_docstrings(+    "The bare VisualBert Model transformer outputting raw hidden-states without any specific head on top.",+    VISUAL_BERT_START_DOCSTRING,+)+class VisualBertModel(VisualBertPreTrainedModel):+    """++    The model can behave as an encoder (with only self-attention) following the architecture described in `Attention is+    all you need <https://arxiv.org/abs/1706.03762>`__ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,+    Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.+    """++    def __init__(self, config, add_pooling_layer=True):+        super().__init__(config)+        self.config = config++        self.embeddings = VisualBertEmbeddings(config)+        self.encoder = VisualBertEncoder(config)++        self.pooler = (+            VisualBertPooler(config) if add_pooling_layer else None+        )  # TO-DO: Check if pooler is needed necessarily.++        self.bypass_transformer = config.bypass_transformer++        if self.bypass_transformer:+            self.additional_layer = VisualBertLayer(config)++        # TO-CHECK: This next line is from old BERT code, which is not used anymore.+        # self.output_attention_weights = config.output_attention_weights++        self.init_weights()  # self.apply(self.init_bert_weights) #Vestiges of old code++    # TO-CHECK+    def get_input_embeddings(self):+        return self.embeddings.word_embeddings++    # TO-CHECK+    def set_input_embeddings(self, value):+        self.embeddings.word_embeddings = value++    # TO-CHECK+    def _prune_heads(self, heads_to_prune):+        """+        Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base+        class PreTrainedModel+        """+        for layer, heads in heads_to_prune.items():+            self.encoder.layer[layer].attention.prune_heads(heads)++    @add_start_docstrings_to_model_forward(VISUAL_BERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))+    @replace_return_docstrings(output_type=BaseModelOutputWithPoolingAndCrossAttentions, config_class=_CONFIG_FOR_DOC)+    @add_code_sample_docstrings(+        tokenizer_class=_TOKENIZER_FOR_DOC,+        tokenizer_checkpoint=_TOKENIZER_CHECKPOINT,+        checkpoint="gchhablani/visualbert-vqa-coco-pre",+        output_type=BaseModelOutputWithPoolingAndCrossAttentions,+        config_class="gchhablani/visualbert-vqa-coco-pre",+        code_sample=VISUAL_BERT_MODEL_SAMPLE,+    )+    def forward(+        self,+        input_ids=None,+        attention_mask=None,+        token_type_ids=None,+        position_ids=None,+        head_mask=None,+        inputs_embeds=None,+        visual_embeds=None,+        visual_attention_mask=None,+        visual_token_type_ids=None,+        image_text_alignment=None,+        # encoder_hidden_states=None,+        # encoder_attention_mask=None,+        # past_key_values=None,+        # use_cache=None,+        output_attentions=None,+        output_hidden_states=None,+        return_dict=None,+    ):++        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions+        output_hidden_states = (+            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states+        )+        return_dict = return_dict if return_dict is not None else self.config.use_return_dict++        # if self.config.is_decoder:+        #     use_cache = use_cache if use_cache is not None else self.config.use_cache+        # else:+        use_cache = False++        if input_ids is not None and inputs_embeds is not None:+            raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")+        elif input_ids is not None:+            input_shape = input_ids.size()+            batch_size, seq_length = input_shape+        elif inputs_embeds is not None:+            input_shape = inputs_embeds.size()[:-1]+            batch_size, seq_length = input_shape+        else:+            raise ValueError("You have to specify either input_ids or inputs_embeds")++        device = input_ids.device if input_ids is not None else inputs_embeds.device++        # past_key_values_length+        # past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0++        if attention_mask is None:+            attention_mask = torch.ones(((batch_size, seq_length)), device=device)+            # attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)+        if token_type_ids is None:+            token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)++        if visual_embeds is not None:+            visual_input_shape = visual_embeds.size()[:-1]+            _, visual_seq_length = visual_input_shape+            if visual_token_type_ids is None:+                visual_token_type_ids = torch.zeros(visual_input_shape, dtype=torch.long, device=device)++            if visual_attention_mask is None:+                visual_attention_mask = torch.ones(visual_input_shape, device=device)++        # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]+        # ourselves in which case we just need to make it broadcastable to all heads.++        # TO-CHECK : Whether input_shape+visual_input_shape is correct.+        if visual_embeds is not None:+            combined_attention_mask = torch.cat((attention_mask, visual_attention_mask), dim=-1)+            extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(+                combined_attention_mask, [batch_size, input_shape + visual_input_shape], device+            )+        else:+            extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(+                attention_mask, [batch_size, input_shape], device+            )++        # If a 2D or 3D attention mask is provided for the cross-attention+        # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]+        # if self.config.is_decoder and encoder_hidden_states is not None:+        #     encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()+        #     encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)+        #     if encoder_attention_mask is None:+        #         encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)+        #     encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)+        # else:+        #     encoder_extended_attention_mask = None++        # Prepare head mask if needed+        # 1.0 in head_mask indicate we keep the head+        # attention_probs has shape bsz x n_heads x N x N+        # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]+        # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]+        head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)++        embedding_output = self.embeddings(+            input_ids=input_ids,+            position_ids=position_ids,+            token_type_ids=token_type_ids,+            inputs_embeds=inputs_embeds,+            visual_embeds=visual_embeds,+            visual_token_type_ids=visual_token_type_ids,+            image_text_alignment=image_text_alignment,+            # past_key_values_length=past_key_values_length,+        )++        if self.bypass_transformer and visual_embeds is not None:+            assert output_hidden_states is None  # TO-DO: Need to check if this is correct.+            text_length = input_ids.size(1)+            text_embedding_output = embedding_output[:, :text_length, :]+            visual_embedding_output = embedding_output[:, text_length:, :]++            text_extended_attention_mask = extended_attention_mask[:, :, text_length, :text_length]+            # text_encoder_hidden_states = encoder_hidden_states[:, :text_length, :]+            # text_encoder_attention_mask = encoder_extended_attention_mask[:, :, :text_length, :text_length]++            # TO-DO: Check how past-key values work and whether they are required to be modified and added here++            encoded_outputs = self.encoder(+                text_embedding_output,+                attention_mask=text_extended_attention_mask,+                # encoder_hidden_states=text_encoder_hidden_states,+                # encoder_attention_mask=text_encoder_attention_mask,+                use_cache=use_cache,+                output_attentions=output_attentions,+                output_hidden_states=output_hidden_states,+                return_dict=return_dict,+            )+            sequence_output = encoded_outputs[0]+            concatenated_input = torch.cat((sequence_output, visual_embedding_output), dim=1)+            sequence_output = self.additional_layer(concatenated_input, extended_attention_mask)+            pooled_output = self.pooler(sequence_output) if self.pooler is not None else None++        else:+            encoder_outputs = self.encoder(+                embedding_output,+                attention_mask=extended_attention_mask,+                head_mask=head_mask,+                # encoder_hidden_states=encoder_hidden_states,+                # encoder_attention_mask=encoder_extended_attention_mask,+                # past_key_values=past_key_values,+                use_cache=use_cache,+                output_attentions=output_attentions,+                output_hidden_states=output_hidden_states,+                return_dict=return_dict,+            )+            sequence_output = encoder_outputs[0]++            pooled_output = self.pooler(sequence_output) if self.pooler is not None else None++        if not return_dict:+            return (sequence_output, pooled_output) + encoder_outputs[1:]++        return BaseModelOutputWithPoolingAndCrossAttentions(  # Changed+            last_hidden_state=sequence_output,+            pooler_output=pooled_output,+            past_key_values=encoder_outputs.past_key_values,+            hidden_states=encoder_outputs.hidden_states,+            attentions=encoder_outputs.attentions,+            cross_attentions=encoder_outputs.cross_attentions,+        )+++# TO-DO: Check if the case where we don't want to calculate is_random_next loss.+# The is a case where during pre-training, in the original code, it is checked if the is_random_next is None+# In this case, the next_sentence_loss is not calculated.+++@add_start_docstrings(+    """+    VisualBert Model with two heads on top as done during the pretraining: a `masked language modeling` head and a+    `sentence-image prediction (classification)` head.+    """,+    VISUAL_BERT_START_DOCSTRING,+)+class VisualBertForPreTraining(VisualBertPreTrainedModel):+    def __init__(self, config):+        super().__init__(config)++        self.visual_bert = VisualBertModel(config)+        self.cls = VisualBertPreTrainingHeads(config)++        # UNUSED+        # self.cut_first = cut_first+        # self.hard_cap_seq_len = hard_cap_seq_len++        self.init_weights()++    def get_output_embeddings(self):+        return self.cls.predictions.decoder++    def set_output_embeddings(self, new_embeddings):+        self.cls.predictions.decoder = new_embeddings++    @add_start_docstrings_to_model_forward(VISUAL_BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))+    @replace_return_docstrings(output_type=VisualBertForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)+    @add_code_sample_docstrings(+        tokenizer_class=_TOKENIZER_FOR_DOC,+        tokenizer_checkpoint=_TOKENIZER_CHECKPOINT,+        checkpoint="gchhablani/visualbert-vqa-coco-pre",+        mask="[MASK]",+        output_type=VisualBertForPreTrainingOutput,+        config_class="gchhablani/visualbert-vqa-coco-pre",+        code_sample=VISUAL_BERT_PRE_TRAINING_SAMPLE,+    )+    def forward(+        self,+        input_ids=None,+        attention_mask=None,+        token_type_ids=None,+        position_ids=None,+        head_mask=None,+        inputs_embeds=None,+        visual_embeds=None,+        visual_attention_mask=None,+        visual_token_type_ids=None,+        image_text_alignment=None,+        output_attentions=None,+        output_hidden_states=None,+        return_dict=None,+        labels=None,+        sentence_image_labels=None,+    ):+        r"""+        labels (:obj:`torch.LongTensor` of shape ``(batch_size, total_sequence_length)``, `optional`):+            Labels for computing the masked language modeling loss. Indices should be in ``[-100, 0, ...,+            config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored+            (masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]``+        sentence_image_labels (``torch.LongTensor`` of shape ``(batch_size,)``, `optional`):+            Labels for computing the sentence-image prediction (classification) loss. Input should be a sequence pair+            (see :obj:`input_ids` docstring) Indices should be in ``[0, 1]``:++            - 0 indicates sequence B is a matching pair of sequence A for the given image,+            - 1 indicates sequence B is a random sequence w.r.t A for the given image.++        Returns:+        """+        return_dict = return_dict if return_dict is not None else self.config.use_return_dict++        outputs = self.visual_bert(+            input_ids,+            attention_mask=attention_mask,+            token_type_ids=token_type_ids,+            position_ids=position_ids,+            head_mask=head_mask,+            inputs_embeds=inputs_embeds,+            visual_embeds=visual_embeds,+            visual_attention_mask=visual_attention_mask,+            visual_token_type_ids=visual_token_type_ids,+            image_text_alignment=image_text_alignment,+            output_attentions=output_attentions,+            output_hidden_states=output_hidden_states,+            return_dict=return_dict,+        )++        sequence_output, pooled_output = outputs[:2]+        prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output)++        total_loss = None+        if labels is not None and sentence_image_labels is not None:+            assert labels.size(-1) == attention_mask.size(-1) + visual_attention_mask.size(+                -1+            ), f"The labels provided should have same sequence length as total attention mask. Found labels with sequence length {labels.size(-1)}, expected {attention_mask.size(-1)+ visual_attention_mask.size(-1)}."++            loss_fct = CrossEntropyLoss()+            masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))+            sentence_image_loss = loss_fct(seq_relationship_score.view(-1, 2), sentence_image_labels.view(-1))+            total_loss = masked_lm_loss + sentence_image_loss++        # TO-CHECK+        if labels is not None and sentence_image_labels is None:+            assert labels.size(-1) == attention_mask.size(-1) + visual_attention_mask.size(+                -1+            ), f"The labels provided should have same sequence length as total attention mask. Found labels with sequence length {labels.size(-1)}, expected {attention_mask.size(-1)+ visual_attention_mask.size(-1)}."+            loss_fct = CrossEntropyLoss()+            masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))+            total_loss = masked_lm_loss++        if not return_dict:+            output = (prediction_scores, seq_relationship_score) + outputs[2:]+            return ((total_loss,) + output) if total_loss is not None else output++        return VisualBertForPreTrainingOutput(+            loss=total_loss,+            prediction_logits=prediction_scores,+            seq_relationship_logits=seq_relationship_score,+            hidden_states=outputs.hidden_states,+            attentions=outputs.attentions,+        )+++class VisualBertClassificationHead(nn.Module):+    """Head for sentence-level classification tasks."""++    def __init__(self, config):+        super().__init__()+        self.dense = nn.Linear(config.hidden_size, config.hidden_size)+        self.dropout = nn.Dropout(config.hidden_dropout_prob)+        self.out_proj = nn.Linear(config.hidden_size, config.num_labels)++        self.config = config++    def forward(self, features, **kwargs):+        x = features[:, 0, :]  # take <s> token (equiv. to [CLS])+        x = self.dropout(x)+        x = self.dense(x)+        x = ACT2FN[self.config.hidden_act](x)+        x = self.dropout(x)+        x = self.out_proj(x)+        return x+++@add_start_docstrings(+    """+    VisualBert Model with a multiple choice classification head on top (a linear layer on top of the pooled output and+    a softmax) e.g. for VCR tasks.+    """,+    VISUAL_BERT_START_DOCSTRING,+)+class VisualBertForMultipleChoice(VisualBertPreTrainedModel):+    def __init__(self, config):+        super().__init__(config)++        self.visual_bert = VisualBertModel(config)+        # TO-CHECK+        # self.sequence_summary = SequenceSummary(config)

Removing.

gchhablani

comment created time in an hour

Pull request review commenthuggingface/transformers

Add VisualBERT

+# coding=utf-8+# Copyright 2021 The HuggingFace Inc. team. All rights reserved.+#+# Licensed under the Apache License, Version 2.0 (the "License");+# you may not use this file except in compliance with the License.+# You may obtain a copy of the License at+#+#     http://www.apache.org/licenses/LICENSE-2.0+#+# Unless required by applicable law or agreed to in writing, software+# distributed under the License is distributed on an "AS IS" BASIS,+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.+# See the License for the specific language governing permissions and+# limitations under the License.+""" PyTorch VisualBERT model. """+++import math+import os+from dataclasses import dataclass+from typing import Optional, Tuple++import torch+import torch.utils.checkpoint+from torch import nn+from torch.nn import CrossEntropyLoss, KLDivLoss, LogSoftmax++from ...activations import ACT2FN+from ...file_utils import (+    ModelOutput,+    _prepare_output_docstrings,+    add_start_docstrings,+    add_start_docstrings_to_model_forward,+    replace_return_docstrings,+)+from ...modeling_outputs import (+    BaseModelOutputWithPastAndCrossAttentions,+    BaseModelOutputWithPoolingAndCrossAttentions,+    MaskedLMOutput,+    MultipleChoiceModelOutput,+    SequenceClassifierOutput,+)+from ...modeling_utils import (+    PreTrainedModel,+    apply_chunking_to_forward,+    find_pruneable_heads_and_indices,+    prune_linear_layer,+)+from ...utils import logging+from .configuration_visual_bert import VisualBertConfig+++logger = logging.get_logger(__name__)++_CONFIG_FOR_DOC = "VisualBertConfig"+_TOKENIZER_FOR_DOC = "BertTokenizer"+_TOKENIZER_CHECKPOINT = "bert-base-uncased"++VISUAL_BERT_PRETRAINED_MODEL_ARCHIVE_LIST = [+    "gchhablani/visualbert-vqa",+    "gchhablani/visualbert-vqa-pre",+    "gchhablani/visualbert-vqa-coco-pre",+    "gchhablani/visualbert-vcr",+    "gchhablani/visualbert-vcr-pre",+    "gchhablani/visualbert-vcr-coco-pre",+    "gchhablani/visualbert-nlvr2",+    "gchhablani/visualbert-nlvr2-pre",+    "gchhablani/visualbert-nlvr2-coco-pre"+    # See all VisualBERT models at https://huggingface.co/models?filter=visual_bert+]+++# TO-CHECK+def load_tf_weights_in_visual_bert(model, config, tf_checkpoint_path):+    """Load tf checkpoints in a pytorch model."""+    try:+        import re++        import numpy as np+        import tensorflow as tf+    except ImportError:+        logger.error(+            "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see "+            "https://www.tensorflow.org/install/ for installation instructions."+        )+        raise+    tf_path = os.path.abspath(tf_checkpoint_path)+    logger.info("Converting TensorFlow checkpoint from {}".format(tf_path))+    # Load weights from TF model+    init_vars = tf.train.list_variables(tf_path)+    names = []+    arrays = []+    for name, shape in init_vars:+        logger.info("Loading TF weight {} with shape {}".format(name, shape))+        array = tf.train.load_variable(tf_path, name)+        names.append(name)+        arrays.append(array)++    for name, array in zip(names, arrays):+        name = name.split("/")+        # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v+        # which are not required for using pretrained model+        if any(+            n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"]+            for n in name+        ):+            logger.info("Skipping {}".format("/".join(name)))+            continue+        pointer = model+        for m_name in name:+            if re.fullmatch(r"[A-Za-z]+_\d+", m_name):+                scope_names = re.split(r"_(\d+)", m_name)+            else:+                scope_names = [m_name]+            if scope_names[0] == "kernel" or scope_names[0] == "gamma":+                pointer = getattr(pointer, "weight")+            elif scope_names[0] == "output_bias" or scope_names[0] == "beta":+                pointer = getattr(pointer, "bias")+            elif scope_names[0] == "output_weights":+                pointer = getattr(pointer, "weight")+            elif scope_names[0] == "squad":+                pointer = getattr(pointer, "classifier")+            else:+                try:+                    pointer = getattr(pointer, scope_names[0])+                except AttributeError:+                    logger.info("Skipping {}".format("/".join(name)))+                    continue+            if len(scope_names) >= 2:+                num = int(scope_names[1])+                pointer = pointer[num]+        if m_name[-11:] == "_embeddings":+            pointer = getattr(pointer, "weight")+        elif m_name == "kernel":+            array = np.transpose(array)+        try:+            assert (+                pointer.shape == array.shape+            ), f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched"+        except AssertionError as e:+            e.args += (pointer.shape, array.shape)+            raise+        logger.info("Initialize PyTorch weight {}".format(name))+        pointer.data = torch.from_numpy(array)+    return model+++def mish(x):+    return x * torch.tanh(nn.functional.softplus(x))+++def add_code_sample_docstrings(+    *docstr,+    tokenizer_class=None,+    tokenizer_checkpoint=None,+    checkpoint=None,+    output_type=None,+    config_class=None,+    mask=None,+    model_cls=None,+    code_sample=None+):+    def docstring_decorator(fn):+        # model_class defaults to function's class if not specified otherwise+        model_class = fn.__qualname__.split(".")[0] if model_cls is None else model_cls++        doc_kwargs = dict(+            model_class=model_class,+            tokenizer_class=tokenizer_class,+            checkpoint=checkpoint,+            mask=mask,+            tokenizer_checkpoint=tokenizer_checkpoint,+        )++        output_doc = _prepare_output_docstrings(output_type, config_class) if output_type is not None else ""++        built_doc = code_sample.format(**doc_kwargs)+        fn.__doc__ = (fn.__doc__ or "") + "".join(docstr) + output_doc + built_doc+        return fn++    return docstring_decorator+++# TO-CHECK: Vestige of the original code+++class BertLayerNorm(nn.Module):+    def __init__(self, hidden_size, eps=1e-12):+        """Construct a layernorm module in the TF style (epsilon inside the square root)."""+        super(BertLayerNorm, self).__init__()+        self.weight = nn.Parameter(torch.ones(hidden_size))+        self.bias = nn.Parameter(torch.zeros(hidden_size))+        self.variance_epsilon = eps++    def forward(self, x):+        u = x.mean(-1, keepdim=True)+        s = (x - u).pow(2).mean(-1, keepdim=True)+        x = (x - u) / torch.sqrt(s + self.variance_epsilon)+        return self.weight * x + self.bias+++class VisualBertEmbeddings(nn.Module):+    """Construct the embeddings from word, position and token_type embeddings and visual embeddings."""++    def __init__(self, config):+        super().__init__()+        self.word_embeddings = nn.Embedding(+            config.vocab_size, config.hidden_size+        )  # TO-CHECK: , padding_idx=config.pad_token_id+        self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)+        self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)++        # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load+        # any TensorFlow checkpoint file++        # TO-CHECK+        # self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)+        self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps)  # original eps=1e-12+        self.dropout = nn.Dropout(config.hidden_dropout_prob)++        # TO-CHECK+        # position_ids (1, len position emb) is contiguous in memory and exported when serialized+        # self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))+        self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")++        # For Visual Features+        # Segment and position embedding for image features+        self.visual_token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)+        self.visual_position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)++        # TO-CHECK: Check if register buffer is needed for Visual features+        self.visual_projection = nn.Linear(config.visual_embedding_dim, config.hidden_size)++    # TO-CHECK: Check how to incorporate this. This is being called outside the classes.+    # def special_intialize(self):+    #     ### This is a bit unorthodox. The better way might be to add an inititilizer to AllenNLP.+    #     # This function is used to initialize the token_type_embeddings_visual and positiona_embedding_visual, just incase.+    #     self.token_type_embeddings_visual.weight = torch.nn.Parameter(deepcopy(self.token_type_embeddings.weight.data), requires_grad = True)+    #     self.position_embeddings_visual.weight = torch.nn.Parameter(deepcopy(self.position_embeddings.weight.data), requires_grad = True)+    #     return++    def forward(+        self,+        input_ids=None,  # TO-CHECK+        token_type_ids=None,+        position_ids=None,+        inputs_embeds=None,+        visual_embeds=None,+        visual_token_type_ids=None,+        image_text_alignment=None,+        # past_key_values_length=0, # TO-CHECK+    ):+        # TO-CHECK: Check if `confidence=None` and `visual_position_embeds=None` (or id) is needed.+        # `position_embeddings_visual` is not used in the original code.++        if input_ids is not None:+            input_shape = input_ids.size()+        else:+            input_shape = inputs_embeds.size()[:-1]++        seq_length = input_shape[1]++        # TO-CHECK+        # if position_ids is None:+        #     position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]+        if input_ids is not None:+            position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device)+            position_ids = position_ids.unsqueeze(0).expand_as(input_ids)+        else:+            position_ids = torch.arange(seq_length, dtype=torch.long, device=inputs_embeds.device)+            position_ids = position_ids.unsqueeze(0).expand(input_shape)++        if token_type_ids is None:+            token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)++        if inputs_embeds is None:+            inputs_embeds = self.word_embeddings(input_ids)++        token_type_embeddings = self.token_type_embeddings(token_type_ids)++        embeddings = inputs_embeds + token_type_embeddings+        if self.position_embedding_type == "absolute":+            position_embeddings = self.position_embeddings(position_ids)+            embeddings += position_embeddings++        if visual_embeds is not None:+            visual_embeds = self.visual_projection(visual_embeds)+            visual_token_type_embeddings = self.visual_token_type_embeddings(visual_token_type_ids)++            if image_text_alignment is not None:++                # TO-DO: Find a way to handle this in a better way.+                # image_text_alignment = Batch x image_length x alignment_number. Each element denotes the position of the word corresponding to the image feature. -1 is the padding value.+                image_text_alignment_mask = (image_text_alignment != -1).long()+                # Get rid of the -1.+                image_text_alignment = image_text_alignment_mask * image_text_alignment++                # Batch x image_length x alignment length x dim+                visual_position_embeddings = self.position_embeddings(+                    image_text_alignment+                ) * image_text_alignment_mask.to(dtype=next(self.parameters()).dtype).unsqueeze(-1)+                visual_position_embeddings = visual_position_embeddings.sum(2)++                # We want to averge along the alignment_number dimension.+                image_text_alignment_mask = image_text_alignment_mask.to(dtype=next(self.parameters()).dtype).sum(2)+                image_text_alignment_mask[image_text_alignment_mask == 0] = 1  # Avoid devide by zero error+                visual_position_embeddings = visual_position_embeddings / image_text_alignment_mask.unsqueeze(-1)++                visual_position_ids = torch.zeros(+                    *visual_embeds.size()[:-1], dtype=torch.long, device=visual_embeds.device+                )  # They use .cuda() but I believe this will be same as visual_embeds device.++                # When fine-tuning the detector , the image_text_alignment is sometimes padded too long.+                if visual_position_embeddings.size(1) != visual_embeds.size(1):+                    assert visual_position_embeddings.size(1) >= visual_embeds.size(1)+                    visual_position_embeddings = visual_position_embeddings[:, : visual_embeds.size(1), :]++                visual_position_embeddings = visual_position_embeddings + self.visual_position_embeddings(+                    visual_position_ids+                )+            else:+                visual_position_ids = torch.zeros(+                    *visual_embeds.size()[:-1], dtype=torch.long, device=visual_embeds.device+                )  # They use .cuda() but I believe this will be same as visual_embeds device.+                visual_position_embeddings = self.visual_position_embeddings(visual_position_ids)++            visual_embeddings = visual_embeds + visual_position_embeddings + visual_token_type_embeddings++            # Concate the two:+            embeddings = torch.cat((embeddings, visual_embeddings), dim=1)  # concat the visual embeddings++        embeddings = self.LayerNorm(embeddings)+        embeddings = self.dropout(embeddings)+        return embeddings+++class VisualBertSelfAttention(nn.Module):+    def __init__(self, config):+        super().__init__()+        if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):+            raise ValueError(+                "The hidden size (%d) is not a multiple of the number of attention "+                "heads (%d)" % (config.hidden_size, config.num_attention_heads)+            )++        self.num_attention_heads = config.num_attention_heads+        self.attention_head_size = int(config.hidden_size / config.num_attention_heads)+        self.all_head_size = self.num_attention_heads * self.attention_head_size++        self.query = nn.Linear(config.hidden_size, self.all_head_size)+        self.key = nn.Linear(config.hidden_size, self.all_head_size)+        self.value = nn.Linear(config.hidden_size, self.all_head_size)++        self.dropout = nn.Dropout(config.attention_probs_dropout_prob)+        self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")+        if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":+            self.max_position_embeddings = config.max_position_embeddings+            self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)++        self.is_decoder = config.is_decoder++    def transpose_for_scores(self, x):+        new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)+        x = x.view(*new_x_shape)+        return x.permute(0, 2, 1, 3)++    def forward(+        self,+        hidden_states,+        attention_mask=None,+        head_mask=None,+        encoder_hidden_states=None,+        encoder_attention_mask=None,+        past_key_value=None,+        output_attentions=False,+    ):+        mixed_query_layer = self.query(hidden_states)++        # If this is instantiated as a cross-attention module, the keys+        # and values come from an encoder; the attention mask needs to be+        # such that the encoder's padding tokens are not attended to.+        is_cross_attention = encoder_hidden_states is not None++        if is_cross_attention and past_key_value is not None:+            # reuse k,v, cross_attentions+            key_layer = past_key_value[0]+            value_layer = past_key_value[1]+            attention_mask = encoder_attention_mask+        elif is_cross_attention:+            key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))+            value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))+            attention_mask = encoder_attention_mask+        elif past_key_value is not None:+            key_layer = self.transpose_for_scores(self.key(hidden_states))+            value_layer = self.transpose_for_scores(self.value(hidden_states))+            key_layer = torch.cat([past_key_value[0], key_layer], dim=2)+            value_layer = torch.cat([past_key_value[1], value_layer], dim=2)+        else:+            key_layer = self.transpose_for_scores(self.key(hidden_states))+            value_layer = self.transpose_for_scores(self.value(hidden_states))++        query_layer = self.transpose_for_scores(mixed_query_layer)++        if self.is_decoder:+            # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.+            # Further calls to cross_attention layer can then reuse all cross-attention+            # key/value_states (first "if" case)+            # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of+            # all previous decoder key/value_states. Further calls to uni-directional self-attention+            # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)+            # if encoder bi-directional self-attention `past_key_value` is always `None`+            past_key_value = (key_layer, value_layer)++        # Take the dot product between "query" and "key" to get the raw attention scores.+        attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))++        if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":+            seq_length = hidden_states.size()[1]+            position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)+            position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1)+            distance = position_ids_l - position_ids_r+            positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)+            positional_embedding = positional_embedding.to(dtype=query_layer.dtype)  # fp16 compatibility++            if self.position_embedding_type == "relative_key":+                relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)+                attention_scores = attention_scores + relative_position_scores+            elif self.position_embedding_type == "relative_key_query":+                relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)+                relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)+                attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key++        attention_scores = attention_scores / math.sqrt(self.attention_head_size)+        if attention_mask is not None:+            # Apply the attention mask is (precomputed for all layers in VisualBertModel forward() function)+            attention_scores = attention_scores + attention_mask++        # Normalize the attention scores to probabilities.+        attention_probs = nn.Softmax(dim=-1)(attention_scores)++        # This is actually dropping out entire tokens to attend to, which might+        # seem a bit unusual, but is taken from the original Transformer paper.+        attention_probs = self.dropout(attention_probs)++        # Mask heads if we want to+        if head_mask is not None:+            attention_probs = attention_probs * head_mask++        context_layer = torch.matmul(attention_probs, value_layer)++        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()+        new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)+        context_layer = context_layer.view(*new_context_layer_shape)++        outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)++        if self.is_decoder:+            outputs = outputs + (past_key_value,)+        return outputs+++class VisualBertSelfOutput(nn.Module):+    def __init__(self, config):+        super().__init__()+        self.dense = nn.Linear(config.hidden_size, config.hidden_size)++        # TO-CHECK+        # self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)+        self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps)  # original eps=1e-12+        self.dropout = nn.Dropout(config.hidden_dropout_prob)++    def forward(self, hidden_states, input_tensor):+        hidden_states = self.dense(hidden_states)+        hidden_states = self.dropout(hidden_states)+        hidden_states = self.LayerNorm(hidden_states + input_tensor)+        return hidden_states+++class VisualBertAttention(nn.Module):+    def __init__(self, config):+        super().__init__()+        self.self = VisualBertSelfAttention(config)+        self.output = VisualBertSelfOutput(config)+        self.pruned_heads = set()++    def prune_heads(self, heads):+        if len(heads) == 0:+            return+        heads, index = find_pruneable_heads_and_indices(+            heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads+        )++        # Prune linear layers+        self.self.query = prune_linear_layer(self.self.query, index)+        self.self.key = prune_linear_layer(self.self.key, index)+        self.self.value = prune_linear_layer(self.self.value, index)+        self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)++        # Update hyper params and store pruned heads+        self.self.num_attention_heads = self.self.num_attention_heads - len(heads)+        self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads+        self.pruned_heads = self.pruned_heads.union(heads)++    def forward(+        self,+        hidden_states,+        attention_mask=None,+        head_mask=None,+        encoder_hidden_states=None,+        encoder_attention_mask=None,+        past_key_value=None,+        output_attentions=False,+    ):+        self_outputs = self.self(+            hidden_states,+            attention_mask,+            head_mask,+            encoder_hidden_states,+            encoder_attention_mask,+            past_key_value,+            output_attentions,+        )+        attention_output = self.output(self_outputs[0], hidden_states)+        outputs = (attention_output,) + self_outputs[1:]  # add attentions if we output them+        return outputs+++class VisualBertIntermediate(nn.Module):+    def __init__(self, config):+        super().__init__()+        self.dense = nn.Linear(config.hidden_size, config.intermediate_size)+        if isinstance(config.hidden_act, str):+            self.intermediate_act_fn = ACT2FN[config.hidden_act]+        else:+            self.intermediate_act_fn = config.hidden_act++    def forward(self, hidden_states):+        hidden_states = self.dense(hidden_states)+        hidden_states = self.intermediate_act_fn(hidden_states)+        return hidden_states+++class VisualBertOutput(nn.Module):+    def __init__(self, config):+        super().__init__()+        self.dense = nn.Linear(config.intermediate_size, config.hidden_size)+        # TO-CHECK+        # self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)+        self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps)  # original eps=1e-12+        self.dropout = nn.Dropout(config.hidden_dropout_prob)++    def forward(self, hidden_states, input_tensor):+        hidden_states = self.dense(hidden_states)+        hidden_states = self.dropout(hidden_states)+        hidden_states = self.LayerNorm(hidden_states + input_tensor)+        return hidden_states+++class VisualBertLayer(nn.Module):+    def __init__(self, config):+        super().__init__()+        self.chunk_size_feed_forward = config.chunk_size_feed_forward+        self.seq_len_dim = 1+        self.attention = VisualBertAttention(config)+        self.is_decoder = config.is_decoder+        self.add_cross_attention = config.add_cross_attention+        if self.add_cross_attention:+            assert self.is_decoder, f"{self} should be used as a decoder model if cross attention is added"+            self.crossattention = VisualBertAttention(config)+        self.intermediate = VisualBertIntermediate(config)+        self.output = VisualBertOutput(config)++    def forward(+        self,+        hidden_states,+        attention_mask=None,+        head_mask=None,+        encoder_hidden_states=None,+        encoder_attention_mask=None,+        past_key_value=None,+        output_attentions=False,+    ):+        # decoder uni-directional self-attention cached key/values tuple is at positions 1,2+        self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None+        self_attention_outputs = self.attention(+            hidden_states,+            attention_mask,+            head_mask,+            output_attentions=output_attentions,+            past_key_value=self_attn_past_key_value,+        )+        attention_output = self_attention_outputs[0]++        # if decoder, the last output is tuple of self-attn cache+        if self.is_decoder:+            outputs = self_attention_outputs[1:-1]+            present_key_value = self_attention_outputs[-1]+        else:+            outputs = self_attention_outputs[1:]  # add self attentions if we output attention weights++        cross_attn_present_key_value = None+        if self.is_decoder and encoder_hidden_states is not None:+            assert hasattr(+                self, "crossattention"+            ), f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers by setting `config.add_cross_attention=True`"++            # cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple+            cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None+            cross_attention_outputs = self.crossattention(+                attention_output,+                attention_mask,+                head_mask,+                encoder_hidden_states,+                encoder_attention_mask,+                cross_attn_past_key_value,+                output_attentions,+            )+            attention_output = cross_attention_outputs[0]+            outputs = outputs + cross_attention_outputs[1:-1]  # add cross attentions if we output attention weights++            # add cross-attn cache to positions 3,4 of present_key_value tuple+            cross_attn_present_key_value = cross_attention_outputs[-1]+            present_key_value = present_key_value + cross_attn_present_key_value++        layer_output = apply_chunking_to_forward(+            self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output+        )+        outputs = (layer_output,) + outputs++        # if decoder, return the attn key/values as the last output+        if self.is_decoder:+            outputs = outputs + (present_key_value,)++        return outputs++    def feed_forward_chunk(self, attention_output):+        intermediate_output = self.intermediate(attention_output)+        layer_output = self.output(intermediate_output, attention_output)+        return layer_output+++class VisualBertEncoder(nn.Module):+    def __init__(self, config):+        super().__init__()+        self.config = config+        self.layer = nn.ModuleList([VisualBertLayer(config) for _ in range(config.num_hidden_layers)])++    def forward(+        self,+        hidden_states,+        attention_mask=None,+        head_mask=None,+        encoder_hidden_states=None,+        encoder_attention_mask=None,+        past_key_values=None,+        use_cache=None,+        output_attentions=False,+        output_hidden_states=False,+        return_dict=True,+    ):+        all_hidden_states = () if output_hidden_states else None+        all_self_attentions = () if output_attentions else None+        all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None++        next_decoder_cache = () if use_cache else None+        for i, layer_module in enumerate(self.layer):+            if output_hidden_states:+                all_hidden_states = all_hidden_states + (hidden_states,)++            layer_head_mask = head_mask[i] if head_mask is not None else None+            past_key_value = past_key_values[i] if past_key_values is not None else None++            if getattr(self.config, "gradient_checkpointing", False) and self.training:++                if use_cache:+                    logger.warn(+                        "`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting "+                        "`use_cache=False`..."+                    )+                    use_cache = False++                def create_custom_forward(module):+                    def custom_forward(*inputs):+                        return module(*inputs, past_key_value, output_attentions)++                    return custom_forward++                layer_outputs = torch.utils.checkpoint.checkpoint(+                    create_custom_forward(layer_module),+                    hidden_states,+                    attention_mask,+                    layer_head_mask,+                    encoder_hidden_states,+                    encoder_attention_mask,+                )+            else:+                layer_outputs = layer_module(+                    hidden_states,+                    attention_mask,+                    layer_head_mask,+                    encoder_hidden_states,+                    encoder_attention_mask,+                    past_key_value,+                    output_attentions,+                )++            hidden_states = layer_outputs[0]+            if use_cache:+                next_decoder_cache += (layer_outputs[-1],)+            if output_attentions:+                all_self_attentions = all_self_attentions + (layer_outputs[1],)+                if self.config.add_cross_attention:+                    all_cross_attentions = all_cross_attentions + (layer_outputs[2],)++        if output_hidden_states:+            all_hidden_states = all_hidden_states + (hidden_states,)++        if not return_dict:+            return tuple(+                v+                for v in [+                    hidden_states,+                    next_decoder_cache,+                    all_hidden_states,+                    all_self_attentions,+                    all_cross_attentions,+                ]+                if v is not None+            )+        return BaseModelOutputWithPastAndCrossAttentions(+            last_hidden_state=hidden_states,+            past_key_values=next_decoder_cache,+            hidden_states=all_hidden_states,+            attentions=all_self_attentions,+            cross_attentions=all_cross_attentions,+        )+++class VisualBertPooler(nn.Module):+    def __init__(self, config):+        super().__init__()+        self.dense = nn.Linear(config.hidden_size, config.hidden_size)+        self.activation = nn.Tanh()++    def forward(self, hidden_states):+        # We "pool" the model by simply taking the hidden state corresponding+        # to the first token.+        first_token_tensor = hidden_states[:, 0]+        pooled_output = self.dense(first_token_tensor)+        pooled_output = self.activation(pooled_output)+        return pooled_output+++class VisualBertPredictionHeadTransform(nn.Module):+    def __init__(self, config):+        super().__init__()+        self.dense = nn.Linear(config.hidden_size, config.hidden_size)+        if isinstance(config.hidden_act, str):+            self.transform_act_fn = ACT2FN[config.hidden_act]+        else:+            self.transform_act_fn = config.hidden_act++        # TO-CHECK+        # self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)+        self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps)++    def forward(self, hidden_states):+        hidden_states = self.dense(hidden_states)+        hidden_states = self.transform_act_fn(hidden_states)+        hidden_states = self.LayerNorm(hidden_states)+        return hidden_states+++class VisualBertLMPredictionHead(nn.Module):+    def __init__(self, config):+        super().__init__()+        self.transform = VisualBertPredictionHeadTransform(config)++        # The output weights are the same as the input embeddings, but there is+        # an output-only bias for each token.+        self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)++        self.bias = nn.Parameter(torch.zeros(config.vocab_size))++        # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`+        self.decoder.bias = self.bias++    def forward(self, hidden_states):+        hidden_states = self.transform(hidden_states)+        hidden_states = self.decoder(hidden_states)+        return hidden_states+++class VisualBertOnlyMLMHead(nn.Module):+    def __init__(self, config):+        super().__init__()+        self.predictions = VisualBertLMPredictionHead(config)++    def forward(self, sequence_output):+        prediction_scores = self.predictions(sequence_output)+        return prediction_scores+++class VisualBertOnlySIPHead(nn.Module):  # Sentence-Image Prediction+    def __init__(self, config):+        super().__init__()+        self.seq_relationship = nn.Linear(config.hidden_size, 2)++    def forward(self, pooled_output):+        seq_relationship_score = self.seq_relationship(pooled_output)+        return seq_relationship_score+++class VisualBertPreTrainingHeads(nn.Module):+    def __init__(self, config):+        super().__init__()+        self.predictions = VisualBertLMPredictionHead(config)+        self.seq_relationship = nn.Linear(config.hidden_size, 2)++    def forward(self, sequence_output, pooled_output):+        prediction_scores = self.predictions(sequence_output)+        seq_relationship_score = self.seq_relationship(pooled_output)+        return prediction_scores, seq_relationship_score+++class VisualBertPreTrainedModel(PreTrainedModel):+    """+    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained+    models.+    """++    config_class = VisualBertConfig+    load_tf_weights = load_tf_weights_in_visual_bert+    base_model_prefix = "visual_bert"+    _keys_to_ignore_on_load_missing = [r"position_ids"]++    def _init_weights(self, module):+        """Initialize the weights"""+        if isinstance(module, (nn.Linear, nn.Embedding)):+            # Slightly different from the TF version which uses truncated_normal for initialization+            # cf https://github.com/pytorch/pytorch/pull/5617+            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)++        # TO-CHECK+        # elif isinstance(module, nn.LayerNorm):+        elif isinstance(module, BertLayerNorm):+            module.bias.data.zero_()+            module.weight.data.fill_(1.0)+        if isinstance(module, nn.Linear) and module.bias is not None:+            module.bias.data.zero_()+++@dataclass+class VisualBertForPreTrainingOutput(ModelOutput):+    """+    Output type of :class:`~transformers.VisualBertForPreTraining`.++    Args:+        loss (`optional`, returned when ``labels`` is provided, ``torch.FloatTensor`` of shape :obj:`(1,)`):+            Total loss as the sum of the masked language modeling loss and the sentence-image prediction+            (classification) loss.+        prediction_logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`):+            Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).+        seq_relationship_logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, 2)`):+            Prediction scores of the sentence-image prediction (classification) head (scores of True/False continuation+            before SoftMax).+        hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):+            Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)+            of shape :obj:`(batch_size, sequence_length, hidden_size)`.++            Hidden-states of the model at the output of each layer plus the initial embedding outputs.+        attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):+            Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads,+            sequence_length, sequence_length)`.++            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention+            heads.+    """++    loss: Optional[torch.FloatTensor] = None+    prediction_logits: torch.FloatTensor = None+    seq_relationship_logits: torch.FloatTensor = None+    hidden_states: Optional[Tuple[torch.FloatTensor]] = None+    attentions: Optional[Tuple[torch.FloatTensor]] = None+++VISUAL_BERT_START_DOCSTRING = r"""+    This model inherits from :class:`~transformers.PreTrainedModel`. Check the superclass documentation for the generic+    methods the library implements for all its model (such as downloading or saving, resizing the input embeddings,+    pruning heads etc.)++    This model is also a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`__+    subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to+    general usage and behavior.++    Parameters:+        config (:class:`~transformers.VisualBertConfig`): Model configuration class with all the parameters of the model.+            Initializing with a config file does not load the weights associated with the model, only the+            configuration. Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model+            weights.+"""++VISUAL_BERT_INPUTS_DOCSTRING = r"""+    Args:+        input_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`):+            Indices of input sequence tokens in the vocabulary.++            Indices can be obtained using :class:`~transformers.BertTokenizer`. See+            :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for+            details.++            `What are input IDs? <../glossary.html#input-ids>`__+        attention_mask (:obj:`torch.FloatTensor` of shape :obj:`({0})`, `optional`):+            Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``:++            - 1 for tokens that are **not masked**,+            - 0 for tokens that are **masked**.++            `What are attention masks? <../glossary.html#attention-mask>`__+        token_type_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`, `optional`):+            Segment token indices to indicate first and second portions of the inputs. Indices are selected in ``[0,+            1]``:++            - 0 corresponds to a `sentence A` token,+            - 1 corresponds to a `sentence B` token.++            `What are token type IDs? <../glossary.html#token-type-ids>`_+        position_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`, `optional`):+            Indices of positions of each input sequence tokens in the position embeddings. Selected in the range ``[0,+            config.max_position_embeddings - 1]``.++            `What are position IDs? <../glossary.html#position-ids>`_+        head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`):+            Mask to nullify selected heads of the self-attention modules. Mask values selected in ``[0, 1]``:++            - 1 indicates the head is **not masked**,+            - 0 indicates the head is **masked**.++        inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`({0}, hidden_size)`, `optional`):+            Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation.+            This is useful if you want more control over how to convert :obj:`input_ids` indices into associated+            vectors than the model's internal embedding lookup matrix.++        visual_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, visual_seq_length, visual_embedding_dim)`, `optional`):+            The embedded representation of the visual inputs, generally derived using using an object detector.++        visual_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, visual_seq_length)`, `optional`):+            Mask to avoid performing attention on visual embeddings. Mask values selected in ``[0, 1]``:++            - 1 for tokens that are **not masked**,+            - 0 for tokens that are **masked**.++            `What are attention masks? <../glossary.html#attention-mask>`__+        visual_token_type_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, visual_seq_length)`, `optional`):+            Segment token indices to indicate different portions of the visual embeds. Indices are selected in ``[0,+            1]``:++            - 0 corresponds to a `sentence A` token,+            - 1 corresponds to a `sentence B` token.++            `What are token type IDs? <../glossary.html#token-type-ids>`_++        image_text_alignment (:obj:`torch.LongTensor` of shape :obj:`(batch_size, visual_seq_length, alignment_number)`, `optional`):+            Image-Text alignment uses to decide the position IDs of the visual embeddings.++        output_attentions (:obj:`bool`, `optional`):+            Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned+            tensors for more detail.+        output_hidden_states (:obj:`bool`, `optional`):+            Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for+            more detail.+        return_dict (:obj:`bool`, `optional`):+            Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple.+"""++VISUAL_BERT_VQA_SAMPLE = r"""+    Example::+        >>> # Assumption: `get_visual_embeddings(image)` gets the visual embeddings of the image in the batch.+        >>> from transformers import {tokenizer_class}, {model_class}+        >>> import torch++        >>> tokenizer = {tokenizer_class}.from_pretrained('{tokenizer_checkpoint}')+        >>> model = {model_class}.from_pretrained('{checkpoint}')++        >>> text = "Who is eating the apple?"+        >>> inputs = tokenizer(text, return_tensors='pt')+        >>> visual_embeds = get_visual_embeddings(image).unsqueeze(0)+        >>> visual_token_type_ids = torch.ones(visual_embeds.shape[:-1], dtype=torch.long) #example+        >>> visual_attention_mask = torch.ones(visual_embeds.shape[:-1], dtype=torch.float)++        >>> inputs.update({{+            "visual_embeds": visual_embeds,+            "visual_token_type_ids": visual_token_type_ids,+            "visual_attention_mask": visual_attention_mask+        }})++        >>> labels = torch.tensor([[0.0,1.0]]).unsqueeze(0)  # Batch size 1, Num labels 2++        >>> outputs = model(**inputs, labels=labels)+        >>> loss = outputs.loss+        >>> scores = outputs.logits+"""++VISUAL_BERT_NLVR_SAMPLE = r"""+    Example::+        >>> # Assumption: `get_visual_embeddings(image)` gets the visual embeddings of the image in the batch.+        >>> from transformers import {tokenizer_class}, {model_class}+        >>> import torch++        >>> tokenizer = {tokenizer_class}.from_pretrained('{tokenizer_checkpoint}')+        >>> model = {model_class}.from_pretrained('{checkpoint}')++        >>> text = "Who is eating the apple?"+        >>> inputs = tokenizer(text, return_tensors='pt')+        >>> visual_embeds = get_visual_embeddings(image).unsqueeze(0)+        >>> visual_token_type_ids = torch.ones(visual_embeds.shape[:-1], dtype=torch.long) #example+        >>> visual_attention_mask = torch.ones(visual_embeds.shape[:-1], dtype=torch.float)++        >>> inputs.update({{+            "visual_embeds": visual_embeds,+            "visual_token_type_ids": visual_token_type_ids,+            "visual_attention_mask": visual_attention_mask+        }})++        >>> labels = torch.tensor(1).unsqueeze(0)  # Batch size 1, Num choices 2++        >>> outputs = model(**inputs, labels=labels)+        >>> loss = outputs.loss+        >>> scores = outputs.logits+"""+++VISUAL_BERT_VQA_ADVANCED_SAMPLE = r"""+    Example::+        >>> # Assumption: `get_visual_embeddings(image)` gets the visual embeddings of the image in the batch.+        >>> from transformers import {tokenizer_class}, {model_class}+        >>> import torch++        >>> tokenizer = {tokenizer_class}.from_pretrained('{tokenizer_checkpoint}')+        >>> model = {model_class}.from_pretrained('{checkpoint}')++        >>> text = "Who is eating the apple?"+        >>> inputs = tokenizer(text, return_tensors='pt')+        >>> visual_embeds = get_visual_embeddings(image).unsqueeze(0)+        >>> visual_token_type_ids = torch.ones(visual_embeds.shape[:-1], dtype=torch.long)+        >>> visual_attention_mask = torch.ones(visual_embeds.shape[:-1], dtype=torch.float)++        >>> inputs.update({{+            "visual_embeds": visual_embeds,+            "visual_token_type_ids": visual_token_type_ids,+            "visual_attention_mask": visual_attention_mask+        }})++        >>> labels = torch.ones((1, inputs["input_ids"].shape[-1]+visual_embeds.shape[-2])) # Batch size 1++        >>> outputs = model(**inputs, labels=labels)+        >>> loss = outputs.loss+        >>> scores = outputs.logits+"""++VISUAL_BERT_FLICKR_SAMPLE = r"""+    Example::+        >>> # Assumption: `get_visual_embeddings(image)` gets the visual embeddings of the image in the batch.+        >>> from transformers import {tokenizer_class}, {model_class}+        >>> import torch++        >>> tokenizer = {tokenizer_class}.from_pretrained('{tokenizer_checkpoint}')+        >>> model = {model_class}.from_pretrained('{checkpoint}')++        >>> text = "Who is eating the apple?"+        >>> inputs = tokenizer(text, return_tensors='pt')+        >>> visual_embeds = get_visual_embeddings(image).unsqueeze(0)+        >>> visual_token_type_ids = torch.ones(visual_embeds.shape[:-1], dtype=torch.long) #example+        >>> visual_attention_mask = torch.ones(visual_embeds.shape[:-1], dtype=torch.float)+        >>> flickr_position = torch.ones((1, inputs["input_ids"].shape[-1]+visual_embeds.shape[-2]))++        >>> inputs.update({{+            "flickr_position": flickr_position,+            "visual_embeds": visual_embeds,+            "visual_token_type_ids": visual_token_type_ids,+            "visual_attention_mask": visual_attention_mask+        }})++        >>> labels = torch.ones((1, inputs["input_ids"].shape[-1]+visual_embeds.shape[-2], visual_embeds.shape[-2])) # Batch size 1++        >>> outputs = model(**inputs, labels=labels)+        >>> loss = outputs.loss+        >>> scores = outputs.logits+"""++VISUAL_BERT_PRE_TRAINING_SAMPLE = r"""+    Example::+        >>> # Assumption: `get_visual_embeddings(image)` gets the visual embeddings of the image in the batch.+        >>> from transformers import {tokenizer_class}, {model_class}+        >>> import torch++        >>> tokenizer = {tokenizer_class}.from_pretrained('{tokenizer_checkpoint}')+        >>> model = {model_class}.from_pretrained('{checkpoint}')++        >>> inputs = tokenizer("The capital of France is {mask}.", return_tensors="pt")+        >>> visual_embeds = get_visual_embeddings(image).unsqueeze(0)+        >>> visual_token_type_ids = torch.ones(visual_embeds.shape[:-1], dtype=torch.long) #example+        >>> visual_attention_mask = torch.ones(visual_embeds.shape[:-1], dtype=torch.float)++        >>> inputs.update({{+            "visual_embeds": visual_embeds,+            "visual_token_type_ids": visual_token_type_ids,+            "visual_attention_mask": visual_attention_mask+        }})++        >>> labels = tokenizer("The capital of France is Paris.", return_tensors="pt", padding="max_length", max_length=inputs["input_ids"].shape[-1]+visual_embeds.shape[-2])["input_ids"]+        >>> sentence_image_labels = torch.tensor(1).unsqueeze(0) # Batch_size+++        >>> outputs = model(**inputs, labels=labels, sentence_image_labels=sentence_image_labels)+        >>> loss = outputs.loss+        >>> prediction_logits = outputs.prediction_logits+        >>> seq_relationship_logits = outputs.seq_relationship_logits+"""++VISUAL_BERT_MODEL_SAMPLE = r"""+    Example::+        >>> # Assumption: `get_visual_embeddings(image)` gets the visual embeddings of the image.+        >>> from transformers import {tokenizer_class}, {model_class}+        >>> import torch++        >>> tokenizer = {tokenizer_class}.from_pretrained('{tokenizer_checkpoint}')+        >>> model = {model_class}.from_pretrained('{checkpoint}')++        >>> inputs = tokenizer("The capital of France is Paris.", return_tensors="pt")+        >>> visual_embeds = get_visual_embeddings(image).unsqueeze(0)+        >>> visual_token_type_ids = torch.ones(visual_embeds.shape[:-1], dtype=torch.long) #example+        >>> visual_attention_mask = torch.ones(visual_embeds.shape[:-1], dtype=torch.float)++        >>> inputs.update({{+            "visual_embeds": visual_embeds,+            "visual_token_type_ids": visual_token_type_ids,+            "visual_attention_mask": visual_attention_mask+        }})++        >>> outputs = model(**inputs)++        >>> last_hidden_states = outputs.last_hidden_state+"""++VISUAL_BERT_MULTIPLE_CHOICE_SAMPLE = r"""+    Example::++        >>> from transformers import {tokenizer_class}, {model_class}+        >>> import torch++        >>> tokenizer = {tokenizer_class}.from_pretrained('{tokenizer_checkpoint}')+        >>> model = {model_class}.from_pretrained('{checkpoint}')++        >>> prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced."+        >>> choice0 = "It is eaten with a fork and a knife."+        >>> choice1 = "It is eaten while held in the hand."++        >>> visual_embeds = get_visual_embeddings(image)+        >>> visual_embeds = visual_embeds.expand(1, 2, *visual_embeds.shape) # (batch_size, num_choices, visual_seq_length, visual_embedding_dim)+        >>> visual_token_type_ids = torch.ones(visual_embeds.shape[:-1], dtype=torch.long)+        >>> visual_attention_mask = torch.ones(visual_embeds.shape[:-1], dtype=torch.float)++        >>> labels = torch.tensor(0).unsqueeze(0)  # choice0 is correct (according to Wikipedia ;)), batch size 1++        >>> encoding = tokenizer([[prompt, prompt], [choice0, choice1]], return_tensors='pt', padding=True)+        >>> outputs = model(**{{k: v.unsqueeze(0) for k,v in encoding.items()}}, visual_embeds=visual_embeds, visual_attention_mask=visual_attention_mask, visual_token_type_ids=visual_token_type_ids, labels=labels)  # batch size is 1++        >>> loss = outputs.loss+        >>> logits = outputs.logits+"""+++@add_start_docstrings(+    "The bare VisualBert Model transformer outputting raw hidden-states without any specific head on top.",+    VISUAL_BERT_START_DOCSTRING,+)+class VisualBertModel(VisualBertPreTrainedModel):+    """++    The model can behave as an encoder (with only self-attention) following the architecture described in `Attention is+    all you need <https://arxiv.org/abs/1706.03762>`__ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,+    Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.+    """++    def __init__(self, config, add_pooling_layer=True):+        super().__init__(config)+        self.config = config++        self.embeddings = VisualBertEmbeddings(config)+        self.encoder = VisualBertEncoder(config)++        self.pooler = (+            VisualBertPooler(config) if add_pooling_layer else None+        )  # TO-DO: Check if pooler is needed necessarily.++        self.bypass_transformer = config.bypass_transformer++        if self.bypass_transformer:+            self.additional_layer = VisualBertLayer(config)++        # TO-CHECK: This next line is from old BERT code, which is not used anymore.+        # self.output_attention_weights = config.output_attention_weights++        self.init_weights()  # self.apply(self.init_bert_weights) #Vestiges of old code++    # TO-CHECK+    def get_input_embeddings(self):+        return self.embeddings.word_embeddings++    # TO-CHECK+    def set_input_embeddings(self, value):+        self.embeddings.word_embeddings = value++    # TO-CHECK+    def _prune_heads(self, heads_to_prune):+        """+        Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base+        class PreTrainedModel+        """+        for layer, heads in heads_to_prune.items():+            self.encoder.layer[layer].attention.prune_heads(heads)++    @add_start_docstrings_to_model_forward(VISUAL_BERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))+    @replace_return_docstrings(output_type=BaseModelOutputWithPoolingAndCrossAttentions, config_class=_CONFIG_FOR_DOC)+    @add_code_sample_docstrings(+        tokenizer_class=_TOKENIZER_FOR_DOC,+        tokenizer_checkpoint=_TOKENIZER_CHECKPOINT,+        checkpoint="gchhablani/visualbert-vqa-coco-pre",+        output_type=BaseModelOutputWithPoolingAndCrossAttentions,+        config_class="gchhablani/visualbert-vqa-coco-pre",+        code_sample=VISUAL_BERT_MODEL_SAMPLE,+    )+    def forward(+        self,+        input_ids=None,+        attention_mask=None,+        token_type_ids=None,+        position_ids=None,+        head_mask=None,+        inputs_embeds=None,+        visual_embeds=None,+        visual_attention_mask=None,+        visual_token_type_ids=None,+        image_text_alignment=None,+        # encoder_hidden_states=None,+        # encoder_attention_mask=None,+        # past_key_values=None,+        # use_cache=None,+        output_attentions=None,+        output_hidden_states=None,+        return_dict=None,+    ):++        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions+        output_hidden_states = (+            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states+        )+        return_dict = return_dict if return_dict is not None else self.config.use_return_dict++        # if self.config.is_decoder:+        #     use_cache = use_cache if use_cache is not None else self.config.use_cache+        # else:+        use_cache = False++        if input_ids is not None and inputs_embeds is not None:+            raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")+        elif input_ids is not None:+            input_shape = input_ids.size()+            batch_size, seq_length = input_shape+        elif inputs_embeds is not None:+            input_shape = inputs_embeds.size()[:-1]+            batch_size, seq_length = input_shape+        else:+            raise ValueError("You have to specify either input_ids or inputs_embeds")++        device = input_ids.device if input_ids is not None else inputs_embeds.device++        # past_key_values_length+        # past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0++        if attention_mask is None:+            attention_mask = torch.ones(((batch_size, seq_length)), device=device)+            # attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)+        if token_type_ids is None:+            token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)++        if visual_embeds is not None:+            visual_input_shape = visual_embeds.size()[:-1]+            _, visual_seq_length = visual_input_shape+            if visual_token_type_ids is None:+                visual_token_type_ids = torch.zeros(visual_input_shape, dtype=torch.long, device=device)++            if visual_attention_mask is None:+                visual_attention_mask = torch.ones(visual_input_shape, device=device)++        # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]+        # ourselves in which case we just need to make it broadcastable to all heads.++        # TO-CHECK : Whether input_shape+visual_input_shape is correct.+        if visual_embeds is not None:+            combined_attention_mask = torch.cat((attention_mask, visual_attention_mask), dim=-1)+            extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(+                combined_attention_mask, [batch_size, input_shape + visual_input_shape], device+            )+        else:+            extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(+                attention_mask, [batch_size, input_shape], device+            )++        # If a 2D or 3D attention mask is provided for the cross-attention+        # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]+        # if self.config.is_decoder and encoder_hidden_states is not None:+        #     encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()+        #     encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)+        #     if encoder_attention_mask is None:+        #         encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)+        #     encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)+        # else:+        #     encoder_extended_attention_mask = None++        # Prepare head mask if needed+        # 1.0 in head_mask indicate we keep the head+        # attention_probs has shape bsz x n_heads x N x N+        # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]+        # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]+        head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)++        embedding_output = self.embeddings(+            input_ids=input_ids,+            position_ids=position_ids,+            token_type_ids=token_type_ids,+            inputs_embeds=inputs_embeds,+            visual_embeds=visual_embeds,+            visual_token_type_ids=visual_token_type_ids,+            image_text_alignment=image_text_alignment,+            # past_key_values_length=past_key_values_length,+        )++        if self.bypass_transformer and visual_embeds is not None:+            assert output_hidden_states is None  # TO-DO: Need to check if this is correct.+            text_length = input_ids.size(1)+            text_embedding_output = embedding_output[:, :text_length, :]+            visual_embedding_output = embedding_output[:, text_length:, :]++            text_extended_attention_mask = extended_attention_mask[:, :, text_length, :text_length]+            # text_encoder_hidden_states = encoder_hidden_states[:, :text_length, :]+            # text_encoder_attention_mask = encoder_extended_attention_mask[:, :, :text_length, :text_length]++            # TO-DO: Check how past-key values work and whether they are required to be modified and added here++            encoded_outputs = self.encoder(+                text_embedding_output,+                attention_mask=text_extended_attention_mask,+                # encoder_hidden_states=text_encoder_hidden_states,+                # encoder_attention_mask=text_encoder_attention_mask,+                use_cache=use_cache,+                output_attentions=output_attentions,+                output_hidden_states=output_hidden_states,+                return_dict=return_dict,+            )+            sequence_output = encoded_outputs[0]+            concatenated_input = torch.cat((sequence_output, visual_embedding_output), dim=1)+            sequence_output = self.additional_layer(concatenated_input, extended_attention_mask)+            pooled_output = self.pooler(sequence_output) if self.pooler is not None else None++        else:+            encoder_outputs = self.encoder(+                embedding_output,+                attention_mask=extended_attention_mask,+                head_mask=head_mask,+                # encoder_hidden_states=encoder_hidden_states,+                # encoder_attention_mask=encoder_extended_attention_mask,+                # past_key_values=past_key_values,+                use_cache=use_cache,+                output_attentions=output_attentions,+                output_hidden_states=output_hidden_states,+                return_dict=return_dict,+            )+            sequence_output = encoder_outputs[0]++            pooled_output = self.pooler(sequence_output) if self.pooler is not None else None++        if not return_dict:+            return (sequence_output, pooled_output) + encoder_outputs[1:]++        return BaseModelOutputWithPoolingAndCrossAttentions(  # Changed+            last_hidden_state=sequence_output,+            pooler_output=pooled_output,+            past_key_values=encoder_outputs.past_key_values,+            hidden_states=encoder_outputs.hidden_states,+            attentions=encoder_outputs.attentions,+            cross_attentions=encoder_outputs.cross_attentions,+        )+++# TO-DO: Check if the case where we don't want to calculate is_random_next loss.+# The is a case where during pre-training, in the original code, it is checked if the is_random_next is None+# In this case, the next_sentence_loss is not calculated.+++@add_start_docstrings(+    """+    VisualBert Model with two heads on top as done during the pretraining: a `masked language modeling` head and a+    `sentence-image prediction (classification)` head.+    """,+    VISUAL_BERT_START_DOCSTRING,+)+class VisualBertForPreTraining(VisualBertPreTrainedModel):+    def __init__(self, config):+        super().__init__(config)++        self.visual_bert = VisualBertModel(config)+        self.cls = VisualBertPreTrainingHeads(config)++        # UNUSED+        # self.cut_first = cut_first+        # self.hard_cap_seq_len = hard_cap_seq_len++        self.init_weights()++    def get_output_embeddings(self):+        return self.cls.predictions.decoder++    def set_output_embeddings(self, new_embeddings):+        self.cls.predictions.decoder = new_embeddings++    @add_start_docstrings_to_model_forward(VISUAL_BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))+    @replace_return_docstrings(output_type=VisualBertForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)+    @add_code_sample_docstrings(+        tokenizer_class=_TOKENIZER_FOR_DOC,+        tokenizer_checkpoint=_TOKENIZER_CHECKPOINT,+        checkpoint="gchhablani/visualbert-vqa-coco-pre",+        mask="[MASK]",+        output_type=VisualBertForPreTrainingOutput,+        config_class="gchhablani/visualbert-vqa-coco-pre",+        code_sample=VISUAL_BERT_PRE_TRAINING_SAMPLE,+    )+    def forward(+        self,+        input_ids=None,+        attention_mask=None,+        token_type_ids=None,+        position_ids=None,+        head_mask=None,+        inputs_embeds=None,+        visual_embeds=None,+        visual_attention_mask=None,+        visual_token_type_ids=None,+        image_text_alignment=None,+        output_attentions=None,+        output_hidden_states=None,+        return_dict=None,+        labels=None,+        sentence_image_labels=None,+    ):+        r"""+        labels (:obj:`torch.LongTensor` of shape ``(batch_size, total_sequence_length)``, `optional`):+            Labels for computing the masked language modeling loss. Indices should be in ``[-100, 0, ...,+            config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored+            (masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]``+        sentence_image_labels (``torch.LongTensor`` of shape ``(batch_size,)``, `optional`):+            Labels for computing the sentence-image prediction (classification) loss. Input should be a sequence pair+            (see :obj:`input_ids` docstring) Indices should be in ``[0, 1]``:++            - 0 indicates sequence B is a matching pair of sequence A for the given image,+            - 1 indicates sequence B is a random sequence w.r.t A for the given image.++        Returns:+        """+        return_dict = return_dict if return_dict is not None else self.config.use_return_dict++        outputs = self.visual_bert(+            input_ids,+            attention_mask=attention_mask,+            token_type_ids=token_type_ids,+            position_ids=position_ids,+            head_mask=head_mask,+            inputs_embeds=inputs_embeds,+            visual_embeds=visual_embeds,+            visual_attention_mask=visual_attention_mask,+            visual_token_type_ids=visual_token_type_ids,+            image_text_alignment=image_text_alignment,+            output_attentions=output_attentions,+            output_hidden_states=output_hidden_states,+            return_dict=return_dict,+        )++        sequence_output, pooled_output = outputs[:2]+        prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output)++        total_loss = None+        if labels is not None and sentence_image_labels is not None:+            assert labels.size(-1) == attention_mask.size(-1) + visual_attention_mask.size(+                -1+            ), f"The labels provided should have same sequence length as total attention mask. Found labels with sequence length {labels.size(-1)}, expected {attention_mask.size(-1)+ visual_attention_mask.size(-1)}."++            loss_fct = CrossEntropyLoss()+            masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))+            sentence_image_loss = loss_fct(seq_relationship_score.view(-1, 2), sentence_image_labels.view(-1))+            total_loss = masked_lm_loss + sentence_image_loss++        # TO-CHECK+        if labels is not None and sentence_image_labels is None:+            assert labels.size(-1) == attention_mask.size(-1) + visual_attention_mask.size(+                -1+            ), f"The labels provided should have same sequence length as total attention mask. Found labels with sequence length {labels.size(-1)}, expected {attention_mask.size(-1)+ visual_attention_mask.size(-1)}."

Will do.

gchhablani

comment created time in an hour

Pull request review commenthuggingface/transformers

Add VisualBERT

+# coding=utf-8+# Copyright 2021 The HuggingFace Inc. team. All rights reserved.+#+# Licensed under the Apache License, Version 2.0 (the "License");+# you may not use this file except in compliance with the License.+# You may obtain a copy of the License at+#+#     http://www.apache.org/licenses/LICENSE-2.0+#+# Unless required by applicable law or agreed to in writing, software+# distributed under the License is distributed on an "AS IS" BASIS,+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.+# See the License for the specific language governing permissions and+# limitations under the License.+""" PyTorch VisualBERT model. """+++import math+import os+from dataclasses import dataclass+from typing import Optional, Tuple++import torch+import torch.utils.checkpoint+from torch import nn+from torch.nn import CrossEntropyLoss, KLDivLoss, LogSoftmax++from ...activations import ACT2FN+from ...file_utils import (+    ModelOutput,+    _prepare_output_docstrings,+    add_start_docstrings,+    add_start_docstrings_to_model_forward,+    replace_return_docstrings,+)+from ...modeling_outputs import (+    BaseModelOutputWithPastAndCrossAttentions,+    BaseModelOutputWithPoolingAndCrossAttentions,+    MaskedLMOutput,+    MultipleChoiceModelOutput,+    SequenceClassifierOutput,+)+from ...modeling_utils import (+    PreTrainedModel,+    apply_chunking_to_forward,+    find_pruneable_heads_and_indices,+    prune_linear_layer,+)+from ...utils import logging+from .configuration_visual_bert import VisualBertConfig+++logger = logging.get_logger(__name__)++_CONFIG_FOR_DOC = "VisualBertConfig"+_TOKENIZER_FOR_DOC = "BertTokenizer"+_TOKENIZER_CHECKPOINT = "bert-base-uncased"++VISUAL_BERT_PRETRAINED_MODEL_ARCHIVE_LIST = [+    "gchhablani/visualbert-vqa",+    "gchhablani/visualbert-vqa-pre",+    "gchhablani/visualbert-vqa-coco-pre",+    "gchhablani/visualbert-vcr",+    "gchhablani/visualbert-vcr-pre",+    "gchhablani/visualbert-vcr-coco-pre",+    "gchhablani/visualbert-nlvr2",+    "gchhablani/visualbert-nlvr2-pre",+    "gchhablani/visualbert-nlvr2-coco-pre"+    # See all VisualBERT models at https://huggingface.co/models?filter=visual_bert+]+++# TO-CHECK+def load_tf_weights_in_visual_bert(model, config, tf_checkpoint_path):+    """Load tf checkpoints in a pytorch model."""+    try:+        import re++        import numpy as np+        import tensorflow as tf+    except ImportError:+        logger.error(+            "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see "+            "https://www.tensorflow.org/install/ for installation instructions."+        )+        raise+    tf_path = os.path.abspath(tf_checkpoint_path)+    logger.info("Converting TensorFlow checkpoint from {}".format(tf_path))+    # Load weights from TF model+    init_vars = tf.train.list_variables(tf_path)+    names = []+    arrays = []+    for name, shape in init_vars:+        logger.info("Loading TF weight {} with shape {}".format(name, shape))+        array = tf.train.load_variable(tf_path, name)+        names.append(name)+        arrays.append(array)++    for name, array in zip(names, arrays):+        name = name.split("/")+        # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v+        # which are not required for using pretrained model+        if any(+            n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"]+            for n in name+        ):+            logger.info("Skipping {}".format("/".join(name)))+            continue+        pointer = model+        for m_name in name:+            if re.fullmatch(r"[A-Za-z]+_\d+", m_name):+                scope_names = re.split(r"_(\d+)", m_name)+            else:+                scope_names = [m_name]+            if scope_names[0] == "kernel" or scope_names[0] == "gamma":+                pointer = getattr(pointer, "weight")+            elif scope_names[0] == "output_bias" or scope_names[0] == "beta":+                pointer = getattr(pointer, "bias")+            elif scope_names[0] == "output_weights":+                pointer = getattr(pointer, "weight")+            elif scope_names[0] == "squad":+                pointer = getattr(pointer, "classifier")+            else:+                try:+                    pointer = getattr(pointer, scope_names[0])+                except AttributeError:+                    logger.info("Skipping {}".format("/".join(name)))+                    continue+            if len(scope_names) >= 2:+                num = int(scope_names[1])+                pointer = pointer[num]+        if m_name[-11:] == "_embeddings":+            pointer = getattr(pointer, "weight")+        elif m_name == "kernel":+            array = np.transpose(array)+        try:+            assert (+                pointer.shape == array.shape+            ), f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched"+        except AssertionError as e:+            e.args += (pointer.shape, array.shape)+            raise+        logger.info("Initialize PyTorch weight {}".format(name))+        pointer.data = torch.from_numpy(array)+    return model+++def mish(x):+    return x * torch.tanh(nn.functional.softplus(x))+++def add_code_sample_docstrings(+    *docstr,+    tokenizer_class=None,+    tokenizer_checkpoint=None,+    checkpoint=None,+    output_type=None,+    config_class=None,+    mask=None,+    model_cls=None,+    code_sample=None+):+    def docstring_decorator(fn):+        # model_class defaults to function's class if not specified otherwise+        model_class = fn.__qualname__.split(".")[0] if model_cls is None else model_cls++        doc_kwargs = dict(+            model_class=model_class,+            tokenizer_class=tokenizer_class,+            checkpoint=checkpoint,+            mask=mask,+            tokenizer_checkpoint=tokenizer_checkpoint,+        )++        output_doc = _prepare_output_docstrings(output_type, config_class) if output_type is not None else ""++        built_doc = code_sample.format(**doc_kwargs)+        fn.__doc__ = (fn.__doc__ or "") + "".join(docstr) + output_doc + built_doc+        return fn++    return docstring_decorator+++# TO-CHECK: Vestige of the original code+++class BertLayerNorm(nn.Module):+    def __init__(self, hidden_size, eps=1e-12):+        """Construct a layernorm module in the TF style (epsilon inside the square root)."""+        super(BertLayerNorm, self).__init__()+        self.weight = nn.Parameter(torch.ones(hidden_size))+        self.bias = nn.Parameter(torch.zeros(hidden_size))+        self.variance_epsilon = eps++    def forward(self, x):+        u = x.mean(-1, keepdim=True)+        s = (x - u).pow(2).mean(-1, keepdim=True)+        x = (x - u) / torch.sqrt(s + self.variance_epsilon)+        return self.weight * x + self.bias+++class VisualBertEmbeddings(nn.Module):+    """Construct the embeddings from word, position and token_type embeddings and visual embeddings."""++    def __init__(self, config):+        super().__init__()+        self.word_embeddings = nn.Embedding(+            config.vocab_size, config.hidden_size+        )  # TO-CHECK: , padding_idx=config.pad_token_id+        self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)+        self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)++        # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load+        # any TensorFlow checkpoint file++        # TO-CHECK+        # self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)+        self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps)  # original eps=1e-12+        self.dropout = nn.Dropout(config.hidden_dropout_prob)++        # TO-CHECK+        # position_ids (1, len position emb) is contiguous in memory and exported when serialized+        # self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))+        self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")++        # For Visual Features+        # Segment and position embedding for image features+        self.visual_token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)+        self.visual_position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)++        # TO-CHECK: Check if register buffer is needed for Visual features+        self.visual_projection = nn.Linear(config.visual_embedding_dim, config.hidden_size)++    # TO-CHECK: Check how to incorporate this. This is being called outside the classes.+    # def special_intialize(self):+    #     ### This is a bit unorthodox. The better way might be to add an inititilizer to AllenNLP.+    #     # This function is used to initialize the token_type_embeddings_visual and positiona_embedding_visual, just incase.+    #     self.token_type_embeddings_visual.weight = torch.nn.Parameter(deepcopy(self.token_type_embeddings.weight.data), requires_grad = True)+    #     self.position_embeddings_visual.weight = torch.nn.Parameter(deepcopy(self.position_embeddings.weight.data), requires_grad = True)+    #     return++    def forward(+        self,+        input_ids=None,  # TO-CHECK+        token_type_ids=None,+        position_ids=None,+        inputs_embeds=None,+        visual_embeds=None,+        visual_token_type_ids=None,+        image_text_alignment=None,+        # past_key_values_length=0, # TO-CHECK+    ):+        # TO-CHECK: Check if `confidence=None` and `visual_position_embeds=None` (or id) is needed.+        # `position_embeddings_visual` is not used in the original code.++        if input_ids is not None:+            input_shape = input_ids.size()+        else:+            input_shape = inputs_embeds.size()[:-1]++        seq_length = input_shape[1]++        # TO-CHECK+        # if position_ids is None:+        #     position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]+        if input_ids is not None:+            position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device)+            position_ids = position_ids.unsqueeze(0).expand_as(input_ids)+        else:+            position_ids = torch.arange(seq_length, dtype=torch.long, device=inputs_embeds.device)+            position_ids = position_ids.unsqueeze(0).expand(input_shape)++        if token_type_ids is None:+            token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)++        if inputs_embeds is None:+            inputs_embeds = self.word_embeddings(input_ids)++        token_type_embeddings = self.token_type_embeddings(token_type_ids)++        embeddings = inputs_embeds + token_type_embeddings+        if self.position_embedding_type == "absolute":+            position_embeddings = self.position_embeddings(position_ids)+            embeddings += position_embeddings++        if visual_embeds is not None:+            visual_embeds = self.visual_projection(visual_embeds)+            visual_token_type_embeddings = self.visual_token_type_embeddings(visual_token_type_ids)++            if image_text_alignment is not None:++                # TO-DO: Find a way to handle this in a better way.+                # image_text_alignment = Batch x image_length x alignment_number. Each element denotes the position of the word corresponding to the image feature. -1 is the padding value.+                image_text_alignment_mask = (image_text_alignment != -1).long()+                # Get rid of the -1.+                image_text_alignment = image_text_alignment_mask * image_text_alignment++                # Batch x image_length x alignment length x dim+                visual_position_embeddings = self.position_embeddings(+                    image_text_alignment+                ) * image_text_alignment_mask.to(dtype=next(self.parameters()).dtype).unsqueeze(-1)+                visual_position_embeddings = visual_position_embeddings.sum(2)++                # We want to averge along the alignment_number dimension.+                image_text_alignment_mask = image_text_alignment_mask.to(dtype=next(self.parameters()).dtype).sum(2)+                image_text_alignment_mask[image_text_alignment_mask == 0] = 1  # Avoid devide by zero error+                visual_position_embeddings = visual_position_embeddings / image_text_alignment_mask.unsqueeze(-1)++                visual_position_ids = torch.zeros(+                    *visual_embeds.size()[:-1], dtype=torch.long, device=visual_embeds.device+                )  # They use .cuda() but I believe this will be same as visual_embeds device.++                # When fine-tuning the detector , the image_text_alignment is sometimes padded too long.+                if visual_position_embeddings.size(1) != visual_embeds.size(1):+                    assert visual_position_embeddings.size(1) >= visual_embeds.size(1)+                    visual_position_embeddings = visual_position_embeddings[:, : visual_embeds.size(1), :]++                visual_position_embeddings = visual_position_embeddings + self.visual_position_embeddings(+                    visual_position_ids+                )+            else:+                visual_position_ids = torch.zeros(+                    *visual_embeds.size()[:-1], dtype=torch.long, device=visual_embeds.device+                )  # They use .cuda() but I believe this will be same as visual_embeds device.+                visual_position_embeddings = self.visual_position_embeddings(visual_position_ids)++            visual_embeddings = visual_embeds + visual_position_embeddings + visual_token_type_embeddings++            # Concate the two:+            embeddings = torch.cat((embeddings, visual_embeddings), dim=1)  # concat the visual embeddings++        embeddings = self.LayerNorm(embeddings)+        embeddings = self.dropout(embeddings)+        return embeddings+++class VisualBertSelfAttention(nn.Module):+    def __init__(self, config):+        super().__init__()+        if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):+            raise ValueError(+                "The hidden size (%d) is not a multiple of the number of attention "+                "heads (%d)" % (config.hidden_size, config.num_attention_heads)+            )++        self.num_attention_heads = config.num_attention_heads+        self.attention_head_size = int(config.hidden_size / config.num_attention_heads)+        self.all_head_size = self.num_attention_heads * self.attention_head_size++        self.query = nn.Linear(config.hidden_size, self.all_head_size)+        self.key = nn.Linear(config.hidden_size, self.all_head_size)+        self.value = nn.Linear(config.hidden_size, self.all_head_size)++        self.dropout = nn.Dropout(config.attention_probs_dropout_prob)+        self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")+        if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":+            self.max_position_embeddings = config.max_position_embeddings+            self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)++        self.is_decoder = config.is_decoder++    def transpose_for_scores(self, x):+        new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)+        x = x.view(*new_x_shape)+        return x.permute(0, 2, 1, 3)++    def forward(+        self,+        hidden_states,+        attention_mask=None,+        head_mask=None,+        encoder_hidden_states=None,+        encoder_attention_mask=None,+        past_key_value=None,+        output_attentions=False,+    ):+        mixed_query_layer = self.query(hidden_states)++        # If this is instantiated as a cross-attention module, the keys+        # and values come from an encoder; the attention mask needs to be+        # such that the encoder's padding tokens are not attended to.+        is_cross_attention = encoder_hidden_states is not None++        if is_cross_attention and past_key_value is not None:+            # reuse k,v, cross_attentions+            key_layer = past_key_value[0]+            value_layer = past_key_value[1]+            attention_mask = encoder_attention_mask+        elif is_cross_attention:+            key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))+            value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))+            attention_mask = encoder_attention_mask+        elif past_key_value is not None:+            key_layer = self.transpose_for_scores(self.key(hidden_states))+            value_layer = self.transpose_for_scores(self.value(hidden_states))+            key_layer = torch.cat([past_key_value[0], key_layer], dim=2)+            value_layer = torch.cat([past_key_value[1], value_layer], dim=2)+        else:+            key_layer = self.transpose_for_scores(self.key(hidden_states))+            value_layer = self.transpose_for_scores(self.value(hidden_states))++        query_layer = self.transpose_for_scores(mixed_query_layer)++        if self.is_decoder:+            # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.+            # Further calls to cross_attention layer can then reuse all cross-attention+            # key/value_states (first "if" case)+            # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of+            # all previous decoder key/value_states. Further calls to uni-directional self-attention+            # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)+            # if encoder bi-directional self-attention `past_key_value` is always `None`+            past_key_value = (key_layer, value_layer)++        # Take the dot product between "query" and "key" to get the raw attention scores.+        attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))++        if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":+            seq_length = hidden_states.size()[1]+            position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)+            position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1)+            distance = position_ids_l - position_ids_r+            positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)+            positional_embedding = positional_embedding.to(dtype=query_layer.dtype)  # fp16 compatibility++            if self.position_embedding_type == "relative_key":+                relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)+                attention_scores = attention_scores + relative_position_scores+            elif self.position_embedding_type == "relative_key_query":+                relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)+                relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)+                attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key++        attention_scores = attention_scores / math.sqrt(self.attention_head_size)+        if attention_mask is not None:+            # Apply the attention mask is (precomputed for all layers in VisualBertModel forward() function)+            attention_scores = attention_scores + attention_mask++        # Normalize the attention scores to probabilities.+        attention_probs = nn.Softmax(dim=-1)(attention_scores)++        # This is actually dropping out entire tokens to attend to, which might+        # seem a bit unusual, but is taken from the original Transformer paper.+        attention_probs = self.dropout(attention_probs)++        # Mask heads if we want to+        if head_mask is not None:+            attention_probs = attention_probs * head_mask++        context_layer = torch.matmul(attention_probs, value_layer)++        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()+        new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)+        context_layer = context_layer.view(*new_context_layer_shape)++        outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)++        if self.is_decoder:+            outputs = outputs + (past_key_value,)+        return outputs+++class VisualBertSelfOutput(nn.Module):+    def __init__(self, config):+        super().__init__()+        self.dense = nn.Linear(config.hidden_size, config.hidden_size)++        # TO-CHECK+        # self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)+        self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps)  # original eps=1e-12+        self.dropout = nn.Dropout(config.hidden_dropout_prob)++    def forward(self, hidden_states, input_tensor):+        hidden_states = self.dense(hidden_states)+        hidden_states = self.dropout(hidden_states)+        hidden_states = self.LayerNorm(hidden_states + input_tensor)+        return hidden_states+++class VisualBertAttention(nn.Module):+    def __init__(self, config):+        super().__init__()+        self.self = VisualBertSelfAttention(config)+        self.output = VisualBertSelfOutput(config)+        self.pruned_heads = set()++    def prune_heads(self, heads):+        if len(heads) == 0:+            return+        heads, index = find_pruneable_heads_and_indices(+            heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads+        )++        # Prune linear layers+        self.self.query = prune_linear_layer(self.self.query, index)+        self.self.key = prune_linear_layer(self.self.key, index)+        self.self.value = prune_linear_layer(self.self.value, index)+        self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)++        # Update hyper params and store pruned heads+        self.self.num_attention_heads = self.self.num_attention_heads - len(heads)+        self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads+        self.pruned_heads = self.pruned_heads.union(heads)++    def forward(+        self,+        hidden_states,+        attention_mask=None,+        head_mask=None,+        encoder_hidden_states=None,+        encoder_attention_mask=None,+        past_key_value=None,+        output_attentions=False,+    ):+        self_outputs = self.self(+            hidden_states,+            attention_mask,+            head_mask,+            encoder_hidden_states,+            encoder_attention_mask,+            past_key_value,+            output_attentions,+        )+        attention_output = self.output(self_outputs[0], hidden_states)+        outputs = (attention_output,) + self_outputs[1:]  # add attentions if we output them+        return outputs+++class VisualBertIntermediate(nn.Module):+    def __init__(self, config):+        super().__init__()+        self.dense = nn.Linear(config.hidden_size, config.intermediate_size)+        if isinstance(config.hidden_act, str):+            self.intermediate_act_fn = ACT2FN[config.hidden_act]+        else:+            self.intermediate_act_fn = config.hidden_act++    def forward(self, hidden_states):+        hidden_states = self.dense(hidden_states)+        hidden_states = self.intermediate_act_fn(hidden_states)+        return hidden_states+++class VisualBertOutput(nn.Module):+    def __init__(self, config):+        super().__init__()+        self.dense = nn.Linear(config.intermediate_size, config.hidden_size)+        # TO-CHECK+        # self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)+        self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps)  # original eps=1e-12+        self.dropout = nn.Dropout(config.hidden_dropout_prob)++    def forward(self, hidden_states, input_tensor):+        hidden_states = self.dense(hidden_states)+        hidden_states = self.dropout(hidden_states)+        hidden_states = self.LayerNorm(hidden_states + input_tensor)+        return hidden_states+++class VisualBertLayer(nn.Module):+    def __init__(self, config):+        super().__init__()+        self.chunk_size_feed_forward = config.chunk_size_feed_forward+        self.seq_len_dim = 1+        self.attention = VisualBertAttention(config)+        self.is_decoder = config.is_decoder+        self.add_cross_attention = config.add_cross_attention+        if self.add_cross_attention:+            assert self.is_decoder, f"{self} should be used as a decoder model if cross attention is added"+            self.crossattention = VisualBertAttention(config)+        self.intermediate = VisualBertIntermediate(config)+        self.output = VisualBertOutput(config)++    def forward(+        self,+        hidden_states,+        attention_mask=None,+        head_mask=None,+        encoder_hidden_states=None,+        encoder_attention_mask=None,+        past_key_value=None,+        output_attentions=False,+    ):+        # decoder uni-directional self-attention cached key/values tuple is at positions 1,2+        self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None+        self_attention_outputs = self.attention(+            hidden_states,+            attention_mask,+            head_mask,+            output_attentions=output_attentions,+            past_key_value=self_attn_past_key_value,+        )+        attention_output = self_attention_outputs[0]++        # if decoder, the last output is tuple of self-attn cache+        if self.is_decoder:+            outputs = self_attention_outputs[1:-1]+            present_key_value = self_attention_outputs[-1]+        else:+            outputs = self_attention_outputs[1:]  # add self attentions if we output attention weights++        cross_attn_present_key_value = None+        if self.is_decoder and encoder_hidden_states is not None:+            assert hasattr(+                self, "crossattention"+            ), f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers by setting `config.add_cross_attention=True`"++            # cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple+            cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None+            cross_attention_outputs = self.crossattention(+                attention_output,+                attention_mask,+                head_mask,+                encoder_hidden_states,+                encoder_attention_mask,+                cross_attn_past_key_value,+                output_attentions,+            )+            attention_output = cross_attention_outputs[0]+            outputs = outputs + cross_attention_outputs[1:-1]  # add cross attentions if we output attention weights++            # add cross-attn cache to positions 3,4 of present_key_value tuple+            cross_attn_present_key_value = cross_attention_outputs[-1]+            present_key_value = present_key_value + cross_attn_present_key_value++        layer_output = apply_chunking_to_forward(+            self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output+        )+        outputs = (layer_output,) + outputs++        # if decoder, return the attn key/values as the last output+        if self.is_decoder:+            outputs = outputs + (present_key_value,)++        return outputs++    def feed_forward_chunk(self, attention_output):+        intermediate_output = self.intermediate(attention_output)+        layer_output = self.output(intermediate_output, attention_output)+        return layer_output+++class VisualBertEncoder(nn.Module):+    def __init__(self, config):+        super().__init__()+        self.config = config+        self.layer = nn.ModuleList([VisualBertLayer(config) for _ in range(config.num_hidden_layers)])++    def forward(+        self,+        hidden_states,+        attention_mask=None,+        head_mask=None,+        encoder_hidden_states=None,+        encoder_attention_mask=None,+        past_key_values=None,+        use_cache=None,+        output_attentions=False,+        output_hidden_states=False,+        return_dict=True,+    ):+        all_hidden_states = () if output_hidden_states else None+        all_self_attentions = () if output_attentions else None+        all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None++        next_decoder_cache = () if use_cache else None+        for i, layer_module in enumerate(self.layer):+            if output_hidden_states:+                all_hidden_states = all_hidden_states + (hidden_states,)++            layer_head_mask = head_mask[i] if head_mask is not None else None+            past_key_value = past_key_values[i] if past_key_values is not None else None++            if getattr(self.config, "gradient_checkpointing", False) and self.training:++                if use_cache:+                    logger.warn(+                        "`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting "+                        "`use_cache=False`..."+                    )+                    use_cache = False++                def create_custom_forward(module):+                    def custom_forward(*inputs):+                        return module(*inputs, past_key_value, output_attentions)++                    return custom_forward++                layer_outputs = torch.utils.checkpoint.checkpoint(+                    create_custom_forward(layer_module),+                    hidden_states,+                    attention_mask,+                    layer_head_mask,+                    encoder_hidden_states,+                    encoder_attention_mask,+                )+            else:+                layer_outputs = layer_module(+                    hidden_states,+                    attention_mask,+                    layer_head_mask,+                    encoder_hidden_states,+                    encoder_attention_mask,+                    past_key_value,+                    output_attentions,+                )++            hidden_states = layer_outputs[0]+            if use_cache:+                next_decoder_cache += (layer_outputs[-1],)+            if output_attentions:+                all_self_attentions = all_self_attentions + (layer_outputs[1],)+                if self.config.add_cross_attention:+                    all_cross_attentions = all_cross_attentions + (layer_outputs[2],)++        if output_hidden_states:+            all_hidden_states = all_hidden_states + (hidden_states,)++        if not return_dict:+            return tuple(+                v+                for v in [+                    hidden_states,+                    next_decoder_cache,+                    all_hidden_states,+                    all_self_attentions,+                    all_cross_attentions,+                ]+                if v is not None+            )+        return BaseModelOutputWithPastAndCrossAttentions(+            last_hidden_state=hidden_states,+            past_key_values=next_decoder_cache,+            hidden_states=all_hidden_states,+            attentions=all_self_attentions,+            cross_attentions=all_cross_attentions,+        )+++class VisualBertPooler(nn.Module):+    def __init__(self, config):+        super().__init__()+        self.dense = nn.Linear(config.hidden_size, config.hidden_size)+        self.activation = nn.Tanh()++    def forward(self, hidden_states):+        # We "pool" the model by simply taking the hidden state corresponding+        # to the first token.+        first_token_tensor = hidden_states[:, 0]+        pooled_output = self.dense(first_token_tensor)+        pooled_output = self.activation(pooled_output)+        return pooled_output+++class VisualBertPredictionHeadTransform(nn.Module):+    def __init__(self, config):+        super().__init__()+        self.dense = nn.Linear(config.hidden_size, config.hidden_size)+        if isinstance(config.hidden_act, str):+            self.transform_act_fn = ACT2FN[config.hidden_act]+        else:+            self.transform_act_fn = config.hidden_act++        # TO-CHECK+        # self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)+        self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps)++    def forward(self, hidden_states):+        hidden_states = self.dense(hidden_states)+        hidden_states = self.transform_act_fn(hidden_states)+        hidden_states = self.LayerNorm(hidden_states)+        return hidden_states+++class VisualBertLMPredictionHead(nn.Module):+    def __init__(self, config):+        super().__init__()+        self.transform = VisualBertPredictionHeadTransform(config)++        # The output weights are the same as the input embeddings, but there is+        # an output-only bias for each token.+        self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)++        self.bias = nn.Parameter(torch.zeros(config.vocab_size))++        # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`+        self.decoder.bias = self.bias++    def forward(self, hidden_states):+        hidden_states = self.transform(hidden_states)+        hidden_states = self.decoder(hidden_states)+        return hidden_states+++class VisualBertOnlyMLMHead(nn.Module):+    def __init__(self, config):+        super().__init__()+        self.predictions = VisualBertLMPredictionHead(config)++    def forward(self, sequence_output):+        prediction_scores = self.predictions(sequence_output)+        return prediction_scores+++class VisualBertOnlySIPHead(nn.Module):  # Sentence-Image Prediction+    def __init__(self, config):+        super().__init__()+        self.seq_relationship = nn.Linear(config.hidden_size, 2)++    def forward(self, pooled_output):+        seq_relationship_score = self.seq_relationship(pooled_output)+        return seq_relationship_score+++class VisualBertPreTrainingHeads(nn.Module):+    def __init__(self, config):+        super().__init__()+        self.predictions = VisualBertLMPredictionHead(config)+        self.seq_relationship = nn.Linear(config.hidden_size, 2)++    def forward(self, sequence_output, pooled_output):+        prediction_scores = self.predictions(sequence_output)+        seq_relationship_score = self.seq_relationship(pooled_output)+        return prediction_scores, seq_relationship_score+++class VisualBertPreTrainedModel(PreTrainedModel):+    """+    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained+    models.+    """++    config_class = VisualBertConfig+    load_tf_weights = load_tf_weights_in_visual_bert+    base_model_prefix = "visual_bert"+    _keys_to_ignore_on_load_missing = [r"position_ids"]++    def _init_weights(self, module):+        """Initialize the weights"""+        if isinstance(module, (nn.Linear, nn.Embedding)):+            # Slightly different from the TF version which uses truncated_normal for initialization+            # cf https://github.com/pytorch/pytorch/pull/5617+            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)++        # TO-CHECK+        # elif isinstance(module, nn.LayerNorm):+        elif isinstance(module, BertLayerNorm):+            module.bias.data.zero_()+            module.weight.data.fill_(1.0)+        if isinstance(module, nn.Linear) and module.bias is not None:+            module.bias.data.zero_()+++@dataclass+class VisualBertForPreTrainingOutput(ModelOutput):+    """+    Output type of :class:`~transformers.VisualBertForPreTraining`.++    Args:+        loss (`optional`, returned when ``labels`` is provided, ``torch.FloatTensor`` of shape :obj:`(1,)`):+            Total loss as the sum of the masked language modeling loss and the sentence-image prediction+            (classification) loss.+        prediction_logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`):+            Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).+        seq_relationship_logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, 2)`):+            Prediction scores of the sentence-image prediction (classification) head (scores of True/False continuation+            before SoftMax).+        hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):+            Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)+            of shape :obj:`(batch_size, sequence_length, hidden_size)`.++            Hidden-states of the model at the output of each layer plus the initial embedding outputs.+        attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):+            Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads,+            sequence_length, sequence_length)`.++            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention+            heads.+    """++    loss: Optional[torch.FloatTensor] = None+    prediction_logits: torch.FloatTensor = None+    seq_relationship_logits: torch.FloatTensor = None+    hidden_states: Optional[Tuple[torch.FloatTensor]] = None+    attentions: Optional[Tuple[torch.FloatTensor]] = None+++VISUAL_BERT_START_DOCSTRING = r"""+    This model inherits from :class:`~transformers.PreTrainedModel`. Check the superclass documentation for the generic+    methods the library implements for all its model (such as downloading or saving, resizing the input embeddings,+    pruning heads etc.)++    This model is also a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`__+    subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to+    general usage and behavior.++    Parameters:+        config (:class:`~transformers.VisualBertConfig`): Model configuration class with all the parameters of the model.+            Initializing with a config file does not load the weights associated with the model, only the+            configuration. Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model+            weights.+"""++VISUAL_BERT_INPUTS_DOCSTRING = r"""+    Args:+        input_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`):+            Indices of input sequence tokens in the vocabulary.++            Indices can be obtained using :class:`~transformers.BertTokenizer`. See+            :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for+            details.++            `What are input IDs? <../glossary.html#input-ids>`__+        attention_mask (:obj:`torch.FloatTensor` of shape :obj:`({0})`, `optional`):+            Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``:++            - 1 for tokens that are **not masked**,+            - 0 for tokens that are **masked**.++            `What are attention masks? <../glossary.html#attention-mask>`__+        token_type_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`, `optional`):+            Segment token indices to indicate first and second portions of the inputs. Indices are selected in ``[0,+            1]``:++            - 0 corresponds to a `sentence A` token,+            - 1 corresponds to a `sentence B` token.++            `What are token type IDs? <../glossary.html#token-type-ids>`_+        position_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`, `optional`):+            Indices of positions of each input sequence tokens in the position embeddings. Selected in the range ``[0,+            config.max_position_embeddings - 1]``.++            `What are position IDs? <../glossary.html#position-ids>`_+        head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`):+            Mask to nullify selected heads of the self-attention modules. Mask values selected in ``[0, 1]``:++            - 1 indicates the head is **not masked**,+            - 0 indicates the head is **masked**.++        inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`({0}, hidden_size)`, `optional`):+            Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation.+            This is useful if you want more control over how to convert :obj:`input_ids` indices into associated+            vectors than the model's internal embedding lookup matrix.++        visual_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, visual_seq_length, visual_embedding_dim)`, `optional`):+            The embedded representation of the visual inputs, generally derived using using an object detector.++        visual_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, visual_seq_length)`, `optional`):+            Mask to avoid performing attention on visual embeddings. Mask values selected in ``[0, 1]``:++            - 1 for tokens that are **not masked**,+            - 0 for tokens that are **masked**.++            `What are attention masks? <../glossary.html#attention-mask>`__+        visual_token_type_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, visual_seq_length)`, `optional`):+            Segment token indices to indicate different portions of the visual embeds. Indices are selected in ``[0,+            1]``:++            - 0 corresponds to a `sentence A` token,+            - 1 corresponds to a `sentence B` token.++            `What are token type IDs? <../glossary.html#token-type-ids>`_++        image_text_alignment (:obj:`torch.LongTensor` of shape :obj:`(batch_size, visual_seq_length, alignment_number)`, `optional`):+            Image-Text alignment uses to decide the position IDs of the visual embeddings.++        output_attentions (:obj:`bool`, `optional`):+            Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned+            tensors for more detail.+        output_hidden_states (:obj:`bool`, `optional`):+            Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for+            more detail.+        return_dict (:obj:`bool`, `optional`):+            Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple.+"""++VISUAL_BERT_VQA_SAMPLE = r"""+    Example::+        >>> # Assumption: `get_visual_embeddings(image)` gets the visual embeddings of the image in the batch.+        >>> from transformers import {tokenizer_class}, {model_class}+        >>> import torch++        >>> tokenizer = {tokenizer_class}.from_pretrained('{tokenizer_checkpoint}')+        >>> model = {model_class}.from_pretrained('{checkpoint}')++        >>> text = "Who is eating the apple?"+        >>> inputs = tokenizer(text, return_tensors='pt')+        >>> visual_embeds = get_visual_embeddings(image).unsqueeze(0)+        >>> visual_token_type_ids = torch.ones(visual_embeds.shape[:-1], dtype=torch.long) #example+        >>> visual_attention_mask = torch.ones(visual_embeds.shape[:-1], dtype=torch.float)++        >>> inputs.update({{+            "visual_embeds": visual_embeds,+            "visual_token_type_ids": visual_token_type_ids,+            "visual_attention_mask": visual_attention_mask+        }})++        >>> labels = torch.tensor([[0.0,1.0]]).unsqueeze(0)  # Batch size 1, Num labels 2++        >>> outputs = model(**inputs, labels=labels)+        >>> loss = outputs.loss+        >>> scores = outputs.logits+"""++VISUAL_BERT_NLVR_SAMPLE = r"""+    Example::+        >>> # Assumption: `get_visual_embeddings(image)` gets the visual embeddings of the image in the batch.+        >>> from transformers import {tokenizer_class}, {model_class}+        >>> import torch++        >>> tokenizer = {tokenizer_class}.from_pretrained('{tokenizer_checkpoint}')+        >>> model = {model_class}.from_pretrained('{checkpoint}')++        >>> text = "Who is eating the apple?"+        >>> inputs = tokenizer(text, return_tensors='pt')+        >>> visual_embeds = get_visual_embeddings(image).unsqueeze(0)+        >>> visual_token_type_ids = torch.ones(visual_embeds.shape[:-1], dtype=torch.long) #example+        >>> visual_attention_mask = torch.ones(visual_embeds.shape[:-1], dtype=torch.float)++        >>> inputs.update({{+            "visual_embeds": visual_embeds,+            "visual_token_type_ids": visual_token_type_ids,+            "visual_attention_mask": visual_attention_mask+        }})++        >>> labels = torch.tensor(1).unsqueeze(0)  # Batch size 1, Num choices 2++        >>> outputs = model(**inputs, labels=labels)+        >>> loss = outputs.loss+        >>> scores = outputs.logits+"""+++VISUAL_BERT_VQA_ADVANCED_SAMPLE = r"""+    Example::+        >>> # Assumption: `get_visual_embeddings(image)` gets the visual embeddings of the image in the batch.+        >>> from transformers import {tokenizer_class}, {model_class}+        >>> import torch++        >>> tokenizer = {tokenizer_class}.from_pretrained('{tokenizer_checkpoint}')+        >>> model = {model_class}.from_pretrained('{checkpoint}')++        >>> text = "Who is eating the apple?"+        >>> inputs = tokenizer(text, return_tensors='pt')+        >>> visual_embeds = get_visual_embeddings(image).unsqueeze(0)+        >>> visual_token_type_ids = torch.ones(visual_embeds.shape[:-1], dtype=torch.long)+        >>> visual_attention_mask = torch.ones(visual_embeds.shape[:-1], dtype=torch.float)++        >>> inputs.update({{+            "visual_embeds": visual_embeds,+            "visual_token_type_ids": visual_token_type_ids,+            "visual_attention_mask": visual_attention_mask+        }})++        >>> labels = torch.ones((1, inputs["input_ids"].shape[-1]+visual_embeds.shape[-2])) # Batch size 1++        >>> outputs = model(**inputs, labels=labels)+        >>> loss = outputs.loss+        >>> scores = outputs.logits+"""++VISUAL_BERT_FLICKR_SAMPLE = r"""+    Example::+        >>> # Assumption: `get_visual_embeddings(image)` gets the visual embeddings of the image in the batch.+        >>> from transformers import {tokenizer_class}, {model_class}+        >>> import torch++        >>> tokenizer = {tokenizer_class}.from_pretrained('{tokenizer_checkpoint}')+        >>> model = {model_class}.from_pretrained('{checkpoint}')++        >>> text = "Who is eating the apple?"+        >>> inputs = tokenizer(text, return_tensors='pt')+        >>> visual_embeds = get_visual_embeddings(image).unsqueeze(0)+        >>> visual_token_type_ids = torch.ones(visual_embeds.shape[:-1], dtype=torch.long) #example+        >>> visual_attention_mask = torch.ones(visual_embeds.shape[:-1], dtype=torch.float)+        >>> flickr_position = torch.ones((1, inputs["input_ids"].shape[-1]+visual_embeds.shape[-2]))++        >>> inputs.update({{+            "flickr_position": flickr_position,+            "visual_embeds": visual_embeds,+            "visual_token_type_ids": visual_token_type_ids,+            "visual_attention_mask": visual_attention_mask+        }})++        >>> labels = torch.ones((1, inputs["input_ids"].shape[-1]+visual_embeds.shape[-2], visual_embeds.shape[-2])) # Batch size 1++        >>> outputs = model(**inputs, labels=labels)+        >>> loss = outputs.loss+        >>> scores = outputs.logits+"""++VISUAL_BERT_PRE_TRAINING_SAMPLE = r"""+    Example::+        >>> # Assumption: `get_visual_embeddings(image)` gets the visual embeddings of the image in the batch.+        >>> from transformers import {tokenizer_class}, {model_class}+        >>> import torch++        >>> tokenizer = {tokenizer_class}.from_pretrained('{tokenizer_checkpoint}')+        >>> model = {model_class}.from_pretrained('{checkpoint}')++        >>> inputs = tokenizer("The capital of France is {mask}.", return_tensors="pt")+        >>> visual_embeds = get_visual_embeddings(image).unsqueeze(0)+        >>> visual_token_type_ids = torch.ones(visual_embeds.shape[:-1], dtype=torch.long) #example+        >>> visual_attention_mask = torch.ones(visual_embeds.shape[:-1], dtype=torch.float)++        >>> inputs.update({{+            "visual_embeds": visual_embeds,+            "visual_token_type_ids": visual_token_type_ids,+            "visual_attention_mask": visual_attention_mask+        }})++        >>> labels = tokenizer("The capital of France is Paris.", return_tensors="pt", padding="max_length", max_length=inputs["input_ids"].shape[-1]+visual_embeds.shape[-2])["input_ids"]+        >>> sentence_image_labels = torch.tensor(1).unsqueeze(0) # Batch_size+++        >>> outputs = model(**inputs, labels=labels, sentence_image_labels=sentence_image_labels)+        >>> loss = outputs.loss+        >>> prediction_logits = outputs.prediction_logits+        >>> seq_relationship_logits = outputs.seq_relationship_logits+"""++VISUAL_BERT_MODEL_SAMPLE = r"""+    Example::+        >>> # Assumption: `get_visual_embeddings(image)` gets the visual embeddings of the image.+        >>> from transformers import {tokenizer_class}, {model_class}+        >>> import torch++        >>> tokenizer = {tokenizer_class}.from_pretrained('{tokenizer_checkpoint}')+        >>> model = {model_class}.from_pretrained('{checkpoint}')++        >>> inputs = tokenizer("The capital of France is Paris.", return_tensors="pt")+        >>> visual_embeds = get_visual_embeddings(image).unsqueeze(0)+        >>> visual_token_type_ids = torch.ones(visual_embeds.shape[:-1], dtype=torch.long) #example+        >>> visual_attention_mask = torch.ones(visual_embeds.shape[:-1], dtype=torch.float)++        >>> inputs.update({{+            "visual_embeds": visual_embeds,+            "visual_token_type_ids": visual_token_type_ids,+            "visual_attention_mask": visual_attention_mask+        }})++        >>> outputs = model(**inputs)++        >>> last_hidden_states = outputs.last_hidden_state+"""++VISUAL_BERT_MULTIPLE_CHOICE_SAMPLE = r"""+    Example::++        >>> from transformers import {tokenizer_class}, {model_class}+        >>> import torch++        >>> tokenizer = {tokenizer_class}.from_pretrained('{tokenizer_checkpoint}')+        >>> model = {model_class}.from_pretrained('{checkpoint}')++        >>> prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced."+        >>> choice0 = "It is eaten with a fork and a knife."+        >>> choice1 = "It is eaten while held in the hand."++        >>> visual_embeds = get_visual_embeddings(image)+        >>> visual_embeds = visual_embeds.expand(1, 2, *visual_embeds.shape) # (batch_size, num_choices, visual_seq_length, visual_embedding_dim)+        >>> visual_token_type_ids = torch.ones(visual_embeds.shape[:-1], dtype=torch.long)+        >>> visual_attention_mask = torch.ones(visual_embeds.shape[:-1], dtype=torch.float)++        >>> labels = torch.tensor(0).unsqueeze(0)  # choice0 is correct (according to Wikipedia ;)), batch size 1++        >>> encoding = tokenizer([[prompt, prompt], [choice0, choice1]], return_tensors='pt', padding=True)<