☁️ Google Drive (Data)
📄arXiv •
𝕏 Blog •
🌐 Web
🤗 HF (Model)👇 •
🔭 Model Scope (Model)👇 •
🧊 Wise Model (Model)👇
🎯 Task Name | 🤗 HuggingFace | 🔭 ModelScope | 🧊 WiseModel |
---|---|---|---|
Entity Linking | Llama2-7B | Llama2-7B | Llama2-7B |
Single-hop QA | Llama2-7B | Llama2-7B | Llama2-7B |
Multi-hop QA | Llama2-7B | Llama2-7B | Llama2-7B |
- Support LoRA train
- Code documentation
- Support vLLM inference
- Support distributed embedding
- Gradio
We introduce a One-pass Generation and retrieval framework (OneGen) for fine-tuning LLMs on generation, retrieval, or hybrid tasks. Our core idea is to integrate generation and retrieval to the same context by allocating the retrieval task to retirval tokens generated in an autoregressive manner, thus enabling LLM to perform both tasks in a single forward pass.
The following figure illustrates the training process. We first introduce the concept named roles of tokens in LLMs
. A token
-
Generating next token, noted as
$role(x_i)=\texttt{GEN}$ . -
Providing context information, noted as
$role(x_i)=\texttt{CTX}$ . -
Representing a sentence, noted as
$role(x_i)=\texttt{RET}$ .
Hence, we apply the cross-entropy loss for the token
The following figure illustrates the inference process of different methods for RAG task. First, we can see both GritLM and OneGen only need to deploy a single model, which can lower the deployment cost. However, GritLM achieves generation and retrieval within a single model by switching back and forth between causal attention and bidirectional attention. Additionally, both GritLM and the Pipeline method require explicit queries, which leads to the need for two forward passes for the queries. In contrast, OneGen can perform retrieval during the generation process, thus avoiding the two forward pass calculations for the queries and allowing for the direct use of kv-cache, significantly reducing inference costs.
git clone https://github.com/zjunlp/OneGen
cd OneGen
conda create -n onegen python=3.9 -y
conda activate onegen
pip install -r requirements.txt
The inference section focuses on running model predictions to get output results (Single-hop QA is an exception). The evaluation of these results is discussed in the Evaluation section.
Download train_data.tar.gz
and eval_data.tar.gz
from Google Drive. After extracting, you will get two folders: train_data
and eval_data
. Move these two folders into the data
directory. Use the following commands to extract the files:
tar -xzvf train_data.tar.gz
tar -xzvf eval_data.tar.gz
Please note that the training data we are using is available on Hugging Face, so you do not need to download train_data.tar.gz
. Just run the training scripts!
Download the trained model (Optional)
The model weights trained on three tasks have been made public and are available for download on three platforms: 🤗Huggingface
, 🔭ModelScope
, and 🧊WiseModel
. For detailed information, please refer to the table below:
🎯 Task Name | 🤗 HuggingFace | 🔭 ModelScope | 🧊 WiseModel |
---|---|---|---|
Entity Linking | Llama2-7B | Llama2-7B | Llama2-7B |
Single-hop QA | Llama2-7B | Llama2-7B | Llama2-7B |
Multi-hop QA | Llama2-7B | Llama2-7B | Llama2-7B |
Note
It is worth noting that for the Entity Linking task, we have pre-stored the entity embeddings. Click here to download them.
Training model from scratch (Optional)
We provide the training scripts for three tasks. If you are using a locally downloaded model, you can modify the info-model
field in the workflow/{task}/{model}.json
file. Update the model_path
and tokenizer_path
with the local paths. Note that the hyperparameters in the configuration files are set for 8xA800 GPUs. If you encounter OOM (Out of Memory) issues, please reduce the per_device_train_batch_size
, n_pos_per_sent
, n_neg_per_pos
, and max_length
.
# Entity Linking
deepspeed train.py --workflow workflow/entity_linking/llama2.json
# Single-Hop QA
deepspeed train.py --workflow workflow/self_rag/llama2.json
# Multi-hop QA
deepspeed train.py --workflow workflow/multi_hop_qa/llama2.json
Here are the inference scripts for the Entity Linking and Multi-hop QA tasks. The inference script for Single-Hop QA is introduced in the next section. You can modify the values of fields such as model_path
, tokenizer_path
, file
, and output_file_path
in {config}/{eval_config}/{task}/{config}.json
as needed.
# Entity Linking (Need GPU)
python eval.py --config config/eval_config/entity_linking/llama2_wo_pkl.json
# Multi-hop QA (Need GPU)
python eval.py --config config/eval_config/multi_hop_qa/llama2.json
Below are the evaluation scripts for the Entity Linking and Multi-hop QA tasks. /your/path/to/result.jsonl
is the file saved during the inference stage.
# Entity Linking (CPU)
bash scripts/eval_el.sh el /your/path/to/result.jsonl
# Multi-hop QA for HotpotQA dataset (CPU)
bash scripts/eval_multi_hop_qa.sh /your/path/to/result.jsonl hotpotqa
# Multi-hop QA for 2WIKI dataset (CPU)
bash scripts/eval_multi_hop_qa.sh /your/path/to/result.jsonl 2wiki
Here is the evaluation for the Single-Hop QA task, mainly based on Self-RAG:
# Single-hop QA using Self-RAG (Need GPU)
# [CUDA_VISIBLE_DEVICES] [MODE] [MODEL_PATH] [SAVE_TAG] [SAVED_DATASET_PATH] [N_DOC] [ENV] [SCORE]
bash scripts/eval_self_rag.sh 0 always_retrieve /your/path/to/model model_tag saved_rank_path 5 true true
If this work is helpful, please kindly cite as:
@misc{zhang2024onegen,
title={OneGen: Efficient One-Pass Unified Generation and Retrieval for LLMs},
author={Jintian Zhang and Cheng Peng and Mengshu Sun and Xiang Chen and Lei Liang and Zhiqiang Zhang and Jun Zhou and Huajun Chen and Ningyu Zhang},
year={2024},
eprint={2409.05152},
archivePrefix={arXiv},
primaryClass={cs.CL},
url={https://arxiv.org/abs/2409.05152},
}