Source code for mmf.datasets.processors.bert_processors
# Copyright (c) Facebook, Inc. and its affiliates.
import random
import torch
from transformers.tokenization_auto import AutoTokenizer
from mmf.common.registry import registry
from mmf.datasets.processors.processors import BaseProcessor
[docs]@registry.register_processor("masked_token")
class MaskedTokenProcessor(BaseProcessor):
_CLS_TOKEN = "[CLS]"
_SEP_TOKEN = "[SEP]"
def __init__(self, config, *args, **kwargs):
tokenizer_config = config.tokenizer_config
self._tokenizer = AutoTokenizer.from_pretrained(
tokenizer_config.type, **tokenizer_config.params
)
self._max_seq_length = config.max_seq_length
self._probability = getattr(config, "mask_probability", 0.15)
def get_vocab_size(self):
return len(self._tokenizer)
def _random_word(self, tokens, probability=0.15):
labels = []
for idx, token in enumerate(tokens):
prob = random.random()
if prob < probability:
prob /= probability
# 80% randomly change token to mask token
if prob < 0.8:
tokens[idx] = "[MASK]"
# 10% randomly change token to random token
elif prob < 0.9:
tokens[idx] = self._tokenizer.convert_ids_to_tokens(
torch.randint(len(self._tokenizer), (1,), dtype=torch.long)
)[0]
# rest 10% keep the original token as it is
labels.append(self._tokenizer.convert_tokens_to_ids(token))
else:
labels.append(-1)
return tokens, labels
[docs] def _truncate_seq_pair(self, tokens_a, tokens_b, max_length):
"""Truncates a sequence pair in place to the maximum length."""
# This is a simple heuristic which will always truncate the longer sequence
# one token at a time. This makes more sense than truncating an equal percent
# of tokens from each, since if one sequence is very short then each token
# that's truncated likely contains more information than a longer sequence.
if tokens_b is None:
tokens_b = []
while True:
total_length = len(tokens_a) + len(tokens_b)
if total_length <= max_length:
break
if len(tokens_a) > len(tokens_b):
tokens_a.pop()
else:
tokens_b.pop()
def _convert_to_indices(self, tokens_a, tokens_b=None, probability=0.15):
tokens_a, label_a = self._random_word(tokens_a, probability=probability)
tokens = [self._CLS_TOKEN]
segment_ids = [0]
tokens += tokens_a
segment_ids += [0] * len(tokens_a)
tokens.append(self._SEP_TOKEN)
segment_ids.append(0)
if tokens_b:
tokens_b, label_b = self._random_word(tokens_b, probability=probability)
lm_label_ids = [-1] + label_a + [-1] + label_b + [-1]
assert len(tokens_b) > 0
tokens += tokens_b
segment_ids += [1] * len(tokens_b)
tokens.append(self._SEP_TOKEN)
segment_ids.append(1)
else:
lm_label_ids = [-1] + label_a + [-1]
input_ids = self._tokenizer.convert_tokens_to_ids(tokens)
input_mask = [1] * len(input_ids)
# Zero-pad up to the sequence length.
while len(input_ids) < self._max_seq_length:
input_ids.append(0)
input_mask.append(0)
segment_ids.append(0)
lm_label_ids.append(-1)
assert len(input_ids) == self._max_seq_length
assert len(input_mask) == self._max_seq_length
assert len(segment_ids) == self._max_seq_length
assert len(lm_label_ids) == self._max_seq_length
input_ids = torch.tensor(input_ids, dtype=torch.long)
input_mask = torch.tensor(input_mask, dtype=torch.long)
segment_ids = torch.tensor(segment_ids, dtype=torch.long)
lm_label_ids = torch.tensor(lm_label_ids, dtype=torch.long)
return {
"input_ids": input_ids,
"input_mask": input_mask,
"segment_ids": segment_ids,
"lm_label_ids": lm_label_ids,
"tokens": tokens,
}
def __call__(self, item):
text_a = item["text_a"]
text_b = item.get("text_b", None)
tokens_a = self._tokenizer.tokenize(text_a)
tokens_b = None
if text_b:
tokens_b = self._tokenizer.tokenize(text_b)
self._truncate_seq_pair(tokens_a, tokens_b, self._max_seq_length - 2)
output = self._convert_to_indices(
tokens_a, tokens_b, probability=self._probability
)
output["is_correct"] = torch.tensor(item["is_correct"], dtype=torch.long)
return output
@registry.register_processor("bert_tokenizer")
class BertTokenizer(MaskedTokenProcessor):
def __init__(self, config, *args, **kwargs):
super().__init__(config, *args, **kwargs)
self._probability = 0
def __call__(self, item):
if "text" in item:
text_a = item["text"]
else:
text_a = " ".join(item["tokens"])
tokens_a = self._tokenizer.tokenize(text_a)
self._truncate_seq_pair(tokens_a, None, self._max_seq_length - 2)
output = self._convert_to_indices(tokens_a, None, probability=self._probability)
output["text"] = output["tokens"]
return output