andy's blog

tiny-BPE, a minimal parallelized implementation of BPE

Before your favorite transformer model can understand language, it first has to tokenize it.

BPE or Byte Pair Encoding, is the tokenization algorithm behind GPT2 in pre-training. Here we will be implementing a parallel and minimal implementation of it from scratch in less than 200 lines of python code.

a gentle intro to BPE

before we dive into any code we'll understand the algorithm and a bit of its advantages / disadvantages with a small example.

1. pre-tokenization

input : raw text corpus (or chunked text if you have very big text files / data)

process : break down the raw text into an initial sequence of "base units" before BPE merging begins. these base units are typically chosen to be linguistically sensible (like words, punctuation, or individual characters/bytes).

this is where decisions about pre-tokenization schemes occur : byte-level, character-level, or whitespace-based splitting etc.

the choice of pre-tokenization is crucial because it defines the initial set of symbols that BPE will start merging from.

        PAT = r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""

output : a training corpus represented as a sequence of these base units obtained after pre-tokenization.



2. initial vocabulary creation

input : the sequences obtained after pre-tokenization of the corpus.

process : collect all unique base units from the input.

output : an initial vocabulary (set of unique tokens) and typically an initial mapping of these tokens to integer IDs.



3. calculate initial pair frequencies

input : the corpus (as sequence of initial tokens after pre-tokenization)

process : iterate through the corpus and count the occurrences of every adjacent pair of tokens.

output : a collection of all unique adjacent pairs and their frequencies.



4. iterative merging (the core BPE loop)

this is where most of the meat is.

input : the current corpus, current vocabulary, and current pair frequencies.

process (repeated until desired vocab_size or number of merges is reached) :

  1. find most frequent pair : identify the pair (A, B) with the highest frequency.

  2. create new token : form a new token AB (representing the sequence A followed by B).

  3. add to vocabulary & assign ID : add AB to the vocabulary and assign it a new, unique integer ID.

  4. update corpus : replace all occurrences of the pair (A, B) in the corpus with the new token AB.

  5. update pair frequencies : recalculate (or incrementally update) the frequencies of all adjacent pairs in the modified corpus. this is crucial because replacing (A, B) might create new adjacent pairs (e.g., (X, AB) and (AB, Y)) and remove others.



5. final vocabulary and encoding map :

input : the final vocabulary and all learned merges.

process : store the final set of tokens and their assigned integer IDs. the sequence of merges learned implicitly defines how any new text will be tokenized.

output : a tokenizer model that can apply the learned merges to new input text, converting it into a sequence of token IDs.

implementing the algorithm

chunking a file to obtain initial pair frequencies



but let's import the dependencies and establish our special token(s) first:

import os
import re
from collections import defaultdict
from typing import BinaryIO
from multiprocessing import Pool

# GPT-2 regex pre-tokenizer
PAT = r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""
TOKENIZER = re.compile(PAT, re.IGNORECASE)
SPECIAL_TOKEN = "<|endofdoc|>"  # marker for splitting documents



chunking function : splits input into manageable sublists of tokens.

when dealing with huge corpora, running BPE on the entire dataset at once is inefficient. instead, we divide the text into smaller chunks. each chunk is just a slice of the corpus (like splitting a big book into chapters).

this makes processing more manageable and avoids memory bottlenecks.

parallel counting function : each worker counts pair frequencies inside its chunk.

once we have chunks, we can distribute them across multiple CPU cores using Python’s multiprocessing.Pool. each worker process independently applies regex-based pre-tokenization and computes pair statistics. workers don’t interfere with each other — they just work on their assigned chunk.

after finishing, they return intermediate results.

reduction step : merge all results into a global counter. after parallel execution, we need to merge the results back into a single global state. pair frequencies from all chunks are aggregated.

the most frequent pairs are then selected globally for the next merge iteration.

this mimics MapReduce-style processing: Map phase → each worker counts pairs in a chunk. Reduce phase → results are aggregated.

def find_chunk_boundaries(
    file: BinaryIO,
    desired_num_chunks: int,
    split_special_token: bytes
) -> list[int]:
    """
    Find safe chunk boundaries in a file for parallel pre-tokenization.
    Ensures we don't split in the middle of a document.
    """
    assert isinstance(split_special_token, bytes), "split_special_token must be bytes"

    file.seek(0, os.SEEK_END)
    file_size = file.tell()
    file.seek(0)

    chunk_size = file_size // desired_num_chunks
    chunk_boundaries = [i * chunk_size for i in range(desired_num_chunks + 1)]
    chunk_boundaries[-1] = file_size

    mini_chunk_size = 4096  # read-ahead buffer

    for bi in range(1, len(chunk_boundaries) - 1):
        pos = chunk_boundaries[bi]
        file.seek(pos)
        while True:
            mini_chunk = file.read(mini_chunk_size)
            if not mini_chunk:  # reached EOF
                chunk_boundaries[bi] = file_size
                break
            found_at = mini_chunk.find(split_special_token)
            if found_at != -1:
                chunk_boundaries[bi] = pos + found_at
                break
            pos += mini_chunk_size

    return sorted(set(chunk_boundaries))


def process_chunk(args):
    """Worker function: tokenize a file chunk and count token frequencies."""
    filename, start, end = args
    counts = defaultdict(int)

    with open(filename, "rb") as f:
        f.seek(start)
        chunk = f.read(end - start).decode("utf-8", errors="ignore")

    docs = re.split(re.escape(SPECIAL_TOKEN), chunk)
    for doc in docs:
        for match in TOKENIZER.finditer(doc):
            token = match.group()
            counts[token] += 1

    return counts


def merge_counts(dicts):
    """Reducer: merge token frequency dictionaries from different workers."""
    merged = defaultdict(int)
    for d in dicts:
        for k, v in d.items():
            merged[k] += v
    return merged


def run_pre_tokenization_parallel(filename: str, num_processes: int = 4):
    """Run pre-tokenization in parallel across file chunks."""
    with open(filename, "rb") as f:
        boundaries = find_chunk_boundaries(f, num_processes, SPECIAL_TOKEN.encode("utf-8"))

    args = [(filename, s, e) for s, e in zip(boundaries[:-1], boundaries[1:])]

    with Pool(num_processes) as pool:
        partial_counts = pool.map(process_chunk, args)

    return merge_counts(partial_counts)



writing a function to obtain pair frequencies

once we’ve got our initial vocabulary and tokenized corpus, the next step is to look for the most promising merge candidates.

this step is crucial because BPE’s core idea is to merge the most frequent symbol pairs to build subword units.

def get_pair_counts(self, freqs: dict[tuple[tuple[int, ...]], int]) -> dict[tuple[int, int], int]:
        """Count frequency of all adjacent byte pairs."""
        pair_counts = defaultdict(int)
        for sequence, count in freqs.items():
            for i in range(len(sequence) - 1):
                pair_counts[(sequence[i], sequence[i + 1])] += count
        return pair_counts



the bpe merging (or training) loop

this is the heart of the BPE algorithm as this loop is where BPE learns its vocabulary.

we start with raw bytes as our base units.

on each iteration:

over many iterations, the tokenizer builds a compact vocabulary: common patterns like th, ing, or even full words get merged into single tokens. rare words remain split into smaller subwords.

this iterative process is why BPE balances efficiency (smaller vocabularies) with flexibility (still handles unseen words or out-of-vocabulary words).

   def train(self, filename: str, vocab_size: int, num_processes: int = 4):
        """Train BPE merges and vocab from corpus."""
        pre_token_counts = self.run_pre_tokenization_parallel(filename, num_processes)

        # Initialize vocab with byte sequences
        freqs: dict[tuple[tuple[int, ...]], int] = {}
        for token, count in pre_token_counts.items():
            token_bytes = token.encode("utf-8")
            token_seq = tuple((b,) for b in token_bytes)
            freqs[token_seq] = count
            for b in token_seq:
                if b not in self.vocab:
                    self.vocab[b] = self.next_id
                    self.next_id += 1

        # Iteratively learn merges
        while len(self.vocab) < vocab_size:
            pair_counts = self.get_pair_counts(freqs)
            if not pair_counts:
                break

            max_pair = max(pair_counts.items(), key=lambda x: (x[1], x[0]))[0]
            merged_token = max_pair[0] + max_pair[1]

            if merged_token in self.vocab:
                continue

            self.vocab[merged_token] = self.next_id
            self.next_id += 1
            self.merges.append(max_pair)

            # Replace max pair in sequences
            new_freqs = {}
            for seq, freq in freqs.items():
                new_seq = []
                i = 0
                while i < len(seq):
                    if i < len(seq) - 1 and (seq[i], seq[i + 1]) == max_pair:
                        new_seq.append(merged_token)
                        i += 2
                    else:
                        new_seq.append(seq[i])
                        i += 1
                new_freqs[tuple(new_seq)] = freq
            freqs = new_freqs

        # Build inverse vocab for decoding
        self.inv_vocab = {idx: tok for tok, idx in self.vocab.items()}



