models.base_model¶
Models built on top of Pythia need to inherit BaseModel
class and adhere to
some format. To create a model for MMF, follow this quick cheatsheet.
- Inherit
BaseModel
class, make sure to callsuper().__init__()
in your class’s__init__
function. - Implement build function for your model. If you build everything in
__init__
, you can just return in this function. - Write a forward function which takes in a
SampleList
as an argument and returns a dict. - Register using
@registry.register_model("key")
decorator on top of the class.
If you are doing logits based predictions, the dict you return from your model
should contain a scores field. Losses are automatically calculated by the
BaseModel
class and added to this dict if not present.
Example:
import torch
from mmf.common.registry import registry
from mmf.models.base_model import BaseModel
@registry.register("pythia")
class Pythia(BaseModel):
# config is model_config from global config
def __init__(self, config):
super().__init__(config)
def build(self):
....
def forward(self, sample_list):
scores = torch.rand(sample_list.get_batch_size(), 3127)
return {"scores": scores}
-
class
mmf.models.base_model.
BaseModel
(config)[source]¶ For integration with Pythia’s trainer, datasets and other features, models needs to inherit this class, call super, write a build function, write a forward function taking a
SampleList
as input and returning a dict as output and finally, register it using@registry.register_model
Parameters: config (DictConfig) – model_config
configuration from global config.-
build
()[source]¶ Function to be implemented by the child class, in case they need to build their model separately than
__init__
. All model related downloads should also happen here.
-
format_for_prediction
(results, report)[source]¶ Implement this method in models if it requires to modify prediction results using report fields. Note that the required fields in report should already be gathered in report.
-
classmethod
format_state_key
(key)[source]¶ Can be implemented if something special needs to be done key when pretrained model is being load. This will adapt and return keys according to that. Useful for backwards compatibility. See updated load_state_dict below. For an example, see VisualBERT model’s code.
Parameters: key (string) – key to be formatted Returns: formatted key Return type: string
-
forward
(sample_list, *args, **kwargs)[source]¶ To be implemented by child class. Takes in a
SampleList
and returns back a dict.Parameters: - sample_list (SampleList) – SampleList returned by the DataLoader for
- iteration (current) –
Returns: Dict containing scores object.
Return type: Dict
-
init_losses
()[source]¶ Initializes loss for the model based
losses
key. Automatically called by MMF internally after building the model.
-
load_state_dict
(state_dict, *args, **kwargs)[source]¶ Copies parameters and buffers from
state_dict
into this module and its descendants. Ifstrict
isTrue
, then the keys ofstate_dict
must exactly match the keys returned by this module’sstate_dict()
function.Parameters: - state_dict (dict) – a dict containing parameters and persistent buffers.
- strict (bool, optional) – whether to strictly enforce that the keys
in
state_dict
match the keys returned by this module’sstate_dict()
function. Default:True
Returns: - missing_keys is a list of str containing the missing keys
- unexpected_keys is a list of str containing the unexpected keys
Return type: NamedTuple
withmissing_keys
andunexpected_keys
fields
-