How to fine-tune your embeddings for better similarity search
This blog post will share our experience with fine-tuning sentence embeddings on a commonly available dataset using similarity learning. We additionally explore how this could benefit the labeling workflow in the Kern AI refinery. To understand this post, you should know what embeddings are and how they are generated. A rough idea of what fine-tuning is also helps. All the code and data referenced in this post is available on GitHub.
We are constantly looking to improve our kern refinery, where labeling plays a central role. There are several ways how we can leverage embeddings to enhance the labeling process. One tool we already implemented is similarity search, where you can select any record and look for similar records based on cosine similarity of their embeddings.
This can be combined with a custom “labeling session”, which is just a name for the selection of records that you are presented with during manual labeling. That means you can gather the 1000 most similar records with similarity search and start labeling them manually. We found that the labeling experience gets much smoother if you have less context switches within one labeling session. Therefore, the goal of fine-tuning our embeddings is getting more records of the same class within a similarity labeling session.
Large language models (LLM) solve a wide variety of tasks like question answering, information extraction, and sentiment analysis. What makes them so good at those tasks is a combination of the right architecture, a well-designed training procedure, and the availability of the whole internet for training data. For example, a more recent LLM from Google called “LaMDA” was trained on 1.56 trillion words from public forums, tutorials, Wikipedia, web documents, and other sources. Using these vast amounts of available data, an LLM is trained to generalize across several domains, which results in a model that is generally really good but lacks domain-specific expertise. This is where fine-tuning comes into play. Fine-tuning is the process of adjusting your language model to better fit the domain of your data. If you, for example, want to process a lot of legal documents about the building process of an offshore wind farm, you might want to specialize your LLM on these kinds of texts. Though before fine-tuning it yourself, you should always take a look at the Hugging Face model database and check if someone already fine-tuned a model on data that is similar to yours.
The last prerequisite we want to look at before diving into the experiment is “similarity learning”. In order to fine-tune embeddings, we need a task to solve. This task could be anything from supervised classification to unsupervised masked token prediction. Since we want better similarity search for our labeling sessions, we will opt for a task that incorporates class information. We discussed internally what we wanted to try and settled on similarity learning, because it is easy to set up, very fast in training and generally just something new to us we wanted to check out. Similarity is (in our case) defined by the class labels. That means two records are similar if they carry the same class label and they are different if they do not carry the same class label.
We wanted to take easy to understand and widely available data for this use case, so we settled on the “AG News” classification dataset, which has four classes: World, Sports, Business and Sci/Tech. Although it is already labeled, which helps us in the evaluation later on, we will act like it is an unlabeled dataset in order to show the full process. Every record has a title, a description and the associated label. We selected 20.000 records by random, loaded them into kern refinery and labeled 261 manually. After creating some labeling functions and active learners, we ran the weak supervision and ended up with 17.361 weakly supervised records. We filtered for a confidence score larger than 0.7, added the manually labeled data and ended up with 10.854 usable records for our fine-tuning pipeline. The remaining 9.156 records (with their original labels) will be used as a test set in the evaluation later. If you want to a closer look at the labeling process or the data itself, you can visit the GitHub repository, where we documented everything.
Quaterion is able to use different kinds of similarity information in order to fine-tune the embeddings. We could use a similarity score, pre-formed triplets or similarity groups (where the group is defined by the class). Because the class information is the only similarity measure we have, we make use of SimilarityGroupSamples. Now that we have the data ready, we need a model to train. Remember, the goal is to learn a mapping from one embedding to another. For that we are going to use a pre-trained LLM as the encoder and add a SkipConnectionHead on top of it (read here why this is preferred over just a linear layer). The Linear layer has as many in-features as it has out-features, which are 384 in our case because we use “all-MiniLM-L6-v2” as our base model, which produces 384-dimensional embeddings. Normally, for example in classification, you would use a classification head that has as many out-features as there are classes. To get the gradients required for training the network you could then use an implementation of the cross entropy loss function. Because we want to learn similarity in the embedding space, we have to employ a different loss function: a triplet loss with cosine distance as the distance metric.
Most of the training details are covered by Quaterion for us, which uses PyTorch Lightning under the hood. The optimizer (we chose Adam) is specified in the model itself, we just need to call the fit method of Quaterion and specify the data loaders for training and validation.
At the beginning of this blog post we mentioned that we wanted to improve similarity search in kern refinery. Because this is very difficult to measure, we thought of a metric that captures what we are trying to achieve: increase the amount of records of the same class in the 1000 most similar records, which we will refer to as the “top_1k” metric. Because not everyone is going to label a thousand records in a single session, we can also identify the amount of records that have to be labeled so this fine-tuning can be beneficial. Additionally we can check whether our fine-tuning also improved classification accuracy on the side. The test data consists of the 9.156 (already labeled) records that were not used in the training or validation steps.
For the top_1k metric we take 250 random samples from the test data, calculate the closest 1000 neighbors for each of them and inspect what percentage of them have the same label as the original sample. This will be then averaged over the 250 samples to retrieve the top_1k metric. The “raw” embeddings refer to the embeddings that are generated when using the same base language model (“all-MiniLM-L6-v2”) but without applying the learned transformation of the embeddings. When looking at the distribution of this metric, we can see that fine-tuning helped a lot. The violin plots show that with the fine-tuned embeddings you are more likely to get the same classes in your similarity search guided labeling session, which means less context switching and therefore a smoother experience. When averaging these 250 values, we get 49.96% same class for the raw embeddings and 58.60% for the fine-tuned ones, an improvement of close to 10%.
Because labeling sessions are not always drawn out to 1000 records, we were curious how the top_k metric behaves for different values for k (the other parameters stay equal to the previous experiment). The benefits from fine-tuning your embeddings seem to already have an impact on labeling sessions with only 25 records, which is good news because that is not a lot. From there on out, the fine-tuned embeddings constantly perform better than the raw embeddings.
A fine-tuned embedding with class information could also benefit a classifier trained on that data. So after training a LogisticRegression on the embeddings of our training data, we evaluated their performance on the test data with the classification report from sklearn. Interestingly, the fine-tuning did not make much of a difference. We even lost a tiny amount of performance compared to the raw embeddings, which is not significant, though. That means that our neighbor-based similarity improved w.r.t. classification but this linear classification model did not find it easier to separate the classes from one another. We will look into this in more detail in the near future.
By sharing our experience in using similarity learning to fine-tune embeddings, we want to encourage you to try this yourself! Quaterion made it really easy to get started and they also offer lots of support if you encounter any difficulties. Apply this pipeline to your projects that require a well tuned similarity! We took a simple classification dataset, but there are many different domains where similarity learning shines. For example in e-commerce where products are mapped into a vector space. Here, a fine-tuned similarity could drastically enhance the user experience! Everything we presented is open-source. You can start from your raw data, load it into the open-source kern refinery, label and export it and then process it in the Quaterion pipeline.
We are constantly looking for better ways to visualize and label data. Currently, we are looking into annotation methods that include a two-dimensional plot of the embeddings where the user can label the data by drawing shapes around the points that should be labeled. When using basic PCA, we found that the embeddings are often not separated well in only two dimensions, which makes this kind of annotation process difficult. Therefore, we are currently working on methods to fine-tune embeddings leading to a better separation of classes in the 2D space.
Keep an eye out for future blog posts, because we will share our experiences about that with you! You could also join our discord for discussions or questions about any of these topics - let that be NLP, embeddings, LLMs, labeling, or data-centric AI in general.