The Llama 3 Herd of Models Part 1: Pre-training

23 Jul 2024

The Llama 3 Herd of Models, Llama Team, AI@Meta, 2024

The race of training Large Language Models (LLMs) has become a pipeline challenge rather than an architecture challenge, in my opinion (who doesn’t use Transformer with tonnes of GPUs?). Therefore, I find the recent Llama 3.1 model paper from Meta very interesting, because it detailedly outlines the efforts done by different workstreams to enable an improvement on top of other LLMs (by July 2024). The efforts are not rocket science, but a set of carefully thought data management strategies. Thus I tried to summarize those strategies here and skip the standard model training details and also the multimodal parts for a more concise read. Even so, the content is very overwhelming thus I decided to focus on pre-training today only.

A. Data for Pre-training

  1. Preprocessing Pre-training Web Data
    Text preprocessing plays a main role in improving the new Llama3. The main purpose is to clean and categorize data for better training quality and mix. In a nutshell, the recipe consists of:
    • A secret HTML parser that removes boilerplates and retains useful content at the same time.
    • A set of content de-duplications based on URLs, documents (based on sketching), and sentences.
    • A heuristic to further remove low quality text: (i) duplicated n-gram coverage ratio to remove system logs; (ii) “dirty word” counting to filter out NSFW contents; (iii) token-distribution KL divergence to filer outlier tokens.
    • A model finetuned on the clean website to generate quality scores for each website to determine whether we should include these websites.
    • Categorize web pages with Llama 2 finetuned with annotated web pages (again, the category is a secret sauce). The important categories are “math” and “STEM reasoning”.
    • Categorize multilingual web pages with fasttext based model.
  2. Determining Data Mix for Pre-training
    One important training aspect is to determine the proportions of each category of web data to be supplied to the pre-training. The steps include:
    • Downsample significantly over-represented categories (e.g. arts and entertainment) first.
    • Determine the scales for each category based on scaling laws. It first selects different data mix and trains on small models to predict the performance of larger models based on scaling law (details on scaling laws will be discussed below). After several trial-and-error on the data mix, the best mix so far is trained on larger models and evaluated.
    • Empirically, the data mix Llama3.1 uses has 50% of general knowledge, 25% of maths and reasoning, 17% code tokens, and 8% multilingual tokens.
  3. Data Annealing for Pre-training
    The order of data that are used in pretraining is important due to the annealing nature of the training process. Basically, it means the harder data (i.e. math, reasoning and high quality data) is introduced later in the training process with a smaller learning rate.

B. Architecture for Pre-training
The architecture is pretty standard (basically a Transformer), and I skipped the parts related to hardware. But even so, there are some interesting things to learn.

  1. Determining the Right Scale
    One interesting thing based on compute budgets and the size of training tokens is that we can figure out the right number of parameters of the LLMs based on Scaling Laws. In a nutshell:
    • It first figures out the correlation between the model’s negative log likelihood (NLL) on downstream tasks and training FLOPs. Essentially, it means for varying training budgets, it pre-trains small models with different numbers of tokens and measures the NLL for each combination. Thus for each training budget, we can find the optimal number of training tokens at different computes by fitting a second degree polynomial to the results. With these optimal numbers, we can fit a straight line to predict how many tokens could be fed to the pre-training stage given a large compute. Given their compute budget as \(3.8 \times 10^{25}\) FLOPs, the derived number of tokens is 16.55T tokens. I guess they assume that each parameter can store 40 tokens (literature suggests 1.7-20), thus the final model parameter is 405B.
    • The scaling law could also be used to justify the compute budget based on actual tasks’ accuracies. Given we can draw observations between NLL and different task’s accuracy from small models and previous Llama 2 models, we can fit a line between NLLs and the task accuracy to convey that e.g. (i) given \(3.8 \times 10^{25}\) FLOPs we can reach an NLL of 1.2; (ii) based on previous observations, an NLL of 1.2 implies an accuracy of 0.9 on task A.
  2. Training Recipe
    The training recipe is pretty standard, but also noteworthy:
    • Initial pre-training: first train with small batch sizes and a high learning rate to stabilize the model, then higher batch size and small learning rate for efficiency. The data mix also changes accordingly for different stages.
    • Long context pre-training: the long context data are trained in a later stage because the compute requirement is higher. They divided the data with different context lengths into six parts to carefully assess the training - whether the model could recover short context inputs and whether the current context length data are solved.
    • Annealing: Just like the previous discussion, harder data are being fed later. Also, the hard data are upsampled to help the model memorize the outputs.