Shortcuts

models.base_model

Models built in MMF need to inherit BaseModel class and adhere to a fixed format. To create a model for MMF, follow this quick cheatsheet.

  1. Inherit BaseModel class, make sure to call super().__init__() in your class’s __init__ function.
  2. Implement build function for your model. If you build everything in __init__, you can just return in this function.
  3. Write a forward function which takes in a SampleList as an argument and returns a dict.
  4. Register using @registry.register_model("key") decorator on top of the class.

If you are doing logits based predictions, the dict you return from your model should contain a scores field. Losses are automatically calculated by the BaseModel class and added to this dict if not present.

Example:

import torch

from mmf.common.registry import registry
from mmf.models.base_model import BaseModel


@registry.register("pythia")
class Pythia(BaseModel):
    # config is model_config from global config
    def __init__(self, config):
        super().__init__(config)

    def build(self):
        ....

    def forward(self, sample_list):
        scores = torch.rand(sample_list.get_batch_size(), 3127)
        return {"scores": scores}
class mmf.models.base_model.BaseModel(config: Union[omegaconf.dictconfig.DictConfig, mmf.models.base_model.BaseModel.Config])[source]

For integration with MMF’s trainer, datasets and other features, models needs to inherit this class, call super, write a build function, write a forward function taking a SampleList as input and returning a dict as output and finally, register it using @registry.register_model

Parameters:config (DictConfig) – model_config configuration from global config.
class Config(model: str = '???', losses: Union[List[mmf.modules.losses.LossConfig], NoneType] = '???')[source]
build()[source]

Function to be implemented by the child class, in case they need to build their model separately than __init__. All model related downloads should also happen here.

configure_optimizers()[source]

Member function of PL modules. Used only when PL enabled.

format_for_prediction(results, report)[source]

Implement this method in models if it requires to modify prediction results using report fields. Note that the required fields in report should already be gathered in report.

classmethod format_state_key(key)[source]

Can be implemented if something special needs to be done to the key when pretrained model is being loaded. This will adapt and return keys according to that. Useful for backwards compatibility. See updated load_state_dict below. For an example, see VisualBERT model’s code.

Parameters:key (string) – key to be formatted
Returns:formatted key
Return type:string
forward(sample_list, *args, **kwargs)[source]

To be implemented by child class. Takes in a SampleList and returns back a dict.

Parameters:
  • sample_list (SampleList) – SampleList returned by the DataLoader for
  • iteration (current) –
Returns:

Dict containing scores object.

Return type:

Dict

init_losses()[source]

Initializes loss for the model based losses key. Automatically called by MMF internally after building the model.

load_state_dict(state_dict, *args, **kwargs)[source]

Copies parameters and buffers from state_dict into this module and its descendants. If strict is True, then the keys of state_dict must exactly match the keys returned by this module’s state_dict() function.

Parameters:
  • state_dict (dict) – a dict containing parameters and persistent buffers.
  • strict (bool, optional) – whether to strictly enforce that the keys in state_dict match the keys returned by this module’s state_dict() function. Default: True
Returns:

  • missing_keys is a list of str containing the missing keys
  • unexpected_keys is a list of str containing the unexpected keys

Return type:

NamedTuple with missing_keys and unexpected_keys fields

on_load_checkpoint(checkpoint: Dict[str, Any]) → None[source]

This is called by the pl.LightningModule before the model’s checkpoint is loaded.

on_save_checkpoint(checkpoint: Dict[str, Any]) → None[source]

Give the model a chance to add something to the checkpoint. state_dict is already there.

Parameters:checkpoint – A dictionary in which you can save variables to save in a checkpoint. Contents need to be pickleable.
test_step(batch: mmf.common.sample.SampleList, batch_idx: int, *args, **kwargs)[source]

Member function of PL modules. Used only when PL enabled. To be implemented by child class. Takes in a SampleList, batch_idx and returns back a dict.

Parameters:
  • sample_list (SampleList) – SampleList returned by the DataLoader for
  • iteration (current) –
Returns:

Dict

training_step(batch: mmf.common.sample.SampleList, batch_idx: int, *args, **kwargs)[source]

Member function of PL modules. Used only when PL enabled. To be implemented by child class. Takes in a SampleList, batch_idx and returns back a dict.

Parameters:
  • sample_list (SampleList) – SampleList returned by the DataLoader for
  • iteration (current) –
Returns:

Dict containing loss.

Return type:

Dict

validation_step(batch: mmf.common.sample.SampleList, batch_idx: int, *args, **kwargs)[source]

Member function of PL modules. Used only when PL enabled. To be implemented by child class. Takes in a SampleList, batch_idx and returns back a dict.

Parameters:
  • sample_list (SampleList) – SampleList returned by the DataLoader for
  • iteration (current) –
Returns:

Dict

Read the Docs v: latest
Versions
latest
stable
website
configuration_docs
Downloads
pdf
html
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.