0%

NLP字典

Snippets

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
from collections import defaultdict, Counter

class Vocab(object):
def __init__(self, tokens=None):
self.idx_to_token = list()
self.token_to_idx = dict()

if tokens is not None:
if "<unk>" not in tokens:
tokens = tokens + ["<unk>"]
for token in tokens:
self.idx_to_token.append(token)
self.token_to_idx[token] = len(self.idx_to_token) - 1
self.unk = self.token_to_idx['<unk>']

@classmethod
def build(cls, text, min_freq=1, reserved_tokens=None):
token_freqs = defaultdict(int)
for sentence in text:
for token in sentence:
token_freqs[token] += 1
uniq_tokens = ["<unk>"] + (reserved_tokens if reserved_tokens else [])
uniq_tokens += [token for token, freq in token_freqs.items() \
if freq >= min_freq and token != "<unk>"]
return cls(uniq_tokens)

def __len__(self):
return len(self.idx_to_token)

def __getitem__(self, token):
return self.token_to_idx.get(token, self.unk)

def convert_tokens_to_ids(self, tokens):
return [self[token] for token in tokens]

def convert_ids_to_tokens(self, indices):
return [self.idx_to_token[index] for index in indices]


def save_vocab(vocab, path):
with open(path, 'w') as writer:
writer.write("\n".join(vocab.idx_to_token))


def read_vocab(path):
with open(path, 'r') as f:
tokens = f.read().split('\n')
return Vocab(tokens)