Machine learning promises to be transformational, providing much needed innovation across industries ranging from financial services to healthcare to agriculture. The resurgence of machine learning owing to the astonishing performance of deep learning in computer vision and natural language processing has prompted industry to jump on the bandwagon and utilise machine learning in their businesses. In fact, entire startups are being built around machine learning alone, promising that their new found algorithms will vastly improve and outperform existing products and services.
Whilst new state of the art algorithms are being published week-on-week, using them in a production setting is non-trivial. The additional complexity of using machine learning to power products cannot be understated. It may come as a surprise to new practitioners that most of this complexity has very little to do with the model itself but rather with the infrastructure supporting it. In fact, a common trope is that the development of machine learning models comprises only a small part of the puzzle. So what makes machine learning complex? Data cleaning and warehousing, data versioning, model versioning, distributed training (in the event you are training large models), reproducibility, concept drift, inference speed and optimizing the right thing, to name a few.
Fortunately, this makes for an extremely rewarding engineering experience. As engineers, we are drawn to technically challenging problems - at least, I'd like to think so! Machine learning in the wild does not fall short in any regard. It is truly stimulating, though it might cause you to lose a few hairs, destroy your laptop and move to Motuo - consider yourself warned. In this post, I illuminate some of the core challenges faced in walking the distance with machine learning.
Data Quality
What is machine learning without data? The answer is nothing. Absolutely nothing. As the common saying goes, "garbage in, garbage out". Perhaps the most important aspect of machine learning is the data used to train models. Without quality data, no amount of sophistication and creativity will help your machine learning model perform accurately. Whether you agree with the below tweet or not, it does provide an interesting angle on what a model is - a view of the data it was trained on.
Data is the source of truth for your model. Every phenomenon it learns is encompassed by that data. Given this stance, you'd be crazy to not ensure that your model is trained on the highest quality data you can create or get your hands on. Unfortunately, ensuring data quality is difficult, expensive and laborious. Rather than trying to circumvent this, it should be incorporated into the company/team strategy. Making the necessary capacity and budget provisions will make this manageable. Ideally, ensuring data quality should be reduced to a mere nuisance as opposed to a blocking point to deploying machine learning models.
Concept Drift
Real-life data is akin to a living organism. It is dynamic and changes over time, drifting further and further away from the data that came before it. This is called concept drift and is a common challenge in machine learning. In its simplest form it is not difficult to deal with as it only requires training on the new batch of data. The complexity arises on an infrastructural level. To handle this effectively, it requires models to be continuously retrained as their performance starts deteriorating in production. Setting up the correct infrastructure to facilitate easy training to deployment and monitoring in production can be difficult depending on the velocity at which this needs to happen. In this blog post by Netflix, they discuss their system architecture for recommendations. Interesting to note is their separation of offline, online and nearline computation. Offline is the most common paradigm in machine learning today where data is stored and models are trained on the new data. However, they reached limitations with this approach as it often took time between updates to the models, impacting their recency. They solve this problem using online computation which reacts to the latest changes in the data in real-time. Nearline computation is a hybrid of these two approaches. Considering Netflix's scale, most of us would not need such comprehensive architectures to run machine learning in production but it certainly highlights the complexity that can arise.
To drive concept drift home, I'll walk through an example. Keeping with the theme of Netflix, imagine you worked there and trained a recommendation model in January. At this time, life was good and relaxed. The majority of customers were watching comedy shows. The model was able to learn this and its performance was admirable and subsequently deployed into production. It is now June and you realise the model is underperforming. The world is going through a pandemic and in response, people have shifted away from comedy shows toward thrillers and horrors. Evidently, this behavioural change implies that the dataset that was used to train the model in January is outdated. To bring the model up to date, you train it on the latest data and as expected, your magic algorithm is performing well once again. Hooray!
Versioning
Versioning is an essential part of software engineering. Git has largely solved this problem for traditional software development. In machine learning, we not only have to version our code but also the parameters of our model and the data that was used to train the model.
Data versioning is a new challenge that machine learning brings to the table. As discussed earlier, data is the source of truth for a model. When models are trained on different data, we naturally expect differences in their performance. Without keeping track of which data a model was trained on, we cannot make any significant claims on its performance relative to another model. Alongside that, without knowing which dataset a model was evaluated on, we cannot fairly compare competing models. Finally, we cannot reproduce experiments if we do not version our data. Reproducibility is a growing concern in the machine learning community and rightfully so - ensuring we create the best environment for it is important. From this, it is clear that data provenance is of utmost importance in machine learning efforts. At this current time, there is no industry standard for data versioning. There are many options available in the market. A few of them include dvc, dolt and pachyderm. They do not all fall within the same use cases but there is extensive overlap between them.
One example of a scenario in which poor data versioning can play against you is during the experimentation phase. Let's suppose you decide to create an object detection system. You train a model with dataset A. You decide to change the data slightly and train a model on this data; dataset B. You repeat this data alteration process twice more, creating dataset C and D. After these experiments, you find that the models trained on dataset A and C perform the best. Unfortunately, you did not record the differences between these datasets and only possess the data with the most recent changes (i.e. dataset D). It is now impossible for you to reproduce the best performing model. Alongside this, you cannot understand why datasets A and C result in the best models, reducing your understanding and confidence in the model's performance. This highlights the importance of data versioning and why it is critical to the machine learning process.
Experimentation Platforms
Experimentation platforms refer to the necessary platforms needed to use machine learning as effectively as possible. It is not necessary in the beginning of the machine learning journey but as a company begins to scale up their machine learning efforts, it becomes critical. Thousands of experiments need to be run and evaluated. Having the correct infrastructure and tools makes the iteration speed fast and allows models to make it into production.
Let's walk through an example to highlight its importance. You've decided that you need to create and train a new model for fraud detection. You whip out your trusty jupyter notebook and set off on the adventure of a lifetime. On completion of your model and dataset, you proceed to train the model on a virtual server provided by the company on their prefered cloud provider. The virtual server is not configured, has no fault tolerance and requires many manual steps (ssh into it, install all dependencies, monitor it, etc). Naturally, these steps present many points of potential failure as well as general headaches, slowing down development. The company realises that this is inefficient and invests into tooling to ease the machine learning development process. There are a whole host of tools such as Kubeflow, Comet, Weights & Biases, MLflow to name a few. These are utilised to create a machine learning platform that makes the machine learning lifecycle easier to manage. Now you have the instant ability to monitor progress, visualise training and validation loss and view the results on a user-friendly dashboard. Clearly, this has improved the iteration speed, removed the need to understand infrastructure and provided the necessary amount of visibility. By using the correct platform, the time taken between iterations has decreased substantially.
Eventually, there is an upfront cost that must be paid in scaling out machine learning efforts by providing the necessary platforms. Data scientists and machine learning engineers need to be empowered to do their work effectively and this is one step in that direction. Alongside this, it allows them to add more value by being able to improve their models in a shorter period of time.
Optimizing the Right Thing
In many product offerings, the core goal of machine learning is to improve the customer experience. In a music company such as Spotify, examples of improving the customer experience include song recommendations and personalisation based on music taste. In a medical imaging company, it could help doctors make more accurate diagnoses or provide diagnoses themselves. Metrics provide an objective way of ranking the performance of models and ensuring you have the latest and greatest model in production. Unfortunately, a common pitfall is optimizing the wrong metric meaning that the metric does not translate to an improvement in the customer experience. A simple question to ask is: what does a 10% improvement of this metric mean for the customer? It is often quite simple to analyse your model against standard metrics in that line of work, however, it does not always translate to a tangible improvement for the purpose they were designed to serve.
A simple case to illuminate this is in medical diagnosis. Imagine you are the patient and this revolutionary AI is used to provide you with a diagnosis. If the model predicts that you have a disease and you do not, the overall harm inflicted by the incorrect diagnosis is minimal. Despite being incorrect, you will survive and hopefully live for many years. Now imagine the model predicts that you do not have a disease that you do in fact have. In response to this prediction, the doctor lets you off into the world to fend for yourself without medication or treatment. Evidently, this is a harmful situation which will likely have severe consequences. How does this relate to optimizing the right thing? Imagine being responsible for creating the model that performs this task and your chosen metric is accuracy. Accuracy measures the total number of correct predictions out of total predictions. Your new model is 5% better than your old model and so you release it into production - happy days! Unfortunately, that metric does not communicate the most important value to you: how often your model predicts an individual does not have a disease when they actually do! The appropriate metric here is recall. Contextually, perfect recall means that the model successfully predicts every case where an individual has a disease, however it makes some mistakes where it incorrectly predicts that individuals have a disease that they do not have.
If only it was as easy as the example makes it seem! In most cases, metrics will have to be designed with the aim of improving the extent to which they translate to an improvement in the customer experience. This talk by Jade Abbott walks through an example of this - start at the 15 minute mark.
The main takeaway is that metrics should be carefully designed and give an indication of the model's impact on the customer. Additionally, it allows the rest of the business to understand the value of the work you are doing. Then they will also celebrate when your model improves by 0.01%!
Greener Pastures
Machine learning in the wild is difficult! This should not discourage us by any means. While difficult, it is extremely rewarding and if you're the kind of engineer I hope you are, it is exciting! With all difficult things comes the joy of working on technically challenging problems.
Top comments (0)