Monday, June 01, 2026

RecursiveMAS: Teaching AI Agents to Think Together in Secret


INTRODUCTION: THE PROBLEM WITH CHATTY AGENTS

If you have spent any time building multi-agent AI systems, you have probably run into the same frustrating pattern. You wire up a Planner agent, a Critic agent, and a Solver agent. The Planner writes a plan in plain English. The Critic reads that plan, writes a critique in plain English. The Solver reads the critique, writes an answer in plain English. The whole thing feels elegant on a whiteboard. Then you run it and discover that your system is spending enormous amounts of time and compute budget just generating, tokenizing, and re-reading intermediate text that nobody outside the system ever sees.

It is a bit like watching a relay race where each runner, instead of simply passing the baton, stops to write a detailed memo about the baton, hands the memo to the next runner, who then reads the memo, writes their own memo summarizing the first memo, and only then starts running. The overhead is absurd. The information is there — it is just wrapped in an expensive, lossy, text-shaped package.

This is the problem that the paper "RecursiveMAS: Scaling Agent Collaboration through Unified Latent-Space Recursive Computation" (see also the additional documentation and the official GitHub repository) sets out to solve. The authors, from Stanford and UIUC, ask a beautifully simple question: what if agents could skip the memo entirely and just pass the raw thought?

The answer they arrive at is RecursiveMAS, a framework that lets heterogeneous LLM agents collaborate entirely in the continuous vector space — the latent space — that lives inside the models themselves, without converting intermediate reasoning into text until the very last moment. The results are striking: an average accuracy improvement of 8.3% over strong baselines, inference speedups of 1.2x to 2.4x, and token usage reductions of 34.6% to 75.6% across nine benchmarks covering mathematics, science, medicine, search, and code generation.

This tutorial will walk you through every piece of the system, from the conceptual foundations to the mathematical details to working code that you can run against both local models (via Ollama or Hugging Face Transformers) and remote LLM APIs (via OpenAI-compatible endpoints). By the end, you will understand not just what RecursiveMAS does, but why it works, and how you can start building systems inspired by its ideas today.


INSTALLATION AND PROJECT SETUP

Before diving into the code, here is everything you need to install and how to structure your project.

Requirements

# requirements.txt
torch>=2.1.0
transformers>=4.40.0
accelerate>=0.27.0
requests>=2.31.0
sentencepiece>=0.1.99
protobuf>=3.20.0

Install with:

pip install -r requirements.txt

Optional: Ollama (for local model serving without writing HF loading code)

# Install Ollama from https://ollama.ai, then:
ollama serve                          # start the server (keep this running)
ollama pull qwen2.5:1.5b              # ~1 GB, fast on CPU
ollama pull llama3.2:1b               # ~700 MB, very fast on CPU
ollama pull qwen2.5:7b                # ~4.5 GB, good quality, needs GPU

Project File Structure

recursive_mas/
├── requirements.txt
├── recursive_mas.py          # All core classes (the full combined file)
├── demo.py                   # Demo script that imports from recursive_mas.py
└── README.md

All code blocks in this tutorial belong in recursive_mas.py in the order they appear, followed by the demo code in demo.py. A complete, self-contained combined file is provided at the end of Part Six.


CHAPTER ONE: UNDERSTANDING THE LANDSCAPE

Before we dive into RecursiveMAS itself, we need to make sure we share a common vocabulary. If you have built agentic systems before, some of this will be review, but the framing matters for what comes later.

What Is a Multi-Agent System, Really?

A multi-agent system (MAS) is a collection of individual language model agents, each assigned a distinct role or area of expertise, that collaborate to solve a task that would be difficult or impossible for any single agent alone. The intuition is that specialization helps. A model fine-tuned for mathematical reasoning will outperform a generalist model on math problems. A model trained on biomedical literature will do better on medical questions. By combining specialists, you get a system that is smarter than any of its parts.

The paper formalizes this nicely. You have a system S composed of N agents A₁, A₂, ..., Aₙ. Each agent Aᵢ has its own parameters and its own last-layer hidden representations. The system maintains a collective latent state, which is the combined internal representation of what all agents currently "know" about the problem. Given an input question x with a ground truth answer y, the system orchestrates interactions among agents to collaboratively produce a final prediction.

The key insight that motivates the whole paper is captured in what the authors call Recursive Multi-Agent Evolution: a recursive evolution is the progressive refinement of the collective latent state, where each agent adjusts its latent representation and its own reasoning state so that the updated system is better aligned for the given problem. In other words, the system should get smarter with each round of interaction, not just produce one answer and stop.

The Four Collaboration Patterns

The paper identifies four archetypal ways that agents can collaborate, and RecursiveMAS is designed to work with all of them. Understanding these patterns is essential because the framework is deliberately structure-agnostic — it does not care how you arrange your agents, it just makes whatever arrangement you choose work better.

The Sequential Style arranges agents in a chain, where each agent builds on the work of the previous one. The paper uses a Planner–Critic–Solver arrangement: the Planner decomposes the problem into a step-by-step plan, the Critic evaluates that plan and identifies weaknesses, and the Solver uses the refined plan to produce a final answer. This is the most common pattern in practice and the one the paper uses for its primary experiments.

The Mixture Style runs multiple domain-specialized agents in parallel, then aggregates their outputs. The paper uses Math, Code, and Science specialists whose outputs are combined by a Summarizer agent. This pattern is powerful when you do not know in advance which domain a question belongs to, or when a question genuinely spans multiple domains.

The Distillation Style pairs a large, capable Expert model with a smaller, faster Learner model. The Expert provides rich guidance; the Learner absorbs it and produces the final answer more efficiently. This is essentially knowledge distillation happening at inference time, not just training time. The paper shows that RecursiveMAS can improve the Learner by 8.0% while retaining a 1.5x speed advantage over the Expert alone.

The Deliberation Style pairs an inner-thinking Reflector with a Tool-Caller that can invoke external tools like Python interpreters or search APIs. The two agents iteratively exchange, critique, and refine candidate solutions until they reach consensus, after which the Tool-Caller produces the final answer. This is the most complex pattern and the most interesting for building real-world agentic systems.

The Problem with Text-Based Communication

Here is where things get interesting. In all four patterns above, the traditional approach has agents communicate by generating text. Agent A₁ produces a text response, which is fed as a prompt to Agent A₂, which produces another text response, and so on. This seems natural — text is the universal interface of language models, after all.

But this approach has two serious problems that the paper addresses with mathematical rigor.

The first problem is computational efficiency. When an intermediate agent generates text, it must run the full vocabulary projection layer (which maps from the hidden dimension to a vocabulary of tens of thousands of tokens), sample a token, and then the next agent must re-embed that token back into the hidden space. This decode-then-re-encode cycle is expensive. The paper proves formally that text-based recursive MAS has higher runtime complexity than latent-space-based RecursiveMAS, because RecursiveMAS replaces the O(|V|) vocabulary projection with a small linear transformation over the hidden dimension d_h, which is much smaller than the vocabulary size |V|.

The second problem is gradient vanishing during training. When you try to train a text-based multi-agent system end-to-end, the gradients have to flow backward through the discrete token sampling operation. Discrete sampling is not differentiable, so in practice the gradient must pass through the softmax distribution over the vocabulary. The paper proves that when tokens are generated with high confidence (which is exactly what you want from a good model), the softmax distribution becomes very peaked, its covariance matrix becomes nearly singular, and the gradient norm collapses toward zero. The RecursiveLink, by contrast, maintains gradient norms that are bounded away from zero by a quantity that depends on the hidden dimension, not on token confidence. This means training actually works.

These two theoretical results — the complexity proposition and the gradient stability theorem — are not just academic decoration. They are the mathematical justification for why the whole system is designed the way it is.


CHAPTER TWO: THE ARCHITECTURE OF RECURSIVEMAS

Now let us get into the actual machinery. RecursiveMAS has three main components: the RecursiveLink module (which comes in two flavors, inner and outer), the latent thoughts generation process, and the recursive loop that chains everything together.

The RecursiveLink: The Heart of the System

The RecursiveLink is a small, lightweight neural module — just two linear layers with a residual connection and a GELU activation — that serves as the bridge between agents. It is the only part of the system that gets trained. All the large LLM agent parameters are frozen. This is a crucial design choice: you do not need to retrain your expensive 7B or 13B parameter models. You just train a tiny adapter that knows how to translate between their hidden spaces.

There are two variants of the RecursiveLink, and understanding the difference between them is key to understanding the whole architecture.

The Inner RecursiveLink operates within a single agent. Its job is to take the last-layer hidden state that the agent produces at one autoregressive step and transform it into an input embedding for the next step, so that the agent can continue reasoning in the latent space without ever decoding to text. The formula is:

$$R_{\text{inner}}(h) = h + W_2 \cdot \sigma(W_1 \cdot h)$$

where (h) is the current last-layer hidden state vector, (W_1) and (W_2) are learned linear layers, (\sigma) is the GELU activation function, and the addition of (h) on the left is the residual connection. The residual connection is not just a nice-to-have — it is architecturally important. By adding the original (h) back, the module is forced to learn only the distributional shift (the difference between where the hidden state lives and where the input embedding space lives), rather than learning the entire transformation from scratch. This makes training more stable and faster to converge.

The Outer RecursiveLink operates between agents. Its job is to take the hidden states produced by one agent and transform them into input embeddings for a different agent, which may have a completely different hidden dimension. This is the "heterogeneous" part of the framework — you can connect a 1.7B parameter model to a 7B parameter model, and the outer link handles the dimensional mismatch. The formula adds one more linear layer:

$$R_{\text{outer}}(h) = W_3 \cdot h + W_2 \cdot \sigma(W_1 \cdot h)$$

The difference from the inner link is the (W_3 \cdot h) term in the residual branch. In the inner link, the residual is just (h) (the identity), because the source and target spaces have the same dimension. In the outer link, (W_3) is a linear projection that maps from the source agent's hidden dimension to the target agent's hidden dimension, so the residual branch also performs the dimensional alignment. The nonlinear branch (W_2 \cdot \sigma(W_1 \cdot h)) then learns the fine-grained distributional correction on top of that linear alignment.

Let us look at how this translates into code. The following implementation works with PyTorch and is designed to be clean, readable, and easy to extend:

import torch
import torch.nn as nn
import torch.nn.functional as F


class InnerRecursiveLink(nn.Module):
    """
    The Inner RecursiveLink operates within a single LLM agent.

    It transforms the agent's last-layer hidden state at step t into
    an input embedding for step t+1, enabling the agent to reason
    in the continuous latent space without decoding to text.

    The residual connection is critical: it forces the module to learn
    only the distributional shift, not the full transformation.
    This leads to more stable gradients and faster convergence.

    Architecture:
        R_inner(h) = h + W2 * GELU(W1 * h)

    At initialization, W2 is set to all zeros so the module starts as
    a pure identity transformation (output == input). Training then
    learns the residual correction on top of this stable baseline.

    Args:
        hidden_dim: The hidden dimension of the LLM agent this link
                    is paired with. Must match the model's d_model.
    """

    def __init__(self, hidden_dim: int):
        super().__init__()
        self.hidden_dim = hidden_dim

        # W1: first linear layer, maps hidden_dim -> hidden_dim
        self.W1 = nn.Linear(hidden_dim, hidden_dim, bias=True)

        # W2: second linear layer, maps hidden_dim -> hidden_dim
        self.W2 = nn.Linear(hidden_dim, hidden_dim, bias=True)

        self._initialize_weights()

    def _initialize_weights(self):
        """
        Initialize weights so the module starts as a near-identity
        transformation.

        W1 is initialized with small random values (gain=0.1) so
        GELU(W1*h) produces small activations initially.
        W2 is initialized to zero so the entire nonlinear branch
        outputs zero at the start, making R_inner(h) = h + 0 = h.
        This identity-at-init property ensures stable early training.
        """
        nn.init.xavier_uniform_(self.W1.weight, gain=0.1)
        nn.init.zeros_(self.W1.bias)
        nn.init.zeros_(self.W2.weight)
        nn.init.zeros_(self.W2.bias)

    def forward(self, h: torch.Tensor) -> torch.Tensor:
        """
        Transform a last-layer hidden state into the next input embedding.

        Args:
            h: Hidden state tensor.
               Shape: (batch_size, hidden_dim) for a single step, or
                      (batch_size, seq_len, hidden_dim) for a sequence.

        Returns:
            Transformed embedding of the same shape as h, ready to be
            used as input to the next autoregressive step.
        """
        # Nonlinear branch learns the distributional correction
        correction = self.W2(F.gelu(self.W1(h)))
        # Residual connection preserves the original latent semantics
        return h + correction


class OuterRecursiveLink(nn.Module):
    """
    The Outer RecursiveLink bridges two heterogeneous LLM agents.

    It transforms the last-layer hidden states of a source agent into
    input embeddings aligned with the target agent's embedding space.
    This enables seamless cross-agent latent state transfer even when
    the two agents have different hidden dimensions (e.g., a 1.7B model
    talking to a 7B model).

    Architecture:
        R_outer(h) = W3 * h + W2 * GELU(W1 * h)

    The W3 term in the residual branch handles the dimensional alignment
    (a learned linear projection from source_dim to target_dim), while
    the nonlinear branch learns fine-grained distributional correction.

    Args:
        source_dim: Hidden dimension of the source (sending) agent.
        target_dim: Hidden dimension of the target (receiving) agent.
    """

    def __init__(self, source_dim: int, target_dim: int):
        super().__init__()
        self.source_dim = source_dim
        self.target_dim = target_dim

        # W1: projects within source space before the nonlinear activation
        self.W1 = nn.Linear(source_dim, source_dim, bias=True)

        # W2: projects from source space to target space (nonlinear branch)
        self.W2 = nn.Linear(source_dim, target_dim, bias=True)

        # W3: the residual projection that handles dimensional alignment.
        # This is the key structural difference from the inner link.
        # No bias: the bias in W2 already handles the offset.
        self.W3 = nn.Linear(source_dim, target_dim, bias=False)

        self._initialize_weights()

    def _initialize_weights(self):
        """
        Initialize so that W3 provides a reasonable linear baseline
        (Xavier uniform) and the nonlinear branch starts near zero,
        similar to the inner link's identity-at-init strategy.
        """
        nn.init.xavier_uniform_(self.W3.weight, gain=1.0)
        nn.init.xavier_uniform_(self.W1.weight, gain=0.1)
        nn.init.zeros_(self.W1.bias)
        nn.init.zeros_(self.W2.weight)
        nn.init.zeros_(self.W2.bias)

    def forward(self, h: torch.Tensor) -> torch.Tensor:
        """
        Transform source agent's hidden states into target agent's
        input embedding space.

        Args:
            h: Hidden state tensor from the source agent.
               Shape: (batch_size, seq_len, source_dim) or
                      (batch_size, source_dim).

        Returns:
            Transformed embedding aligned with the target agent's space.
            Shape: (batch_size, seq_len, target_dim) or
                   (batch_size, target_dim), matching the input rank.
        """
        # Linear residual branch: handles dimensional alignment
        linear_branch = self.W3(h)
        # Nonlinear branch: learns fine-grained distributional correction
        nonlinear_branch = self.W2(F.gelu(self.W1(h)))
        return linear_branch + nonlinear_branch

Notice how clean and small this is. The entire RecursiveLink — the module that makes the whole system work — is fewer than 120 lines of code including comments. This is one of the most elegant aspects of the paper: the key innovation is architecturally tiny, even though its effects are large.

Latent Thoughts Generation: Thinking Without Words

The next concept to understand is latent thoughts generation, which is the process by which an agent reasons in the continuous latent space rather than producing text.

In normal autoregressive generation, an LLM takes an input sequence of tokens, embeds them into vectors, runs them through the Transformer, and at each step produces a probability distribution over the vocabulary from which it samples the next token. The next token is then embedded and fed back in as input, and the cycle continues.

In latent thoughts generation, the process is different. After the Transformer produces the last-layer hidden state (h_t) at step (t), instead of projecting (h_t) to the vocabulary and sampling a token, we pass (h_t) through the Inner RecursiveLink to get (e_{t+1} = R_{\text{inner}}(h_t)). This embedding (e_{t+1}) is then fed directly as the input embedding for step (t+1), bypassing the token sampling entirely. The agent runs for (m) steps this way, producing a sequence of hidden states (H = [h_t, h_{t+1}, \ldots, h_{t+m}]) that represent its "latent thoughts" — its reasoning, encoded as vectors rather than words.

The paper shows that (m = 80) is a sweet spot: performance improves steadily as (m) increases from 0 to about 80, then plateaus. This is a practically useful finding because it means you do not need to run hundreds of latent steps to get the benefit — a moderate budget of 80 steps is enough for effective collaboration.

The following function illustrates how latent thoughts generation works. Note the careful separation of the frozen model forward pass (inside torch.no_grad()) from the trainable inner link forward pass (outside torch.no_grad()). This separation is essential: the model's parameters must not receive gradients (they are frozen), but the inner link's parameters must receive gradients during training.

from typing import Optional


def generate_latent_thoughts(
    model: nn.Module,
    input_embeddings: torch.Tensor,
    inner_link: InnerRecursiveLink,
    num_latent_steps: int = 80,
    prior_latent_state: Optional[torch.Tensor] = None,
    training: bool = False,
) -> torch.Tensor:
    """
    Run an agent in latent-thoughts generation mode.

    Instead of decoding tokens at each step, we use the Inner RecursiveLink
    to feed the last-layer hidden state back as the next input embedding.
    This allows the agent to reason in continuous latent space for
    `num_latent_steps` steps before any text is produced.

    IMPORTANT — gradient flow design:
        The frozen model forward pass is always wrapped in torch.no_grad()
        to avoid building a computation graph through frozen parameters
        (which would waste memory and compute). The inner_link forward
        pass is NOT wrapped in torch.no_grad() when training=True, so
        gradients flow through the link weights during backpropagation.
        When training=False (inference), everything runs under no_grad.

    Args:
        model: A Hugging Face causal LM. Its parameters must be frozen
               before calling this function.
        input_embeddings: The embedded input context for this agent.
                          Shape: (batch_size, context_len, hidden_dim)
        inner_link: The InnerRecursiveLink for this agent.
        num_latent_steps: How many latent reasoning steps to perform.
                          The paper finds m=80 is a good default.
        prior_latent_state: Optional latent thoughts from a previous
                            recursion round, already projected into this
                            agent's embedding space by an outer link.
                            Shape: (batch_size, prior_len, hidden_dim).
                            Prepended to the input context so the agent
                            can condition on previous-round information.
        training: Set True during outer-loop training so gradients flow
                  through the inner_link. Set False for inference.

    Returns:
        latent_thoughts: The sequence of last-layer hidden states produced
                         during latent generation.
                         Shape: (batch_size, num_latent_steps, hidden_dim)
    """
    # If we have latent state from a previous recursion round,
    # prepend it to the input context so the agent conditions on it.
    if prior_latent_state is not None:
        current_embeddings = torch.cat(
            [prior_latent_state, input_embeddings], dim=1
        )
    else:
        current_embeddings = input_embeddings

    latent_thoughts = []

    for _ in range(num_latent_steps):
        # ---- Frozen model forward pass (never builds grad graph) ----
        with torch.no_grad():
            outputs = model(
                inputs_embeds=current_embeddings,
                output_hidden_states=True,
                use_cache=False,
            )
            # Last-layer hidden state at the final sequence position.
            # Shape: (batch_size, hidden_dim)
            last_hidden = outputs.hidden_states[-1][:, -1, :]

        # ---- Trainable inner link (builds grad graph when training) ----
        # last_hidden is detached from the model graph (produced under
        # no_grad), but gradients still flow through inner_link's own
        # weight matrices W1 and W2, which is exactly what we want.
        if training:
            next_embedding = inner_link(last_hidden)
        else:
            with torch.no_grad():
                next_embedding = inner_link(last_hidden)

        # Store the hidden state (detached — we only need it as a value
        # to pass to the next agent, not to differentiate through it here)
        latent_thoughts.append(last_hidden.unsqueeze(1))

        # Append the new embedding to the running context.
        # The context grows by one embedding per step.
        current_embeddings = torch.cat(
            [current_embeddings, next_embedding.unsqueeze(1)], dim=1
        )

    # Stack all latent thoughts into a single tensor.
    # Shape: (batch_size, num_latent_steps, hidden_dim)
    return torch.cat(latent_thoughts, dim=1)

The key thing to notice here is the loop structure. At each step, we run the full Transformer forward pass under torch.no_grad() (because the model is frozen), grab the last-layer hidden state, then call the inner link outsidetorch.no_grad() (because the link needs gradients during training). The context window grows by one embedding per step. After (m) steps, we have a sequence of (m) hidden state vectors that encode the agent's latent reasoning about the problem.

This is conceptually similar to how Chain-of-Thought prompting works — you are giving the model space to reason before committing to an answer — but instead of generating text tokens that consume vocabulary space and require decoding, you are generating continuous vectors that are much more information-dense and much cheaper to produce.

Chaining Agents into a Loop

Now we have the two building blocks: the RecursiveLink modules and the latent thoughts generation process. The third piece is how these are combined to form the recursive loop.

The process for a single recursion round goes like this. Agent A₁ receives the input question (as embeddings) and, if this is not the first round, the latent state from the previous round. It runs latent thoughts generation for (m) steps, producing (H_{A_1}). These latent thoughts are then passed through the Outer RecursiveLink to transform them into the embedding space of Agent A₂. Agent A₂ receives both its own input context embeddings and the transformed latent thoughts from A₁, concatenated together. Agent A₂ then runs its own latent thoughts generation, producing (H_{A_2}). This continues through all N agents.

After the last agent Aₙ completes latent thoughts generation, its latent outputs (H_{A_N}) are passed back to the first agent A₁ through another Outer RecursiveLink, closing the loop. This is the "recursive" part: the system's latent answer from round (r) becomes additional context for round (r+1). Each new round can condition on what the system collectively produced in all previous rounds, enabling iterative refinement.

Only after the final recursion round does any text get produced. The last agent Aₙ decodes its latent thoughts into a textual answer using the standard vocabulary projection. All intermediate rounds are entirely in the latent space.

The following diagram represents the information flow for a two-agent system over two recursion rounds:

Round 1:
=========
Question (text) --> [Embed] --> E_A1
                                  |
                                  v
                          [Agent A1 + Inner Link]
                          generates H_A1 (latent)
                                  |
                          [Outer Link A1->A2]
                                  |
                                  v
                E_A2 + R_outer(H_A1) --> [Agent A2 + Inner Link]
                                         generates H_A2 (latent)
                                                |
                                        [Outer Link A2->A1]
                                                |
                          +---------------------+
                          |
                          v
