Introduction
Multiclass classification is a classification problem where more than two classes are present. It is a fundamental machine learning task which aims to classify each instance into one of a predefined set of classes. For instance, classifying a set of images of animals as dogs, cats or rabbits. Each sample is assigned to only one label, i.e., an image can be classified as either dog or rabbit but not both at the same time.
When working on a classification problem, there are instances when one class label has lower number of observations than other class labels. So, this type of dataset is known as imbalanced dataset. The problem is common and can lead to biased classification by the model.
This article will talk about handling imbalanced dataset. It will also provide a step-by-step approach to perform multiclass classification using machine learning algorithms.
Dataset
The dataset used is a numeric dataset. It consists of 11k+ rows and 10 columns. The columns 1-9 are the feature vectors and column 10 is the target vector. There are total 20 classes that need to be predicted and the classes are labelled as digits from 0-19.
The head() method is used to return top n (5 by default) rows of a data frame.
df = pd.read_csv("./training.csv") df.head()
The dataset is highly imbalanced. Count plot of the target can be visualized as following
import seaborn as sns sns.countplot(df['target'])
Removal of outliers
Boxplot
In this part, we are going to learn about outliers. Outliers in a dataset are the values which lie far away from majority of points. Boxplots can be plotted which show the median, minimum, maximum, first quartile and third quartile. Using boxplots, we can visually check for the outliers present in the dataset.
The first quartile(Q1) is the median of the lower half of the set. This means that about 25% of the values are less than Q1.
The third quartile(Q3) is the median of the upper half of the set. This means that about 75% of the values are less than Q3.
IQR
IQR (Inter-Quartile Range) is the difference between Q3 and Q1. An interval is calculated using the given equation:
[ Q1 – 1.5 x IQR, Q3 + 1.5 x IQR ]
Therefore, if an observation lies outside this interval, it is considered as an outlier.
#finding outliers for feature3 #finding the 1st quartile q1 = np.quantile(df.feature3, 0.25) # finding the 3rd quartile q3 = np.quantile(df.feature3, 0.75) med = np.median(df.feature3) # finding the iqr region iqr = q3-q1 # finding upper and lower whiskers upper_bound = q3+(1.5*iqr) lower_bound = q1-(1.5*iqr) outliers = df.feature3[(df.feature3 <= lower_bound) | (df.feature3 >= upper_bound)]
Feature Selection using Correlation Heatmap
Correlation is a statistical score which tells how close two variables are to having a linear relationship with each other. Higher correlation between two variables will have very similar effect on dependent variables. Therefore, we prefer dropping one of the two features.
We can generate a correlation matrix using:
df.corr()
And correlation heatmap using:
sns.heatmap(corr)
The correlation heatmap generated is shown below
The feature1 and feature3 have high correlation with other features and dropping them gave us better results.
X = df.drop('target',axis=1) y=df['target'] X = X.drop(['feature1','feature3'],axis = 1)
Handling Imbalanced Dataset
SMOTE (Synthetic Minority Oversampling Technique)
SMOTE helps in oversampling the examples in the minority class. The process involved in oversampling is given as-
- Select random data from the minority class.
- Calculate the Euclidean distance between the random data and its k nearest neighbors.
- Multiply the difference with a random number between 0 and 1. Then, add the result to the minority class as a synthetic sample.
- Repeat the procedure until the expected proportion of minority class is met.
SMOTE-Tomek Links
This is a modified version of SMOTE. It combines the ability of both SMOTE and Tomek Links. SMOTE is capable of generating synthetic data for minority class. And, Tomek Links is able to remove the data that are identified as Tomek links from the majority class.
from imblearn.combine import SMOTETomek from imblearn.under_sampling import TomekLinks
# Define SMOTE-Tomek Links resample=SMOTETomek(tomek=TomekLinks(sampling_strategy='majority')) X, y = resample.fit_resample(X, y)
We can now visualize count plot to have an equal number of samples for each class in the target.
Splitting Data into Train and Test Data
The dataset can be split into training set and validation set. We will use the training set for training our model. Validation set will be used to check whether our model can perform well on new, unseen data.
from sklearn.model_selection import train_test_split train_x, val_x,train_y, val_y = train_test_split(X,y,test_size=0.2)
Multiclass Classification using Random Forest Classifier
Random forest consists of a large number of single decision trees that work as an ensemble. Each individual tree in the random forest outputs a class prediction. Each class gets some votes and the class with the most votes becomes the model’s prediction.
from sklearn.ensemble import RandomForestClassifier rfc = RandomForestClassifier(n_estimators=200) rfc.fit(train_x, train_y) rfc_predict = rfc.predict(val_x) print('Accuracy score:',accuracy_score(val_y, rfc_predict))
Random Forest Classifier worked very well on the test dataset and gave an accuracy score of 88.62%. However, only knowing the accuracy score is not enough. In the next and final part, we will look at various evaluation metrics and calculate them using python in-built functions.
Evaluation Metrics
- Accuracy: the proportion of the total number of predictions that were correct.
- Precision: the proportion of positive cases that were correctly identified.
- Sensitivity or Recall : the proportion of actual positive cases which are correctly identified.
- F1 Score: The F1 score can be interpreted as a harmonic mean of the precision and recall
F1 Score = 2 * (precision * recall) / (precision + recall)
from sklearn.metrics import classification_report print(classification_report(val_y,rfc_predict))