BPE Tokenizer Implementation Exercise
Partial solution to BPE Tokenizer Implementation Exercise from Andrej Karpathy.
Corresponding youtube video on the tokenizer topic.
import regex
import requests
from collections import Counter
GPT4_SPLIT_PATTERN = r"""'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]++[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+"""
SHAKESPEAR_TEXT_URL = "https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt"
class BPETokenizer:
def __init__(self, *, pattern=GPT4_SPLIT_PATTERN, special_tokens=[]):
self.pattern = pattern # regex pattern to split text into words
self.special_tokens = special_tokens # pre-allocated special tokens
self.vocab = self._init_vocab() # map byte to token id
self.itob = {} # the reverse map of vocab, map token id to byte
def _init_vocab(self) -> dict:
vocab = {}
for i in range(2**8):
vocab[bytes([i])] = i
for special_token in self.special_tokens:
vocab[bytes(special_token.encode("utf-8"))] = len(vocab)
return vocab
def _get_stats(self, bytes_of_words: list[list[bytes]]) -> Counter:
counts = Counter()
# count the frequencey of each adjacent byte pairs result stat will be used to find merge rules.
for bytes_of_word in bytes_of_words:
for byte_pair in zip(bytes_of_word, bytes_of_word[1:]):
counts[byte_pair] += 1
return counts
# split the text into words then for each word further split into bytes.
def _parse_text(self, text: str) -> list[list[bytes]]:
return [
[bytes([b]) for b in word.encode("utf-8")]
for word in regex.findall(self.pattern, text)
]
def train(self, text: str, vocab_size: int, verbose: int = 0):
bytes_of_words = self._parse_text(text)
num_merges = vocab_size - len(self.vocab)
if verbose:
print(f"total {num_merges} merges to learn")
for step in range(num_merges):
# find the merge
counts = self._get_stats(bytes_of_words)
pair_to_merge = max(counts.keys(), key=counts.get)
byte_pair = b"".join(pair_to_merge)
self.vocab[byte_pair] = len(self.vocab)
# apply the merge to training data
temp_bytes_of_words = []
for bytes_of_word in bytes_of_words:
temp_bytes_of_word = []
just_merged = False
for first, second in zip(bytes_of_word, bytes_of_word[1:]):
if just_merged:
just_merged = False
continue
if (first, second) == pair_to_merge:
temp_bytes_of_word.append(byte_pair)
just_merged = True
else:
temp_bytes_of_word.append(first)
if not just_merged:
temp_bytes_of_word.append(bytes_of_word[-1])
temp_bytes_of_words.append(temp_bytes_of_word)
bytes_of_words = temp_bytes_of_words
if verbose and (step + 1) % verbose == 0:
print(
f"merge discovered at step {step + 1} is : ",
f"{pair_to_merge[0]} + {pair_to_merge[1]} -> {byte_pair}",
)
def encode(self, text):
bytes_of_words = self._parse_text(text)
for bytes_of_word in bytes_of_words:
# speed this up? only one instance of the lowest rank pair gets updated each time
while True:
min_idx = min_rank = merged_bytes = None
# find the mergeable byte pairs with the lowest rank
for i, byte_pair in enumerate(zip(bytes_of_word, bytes_of_word[1:])):
rank = self.vocab.get(byte_pair, None)
if rank is None:
continue
if min_rank is None or min_rank > rank:
min_rank = rank
min_idx = i
merged_bytes = b"".join(byte_pair)
if min_rank is None:
break
bytes_of_word = (
bytes_of_word[:min_idx] + [merged_bytes] + bytes_of_word[min_idx + 2:]
)
token_ids = [
self.vocab[b] for bytes_of_word in bytes_of_words for b in bytes_of_word
]
return token_ids
def decode(self, token_ids):
if not self.itob: self.itob = {v:k for k,v in self.vocab.items()}
return b"".join((self.itob[i] for i in token_ids)).decode("utf-8")
def save(self):
pass
def load(self):
pass
def read_text_from_url(url):
try:
response = requests.get(url)
response.raise_for_status() # Raise HTTPError for bad responses (4xx or 5xx)
return response.text
except requests.exceptions.RequestException as e:
print(f"Error fetching data from URL: {e}")
return None
if __name__ == "__main__":
text = read_text_from_url(SHAKESPEAR_TEXT_URL)
a = BPETokenizer(special_tokens=["<bos>", "<eos>", "<pad>", "<unk>"])
a.train(text, vocab_size=1024, verbose=100)
encoded = a.encode(text[:512])
print(text[:512], "\n encoded as: \n", encoded)
decoded = a.decode(encoded)
print("decoded: ", decoded)
print("equal to original text? ", decoded == text[:512])
Some random after thoughts:
- The text used to train the tokenizer should ideally match the training/inference text distribution. If the training and inference distribution are quite different, maybe use a separated tokenizer. For example, the output is English comment and code only, while the input can be multi-language and more descriptive of the code we want to generate. Can we use a tokenizer of a smaller vocab size for output?
- If a token id is never seen during the training run, its embedding will be random, prompting the model with such a token will cause undefined behavior. Eg., “solidgoldmagikarp”. Maybe run frequency counter on the token id seen during training, reject bad input, or reserve a unk token and map bad token to it?
- For multilingual models, the tokenizer might be an important factor in determining the less performant language. Balancing the language mixture in tokenizer training data may help.
- Larger vocabulary size leads to shorter encoded sequences, which allows more information to be retrained in the limited context window and, therefore, improves performance. On the flip side, it will require more memory for training and make the softmax more expensive at inference.