Round 2:
=========
Question (text) --> [Embed] --> E_A1
R_outer(H_A2 from Round 1) ------+
                                  |
                                  v
                          [Agent A1 + Inner Link]
                          generates H_A1' (latent)
                                  |
                          [Outer Link A1->A2]
                                  |
                                  v
                E_A2 + R_outer(H_A1') --> [Agent A2 + Inner Link]
                                          generates H_A2' (latent)
                                                |
                                        [Decode to text]
                                                |
                                                v
                                        FINAL ANSWER

This is a beautiful structure. The question embeddings are always fed in fresh at the start of each round, so the agents never lose sight of what they are trying to answer. But they also receive the accumulated latent wisdom of all previous rounds, allowing them to iteratively refine their reasoning.


CHAPTER THREE: TRAINING THE SYSTEM

One of the most practically important aspects of RecursiveMAS is how it is trained. The answer is elegant: you do not train the large LLM agents at all. You freeze all their parameters and only train the tiny RecursiveLink modules. This means the training cost is dramatically lower than fine-tuning the agents themselves.

The paper reports that RecursiveMAS uses only 13.12 million trainable parameters (0.31% of the total parameter count), compared to 15.29 million for LoRA and 4.21 billion for full supervised fine-tuning. Despite having fewer trainable parameters than LoRA, RecursiveMAS achieves higher accuracy (74.9% average vs. 66.9% for LoRA and 68.6% for full SFT) and uses less GPU memory (15.29 GB vs. 21.67 GB for LoRA and 41.40 GB for full SFT). This is a remarkable result that demonstrates the power of training the right thing rather than training everything.

The training happens in two stages, which the paper calls the Inner-Outer Loop training paradigm.

Stage One: Inner-Loop Training

The inner loop trains each agent's Inner RecursiveLink independently. The goal is to warm-start each agent so that it can generate useful latent thoughts — so that the hidden states it produces in latent mode are semantically meaningful and aligned with what the model would produce if it were generating text.

The training objective for the inner loop is a cosine similarity loss. For each training example ((x, y)), you run the agent in latent mode to get an initial hidden state (H_0) (the last-layer hidden state after processing the input context). You also take the ground truth answer (y), pass it through the model's standard input embedding layer to get (\text{Emb}(y)), and compute the cosine similarity between (R_{\text{inner}}(H_0)) and the mean of (\text{Emb}(y)). The loss is:

$$L_{\text{inner}} = 1 - \cos!\left(R_{\text{inner}}(H_0),; \overline{\text{Emb}(y)}\right)$$

where (\overline{\text{Emb}(y)}) denotes the mean of the answer token embeddings, representing the semantic "direction" the latent thoughts should point toward. Minimizing this loss encourages the inner link to transform the agent's latent thoughts so that they point in the same direction as the embeddings of the correct answer. This is a form of self-supervised alignment: you are teaching the agent to think in a way that is consistent with its own text generation capabilities.

The following code shows how to compute this inner-loop training loss and run the training stage:

from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR
from typing import Tuple


def compute_inner_loop_loss(
    model: nn.Module,
    inner_link: InnerRecursiveLink,
    input_ids: torch.Tensor,
    target_ids: torch.Tensor,
) -> torch.Tensor:
    """
    Compute the inner-loop training loss for one agent.

    The loss encourages the inner link to produce latent thoughts
    that are semantically aligned with the ground-truth answer
    embeddings, measured by cosine similarity.

    Gradient flow:
        - model parameters: frozen, no gradients computed.
        - inner_link parameters: receive gradients through the
          cosine similarity loss. The key is that inner_link(H_0)
          is called OUTSIDE torch.no_grad(), so PyTorch builds a
          computation graph through W1 and W2 of the inner link.
          H_0 itself is a detached tensor (produced under no_grad),
          which is correct since we do not want gradients to flow
          into the frozen model.

    Args:
        model: The frozen LLM agent.
        inner_link: The Inner RecursiveLink to train.
        input_ids: Tokenized input question. Shape: (batch, seq_len).
        target_ids: Tokenized ground-truth answer. Shape: (batch, ans_len).

    Returns:
        loss: Scalar tensor. Value is in [0, 2]; minimizing it aligns
              the inner link's output with the answer embedding direction.
    """
    embedding_layer = model.get_input_embeddings()

    # Embed the input question context (no grad needed here)
    with torch.no_grad():
        # Shape: (batch, seq_len, hidden_dim)
        input_embeddings = embedding_layer(input_ids)

        # Embed the ground-truth answer — this is our alignment target.
        # Shape: (batch, ans_len, hidden_dim)
        target_embeddings = embedding_layer(target_ids)

        # Run the frozen model on the input context to get the initial
        # last-layer hidden state H_0.
        initial_output = model(
            inputs_embeds=input_embeddings,
            output_hidden_states=True,
            use_cache=False,
        )
        # Shape: (batch, hidden_dim)
        H_0 = initial_output.hidden_states[-1][:, -1, :]

    # Apply the inner link WITH gradient tracking.
    # H_0 is detached (produced under no_grad), but gradients still
    # flow through inner_link's own weight matrices W1 and W2.
    # Shape: (batch, hidden_dim)
    transformed_latent = inner_link(H_0)

    # Target direction: mean of the ground-truth answer embeddings.
    # This is a detached tensor (produced under no_grad above).
    # Shape: (batch, hidden_dim)
    target_direction = target_embeddings.mean(dim=1)

    # Cosine similarity: values in [-1, 1]. We want to maximize it,
    # so we minimize 1 - similarity.
    similarity = F.cosine_similarity(
        transformed_latent, target_direction, dim=-1
    )

    # Average over the batch
    loss = (1.0 - similarity).mean()
    return loss


def train_inner_loop(
    model: nn.Module,
    inner_link: InnerRecursiveLink,
    dataloader,
    num_epochs: int = 3,
    learning_rate: float = 1e-4,
    device: str = "cuda",
) -> None:
    """
    Run the inner-loop training stage for one agent.

    Only the inner_link parameters are updated. The model is frozen.
    Call this independently for each agent before running outer-loop
    training.

    Args:
        model: The frozen LLM agent (all parameters must have
               requires_grad=False before calling this function).
        inner_link: The Inner RecursiveLink to train.
        dataloader: DataLoader yielding (input_ids, target_ids) batches,
                    where both tensors are on `device`.
        num_epochs: Number of training epochs.
        learning_rate: Initial learning rate for AdamW.
        device: Device string for moving tensors if needed.
    """
    # Verify the model is frozen
    for param in model.parameters():
        param.requires_grad = False

    # Only optimize the inner link parameters
    optimizer = AdamW(inner_link.parameters(), lr=learning_rate)
    scheduler = CosineAnnealingLR(
        optimizer, T_max=num_epochs * len(dataloader)
    )

    inner_link.train()
    model.eval()

    for epoch in range(num_epochs):
        total_loss = 0.0
        num_batches = 0

        for input_ids, target_ids in dataloader:
            optimizer.zero_grad()

            loss = compute_inner_loop_loss(
                model=model,
                inner_link=inner_link,
                input_ids=input_ids,
                target_ids=target_ids,
            )

            loss.backward()

            # Gradient clipping for stability
            torch.nn.utils.clip_grad_norm_(
                inner_link.parameters(), max_norm=1.0
            )

            optimizer.step()
            scheduler.step()

            total_loss += loss.item()
            num_batches += 1

        avg_loss = total_loss / max(num_batches, 1)
        print(
            f"  [Inner Loop] Epoch {epoch + 1}/{num_epochs} "
            f"| Loss: {avg_loss:.4f}"
        )

The inner-loop training is done independently for each agent. You train Agent A₁'s inner link, then Agent A₂'s inner link, and so on. Each agent learns to generate latent thoughts that are semantically coherent on its own, before the agents start collaborating.

Stage Two: Outer-Loop Training

After the inner loop has warmed up each agent's latent thoughts generation capability, the outer loop trains the Outer RecursiveLink modules that connect agents to each other. This is where the real magic happens: the entire system is optimized end-to-end as a unified entity.

The outer-loop training objective is a standard cross-entropy loss on the final textual output:

$$L_{\text{outer}} = \text{CrossEntropy}!\left(S^{(n)}!\left(S^{(n-1)}!\left(\cdots S^{(1)}(x)\cdots\right)\right), y\right)$$

where (S^{(r)}(\cdot)) denotes the system state after recursion round (r). The computation graph is preserved through all recursion rounds, so gradients can flow backward from the final prediction all the way through every outer link in every round. Each outer link receives a gradient signal that reflects its global contribution to the final answer, not just its local behavior.

This is the "shared credit assignment" that the paper refers to. Instead of training each agent in isolation and hoping they work well together, the outer loop trains the connections between agents with full knowledge of how the whole system performs. It is the difference between training individual musicians in isolation and rehearsing the whole orchestra together.

The following code implements the full RecursiveMASSystem as a PyTorch nn.Module and the outer-loop training procedure. Note the careful use of nn.ModuleList to properly register the RecursiveLink submodules (so they appear in system.parameters() and system.state_dict()), and the separation of frozen agent models from trainable links:

class RecursiveMASSystem(nn.Module):
    """
    A RecursiveMAS system with N agents in sequential style.

    This class wires together N frozen LLM agents with their
    Inner RecursiveLinks and Outer RecursiveLinks, implementing
    the full recursive loop for training and inference.

    Architecture notes:
        - agent_models: stored as a plain Python list (NOT nn.ModuleList)
          because their parameters are frozen and we deliberately do not
          want them to appear in system.parameters() or system.state_dict().
          This keeps the optimizer and checkpoint focused on the links only.
        - inner_links, outer_links: stored as nn.ModuleList so they ARE
          registered as submodules, appear in system.parameters(), and are
          saved/loaded correctly via state_dict().

    Args:
        agent_models: List of N frozen Hugging Face causal LM models.
                      Parameters must be frozen before passing them here.
        inner_links: List of N InnerRecursiveLink modules, one per agent.
        outer_links: List of N OuterRecursiveLink modules. outer_links[i]
                     connects agent i to agent (i+1) % N, so the last
                     element connects the final agent back to the first,
                     closing the recursive loop.
        num_latent_steps: Number of latent reasoning steps per agent
                          per recursion round (default 80, per the paper).
    """

    def __init__(
        self,
        agent_models: list,
        inner_links: list,
        outer_links: list,
        num_latent_steps: int = 80,
    ):
        super().__init__()

        if len(agent_models) < 2:
            raise ValueError(
                "RecursiveMASSystem requires at least 2 agents."
            )
        if len(inner_links) != len(agent_models):
            raise ValueError(
                "Must provide exactly one inner_link per agent."
            )
        if len(outer_links) != len(agent_models):
            raise ValueError(
                "Must provide exactly one outer_link per agent "
                "(outer_links[i] connects agent i to agent (i+1) % N)."
            )

        self.num_agents = len(agent_models)
        self.num_latent_steps = num_latent_steps

        # Store frozen agent models as a plain list.
        # They are NOT registered as nn.Module submodules intentionally:
        # their parameters are frozen and should not appear in
        # self.parameters() to keep the optimizer clean.
        self._agent_models = agent_models

        # Register RecursiveLinks as proper nn.Module submodules so they
        # appear in self.parameters() and self.state_dict().
        self.inner_links = nn.ModuleList(inner_links)
        self.outer_links = nn.ModuleList(outer_links)

        # Verify all agent models are frozen
        for i, model in enumerate(self._agent_models):
            for param in model.parameters():
                if param.requires_grad:
                    raise ValueError(
                        f"Agent {i} has unfrozen parameters. "
                        f"Freeze all agent parameters before constructing "
                        f"RecursiveMASSystem."
                    )

    def _run_agent_latent(
        self,
        agent_idx: int,
        context_embeddings: torch.Tensor,
        training: bool,
    ) -> torch.Tensor:
        """
        Run one agent in latent-thoughts generation mode.

        The frozen model forward pass runs under torch.no_grad() to
        avoid building a computation graph through frozen parameters.
        The inner_link forward pass runs with gradient tracking when
        training=True, so the link weights receive proper gradients.

        Args:
            agent_idx: Index of the agent to run (0-based).
            context_embeddings: Input context for this agent, already
                                including any transferred latent state
                                from other agents.
                                Shape: (batch, context_len, hidden_dim)
            training: True during outer-loop training; False at inference.

        Returns:
            latent_thoughts: Shape (batch, num_latent_steps, hidden_dim)
        """
        model = self._agent_models[agent_idx]
        inner_link = self.inner_links[agent_idx]

        current_embeddings = context_embeddings
        latent_thoughts = []

        for _ in range(self.num_latent_steps):
            # Frozen model forward pass — never builds grad graph
            with torch.no_grad():
                outputs = model(
                    inputs_embeds=current_embeddings,
                    output_hidden_states=True,
                    use_cache=False,
                )
                # Shape: (batch, hidden_dim)
                last_hidden = outputs.hidden_states[-1][:, -1, :]

            # Trainable inner link — builds grad graph when training=True
            if training:
                next_emb = inner_link(last_hidden)
            else:
                with torch.no_grad():
                    next_emb = inner_link(last_hidden)

            # Store the hidden state value (always detached from model graph)
            latent_thoughts.append(last_hidden.unsqueeze(1))

            # Grow the context by one embedding
            current_embeddings = torch.cat(
                [current_embeddings, next_emb.unsqueeze(1)], dim=1
            )

        # Shape: (batch, num_latent_steps, hidden_dim)
        return torch.cat(latent_thoughts, dim=1)

    def forward(
        self,
        input_ids: torch.Tensor,
        num_recursion_rounds: int = 3,
    ) -> torch.Tensor:
        """
        Run the full RecursiveMAS forward pass.

        Performs `num_recursion_rounds` rounds of latent collaboration
        among all agents, then decodes the final answer as text logits
        from the last agent in the final round.

        Args:
            input_ids: Tokenized input question.
                       Shape: (batch, seq_len)
            num_recursion_rounds: Number of recursive collaboration rounds.
                                  The paper uses n=3 as the default.

        Returns:
            logits: Output logits from the final agent in the last round.
                    Shape: (batch, final_context_len, vocab_size)
        """
        training = self.training  # True if system.train() was called

        # Pre-embed the input question for each agent.
        # Each agent may have a different hidden dimension, so we embed
        # separately using each agent's own embedding layer.
        input_embeddings = []
        for model in self._agent_models:
            with torch.no_grad():
                emb = model.get_input_embeddings()(input_ids)
            input_embeddings.append(emb)

        # feedback_latent[i] holds the outer-link-projected latent state
        # that agent i will receive at the start of the next round.
        # Initialized to None (no feedback before round 1).
        feedback_latent = [None] * self.num_agents

        final_logits = None

        for round_idx in range(num_recursion_rounds):
            new_feedback_latent = [None] * self.num_agents

            for agent_idx in range(self.num_agents):
                # Build context: agent's own input + feedback from
                # the previous round (if any)
                if feedback_latent[agent_idx] is not None:
                    context = torch.cat(
                        [feedback_latent[agent_idx],
                         input_embeddings[agent_idx]],
                        dim=1,
                    )
                else:
                    context = input_embeddings[agent_idx]

                # Run this agent in latent mode
                latent = self._run_agent_latent(
                    agent_idx=agent_idx,
                    context_embeddings=context,
                    training=training,
                )

                # Project this agent's latent thoughts into the next
                # agent's embedding space via the outer link.
                next_agent_idx = (agent_idx + 1) % self.num_agents
                outer_link = self.outer_links[agent_idx]

                if training:
                    projected = outer_link(latent)
                else:
                    with torch.no_grad():
                        projected = outer_link(latent)

                # The projected latent becomes feedback for the next agent
                # in this same round (sequential style: each agent sees
                # the previous agent's output within the same round).
                new_feedback_latent[next_agent_idx] = projected

            feedback_latent = new_feedback_latent

            # After the final round, decode text from the last agent.
            # We run one more forward pass on the last agent with its
            # full context (input + feedback it received this round)
            # to get vocabulary logits.
            if round_idx == num_recursion_rounds - 1:
                last_agent_idx = self.num_agents - 1
                last_model = self._agent_models[last_agent_idx]

                # Reconstruct the last agent's full context for decoding
                if feedback_latent[last_agent_idx] is not None:
                    decode_context = torch.cat(
                        [feedback_latent[last_agent_idx],
                         input_embeddings[last_agent_idx]],
                        dim=1,
                    )
                else:
                    decode_context = input_embeddings[last_agent_idx]

                with torch.no_grad() if not training else torch.enable_grad():
                    decode_output = last_model(
                        inputs_embeds=decode_context,
                        output_hidden_states=False,
                        use_cache=False,
                    )
                final_logits = decode_output.logits

        return final_logits


def train_outer_loop(
    system: RecursiveMASSystem,
    dataloader,
    num_epochs: int = 3,
    learning_rate: float = 1e-4,
    num_recursion_rounds: int = 3,
) -> None:
    """
    Run the outer-loop training stage for the full RecursiveMAS system.

    Only the RecursiveLink parameters (inner_links + outer_links) are
    updated. All LLM agent models remain frozen throughout.

    The computation graph is preserved across all recursion rounds so
    that gradients flow through every outer link and inner link in
    every round — this is the "shared credit assignment" described in
    the paper.

    Args:
        system: The RecursiveMASSystem to train. Must have been
                constructed with frozen agent models.
        dataloader: DataLoader yielding (input_ids, target_ids) batches.
                    target_ids should use -100 for positions that should
                    not contribute to the loss (standard HF convention).
        num_epochs: Number of training epochs.
        learning_rate: Initial learning rate for AdamW.
        num_recursion_rounds: Number of recursive rounds during training.
    """
    # system.parameters() returns only the RecursiveLink parameters
    # because the agent models are stored as a plain list (not ModuleList)
    # and thus not registered as submodules.
    optimizer = AdamW(system.parameters(), lr=learning_rate)
    scheduler = CosineAnnealingLR(
        optimizer, T_max=num_epochs * len(dataloader)
    )
    criterion = nn.CrossEntropyLoss(ignore_index=-100)

    system.train()

    for epoch in range(num_epochs):
        total_loss = 0.0
        num_batches = 0

        for input_ids, target_ids in dataloader:
            optimizer.zero_grad()

            # Full forward pass through all recursion rounds.
            # The computation graph is preserved for backpropagation
            # through all RecursiveLink modules in all rounds.
            logits = system(
                input_ids=input_ids,
                num_recursion_rounds=num_recursion_rounds,
            )

            # logits shape: (batch, context_len, vocab_size)
            # target_ids shape: (batch, seq_len)
            # We align by taking only the last seq_len positions of logits
            # (the positions corresponding to the answer tokens).
            seq_len = target_ids.shape[1]
            logits_for_loss = logits[:, -seq_len:, :]

            batch_size, ans_len, vocab_size = logits_for_loss.shape
            loss = criterion(
                logits_for_loss.reshape(batch_size * ans_len, vocab_size),
                target_ids.reshape(batch_size * ans_len),
            )

            # Backpropagate through ALL recursion rounds.
            # Gradients flow through every outer link and inner link.
            loss.backward()

            # Gradient clipping for stability (paper uses AdamW + cosine LR)
            torch.nn.utils.clip_grad_norm_(
                system.parameters(), max_norm=1.0
            )

            optimizer.step()
            scheduler.step()

            total_loss += loss.item()
            num_batches += 1

        avg_loss = total_loss / max(num_batches, 1)
        print(
            f"  [Outer Loop] Epoch {epoch + 1}/{num_epochs} "
            f"| Loss: {avg_loss:.4f}"
        )

CHAPTER FOUR: WORKING WITH LOCAL AND REMOTE LLMS

Now that we understand the theory and the training procedure, let us look at how to actually use RecursiveMAS in practice. The paper uses models from the Hugging Face ecosystem (Qwen, LLaMA, Gemma, Mistral), but the principles apply equally to remote API-based models. We will build a practical wrapper that supports both local models (via Hugging Face Transformers or Ollama) and remote models (via OpenAI-compatible APIs).

The key challenge with remote APIs is that they do not give you access to hidden states. You cannot intercept the last-layer embeddings of GPT-4 or Claude. This means you cannot implement the full RecursiveMAS architecture with remote APIs — you can only implement a text-based approximation. However, understanding this limitation is itself valuable, and for local models you can implement the full system.

We will build a unified interface that handles both cases gracefully.

The Agent Abstraction

The first thing we need is a clean abstraction for an "agent" that works regardless of whether the underlying model is local or remote:

from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import Optional


@dataclass
class AgentResponse:
    """
    Unified response object returned by any agent, local or remote.

    For local agents, hidden_states contains the actual latent vectors
    that RecursiveMAS uses for cross-agent communication. For remote
    agents, hidden_states is None and only the text response is available.

    Attributes:
        text: The decoded text response from the agent.
        hidden_states: Last-layer hidden states if available (local only).
                       Shape: (1, num_generated_tokens, hidden_dim) or None.
                       Note: shape depends on how many tokens were generated.
        token_count: Number of tokens generated (for efficiency tracking).
    """
    text: str
    hidden_states: Optional[torch.Tensor]
    token_count: int


class BaseAgent(ABC):
    """
    Abstract base class for all RecursiveMAS agents.

    Concrete implementations handle local Hugging Face models,
    Ollama-served models, and remote OpenAI-compatible API models.
    All agents expose the same interface so the RecursiveMAS orchestrator
    can work with any combination of local and remote agents.
    """

    def __init__(self, name: str, role: str):
        """
        Args:
            name: A human-readable name for this agent (e.g., "Planner").
            role: The agent's role description, used in system prompts.
        """
        self.name = name
        self.role = role

    @abstractmethod
    def generate_text(
        self,
        prompt: str,
        system_prompt: Optional[str] = None,
        max_new_tokens: int = 512,
        temperature: float = 0.6,
    ) -> AgentResponse:
        """
        Generate a text response to the given prompt.

        This is the universal interface used by all agents.
        Local agents additionally populate hidden_states in the response.

        Args:
            prompt: The user prompt to respond to.
            system_prompt: Optional system prompt for role conditioning.
            max_new_tokens: Maximum number of tokens to generate.
            temperature: Sampling temperature (lower = more deterministic).

        Returns:
            AgentResponse with text, optional hidden_states, and token_count.
        """
        pass

    @property
    @abstractmethod
    def supports_latent_transfer(self) -> bool:
        """
        Whether this agent supports latent-space state transfer.

        Returns True for local Hugging Face agents (which expose hidden
        states), and False for remote API agents (which do not).
        """
        pass

    def __repr__(self) -> str:
        return (
            f"{self.__class__.__name__}"
            f"(name={self.name!r}, role={self.role!r})"
        )

Local Agent: Hugging Face Transformers

The local agent implementation uses Hugging Face Transformers and gives us full access to hidden states, enabling the complete RecursiveMAS architecture:

from transformers import AutoTokenizer, AutoModelForCausalLM


class LocalHuggingFaceAgent(BaseAgent):
    """
    A RecursiveMAS agent backed by a local Hugging Face model.

    This agent supports full latent-space state transfer because it
    has direct access to the model's internal hidden states. It is
    the agent type used in the original RecursiveMAS paper with
    models like Qwen3, LLaMA-3, Gemma3, and Mistral.

    The model parameters are always frozen — only the RecursiveLink
    parameters are trained.

    Args:
        model_name_or_path: Hugging Face model identifier or local path.
                            Examples:
                              "Qwen/Qwen2.5-1.5B-Instruct"
                              "meta-llama/Llama-3.2-1B-Instruct"
                              "/path/to/local/model"
        name: Human-readable agent name.
        role: Agent role description for system prompts.
        device: Torch device string. Use "cuda" for NVIDIA GPU,
                "mps" for Apple Silicon, "cpu" for CPU-only.
                Note: "mps" support varies by model; test before deploying.
        torch_dtype: Data type for model weights. torch.float16 is
                     recommended for GPU to save memory. Use torch.float32
                     for CPU or if you encounter numerical issues.
    """

    def __init__(
        self,
        model_name_or_path: str,
        name: str,
        role: str,
        device: str = "cuda",
        torch_dtype: torch.dtype = torch.float16,
    ):
        super().__init__(name=name, role=role)
        self.device = device
        self.model_name = model_name_or_path

        print(f"Loading {name} from {model_name_or_path}...")

        self.tokenizer = AutoTokenizer.from_pretrained(
            model_name_or_path,
            trust_remote_code=True,
        )

        # Ensure the tokenizer has a padding token
        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token

        # Load model and immediately freeze all parameters
        self.model = AutoModelForCausalLM.from_pretrained(
            model_name_or_path,
            torch_dtype=torch_dtype,
            device_map=device,
            trust_remote_code=True,
        )

        # Freeze all model parameters — only RecursiveLinks are trained
        for param in self.model.parameters():
            param.requires_grad = False

        self.model.eval()

        # Cache the hidden dimension for RecursiveLink sizing
        self.hidden_dim = self.model.config.hidden_size

        total_params = sum(p.numel() for p in self.model.parameters())
        print(
            f"  Loaded {name}: hidden_dim={self.hidden_dim}, "
            f"params={total_params:,} (all frozen)"
        )

    @property
    def supports_latent_transfer(self) -> bool:
        """Local models always support latent transfer."""
        return True

    def get_hidden_dim(self) -> int:
        """Return the model's hidden dimension for RecursiveLink sizing."""
        return self.hidden_dim

    def get_raw_model(self) -> nn.Module:
        """
        Return the underlying frozen Hugging Face model.

        Used by RecursiveMASSystem to access the model directly for
        latent-space operations.
        """
        return self.model

    def get_embeddings(self, text: str) -> torch.Tensor:
        """
        Get the input embeddings for a piece of text.

        Used to embed the question context before latent thoughts generation.

        Args:
            text: The text to embed.

        Returns:
            Embeddings tensor of shape (1, seq_len, hidden_dim).
        """
        inputs = self.tokenizer(
            text,
            return_tensors="pt",
            truncation=True,
            max_length=512,
        ).to(self.device)

        with torch.no_grad():
            embeddings = self.model.get_input_embeddings()(
                inputs["input_ids"]
            )
        return embeddings

    def generate_text(
        self,
        prompt: str,
        system_prompt: Optional[str] = None,
        max_new_tokens: int = 512,
        temperature: float = 0.6,
    ) -> AgentResponse:
        """
        Generate a text response using standard autoregressive decoding.

        This is used for the final output in the last recursion round,
        or for text-based approximation mode.

        Args:
            prompt: The user message.
            system_prompt: Optional system message for role conditioning.
            max_new_tokens: Maximum tokens to generate.
            temperature: Sampling temperature.

        Returns:
            AgentResponse with text output. hidden_states contains the
            last-layer hidden states from the final generation step,
            shape (1, 1, hidden_dim) — one position, last layer.
        """
        messages = []
        if system_prompt:
            messages.append({"role": "system", "content": system_prompt})
        messages.append({"role": "user", "content": prompt})

        formatted_prompt = self.tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=True,
        )

        inputs = self.tokenizer(
            formatted_prompt,
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=2048,
        ).to(self.device)

        input_length = inputs["input_ids"].shape[1]

        with torch.no_grad():
            outputs = self.model.generate(
                **inputs,
                max_new_tokens=max_new_tokens,
                temperature=temperature if temperature > 0 else 1.0,
                top_p=0.95,
                do_sample=(temperature > 0),
                output_hidden_states=True,
                return_dict_in_generate=True,
                pad_token_id=self.tokenizer.pad_token_id,
            )

        # Decode the generated text (excluding the input prompt tokens)
        generated_ids = outputs.sequences[:, input_length:]
        generated_text = self.tokenizer.decode(
            generated_ids[0],
            skip_special_tokens=True,
        )

        # Extract hidden states from the last generation step.
        # outputs.hidden_states is a tuple (one entry per generated token),
        # each entry is itself a tuple (one entry per Transformer layer).
        # outputs.hidden_states[-1][-1] is the last layer of the last token.
        # Shape: (batch_size=1, seq_position=1, hidden_dim)
        last_step_hidden = None
        if outputs.hidden_states:
            last_step_hidden = outputs.hidden_states[-1][-1]

        token_count = generated_ids.shape[1]

        return AgentResponse(
            text=generated_text,
            hidden_states=last_step_hidden,
            token_count=token_count,
        )

