Saturday, December 06, 2025

Catch Me If You Can: Reinforcement Learning in a Minimal Arcade World



Introduction


Reinforcement Learning (RL) has become one of the most talked-about subfields of Artificial Intelligence, not only in academia but increasingly in engineering practice. Whenever autonomous systems adapt in uncertain environments — whether a robot learns to walk, a fleet of taxis self-organizes, or a data center balances workloads — RL concepts play a role. Yet many software engineers view RL as an impenetrable forest of mathematics, neural networks, and jargon.


This article proposes a different approach: we shrink the world to the size of a tiny arcade game. In this minimalist world, the essential RL principles appear clearly, free of the distractions of industrial complexity. By the end, you will not only understand the famous Q-learning algorithm but also appreciate why reinforcement learning matters for software engineers building the next generation of adaptive systems.


The Arcade World: A Minimal Environment


The game is called Catch. Imagine a grid of width W and height H. A ball appears at the top row and falls straight down, one step at a time. At the bottom, a paddle moves horizontally: left, right, or stay. If the paddle aligns with the ball when it hits the ground, the player scores; otherwise, they fail.


+-------------------+   y = H -1

|                   |

|         (ball)    |

|                   |

|                   |

|                   |

|                   |

|      [paddle]     |   y = 0

+-------------------+


The rules are minimal, the dynamics predictable. And yet, this world suffices to illustrate everything we need: states, actions, rewards, exploration, exploitation, and policy learning.



The Feedback Loop: Agent and Environment


Reinforcement learning is not about training with labeled data or finding clusters. It is about interaction. An agent repeatedly acts in an environment and is judged by the consequences. The cycle looks like this:


      ┌─────────────┐

      │             

      │   Agent     

      │             

      └─────┬───────┘

            │ action a_t

            

      ┌─────────────┐

      │             

      │ Environment │

      │             

      └─────┬───────┘

   reward r │ state s_{t+1}

            


In Catch:

State (s): the ball’s x,y position and the paddle’s x position.

Action (a): move left, stay, move right.

Reward (r): +1 if caught, −1 if missed, 0 otherwise.


This loop repeats until the ball reaches the bottom. Then the episode ends, and the next begins.



Formalizing the Challenge: The MDP


Professionals like rigor. RL is formally described as a Markov Decision Process (MDP). An MDP is defined by:

A set of states S.

A set of actions A.

A transition function P(s’|s,a).

A reward function R(s,a).

A discount factor γ  [0,1].


The Markov property states that the next state depends only on the current state and action, not on the full history. In Catch, this property holds: the ball’s and paddle’s current positions fully determine the future.


The Dilemma: Exploration vs. Exploitation


Engineers face trade-offs daily: should we refactor for long-term health or ship features now? RL encapsulates a similar tension: explore new actions or exploit known good actions?


ASCII sketch of the dilemma:


Exploration: try something new ──► possible higher future reward

Exploitation: stick to best known ──► reliable immediate reward


The standard technique is ε-greedy:


if random() < ε:

    choose random action   # exploration

else:

    choose argmax_a Q(s,a) # exploitation



At the start of training, ε is high (lots of exploration). Over episodes, ε decays, allowing convergence to exploitation. The decay is usually exponential or linear, as illustrated below:


ε

^

|\

| \

 \

|   \                      Exploration ↓

|    \__________________   Exploitation ↑

+-----------------------> Episodes


This curve reminds us that exploration is front-loaded: in early episodes, randomness dominates, helping the agent discover options. Later, exploration dwindles, stabilizing behavior into exploitation of the learned policy.


The Central Idea: Action-Value Functions


The agent must answer: How good is it to take action a in state s? This is quantified by the Q-function:


Q(s,a) = expected cumulative reward from state s

          taking action a, then following the best policy


If Q(s, right) > Q(s, left), the agent knows moving right is more promising.



The Q-Learning Update



Q-learning is an off-policy algorithm: it learns the optimal Q-function even while exploring. The update rule:


Q(s,a) ← Q(s,a) + α [ r + γ * max_{a'} Q(s',a') − Q(s,a) ]


ASCII depiction:


               ┌──────────────────────────────┐

       │  Old estimate: Q(s,a)        

       └──────────────────────────────┘

                   

                   

 target = r + γ * max_a' Q(s',a')

                   

                   

 New Q(s,a) = Old + α * (target - Old)


The learning rate α controls how much new information overrides old beliefs. The discount factor γ values immediate vs. future rewards.



The Learning Curve: From Chaos to Competence


Engineers trust metrics. During training, we record:

Episode reward: total reward per episode.

Moving average: smoothed reward trend.

Evaluation: average greedy performance without exploration.


An ASCII illustration:


Reward

^

|            xxxx

|          xx    xx

|        xx        xxx

|    xxxx              xxx

+--------------------------------> Episodes


Early on, rewards oscillate around zero. As learning progresses, the curve climbs toward consistent positive values.



Seeing is Believing: Visual Replays


Numbers tell part of the story. But showing the agent’s behavior completes the picture. After training, we run the agent greedily and render each frame:


Frame 1:    Frame 2:    Frame 3:
            (ball)      (ball)
  (ball)                 .
             .           .
  [paddle]   [paddle]   [paddle]


The paddle moves into position, anticipating the ball’s landing. This visual confirmation reinforces the metrics.



Why This Matters for Software Engineers


So far, this may look like a toy. But the concepts generalize:

Adaptive resource allocation: RL can manage virtual machines in a cloud cluster, scaling resources up or down.

Traffic optimization: signals adjust in real-time, balancing flow.

Recommendation systems: explore new content but exploit user preferences.


These systems are more complex than Catch, but the feedback loop, the exploration–exploitation trade-off, and the Q-update are the same.



Lessons for Professionals

1. Reward shaping is design: Get the reward wrong, and the agent learns undesirable behavior.

2. Exploration is costly but necessary: Systems may underperform temporarily to discover better policies.

3. Policies emerge, not are programmed: RL shifts mindset from deterministic coding to probabilistic optimization.

4. Scaling requires function approximation: Tabular Q-learning suffices for Catch, but deep neural networks (DQN) are needed for larger spaces.




Beyond Tabular Q-Learning: The DQN Leap


Our tabular example does not scale. In modern practice, engineers use Deep Q-Networks (DQNs), where Q(s,a) is approximated by a neural network trained on replayed experiences. The architecture:


   State (pixels) ──► ConvNet ──► Fully connected ──► Q-values for actions


This is the method that allowed RL agents to reach human-level performance on Atari games. The essence is unchanged; the function approximator is more powerful.




Risks and Considerations


Professional deployments must consider:

Convergence: RL can diverge if hyperparameters are mis-set.

Safety: Exploration in real systems may cause harm.

Explainability: Policies learned by Q-functions may be hard to interpret.


Engineers must combine RL with simulation environments, guardrails, and monitoring.



Conclusion: From Arcade to Industry


By observing a paddle learn to catch a ball, we witness reinforcement learning in its purest form. The lesson for software engineers is not the specifics of tabular Q-tables but the architectural mindset: systems can be built to learn from feedback, balancing exploration and exploitation, updating expectations, and improving autonomously.


Today, it’s an arcade toy. Tomorrow, it could be your cloud infrastructure, your production pipeline, or your autonomous vehicle. The cycle is always the same:


observe state → choose action → receive reward → update Q → improve policy


Understanding this cycle equips engineers to participate in the next wave of adaptive software systems.




SOURCE CODE


„NORMAL“ VERSION


#!/usr/bin/env python3

"""

Arcade Catch + Q-learning (tabular) with live training curve and episode replays.


How to run:

  pip install numpy matplotlib



import math

import random

from collections import defaultdict, deque

from typing import Dict, Tuple, List


import numpy as np

import matplotlib.pyplot as plt



# =========================

# Environment: Catch Game

# =========================

class CatchEnv:

    """

    A tiny arcade-like game:

    - Discrete grid of width W and height H.

    - A ball spawns at the top row (y=H-1) at a random x and falls straight down.

    - A paddle sits on the bottom row (y=0) and can move LEFT (-1), STAY (0), or RIGHT (+1).

    - Episode ends when the ball leaves the grid; reward is +1 if paddle x == ball x, else -1.


    State representation: (ball_x, ball_y, paddle_x)

    """


    def __init__(self, width: int = 7, height: int = 7, paddle_speed: int = 1, seed: int = 0):

        self.W = width

        self.H = height

        self.paddle_speed = paddle_speed

        self.rng = random.Random(seed)

        self.reset()


    def reset(self) -> Tuple[int, int, int]:

        self.ball_x = self.rng.randrange(self.W)

        self.ball_y = self.H - 1

        self.paddle_x = self.W // 2

        return (self.ball_x, self.ball_y, self.paddle_x)


    def step(self, action: int) -> Tuple[Tuple[int, int, int], int, bool]:

        """

        Apply action in {-1, 0, +1}. Paddle is clamped to [0, W-1].

        Ball falls by one row. If ball leaves the grid, episode ends and reward is given.

        Returns: (next_state, reward, done)

        """

        self.paddle_x = int(np.clip(self.paddle_x + action * self.paddle_speed, 0, self.W - 1))

        self.ball_y -= 1


        done = False

        reward = 0

        if self.ball_y < 0:

            done = True

            reward = 1 if self.paddle_x == self.ball_x else -1


        state = (self.ball_x, max(self.ball_y, 0), self.paddle_x)

        return state, reward, done


    def render_frame(self) -> np.ndarray:

        """

        Return a 2D array with ball=2, paddle=1, empty=0 for visualization.

        """

        grid = np.zeros((self.H, self.W), dtype=int)

        if 0 <= self.ball_y < self.H:

            grid[self.ball_y, self.ball_x] = 2

        grid[0, self.paddle_x] = 1

        return grid



# =========================

# Q-learning Agent

# =========================

class QAgent:

    """

    Tabular Q-learning agent with epsilon-greedy exploration.

    Q is stored as a dict: state -> {action: value}

    """


    def __init__(

        self,

        width: int,

        height: int,

        actions: Tuple[int, ...] = (-1, 0, 1),

        alpha: float = 0.2,

        gamma: float = 0.98,

        eps_start: float = 1.0,

        eps_end: float = 0.05,

        eps_decay: float = 0.995,

        seed: int = 0,

    ):

        self.W = width

        self.H = height

        self.actions = actions

        self.alpha = alpha

        self.gamma = gamma

        self.eps = eps_start

        self.eps_end = eps_end

        self.eps_decay = eps_decay

        self.rng = random.Random(seed)


        self.Q: Dict[Tuple[int, int, int], Dict[int, float]] = defaultdict(

            lambda: {a: 0.0 for a in self.actions}

        )


    def policy(self, state: Tuple[int, int, int]) -> int:

        """Epsilon-greedy action selection."""

        if self.rng.random() < self.eps:

            return self.rng.choice(self.actions)

        qvals = self.Q[state]

        max_q = max(qvals.values())

        best_actions = [a for a, q in qvals.items() if q == max_q]

        return self.rng.choice(best_actions)


    def greedy_action(self, state: Tuple[int, int, int]) -> int:

        """Greedy (argmax) action selection."""

        qvals = self.Q[state]

        max_q = max(qvals.values())

        best_actions = [a for a, q in qvals.items() if q == max_q]

        return self.rng.choice(best_actions)


    def update(self, s: Tuple[int, int, int], a: int, r: int, s_next: Tuple[int, int, int], done: bool) -> None:

        """Tabular Q-learning update."""

        qsa = self.Q[s][a]

        target = r if done else r + self.gamma * max(self.Q[s_next].values())

        self.Q[s][a] = qsa + self.alpha * (target - qsa)


    def decay_epsilon(self) -> None:

        self.eps = max(self.eps_end, self.eps * self.eps_decay)



# =========================

# Training & Evaluation

# =========================

def train(

    env: CatchEnv,

    agent: QAgent,

    episodes: int = 2000,

    max_steps: int = 50,

    eval_every: int = 200,

    eval_episodes: int = 100,

) -> Tuple[List[float], List[float], List[Tuple[int, float]]]:

    """

    Train with epsilon-greedy Q-learning. Return (rewards, moving_avg, eval_scores).

    eval_scores is a list of (episode_index, avg_greedy_reward).

    """

    rewards: List[float] = []

    moving_avg: List[float] = []

    eval_scores: List[Tuple[int, float]] = []

    window = deque(maxlen=50)


    for ep in range(1, episodes + 1):

        s = env.reset()

        total = 0.0

        for _ in range(max_steps):

            a = agent.policy(s)

            s_next, r, done = env.step(a)

            agent.update(s, a, r, s_next, done)

            s = s_next

            total += r

            if done:

                break


        agent.decay_epsilon()

        rewards.append(total)

        window.append(total)

        moving_avg.append(sum(window) / len(window))


        if eval_every and (ep % eval_every) == 0:

            avg_r = evaluate(env, agent, episodes=eval_episodes, max_steps=max_steps)

            eval_scores.append((ep, avg_r))


    return rewards, moving_avg, eval_scores



def evaluate(env: CatchEnv, agent: QAgent, episodes: int = 50, max_steps: int = 50) -> float:

    """Evaluate the greedy policy (epsilon=0). Return average episode reward."""

    old_eps = agent.eps

    agent.eps = 0.0

    total = 0.0

    for _ in range(episodes):

        s = env.reset()

        ep_r = 0.0

        for _ in range(max_steps):

            a = agent.greedy_action(s)

            s, r, done = env.step(a)

            ep_r += r

            if done:

                break

        total += ep_r

    agent.eps = old_eps

    return total / episodes



# =========================

# Visualization

# =========================

def show_learning_curve(rewards: List[float], moving_avg: List[float], eval_scores: List[Tuple[int, float]]) -> None:

    plt.figure(figsize=(8, 4))

    plt.plot(rewards, label="Episode reward")

    plt.plot(moving_avg, label="Moving average (last 50)")

    if eval_scores:

        xs, ys = zip(*eval_scores)

        plt.scatter(xs, ys, label="Greedy eval avg", marker="x")

    plt.xlabel("Episode")

    plt.ylabel("Reward")

    plt.title("Q-learning on Catch: training progress")

    plt.legend()

    plt.tight_layout()

    plt.show()



def play_and_render(env: CatchEnv, agent: QAgent, episodes: int = 3, max_steps: int = 50) -> None:

    """

    Play a few episodes with the greedy policy and render frames using matplotlib.

    """

    old_eps = agent.eps

    agent.eps = 0.0


    for ep in range(1, episodes + 1):

        s = env.reset()

        frames: List[np.ndarray] = []

        rewards = 0

        for _ in range(max_steps):

            frames.append(env.render_frame())

            a = agent.greedy_action(s)

            s, r, done = env.step(a)

            rewards += r

            if done:

                frames.append(env.render_frame())

                break


        # Show frames one-by-one

        for i, frame in enumerate(frames):

            plt.figure(figsize=(3, 3))

            # invert Y so bottom row plots at bottom

            plt.imshow(frame[::-1, :], vmin=0, vmax=2)

            plt.xticks(range(env.W))

            plt.yticks(range(env.H))

            plt.title(f"Episode {ep} – frame {i+1}/{len(frames)}")

            plt.grid(True)

            plt.tight_layout()

            plt.show()


        # Episode result

        res = "CATCH!" if rewards > 0 else "MISS"

        plt.figure(figsize=(3, 0.6))

        plt.text(0.5, 0.5, f"Episode {ep} result: {res}", ha="center", va="center")

        plt.axis("off")

        plt.tight_layout()

        plt.show()


    agent.eps = old_eps



# =========================

# Main

# =========================

def main() -> None:

    W, H = 7, 7

    env = CatchEnv(width=W, height=H, seed=42)

    agent = QAgent(width=W, height=H, alpha=0.25, gamma=0.98, eps_start=1.0, eps_end=0.05, eps_decay=0.995)


    # Episodes ~= grid height + a little slack so ball reaches bottom

    episodes = 2000

    max_steps = H + 2


    rewards, moving_avg, eval_scores = train(

        env, agent, episodes=episodes, max_steps=max_steps, eval_every=200, eval_episodes=100

    )


    show_learning_curve(rewards, moving_avg, eval_scores)

    play_and_render(env, agent, episodes=3, max_steps=max_steps)



if __name__ == "__main__":

    main()




PIXEL VERSION DQN (accesses pixel directly)


#!/usr/bin/env python3

"""

Arcade Catch (pixels) + DQN (ConvNet) with experience replay and target network.

- Learns from raw grid images (1 channel) optionally stacked over K frames.

- Device auto-detects CUDA / Apple MPS / CPU.

- Visualizes training curve and plays greedy episodes after training.


Run:

  pip install numpy matplotlib torch

  python arcade_dqn.py


Tip: For better results, increase TRAIN_EPISODES and REPLAY_SIZE.

"""


import math

import random

import time

from collections import deque, namedtuple

from typing import Deque, List, Tuple


import numpy as np

import torch

import torch.nn as nn

import torch.optim as optim

import matplotlib.pyplot as plt



# =========================

# Environment: Catch (pixels)

# =========================

class CatchEnv:

    """

    A tiny arcade-like game rendered as pixels (H x W single-channel image with values {0,1,2}).

    - Ball falls downward; paddle moves left/stay/right.

    - Reward +1 for catch, -1 otherwise at episode end.

    - Observation: np.ndarray[H, W] with integers 0(empty),1(paddle),2(ball).

    """

    def __init__(self, width=7, height=7, paddle_speed=1, seed=0):

        self.W = width

        self.H = height

        self.paddle_speed = paddle_speed

        self.rng = random.Random(seed)

        self.reset()


    def reset(self):

        self.ball_x = self.rng.randrange(self.W)

        self.ball_y = self.H - 1

        self.paddle_x = self.W // 2

        return self._obs()


    def step(self, action:int):

        # action in {-1,0,+1} encoded as 0,1,2 externally; we'll map to -1,0,+1 here.

        mapped = [-1, 0, 1][action]

        self.paddle_x = max(0, min(self.W-1, self.paddle_x + mapped * self.paddle_speed))

        self.ball_y -= 1


        done = False

        reward = 0.0

        if self.ball_y < 0:

            done = True

            reward = 1.0 if self.paddle_x == self.ball_x else -1.0


        return self._obs(), reward, done


    def _obs(self):

        grid = np.zeros((self.H, self.W), dtype=np.float32)

        if 0 <= self.ball_y < self.H:

            grid[self.ball_y, self.ball_x] = 2.0

        grid[0, self.paddle_x] = 1.0

        return grid


    def render_frame(self):

        return self._obs()


    @property

    def action_space_n(self):

        return 3  # left, stay, right



# Optional wrapper to stack K consecutive frames -> encode motion

class FrameStack:

    def __init__(self, env: CatchEnv, k: int = 2):

        self.env = env

        self.k = k

        self.frames: Deque[np.ndarray] = deque(maxlen=k)


    def reset(self):

        obs = self.env.reset()

        self.frames.clear()

        for _ in range(self.k):

            self.frames.append(obs)

        return self._get_obs()


    def step(self, action:int):

        obs, r, done = self.env.step(action)

        self.frames.append(obs)

        return self._get_obs(), r, done


    def _get_obs(self):

        # shape: (k, H, W)

        return np.stack(list(self.frames), axis=0)


    def render_frame(self):

        return self.env.render_frame()


    @property

    def action_space_n(self):

        return self.env.action_space_n


    @property

    def H(self):

        return self.env.H


    @property

    def W(self):

        return self.env.W



# =========================

# DQN Model (ConvNet)

# =========================

class ConvDQN(nn.Module):

    def __init__(self, in_channels: int, height: int, width: int, num_actions: int):

        super().__init__()

        # Small conv net tailored for tiny 7x7 to ~15x15 inputs

        self.body = nn.Sequential(

            nn.Conv2d(in_channels, 32, kernel_size=3, padding=1),

            nn.ReLU(),

            nn.Conv2d(32, 64, kernel_size=3, padding=1),

            nn.ReLU(),

        )

        # compute conv output size

        with torch.no_grad():

            dummy = torch.zeros(1, in_channels, height, width)

            conv_out = self.body(dummy).view(1, -1).size(1)

        self.head = nn.Sequential(

            nn.Linear(conv_out, 128),

            nn.ReLU(),

            nn.Linear(128, num_actions),

        )


    def forward(self, x):

        x = self.body(x)

        x = x.view(x.size(0), -1)

        return self.head(x)



# =========================

# Replay Buffer

# =========================

Transition = namedtuple("Transition", ("state", "action", "reward", "next_state", "done"))


class ReplayBuffer:

    def __init__(self, capacity: int, device: torch.device):

        self.capacity = capacity

        self.device = device

        self.memory: Deque[Transition] = deque(maxlen=capacity)


    def push(self, *args):

        self.memory.append(Transition(*args))


    def __len__(self):

        return len(self.memory)


    def sample(self, batch_size: int):

        batch = random.sample(self.memory, batch_size)

        # Stack into tensors

        state = torch.tensor(np.stack([b.state for b in batch], axis=0), dtype=torch.float32, device=self.device)

        action = torch.tensor([b.action for b in batch], dtype=torch.long, device=self.device).unsqueeze(1)

        reward = torch.tensor([b.reward for b in batch], dtype=torch.float32, device=self.device).unsqueeze(1)

        next_state = torch.tensor(np.stack([b.next_state for b in batch], axis=0), dtype=torch.float32, device=self.device)

        done = torch.tensor([b.done for b in batch], dtype=torch.float32, device=self.device).unsqueeze(1)

        return state, action, reward, next_state, done



# =========================

# Utilities

# =========================

def get_device():

    if torch.cuda.is_available():

        return torch.device("cuda")

    # Apple MPS

    if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():

        return torch.device("mps")

    return torch.device("cpu")



def epsilon_by_episode(ep, eps_start: float, eps_end: float, eps_decay: float):

    # Exponential decay over episodes

    return max(eps_end, eps_start * (eps_decay ** ep))



# =========================

# Training & Evaluation

# =========================

def select_action(model: nn.Module, state: np.ndarray, eps: float, n_actions: int, device: torch.device):

    if random.random() < eps:

        return random.randrange(n_actions)

    with torch.no_grad():

        s = torch.tensor(state, dtype=torch.float32, device=device).unsqueeze(0)  # (1,C,H,W)

        q = model(s)

        return int(q.argmax(dim=1).item())


def evaluate(env: FrameStack, model: nn.Module, device: torch.device, episodes: int = 50, max_steps: int = 50):

    model.eval()

    total = 0.0

    with torch.no_grad():

        for _ in range(episodes):

            s = env.reset()

            ep_r = 0.0

            for _ in range(max_steps):

                a = select_action(model, s, eps=0.0, n_actions=env.action_space_n, device=device)

                s, r, done = env.step(a)

                ep_r += r

                if done:

                    break

            total += ep_r

    model.train()

    return total / episodes


def plot_learning(reward_history, moving_avg, eval_points, eval_scores):

    plt.figure(figsize=(8,4))

    plt.plot(reward_history, label="Episode reward")

    plt.plot(moving_avg, label="Moving avg (50)")

    if eval_points:

        plt.scatter(eval_points, eval_scores, marker='x', label="Greedy eval avg")

    plt.xlabel("Episode")

    plt.ylabel("Reward")

    plt.title("DQN on Catch (pixels)")

    plt.legend()

    plt.tight_layout()

    plt.show()


def play_and_render(env: FrameStack, model: nn.Module, device: torch.device, episodes: int = 3, max_steps: int = 50):

    for ep in range(1, episodes+1):

        s = env.reset()

        frames = []

        rewards = 0.0

        for _ in range(max_steps):

            frames.append(env.render_frame())

            a = select_action(model, s, eps=0.0, n_actions=env.action_space_n, device=device)

            s, r, done = env.step(a)

            rewards += r

            if done:

                frames.append(env.render_frame())

                break


        for i, frame in enumerate(frames):

            plt.figure(figsize=(3,3))

            plt.imshow(frame[::-1, :], vmin=0, vmax=2)

            plt.xticks(range(env.W)); plt.yticks(range(env.H))

            plt.title(f"Episode {ep} – frame {i+1}/{len(frames)}")

            plt.grid(True); plt.tight_layout(); plt.show()


        plt.figure(figsize=(3,0.6))

        plt.text(0.5, 0.5, f"Episode {ep} result: {'CATCH!' if rewards>0 else 'MISS'}", ha='center', va='center')

        plt.axis('off'); plt.tight_layout(); plt.show()



def train_dqn(

    seed: int = 42,

    WIDTH: int = 7,

    HEIGHT: int = 7,

    FRAME_STACK: int = 2,

    TRAIN_EPISODES: int = 3000,

    MAX_STEPS: int = 10,  # ~H+3; ball reaches bottom

    BATCH_SIZE: int = 64,

    REPLAY_SIZE: int = 5000,

    START_LEARNING_AFTER: int = 500,

    GAMMA: float = 0.99,

    LR: float = 1e-3,

    TARGET_UPDATE_EVERY: int = 200,   # steps

    EPS_START: float = 1.0,

    EPS_END: float = 0.05,

    EPS_DECAY: float = 0.995,         # per episode

    EVAL_EVERY: int = 200,

    EVAL_EPISODES: int = 100,

    SAVE_PATH: str = "dqn_catch.pt",

):

    random.seed(seed); np.random.seed(seed); torch.manual_seed(seed)

    base_env = CatchEnv(width=WIDTH, height=HEIGHT, seed=seed)

    env = FrameStack(base_env, k=FRAME_STACK)


    device = get_device()

    print(f"[INFO] Device: {device}")


    in_channels = FRAME_STACK

    n_actions = env.action_space_n


    policy_net = ConvDQN(in_channels, HEIGHT, WIDTH, n_actions).to(device)

    target_net = ConvDQN(in_channels, HEIGHT, WIDTH, n_actions).to(device)

    target_net.load_state_dict(policy_net.state_dict())

    target_net.eval()


    optimizer = optim.Adam(policy_net.parameters(), lr=LR)

    replay = ReplayBuffer(REPLAY_SIZE, device=device)


    reward_history: List[float] = []

    moving_avg: List[float] = []

    eval_points: List[int] = []

    eval_scores: List[float] = []

    ma_window: Deque[float] = deque(maxlen=50)


    global_step = 0

    for ep in range(1, TRAIN_EPISODES+1):

        s = env.reset()

        total_r = 0.0


        eps = epsilon_by_episode(ep, EPS_START, EPS_END, EPS_DECAY)


        for t in range(MAX_STEPS):

            a = select_action(policy_net, s, eps, n_actions, device)

            s_next, r, done = env.step(a)


            replay.push(s, a, r, s_next, float(done))

            s = s_next

            total_r += r

            global_step += 1


            # Learn

            if len(replay) >= max(BATCH_SIZE, START_LEARNING_AFTER):

                state, action, reward, next_state, done_mask = replay.sample(BATCH_SIZE)


                # Q(s,a)

                q_values = policy_net(state).gather(1, action)  # [B,1]


                with torch.no_grad():

                    # Double DQN trick (optional here): act by policy net, evaluate by target net

                    next_actions = policy_net(next_state).argmax(dim=1, keepdim=True)

                    next_q = target_net(next_state).gather(1, next_actions)

                    target = reward + (1.0 - done_mask) * GAMMA * next_q


                loss = nn.SmoothL1Loss()(q_values, target)


                optimizer.zero_grad()

                loss.backward()

                nn.utils.clip_grad_norm_(policy_net.parameters(), 5.0)

                optimizer.step()


            # Target network update

            if (global_step % TARGET_UPDATE_EVERY) == 0:

                target_net.load_state_dict(policy_net.state_dict())


            if done:

                break


        reward_history.append(total_r)

        ma_window.append(total_r)

        moving_avg.append(sum(ma_window)/len(ma_window))


        if (ep % EVAL_EVERY) == 0:

            avg = evaluate(env, policy_net, device, episodes=EVAL_EPISODES, max_steps=MAX_STEPS)

            eval_points.append(ep)

            eval_scores.append(avg)

            print(f"[EVAL] Ep {ep} avg reward (greedy): {avg:.3f}  eps={eps:.3f}")


    # Save model

    torch.save({"model": policy_net.state_dict(),

                "in_channels": in_channels,

                "H": HEIGHT,

                "W": WIDTH,

                "n_actions": n_actions}, SAVE_PATH)

    print(f"[INFO] Saved model to {SAVE_PATH}")


    # Plots and playback

    plot_learning(reward_history, moving_avg, eval_points, eval_scores)

    play_and_render(env, policy_net, device, episodes=3, max_steps=MAX_STEPS)



def main():

    # Defaults are tuned for demo speed. For stronger learning bump TRAIN_EPISODES to 10k+.

    train_dqn(

        TRAIN_EPISODES=3000,

        REPLAY_SIZE=8000,

        START_LEARNING_AFTER=500,

        MAX_STEPS=10,     # ~height + few steps

        FRAME_STACK=2,

        TARGET_UPDATE_EVERY=200,

        EVAL_EVERY=300,

        EVAL_EPISODES=100,

        LR=1e-3,

    )



if __name__ == "__main__":

    main()

No comments: