Implementation of Coconut, proposed by the paper Training Large Language Models to Reason in a Continuous Latent Space out of FAIR, in Pytorch
Architecture wise, the closest work to the one proposed here would be RMT, where the memory tokens there could serve as the continuous latent tokens. Both directions are worth exploring
$ pip install coconut-pytorch
import torch
from coconut_pytorch import Coconut
model = Coconut(
num_reasoning_steps = 3,
num_latents_per_step = 1,
transformer = dict(
num_tokens = 256,
dim = 512,
depth = 6
)
)
prompt = torch.randint(0, 256, (2, 1024))
answer = torch.randint(0, 256, (2, 64))
loss = model(prompt, answer)
loss.backward()
# after much training
answer = model.generate(prompt, max_length = 64) # (2, 64)
@inproceedings{Hao2024TrainingLL,
title = {Training Large Language Models to Reason in a Continuous Latent Space},
author = {Shibo Hao and Sainbayar Sukhbaatar and DiJia Su and Xian Li and Zhiting Hu and Jason Weston and Yuandong Tian},
year = {2024},
url = {https://api.semanticscholar.org/CorpusID:274610816}
}
@article{Burtsev2021MultiStreamT,
title = {Multi-Stream Transformers},
author = {Mikhail S. Burtsev and Anna Rumshisky},
journal = {ArXiv},
year = {2021},
volume = {abs/2107.10342},
url = {https://api.semanticscholar.org/CorpusID:236171087}
}