male-1: Hello and welcome to Byte-Sized Breakthroughs, the podcast that dives deep into the latest research papers, serving you in-depth understanding without the need for an advanced degree...or even reading the paper yourself! I'm your host, Alex Askwell. Today, we're tackling "Distillation Scaling Laws," a fascinating study that aims to make the training of smaller, more efficient language models more predictable and cost-effective. With me today, we have Dr. Paige Turner, the lead researcher on this work, and Prof. Wyd Spectrum, an expert in the field of large language models. Welcome to both of you! female-1: Thanks for having me, Alex. Happy to be here. female-2: It's a pleasure to join you, Alex. male-1: Dr. Turner, could you give us a brief overview of your paper? What problem were you trying to solve, and what was your general approach? female-1: Certainly. Our primary goal was to address the challenge of efficiently creating smaller, more manageable language models through a process called knowledge distillation. These smaller models are easier to deploy, consume less energy, and are cheaper to run. Distillation involves training a smaller “student” model to mimic the behavior of a larger, more capable “teacher” model. However, figuring out how to best allocate compute resources to both the teacher and the student during this process has always been somewhat of a black art. We wanted to make the process more predictable and effective. female-1: And what was your approach? female-1: We conducted a large-scale, controlled study, varying the sizes of both student and teacher models, as well as the amount of data used for training and distillation. This generated a lot of empirical data, from which we then derived a 'distillation scaling law.' This law basically allows us to estimate the student's performance based on factors like the teacher's performance, the student's size, and the amount of data used for distillation. male-1: Prof. Spectrum, could you provide some context for our listeners? Why is this research important, and how does it fit into the broader landscape of large language model development? female-2: Absolutely, Alex. Dr. Turner's work is situated at the intersection of several important trends in the field. First, there's the ongoing push for scaling laws, which aim to establish predictable relationships between model performance and factors like model size, training data, and compute. Landmark works from Kaplan et al. and Hoffmann et al. demonstrated these scaling laws in supervised learning. However, as models get larger, inference costs become a major concern. The cost of using the model ends up being much larger than the cost to train. This has led to interest in overtraining and distillation, with both approaches showing benefits. This paper provides a pathway to understanding where distillation can help best, as well as a framework for optimizing model design when you're working with limited compute. male-1: So, it's about striking a balance between model capability and practicality? female-2: Precisely. It's about finding the most efficient way to achieve a desired level of performance, while also considering factors like inference costs and accessibility. Distillation offers one route, but it needs to be done right, which is where Dr. Turner's research comes in. male-1: Dr. Turner, what would you say are the key contributions and innovations of your paper? female-1: I would point to a few things. The main innovation is the distillation scaling law itself, which is expressed in Equation 8 of the paper. This equation allows you to plug in numbers for the teacher and student, and get a relatively accurate prediction of the performance of the student model, in terms of its cross-entropy loss. This is something that really hasn't existed before in a rigorous, quantifiable way. female-1: And what’s so important about the cross-entropy loss? female-1: That’s a great question. The cross-entropy loss, in our context, it a measure of how well a language model predicts the next word in a sequence. Lower cross-entropy loss means the model makes better predictions, and in practice, this often translates to better performance on other downstream tasks. It is an ideal evaluation metric as it is a differentiable function that the models use to optimize their weights during training. male-1: Excellent, please proceed with the other key contributions. female-1: The second key contribution is identifying compute-optimal distillation recipes. We provide concrete guidance on how to allocate compute resources between the teacher and student models to achieve the best possible student performance. This includes recipes for scenarios where a teacher model already exists and for those where you need to train one from scratch. female-1: What do you mean, "recipes?" female-1: In this context, we mean that we can provide some direction such as “If you want to produce a model of a given size and you only have this amount of compute, then allocate this much compute to training the teacher, allocate this many tokens for the student distillation, and pick this size of model.” If you follow our recipe you will achieve optimal performance from your student. female-1: And the final key contribution? female-1: Our refined understanding of the 'capacity gap.' The capacity gap is the phenomenon where, somewhat counter-intuitively, simply making the teacher model bigger *doesn't* necessarily lead to a better student model. In fact, it can sometimes make the student *worse*. Our work shows that this isn't just about the size of the teacher, but also its 'learning capacity,' which is determined by how well-trained the teacher is and what it has access to, and it directly impacts its cross-entropy loss. male-1: That's fascinating. Let's dive deeper into the methodology. The paper mentions three experimental protocols: 'Fixed M Teachers/Student IsoFLOPs,' 'IsoFLOP Teachers/Fixed M Students,' and 'Fixed M Teachers/Fixed M Students.' Could you break down what each of those means? female-1: Sure. These are basically different ways we structured our experiments to collect data for fitting the distillation scaling law. female-1: What do you mean by "isoFLOP"? female-1: “IsoFLOP” refers to experiments where we kept the total computational cost (measured in FLOPs, or Floating Point Operations) constant, while varying other factors like model size and training data. So rather than picking a fixed value for model size and running the experiment, we put a hard cap on the number of FLOPs that we would spend, and let all the model sizes vary below that hard cap. It is a bit like saying “I only want to spend one thousand dollars on hardware and electricity, how can I best allocate my compute and hardware across training?”, then running the experiment to observe. female-1: What do each of those Fixed M experiments mean? female-1: Okay, "Fixed M Teachers/Student IsoFLOPs" experiments, we started with several 'teacher' models, each trained to a compute-optimal 'M-ratio' (tokens per parameter). Remember that it is the amount of training data divided by the number of parameters in that model. For example, if M is 20, it means the model trains on 20 tokens for each parameter that it has. And then distilled those teachers into student models at a couple of IsoFLOP profiles. female-1: Why a compute-optimal “M-ratio?” female-1: Previous research has shown there is typically an ideal amount of tokens to train a particular size of model on in the supervised setting. We wanted the teachers to be high performing, and so we chose that as a means to start in a reasonable spot. female-1: And "IsoFLOP Teachers/Fixed M Students?" female-1: In that setting we fixed the student sizes and tokens per parameter and then saw how changing the compute profile and training size for the teacher model affected the student model's performance. We are distilling a bunch of teachers into the *same* students. female-1: And finally, Fixed M Teachers/Fixed M Students? female-1: That was simply a set of experiments that varied each side independently to show a broader set of combinations. It is what allowed us to notice how the capacity gap shifted between the teacher and student. This type of experiment was most useful for a human to make conclusions, not for training the model. male-1: Thanks for clarifying! Once you had this experimental data, how did you actually fit the Distillation Scaling Law? What variables did you need to find values for? female-1: Essentially, we had a lot of data points, each representing a different distillation experiment with specific teacher and student configurations. Each experiment gives a data point that is used to fit the various coefficients for the Distillation Scaling Law equation. Given we had experimental observations for the cross-entropies for the student and teacher, the size of the teacher and student, and the number of distillation tokens in the student. Then we used an iterative process, implemented with a numerical optimizer, that tried combinations of coefficients (e.g. A’, B’, alpha, beta, etc.) from our equations and tried to find the set that best fit the different experiments that we ran. male-1: Can you describe the Scaling Law functional form more specifically? Break down the different terms of the key equation so our listeners can understand? female-1: Our scaling law, which is Equation 8, describes the student model cross-entropy L_S as L_S = L_T + (1 / (L_T^(c_0))) * (1 + (L_T / (L~(S) * d_1))^(1/f_1))^(-c_1 * f_1) * (A' * N_S^(α') + B' * D_S^(β'))^(γ'). Let's unpack it term by term. male-1: Please do! female-1: The first term, L_T, is the teacher’s cross-entropy loss. The second large term modifies that teacher cross entropy based on all kinds of things, such as coefficients, and size of the student. female-1: Could you go into more details on this modifier? female-1: The modifier is comprised of two main parts. The first is the 1 / (L_T^(c_0)) term, where c_0 is a fitted coefficient. This term says that in general a higher performing teacher has a beneficial effect of the student that decays at a polynomial rate, where that polynomial rate of decay is governed by our value for c_0. The other term, (1 + (L_T / (L~(S) * d_1))^(1/f_1))^(-c_1 * f_1) accounts for the fact that the student can underperform when the teacher is too strong. The L~(S) represents the student’s cross-entropy if it was just to train on all of the same training data that the student used. This accounts for the cases where, due to the capacity gap, the student may not be able to model the entire capability of the teacher and ends up overfitting. f_1 and c_1 are other fitted coefficients, and d_1 shows the ratio between L_T and L~(S). The last term, (A' * N_S^(α') + B' * D_S^(β'))^(γ') represents the ability for a student to model the cross-entropy based on its size and tokens, and it matches the standard intuition of models following a power law on those factors. male-1: Great, thank you! I would like to understand how your team compared against other key findings in this space. For example, how did your results compare to Beyer's 'patient teacher', Zhang's Capacity Gap and other models from the field? female-1: One of the interesting areas was that our results showed that distillation became less effective than supervised learning in the ‘patient’ setting. Our setup is slightly different than that of Beyer, since Beyer et al. used a vision dataset, and we primarily are working with a language dataset. In particular, they find that patience is ideal when there is consistency between the teacher and the student. Our setup also holds that invariant by always providing the same data to both student and teacher. The different result may be due to different computer vision architectures. female-1: Interesting! What about other differences between your approach and the other existing models? female-1: With regards to Zhang’s work on the Capacity Gap, his results have been observed with models under specific conditions: the teachers have to be compute-optimal under a supervised perspective, and he is also under a particular scale. That may not always be the case for smaller models with a weaker performance profile. Our method is more general in that it does not make any assumptions about the data profile, and is not limited to a specific range. male-1: Let's talk about the nuts and bolts of your experiments. What datasets, model architectures, and hardware did you use? female-1: For all our experiments, we used the English-only subset of the C4 dataset. It's a commonly used dataset for pretraining language models, which facilitates reproducibility and comparison. Our models were based on the architecture described in Gunter et al. We use a simplified version of µP (Maximal Update Parameterization). µP greatly simplified hyperparameter tuning. Finally, we always use a fixed aspect ratio to simplify our calculations. Since our models have a fixed aspect ratio we know that the hidden size is always equal to the layer count times 128. female-1: What does that mean, "fixed aspect ratio?" female-1: The aspect ratio refers to the relationship between the model dimension and the number of layers. If one doubles, then the other doubles. As the total model size increases, each layer also needs to increase in complexity. It helps us approximate the number of non-embedding parameters and total compute as it is applied to the scaling laws. Using this fixed ratio ends up improving performance, making the model size relate more directly to work being done by the model. male-1: The paper also mentions "strong-to-weak generalization." What does that mean, and how did it manifest in your results? female-1: In general, strong-to-weak generalizaiton means that the student with lower power has better performance than the larger model in a certain domain. Specifically, weak-to-strong generalization occurred only in the finite distillation data regime, meaning that this phenomenon was only active so long as we had not yet saturated training, and disappeared at infinitely large training runs. But in our setting, and only at the number of tokens that we set, we could observe that the student cross-entropy increases again and eventually matched the teacher cross-entropy. male-1: Let’s dive into the compute optimal setting. What does the setting represent, and what were the key insights? female-1: That setting represents us taking the scaling laws and trying to actually optimize the whole design for specific circumstances. First, we laid out the following cases that will result in a lower cross entropy loss: the best case, the teacher inference, teacher pre-training, and the teacher pre-training and inference. female-1: What does “best case” mean in this context? female-1: The best-case scenario assumes the teacher produces no additional FLOPs, which gives us freedom to choose the teacher. It’s like the teacher is already there and fully amortized, and we are completely unconstrained. However, it is not really useful as, in the real world, there are additional computing costs that are associated with the logit inference. What ends up happening in this case is that the optimal student always takes the compute. female-1: And in the case where we take teacher inference into account? female-1: In that case, we start to consider the number of teacher FLOPs that are needed as we infer. We assume here that the teacher is always available, but we are paying to use the teacher and so we account for the total amount of compute. What was found here was that, instead of doing compute with our model alone, we find that we have to compute all the signals that we have. As more compute becomes available, we make the teacher model larger. Once more computing power is added, it ends up taking the overtrained teacher model. female-1: Finally, can you address the cases where teacher pre-training and teacher pre-training with inference are included? female-1: The best student cross-entropy is always higher in this condition than it is in the supervised setting. One case to consider here is the goal of the teacher. Does it have uses beyond simply assisting in the production of a student model? If the intent is to simply distill into the student models, it is best to simply go to supervised learning instead of going through the burden of training a teacher model. male-1: Let’s switch gears a bit. What are the limitations of your work, and what future research directions do you see as most promising? female-1: One of the key limitations is our focus on the language modeling setting using a specific dataset. While we believe our findings are generalizable, we can't be certain that distillation behaves the same way in all domains. There is also the potential for bias, with the model's bias being inherited from the teacher model, even if the student model has been distilled on data that is unbiased. female-1: What should be researched to follow up on this model? female-1: There are a number of different pieces that can be considered in the future. The first is to verify the scaling properties of sequence-level knowledge in a controlled manner, similar to our model. We are also interested in combining distillation with supervised training, especially by discovering where we should transition the model. Furthermore, there are many other languages and domains that have not been assessed yet that should also be added. Finally, more efficient methods for distribution truncation need to be tested to reduce the inference cost without harming performance. male-1: Professor Spectrum, zooming out a bit, what do you see as the broader impact of this research, and what potential applications might arise from it? female-2: This research has the potential to significantly impact the field of AI in several ways. First, by providing compute-optimal recipes for distillation, it enables the creation of smaller, more efficient language models. This, in turn, reduces inference costs, making these models more accessible and deployable in resource-constrained environments. Second, smaller models also have a smaller carbon footprint, which is an increasingly important consideration. Third, making models smaller and more efficient will contribute to the democratization of AI, allowing a greater number of perspectives to study these models. However, there are also potential downsides. Smaller, more powerful language models could be used by bad actors to generate targeted misinformation more efficiently. So, there’s a dual-use aspect to consider. female-2: What kind of applications might we see? female-2: I'm really excited about the potential for these techniques to be used in smaller, local and more democratically driven organizations to run powerful inference. It is going to help edge devices run state-of-the-art models directly on phones, or at individual hospitals or clinics, for example. Further, in the case that a bad actor can generate misinformation at a lower cost due to efficient and small models, then it would also decrease the amount of resources that would be required to build a small, powerful, and safe, model. male-1: Well, we are just about out of time. Dr. Turner, if you had to summarize the key takeaways for our listeners, what would they be? female-1: The distillation scaling law is a valuable tool for resource allocation, enabling practitioners to make better-informed decisions about which models to use and how to use them. However, it is important to account for both the compute and data needed, and to only rely on supervised learning if you do not have a well-designed plan for the use of the teacher model, otherwise there will be additional costs that will require more computing power than if you just used supervised learning to begin with. male-1: Dr. Turner, Prof. Spectrum, thank you both for your time and insights! And thank you, listeners, for joining us on this episode of Byte-Sized Breakthroughs. We'll be back soon with another deep dive into the world of AI research.