Remote Agent: OpenAI-Compatible API

For remote models, we implement a text-only agent that works with any OpenAI-compatible API. This includes OpenAI itself, Azure OpenAI, local Ollama servers, vLLM servers, and many other providers:

import time
import requests


class RemoteAPIAgent(BaseAgent):
    """
    A RecursiveMAS agent backed by a remote OpenAI-compatible API.

    This agent does NOT support latent-space transfer because remote
    APIs do not expose internal hidden states. It participates in the
    system through text-based communication only, making it suitable
    for the text-based approximation of RecursiveMAS or for the final
    output step where text generation is required anyway.

    Compatible with: OpenAI API, Azure OpenAI, Ollama (REST API),
                     vLLM, LM Studio, and any OpenAI-compatible server.

    Args:
        api_base_url: The base URL of the API endpoint.
                      Examples:
                        "https://api.openai.com/v1"
                        "http://localhost:11434/v1"   (Ollama)
                        "http://localhost:8000/v1"    (vLLM)
        model_id: The model identifier used in API calls.
                  Examples: "gpt-4o-mini", "llama3.2", "qwen2.5:7b"
        api_key: API key for authentication. Use "ollama" for Ollama
                 (it does not require a real key).
        name: Human-readable agent name.
        role: Agent role description.
        timeout: HTTP request timeout in seconds.
        max_retries: Number of retries on transient failures, with
                     exponential backoff between attempts.
    """

    def __init__(
        self,
        api_base_url: str,
        model_id: str,
        api_key: str,
        name: str,
        role: str,
        timeout: int = 60,
        max_retries: int = 3,
    ):
        super().__init__(name=name, role=role)
        self.api_base_url = api_base_url.rstrip("/")
        self.model_id = model_id
        self.timeout = timeout
        self.max_retries = max_retries

        self.headers = {
            "Authorization": f"Bearer {api_key}",
            "Content-Type": "application/json",
        }

    @property
    def supports_latent_transfer(self) -> bool:
        """Remote API agents do not support latent transfer."""
        return False

    def generate_text(
        self,
        prompt: str,
        system_prompt: Optional[str] = None,
        max_new_tokens: int = 512,
        temperature: float = 0.6,
    ) -> AgentResponse:
        """
        Generate a text response via the remote API.

        Implements exponential backoff retry logic for robustness
        against transient network failures and rate limiting.

        Args:
            prompt: The user message.
            system_prompt: Optional system message.
            max_new_tokens: Maximum tokens to generate.
            temperature: Sampling temperature.

        Returns:
            AgentResponse with text output. hidden_states is always None
            for remote agents since we cannot access internal model state.
        """
        messages = []
        if system_prompt:
            messages.append({"role": "system", "content": system_prompt})
        messages.append({"role": "user", "content": prompt})

        payload = {
            "model": self.model_id,
            "messages": messages,
            "max_tokens": max_new_tokens,
            "temperature": temperature,
            "top_p": 0.95,
        }

        last_error = None
        for attempt in range(self.max_retries):
            try:
                response = requests.post(
                    f"{self.api_base_url}/chat/completions",
                    headers=self.headers,
                    json=payload,
                    timeout=self.timeout,
                )
                response.raise_for_status()

                data = response.json()
                generated_text = data["choices"][0]["message"]["content"]
                token_count = data.get("usage", {}).get(
                    "completion_tokens",
                    len(generated_text.split()),
                )

                return AgentResponse(
                    text=generated_text,
                    hidden_states=None,
                    token_count=token_count,
                )

            except requests.exceptions.RequestException as e:
                last_error = e
                if attempt < self.max_retries - 1:
                    wait_time = 2 ** attempt
                    print(
                        f"  [{self.name}] Request failed "
                        f"(attempt {attempt + 1}/{self.max_retries}), "
                        f"retrying in {wait_time}s: {e}"
                    )
                    time.sleep(wait_time)

        raise RuntimeError(
            f"[{self.name}] All {self.max_retries} API attempts failed. "
            f"Last error: {last_error}"
        )


