Using Deep Q-Learning in the Classification of an Imbalanced Dataset

Moustafa Ayoub
Ixor
Published in
6 min readDec 10, 2019

Having an imbalanced dataset in a real-world application is one of the most common problems that can be faced when using machine learning. Multiple approaches have been used to address this problem, whether on the algorithm’s level or data level. On the algorithm’s level class weights could be introduced to the cost function, where a higher misclassification cost is assigned for the minority class. Whereas on the data’s level usually the dataset is resampled, either upsampling the minority class or downsampling the majority class. In this post, we will see how the concept behind deep Q-learning can be utilized to tackle the problem of an imbalanced dataset.

Dataset:

At IxorThink, the dataset being used is a set of medical scans for the detection of Ductal carcinoma in situ (DCIS), which is the presence of abnormal cells inside a milk duct in the breast and is considered the earliest form of breast cancer.

The instance of the appearance of such cells or class is drastically lower than the appearance of the other class that belongs to a healthy tissue. This often is a problem in training a machine learning algorithm and it becomes difficult for the algorithm to detect a minority class sample, which results in significant costs.

Formalize the Process:

We will not dive deep into how Deep Q-Learning works, but I would recommend checking the following post for a more detailed explanation of the theory behind a Deep Q-network:

https://medium.com/@jonathan_hui/rl-dqn-deep-q-network-e207751f7ae4

If a DQN would be considered in a classification problem, then this problem is to be seen as a guessing game. Where a positive reward is received if the guess is correct and a negative reward if not. Throughout the training process, the agent would learn an optimal policy to maximize the cumulative reward thus maximizing the overall correctly classified samples.

Since the DQN would follow a Markov Decision Process, the process of interaction between the agent and the environment should be transformed into a sequential process. However, this is an uncommon approach if you were to classify images that are not related to each other (we will get to this later on).

Reinforcement Learning Interaction In Image Classification

Now let’s find out how deep Q-learning’s application for imbalanced image classification can work on practice.

1- State: The state s in the environment is the training sample, which would be the image sample in our case.

2- Action: The action a of the agent is the label of the training sample. For simplification, we will consider only a binary problem, where the agent is only able to choose from the set of actions A = {0,1}, 0 is the majority class and 1 is the minority class.

3- Reward: The reward r is the feedback that the environment gives back to the agent for it to measure its success in classifying the state s correctly. To help the agent learn the optimal classification policy when we have an imbalanced dataset, give the agent a higher reward when the input state belongs to the minority class and a low reward when the input belongs to the majority class.

4- Discount factor: The factor 𝛾 ∈ [0,1], it weighs in the importance of future rewards. Since we are working on image classification, then consecutive samples are not correlated and each image needs to be classified correctly. This is why a low value for 𝛾 would be a better choice.

5- Exploration rate: The rate ε ∈ [0,1], when it is set to 1, this means that the actions taken are purely based on exploration, on the other hand, if the value is 0 the actions taken are an exploitation of the agent’s knowledge.

6- Episode: The episode e ends when all the training samples in the training set have been passed for the agent to classify.

Set the Reward Function:

As mentioned previously, the reward system should be set on the idea that it is difficult to correctly classify an object belonging to a minority class in the imbalanced set. Thus, giving higher rewards or punishments on the classification or misclassification of the object belonging to the minority class.

A rule of thumb in choosing the values of the rewards is to make use of the ratio of the number of majority class elements to minority class elements, which we will refer to as 𝞺.

Reward system of the Imbalanced Image Classification Problem

Set the Memory:

A key to solving a problem of the imbalanced dataset is to split the size of the memory equally to sub-memories between the different classes. When this step is done, each sub-memory would be appended by its corresponding class instead of overwriting all the minority class samples by the ones that belong to the majority class since they rarely appear.

This way we would guarantee that when we select random samples from the memory to train the agent, our samples would be balanced between the different classes of the training set.

Memory Appending Process of the Imbalanced Image Classification Problem

Episode Termination:

The episode e is terminated when the agent has misclassified more than a certain number x of samples from the minority class, or has successfully been trained on all the samples of the training set.

Set the DQN:

The Value function is the expected cumulative reward from being at a particular state in the environment and it depends on the policy which the agent relies on to pick the action to perform.

The Q-value function is a more complex step than the normal value function since it considers both the state and an action, which also depends on the policy which the agent relies on to pick the action to perform.

The Optimal Q-value function is denoted by Q∗(s, a) = Q(s, a, θ). To approximate the Q∗, the method of deep Q-learning is introduced, and a new term θ is taken into account. θ, represents all of the weights in the deep neural network that are used to approximate the Q∗.

On each time step, the agent will randomly remember some of the events that happened, which be used to train the DQN. These events are a combination of the state action pair at a given time step and their reward.

The architecture of the DQN depends on the complexity of the dataset. The Linear activation function is chosen since the DQN plays the role of the Q function estimator at any given state and action, then outputs the future expected cumulative rewards starting from the current state onward. This implies that the Q function is a real value and the value is not constrained with a range that is given in activation functions like “Softmax”. The loss function is the mean squared error and Adam optimizer is used.

Conclusion

A new path for image classification has been utilized to overcome the problem of having an imbalanced dataset instead of using the traditional ways from data resampling to adding class weights for the cost function. With a highly imbalanced dataset use case, promising results were achieved and hopefully in the future we will add more features such as multi-class classification and experiment with improved deep reinforcement learning algorithms.

References

[1] Enlu Lin, Qiong Chen, and Xiaoming Qi. “Deep Reinforcement Learning for Imbalanced Classification”. arXiv preprint arXiv:1901.01379v1 (2019).

At IxorThink we are constantly trying to improve our methods to create state-of-the-art solutions. As a software-company, we can provide stable and fully developed solutions. Feel free to contact us for more information.

--

--