common.sample¶
Sample and SampleList are data structures for arbitrary data returned from a
dataset. To work with MMF, minimum requirement for datasets is to return
an object of Sample class and for models to accept an object of type SampleList
as an argument.
Sample is used to represent an arbitrary sample from dataset, while SampleList
is list of Sample combined in an efficient way to be used by the model.
In simple term, SampleList is a batch of Sample but allow easy access of
attributes from Sample while taking care of properly batching things.
-
class
mmf.common.sample.Sample(init_dict=None)[source]¶ Sample represent some arbitrary data. All datasets in MMF must return an object of type
Sample.Parameters: init_dict (Dict) – Dictionary to init Sampleclass with.Usage:
>>> sample = Sample({"text": torch.tensor(2)}) >>> sample.text.zero_() # Custom attributes can be added to ``Sample`` after initialization >>> sample.context = torch.tensor(4)
-
class
mmf.common.sample.SampleList(samples=None)[source]¶ SampleListis used to collate a list ofSampleinto a batch during batch preparation. It can be thought of as a merger of list of Dicts into a single Dict.If
Samplecontains an attribute ‘text’ of size (2) and there are 10 samples in list, the returnedSampleListwill have an attribute ‘text’ which is a tensor of size (10, 2).Parameters: samples (type) – List of Samplefrom which theSampleListwill be created.Usage:
>>> sample_list = [ Sample({"text": torch.tensor(2)}), Sample({"text": torch.tensor(2)}) ] >>> sample_list.text torch.tensor([2, 2])
-
add_field(field, data)[source]¶ Add an attribute
fieldwith valuedatato the SampleListParameters: - field (str) – Key under which the data will be added.
- data (object) – Data to be added, can be a
torch.Tensor,listorSample
-
copy()[source]¶ Get a copy of the current SampleList
Returns: Copy of current SampleList. Return type: SampleList
-
fields()[source]¶ Get current attributes/fields registered under the SampleList.
Returns: list of attributes of the SampleList. Return type: List[str]
-
get_batch_size()[source]¶ Get batch size of the current
SampleList. There must be a tensor- def __getitem__(self, key):
return self.__dict__[key] field present in the
SampleListcurrently.Returns:
- def __getitem__(self, key):
- return self.__dict__[key]
- int: Size of the batch in
SampleList.
-
get_field(field)[source]¶ Get value of a particular attribute
Parameters: field (str) – Attribute whose value is to be returned.
-
get_fields(fields)[source]¶ Get a new
SampleListgenerated from the currentSampleListbut contains only the attributes passed in fields argumentParameters: fields (List[str]) – Attributes whose SampleListwill be made.Returns: SampleList containing only the attribute values of the fields which were passed. Return type: SampleList
-
get_item_list(key)[source]¶ Get
SampleListof only one particular attribute that is present in theSampleList.Parameters: key (str) – Attribute whose SampleListwill be made.Returns: SampleList containing only the attribute value of the key which was passed. Return type: SampleList
-
pin_memory()[source]¶ In custom batch object, we need to define pin_memory function so that PyTorch can actually apply pinning. This function just individually pins all of the tensor fields
-
to(device, non_blocking=True)[source]¶ Similar to
.tofunction on a torch.Tensor. Moves all of the tensors present inside theSampleListto a particular device. If an attribute’s value is not a tensor, it is ignored and kept as it is.Parameters: - device (str|torch.device) – Device on which the
SampleListshould moved. - non_blocking (bool) – Whether the move should be non_blocking. Default: True
Returns: a SampleList moved to the
device.Return type: - device (str|torch.device) – Device on which the
-