DEV Community

Cover image for Balancing the Imbalanced
Lindsey
Lindsey

Posted on

Balancing the Imbalanced

As I continue to build up my data science toolkit, I've begun learning about the types of classification techniques that are used to solve everyday problems. These tools are really cool! Want to know whether an email you received is spam or not spam? Use a classification technique! Want to know if a new transaction is fraudulent or not? Use a classification technique! Et cetera, et cetera.

One thing I've seen again and again is the importance of class balance when feeding data into these models. Think about it - you're asking a computer, which has NO idea what you're talking about or how to identify anything in any way other than how you tell it to identify things, to look at something completely new and categorize it. If you feed it 1000 emails, 950 of which are 'not spam' and 50 of which are 'spam,' and ask it to identify which are 'not spam,' it can just label everything as 'not spam' and be 95% correct! Not bad.

And yet... that doesn't do what you want at all. You want your model to learn the characteristics of 'spam' emails and actually identify the parts of it which are reliable predictors for 'spam' in general, something the computer is increasingly incentivized not to do as the majority in your datasets gets larger and your models become more complex.

So! Time to practice how to balance the classes within your dataset. I'll be giving examples of how to code a few methods I've encountered in Python 3, using Pandas, SciKit Learn, and a bit of imblearn to make all of our lives easier.

alt text


The Simple Ways to Balance

Perhaps the simplest way to balance your under-represented category against the rest of your data is to under-sample the rest of your data. To stick with our 950 'not spam' versus 50 'spam' example, we'd simply take a sample of 50 'not spam' and use that sample with our full 'spam' data to have a balanced dataset to use to train our model! Easy-peasy.

# Using a Pandas dataframe, 'data,' where a column "category" either 
# has the "majority" option or the "minority" option within the column

minority = data[data["category"] == "minority"]
majority = data[data["category"] == "majority"].sample(n=len(minority))
Enter fullscreen mode Exit fullscreen mode

Alas, you can probably see some problems with this simple model. We lose a lot of data (900 observations) by going down this route, for one. The key to differentiating between 'spam' and 'not spam' could be hidden within that lost data!

-

So, a different (but still simple) way to balance your under-represented category is to over-sample from that minority, with replacement, over and over, until it's the same size as your majority.

# Same example Pandas dataframe as before

majority = data[data["category"] == "majority"]
minority = data[data["category"] == "minority"].sample(n=len(majority), replace=True)
Enter fullscreen mode Exit fullscreen mode

This... also has problems. With a case like our 950-50 split, that means you're likely using those 50 observations in the 'spam' category 19 times, over and over again, to get even with the 950 observations of 'not spam.' This will very likely result in overfitting your data - where your model becomes so used to the content of your minority, the 'spam' category, that it only works on those emails, and cannot be generalized to recognize 'spam' out in the real world.

Sure you can balance these two, both over-sampling your minority and under-sampling your majority, and maybe that will work fine for some of what you do! But, for the cases when you need a more nuanced way to balance your data, there are more complicated methods.


A Little More Complicated

alt text

Alright, so if we can't simply sample the data we already have, what can we do? One idea is to add weight to our minority category, so our model knows that the frequency with which it encounters each class does not translate into the importance of each class - the less frequent category should be considered more important, even though it's rare!

SciKit Learn's Logistic Regression model for classifying data has a built-in option for class_weight which allows you to explain to your model that some classes should be considered more strongly than others. The easiest way to balance from there is to just apply class_weight='balanced' - the Logistic Regression model will automatically know to assign a weight inverse to the frequency of that class. In our spam example, the logistic regression model then knows 'spam' and 'not spam' should be balanced, and will automatically say those 50 examples of 'spam' should be weighted so they're considered more important than the 950 examples of 'not spam.'

# Import the logistic regression package from sci-kit learn
from sklearn.linear_model import LogisticRegression

