Q-Learning in Rust

Gautam Kumar
5 min readDec 1, 2023

--

Maze-solver Agent (A) and obstacle (O), Goal is to reach rightmost bottom.

Q-learning has following components.

  1. Objective: The goal of Q-Learning is to learn a policy, which tells an agent what action to take under given environment.
  2. Environment and Agent: In Q-Learning, an agent interacts with an environment in discrete time steps. At each time step, the agent chooses an action, and the environment responds by presenting a new state and a reward.
  3. Q-Table: Q-table provides a lookup for the best action to take for each state. The table’s rows represent all possible states, and the columns represent possible actions. The values in the table (Q-values) represent the expected future rewards of taking a given action in a given state.
  4. Learning Process: The agent learns by updating Q-values in the Q-table based on the equation:

Qnew​(state,action) = Q(state,action) + α × (reward + γ × max( Q(nextstate, allactions) ) − Q(state, action))

α is the learning rate (how much new information overrides old information), and γ is the discount factor (importance of future rewards).

5. Exploration vs. Exploitation: Q-Learning involves a balance between exploration (trying new things) and exploitation (using known information). This is often managed by strategies like ε-greedy, where the agent randomly explores with a probability of ε and exploits known information with a probability of 1-ε.

Over time, as the agent explores the environment and updates the Q-table, it starts to develop a policy by choosing the action with the highest Q-value for each state.

Lets take a real-world example and run by it

use std::collections::HashMap;
use rand::Rng;
use rand::seq::SliceRandom;

// Environment settings
const GRID_SIZE: usize = 5;

// hyperparameters
const ALPHA: f64 = 0.1;
const GAMMA: f64 = 0.9;
const EPSILON: f64 = 0.1;
const EPISODES: usize = 1000; // For training loop

#[derive(Hash, Eq, PartialEq, Debug, Clone, Copy)]
enum Action {
Up,
Down,
Left,
Right,
}

fn main() {
let mut rng = rand::thread_rng();

// States and actions
let states: Vec<(usize, usize)> = (0..GRID_SIZE)
.flat_map(|x| (0..GRID_SIZE).map(move |y| (x, y)))
.collect();
let actions = vec![Action::Up, Action::Down, Action::Left, Action::Right];

// Example obstacles and goal
let obstacles = vec![(1, 1), (2, 2), (3, 3)];
let goal = (4, 4);

// Initialize Q-table
let mut q_table: HashMap<((usize, usize), Action), f64> = HashMap::new();
for &state in &states {
for &action in &actions {
q_table.insert((state, action), 0.0);
}
}

// Training loop
for _ in 0..EPISODES {
let mut state = (0, 0); // Start state
while state != goal {
// Select action exploration vs exploitation
let action = if rng.gen::<f64>() < EPSILON {
*actions.choose(&mut rng).unwrap()
} else {
*actions.iter().max_by(|&&a1, &&a2| q_table[&(state, a1)].partial_cmp(&q_table[&(state, a2)]).unwrap()).unwrap()
};

// Get next state
let next_state = get_next_state(state, action);

// Print current and next state for debugging
//println!("Current State: {:?}, Next State: {:?}", state, next_state);

// Reward function
let reward = if next_state == goal {
100.0
} else if obstacles.contains(&next_state) {
-50.0
} else {
-1.0
};

// Q-learning update
let next_max = actions.iter().map(|&a| q_table[&(next_state, a)]).fold(f64::MIN, f64::max);

let q_value = q_table.get_mut(&(state, action)).unwrap();
*q_value += ALPHA * (reward + GAMMA * next_max - *q_value);

state = next_state;
}
}

// Display results in a matrix format
println!("Learned Q-values:");
for x in 0..GRID_SIZE {
for y in 0..GRID_SIZE {
print!("State ({}, {}): ", x, y);
for action in &actions {
let q_value = q_table.get(&((x, y), *action)).unwrap_or(&0.0);
print!("{:?}: {:.2}, ", action, q_value);
}
println!(); // New line after each state
}
println!(); // Extra line to separate rows of the grid
}

}

// Helper function to get next state
fn get_next_state(state: (usize, usize), action: Action) -> (usize, usize) {
let (mut x, mut y) = state;
match action {
Action::Up => x = x.saturating_sub(1),
Action::Down => x = usize::min(x + 1, GRID_SIZE - 1),
Action::Left => y = y.saturating_sub(1),
Action::Right => y = usize::min(y + 1, GRID_SIZE - 1),
}
(x, y)
}

cargo run

Learned Q-values:
State (0, 0): Up: 35.47, Down: 42.61, Left: 35.07, Right: 4.47,
State (0, 1): Up: -1.98, Down: -5.00, Left: -2.04, Right: 12.69,
State (0, 2): Up: -1.40, Down: -1.34, Left: -1.44, Right: 24.98,
State (0, 3): Up: 2.29, Down: 40.60, Left: -0.89, Right: -0.65,
State (0, 4): Up: -0.49, Down: 4.49, Left: -0.51, Right: -0.58,

State (1, 0): Up: 33.54, Down: 48.46, Left: 41.47, Right: -4.49,
State (1, 1): Up: -0.54, Down: 54.41, Left: 3.96, Right: -0.33,
State (1, 2): Up: -0.99, Down: -9.51, Left: -20.51, Right: 5.94,
State (1, 3): Up: -0.52, Down: 0.23, Left: -0.46, Right: 57.38,
State (1, 4): Up: -0.38, Down: 73.20, Left: 3.92, Right: 7.82,

State (2, 0): Up: 38.65, Down: 47.31, Left: 39.26, Right: 54.95,
State (2, 1): Up: -8.04, Down: 62.17, Left: 44.84, Right: -15.45,
State (2, 2): Up: -0.11, Down: 49.88, Left: -0.25, Right: -0.21,
State (2, 3): Up: -0.22, Down: -5.00, Left: -5.60, Right: 8.13,
State (2, 4): Up: 5.86, Down: 87.06, Left: -0.21, Right: -0.20,

State (3, 0): Up: 4.26, Down: -0.63, Left: 16.29, Right: 61.70,
State (3, 1): Up: 52.81, Down: 70.19, Left: 43.77, Right: 45.81,
State (3, 2): Up: -5.12, Down: 15.07, Left: 62.12, Right: -1.82,
State (3, 3): Up: 0.00, Down: 79.23, Left: 5.39, Right: 8.01,
State (3, 4): Up: 3.20, Down: 99.70, Left: -9.43, Right: 8.78,

State (4, 0): Up: -0.43, Down: -0.39, Left: 2.71, Right: 65.93,
State (4, 1): Up: 57.61, Down: 66.17, Left: 47.07, Right: 79.10,
State (4, 2): Up: 46.66, Down: 72.77, Left: 61.31, Right: 89.00,
State (4, 3): Up: 5.49, Down: 82.61, Left: 71.18, Right: 100.00,
State (4, 4): Up: 0.00, Down: 0.00, Left: 0.00, Right: 0.00,
Agent reached the goal by exploring the path

There are few fundamental issues with Q-learning

  1. Its hard to parallelize the Q-learning algorithm
  2. It only works for close system.

Code for interactive Q-learning: https://github.com/goswamig/rust-ai

My previous blogs on rust: https://medium.com/@gautamgoswami/q-learning-in-rust-c16cacff1829

--

--