models.base_model¶
Models built in MMF need to inherit BaseModel class and adhere to
a fixed format. To create a model for MMF, follow this quick cheatsheet.
- Inherit
BaseModelclass, make sure to callsuper().__init__()in your class’s__init__function. - Implement build function for your model. If you build everything in
__init__, you can just return in this function. - Write a forward function which takes in a
SampleListas an argument and returns a dict. - Register using
@registry.register_model("key")decorator on top of the class.
If you are doing logits based predictions, the dict you return from your model
should contain a scores field. Losses are automatically calculated by the
BaseModel class and added to this dict if not present.
Example:
import torch
from mmf.common.registry import registry
from mmf.models.base_model import BaseModel
@registry.register("pythia")
class Pythia(BaseModel):
# config is model_config from global config
def __init__(self, config):
super().__init__(config)
def build(self):
....
def forward(self, sample_list):
scores = torch.rand(sample_list.get_batch_size(), 3127)
return {"scores": scores}
-
class
mmf.models.base_model.BaseModel(config: Union[omegaconf.dictconfig.DictConfig, mmf.models.base_model.BaseModel.Config])[source]¶ For integration with MMF’s trainer, datasets and other features, models needs to inherit this class, call super, write a build function, write a forward function taking a
SampleListas input and returning a dict as output and finally, register it using@registry.register_modelParameters: config (DictConfig) – model_configconfiguration from global config.-
class
Config(model: str = '???', losses: Union[List[mmf.modules.losses.LossConfig], NoneType] = '???')[source]¶
-
build()[source]¶ Function to be implemented by the child class, in case they need to build their model separately than
__init__. All model related downloads should also happen here.
-
format_for_prediction(results, report)[source]¶ Implement this method in models if it requires to modify prediction results using report fields. Note that the required fields in report should already be gathered in report.
-
classmethod
format_state_key(key)[source]¶ Can be implemented if something special needs to be done to the key when pretrained model is being loaded. This will adapt and return keys according to that. Useful for backwards compatibility. See updated load_state_dict below. For an example, see VisualBERT model’s code.
Parameters: key (string) – key to be formatted Returns: formatted key Return type: string
-
forward(sample_list, *args, **kwargs)[source]¶ To be implemented by child class. Takes in a
SampleListand returns back a dict.Parameters: - sample_list (SampleList) – SampleList returned by the DataLoader for
- iteration (current) –
Returns: Dict containing scores object.
Return type: Dict
-
init_losses()[source]¶ Initializes loss for the model based
losseskey. Automatically called by MMF internally after building the model.
-
load_state_dict(state_dict, *args, **kwargs)[source]¶ Copies parameters and buffers from
state_dictinto this module and its descendants. IfstrictisTrue, then the keys ofstate_dictmust exactly match the keys returned by this module’sstate_dict()function.Parameters: - state_dict (dict) – a dict containing parameters and persistent buffers.
- strict (bool, optional) – whether to strictly enforce that the keys
in
state_dictmatch the keys returned by this module’sstate_dict()function. Default:True
Returns: - missing_keys is a list of str containing the missing keys
- unexpected_keys is a list of str containing the unexpected keys
Return type: NamedTuplewithmissing_keysandunexpected_keysfields
-
on_load_checkpoint(checkpoint: Dict[str, Any]) → None[source]¶ This is called by the pl.LightningModule before the model’s checkpoint is loaded.
-
on_save_checkpoint(checkpoint: Dict[str, Any]) → None[source]¶ Give the model a chance to add something to the checkpoint.
state_dictis already there.Parameters: checkpoint – A dictionary in which you can save variables to save in a checkpoint. Contents need to be pickleable.
-
test_step(batch: mmf.common.sample.SampleList, batch_idx: int, *args, **kwargs)[source]¶ Member function of PL modules. Used only when PL enabled. To be implemented by child class. Takes in a
SampleList, batch_idx and returns back a dict.Parameters: - sample_list (SampleList) – SampleList returned by the DataLoader for
- iteration (current) –
Returns: Dict
-
training_step(batch: mmf.common.sample.SampleList, batch_idx: int, *args, **kwargs)[source]¶ Member function of PL modules. Used only when PL enabled. To be implemented by child class. Takes in a
SampleList, batch_idx and returns back a dict.Parameters: - sample_list (SampleList) – SampleList returned by the DataLoader for
- iteration (current) –
Returns: Dict containing loss.
Return type: Dict
-
validation_step(batch: mmf.common.sample.SampleList, batch_idx: int, *args, **kwargs)[source]¶ Member function of PL modules. Used only when PL enabled. To be implemented by child class. Takes in a
SampleList, batch_idx and returns back a dict.Parameters: - sample_list (SampleList) – SampleList returned by the DataLoader for
- iteration (current) –
Returns: Dict
-
class