class OllamaAgent(RemoteAPIAgent):
    """
    Convenience subclass for agents backed by a local Ollama server.

    Ollama provides an OpenAI-compatible REST API at localhost:11434,
    making it easy to run models like Llama, Mistral, Qwen, and Gemma
    locally without writing GPU-intensive model loading code.

    Usage:
        # 1. Install Ollama: https://ollama.ai
        # 2. Start the server:
        #      ollama serve
        # 3. Pull a model:
        #      ollama pull qwen2.5:1.5b
        # 4. Create an agent:
        agent = OllamaAgent(
            model_id="qwen2.5:1.5b",
            name="Planner",
            role="You plan step-by-step solutions.",
        )

    Args:
        model_id: The Ollama model name (e.g., "qwen2.5:1.5b",
                  "llama3.2:1b", "qwen2.5:7b").
        name: Human-readable agent name.
        role: Agent role description.
        host: Ollama server hostname (default: "localhost").
        port: Ollama server port (default: 11434).
        timeout: HTTP request timeout in seconds (default: 120,
                 longer than RemoteAPIAgent default because local
                 models can be slow on CPU).
        max_retries: Number of retries on failure (default: 3).
    """

    def __init__(
        self,
        model_id: str,
        name: str,
        role: str,
        host: str = "localhost",
        port: int = 11434,
        timeout: int = 120,
        max_retries: int = 3,
    ):
        super().__init__(
            api_base_url=f"http://{host}:{port}/v1",
            model_id=model_id,
            api_key="ollama",
            name=name,
            role=role,
            timeout=timeout,
            max_retries=max_retries,
        )

CHAPTER FIVE: THE ORCHESTRATOR

With our agent abstractions in place, we can now build the orchestrator that implements the RecursiveMAS collaboration patterns. The orchestrator is responsible for managing the recursion rounds, routing latent states between agents (when available), and producing the final answer.

For the text-based approximation (which works with both local and remote agents), the orchestrator passes text between agents. For the full latent-space version (which requires local agents), it passes hidden state tensors through the RecursiveLink modules.

The following orchestrator implements the Sequential Style pattern, which is the most common and the one used for the paper's primary experiments. It supports both modes automatically based on whether the agents support latent transfer:

from typing import List, Optional, Tuple


class SequentialRecursiveMASOrchestrator:
    """
    Orchestrates a sequential-style RecursiveMAS collaboration.

    Manages a chain of agents (e.g., Planner -> Critic -> Solver) through
    multiple recursion rounds. Automatically uses latent-space transfer
    when all agents are local, and falls back to text-based transfer
    when any agent is remote.

    This implements the "light" and "scaled" sequential configuration
    from the paper: a Planner, Critic, and Solver in a chain, iterated
    over multiple recursion rounds.

    Args:
        agents: List of BaseAgent instances in pipeline order.
                Minimum 2 agents required.
        inner_links: List of InnerRecursiveLink modules, one per agent.
                     Required only for latent-space mode (all local agents).
                     Pass None to force text-based mode.
        outer_links: List of OuterRecursiveLink modules, one per agent.
                     outer_links[i] connects agent i to agent (i+1) % N.
                     Required only for latent-space mode.
                     Pass None to force text-based mode.
        num_latent_steps: Number of latent reasoning steps per agent
                          per round (default 80, per the paper).
        num_recursion_rounds: Number of recursive collaboration rounds
                              (default 3, per the paper).
    """

    def __init__(
        self,
        agents: List[BaseAgent],
        inner_links: Optional[List[InnerRecursiveLink]] = None,
        outer_links: Optional[List[OuterRecursiveLink]] = None,
        num_latent_steps: int = 80,
        num_recursion_rounds: int = 3,
    ):
        if len(agents) < 2:
            raise ValueError(
                "Sequential RecursiveMAS requires at least 2 agents."
            )

        self.agents = agents
        self.inner_links = inner_links
        self.outer_links = outer_links
        self.num_latent_steps = num_latent_steps
        self.num_recursion_rounds = num_recursion_rounds

        # Determine whether we can use full latent-space mode.
        # Requires: all agents are local AND links are provided.
        self.use_latent_mode = (
            all(agent.supports_latent_transfer for agent in agents)
            and inner_links is not None
            and outer_links is not None
            and len(inner_links) == len(agents)
            and len(outer_links) == len(agents)
        )

        mode = "latent-space" if self.use_latent_mode else "text-based"
        print(
            f"SequentialRecursiveMAS initialized with {len(agents)} agents "
            f"in {mode} mode over {num_recursion_rounds} recursion rounds."
        )

    def _build_agent_prompt(
        self,
        agent: BaseAgent,
        question: str,
        prior_context: Optional[str],
        round_idx: int,
        agent_idx: int,
    ) -> Tuple[str, str]:
        """
        Build the system and user prompts for an agent in text mode.

        In text mode, prior context from previous agents or rounds is
        included in the prompt as explicit text. This is the text-based
        approximation of what RecursiveMAS does in latent space.

        Args:
            agent: The agent to build prompts for.
            question: The original question.
            prior_context: Text output from the previous agent or round.
                           None if this is the first agent in round 1.
            round_idx: Current recursion round index (0-based).
            agent_idx: This agent's position in the pipeline (0-based).

        Returns:
            Tuple of (system_prompt, user_prompt).
        """
        system_prompt = (
            f"You are a {agent.role} in a recursive multi-agent system. "
            f"You are agent {agent_idx + 1} of {len(self.agents)}, "
            f"participating in recursion round {round_idx + 1} "
            f"of {self.num_recursion_rounds}. "
            f"Collaborate carefully and build on the work of other agents."
        )

        if prior_context:
            user_prompt = (
                f"Here is context from the previous agent in this round:\n\n"
                f"---\n{prior_context}\n---\n\n"
                f"Given this context, provide your response to the question:\n\n"
                f"{question}"
            )
        else:
            user_prompt = (
                f"Please respond to the following question:\n\n{question}"
            )

        return system_prompt, user_prompt

    def solve_text_mode(self, question: str) -> str:
        """
        Solve a question using text-based recursive collaboration.

        This is the fallback mode used when any agent does not support
        latent transfer, or when no RecursiveLink modules are provided.
        It approximates RecursiveMAS by passing text between agents
        across multiple recursion rounds.

        While less efficient than latent-space mode (no token savings,
        no gradient stability benefits), this still benefits from the
        recursive multi-round structure and works with any combination
        of local and remote agents.

        Args:
            question: The question to answer.

        Returns:
            The final text answer from the last agent in the last round.
        """
        if not self.agents:
            return ""

        print(
            f"\nSolving in text mode over "
            f"{self.num_recursion_rounds} rounds..."
        )

        # The last agent's output from the previous round feeds back
        # into the first agent at the start of the next round.
        previous_round_output: Optional[str] = None
        final_answer: str = ""

        for round_idx in range(self.num_recursion_rounds):
            print(
                f"  Round {round_idx + 1}/{self.num_recursion_rounds}"
            )
            # Within a round, each agent sees the previous agent's output
            current_context: Optional[str] = previous_round_output

            for agent_idx, agent in enumerate(self.agents):
                system_prompt, user_prompt = self._build_agent_prompt(
                    agent=agent,
                    question=question,
                    prior_context=current_context,
                    round_idx=round_idx,
                    agent_idx=agent_idx,
                )

                response = agent.generate_text(
                    prompt=user_prompt,
                    system_prompt=system_prompt,
                    max_new_tokens=512,
                    temperature=0.6,
                )

                print(
                    f"    [{agent.name}] Generated "
                    f"{response.token_count} tokens"
                )

                # This agent's output becomes context for the next agent
                current_context = response.text

                # Track the final agent's output as the answer
                if agent_idx == len(self.agents) - 1:
                    final_answer = response.text

            # The last agent's output feeds back to the first agent
            # in the next round
            previous_round_output = final_answer

        return final_answer

    def solve(self, question: str) -> str:
        """
        Solve a question using RecursiveMAS.

        Automatically selects the best available mode:
          - Latent-space mode if all agents are local and links are provided.
          - Text-based mode otherwise.

        Args:
            question: The question to answer.

        Returns:
            The final answer string.
        """
        if self.use_latent_mode:
            # Full latent-space RecursiveMAS requires direct model access
            # via the RecursiveMASSystem nn.Module. For simplicity in this
            # orchestrator, we note that latent-space inference should be
            # driven through RecursiveMASSystem.forward() directly.
            # The orchestrator's text mode is provided for quick prototyping
            # with any agent type.
            print(
                "Note: Full latent-space inference is available via "
                "RecursiveMASSystem.forward(). The orchestrator uses "
                "text-based mode for broad compatibility."
            )
        return self.solve_text_mode(question)

CHAPTER SIX: THE COMPLETE COMBINED FILE

Here is the complete, self-contained recursive_mas.py file that combines all code from Chapters Two through Five in the correct dependency order. Copy this entire block into recursive_mas.py and you are ready to run.

"""
recursive_mas.py
================
Complete implementation of RecursiveMAS-inspired multi-agent collaboration.

Based on: "RecursiveMAS: Scaling Agent Collaboration through Unified
Latent-Space Recursive Computation" (2604.25917)

This file contains:
  - InnerRecursiveLink and OuterRecursiveLink modules
  - generate_latent_thoughts() utility function
  - compute_inner_loop_loss() and train_inner_loop() for Stage 1 training
  - RecursiveMASSystem nn.Module and train_outer_loop() for Stage 2 training
  - AgentResponse dataclass and BaseAgent abstract class
  - LocalHuggingFaceAgent for local Hugging Face models
  - RemoteAPIAgent and OllamaAgent for remote/Ollama-served models
  - SequentialRecursiveMASOrchestrator for running the full pipeline

Requirements:
  pip install torch>=2.1.0 transformers>=4.40.0 accelerate>=0.27.0 requests>=2.31.0

Usage:
  See demo.py for a complete runnable example.
"""

# ============================================================
# Standard library imports
# ============================================================
import time
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import List, Optional, Tuple

# ============================================================
# Third-party imports
# ============================================================
import requests
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR

# Transformers imports are inside LocalHuggingFaceAgent.__init__
# to make the file importable even if transformers is not installed
# (useful when only using remote agents).


# ============================================================
# Section 1: RecursiveLink Modules
# ============================================================

class InnerRecursiveLink(nn.Module):
    """
    The Inner RecursiveLink operates within a single LLM agent.

    Transforms the agent's last-layer hidden state at step t into
    an input embedding for step t+1, enabling the agent to reason
    in continuous latent space without decoding to text.

    Architecture:
        R_inner(h) = h + W2 * GELU(W1 * h)

    At initialization, W2 is all zeros so the module starts as a
    pure identity (output == input). Training learns the residual
    correction on top of this stable baseline.

    Args:
        hidden_dim: Hidden dimension of the paired LLM agent.
    """

    def __init__(self, hidden_dim: int):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.W1 = nn.Linear(hidden_dim, hidden_dim, bias=True)
        self.W2 = nn.Linear(hidden_dim, hidden_dim, bias=True)
        self._initialize_weights()

    def _initialize_weights(self):
        nn.init.xavier_uniform_(self.W1.weight, gain=0.1)
        nn.init.zeros_(self.W1.bias)
        nn.init.zeros_(self.W2.weight)
        nn.init.zeros_(self.W2.bias)

    def forward(self, h: torch.Tensor) -> torch.Tensor:
        """
        Args:
            h: Shape (batch, hidden_dim) or (batch, seq, hidden_dim).
        Returns:
            Same shape as h.
        """
        return h + self.W2(F.gelu(self.W1(h)))


class OuterRecursiveLink(nn.Module):
    """
    The Outer RecursiveLink bridges two heterogeneous LLM agents.

    Transforms last-layer hidden states of a source agent into input
    embeddings aligned with the target agent's embedding space.

    Architecture:
        R_outer(h) = W3 * h + W2 * GELU(W1 * h)

    W3 handles dimensional alignment (source_dim -> target_dim).
    The nonlinear branch learns fine-grained distributional correction.

    Args:
        source_dim: Hidden dimension of the source agent.
        target_dim: Hidden dimension of the target agent.
    """

    def __init__(self, source_dim: int, target_dim: int):
        super().__init__()
        self.source_dim = source_dim
        self.target_dim = target_dim
        self.W1 = nn.Linear(source_dim, source_dim, bias=True)
        self.W2 = nn.Linear(source_dim, target_dim, bias=True)
        self.W3 = nn.Linear(source_dim, target_dim, bias=False)
        self._initialize_weights()

    def _initialize_weights(self):
        nn.init.xavier_uniform_(self.W3.weight, gain=1.0)
        nn.init.xavier_uniform_(self.W1.weight, gain=0.1)
        nn.init.zeros_(self.W1.bias)
        nn.init.zeros_(self.W2.weight)
        nn.init.zeros_(self.W2.bias)

    def forward(self, h: torch.Tensor) -> torch.Tensor:
        """
        Args:
            h: Shape (batch, seq, source_dim) or (batch, source_dim).
        Returns:
            Shape (batch, seq, target_dim) or (batch, target_dim).
        """
        return self.W3(h) + self.W2(F.gelu(self.W1(h)))


# ============================================================
# Section 2: Latent Thoughts Generation
# ============================================================

def generate_latent_thoughts(
    model: nn.Module,
    input_embeddings: torch.Tensor,
    inner_link: InnerRecursiveLink,
    num_latent_steps: int = 80,
    prior_latent_state: Optional[torch.Tensor] = None,
    training: bool = False,
) -> torch.Tensor:
    """
    Run an agent in latent-thoughts generation mode.

    The frozen model forward pass runs under torch.no_grad() to avoid
    building a computation graph through frozen parameters. The inner_link
    forward pass runs with gradient tracking when training=True.

    Args:
        model: Frozen Hugging Face causal LM.
        input_embeddings: Shape (batch, context_len, hidden_dim).
        inner_link: The InnerRecursiveLink for this agent.
        num_latent_steps: Latent steps per agent (paper default: 80).
        prior_latent_state: Optional projected latent from previous round.
                            Shape (batch, prior_len, hidden_dim).
        training: True during outer-loop training; False at inference.

    Returns:
        latent_thoughts: Shape (batch, num_latent_steps, hidden_dim).
    """
    if prior_latent_state is not None:
        current_embeddings = torch.cat(
            [prior_latent_state, input_embeddings], dim=1
        )
    else:
        current_embeddings = input_embeddings

    latent_thoughts = []

    for _ in range(num_latent_steps):
        # Frozen model forward — never builds grad graph
        with torch.no_grad():
            outputs = model(
                inputs_embeds=current_embeddings,
                output_hidden_states=True,
                use_cache=False,
            )
            # Shape: (batch, hidden_dim)
            last_hidden = outputs.hidden_states[-1][:, -1, :]

        # Trainable inner link — builds grad graph when training=True
        if training:
            next_emb = inner_link(last_hidden)
        else:
            with torch.no_grad():
                next_emb = inner_link(last_hidden)

        # Store the hidden state value (detached from model graph)
        latent_thoughts.append(last_hidden.unsqueeze(1))

        # Grow the context by one embedding per step
        current_embeddings = torch.cat(
            [current_embeddings, next_emb.unsqueeze(1)], dim=1
        )

    return torch.cat(latent_thoughts, dim=1)


# ============================================================
# Section 3: Inner-Loop Training
# ============================================================

def compute_inner_loop_loss(
    model: nn.Module,
    inner_link: InnerRecursiveLink,
    input_ids: torch.Tensor,
    target_ids: torch.Tensor,
) -> torch.Tensor:
    """
    Compute the inner-loop training loss for one agent.

    Loss = 1 - cosine_similarity(R_inner(H_0), mean(Emb(y)))

    Gradients flow through inner_link.W1 and inner_link.W2 only.
    The model is frozen; H_0 is produced under torch.no_grad().

    Args:
        model: Frozen LLM agent.
        inner_link: InnerRecursiveLink to train.
        input_ids: Shape (batch, seq_len).
        target_ids: Shape (batch, ans_len).

    Returns:
        Scalar loss tensor in [0, 2].
    """
    embedding_layer = model.get_input_embeddings()

    with torch.no_grad():
        input_embeddings = embedding_layer(input_ids)
        target_embeddings = embedding_layer(target_ids)
        initial_output = model(
            inputs_embeds=input_embeddings,
            output_hidden_states=True,
            use_cache=False,
        )
        # Shape: (batch, hidden_dim)
        H_0 = initial_output.hidden_states[-1][:, -1, :]

    # inner_link called OUTSIDE no_grad — gradients flow through W1, W2
    transformed_latent = inner_link(H_0)

    # Target: mean of answer token embeddings (detached)
    target_direction = target_embeddings.mean(dim=1)

    similarity = F.cosine_similarity(
        transformed_latent, target_direction, dim=-1
    )
    return (1.0 - similarity).mean()


def train_inner_loop(
    model: nn.Module,
    inner_link: InnerRecursiveLink,
    dataloader,
    num_epochs: int = 3,
    learning_rate: float = 1e-4,
) -> None:
    """
    Run the inner-loop training stage for one agent.

    Call this independently for each agent before outer-loop training.

    Args:
        model: Frozen LLM agent (all params must have requires_grad=False).
        inner_link: InnerRecursiveLink to train.
        dataloader: Yields (input_ids, target_ids) batches.
        num_epochs: Training epochs.
        learning_rate: Initial learning rate for AdamW with cosine schedule.
    """
    for param in model.parameters():
        param.requires_grad = False

    optimizer = AdamW(inner_link.parameters(), lr=learning_rate)
    scheduler = CosineAnnealingLR(
        optimizer, T_max=num_epochs * len(dataloader)
    )

    inner_link.train()
    model.eval()

    for epoch in range(num_epochs):
        total_loss = 0.0
        num_batches = 0

        for input_ids, target_ids in dataloader:
            optimizer.zero_grad()
            loss = compute_inner_loop_loss(model, inner_link,
                                           input_ids, target_ids)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(
                inner_link.parameters(), max_norm=1.0
            )
            optimizer.step()
            scheduler.step()
            total_loss += loss.item()
            num_batches += 1

        avg_loss = total_loss / max(num_batches, 1)
        print(
            f"  [Inner Loop] Epoch {epoch + 1}/{num_epochs} "
            f"| Loss: {avg_loss:.4f}"
        )


# ============================================================
# Section 4: RecursiveMASSystem and Outer-Loop Training
# ============================================================

class RecursiveMASSystem(nn.Module):
    """
    A RecursiveMAS system with N agents in sequential style.

    Frozen agent models are stored as a plain Python list (NOT
    nn.ModuleList) so they do not appear in self.parameters() or
    self.state_dict(). This keeps the optimizer and checkpoint
    focused on the RecursiveLink modules only.

    Inner and outer links ARE stored as nn.ModuleList so they are
    properly registered as submodules and appear in self.parameters().

    Args:
        agent_models: List of N frozen Hugging Face causal LM models.
        inner_links: List of N InnerRecursiveLink modules.
        outer_links: List of N OuterRecursiveLink modules.
                     outer_links[i] connects agent i to agent (i+1) % N.
        num_latent_steps: Latent steps per agent per round (default 80).
    """

    def __init__(
        self,
        agent_models: list,
        inner_links: list,
        outer_links: list,
        num_latent_steps: int = 80,
    ):
        super().__init__()

        if len(agent_models) < 2:
            raise ValueError("RecursiveMASSystem requires at least 2 agents.")
        if len(inner_links) != len(agent_models):
            raise ValueError("Must provide exactly one inner_link per agent.")
        if len(outer_links) != len(agent_models):
            raise ValueError("Must provide exactly one outer_link per agent.")

        self.num_agents = len(agent_models)
        self.num_latent_steps = num_latent_steps

        # Frozen models: plain list, NOT registered as submodules
        self._agent_models = agent_models

        # Trainable links: registered as submodules via nn.ModuleList
        self.inner_links = nn.ModuleList(inner_links)
        self.outer_links = nn.ModuleList(outer_links)

        # Verify all agent models are frozen
        for i, model in enumerate(self._agent_models):
            for param in model.parameters():
                if param.requires_grad:
                    raise ValueError(
                        f"Agent {i} has unfrozen parameters. "
                        f"Call model.requires_grad_(False) before "
                        f"constructing RecursiveMASSystem."
                    )

    def _run_agent_latent(
        self,
        agent_idx: int,
        context_embeddings: torch.Tensor,
        training: bool,
    ) -> torch.Tensor:
        """
        Run one agent in latent-thoughts generation mode.

        Args:
            agent_idx: Index of the agent (0-based).
            context_embeddings: Shape (batch, context_len, hidden_dim).
            training: True during outer-loop training.

        Returns:
            Shape (batch, num_latent_steps, hidden_dim).
        """
        model = self._agent_models[agent_idx]
        inner_link = self.inner_links[agent_idx]
        current_embeddings = context_embeddings
        latent_thoughts = []

        for _ in range(self.num_latent_steps):
            with torch.no_grad():
                outputs = model(
                    inputs_embeds=current_embeddings,
                    output_hidden_states=True,
                    use_cache=False,
                )
                last_hidden = outputs.hidden_states[-1][:, -1, :]

            if training:
                next_emb = inner_link(last_hidden)
            else:
                with torch.no_grad():
                    next_emb = inner_link(last_hidden)

            latent_thoughts.append(last_hidden.unsqueeze(1))
            current_embeddings = torch.cat(
                [current_embeddings, next_emb.unsqueeze(1)], dim=1
            )

        return torch.cat(latent_thoughts, dim=1)

    def forward(
        self,
        input_ids: torch.Tensor,
        num_recursion_rounds: int = 3,
    ) -> torch.Tensor:
        """
        Run the full RecursiveMAS forward pass.

        Args:
            input_ids: Shape (batch, seq_len).
            num_recursion_rounds: Number of recursive rounds (default 3).

        Returns:
            logits: Shape (batch, decode_context_len, vocab_size).
        """
        training = self.training

        # Pre-embed the input for each agent (each may have different hidden_dim)
        input_embeddings = []
        for model in self._agent_models:
            with torch.no_grad():
                emb = model.get_input_embeddings()(input_ids)
            input_embeddings.append(emb)

        # feedback_latent[i]: outer-link-projected latent that agent i
        # will receive at the start of the next round. None before round 1.
        feedback_latent: List[Optional[torch.Tensor]] = [None] * self.num_agents
        final_logits = None

        for round_idx in range(num_recursion_rounds):
            new_feedback_latent: List[Optional[torch.Tensor]] = [None] * self.num_agents

            for agent_idx in range(self.num_agents):
                # Build context: feedback from previous round + fresh input
                if feedback_latent[agent_idx] is not None:
                    context = torch.cat(
                        [feedback_latent[agent_idx],
                         input_embeddings[agent_idx]],
                        dim=1,
                    )
                else:
                    context = input_embeddings[agent_idx]

                # Run this agent in latent mode
                latent = self._run_agent_latent(
                    agent_idx=agent_idx,
                    context_embeddings=context,
                    training=training,
                )

                # Project latent into the next agent's embedding space
                next_agent_idx = (agent_idx + 1) % self.num_agents
                outer_link = self.outer_links[agent_idx]

                if training:
                    projected = outer_link(latent)
                else:
                    with torch.no_grad():
                        projected = outer_link(latent)

                # Projected latent becomes feedback for the next agent
                # in this same round (sequential: each agent sees the
                # previous agent's output within the round)
                new_feedback_latent[next_agent_idx] = projected

            feedback_latent = new_feedback_latent

            # After the final round, decode text from the last agent
            if round_idx == num_recursion_rounds - 1:
                last_agent_idx = self.num_agents - 1
                last_model = self._agent_models[last_agent_idx]

                # Reconstruct the last agent's decode context
                if feedback_latent[last_agent_idx] is not None:
                    decode_context = torch.cat(
                        [feedback_latent[last_agent_idx],
                         input_embeddings[last_agent_idx]],
                        dim=1,
                    )
                else:
                    decode_context = input_embeddings[last_agent_idx]

                if training:
                    decode_output = last_model(
                        inputs_embeds=decode_context,
                        output_hidden_states=False,
                        use_cache=False,
                    )
                else:
                    with torch.no_grad():
                        decode_output = last_model(
                            inputs_embeds=decode_context,
                            output_hidden_states=False,
                            use_cache=False,
                        )
                final_logits = decode_output.logits

        return final_logits


def train_outer_loop(
    system: RecursiveMASSystem,
    dataloader,
    num_epochs: int = 3,
    learning_rate: float = 1e-4,
    num_recursion_rounds: int = 3,
) -> None:
    """
    Run the outer-loop training stage for the full RecursiveMAS system.

    Only RecursiveLink parameters are updated. Agent models stay frozen.
    Gradients flow through every outer link and inner link in every round
    (shared credit assignment across the full recursive computation graph).

    Args:
        system: RecursiveMASSystem with frozen agent models.
        dataloader: Yields (input_ids, target_ids) batches.
                    Use -100 in target_ids for positions to ignore.
        num_epochs: Training epochs.
        learning_rate: Initial learning rate for AdamW with cosine schedule.
        num_recursion_rounds: Recursive rounds during training (default 3).
    """
    # system.parameters() returns only RecursiveLink params because
    # agent models are stored as a plain list (not nn.ModuleList)
    optimizer = AdamW(system.parameters(), lr=learning_rate)
    scheduler = CosineAnnealingLR(
        optimizer, T_max=num_epochs * len(dataloader)
    )
    criterion = nn.CrossEntropyLoss(ignore_index=-100)

    system.train()

    for epoch in range(num_epochs):
        total_loss = 0.0
        num_batches = 0

        for input_ids, target_ids in dataloader:
            optimizer.zero_grad()

            logits = system(
                input_ids=input_ids,
                num_recursion_rounds=num_recursion_rounds,
            )

            # Align logits with target: take the last seq_len positions
            seq_len = target_ids.shape[1]
            logits_for_loss = logits[:, -seq_len:, :]

            batch_size, ans_len, vocab_size = logits_for_loss.shape
            loss = criterion(
                logits_for_loss.reshape(batch_size * ans_len, vocab_size),
                target_ids.reshape(batch_size * ans_len),
            )

            # Backpropagate through all recursion rounds
            loss.backward()
            torch.nn.utils.clip_grad_norm_(
                system.parameters(), max_norm=1.0
            )
            optimizer.step()
            scheduler.step()

            total_loss += loss.item()
            num_batches += 1

        avg_loss = total_loss / max(num_batches, 1)
        print(
            f"  [Outer Loop] Epoch {epoch + 1}/{num_epochs} "
            f"| Loss: {avg_loss:.4f}"
        )


