models.base_model¶
Models built on top of Pythia need to inherit BaseModel class and adhere to
some format. To create a model for MMF, follow this quick cheatsheet.
- Inherit
BaseModelclass, make sure to callsuper().__init__()in your class’s__init__function. - Implement build function for your model. If you build everything in
__init__, you can just return in this function. - Write a forward function which takes in a
SampleListas an argument and returns a dict. - Register using
@registry.register_model("key")decorator on top of the class.
If you are doing logits based predictions, the dict you return from your model
should contain a scores field. Losses are automatically calculated by the
BaseModel class and added to this dict if not present.
Example:
import torch
from mmf.common.registry import registry
from mmf.models.base_model import BaseModel
@registry.register("pythia")
class Pythia(BaseModel):
# config is model_config from global config
def __init__(self, config):
super().__init__(config)
def build(self):
....
def forward(self, sample_list):
scores = torch.rand(sample_list.get_batch_size(), 3127)
return {"scores": scores}
-
class
mmf.models.base_model.BaseModel(config)[source]¶ For integration with Pythia’s trainer, datasets and other features, models needs to inherit this class, call super, write a build function, write a forward function taking a
SampleListas input and returning a dict as output and finally, register it using@registry.register_modelParameters: config (DictConfig) – model_configconfiguration from global config.-
build()[source]¶ Function to be implemented by the child class, in case they need to build their model separately than
__init__. All model related downloads should also happen here.
-
format_for_prediction(results, report)[source]¶ Implement this method in models if it requires to modify prediction results using report fields. Note that the required fields in report should already be gathered in report.
-
classmethod
format_state_key(key)[source]¶ Can be implemented if something special needs to be done key when pretrained model is being load. This will adapt and return keys according to that. Useful for backwards compatibility. See updated load_state_dict below. For an example, see VisualBERT model’s code.
Parameters: key (string) – key to be formatted Returns: formatted key Return type: string
-
forward(sample_list, *args, **kwargs)[source]¶ To be implemented by child class. Takes in a
SampleListand returns back a dict.Parameters: - sample_list (SampleList) – SampleList returned by the DataLoader for
- iteration (current) –
Returns: Dict containing scores object.
Return type: Dict
-
init_losses()[source]¶ Initializes loss for the model based
losseskey. Automatically called by MMF internally after building the model.
-
load_state_dict(state_dict, *args, **kwargs)[source]¶ Copies parameters and buffers from
state_dictinto this module and its descendants. IfstrictisTrue, then the keys ofstate_dictmust exactly match the keys returned by this module’sstate_dict()function.Parameters: - state_dict (dict) – a dict containing parameters and persistent buffers.
- strict (bool, optional) – whether to strictly enforce that the keys
in
state_dictmatch the keys returned by this module’sstate_dict()function. Default:True
Returns: - missing_keys is a list of str containing the missing keys
- unexpected_keys is a list of str containing the unexpected keys
Return type: NamedTuplewithmissing_keysandunexpected_keysfields
-