Skip to content

Commit

Permalink
Merge branch 'main' into add-gpt4-turbo
Browse files Browse the repository at this point in the history
  • Loading branch information
Plexcalibur authored Jul 18, 2024
2 parents f7d71e4 e0bab9a commit 152d90e
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 10 deletions.
19 changes: 9 additions & 10 deletions lib/src/main/java/com/knuddels/jtokkit/GptBytePairEncoding.java
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 7,7 @@
import com.knuddels.jtokkit.api.EncodingResult;
import com.knuddels.jtokkit.api.GptBytePairEncodingParams;
import com.knuddels.jtokkit.api.IntArrayList;

import java.util.regex.Matcher;
import java.util.regex.Pattern;

Expand Down Expand Up @@ -44,7 45,7 @@ public EncodingResult encode(String text, int maxTokenCount) {

private InternalResult encodeInternal(String text, int maxTokenCount, boolean keepEncodings) {
if (text == null) {
return new InternalResult(new IntArrayList(0), false);
return new InternalResult(new IntArrayList(0), -1, false, -1);
}

specialEncoder.checkForSpecialTokens(text);
Expand All @@ -64,7 65,7 @@ public EncodingResult encodeOrdinary(String text, int maxTokenCount) {

private InternalResult encodeOrdinaryInternal(String text, int maxTokenCount, boolean keepEncodings) {
if (text == null) {
return new InternalResult(new IntArrayList(0), false);
return new InternalResult(new IntArrayList(0), -1, false, -1);
}

IntArrayList out = new IntArrayList();
Expand All @@ -81,12 82,12 @@ private InternalResult encodeOrdinaryInternal(String text, int maxTokenCount, bo
String decoded = decode(tokens);
if (text.startsWith(decoded)) {
// If decoded text is equal to the head of the original text, we can safely return the tokens
return new InternalResult(tokens, text.length() > decoded.length());
return new InternalResult(tokens, -1, text.length() > decoded.length(), decoded.length() - 1);
}
}
}

return new InternalResult(out, tokenCount, false);
return new InternalResult(out, tokenCount, false, text.length() - 1);
}

int encodeOrdinaryInternal(String text, int maxTokenCount, boolean keepEncodings, IntArrayList out) {
Expand Down Expand Up @@ -140,15 141,13 @@ private static final class InternalResult {
private final IntArrayList tokens;
private final boolean truncated;
private final int tokenCount;
private final int lastProcessedCharacterIndex; // -1 == text was null or string was empty()

private InternalResult(IntArrayList tokens, boolean truncated) {
this(tokens, -1, truncated);
}

private InternalResult(IntArrayList tokens, int tokenCount, boolean truncated) {
private InternalResult(IntArrayList tokens, int tokenCount, boolean truncated, int lastProcessedCharacterIndex) {
this.tokens = tokens;
this.truncated = truncated;
this.tokenCount = tokenCount < 0 ? tokens.size() : tokenCount;
this.lastProcessedCharacterIndex = lastProcessedCharacterIndex;
}

private EncodingResult toEncodingResult() {
Expand All @@ -158,7 157,7 @@ private EncodingResult toEncodingResult() {
);
}

return new EncodingResult(tokens, truncated);
return new EncodingResult(tokens, truncated, lastProcessedCharacterIndex);
}

private int toTokenCount() {
Expand Down
18 changes: 18 additions & 0 deletions lib/src/main/java/com/knuddels/jtokkit/api/EncodingResult.java
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 6,16 @@
public final class EncodingResult {
private final IntArrayList tokens;
private final boolean truncated;
private final int lastProcessedCharacterIndex;

public EncodingResult(final IntArrayList tokens, final boolean truncated) {
this(tokens, truncated, -1);
}

public EncodingResult(final IntArrayList tokens, final boolean truncated, final int lastProcessedCharacterIndex) {
this.tokens = tokens;
this.truncated = truncated;
this.lastProcessedCharacterIndex = lastProcessedCharacterIndex;
}

/**
Expand All @@ -30,11 36,23 @@ public boolean isTruncated() {
return truncated;
}

/**
* Returns the index of the last processed character in the input string
*
* @return the index of the last processed character in the input string, is -1 if text was null or empty
*/
public int getLastProcessedCharacterIndex() {
return lastProcessedCharacterIndex;
}



@Override
public String toString() {
return "EncodingResult{"
"tokens=" tokens
", truncated=" truncated
", lastProcessedCharacterIndex=" lastProcessedCharacterIndex
'}';
}
}
Original file line number Diff line number Diff line change
@@ -0,0 1,40 @@
package com.knuddels.jtokkit;

import com.knuddels.jtokkit.api.Encoding;
import com.knuddels.jtokkit.api.EncodingType;
import org.junit.jupiter.api.Test;

import static org.junit.jupiter.api.Assertions.assertEquals;

public class EncodingLastProcessedCharacterIndexTest {


private static final Encoding ENCODING = Encodings.newDefaultEncodingRegistry().getEncoding(EncodingType.CL100K_BASE);

@Test
void testNullInput() {
var encodingResult = ENCODING.encode(null, 10);
assertEquals(encodingResult.getLastProcessedCharacterIndex(), -1);
}

@Test
void testEmptyInput() {
String input = "";
var encodingResult = ENCODING.encode(input, 10);
assertEquals(encodingResult.getLastProcessedCharacterIndex(), -1);
}

@Test
void testShortInput() {
String input = "Hello World!";
var encodingResult = ENCODING.encode(input, 10);
assertEquals(encodingResult.getLastProcessedCharacterIndex(), 11);
}

@Test
void testLongInput() {
String input = "Lorem ipsum dolor sit amet, consectetur adipiscing elit. Fusce condimentum enim ac tellus malesuada, a consectetur nibh efficitur. 🚀🚀🚀";
var encodingResult = ENCODING.encode(input, 10);
assertEquals(encodingResult.getLastProcessedCharacterIndex(), 55);
}
}

0 comments on commit 152d90e

Please sign in to comment.