Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Question about Encoder Logic #87

Open
JackxTong opened this issue Jul 20, 2024 · 2 comments
Open

Question about Encoder Logic #87

JackxTong opened this issue Jul 20, 2024 · 2 comments

Comments

@JackxTong
Copy link

JackxTong commented Jul 20, 2024

I noticed the encode() method has extra logic with a while loop to find the lowest merge index:

    def encode(self, text):
        text_bytes = text.encode("utf-8") # raw bytes
        ids = list(text_bytes) # list of integers in range 0..255
        while len(ids) >= 2:
            stats = get_stats(ids)
            pair = min(stats, key=lambda p: self.merges.get(p, float("inf")))
            if pair not in self.merges:
                break # nothing else can be merged anymore
            idx = self.merges[pair]
            ids = merge(ids, pair, idx)
        return ids

Can we simplify it like this:

    def encode(self, text):
        tokens = text.encode("utf-8")
        tokens = list(map(int, tokens))
        for pair, index in self.merges.items():
            tokens = merge(tokens, pair, index)
        return tokens

Since merge() merges all occurrences, it seems a simple for loop suffices. Is there a reason for the more complex logic?
I have trained my tokenizer vs the basictokenizer on some text data, and achieved the exact same vocab & encoder.
Maybe I missed something. Could you clarify?

Thanks!

Update:
I made a pytest from my forked repo just to show mine is also correct:
For anyone interested to try out

@202030481266
Copy link

I think it's a little bit different, but the effect should be the same (your python version should be higher than 3.7). Your implementation completely iterates over all merge items, but the original code can jump out. I think the reason the original code was written this way was to prevent the dictionary order might not be in the order it was added. Karpathy seems to mention this in the video, but the issue was fixed in py3.8.

@alexandermorgan
Copy link

As @202030481266 mentioned, your simpler version iterates over all of the merges made in the vocabulary. For a realistic tokenizer this is a lot of merges (~50k for GPT2, 200k for GPT4o) so at a practical scale your approach would require a lot more work and most of the merges applied would not even be in the chunk of text being processed.
But there is a little overly complex thing in Karpathy's code here. He calls get_stats inside of encode but only uses the keys from the get_stats dictionary. Since we're only using the keys here, there's no sense in going through the trouble of calculating the values (which is the point of get_stats). So instead of using get_stats(ids) it would be a lot less work to line up the consecutive pairs like this zip(ids, ids[1:]). Even if the ids list is only one element long that will still work correctly without throwing an out of range error.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants