1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48
| from collections import defaultdict, Counter
class Vocab(object): def __init__(self, tokens=None): self.idx_to_token = list() self.token_to_idx = dict()
if tokens is not None: if "<unk>" not in tokens: tokens = tokens + ["<unk>"] for token in tokens: self.idx_to_token.append(token) self.token_to_idx[token] = len(self.idx_to_token) - 1 self.unk = self.token_to_idx['<unk>']
@classmethod def build(cls, text, min_freq=1, reserved_tokens=None): token_freqs = defaultdict(int) for sentence in text: for token in sentence: token_freqs[token] += 1 uniq_tokens = ["<unk>"] + (reserved_tokens if reserved_tokens else []) uniq_tokens += [token for token, freq in token_freqs.items() \ if freq >= min_freq and token != "<unk>"] return cls(uniq_tokens)
def __len__(self): return len(self.idx_to_token)
def __getitem__(self, token): return self.token_to_idx.get(token, self.unk)
def convert_tokens_to_ids(self, tokens): return [self[token] for token in tokens]
def convert_ids_to_tokens(self, indices): return [self.idx_to_token[index] for index in indices]
def save_vocab(vocab, path): with open(path, 'w') as writer: writer.write("\n".join(vocab.idx_to_token))
def read_vocab(path): with open(path, 'r') as f: tokens = f.read().split('\n') return Vocab(tokens)
|