# ============================================================
# Section 5: Agent Abstractions
# ============================================================

@dataclass
class AgentResponse:
    """
    Unified response from any agent, local or remote.

    Attributes:
        text: Decoded text response.
        hidden_states: Last-layer hidden states (local agents only).
                       Shape: (1, 1, hidden_dim) for the last generated
                       token's last layer, or None for remote agents.
        token_count: Number of tokens generated.
    """
    text: str
    hidden_states: Optional[torch.Tensor]
    token_count: int


class BaseAgent(ABC):
    """Abstract base class for all RecursiveMAS agents."""

    def __init__(self, name: str, role: str):
        self.name = name
        self.role = role

    @abstractmethod
    def generate_text(
        self,
        prompt: str,
        system_prompt: Optional[str] = None,
        max_new_tokens: int = 512,
        temperature: float = 0.6,
    ) -> AgentResponse:
        pass

    @property
    @abstractmethod
    def supports_latent_transfer(self) -> bool:
        pass

    def __repr__(self) -> str:
        return (
            f"{self.__class__.__name__}"
            f"(name={self.name!r}, role={self.role!r})"
        )


class LocalHuggingFaceAgent(BaseAgent):
    """
    RecursiveMAS agent backed by a local Hugging Face model.

    Supports full latent-space state transfer via hidden states.
    All model parameters are frozen at construction time.

    Args:
        model_name_or_path: HF model ID or local path.
        name: Agent name.
        role: Role description for system prompts.
        device: "cuda", "mps", or "cpu". Note: MPS support varies by model.
        torch_dtype: torch.float16 recommended for GPU; torch.float32 for CPU.
    """

    def __init__(
        self,
        model_name_or_path: str,
        name: str,
        role: str,
        device: str = "cuda",
        torch_dtype: torch.dtype = torch.float16,
    ):
        super().__init__(name=name, role=role)
        self.device = device
        self.model_name = model_name_or_path

        # Import here so the file is importable without transformers installed
        from transformers import AutoTokenizer, AutoModelForCausalLM

        print(f"Loading {name} from {model_name_or_path}...")

        self.tokenizer = AutoTokenizer.from_pretrained(
            model_name_or_path, trust_remote_code=True
        )
        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token

        self.model = AutoModelForCausalLM.from_pretrained(
            model_name_or_path,
            torch_dtype=torch_dtype,
            device_map=device,
            trust_remote_code=True,
        )

        # Freeze all parameters immediately
        for param in self.model.parameters():
            param.requires_grad = False
        self.model.eval()

        self.hidden_dim = self.model.config.hidden_size
        total_params = sum(p.numel() for p in self.model.parameters())
        print(
            f"  Loaded {name}: hidden_dim={self.hidden_dim}, "
            f"params={total_params:,} (all frozen)"
        )

    @property
    def supports_latent_transfer(self) -> bool:
        return True

    def get_hidden_dim(self) -> int:
        return self.hidden_dim

    def get_raw_model(self) -> nn.Module:
        """Return the underlying frozen HF model for use in RecursiveMASSystem."""
        return self.model

    def get_embeddings(self, text: str) -> torch.Tensor:
        """
        Embed text using the model's input embedding layer.

        Returns:
            Shape (1, seq_len, hidden_dim).
        """
        from transformers import AutoTokenizer  # already loaded, no-op
        inputs = self.tokenizer(
            text, return_tensors="pt", truncation=True, max_length=512
        ).to(self.device)
        with torch.no_grad():
            embeddings = self.model.get_input_embeddings()(inputs["input_ids"])
        return embeddings

    def generate_text(
        self,
        prompt: str,
        system_prompt: Optional[str] = None,
        max_new_tokens: int = 512,
        temperature: float = 0.6,
    ) -> AgentResponse:
        """
        Generate text using standard autoregressive decoding.

        Returns:
            AgentResponse. hidden_states shape: (1, 1, hidden_dim)
            representing the last layer of the last generated token.
        """
        messages = []
        if system_prompt:
            messages.append({"role": "system", "content": system_prompt})
        messages.append({"role": "user", "content": prompt})

        formatted_prompt = self.tokenizer.apply_chat_template(
            messages, tokenize=False, add_generation_prompt=True
        )
        inputs = self.tokenizer(
            formatted_prompt,
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=2048,
        ).to(self.device)
        input_length = inputs["input_ids"].shape[1]

        with torch.no_grad():
            outputs = self.model.generate(
                **inputs,
                max_new_tokens=max_new_tokens,
                temperature=temperature if temperature > 0 else 1.0,
                top_p=0.95,
                do_sample=(temperature > 0),
                output_hidden_states=True,
                return_dict_in_generate=True,
                pad_token_id=self.tokenizer.pad_token_id,
            )

        generated_ids = outputs.sequences[:, input_length:]
        generated_text = self.tokenizer.decode(
            generated_ids[0], skip_special_tokens=True
        )

        # outputs.hidden_states: tuple[tuple[Tensor]]
        # Outer tuple: one entry per generated token
        # Inner tuple: one entry per Transformer layer
        # [-1][-1]: last layer of last generated token
        # Shape: (batch=1, seq_pos=1, hidden_dim)
        last_step_hidden = None
        if outputs.hidden_states:
            last_step_hidden = outputs.hidden_states[-1][-1]

        return AgentResponse(
            text=generated_text,
            hidden_states=last_step_hidden,
            token_count=generated_ids.shape[1],
        )


class RemoteAPIAgent(BaseAgent):
    """
    RecursiveMAS agent backed by a remote OpenAI-compatible API.

    Does NOT support latent-space transfer (no access to hidden states).
    Works with: OpenAI, Azure OpenAI, Ollama, vLLM, LM Studio, etc.

    Args:
        api_base_url: API base URL, e.g. "https://api.openai.com/v1".
        model_id: Model identifier, e.g. "gpt-4o-mini".
        api_key: API key. Use "ollama" for Ollama (no real key needed).
        name: Agent name.
        role: Role description.
        timeout: HTTP timeout in seconds.
        max_retries: Retry attempts with exponential backoff.
    """

    def __init__(
        self,
        api_base_url: str,
        model_id: str,
        api_key: str,
        name: str,
        role: str,
        timeout: int = 60,
        max_retries: int = 3,
    ):
        super().__init__(name=name, role=role)
        self.api_base_url = api_base_url.rstrip("/")
        self.model_id = model_id
        self.timeout = timeout
        self.max_retries = max_retries
        self.headers = {
            "Authorization": f"Bearer {api_key}",
            "Content-Type": "application/json",
        }

    @property
    def supports_latent_transfer(self) -> bool:
        return False

    def generate_text(
        self,
        prompt: str,
        system_prompt: Optional[str] = None,
        max_new_tokens: int = 512,
        temperature: float = 0.6,
    ) -> AgentResponse:
        messages = []
        if system_prompt:
            messages.append({"role": "system", "content": system_prompt})
        messages.append({"role": "user", "content": prompt})

        payload = {
            "model": self.model_id,
            "messages": messages,
            "max_tokens": max_new_tokens,
            "temperature": temperature,
            "top_p": 0.95,
        }

        last_error = None
        for attempt in range(self.max_retries):
            try:
                response = requests.post(
                    f"{self.api_base_url}/chat/completions",
                    headers=self.headers,
                    json=payload,
                    timeout=self.timeout,
                )
                response.raise_for_status()
                data = response.json()
                generated_text = data["choices"][0]["message"]["content"]
                token_count = data.get("usage", {}).get(
                    "completion_tokens", len(generated_text.split())
                )
                return AgentResponse(
                    text=generated_text,
                    hidden_states=None,
                    token_count=token_count,
                )
            except requests.exceptions.RequestException as e:
                last_error = e
                if attempt < self.max_retries - 1:
                    wait_time = 2 ** attempt
                    print(
                        f"  [{self.name}] Attempt {attempt + 1}/"
                        f"{self.max_retries} failed, retrying in "
                        f"{wait_time}s: {e}"
                    )
                    time.sleep(wait_time)

        raise RuntimeError(
            f"[{self.name}] All {self.max_retries} attempts failed. "
            f"Last error: {last_error}"
        )


class OllamaAgent(RemoteAPIAgent):
    """
    Agent backed by a local Ollama server.

    Ollama exposes an OpenAI-compatible REST API at localhost:11434.

    Setup:
        ollama serve
        ollama pull qwen2.5:1.5b   # or any supported model

    Args:
        model_id: Ollama model name, e.g. "qwen2.5:1.5b", "llama3.2:1b".
        name: Agent name.
        role: Role description.
        host: Ollama server host (default "localhost").
        port: Ollama server port (default 11434).
        timeout: HTTP timeout in seconds (default 120 for CPU inference).
        max_retries: Retry attempts (default 3).
    """

    def __init__(
        self,
        model_id: str,
        name: str,
        role: str,
        host: str = "localhost",
        port: int = 11434,
        timeout: int = 120,
        max_retries: int = 3,
    ):
        super().__init__(
            api_base_url=f"http://{host}:{port}/v1",
            model_id=model_id,
            api_key="ollama",
            name=name,
            role=role,
            timeout=timeout,
            max_retries=max_retries,
        )


# ============================================================
# Section 6: Orchestrator
# ============================================================

class SequentialRecursiveMASOrchestrator:
    """
    Orchestrates sequential-style RecursiveMAS collaboration.

    Manages a chain of agents (Planner -> Critic -> Solver -> ...) through
    multiple recursion rounds. Uses latent-space mode when all agents are
    local and links are provided; falls back to text-based mode otherwise.

    Args:
        agents: List of BaseAgent instances in pipeline order (min 2).
        inner_links: InnerRecursiveLink per agent (latent mode only).
        outer_links: OuterRecursiveLink per agent (latent mode only).
        num_latent_steps: Latent steps per agent per round (default 80).
        num_recursion_rounds: Recursive rounds (default 3).
    """

    def __init__(
        self,
        agents: List[BaseAgent],
        inner_links: Optional[List[InnerRecursiveLink]] = None,
        outer_links: Optional[List[OuterRecursiveLink]] = None,
        num_latent_steps: int = 80,
        num_recursion_rounds: int = 3,
    ):
        if len(agents) < 2:
            raise ValueError(
                "Sequential RecursiveMAS requires at least 2 agents."
            )
        self.agents = agents
        self.inner_links = inner_links
        self.outer_links = outer_links
        self.num_latent_steps = num_latent_steps
        self.num_recursion_rounds = num_recursion_rounds

        self.use_latent_mode = (
            all(agent.supports_latent_transfer for agent in agents)
            and inner_links is not None
            and outer_links is not None
            and len(inner_links) == len(agents)
            and len(outer_links) == len(agents)
        )

        mode = "latent-space" if self.use_latent_mode else "text-based"
        print(
            f"SequentialRecursiveMAS: {len(agents)} agents, {mode} mode, "
            f"{num_recursion_rounds} recursion rounds."
        )

    def _build_agent_prompt(
        self,
        agent: BaseAgent,
        question: str,
        prior_context: Optional[str],
        round_idx: int,
        agent_idx: int,
    ) -> Tuple[str, str]:
        system_prompt = (
            f"You are a {agent.role} in a recursive multi-agent system. "
            f"You are agent {agent_idx + 1} of {len(self.agents)}, "
            f"in recursion round {round_idx + 1} of {self.num_recursion_rounds}. "
            f"Build carefully on the work of other agents."
        )
        if prior_context:
            user_prompt = (
                f"Context from the previous agent:\n\n"
                f"---\n{prior_context}\n---\n\n"
                f"Your response to the question:\n\n{question}"
            )
        else:
            user_prompt = f"Please respond to:\n\n{question}"
        return system_prompt, user_prompt

    def solve_text_mode(self, question: str) -> str:
        """
        Solve using text-based recursive collaboration.

        Works with any combination of local and remote agents.
        Each agent sees the previous agent's text output within a round;
        the last agent's output feeds back to the first agent in the
        next round.

        Args:
            question: The question to answer.

        Returns:
            Final text answer from the last agent in the last round.
            Returns empty string if agents list is empty or rounds is 0.
        """
        if not self.agents or self.num_recursion_rounds == 0:
            return ""

        print(
            f"\nSolving in text mode: "
            f"{self.num_recursion_rounds} rounds, "
            f"{len(self.agents)} agents per round."
        )

        previous_round_output: Optional[str] = None
        final_answer: str = ""

        for round_idx in range(self.num_recursion_rounds):
            print(f"  Round {round_idx + 1}/{self.num_recursion_rounds}")
            current_context: Optional[str] = previous_round_output

            for agent_idx, agent in enumerate(self.agents):
                system_prompt, user_prompt = self._build_agent_prompt(
                    agent=agent,
                    question=question,
                    prior_context=current_context,
                    round_idx=round_idx,
                    agent_idx=agent_idx,
                )
                response = agent.generate_text(
                    prompt=user_prompt,
                    system_prompt=system_prompt,
                    max_new_tokens=512,
                    temperature=0.6,
                )
                print(
                    f"    [{agent.name}] {response.token_count} tokens"
                )
                current_context = response.text
                if agent_idx == len(self.agents) - 1:
                    final_answer = response.text

            previous_round_output = final_answer

        return final_answer

    def solve(self, question: str) -> str:
        """
        Solve a question using RecursiveMAS.

        Selects latent-space mode if available, otherwise text-based mode.

        Args:
            question: The question to answer.

        Returns:
            The final answer string.
        """
        if self.use_latent_mode:
            print(
                "Note: Full latent-space inference is available via "
                "RecursiveMASSystem.forward(). The orchestrator uses "
                "text-based mode for broad compatibility."
            )
        return self.solve_text_mode(question)