# Start the instance of the Logistic Regression, but balanced
# Default for class_weight is None, which gives all classes a weight of 1
logreg = LogisticRegression(class_weight='balanced') 
Enter fullscreen mode Exit fullscreen mode

So what is this actually doing? You're telling your model that all classes should contribute equally when it calculates its loss function. In other words, when the model is deciding which way to best fit the data, you're being really explicit in telling it that it needs to consider the percentage of errors in the minority as just as important as the percentage of errors in the majority.

With our example, we discussed a model that always predicts our emails are 'not spam,' since 950 out of 1000 are 'not spam' and only predicting 'not spam' results in a model that's 95% accurate. But that's 0% accuracy for emails that are actually 'spam,' since it never predicts that an email is 'spam.' By telling our model that the classes should be balanced, our model knows that the accuracy for predicting 'spam' is just as important as the accuracy for predicting 'not spam,' and thus it can't consider an overall 95% accuracy as acceptable.

This works! I can only speak to the Logistic Regression model at the moment, but I know other Sci-Kit Learn modeling algorithms have a way of balancing their classes. This may be enough to get you better results with your data, and if so that's great. But what if it's not? Can we get more complicated?


Of Course We Can Get More Complicated

alt text

Another idea - what if we could train our model to make synthetic data, that's similar to the data in our 'spam' minority category but is a little bit different, thus avoiding some of the over-fitting that we were worried about before?

Yes, this is a thing, and no, you don't have to code it from scratch. The Imbalanced Learn library, imblearn, is full of fun ways to apply more complicated balancing techniques - including under- and over-sampling through clusters! These techniques work by identifying clusters in your dataset. To under-sample, you use those clusters to remove observations within the cluster, thus preserving more diversity in the majority cluster than randomly under-sampling. To over-sample, you generate new, synthetic observations within the minority cluster, thus avoiding overfitting to your data because the data within the minority is more diverse.

Okay, but how in the world does any of that work? Let's dig in.

The Synthetic Minority Oversampling Technique (SMOTE) is the most common method I've run into to conduct cluster-based over-sampling. SMOTE works by finding all the instances of the minority category within the observations, drawing lines between those instances, and then creating new observations along those lines.

I found a great explainer of how SMOTE works on Rich Data, although his examples are created in R (aka less helpful for us Python-only people). But the below image shows exactly how those lines are drawn and where the resulting new, synthetic observations are created.

alt text

So how do we do this in Python?

# Import the SMOTE package from the imblearn library
from imblearn.over_sampling import SMOTE

# First, look at your initial value counts
print(y.value_counts())

# Start your SMOTE instance
smote = SMOTE()

# Apply SMOTE to your data, some previously defined X and y
X_resampled, y_resampled = smote.fit_resample(X, y) 

# Look at your new, resampled value counts - should be equal!
print(pd.Series(y_resampled).value_counts())
Enter fullscreen mode Exit fullscreen mode

Now, can you guess why this isn't perfect? This is better than simply using a random over-sample, yet not only are these synthetic samples not real data but also these samples are based on your existing minority. So, those new, synthetic samples can still result in over-fitting, since they're made from our original minority category. An additional pitfall you might run into is if one of your minority category is an outlier - you'll have new data that creates synthetic data based on the line between that outlier and another point in your minority, and maybe that new synthetic data point is also an outlier.

I'll note that SMOTE has a bunch of variants that people have invented that account for some of the overfitting and outlier problems, but are increasingly more complex. Do your best.

-

Another way to create synthetic data to over-sample our minority category is the Adaptive Synthetic approach, ADASYN. ADASYN works similarly to SMOTE, but it focuses on the points in the minority cluster which are the closest to the majority cluster, aka the ones that are most likely to be confused, and focuses on those. It tries to help out your model by focusing on where it might get confused, where 'spam' and 'not spam' are the closest, and making more data in your 'spam' minority category there.

# Import the ADASYN package from the imblearn library
from imblearn.over_sampling import ADASYN