the encode method

once training is done, we can tokenize new text.

def encode(self, text: str) -> list[int]:
        """Encode text into BPE token IDs."""
        pre_tokens = re.findall(self.PAT, text)
        token_ids = []
        merge_set = set(self.merges)

        for token in pre_tokens:
            byte_seq = [(b,) for b in token.encode("utf-8")]

            while True:
                merged = False
                i = 0
                new_seq = []
                while i < len(byte_seq):
                    if i < len(byte_seq) - 1 and (byte_seq[i], byte_seq[i+1]) in merge_set:
                        new_seq.append(byte_seq[i] + byte_seq[i+1])
                        i += 2
                        merged = True
                    else:
                        new_seq.append(byte_seq[i])
                        i += 1
                byte_seq = new_seq
                if not merged:
                    break

            token_ids.extend(self.vocab[tok] for tok in byte_seq)

        return token_ids



the decode method

def decode(self, token_ids: list[int]) -> str:
        """Decode token IDs back into string."""
        byte_seq = []
        for tid in token_ids:
            tok_bytes = self.inv_vocab[tid]
            byte_seq.extend(tok_bytes if isinstance(tok_bytes, tuple) else [tok_bytes])

        # Convert list of ints back into bytes → decode utf-8
        return bytes(byte_seq).decode("utf-8", errors="ignore")



combining everything into a single class

here’s the full implementation tied together in a single BPETokenizer class, so you can train, encode, and decode with just a few calls.

import os
import re
from collections import defaultdict
from typing import BinaryIO
from multiprocessing import Pool


class BPETokenizer:
    # GPT-2 regex pre-tokenizer
    PAT = r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""
    TOKENIZER = re.compile(PAT, re.IGNORECASE)
    SPECIAL_TOKEN = "<|endofdoc|>"

    def __init__(self):
        self.vocab: dict[tuple[int, ...], int] = {}
        self.inv_vocab: dict[int, tuple[int, ...]] = {}
        self.merges: list[tuple] = []
        self.next_id: int = 0

    # ---------- Chunking for parallel pre-tokenization ----------
    def find_chunk_boundaries(
        self, file: BinaryIO, desired_num_chunks: int, split_special_token: bytes
    ) -> list[int]:
        """Find safe chunk boundaries in a file for parallel pre-tokenization."""
        assert isinstance(split_special_token, bytes), "split_special_token must be bytes"

        file.seek(0, os.SEEK_END)
        file_size = file.tell()
        file.seek(0)

        chunk_size = file_size // desired_num_chunks
        chunk_boundaries = [i * chunk_size for i in range(desired_num_chunks + 1)]
        chunk_boundaries[-1] = file_size

        mini_chunk_size = 4096

        for bi in range(1, len(chunk_boundaries) - 1):
            pos = chunk_boundaries[bi]
            file.seek(pos)
            while True:
                mini_chunk = file.read(mini_chunk_size)
                if not mini_chunk:  # EOF
                    chunk_boundaries[bi] = file_size
                    break
                found_at = mini_chunk.find(split_special_token)
                if found_at != -1:
                    chunk_boundaries[bi] = pos + found_at
                    break
                pos += mini_chunk_size

        return sorted(set(chunk_boundaries))

    def _process_chunk(self, args):
        """Worker: tokenize file chunk and count token frequencies."""
        filename, start, end = args
        counts = defaultdict(int)

        with open(filename, "rb") as f:
            f.seek(start)
            chunk = f.read(end - start).decode("utf-8", errors="ignore")

        docs = re.split(re.escape(self.SPECIAL_TOKEN), chunk)
        for doc in docs:
            for match in self.TOKENIZER.finditer(doc):
                token = match.group()
                counts[token] += 1
        return counts

    def _merge_counts(self, dicts):
        merged = defaultdict(int)
        for d in dicts:
            for k, v in d.items():
                merged[k] += v
        return merged

    def run_pre_tokenization_parallel(self, filename: str, num_processes: int = 4):
        """Parallel pre-tokenization across file chunks."""
        with open(filename, "rb") as f:
            boundaries = self.find_chunk_boundaries(
                f, num_processes, self.SPECIAL_TOKEN.encode("utf-8")
            )

        args = [(filename, s, e) for s, e in zip(boundaries[:-1], boundaries[1:])]

        with Pool(num_processes) as pool:
            partial_counts = pool.map(self._process_chunk, args)

        return self._merge_counts(partial_counts)

    # ---------- Training ----------
    def get_pair_counts(self, freqs: dict[tuple[tuple[int, ...]], int]) -> dict[tuple[int, int], int]:
        """Count frequency of all adjacent byte pairs."""
        pair_counts = defaultdict(int)
        for sequence, count in freqs.items():
            for i in range(len(sequence) - 1):
                pair_counts[(sequence[i], sequence[i + 1])] += count
        return pair_counts

    def train(self, filename: str, vocab_size: int, num_processes: int = 4):
        """Train BPE merges and vocab from corpus."""
        pre_token_counts = self.run_pre_tokenization_parallel(filename, num_processes)

        # Initialize vocab with byte sequences
        freqs: dict[tuple[tuple[int, ...]], int] = {}
        for token, count in pre_token_counts.items():
            token_bytes = token.encode("utf-8")
            token_seq = tuple((b,) for b in token_bytes)
            freqs[token_seq] = count
            for b in token_seq:
                if b not in self.vocab:
                    self.vocab[b] = self.next_id
                    self.next_id += 1

        # Iteratively learn merges
        while len(self.vocab) < vocab_size:
            pair_counts = self.get_pair_counts(freqs)
            if not pair_counts:
                break

            max_pair = max(pair_counts.items(), key=lambda x: (x[1], x[0]))[0]
            merged_token = max_pair[0] + max_pair[1]

            if merged_token in self.vocab:
                continue

            self.vocab[merged_token] = self.next_id
            self.next_id += 1
            self.merges.append(max_pair)

            # Replace max pair in sequences
            new_freqs = {}
            for seq, freq in freqs.items():
                new_seq = []
                i = 0
                while i < len(seq):
                    if i < len(seq) - 1 and (seq[i], seq[i + 1]) == max_pair:
                        new_seq.append(merged_token)
                        i += 2
                    else:
                        new_seq.append(seq[i])
                        i += 1
                new_freqs[tuple(new_seq)] = freq
            freqs = new_freqs

        # Build inverse vocab for decoding
        self.inv_vocab = {idx: tok for tok, idx in self.vocab.items()}

    # ---------- Encoding / Decoding ----------
    def encode(self, text: str) -> list[int]:
        """Encode text into BPE token IDs."""
        pre_tokens = re.findall(self.PAT, text)
        token_ids = []
        merge_set = set(self.merges)

        for token in pre_tokens:
            byte_seq = [(b,) for b in token.encode("utf-8")]

            while True:
                merged = False
                i = 0
                new_seq = []
                while i < len(byte_seq):
                    if i < len(byte_seq) - 1 and (byte_seq[i], byte_seq[i+1]) in merge_set:
                        new_seq.append(byte_seq[i] + byte_seq[i+1])
                        i += 2
                        merged = True
                    else:
                        new_seq.append(byte_seq[i])
                        i += 1
                byte_seq = new_seq
                if not merged:
                    break

            token_ids.extend(self.vocab[tok] for tok in byte_seq)

        return token_ids

    def decode(self, token_ids: list[int]) -> str:
        """Decode token IDs back into string."""
        byte_seq = []
        for tid in token_ids:
            tok_bytes = self.inv_vocab[tid]
            byte_seq.extend(tok_bytes if isinstance(tok_bytes, tuple) else [tok_bytes])

        # Convert list of ints back into bytes → decode utf-8
        return bytes(byte_seq).decode("utf-8", errors="ignore")



byte pair encoding (BPE) is deceptively simple, yet incredibly powerful in how it builds compact token vocabularies. in our implementation, we walked through pair frequency calculation, iterative merging, and encoding/decoding, seeing firsthand how repeated subword merges emerge into meaningful tokens.

that said, there’s always room for optimization, especially in the merge step. right now, pair frequencies are recomputed from scratch each iteration, which isn’t the most efficient.

using techniques like reverse indexing or a priority queue to track the most frequent pairs can significantly speed things up for larger corpora, but that's for another blog or you can try to implement that yourself!

thank you for reading; i hope you enjoyed the breakdown and picked up something new today. i’d love to hear your thoughts, suggestions, or any cool ideas for improvement.

also you can find the complete code here.