CHAPTER SEVEN: THE DEMO SCRIPT

Save the following as demo.py in the same directory as recursive_mas.py:

"""
demo.py
=======
Demonstration of RecursiveMAS sequential collaboration.

Run with:
    python demo.py --mode ollama
    python demo.py --mode openai
    python demo.py --mode ollama --question "Your question here"

Prerequisites:
    Ollama mode:
        ollama serve
        ollama pull qwen2.5:1.5b
        ollama pull llama3.2:1b

    OpenAI mode:
        export OPENAI_API_KEY="sk-..."
"""

import os
import argparse
from typing import List

# Import everything from the combined module
from recursive_mas import (
    BaseAgent,
    OllamaAgent,
    RemoteAPIAgent,
    SequentialRecursiveMASOrchestrator,
)


def create_ollama_agents() -> List[BaseAgent]:
    """
    Create a three-agent sequential system using local Ollama models.

    Uses small, fast models that run on a laptop (CPU is fine):
      - Planner:  qwen2.5:1.5b  (~1 GB)
      - Critic:   llama3.2:1b   (~700 MB)
      - Solver:   qwen2.5:1.5b  (~1 GB, same model, different role)
    """
    planner = OllamaAgent(
        model_id="qwen2.5:1.5b",
        name="Planner",
        role=(
            "expert problem decomposer who breaks complex questions into "
            "clear, structured step-by-step plans"
        ),
    )
    critic = OllamaAgent(
        model_id="llama3.2:1b",
        name="Critic",
        role=(
            "rigorous evaluator who identifies flaws, gaps, and improvements "
            "in proposed plans and solutions"
        ),
    )
    solver = OllamaAgent(
        model_id="qwen2.5:1.5b",
        name="Solver",
        role=(
            "precise problem solver who produces final, well-reasoned answers "
            "based on the plan and critique provided"
        ),
    )
    return [planner, critic, solver]


def create_openai_agents(api_key: str) -> List[BaseAgent]:
    """
    Create a three-agent sequential system using the OpenAI API.

    Uses gpt-4o-mini for all three roles for cost efficiency.
    """
    base_url = "https://api.openai.com/v1"
    model = "gpt-4o-mini"

    planner = RemoteAPIAgent(
        api_base_url=base_url,
        model_id=model,
        api_key=api_key,
        name="Planner",
        role=(
            "expert problem decomposer who breaks complex questions into "
            "clear, structured step-by-step plans"
        ),
    )
    critic = RemoteAPIAgent(
        api_base_url=base_url,
        model_id=model,
        api_key=api_key,
        name="Critic",
        role=(
            "rigorous evaluator who identifies flaws, gaps, and improvements "
            "in proposed plans and solutions"
        ),
    )
    solver = RemoteAPIAgent(
        api_base_url=base_url,
        model_id=model,
        api_key=api_key,
        name="Solver",
        role=(
            "precise problem solver who produces final, well-reasoned answers "
            "based on the plan and critique provided"
        ),
    )
    return [planner, critic, solver]


def run_demo(agents: List[BaseAgent], question: str) -> None:
    """Run a RecursiveMAS demo with the given agents and question."""
    print("\n" + "=" * 60)
    print("RecursiveMAS Sequential Collaboration Demo")
    print("=" * 60)
    print(f"Question : {question}")
    print(f"Agents   : {[a.name for a in agents]}")
    print(f"Rounds   : 3")
    print("-" * 60)

    orchestrator = SequentialRecursiveMASOrchestrator(
        agents=agents,
        num_recursion_rounds=3,
    )

    answer = orchestrator.solve(question)

    print("\n" + "=" * 60)
    print("FINAL ANSWER:")
    print("-" * 60)
    print(answer)
    print("=" * 60)


def main() -> None:
    parser = argparse.ArgumentParser(
        description="RecursiveMAS Sequential Collaboration Demo",
        formatter_class=argparse.RawDescriptionHelpFormatter,
        epilog=__doc__,
    )
    parser.add_argument(
        "--mode",
        choices=["ollama", "openai"],
        default="ollama",
        help="Backend to use for agents (default: ollama).",
    )
    parser.add_argument(
        "--question",
        type=str,
        default=(
            "A train travels from City A to City B at 60 km/h and returns "
            "at 40 km/h. What is the average speed for the entire journey?"
        ),
        help="The question to answer.",
    )
    args = parser.parse_args()

    if args.mode == "ollama":
        print("Using local Ollama agents.")
        print("Ensure Ollama is running ('ollama serve') and models are pulled:")
        print("  ollama pull qwen2.5:1.5b")
        print("  ollama pull llama3.2:1b")
        agents = create_ollama_agents()

    elif args.mode == "openai":
        api_key = os.environ.get("OPENAI_API_KEY", "")
        if not api_key:
            raise ValueError(
                "Set the OPENAI_API_KEY environment variable to use OpenAI. "
                "Example: export OPENAI_API_KEY='sk-...'"
            )
        print("Using OpenAI API agents (gpt-4o-mini).")
        agents = create_openai_agents(api_key)

    else:
        raise ValueError(f"Unknown mode: {args.mode}")

    run_demo(agents=agents, question=args.question)


if __name__ == "__main__":
    main()

CHAPTER EIGHT: PERFORMANCE RESULTS AND WHAT THEY MEAN

Let us take a step back and look at what the paper actually reports in terms of numbers, because the results are genuinely impressive and worth understanding in detail.

The paper evaluates RecursiveMAS across nine benchmarks. On MATH500, a collection of 500 competition-style math problems, RecursiveMAS achieves 88.0% accuracy. To put that in context, a single agent with LoRA fine-tuning achieves 83.1%, and the strongest text-based recursive baseline (Recursive-TextMAS) achieves 85.8%. RecursiveMAS beats the text-based recursive baseline by 2.2 percentage points while being faster and using fewer tokens.

On GPQA-Diamond, a graduate-level science and engineering benchmark that is genuinely difficult even for expert humans, RecursiveMAS achieves 66.2% compared to 62.5% for TextGrad and 61.6% for Recursive-TextMAS. On LiveCodeBench, a competitive programming benchmark, RecursiveMAS achieves 42.9% compared to 39.8% for TextGrad and 38.7% for Recursive-TextMAS.

On the AIME 2025 and AIME 2026 benchmarks (the American Invitational Mathematics Examination, which is extremely difficult), RecursiveMAS achieves 86.7% on both, compared to 73.3% for the best single-agent baseline and 73.3% for Recursive-TextMAS on AIME 2025. That is a 13.4 percentage point improvement over the text-based recursive baseline on one of the hardest math benchmarks in existence.

The efficiency gains are equally striking. As the number of recursion rounds increases from 1 to 3, RecursiveMAS becomes progressively faster relative to the text-based baseline, because each additional round of latent collaboration adds very little overhead (just the RecursiveLink transformations) compared to the text-based approach which must decode and re-encode all intermediate outputs. By round 3, RecursiveMAS is generating 34.6% to 75.6% fewer tokens than the text-based baseline while achieving higher accuracy.

The following table summarizes the key comparison from the paper:

MethodMATH500AIME25GPQA-DLiveCodeMedQA
Single Agent (LoRA)83.1%70.0%62.0%37.4%76.1%
Single Agent (Full-SFT)83.2%73.3%62.8%38.6%77.0%
Mixture-of-Agents (MoA)79.8%60.0%47.6%27.0%57.5%
TextGrad84.9%73.3%62.5%39.8%77.2%
LoopLM84.6%66.7%48.1%24.9%56.4%
Recursive-TextMAS85.8%73.3%61.6%38.7%77.0%
RecursiveMAS88.0%86.7%66.2%42.9%79.3%

RecursiveMAS wins on every single benchmark, often by a substantial margin. The improvement over Recursive-TextMAS (which has the same structure but uses text instead of latent states) is particularly telling because it isolates the contribution of the latent-space communication itself.

The RecursiveLink design ablation is also worth examining. The paper compares four architectural variants:

RecursiveLink DesignMATH500GPQA-DLiveCodeBench
1-Layer (no residual)84.4%63.2%40.1%
1-Layer + Residual86.7%65.3%41.4%
2-Layer (no residual)85.6%64.5%40.5%
2-Layer + Residual (ours)88.0%66.2%42.9%

The residual connection is clearly important — adding it to a 1-layer design (going from 84.4% to 86.7% on MATH500) gives a bigger boost than adding a second layer without the residual (going from 84.4% to 85.6%). The full 2-layer residual design is best across the board. This validates the design intuition: the residual connection forces the module to learn only the distributional shift, which is easier to learn and more stable to train.

The latent thoughts length ablation is also instructive:

Latent Steps (m)MATH500GPQA-DLiveCodeBench
0 (no latent)83.3%61.4%38.1%
1684.9%62.0%40.3%
3285.2%62.8%40.7%
4885.6%63.6%41.4%
6486.8%64.1%42.0%
8086.8%64.2%42.5%
9686.5%64.5%42.2%
11286.9%64.3%42.6%
12886.7%64.4%42.6%

Performance improves steadily from m=0 to m=80, then plateaus. This means you do not need to run hundreds of latent steps — 80 is enough. And even m=16 gives a meaningful improvement over no latent thoughts at all, which is encouraging for resource-constrained deployments.


CHAPTER NINE: THE BIGGER PICTURE AND WHAT COMES NEXT

RecursiveMAS represents a genuinely new way of thinking about multi-agent AI systems. Instead of treating agents as black boxes that communicate through text, it treats the entire multi-agent system as a single unified computation that happens to be distributed across multiple models. The RecursiveLink is the key enabler: a tiny, trainable module that knows how to translate between the hidden spaces of different models.

This has several implications that are worth thinking about carefully.

The first implication is about the nature of agent communication. When we build text-based multi-agent systems, we implicitly assume that text is the right medium for agents to share information. RecursiveMAS challenges this assumption. Text is great for communicating with humans, but between AI agents, continuous vectors are richer, faster, and more gradient-friendly. The paper's theoretical results make this precise: text-based communication introduces gradient vanishing and computational overhead that latent-space communication avoids.

The second implication is about the scalability of multi-agent systems. The paper shows a clear scaling law: more recursion rounds means better performance, and this improvement is consistent across all benchmarks and all collaboration patterns. This is exciting because it means you can trade compute for accuracy in a predictable way, just as you can with larger models or longer context windows. The "recursion depth" becomes a new axis of scaling.

The third implication is about the cost of intelligence. RecursiveMAS achieves its best results with only 13.12 million trainable parameters — a tiny fraction of the total parameter count of the agents it connects. This suggests that a lot of the "intelligence" in a multi-agent system is not in the individual agents themselves, but in how they communicate. Training better communication channels (the RecursiveLinks) is more efficient than training better agents.

The fourth implication is about heterogeneity. RecursiveMAS works with agents of different sizes and from different model families. The Outer RecursiveLink handles the dimensional mismatch between, say, a 1.7B parameter Qwen model and a 7B parameter Gemma model. This means you can build systems that mix and match models based on their strengths, without worrying about compatibility.

For developers building agentic systems today, the most immediately applicable lessons from this paper are these. First, consider whether your agents really need to communicate through text, or whether a more direct form of information transfer would be better. Second, think about your multi-agent system as a single unified entity that should be optimized as a whole, not as a collection of independent agents that happen to talk to each other. Third, recognize that recursive refinement — having agents iterate on their answers across multiple rounds — is a powerful and underutilized technique. And fourth, remember that the connections between agents (the "links") may be just as important as the agents themselves.

The code in this tutorial gives you a foundation to start experimenting with these ideas. The text-based orchestrator works today with any combination of local and remote models. The RecursiveLink modules and the latent-space generation code give you the building blocks for the full system when you have access to local models. And the agent abstraction makes it easy to swap in different models as you experiment.

The field of agentic AI is moving fast, and RecursiveMAS points toward a future where the boundaries between individual models and multi-agent systems become increasingly blurred. The "system" is the model, and the "model" is the system. That is a genuinely exciting direction, and one that developers with a solid understanding of the fundamentals — which you now have — are well positioned to explore.

Happy building.


APPENDIX: QUICK REFERENCE

RecursiveMAS Key Numbers (from the paper)


Metric

Value

Average accuracy improvement over baselines
8.3%
Inference speedup range1.2x – 2.4x
Token usage reduction range34.6% – 75.6%
Optimal latent thought steps (m)~80
Trainable parameters (RecursiveLinks only)13.12M (0.31% of total)
GPU memory for training15.29 GB
Benchmarks evaluated9
Collaboration patterns supported4
Model families tested4 (Qwen, LLaMA, Gemma, Mistral)

Formulas

$$R_{\text{inner}}(h) = h + W_2 \cdot \sigma(W_1 \cdot h)$$

$$R_{\text{outer}}(h) = W_3 \cdot h + W_2 \cdot \sigma(W_1 \cdot h)$$

$$L_{\text{inner}} = 1 - \cos!\left(R_{\text{inner}}(H_0),; \overline{\text{Emb}(y)}\right)$$

$$L_{\text{outer}} = \text{CrossEntropy}!\left(S^{(n)}!\left(\cdots S^{(1)}(x)\cdots\right), y\right)$$

Recommended Hyperparameters (from the paper)

HyperparameterValue
OptimizerAdamW
Learning rate1e-4 with cosine schedule
Batch size4
Temperature (reasoning tasks)0.6
Temperature (code generation)0.2
Top-p0.95
Training recursion rounds3
Inference recursion rounds3
Gradient clip norm1.0