Mastering CartPole with Enhanced Deep Q-Networks: An In-depth Guide to Equivariant Models

FreaxRuby
3 min readJan 10, 2024

Reinforcement Learning (RL) stands at the forefront of developing intelligent systems capable of learning complex behaviors. My recent foray into this field targets the CartPole problem, a classic RL challenge. Here, I’ll delve deeper into how I employed equivariant models to supercharge the Deep Q-Network (DQN), effectively doubling the learning efficiency. This detailed exploration is designed to cater to enthusiasts at all levels in the RL community.

The CartPole Challenge: Balancing Act in RL

Imagine a simple cart moving along a frictionless track with a pole attached to it. The objective? Prevent the pole from falling over by moving the cart left or right. This seemingly simple task, known as the CartPole problem, encapsulates the core challenges of RL: decision-making under uncertainty and learning from interactions with an environment.

Understanding the State Space

The state of the CartPole system comprises four variables:

1. Cart Position: Horizontal position of the cart on the track.

2. Cart Velocity: Speed and direction of the cart’s movement.

3. Pole Angle: The angle between the pole and the vertical line.

4. Pole Velocity at the Tip: Rate of change of the pole’s angle.

Enhancing DQN with Equivariance: Two-Fold Strategy

Equivariance in RL implies that the model should adapt its outputs in a predictable manner to transformations applied to its inputs. In CartPole, this transformation is the mirroring of the pole’s angle and velocity.

Strategy 1: CustomSymmetricQNet for Network-Level Equivariance

‘CustomSymmetricQNet’ is a tailored neural network that inherently understands the concept of symmetry in the CartPole environment.

The Core: ReflectionSymmetryLayer

This custom layer is the linchpin of the network, designed to reflect specific input features — the pole’s angle and velocity.

class ReflectionSymmetryLayer(nn.Module):
...
def forward(self, x):
reflected_x = x.clone()
reflected_x[:, -2:] = -reflected_x[:, -2:] # Inverting the last two features
return reflected_x
  • Functionality: It clones the input and inverts the last two features. This process effectively creates a mirrored state, helping the network recognize and learn from symmetrical scenarios.

Advantages

  • Efficient Learning: Processes both original and mirrored states, thus gaining double the insights from each interaction.
  • Robust Decision-Making: Equips the network with a more holistic understanding, crucial for dynamic and unpredictable environments.

Strategy 2: SymmetricDQN for Data-Level Equivariance

Here, the focus is on augmenting the training data rather than modifying the network architecture.

Execution

  • Data Mirroring: For every state experienced by the agent, a mirrored state is generated by inverting the pole’s angle and velocity.
  • Training Expansion: The standard `Qnet` is trained on this expanded dataset, covering a broader spectrum of scenarios.
class SymmetricDQN:
...
def get_symmetric_states_actions(self, states, actions):
mirrored_states = self.mirror_states(states) # Generating mirrored states
mirrored_actions = 1 - actions # Adjusting actions accordingly
return mirrored_states, mirrored_actions
...
  • Outcome: This method effectively doubles the variety of training scenarios, enabling the network to learn and adapt more rapidly.

Training and Evaluation

The training involved running multiple episodes, where the agent interacts with the environment, making decisions based on its current policy. By employing CustomSymmetricQNet and SymmetricDQN, the agent exhibited a marked improvement in learning efficiency.

Results

Both methods demonstrated a twofold increase in training speed compared to a traditional DQN. This was reflected in:

  • Faster Policy Convergence: The agent reached optimal decision-making strategies in fewer episodes.
  • Enhanced Pole Balancing Performance: Demonstrated greater proficiency in keeping the pole balanced for extended periods.

--

--