Shortcuts

common.sample

Sample and SampleList are data structures for arbitrary data returned from a dataset. To work with MMF, minimum requirement for datasets is to return an object of Sample class and for models to accept an object of type SampleList as an argument.

Sample is used to represent an arbitrary sample from dataset, while SampleList is list of Sample combined in an efficient way to be used by the model. In simple term, SampleList is a batch of Sample but allow easy access of attributes from Sample while taking care of properly batching things.

class mmf.common.sample.Sample(init_dict=None)[source]

Sample represent some arbitrary data. All datasets in MMF must return an object of type Sample.

Parameters:init_dict (Dict) – Dictionary to init Sample class with.

Usage:

>>> sample = Sample({"text": torch.tensor(2)})
>>> sample.text.zero_()
# Custom attributes can be added to ``Sample`` after initialization
>>> sample.context = torch.tensor(4)
fields()[source]

Get current attributes/fields registered under the sample.

Returns:Attributes registered under the Sample.
Return type:List[str]
class mmf.common.sample.SampleList(samples=None)[source]

SampleList is used to collate a list of Sample into a batch during batch preparation. It can be thought of as a merger of list of Dicts into a single Dict.

If Sample contains an attribute ‘text’ of size (2) and there are 10 samples in list, the returned SampleList will have an attribute ‘text’ which is a tensor of size (10, 2).

Parameters:samples (type) – List of Sample from which the SampleList will be created.

Usage:

>>> sample_list = [
        Sample({"text": torch.tensor(2)}),
        Sample({"text": torch.tensor(2)})
    ]
>>> sample_list.text
torch.tensor([2, 2])
add_field(field, data)[source]

Add an attribute field with value data to the SampleList

Parameters:
  • field (str) – Key under which the data will be added.
  • data (object) – Data to be added, can be a torch.Tensor, list or Sample
copy()[source]

Get a copy of the current SampleList

Returns:Copy of current SampleList.
Return type:SampleList
fields()[source]

Get current attributes/fields registered under the SampleList.

Returns:list of attributes of the SampleList.
Return type:List[str]
get_batch_size()[source]

Get batch size of the current SampleList. There must be a tensor be a tensor present inside sample list to use this function. :returns: Size of the batch in SampleList. :rtype: int

get_field(field)[source]

Get value of a particular attribute

Parameters:field (str) – Attribute whose value is to be returned.
get_fields(fields)[source]

Get a new SampleList generated from the current SampleList but contains only the attributes passed in fields argument

Parameters:fields (List[str]) – Attributes whose SampleList will be made.
Returns:SampleList containing only the attribute values of the fields which were passed.
Return type:SampleList
get_item_list(key)[source]

Get SampleList of only one particular attribute that is present in the SampleList.

Parameters:key (str) – Attribute whose SampleList will be made.
Returns:SampleList containing only the attribute value of the key which was passed.
Return type:SampleList
pin_memory()[source]

In custom batch object, we need to define pin_memory function so that PyTorch can actually apply pinning. This function just individually pins all of the tensor fields

to(device, non_blocking=True)[source]

Similar to .to function on a torch.Tensor. Moves all of the tensors present inside the SampleList to a particular device. If an attribute’s value is not a tensor, it is ignored and kept as it is.

Parameters:
  • device (str|torch.device) – Device on which the SampleList should moved.
  • non_blocking (bool) – Whether the move should be non_blocking. Default: True
Returns:

a SampleList moved to the device.

Return type:

SampleList

to_dict() → Dict[str, Any][source]

Converts a sample list to dict, this is useful for TorchScript and for other internal API unification efforts.

Returns:A dict representation of current sample list
Return type:Dict[str, Any]
mmf.common.sample.detach_tensor(tensor: Any) → Any[source]

Detaches any element passed which has a .detach function defined. Currently, in MMF can be SampleList, Report or a tensor.

Parameters:tensor (Any) – Item to be detached
Returns:Detached element
Return type:Any
Read the Docs v: latest
Versions
latest
stable
website
configuration_docs
Downloads
pdf
html
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.