# Start your ADASYN instance
adasyn = ADASYN()

# Apply ADASYN to your data, some previously defined X and y
X_resampled, y_resampled = adasyn.fit_resample(X, y) 
Enter fullscreen mode Exit fullscreen mode

-

Let's try the opposite, synthetic under-sampling. Cluster Centroids finds clusters, too, but instead of using those clusters to create new data points, you're instead inferring which data points in your majority category are 'central' in that cluster. Your model then uses those centroids (central points) for your majority instead of actual instances.

# Import the ClusterCentroids package from the imblearn library
from imblearn.under_sampling import ClusterCentroids

# Start your ClusterCentroids instance
cc = ClusterCentroids()

# Apply ClusterCentroids to your data, some previously defined X and y
X_cc, y_cc = cc.fit_sample(X, y)

Enter fullscreen mode Exit fullscreen mode

Of course, any under-sampling technique will eliminate some of the data you have, thus reducing the nuance that could be found if you looked at all of your data in your majority category. But this way, at least, those centroids will typically be more representative than a random sample of your majority.

-

If your data is having trouble differentiating between your classes, another alternative technique to ADASYN is to have your model ignore instances of your majority that are close to your minority. Uh, what? So, say you have some instances of 'not spam' that look really similar to 'spam.' You can tell your model to link those similar points, and then ignore the majority in that link, the 'not spam,' thus increasing the space in your data between 'spam' and 'not spam.'

These are called Tomek links, and I found a great example in a Kaggle page on Resampling Strategies for Imbalanced Datasets:

alt text

# Import the TomekLinks package from the imblearn library
from imblearn.under_sampling import TomekLinks

# Start your TomekLinks instance
tomek = TomekLinks()

# Apply TomekLinks to your data, some previously defined X and y
X_tl, y_tl = tomek.fit_sample(X, y)
Enter fullscreen mode Exit fullscreen mode

Does this also have problems? Of course! You're ignoring the data that's right on the cusp between your majority and minority categories, perhaps where you need to dig into that data the most! But it is an option.

There are dozens of increasingly more complicated ways to balance your class, as you mix and match and try to get the best set of observations before you try to build a classification model. See the resources below, and dig into the imblearn documentation, if you'd like to find plenty of other ways to try to balance your imbalanced categories!


Caveats

There are a lot of considerations to keep in mind when doing any part of data science, and of course balancing your imbalanced classes is no exception. One thing I absolutely want to reiterate is how important it is to do a train-test split before creating your model, so you reserve a percentage of your data to test your model.

Create your train-test split BEFORE you balance your classes! Otherwise, especially if you use an over-sampling technique, your 'balanced' classes will have overlap between your training data and your testing data - after all, your over-sampling is basically using data you already have to make more data in your minority class, so your testing data will just be your training data either exactly or slightly modified by SMOTE. This tutorial walks through how that can trip you up in practice in quite a lot of detail.

In general, the best advice is to look at metrics beyond accuracy. Accuracy is important, but if we only looked at accuracy in our 'spam' or 'not spam' example we'd have a 95% accurate but otherwise completely useless model. Look at recall and precision as well, and try, as always, to find that magical Goldilocks zone that achieves what you want your model to achieve. Run a confusion matrix - confusion matrix is friend!

-

Soon, I'll edit this post to add an example GitHub repository using actual data, not just spam. In the meantime, any suggestions for more robust ways to balance your datasets? Run into any pitfalls when applying these techniques, or have a technique you find yourself turning to again and again? Let me know!

alt text

Some Resources

I used many of the below to learn more about each of these techniques:

Cover image sourced from this Medium post. SMOTE visualization sourced from Rich Data. Tomek link visualization sourced from this Kaggle page. GIFs, as always, from GIPHY

Latest comments (1)

Collapse
 
mmithrakumar profile image
Mukesh Mithrakumar

Nice article