Introduction
Reinforcement Learning (RL) has become one of the most talked-about subfields of Artificial Intelligence, not only in academia but increasingly in engineering practice. Whenever autonomous systems adapt in uncertain environments — whether a robot learns to walk, a fleet of taxis self-organizes, or a data center balances workloads — RL concepts play a role. Yet many software engineers view RL as an impenetrable forest of mathematics, neural networks, and jargon.
This article proposes a different approach: we shrink the world to the size of a tiny arcade game. In this minimalist world, the essential RL principles appear clearly, free of the distractions of industrial complexity. By the end, you will not only understand the famous Q-learning algorithm but also appreciate why reinforcement learning matters for software engineers building the next generation of adaptive systems.
The Arcade World: A Minimal Environment
The game is called Catch. Imagine a grid of width W and height H. A ball appears at the top row and falls straight down, one step at a time. At the bottom, a paddle moves horizontally: left, right, or stay. If the paddle aligns with the ball when it hits the ground, the player scores; otherwise, they fail.
+-------------------+ y = H -1
| |
| (ball) |
| |
| |
| |
| |
| [paddle] | y = 0
+-------------------+
The rules are minimal, the dynamics predictable. And yet, this world suffices to illustrate everything we need: states, actions, rewards, exploration, exploitation, and policy learning.
The Feedback Loop: Agent and Environment
Reinforcement learning is not about training with labeled data or finding clusters. It is about interaction. An agent repeatedly acts in an environment and is judged by the consequences. The cycle looks like this:
┌─────────────┐
│ │
│ Agent │
│ │
└─────┬───────┘
│ action a_t
▼
┌─────────────┐
│ │
│ Environment │
│ │
└─────┬───────┘
reward r │ state s_{t+1}
▼
In Catch:
• State (s): the ball’s x,y position and the paddle’s x position.
• Action (a): move left, stay, move right.
• Reward (r): +1 if caught, −1 if missed, 0 otherwise.
This loop repeats until the ball reaches the bottom. Then the episode ends, and the next begins.
Formalizing the Challenge: The MDP
Professionals like rigor. RL is formally described as a Markov Decision Process (MDP). An MDP is defined by:
• A set of states S.
• A set of actions A.
• A transition function P(s’|s,a).
• A reward function R(s,a).
• A discount factor γ ∈ [0,1].
The Markov property states that the next state depends only on the current state and action, not on the full history. In Catch, this property holds: the ball’s and paddle’s current positions fully determine the future.
The Dilemma: Exploration vs. Exploitation
Engineers face trade-offs daily: should we refactor for long-term health or ship features now? RL encapsulates a similar tension: explore new actions or exploit known good actions?
ASCII sketch of the dilemma:
Exploration: try something new ──► possible higher future reward
Exploitation: stick to best known ──► reliable immediate reward
The standard technique is ε-greedy:
if random() < ε:
choose random action # exploration
else:
choose argmax_a Q(s,a) # exploitation
At the start of training, ε is high (lots of exploration). Over episodes, ε decays, allowing convergence to exploitation. The decay is usually exponential or linear, as illustrated below:
ε
^
|\
| \
| \
| \ Exploration ↓
| \__________________ Exploitation ↑
+-----------------------> Episodes
This curve reminds us that exploration is front-loaded: in early episodes, randomness dominates, helping the agent discover options. Later, exploration dwindles, stabilizing behavior into exploitation of the learned policy.
The Central Idea: Action-Value Functions
The agent must answer: How good is it to take action a in state s? This is quantified by the Q-function:
Q(s,a) = expected cumulative reward from state s
taking action a, then following the best policy
If Q(s, right) > Q(s, left), the agent knows moving right is more promising.
The Q-Learning Update
Q-learning is an off-policy algorithm: it learns the optimal Q-function even while exploring. The update rule:
Q(s,a) ← Q(s,a) + α [ r + γ * max_{a'} Q(s',a') − Q(s,a) ]
ASCII depiction:
┌──────────────────────────────┐
│ Old estimate: Q(s,a) │
└──────────────────────────────┘
│
▼
target = r + γ * max_a' Q(s',a')
│
▼
New Q(s,a) = Old + α * (target - Old)
The learning rate α controls how much new information overrides old beliefs. The discount factor γ values immediate vs. future rewards.
The Learning Curve: From Chaos to Competence
Engineers trust metrics. During training, we record:
• Episode reward: total reward per episode.
• Moving average: smoothed reward trend.
• Evaluation: average greedy performance without exploration.
An ASCII illustration:
Reward
^
| xxxx
| xx xx
| xx xxx
| xxxx xxx
+--------------------------------> Episodes
Early on, rewards oscillate around zero. As learning progresses, the curve climbs toward consistent positive values.
Seeing is Believing: Visual Replays
Numbers tell part of the story. But showing the agent’s behavior completes the picture. After training, we run the agent greedily and render each frame:
(ball) (ball)
(ball) .
. .
[paddle] [paddle] [paddle]
The paddle moves into position, anticipating the ball’s landing. This visual confirmation reinforces the metrics.
Why This Matters for Software Engineers
So far, this may look like a toy. But the concepts generalize:
• Adaptive resource allocation: RL can manage virtual machines in a cloud cluster, scaling resources up or down.
• Traffic optimization: signals adjust in real-time, balancing flow.
• Recommendation systems: explore new content but exploit user preferences.
These systems are more complex than Catch, but the feedback loop, the exploration–exploitation trade-off, and the Q-update are the same.
Lessons for Professionals
1. Reward shaping is design: Get the reward wrong, and the agent learns undesirable behavior.
2. Exploration is costly but necessary: Systems may underperform temporarily to discover better policies.
3. Policies emerge, not are programmed: RL shifts mindset from deterministic coding to probabilistic optimization.
4. Scaling requires function approximation: Tabular Q-learning suffices for Catch, but deep neural networks (DQN) are needed for larger spaces.
Beyond Tabular Q-Learning: The DQN Leap
Our tabular example does not scale. In modern practice, engineers use Deep Q-Networks (DQNs), where Q(s,a) is approximated by a neural network trained on replayed experiences. The architecture:
State (pixels) ──► ConvNet ──► Fully connected ──► Q-values for actions
This is the method that allowed RL agents to reach human-level performance on Atari games. The essence is unchanged; the function approximator is more powerful.
Risks and Considerations
Professional deployments must consider:
• Convergence: RL can diverge if hyperparameters are mis-set.
• Safety: Exploration in real systems may cause harm.
• Explainability: Policies learned by Q-functions may be hard to interpret.
Engineers must combine RL with simulation environments, guardrails, and monitoring.
Conclusion: From Arcade to Industry
By observing a paddle learn to catch a ball, we witness reinforcement learning in its purest form. The lesson for software engineers is not the specifics of tabular Q-tables but the architectural mindset: systems can be built to learn from feedback, balancing exploration and exploitation, updating expectations, and improving autonomously.
Today, it’s an arcade toy. Tomorrow, it could be your cloud infrastructure, your production pipeline, or your autonomous vehicle. The cycle is always the same:
observe state → choose action → receive reward → update Q → improve policy
Understanding this cycle equips engineers to participate in the next wave of adaptive software systems.
SOURCE CODE
„NORMAL“ VERSION
#!/usr/bin/env python3
"""
Arcade Catch + Q-learning (tabular) with live training curve and episode replays.
How to run:
pip install numpy matplotlib
import math
import random
from collections import defaultdict, deque
from typing import Dict, Tuple, List
import numpy as np
import matplotlib.pyplot as plt
# =========================
# Environment: Catch Game
# =========================
class CatchEnv:
"""
A tiny arcade-like game:
- Discrete grid of width W and height H.
- A ball spawns at the top row (y=H-1) at a random x and falls straight down.
- A paddle sits on the bottom row (y=0) and can move LEFT (-1), STAY (0), or RIGHT (+1).
- Episode ends when the ball leaves the grid; reward is +1 if paddle x == ball x, else -1.
State representation: (ball_x, ball_y, paddle_x)
"""
def __init__(self, width: int = 7, height: int = 7, paddle_speed: int = 1, seed: int = 0):
self.W = width
self.H = height
self.paddle_speed = paddle_speed
self.rng = random.Random(seed)
self.reset()
def reset(self) -> Tuple[int, int, int]:
self.ball_x = self.rng.randrange(self.W)
self.ball_y = self.H - 1
self.paddle_x = self.W // 2
return (self.ball_x, self.ball_y, self.paddle_x)
def step(self, action: int) -> Tuple[Tuple[int, int, int], int, bool]:
"""
Apply action in {-1, 0, +1}. Paddle is clamped to [0, W-1].
Ball falls by one row. If ball leaves the grid, episode ends and reward is given.
Returns: (next_state, reward, done)
"""
self.paddle_x = int(np.clip(self.paddle_x + action * self.paddle_speed, 0, self.W - 1))
self.ball_y -= 1
done = False
reward = 0
if self.ball_y < 0:
done = True
reward = 1 if self.paddle_x == self.ball_x else -1
state = (self.ball_x, max(self.ball_y, 0), self.paddle_x)
return state, reward, done
def render_frame(self) -> np.ndarray:
"""
Return a 2D array with ball=2, paddle=1, empty=0 for visualization.
"""
grid = np.zeros((self.H, self.W), dtype=int)
if 0 <= self.ball_y < self.H:
grid[self.ball_y, self.ball_x] = 2
grid[0, self.paddle_x] = 1
return grid
# =========================
# Q-learning Agent
# =========================
class QAgent:
"""
Tabular Q-learning agent with epsilon-greedy exploration.
Q is stored as a dict: state -> {action: value}
"""
def __init__(
self,
width: int,
height: int,
actions: Tuple[int, ...] = (-1, 0, 1),
alpha: float = 0.2,
gamma: float = 0.98,
eps_start: float = 1.0,
eps_end: float = 0.05,
eps_decay: float = 0.995,
seed: int = 0,
):
self.W = width
self.H = height
self.actions = actions
self.alpha = alpha
self.gamma = gamma
self.eps = eps_start
self.eps_end = eps_end
self.eps_decay = eps_decay
self.rng = random.Random(seed)
self.Q: Dict[Tuple[int, int, int], Dict[int, float]] = defaultdict(
lambda: {a: 0.0 for a in self.actions}
)
def policy(self, state: Tuple[int, int, int]) -> int:
"""Epsilon-greedy action selection."""
if self.rng.random() < self.eps:
return self.rng.choice(self.actions)
qvals = self.Q[state]
max_q = max(qvals.values())
best_actions = [a for a, q in qvals.items() if q == max_q]
return self.rng.choice(best_actions)
def greedy_action(self, state: Tuple[int, int, int]) -> int:
"""Greedy (argmax) action selection."""
qvals = self.Q[state]
max_q = max(qvals.values())
best_actions = [a for a, q in qvals.items() if q == max_q]
return self.rng.choice(best_actions)
def update(self, s: Tuple[int, int, int], a: int, r: int, s_next: Tuple[int, int, int], done: bool) -> None:
"""Tabular Q-learning update."""
qsa = self.Q[s][a]
target = r if done else r + self.gamma * max(self.Q[s_next].values())
self.Q[s][a] = qsa + self.alpha * (target - qsa)
def decay_epsilon(self) -> None:
self.eps = max(self.eps_end, self.eps * self.eps_decay)
# =========================
# Training & Evaluation
# =========================
def train(
env: CatchEnv,
agent: QAgent,
episodes: int = 2000,
max_steps: int = 50,
eval_every: int = 200,
eval_episodes: int = 100,
) -> Tuple[List[float], List[float], List[Tuple[int, float]]]:
"""
Train with epsilon-greedy Q-learning. Return (rewards, moving_avg, eval_scores).
eval_scores is a list of (episode_index, avg_greedy_reward).
"""
rewards: List[float] = []
moving_avg: List[float] = []
eval_scores: List[Tuple[int, float]] = []
window = deque(maxlen=50)
for ep in range(1, episodes + 1):
s = env.reset()
total = 0.0
for _ in range(max_steps):
a = agent.policy(s)
s_next, r, done = env.step(a)
agent.update(s, a, r, s_next, done)
s = s_next
total += r
if done:
break
agent.decay_epsilon()
rewards.append(total)
window.append(total)
moving_avg.append(sum(window) / len(window))
if eval_every and (ep % eval_every) == 0:
avg_r = evaluate(env, agent, episodes=eval_episodes, max_steps=max_steps)
eval_scores.append((ep, avg_r))
return rewards, moving_avg, eval_scores
def evaluate(env: CatchEnv, agent: QAgent, episodes: int = 50, max_steps: int = 50) -> float:
"""Evaluate the greedy policy (epsilon=0). Return average episode reward."""
old_eps = agent.eps
agent.eps = 0.0
total = 0.0
for _ in range(episodes):
s = env.reset()
ep_r = 0.0
for _ in range(max_steps):
a = agent.greedy_action(s)
s, r, done = env.step(a)
ep_r += r
if done:
break
total += ep_r
agent.eps = old_eps
return total / episodes
# =========================
# Visualization
# =========================
def show_learning_curve(rewards: List[float], moving_avg: List[float], eval_scores: List[Tuple[int, float]]) -> None:
plt.figure(figsize=(8, 4))
plt.plot(rewards, label="Episode reward")
plt.plot(moving_avg, label="Moving average (last 50)")
if eval_scores:
xs, ys = zip(*eval_scores)
plt.scatter(xs, ys, label="Greedy eval avg", marker="x")
plt.xlabel("Episode")
plt.ylabel("Reward")
plt.title("Q-learning on Catch: training progress")
plt.legend()
plt.tight_layout()
plt.show()
def play_and_render(env: CatchEnv, agent: QAgent, episodes: int = 3, max_steps: int = 50) -> None:
"""
Play a few episodes with the greedy policy and render frames using matplotlib.
"""
old_eps = agent.eps
agent.eps = 0.0
for ep in range(1, episodes + 1):
s = env.reset()
frames: List[np.ndarray] = []
rewards = 0
for _ in range(max_steps):
frames.append(env.render_frame())
a = agent.greedy_action(s)
s, r, done = env.step(a)
rewards += r
if done:
frames.append(env.render_frame())
break
# Show frames one-by-one
for i, frame in enumerate(frames):
plt.figure(figsize=(3, 3))
# invert Y so bottom row plots at bottom
plt.imshow(frame[::-1, :], vmin=0, vmax=2)
plt.xticks(range(env.W))
plt.yticks(range(env.H))
plt.title(f"Episode {ep} – frame {i+1}/{len(frames)}")
plt.grid(True)
plt.tight_layout()
plt.show()
# Episode result
res = "CATCH!" if rewards > 0 else "MISS"
plt.figure(figsize=(3, 0.6))
plt.text(0.5, 0.5, f"Episode {ep} result: {res}", ha="center", va="center")
plt.axis("off")
plt.tight_layout()
plt.show()
agent.eps = old_eps
# =========================
# Main
# =========================
def main() -> None:
W, H = 7, 7
env = CatchEnv(width=W, height=H, seed=42)
agent = QAgent(width=W, height=H, alpha=0.25, gamma=0.98, eps_start=1.0, eps_end=0.05, eps_decay=0.995)
# Episodes ~= grid height + a little slack so ball reaches bottom
episodes = 2000
max_steps = H + 2
rewards, moving_avg, eval_scores = train(
env, agent, episodes=episodes, max_steps=max_steps, eval_every=200, eval_episodes=100
)
show_learning_curve(rewards, moving_avg, eval_scores)
play_and_render(env, agent, episodes=3, max_steps=max_steps)
if __name__ == "__main__":
main()
PIXEL VERSION DQN (accesses pixel directly)
#!/usr/bin/env python3
"""
Arcade Catch (pixels) + DQN (ConvNet) with experience replay and target network.
- Learns from raw grid images (1 channel) optionally stacked over K frames.
- Device auto-detects CUDA / Apple MPS / CPU.
- Visualizes training curve and plays greedy episodes after training.
Run:
pip install numpy matplotlib torch
python arcade_dqn.py
Tip: For better results, increase TRAIN_EPISODES and REPLAY_SIZE.
"""
import math
import random
import time
from collections import deque, namedtuple
from typing import Deque, List, Tuple
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
# =========================
# Environment: Catch (pixels)
# =========================
class CatchEnv:
"""
A tiny arcade-like game rendered as pixels (H x W single-channel image with values {0,1,2}).
- Ball falls downward; paddle moves left/stay/right.
- Reward +1 for catch, -1 otherwise at episode end.
- Observation: np.ndarray[H, W] with integers 0(empty),1(paddle),2(ball).
"""
def __init__(self, width=7, height=7, paddle_speed=1, seed=0):
self.W = width
self.H = height
self.paddle_speed = paddle_speed
self.rng = random.Random(seed)
self.reset()
def reset(self):
self.ball_x = self.rng.randrange(self.W)
self.ball_y = self.H - 1
self.paddle_x = self.W // 2
return self._obs()
def step(self, action:int):
# action in {-1,0,+1} encoded as 0,1,2 externally; we'll map to -1,0,+1 here.
mapped = [-1, 0, 1][action]
self.paddle_x = max(0, min(self.W-1, self.paddle_x + mapped * self.paddle_speed))
self.ball_y -= 1
done = False
reward = 0.0
if self.ball_y < 0:
done = True
reward = 1.0 if self.paddle_x == self.ball_x else -1.0
return self._obs(), reward, done
def _obs(self):
grid = np.zeros((self.H, self.W), dtype=np.float32)
if 0 <= self.ball_y < self.H:
grid[self.ball_y, self.ball_x] = 2.0
grid[0, self.paddle_x] = 1.0
return grid
def render_frame(self):
return self._obs()
@property
def action_space_n(self):
return 3 # left, stay, right
# Optional wrapper to stack K consecutive frames -> encode motion
class FrameStack:
def __init__(self, env: CatchEnv, k: int = 2):
self.env = env
self.k = k
self.frames: Deque[np.ndarray] = deque(maxlen=k)
def reset(self):
obs = self.env.reset()
self.frames.clear()
for _ in range(self.k):
self.frames.append(obs)
return self._get_obs()
def step(self, action:int):
obs, r, done = self.env.step(action)
self.frames.append(obs)
return self._get_obs(), r, done
def _get_obs(self):
# shape: (k, H, W)
return np.stack(list(self.frames), axis=0)
def render_frame(self):
return self.env.render_frame()
@property
def action_space_n(self):
return self.env.action_space_n
@property
def H(self):
return self.env.H
@property
def W(self):
return self.env.W
# =========================
# DQN Model (ConvNet)
# =========================
class ConvDQN(nn.Module):
def __init__(self, in_channels: int, height: int, width: int, num_actions: int):
super().__init__()
# Small conv net tailored for tiny 7x7 to ~15x15 inputs
self.body = nn.Sequential(
nn.Conv2d(in_channels, 32, kernel_size=3, padding=1),
nn.ReLU(),
nn.Conv2d(32, 64, kernel_size=3, padding=1),
nn.ReLU(),
)
# compute conv output size
with torch.no_grad():
dummy = torch.zeros(1, in_channels, height, width)
conv_out = self.body(dummy).view(1, -1).size(1)
self.head = nn.Sequential(
nn.Linear(conv_out, 128),
nn.ReLU(),
nn.Linear(128, num_actions),
)
def forward(self, x):
x = self.body(x)
x = x.view(x.size(0), -1)
return self.head(x)
# =========================
# Replay Buffer
# =========================
Transition = namedtuple("Transition", ("state", "action", "reward", "next_state", "done"))
class ReplayBuffer:
def __init__(self, capacity: int, device: torch.device):
self.capacity = capacity
self.device = device
self.memory: Deque[Transition] = deque(maxlen=capacity)
def push(self, *args):
self.memory.append(Transition(*args))
def __len__(self):
return len(self.memory)
def sample(self, batch_size: int):
batch = random.sample(self.memory, batch_size)
# Stack into tensors
state = torch.tensor(np.stack([b.state for b in batch], axis=0), dtype=torch.float32, device=self.device)
action = torch.tensor([b.action for b in batch], dtype=torch.long, device=self.device).unsqueeze(1)
reward = torch.tensor([b.reward for b in batch], dtype=torch.float32, device=self.device).unsqueeze(1)
next_state = torch.tensor(np.stack([b.next_state for b in batch], axis=0), dtype=torch.float32, device=self.device)
done = torch.tensor([b.done for b in batch], dtype=torch.float32, device=self.device).unsqueeze(1)
return state, action, reward, next_state, done
# =========================
# Utilities
# =========================
def get_device():
if torch.cuda.is_available():
return torch.device("cuda")
# Apple MPS
if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
return torch.device("mps")
return torch.device("cpu")
def epsilon_by_episode(ep, eps_start: float, eps_end: float, eps_decay: float):
# Exponential decay over episodes
return max(eps_end, eps_start * (eps_decay ** ep))
# =========================
# Training & Evaluation
# =========================
def select_action(model: nn.Module, state: np.ndarray, eps: float, n_actions: int, device: torch.device):
if random.random() < eps:
return random.randrange(n_actions)
with torch.no_grad():
s = torch.tensor(state, dtype=torch.float32, device=device).unsqueeze(0) # (1,C,H,W)
q = model(s)
return int(q.argmax(dim=1).item())
def evaluate(env: FrameStack, model: nn.Module, device: torch.device, episodes: int = 50, max_steps: int = 50):
model.eval()
total = 0.0
with torch.no_grad():
for _ in range(episodes):
s = env.reset()
ep_r = 0.0
for _ in range(max_steps):
a = select_action(model, s, eps=0.0, n_actions=env.action_space_n, device=device)
s, r, done = env.step(a)
ep_r += r
if done:
break
total += ep_r
model.train()
return total / episodes
def plot_learning(reward_history, moving_avg, eval_points, eval_scores):
plt.figure(figsize=(8,4))
plt.plot(reward_history, label="Episode reward")
plt.plot(moving_avg, label="Moving avg (50)")
if eval_points:
plt.scatter(eval_points, eval_scores, marker='x', label="Greedy eval avg")
plt.xlabel("Episode")
plt.ylabel("Reward")
plt.title("DQN on Catch (pixels)")
plt.legend()
plt.tight_layout()
plt.show()
def play_and_render(env: FrameStack, model: nn.Module, device: torch.device, episodes: int = 3, max_steps: int = 50):
for ep in range(1, episodes+1):
s = env.reset()
frames = []
rewards = 0.0
for _ in range(max_steps):
frames.append(env.render_frame())
a = select_action(model, s, eps=0.0, n_actions=env.action_space_n, device=device)
s, r, done = env.step(a)
rewards += r
if done:
frames.append(env.render_frame())
break
for i, frame in enumerate(frames):
plt.figure(figsize=(3,3))
plt.imshow(frame[::-1, :], vmin=0, vmax=2)
plt.xticks(range(env.W)); plt.yticks(range(env.H))
plt.title(f"Episode {ep} â frame {i+1}/{len(frames)}")
plt.grid(True); plt.tight_layout(); plt.show()
plt.figure(figsize=(3,0.6))
plt.text(0.5, 0.5, f"Episode {ep} result: {'CATCH!' if rewards>0 else 'MISS'}", ha='center', va='center')
plt.axis('off'); plt.tight_layout(); plt.show()
def train_dqn(
seed: int = 42,
WIDTH: int = 7,
HEIGHT: int = 7,
FRAME_STACK: int = 2,
TRAIN_EPISODES: int = 3000,
MAX_STEPS: int = 10, # ~H+3; ball reaches bottom
BATCH_SIZE: int = 64,
REPLAY_SIZE: int = 5000,
START_LEARNING_AFTER: int = 500,
GAMMA: float = 0.99,
LR: float = 1e-3,
TARGET_UPDATE_EVERY: int = 200, # steps
EPS_START: float = 1.0,
EPS_END: float = 0.05,
EPS_DECAY: float = 0.995, # per episode
EVAL_EVERY: int = 200,
EVAL_EPISODES: int = 100,
SAVE_PATH: str = "dqn_catch.pt",
):
random.seed(seed); np.random.seed(seed); torch.manual_seed(seed)
base_env = CatchEnv(width=WIDTH, height=HEIGHT, seed=seed)
env = FrameStack(base_env, k=FRAME_STACK)
device = get_device()
print(f"[INFO] Device: {device}")
in_channels = FRAME_STACK
n_actions = env.action_space_n
policy_net = ConvDQN(in_channels, HEIGHT, WIDTH, n_actions).to(device)
target_net = ConvDQN(in_channels, HEIGHT, WIDTH, n_actions).to(device)
target_net.load_state_dict(policy_net.state_dict())
target_net.eval()
optimizer = optim.Adam(policy_net.parameters(), lr=LR)
replay = ReplayBuffer(REPLAY_SIZE, device=device)
reward_history: List[float] = []
moving_avg: List[float] = []
eval_points: List[int] = []
eval_scores: List[float] = []
ma_window: Deque[float] = deque(maxlen=50)
global_step = 0
for ep in range(1, TRAIN_EPISODES+1):
s = env.reset()
total_r = 0.0
eps = epsilon_by_episode(ep, EPS_START, EPS_END, EPS_DECAY)
for t in range(MAX_STEPS):
a = select_action(policy_net, s, eps, n_actions, device)
s_next, r, done = env.step(a)
replay.push(s, a, r, s_next, float(done))
s = s_next
total_r += r
global_step += 1
# Learn
if len(replay) >= max(BATCH_SIZE, START_LEARNING_AFTER):
state, action, reward, next_state, done_mask = replay.sample(BATCH_SIZE)
# Q(s,a)
q_values = policy_net(state).gather(1, action) # [B,1]
with torch.no_grad():
# Double DQN trick (optional here): act by policy net, evaluate by target net
next_actions = policy_net(next_state).argmax(dim=1, keepdim=True)
next_q = target_net(next_state).gather(1, next_actions)
target = reward + (1.0 - done_mask) * GAMMA * next_q
loss = nn.SmoothL1Loss()(q_values, target)
optimizer.zero_grad()
loss.backward()
nn.utils.clip_grad_norm_(policy_net.parameters(), 5.0)
optimizer.step()
# Target network update
if (global_step % TARGET_UPDATE_EVERY) == 0:
target_net.load_state_dict(policy_net.state_dict())
if done:
break
reward_history.append(total_r)
ma_window.append(total_r)
moving_avg.append(sum(ma_window)/len(ma_window))
if (ep % EVAL_EVERY) == 0:
avg = evaluate(env, policy_net, device, episodes=EVAL_EPISODES, max_steps=MAX_STEPS)
eval_points.append(ep)
eval_scores.append(avg)
print(f"[EVAL] Ep {ep} avg reward (greedy): {avg:.3f} eps={eps:.3f}")
# Save model
torch.save({"model": policy_net.state_dict(),
"in_channels": in_channels,
"H": HEIGHT,
"W": WIDTH,
"n_actions": n_actions}, SAVE_PATH)
print(f"[INFO] Saved model to {SAVE_PATH}")
# Plots and playback
plot_learning(reward_history, moving_avg, eval_points, eval_scores)
play_and_render(env, policy_net, device, episodes=3, max_steps=MAX_STEPS)
def main():
# Defaults are tuned for demo speed. For stronger learning bump TRAIN_EPISODES to 10k+.
train_dqn(
TRAIN_EPISODES=3000,
REPLAY_SIZE=8000,
START_LEARNING_AFTER=500,
MAX_STEPS=10, # ~height + few steps
FRAME_STACK=2,
TARGET_UPDATE_EVERY=200,
EVAL_EVERY=300,
EVAL_EPISODES=100,
LR=1e-3,
)
if __name__ == "__main__":
main()
No comments:
Post a Comment