The Llama 3 Herd of Models, Llama Team, AI@Meta, 2024
Okay let’s continue the reading of Llama-3.1 technical report. What happens after pre-training is post-training. It mainly consists of Reward Modeling, Supervised Finetuning (SFT), and Direct Preference Optimization (DPO). What is good nowadays is that these operations can be applied easily with huggingface’s libraries (see RewardTrainer, SFTTrainer and DPOTrainer), meaning we can understand how each technique works by inspecting its training data and inference formats.
A. Modeling with Reward Model, SFT, and DPO
- Background
- Reward Model: The training consists of inputting tuples of (prompt, chosen_answer, rejected_answer) into the model. After training, the model can be used to give a numerical score for a (prompt, answer) pair that indicates how good is the answer to the prompt.
- Supervised Finetuning (SFT): the training is similar to pre-training where the inputs are text-to-text tuples, and the output is similar (i.e. text generation). The model is usually a set of new weights (i.e. adaptor) on top of the current pre-trained model. The purpose is usually to add specific capabilities (e.g. following instructions) to the LLM.
- Direct Preference Optimization (DPO): this is a framework to further train the SFT model with preference data (i.e. the same input tuples to the Reward Model).
- Pipeline Overview
The way how different approaches work together is as follows. First, the Reward Model is trained with human annotated preference data. Then, it is used to select good synthetic data during SFT training iterations (i.e. rejection sampling). DPO is applied to the SFT to further improve the model at the end.
B. Post-training Data
The generation and curation of post-training data is perhaps the most important part of the whole pipeline. There are lots of analyses and insights that you might need to go over the full report, but I’ll try to summarize below.
- Preference Data To train the reward model, the human preference for various question answer pairs is needed. Human annotators are tasked to rank the outputs from different Llama 2 models. What’s special about other annotation processes is that they are also encouraged to edit the responses and provide additional edited output as the most preferred answers, thus enriching the preference dataset.
- SFT Data
The SFT data comes from three sources of data.
- Prompts from human annotation with rejected sampled responses: they first collect a set of prompts with different capabilities, then generate 10-30 outputs for each prompt from the latest best performing model checkpoints at each epoch of finetuning. Among those outputs, the reward model selects the output with the highest scores as ground truths for finetuning in the next epoch. Those outputs are also cleaned with some heuristics to increase the quality.
- Synthetic data targeting specific capabilities (see below)
- Human annotated data targeting specific capabilities (see below)
- Capability-specific Data
A capable LLM needs to solve many doman-specific problems. Thus, there are approaches to curate the data for different sets of expertise for finetuning.- Code: To prepare high quality dataset to train a model for Co-pilot capability, they first pre-train a code expert model (i.e. pre-train the pre-trained LLMs with code data and long context window). Using it and Llama 3, they generate 2.7M examples related to coding (e.g. ask Llama 3 to generate questions based on a code example and ask the code expert model to answer them). These examples are filtered by running static analysis and unit testing on the solutions. Only those that are correct or could be self corrected are kept.
- Multilinguality: the multilingual data composition is quite straightforward so I’d just copy the description from the report. “The overall distribution is 2.4% human annotations, 44.2% data from other NLP tasks, 18.8% rejection sampled data, and 34.6% translated reasoning data.”
- Math and Reasoning: the main challenge is the lack of such data as conversations or chain of thought processes. Thus they leverage Llama 3 to add reasoning and steps to the math problems and filter the incorrect ones with compilation errors in the code format (i.e. Python) and a reward model that is trained on steps annotated by humans. Also, the data is augmented using self correction from Llama 3.
- Long Context: the long context data comes from Llama 3 generated Question Answering (QA), summarization, and code reasoning examples.
- Tool Use: the tool use (i.e. search engine, Python interpreter, Wolfram Alpha) dataset is generated using mainly Llama 3. By providing the APIs to Llama 3, they generate examples that either use one tool, multiple tools, or uploaded files’ texts. The dataset is again filtered based on whether the APIs could be executed.
- Fact Checking: it basically uses Llama 3 as a judge and question generation method to select factual answers.
- Steerability: to make sure the model can switch context based on user input easily, human annotators are asked to provide prompts that steer and evaluate the consistency of these prompts. The preference data are used in DPO to improve the model’s steerability.