Tuesday, June 02, 2026

ARTIFICIAL INTELLIGENCE AGENTS FOR JUPYTER NOTEBOOK GENERATION



An advancement in the realm of artificial intelligence is the emergence of LLM-based agents capable of autonomously generating executable Jupyter Notebooks from natural language prompts. These sophisticated agents empower users, from data scientists to business analysts, to rapidly prototype analyses, visualize data, and explore complex datasets without writing a single line of code themselves. This article delves deeply into the architecture, constituents, and implementation details required to construct such an agent, emphasizing support for diverse LLM deployments and hardware configurations.

INTRODUCTION

The paradigm of an LLM-based agent for Jupyter Notebook generation represents a significant leap in productivity and accessibility within data science and software development. Instead of manually crafting code, users can articulate their analytical goals or programming tasks in plain English. The agent then interprets these intentions, plans a series of actions, generates the necessary code, executes it (potentially), and compiles the results into a structured, runnable Jupyter Notebook. This capability democratizes advanced computational tasks, making them accessible to a broader audience and accelerating development cycles for experienced practitioners. The system we envision supports both local and remote Large Language Models (LLMs) and is designed to operate seamlessly across various GPU architectures including NVIDIA CUDA, AMD ROCm, Apple Metal Performance Shaders (MPS), and Intel's integrated and discrete GPUs.

CORE ARCHITECTURE OF THE NOTEBOOK GENERATION AGENT

The architecture of an LLM-based agent for generating Jupyter Notebooks is inherently modular, comprising several interconnected components that work in concert to fulfill user requests. At its heart lies an orchestration layer that leverages the reasoning capabilities of a Large Language Model. This layer interacts with a suite of specialized tools, an execution environment, and a robust LLM integration layer that abstracts away the complexities of different LLM providers and hardware backends. The final output is a well-structured Jupyter Notebook, ready for immediate use or further refinement.

Figure 1: High-Level Agent Architecture


CONSTITUENTS AND THEIR DETAILS

Let us explore each constituent of this architecture in detail.

  1. User Interface and Prompt Engineering

    The user interface serves as the primary gateway for interaction, allowing users to submit their requests in natural language. This interface can range from a simple command-line tool to a sophisticated web application. Effective prompt engineering is crucial here, as the clarity and specificity of the user's prompt directly impact the agent's ability to generate accurate and relevant notebooks. The agent's internal prompt, which guides the LLM, will often include instructions on the desired output format (e.g., Python code, markdown for explanations), available tools, and constraints.

    Example of a user prompt: "Analyze the 'sales_data.csv' file. Show the top 5 products by total sales. Create a line plot of monthly sales trends and save the notebook as 'sales_analysis.ipynb'."

  2. Agent Orchestration Layer

    This layer acts as the brain of the agent, responsible for interpreting the user's prompt, devising a plan to achieve the stated goal, executing that plan using available tools, and refining the approach based on observations. It embodies the "Plan, Act, Observe, Refine" loop.

    • Planning: The LLM analyzes the user's request and breaks it down into a sequence of smaller, manageable tasks. For instance, "analyze sales data" might become "load data", "calculate total sales per product", "identify top 5 products", "aggregate sales by month", "generate plot code", "assemble notebook".
    • Acting: The agent invokes specific tools (e.g., a code interpreter, a file reader) to perform the planned tasks. The LLM generates the arguments or code for these tools.
    • Observing: The agent receives feedback from the tools, such as the output of executed code, error messages, or data summaries.
    • Refining: Based on the observations, the LLM adjusts its plan, corrects errors, or generates further steps to move closer to the goal. This iterative process is fundamental to the agent's intelligence and robustness.

    A conceptual snippet for the agent's core loop might look like this:

    class NotebookAgent:
        def __init__(self, llm_connector, tools):
            self.llm = llm_connector
            self.tools = tools
            self.notebook_cells = [] # Stores generated cells
    
        def generate_notebook_from_prompt(self, user_prompt):
            # Initial planning phase using the LLM
            initial_plan_prompt = f"""
            You are an expert data scientist agent. Your goal is to generate a Jupyter Notebook
            that fulfills the user's request. Break down the user's request into a series of
            steps, including data loading, processing, analysis, visualization, and notebook
            assembly. List the steps clearly.
    
            User Request: {user_prompt}
            """
            plan_response = self.llm.invoke(initial_plan_prompt)
            current_plan = self._parse_plan(plan_response)
    
            for step in current_plan:
                # For each step, generate code or invoke a tool
                action_prompt = f"""
                Based on the overall plan and the current step, generate the Python code
                or specify the tool to use.
                Current Step: {step}
                Previous Cells: {self.notebook_cells}
                """
                action_response = self.llm.invoke(action_prompt)
                action_type, content = self._parse_action(action_response)
    
                if action_type == "code":
                    # Add code to notebook cells
                    self.notebook_cells.append({"cell_type": "code", "source": content})
                    # Potentially execute code and observe output for next steps
                    # output = self.tools["code_interpreter"].execute(content)
                    # self._process_observation(output)
                elif action_type == "markdown":
                    self.notebook_cells.append({"cell_type": "markdown", "source": content})
                elif action_type == "tool_invocation":
                    tool_name, tool_args = self._parse_tool_invocation(content)
                    if tool_name in self.tools:
                        tool_output = self.tools[tool_name].run(tool_args)
                        # Process tool_output, potentially add to notebook or inform LLM
                        self._process_tool_output(tool_output)
                    else:
                        print(f"Error: Tool '{tool_name}' not found.")
    
            # Final assembly and saving of the notebook
            return self._assemble_and_save_notebook(self.notebook_cells, "generated_notebook.ipynb")
    
        def _parse_plan(self, llm_output):
            # Placeholder for parsing LLM's plan output into actionable steps
            # This would typically involve more sophisticated parsing, potentially
            # using regex or another LLM call for structured output.
            print(f"Parsed plan: {llm_output}")
            return llm_output.split("\n") # Simple split for demonstration
    
        def _parse_action(self, llm_output):
            # Placeholder for parsing LLM's action output (code, markdown, tool)
            # Example: "CODE: print('Hello')" or "MARKDOWN: # Introduction"
            if llm_output.startswith("CODE:"):
                return "code", llm_output[len("CODE:"):].strip()
            elif llm_output.startswith("MARKDOWN:"):
                return "markdown", llm_output[len("MARKDOWN:"):].strip()
            elif llm_output.startswith("TOOL:"):
                # Example: TOOL: file_reader(path='data.csv')
                return "tool_invocation", llm_output[len("TOOL:"):].strip()
            return "unknown", llm_output
    
        def _process_tool_output(self, output):
            # Placeholder for processing tool output, e.g., feeding back to LLM
            print(f"Tool output processed: {output}")
    
        def _assemble_and_save_notebook(self, cells, filename):
            # This method would use nbformat to create and save the .ipynb file
            print(f"Assembling and saving notebook to {filename} with {len(cells)} cells.")
            # Actual implementation would use nbformat
            return filename
    
  3. LLM Integration Layer

    This is a critical component that abstracts the complexities of interacting with various LLMs, whether they are hosted remotely (e.g., OpenAI, Anthropic) or run locally (e.g., Llama 2, Mixtral). It also manages the underlying hardware configuration, ensuring optimal utilization of GPUs across different vendors.

    • Remote LLMs: For remote models, this layer handles API key management, rate limiting, request/response serialization, and error handling. It provides a unified interface regardless of the specific API endpoint.
    • Local LLMs: For local models, this layer manages model loading, memory allocation, and device placement. It needs to support various local inference engines and frameworks.

    The key challenge here is supporting diverse GPU architectures. This layer must intelligently detect available hardware and configure the LLM inference engine accordingly.

    • NVIDIA CUDA: The most common, typically handled by PyTorch or TensorFlow, and specific libraries like llama_cpp_python when compiled with CUDA support. Detection often involves torch.cuda.is_available().
    • AMD ROCm: AMD's open-source platform. PyTorch and TensorFlow have ROCm backends. llama_cpp_python can be compiled with ROCm support. Detection might involve checking for ROCM_PATHenvironment variables or using torch.xpu.is_available() if using Intel's oneAPI for cross-vendor support.
    • Apple MPS (Metal Performance Shaders): Apple's framework for accelerating machine learning on Apple Silicon. PyTorch supports MPS via torch.backends.mps.is_available().
    • Intel GPUs (integrated and discrete): Intel provides oneAPI and specific optimizations for PyTorch and TensorFlow. Detection might involve torch.xpu.is_available() or checking for Intel-specific libraries.

    A simplified LLMConnector class demonstrating this abstraction:

    import os
    import torch
    from openai import OpenAI
    from llama_cpp import Llama # For local GGUF models
    
    class LLMConnector:
        def __init__(self, model_type="local", model_name="llama2-7b-chat.Q4_K_M.gguf", api_key=None, base_url=None):
            self.model_type = model_type
            self.model_name = model_name
            self.api_key = api_key
            self.base_url = base_url
            self.llm_instance = None
            self._initialize_llm()
    
        def _initialize_llm(self):
            if self.model_type == "remote":
                if not self.api_key:
                    raise ValueError("API key is required for remote LLM.")
                self.llm_instance = OpenAI(api_key=self.api_key, base_url=self.base_url)
                print(f"Initialized remote LLM: {self.model_name}")
            elif self.model_type == "local":
                model_path = os.path.join("models", self.model_name)
                if not os.path.exists(model_path):
                    raise FileNotFoundError(f"Local model not found at {model_path}")
    
                # Determine GPU layers based on available hardware
                n_gpu_layers = 0
                if torch.cuda.is_available():
                    print("CUDA GPU detected.")
                    n_gpu_layers = -1 # Use all GPU layers
                elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
                    print("Apple MPS detected.")
                    n_gpu_layers = -1 # Use all GPU layers
                elif os.getenv("ROCM_PATH") or (hasattr(torch, 'xpu') and torch.xpu.is_available()):
                    # Basic check for ROCm or Intel XPU (oneAPI)
                    print("ROCm or Intel XPU detected.")
                    n_gpu_layers = -1 # Use all GPU layers
                else:
                    print("No suitable GPU detected or configured, running on CPU.")
                    n_gpu_layers = 0 # Run on CPU
    
                try:
                    self.llm_instance = Llama(
                        model_path=model_path,
                        n_ctx=4096, # Context window size
                        n_gpu_layers=n_gpu_layers, # Number of layers to offload to GPU
                        verbose=False # Suppress Llama.cpp verbose output
                    )
                    print(f"Initialized local LLM: {self.model_name} with {n_gpu_layers} GPU layers.")
                except Exception as e:
                    print(f"Error initializing local LLM: {e}. Falling back to CPU if possible.")
                    self.llm_instance = Llama(
                        model_path=model_path,
                        n_ctx=4096,
                        n_gpu_layers=0, # Force CPU
                        verbose=False
                    )
    
            else:
                raise ValueError(f"Unsupported LLM model type: {self.model_type}")
    
        def invoke(self, prompt, max_tokens=1024, temperature=0.7):
            if self.model_type == "remote":
                try:
                    response = self.llm_instance.chat.completions.create(
                        model=self.model_name,
                        messages=[{"role": "user", "content": prompt}],
                        max_tokens=max_tokens,
                        temperature=temperature
                    )
                    return response.choices[0].message.content
                except Exception as e:
                    print(f"Error invoking remote LLM: {e}")
                    raise
            elif self.model_type == "local":
                try:
                    response = self.llm_instance.create_chat_completion(
                        messages=[{"role": "user", "content": prompt}],
                        max_tokens=max_tokens,
                        temperature=temperature
                    )
                    return response["choices"][0]["message"]["content"]
                except Exception as e:
                    print(f"Error invoking local LLM: {e}")
                    raise
            return "" # Should not reach here
    

    This LLMConnector demonstrates how to abstract away the LLM interaction. For local models, it attempts to detect and utilize available GPU resources from NVIDIA, Apple, AMD, or Intel. The n_gpu_layers=-1 for llama_cpp_python is a common way to instruct it to offload as many layers as possible to the GPU. For transformers-based models, explicit device placement (model.to("cuda")model.to("mps")model.to("xpu")) would be managed within this layer.

  4. Tooling Layer

    The tooling layer provides the agent with capabilities beyond pure text generation. These tools are essentially functions or modules that the LLM can call to interact with the external environment, perform computations, or access data.

    Common tools include:

    • Code Interpreter: Executes Python code in a sandboxed environment. This is crucial for data loading, manipulation, statistical analysis, and plotting.
    • File System Access: Reads and writes files, lists directories.
    • Data Access: Connects to databases, APIs, or cloud storage.
    • Visualization Libraries: Generates plots and charts (e.g., Matplotlib, Seaborn, Plotly).
    • Internet Search: Fetches information from the web (e.g., for finding specific library usage or data formats).

    Each tool should have a clear description that the LLM can understand, along with defined input parameters and expected output formats.

    Example of a CodeInterpreter tool:

    import io
    import sys
    import traceback
    import pandas as pd # Example dependency for code execution
    
    class CodeInterpreter:
        def __init__(self, sandbox_mode=False):
            self.sandbox_mode = sandbox_mode
            self.global_vars = {} # For maintaining state across executions
            self.local_vars = {}
    
        def execute(self, code_string):
            # Redirect stdout and stderr to capture output
            old_stdout = sys.stdout
            old_stderr = sys.stderr
            redirected_output = io.StringIO()
            redirected_error = io.StringIO()
            sys.stdout = redirected_output
            sys.stderr = redirected_error
    
            try:
                # Execute code in a controlled environment
                # For true sandboxing, this would involve subprocesses, Docker, or similar.
                exec(code_string, self.global_vars, self.local_vars)
                output = redirected_output.getvalue()
                error = redirected_error.getvalue()
                if error:
                    return f"ERROR: {error}\nOUTPUT: {output}"
                return f"SUCCESS: {output}"
            except Exception as e:
                error_traceback = traceback.format_exc()
                return f"EXECUTION FAILED: {error_traceback}\nOUTPUT: {redirected_output.getvalue()}"
            finally:
                # Restore stdout and stderr
                sys.stdout = old_stdout
                sys.stderr = old_stderr
    
    # Example usage within the agent
    # code_interpreter = CodeInterpreter()
    # result = code_interpreter.execute("import pandas as pd\ndf = pd.DataFrame({'col': [1,2,3]})\nprint(df)")
    # print(result)
    

    For production environments, the CodeInterpreter must be robustly sandboxed, perhaps by running code in a separate process, a Docker container, or a dedicated Jupyter kernel managed via jupyter_client. This prevents malicious code execution and isolates dependencies.

  5. Notebook Generation Logic

    Once the agent has generated code snippets, markdown explanations, and potentially executed some steps to gather results, these pieces need to be assembled into a coherent Jupyter Notebook. The nbformat library is the standard Python library for reading, writing, and manipulating .ipynb files.

    The agent will construct a list of notebook cells, each containing either code or markdown. For code cells, it might also include execution outputs if the code was run internally for verification or to provide context.

    import nbformat
    from nbformat.v4 import new_notebook, new_code_cell, new_markdown_cell
    
    class NotebookAssembler:
        def __init__(self):
            pass
    
        def assemble_notebook(self, cells, filename="generated_notebook.ipynb"):
            """
            Assembles a list of cells into a Jupyter Notebook file.
    
            Args:
                cells (list): A list of dictionaries, each representing a cell.
                              Example: [{"cell_type": "code", "source": "print('Hello')"},
                                        {"cell_type": "markdown", "source": "# Introduction"}]
                filename (str): The name of the output .ipynb file.
            Returns:
                str: The path to the generated notebook file.
            """
            notebook = new_notebook()
            for cell_data in cells:
                if cell_data["cell_type"] == "code":
                    cell = new_code_cell(cell_data["source"])
                    # If execution outputs were captured, they could be added here
                    # cell.outputs = [...]
                elif cell_data["cell_type"] == "markdown":
                    cell = new_markdown_cell(cell_data["source"])
                else:
                    print(f"Warning: Unknown cell type '{cell_data['cell_type']}', skipping.")
                    continue
                notebook.cells.append(cell)
    
            try:
                with open(filename, 'w', encoding='utf-8') as f:
                    nbformat.write(notebook, f)
                print(f"Notebook successfully saved to {filename}")
                return filename
            except Exception as e:
                print(f"Error saving notebook to {filename}: {e}")
                raise
    
  6. Execution Environment (for Verification and Testing)

    While the agent generates the notebook, it's often beneficial for it to execute portions of the generated code internally to verify correctness, gather outputs, and inform subsequent steps. This execution must occur in a controlled, sandboxed environment to prevent security risks and manage dependencies.

    Options for an execution environment include:

    • Isolated Python subprocess calls.
    • Dedicated Docker containers, providing strong isolation.
    • Using jupyter_client to programmatically interact with a Jupyter kernel. This allows executing cells and capturing rich outputs, similar to how a user would interact with a notebook.

    The execution environment should also manage dependencies. Before running generated code, it might need to install required libraries (e.g., pandasmatplotlib).

  7. GPU/Hardware Abstraction

    As highlighted in the LLM Integration Layer, supporting diverse GPU architectures is paramount for broad applicability. The strategy involves:

    • Detection: Programmatically identify the available hardware (NVIDIA, AMD, Apple, Intel). Libraries like torch offer functions such as torch.cuda.is_available()torch.backends.mps.is_available(), and potentially torch.xpu.is_available() for Intel/ROCm. Environment variables like ROCM_PATH can also be indicative.
    • Configuration: Based on detection, configure the LLM inference engine.
      • For llama_cpp_python, this means setting n_gpu_layers appropriately during Llama object instantiation.
      • For transformers models, it involves moving the model to the correct device: model.to("cuda")model.to("mps")model.to("xpu"), or model.to("cpu").
      • For models that require specific backend installations (e.g., PyTorch with ROCm or Intel oneAPI), the system should guide the user on prerequisites or attempt to use a CPU fallback if GPU is unavailable or misconfigured.
    • Fallback: Always provide a CPU fallback mechanism if GPU acceleration is not available or encounters errors. This ensures the agent remains functional, albeit with potentially slower performance.

DETAILED IMPLEMENTATION ASPECTS

  1. Prompt Design for Notebook Generation

    Crafting effective prompts for the LLM is an art and a science. The agent's internal prompt should guide the LLM to:

    • Understand the user's intent.
    • Identify necessary tools.
    • Generate correct and executable code.
    • Provide clear markdown explanations.
    • Format the output appropriately for notebook cells.
    • Handle potential errors gracefully.

    The prompt should include:

    • Role definition: "You are an expert Python programmer and data scientist."
    • Task description: "Your goal is to generate a Jupyter Notebook to analyze data."
    • Available tools: "You have access to a CodeInterpreter tool to run Python code and a NotebookAssemblerto save the final notebook."
    • Output format instructions: "Generate code cells prefixed with 'CODE:' and markdown cells with 'MARKDOWN:'. If you need to execute code to get information, use the CodeInterpreter and respond with the output."
    • Constraints: "Ensure all necessary imports are at the beginning of the code cells. Provide comments for complex logic."
  2. Agent Loop: Plan, Act, Observe, Refine

    The iterative nature of the agent's operation is key to its intelligence.

    • Plan: The LLM generates a high-level plan.
    • Act: The LLM generates code or tool calls based on the plan.
    • Observe: The CodeInterpreter or other tools execute the action and return results (output, errors, data).
    • Refine: The LLM analyzes the observations. If successful, it proceeds to the next plan step. If an error occurs, it attempts to debug and correct the code, or adjust the plan. This feedback loop is what makes the agent robust.
  3. Code Execution and Sandboxing

    As previously discussed, executing arbitrary code generated by an LLM requires strict sandboxing.

    • Security: Prevent access to sensitive files, network resources, or system commands. Docker containers are an excellent solution for this, providing strong isolation.
    • Dependency Management: Each execution environment should have its own set of dependencies. The agent might need to infer and install required libraries (e.g., pip install pandas) before running the analysis code.
    • State Management: For a multi-step analysis, the execution environment needs to maintain state (e.g., variables defined in one cell should be accessible in subsequent cells). This is naturally handled by a single Jupyter kernel or by passing state explicitly between sandbox runs.
  4. Handling Dependencies

    The generated notebooks will inevitably rely on various Python libraries (e.g., pandasmatplotlibscikit-learn). The agent should:

    • Explicitly include import statements in the generated code.
    • Potentially suggest or automatically add !pip install <library_name> commands in the notebook's initial cells if it detects missing dependencies in the execution environment.
    • The execution environment itself must be configured with common data science libraries or have the capability to install them on demand.
  5. Error Handling and Debugging

    LLMs can make mistakes. The agent must be designed to handle errors gracefully.

    • Capture Errors: The CodeInterpreter must capture stdoutstderr, and exceptions.
    • Feedback to LLM: Error messages and stack traces should be fed back to the LLM in the "Observe" phase.
    • Correction Loop: The LLM should then attempt to debug the code, generate a corrected version, or modify its plan. This might involve prompting the LLM with the error message and the problematic code, asking it to identify and fix the issue.
    • User Notification: If the agent cannot resolve an error after several attempts, it should inform the user.
  6. Security Considerations

    Running LLM-generated code poses significant security risks.

    • Sandboxing: This is the most critical measure. Isolate code execution in containers or virtual machines.
    • Resource Limits: Limit CPU, memory, and execution time to prevent denial-of-service attacks or runaway processes.
    • Input Validation: While the agent processes natural language, any direct file paths or external resource URLs provided by the user or generated by the LLM should be carefully validated.
    • Least Privilege: The execution environment should run with the minimum necessary permissions.

CONCLUSION

Building an LLM-based agent for Jupyter Notebook generation is a complex yet highly rewarding endeavor. By meticulously designing the agent's orchestration, abstracting LLM interactions, providing robust tooling, and ensuring secure code execution across diverse hardware, we can create a powerful system that significantly enhances productivity and accessibility in data science and development. The ability to seamlessly switch between local and remote LLMs, coupled with comprehensive GPU support, ensures the agent's versatility and performance for a wide range of users and computational environments. Such an agent moves us closer to a future where natural language is a primary interface for complex computational tasks, empowering more individuals to harness the power of data and AI.

ADDENDUM: FULL RUNNING EXAMPLE

This full running example demonstrates a complete NotebookAgent that can process a user prompt, generate Python code, and assemble a Jupyter Notebook. It includes the LLMConnectorCodeInterpreter, and NotebookAssemblercomponents, integrated into a cohesive system. For the purpose of this running example, the LLMConnector will be configured to use a local LLM, and the CodeInterpreter will run in a simplified in-process mode for demonstration, but with the understanding that a production system would require robust sandboxing.

First, ensure you have the necessary libraries installed: pip install openai llama-cpp-python nbformat pandas matplotlib

You will also need a local GGUF model file, for example, llama2-7b-chat.Q4_K_M.gguf. Place this file in a directory named models relative to where your script runs. You can download such models from Hugging Face (e.g., TheBloke's repositories).

We will use a simple sales_data.csv file for our example. Create this file in the same directory as your Python script:

sales_data.csv

Date,Product,Sales 2023-01-01,Product A,100 2023-01-01,Product B,150 2023-01-02,Product A,120 2023-01-02,Product C,200 2023-01-03,Product B,180 2023-01-03,Product A,110 2023-02-01,Product A,90 2023-02-01,Product C,250 2023-02-02,Product B,160 2023-02-02,Product A,130 2023-03-01,Product A,110 2023-03-01,Product B,170

Now, here is the complete Python code for the agent:

import os
import io
import sys
import traceback
import pandas as pd
import matplotlib.pyplot as plt
import nbformat
from nbformat.v4 import new_notebook, new_code_cell, new_markdown_cell
import torch
from openai import OpenAI
from llama_cpp import Llama # For local GGUF models

# --- 1. LLM Integration Layer ---
class LLMConnector:
    """
    Connects to various LLMs, abstracting local and remote inference.
    Handles GPU detection and configuration for local models.
    """
    def __init__(self, model_type="local", model_name="llama2-7b-chat.Q4_K_M.gguf", api_key=None, base_url=None):
        self.model_type = model_type
        self.model_name = model_name
        self.api_key = api_key
        self.base_url = base_url
        self.llm_instance = None
        self._initialize_llm()

    def _initialize_llm(self):
        """Initializes the LLM instance based on model_type."""
        if self.model_type == "remote":
            if not self.api_key:
                raise ValueError("API key is required for remote LLM.")
            # If base_url is provided, it can be a custom endpoint (e.g., local OpenAI-compatible server)
            self.llm_instance = OpenAI(api_key=self.api_key, base_url=self.base_url)
            print(f"Initialized remote LLM: {self.model_name}")
        elif self.model_type == "local":
            model_path = os.path.join("models", self.model_name)
            if not os.path.exists(model_path):
                raise FileNotFoundError(f"Local model not found at {model_path}. Please download it and place it in the 'models' directory.")

            # Determine GPU layers based on available hardware
            n_gpu_layers = 0
            if torch.cuda.is_available():
                print("CUDA GPU detected. Using all GPU layers.")
                n_gpu_layers = -1 # Use all GPU layers
            elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
                print("Apple MPS detected. Using all GPU layers.")
                n_gpu_layers = -1 # Use all GPU layers
            elif os.getenv("ROCM_PATH") or (hasattr(torch, 'xpu') and torch.xpu.is_available()):
                # Basic check for ROCm (AMD) or Intel XPU (oneAPI)
                # Note: Full ROCm/Intel support with llama.cpp requires specific compilation.
                # This check is a best effort.
                print("ROCm or Intel XPU detected. Attempting to use all GPU layers.")
                n_gpu_layers = -1 # Use all GPU layers
            else:
                print("No suitable GPU detected or configured for local LLM. Running on CPU.")
                n_gpu_layers = 0 # Run on CPU

            try:
                self.llm_instance = Llama(
                    model_path=model_path,
                    n_ctx=4096, # Context window size, adjust as needed
                    n_gpu_layers=n_gpu_layers, # Number of layers to offload to GPU
                    verbose=False # Suppress Llama.cpp verbose output
                )
                print(f"Initialized local LLM: {self.model_name} with {n_gpu_layers} GPU layers.")
            except Exception as e:
                print(f"Error initializing local LLM with GPU support: {e}. Falling back to CPU.")
                self.llm_instance = Llama(
                    model_path=model_path,
                    n_ctx=4096,
                    n_gpu_layers=0, # Force CPU
                    verbose=False
                )
        else:
            raise ValueError(f"Unsupported LLM model type: {self.model_type}. Choose 'local' or 'remote'.")

    def invoke(self, prompt, max_tokens=1024, temperature=0.7):
        """
        Invokes the LLM with the given prompt.
        """
        messages = [{"role": "user", "content": prompt}]
        if self.model_type == "remote":
            try:
                response = self.llm_instance.chat.completions.create(
                    model=self.model_name,
                    messages=messages,
                    max_tokens=max_tokens,
                    temperature=temperature
                )
                return response.choices[0].message.content
            except Exception as e:
                print(f"Error invoking remote LLM: {e}")
                raise
        elif self.model_type == "local":
            try:
                response = self.llm_instance.create_chat_completion(
                    messages=messages,
                    max_tokens=max_tokens,
                    temperature=temperature
                )
                return response["choices"][0]["message"]["content"]
            except Exception as e:
                print(f"Error invoking local LLM: {e}")
                raise
        return "" # Should not be reached

# --- 2. Tooling Layer: Code Interpreter ---
class CodeInterpreter:
    """
    Executes Python code in a controlled environment.
    For production, this should be a sandboxed subprocess or Docker container.
    """
    def __init__(self):
        # Global and local variables for maintaining execution state
        self.global_vars = {'pd': pd, 'plt': plt} # Pre-import common libraries
        self.local_vars = {}

    def execute(self, code_string):
        """
        Executes the given Python code string and captures its output.
        """
        old_stdout = sys.stdout
        old_stderr = sys.stderr
        redirected_output = io.StringIO()
        redirected_error = io.StringIO()
        sys.stdout = redirected_output
        sys.stderr = redirected_error

        try:
            # Execute code in the current process's namespace (simplified sandboxing)
            exec(code_string, self.global_vars, self.local_vars)
            output = redirected_output.getvalue()
            error = redirected_error.getvalue()
            if error:
                return f"EXECUTION ERROR (stderr):\n{error}\nOUTPUT (stdout):\n{output}"
            return f"EXECUTION SUCCESS:\n{output}"
        except Exception as e:
            error_traceback = traceback.format_exc()
            return f"EXECUTION FAILED (exception):\n{error_traceback}\nOUTPUT (stdout):\n{redirected_output.getvalue()}"
        finally:
            sys.stdout = old_stdout
            sys.stderr = old_stderr

# --- 3. Notebook Generation Logic ---
class NotebookAssembler:
    """
    Assembles a list of cells into a Jupyter Notebook (.ipynb) file.
    """
    def __init__(self):
        pass

    def assemble_notebook(self, cells, filename="generated_notebook.ipynb"):
        """
        Assembles a list of cell data into a Jupyter Notebook file.

        Args:
            cells (list): A list of dictionaries, each representing a cell.
                          Example: [{"cell_type": "code", "source": "print('Hello')"},
                                    {"cell_type": "markdown", "source": "# Introduction"}]
            filename (str): The name of the output .ipynb file.
        Returns:
            str: The path to the generated notebook file.
        """
        notebook = new_notebook()
        for cell_data in cells:
            if cell_data["cell_type"] == "code":
                cell = new_code_cell(cell_data["source"])
                # In a more advanced system, outputs from CodeInterpreter could be added here
                if "outputs" in cell_data:
                    cell.outputs = cell_data["outputs"]
            elif cell_data["cell_type"] == "markdown":
                cell = new_markdown_cell(cell_data["source"])
            else:
                print(f"Warning: Unknown cell type '{cell_data['cell_type']}', skipping.")
                continue
            notebook.cells.append(cell)

        try:
            with open(filename, 'w', encoding='utf-8') as f:
                nbformat.write(notebook, f)
            print(f"Notebook successfully saved to {filename}")
            return filename
        except Exception as e:
            print(f"Error saving notebook to {filename}: {e}")
            raise

# --- 4. Agent Orchestration Layer ---
class NotebookAgent:
    """
    Orchestrates the LLM, tools, and notebook assembly to generate Jupyter Notebooks.
    """
    def __init__(self, llm_connector, code_interpreter, notebook_assembler):
        self.llm = llm_connector
        self.code_interpreter = code_interpreter
        self.notebook_assembler = notebook_assembler
        self.notebook_cells = [] # Stores generated cells
        self.conversation_history = [] # For maintaining context with the LLM

    def _add_to_history(self, role, content):
        """Adds a message to the conversation history."""
        self.conversation_history.append({"role": role, "content": content})

    def _get_full_prompt(self, current_instruction):
        """Constructs the full prompt including history and current instruction."""
        # This is a simplified approach; for production, a more sophisticated
        # prompt engineering strategy (e.g., few-shot examples, specific tool descriptions)
        # would be used.
        base_prompt = """
        You are an expert Python programmer and data scientist. Your goal is to generate a
        Jupyter Notebook based on the user's request. You have access to a CodeInterpreter
        tool to execute Python code and observe its output. You must generate code cells
        and markdown cells.

        Instructions:
        1.  Start with a markdown introduction.
        2.  For each step, generate the necessary Python code.
        3.  If you need to verify code or get data, use the CodeInterpreter tool by
            outputting "TOOL_CODE_EXEC:<your python code here>". The output of the tool
            will be provided to you.
        4.  If you want to output a code cell for the notebook, use "NOTEBOOK_CODE:<your python code here>".
        5.  If you want to output a markdown cell for the notebook, use "NOTEBOOK_MARKDOWN:<your markdown content here>".
        6.  Ensure all necessary imports are at the beginning of relevant code cells.
        7.  Provide explanations in markdown cells for each code block.
        8.  Do not include any `!pip install` commands in the generated code, assume libraries are available.
        9.  After completing the task, indicate completion with "TASK_COMPLETE".

        Current Notebook Cells (so far):
        """
        current_cells_str = "\n".join([f"  - {c['cell_type'].upper()}: {c['source'][:50]}..." for c in self.notebook_cells])
        if not current_cells_str:
            current_cells_str = "  (No cells yet)"

        history_str = "\n".join([f"{msg['role'].upper()}: {msg['content']}" for msg in self.conversation_history])

        return f"{base_prompt}\n{current_cells_str}\n\n{history_str}\n\nUSER_INSTRUCTION: {current_instruction}\n\nYOUR_RESPONSE:"

    def generate_notebook_from_prompt(self, user_prompt, output_filename="generated_notebook.ipynb"):
        """
        Generates a Jupyter Notebook based on the user's natural language prompt.
        """
        print(f"Agent received prompt: '{user_prompt}'")
        self._add_to_history("user", user_prompt)

        max_iterations = 15 # Prevent infinite loops
        iteration = 0
        task_completed = False

        while iteration < max_iterations and not task_completed:
            iteration += 1
            print(f"\n--- Agent Iteration {iteration} ---")
            current_instruction = f"Continue generating the notebook based on the user's request: '{user_prompt}'. " \
                                  f"Current state: {len(self.notebook_cells)} cells generated." \
                                  f"If the task is complete, output 'TASK_COMPLETE'."

            full_llm_prompt = self._get_full_prompt(current_instruction)
            llm_response = self.llm.invoke(full_llm_prompt, max_tokens=2048, temperature=0.2)
            self._add_to_history("assistant", llm_response)
            print(f"LLM Response:\n{llm_response}")

            if "TASK_COMPLETE" in llm_response:
                task_completed = True
                print("LLM indicated task completion.")
                break

            # Process LLM's response for actions
            lines = llm_response.strip().split('\n')
            action_taken = False
            for line in lines:
                if line.startswith("NOTEBOOK_MARKDOWN:"):
                    markdown_content = line[len("NOTEBOOK_MARKDOWN:"):].strip()
                    self.notebook_cells.append({"cell_type": "markdown", "source": markdown_content})
                    print(f"Added MARKDOWN cell: {markdown_content[:50]}...")
                    action_taken = True
                elif line.startswith("NOTEBOOK_CODE:"):
                    code_content = line[len("NOTEBOOK_CODE:"):].strip()
                    self.notebook_cells.append({"cell_type": "code", "source": code_content})
                    print(f"Added CODE cell: {code_content[:50]}...")
                    action_taken = True
                elif line.startswith("TOOL_CODE_EXEC:"):
                    code_to_execute = line[len("TOOL_CODE_EXEC:"):].strip()
                    print(f"Executing code with CodeInterpreter: {code_to_execute[:100]}...")
                    execution_result = self.code_interpreter.execute(code_to_execute)
                    self._add_to_history("tool_output", execution_result)
                    print(f"CodeInterpreter Output:\n{execution_result[:200]}...") # Limit output for console
                    action_taken = True
                # Handle cases where LLM might just output text without a specific tag
                elif not line.strip().startswith(("NOTEBOOK_", "TOOL_CODE_EXEC:", "TASK_COMPLETE")):
                    # If it's not a recognized command, treat as a general comment or instruction for next turn
                    pass # The LLM's response is already in history, it will see it next turn.

            if not action_taken and not task_completed:
                print("LLM did not provide a recognized action. Will re-prompt.")
                # This might indicate the LLM is stuck or needs more specific guidance.
                # In a real system, this might trigger an error or a more direct prompt to the LLM.

        if not task_completed:
            print("Agent reached maximum iterations without completing the task.")

        # Final assembly and saving
        if self.notebook_cells:
            final_notebook_path = self.notebook_assembler.assemble_notebook(self.notebook_cells, output_filename)
            print(f"Notebook generation complete. Saved to: {final_notebook_path}")
            return final_notebook_path
        else:
            print("No cells were generated for the notebook.")
            return None

# --- Main Execution Block ---
if __name__ == "__main__":
    # Ensure 'models' directory exists for local LLM
    if not os.path.exists("models"):
        os.makedirs("models")
        print("Created 'models' directory. Please place your GGUF model file (e.g., llama2-7b-chat.Q4_K_M.gguf) inside it.")
        sys.exit(1) # Exit if model not present

    # Create a dummy sales_data.csv for the example
    sales_data_content = """Date,Product,Sales
2023-01-01,Product A,100
2023-01-01,Product B,150
2023-01-02,Product A,120
2023-01-02,Product C,200
2023-01-03,Product B,180
2023-01-03,Product A,110
2023-02-01,Product A,90
2023-02-01,Product C,250
2023-02-02,Product B,160
2023-02-02,Product A,130
2023-03-01,Product A,110
2023-03-01,Product B,170
"""
    with open("sales_data.csv", "w") as f:
        f.write(sales_data_content)
    print("Created 'sales_data.csv' for the example.")

    # --- Configuration ---
    # Choose 'local' or 'remote'
    # For 'remote', provide your OpenAI API key and model name
    # For 'local', ensure your GGUF model is in the 'models' directory
    LLM_CONFIG = {
        "type": "local",
        "model_name": "llama2-7b-chat.Q4_K_M.gguf", # Replace with your model if different
        "api_key": os.getenv("OPENAI_API_KEY"), # Only needed for remote
        "base_url": None # For custom OpenAI-compatible endpoints
    }

    print("\nInitializing LLM Connector...")
    llm_connector = LLMConnector(
        model_type=LLM_CONFIG["type"],
        model_name=LLM_CONFIG["model_name"],
        api_key=LLM_CONFIG["api_key"],
        base_url=LLM_CONFIG["base_url"]
    )

    print("\nInitializing Code Interpreter and Notebook Assembler...")
    code_interpreter = CodeInterpreter()
    notebook_assembler = NotebookAssembler()

    print("\nInitializing Notebook Agent...")
    agent = NotebookAgent(llm_connector, code_interpreter, notebook_assembler)

    user_request = "Analyze 'sales_data.csv'. Show the top 5 products by total sales. Create a line plot of monthly sales trends. Save the notebook as 'sales_analysis.ipynb'."
    print(f"\nUser Request: {user_request}")

    generated_notebook_path = agent.generate_notebook_from_prompt(user_request, "sales_analysis.ipynb")

    if generated_notebook_path:
        print(f"\nSuccessfully generated notebook: {generated_notebook_path}")
        print("You can now open 'sales_analysis.ipynb' with Jupyter Lab or Jupyter Notebook.")
    else:
        print("\nNotebook generation failed or no cells were produced.")

Monday, June 01, 2026

RecursiveMAS: Teaching AI Agents to Think Together in Secret


INTRODUCTION: THE PROBLEM WITH CHATTY AGENTS

If you have spent any time building multi-agent AI systems, you have probably run into the same frustrating pattern. You wire up a Planner agent, a Critic agent, and a Solver agent. The Planner writes a plan in plain English. The Critic reads that plan, writes a critique in plain English. The Solver reads the critique, writes an answer in plain English. The whole thing feels elegant on a whiteboard. Then you run it and discover that your system is spending enormous amounts of time and compute budget just generating, tokenizing, and re-reading intermediate text that nobody outside the system ever sees.

It is a bit like watching a relay race where each runner, instead of simply passing the baton, stops to write a detailed memo about the baton, hands the memo to the next runner, who then reads the memo, writes their own memo summarizing the first memo, and only then starts running. The overhead is absurd. The information is there — it is just wrapped in an expensive, lossy, text-shaped package.

This is the problem that the paper "RecursiveMAS: Scaling Agent Collaboration through Unified Latent-Space Recursive Computation" (see also the additional documentation and the official GitHub repository) sets out to solve. The authors, from Stanford and UIUC, ask a beautifully simple question: what if agents could skip the memo entirely and just pass the raw thought?

The answer they arrive at is RecursiveMAS, a framework that lets heterogeneous LLM agents collaborate entirely in the continuous vector space — the latent space — that lives inside the models themselves, without converting intermediate reasoning into text until the very last moment. The results are striking: an average accuracy improvement of 8.3% over strong baselines, inference speedups of 1.2x to 2.4x, and token usage reductions of 34.6% to 75.6% across nine benchmarks covering mathematics, science, medicine, search, and code generation.

This tutorial will walk you through every piece of the system, from the conceptual foundations to the mathematical details to working code that you can run against both local models (via Ollama or Hugging Face Transformers) and remote LLM APIs (via OpenAI-compatible endpoints). By the end, you will understand not just what RecursiveMAS does, but why it works, and how you can start building systems inspired by its ideas today.


INSTALLATION AND PROJECT SETUP

Before diving into the code, here is everything you need to install and how to structure your project.

Requirements

# requirements.txt
torch>=2.1.0
transformers>=4.40.0
accelerate>=0.27.0
requests>=2.31.0
sentencepiece>=0.1.99
protobuf>=3.20.0

Install with:

pip install -r requirements.txt

Optional: Ollama (for local model serving without writing HF loading code)

# Install Ollama from https://ollama.ai, then:
ollama serve                          # start the server (keep this running)
ollama pull qwen2.5:1.5b              # ~1 GB, fast on CPU
ollama pull llama3.2:1b               # ~700 MB, very fast on CPU
ollama pull qwen2.5:7b                # ~4.5 GB, good quality, needs GPU

Project File Structure

recursive_mas/
├── requirements.txt
├── recursive_mas.py          # All core classes (the full combined file)
├── demo.py                   # Demo script that imports from recursive_mas.py
└── README.md

All code blocks in this tutorial belong in recursive_mas.py in the order they appear, followed by the demo code in demo.py. A complete, self-contained combined file is provided at the end of Part Six.


CHAPTER ONE: UNDERSTANDING THE LANDSCAPE

Before we dive into RecursiveMAS itself, we need to make sure we share a common vocabulary. If you have built agentic systems before, some of this will be review, but the framing matters for what comes later.

What Is a Multi-Agent System, Really?

A multi-agent system (MAS) is a collection of individual language model agents, each assigned a distinct role or area of expertise, that collaborate to solve a task that would be difficult or impossible for any single agent alone. The intuition is that specialization helps. A model fine-tuned for mathematical reasoning will outperform a generalist model on math problems. A model trained on biomedical literature will do better on medical questions. By combining specialists, you get a system that is smarter than any of its parts.

The paper formalizes this nicely. You have a system S composed of N agents A₁, A₂, ..., Aâ‚™. Each agent Aáµ¢ has its own parameters and its own last-layer hidden representations. The system maintains a collective latent state, which is the combined internal representation of what all agents currently "know" about the problem. Given an input question x with a ground truth answer y, the system orchestrates interactions among agents to collaboratively produce a final prediction.

The key insight that motivates the whole paper is captured in what the authors call Recursive Multi-Agent Evolution: a recursive evolution is the progressive refinement of the collective latent state, where each agent adjusts its latent representation and its own reasoning state so that the updated system is better aligned for the given problem. In other words, the system should get smarter with each round of interaction, not just produce one answer and stop.

The Four Collaboration Patterns

The paper identifies four archetypal ways that agents can collaborate, and RecursiveMAS is designed to work with all of them. Understanding these patterns is essential because the framework is deliberately structure-agnostic — it does not care how you arrange your agents, it just makes whatever arrangement you choose work better.

The Sequential Style arranges agents in a chain, where each agent builds on the work of the previous one. The paper uses a Planner–Critic–Solver arrangement: the Planner decomposes the problem into a step-by-step plan, the Critic evaluates that plan and identifies weaknesses, and the Solver uses the refined plan to produce a final answer. This is the most common pattern in practice and the one the paper uses for its primary experiments.

The Mixture Style runs multiple domain-specialized agents in parallel, then aggregates their outputs. The paper uses Math, Code, and Science specialists whose outputs are combined by a Summarizer agent. This pattern is powerful when you do not know in advance which domain a question belongs to, or when a question genuinely spans multiple domains.

The Distillation Style pairs a large, capable Expert model with a smaller, faster Learner model. The Expert provides rich guidance; the Learner absorbs it and produces the final answer more efficiently. This is essentially knowledge distillation happening at inference time, not just training time. The paper shows that RecursiveMAS can improve the Learner by 8.0% while retaining a 1.5x speed advantage over the Expert alone.

The Deliberation Style pairs an inner-thinking Reflector with a Tool-Caller that can invoke external tools like Python interpreters or search APIs. The two agents iteratively exchange, critique, and refine candidate solutions until they reach consensus, after which the Tool-Caller produces the final answer. This is the most complex pattern and the most interesting for building real-world agentic systems.

The Problem with Text-Based Communication

Here is where things get interesting. In all four patterns above, the traditional approach has agents communicate by generating text. Agent A₁ produces a text response, which is fed as a prompt to Agent A₂, which produces another text response, and so on. This seems natural — text is the universal interface of language models, after all.

But this approach has two serious problems that the paper addresses with mathematical rigor.

The first problem is computational efficiency. When an intermediate agent generates text, it must run the full vocabulary projection layer (which maps from the hidden dimension to a vocabulary of tens of thousands of tokens), sample a token, and then the next agent must re-embed that token back into the hidden space. This decode-then-re-encode cycle is expensive. The paper proves formally that text-based recursive MAS has higher runtime complexity than latent-space-based RecursiveMAS, because RecursiveMAS replaces the O(|V|) vocabulary projection with a small linear transformation over the hidden dimension d_h, which is much smaller than the vocabulary size |V|.

The second problem is gradient vanishing during training. When you try to train a text-based multi-agent system end-to-end, the gradients have to flow backward through the discrete token sampling operation. Discrete sampling is not differentiable, so in practice the gradient must pass through the softmax distribution over the vocabulary. The paper proves that when tokens are generated with high confidence (which is exactly what you want from a good model), the softmax distribution becomes very peaked, its covariance matrix becomes nearly singular, and the gradient norm collapses toward zero. The RecursiveLink, by contrast, maintains gradient norms that are bounded away from zero by a quantity that depends on the hidden dimension, not on token confidence. This means training actually works.

These two theoretical results — the complexity proposition and the gradient stability theorem — are not just academic decoration. They are the mathematical justification for why the whole system is designed the way it is.


CHAPTER TWO: THE ARCHITECTURE OF RECURSIVEMAS

Now let us get into the actual machinery. RecursiveMAS has three main components: the RecursiveLink module (which comes in two flavors, inner and outer), the latent thoughts generation process, and the recursive loop that chains everything together.

The RecursiveLink: The Heart of the System

The RecursiveLink is a small, lightweight neural module — just two linear layers with a residual connection and a GELU activation — that serves as the bridge between agents. It is the only part of the system that gets trained. All the large LLM agent parameters are frozen. This is a crucial design choice: you do not need to retrain your expensive 7B or 13B parameter models. You just train a tiny adapter that knows how to translate between their hidden spaces.

There are two variants of the RecursiveLink, and understanding the difference between them is key to understanding the whole architecture.

The Inner RecursiveLink operates within a single agent. Its job is to take the last-layer hidden state that the agent produces at one autoregressive step and transform it into an input embedding for the next step, so that the agent can continue reasoning in the latent space without ever decoding to text. The formula is:

$$R_{\text{inner}}(h) = h + W_2 \cdot \sigma(W_1 \cdot h)$$

where (h) is the current last-layer hidden state vector, (W_1) and (W_2) are learned linear layers, (\sigma) is the GELU activation function, and the addition of (h) on the left is the residual connection. The residual connection is not just a nice-to-have — it is architecturally important. By adding the original (h) back, the module is forced to learn only the distributional shift (the difference between where the hidden state lives and where the input embedding space lives), rather than learning the entire transformation from scratch. This makes training more stable and faster to converge.

The Outer RecursiveLink operates between agents. Its job is to take the hidden states produced by one agent and transform them into input embeddings for a different agent, which may have a completely different hidden dimension. This is the "heterogeneous" part of the framework — you can connect a 1.7B parameter model to a 7B parameter model, and the outer link handles the dimensional mismatch. The formula adds one more linear layer:

$$R_{\text{outer}}(h) = W_3 \cdot h + W_2 \cdot \sigma(W_1 \cdot h)$$

The difference from the inner link is the (W_3 \cdot h) term in the residual branch. In the inner link, the residual is just (h) (the identity), because the source and target spaces have the same dimension. In the outer link, (W_3) is a linear projection that maps from the source agent's hidden dimension to the target agent's hidden dimension, so the residual branch also performs the dimensional alignment. The nonlinear branch (W_2 \cdot \sigma(W_1 \cdot h)) then learns the fine-grained distributional correction on top of that linear alignment.

Let us look at how this translates into code. The following implementation works with PyTorch and is designed to be clean, readable, and easy to extend:

import torch
import torch.nn as nn
import torch.nn.functional as F


class InnerRecursiveLink(nn.Module):
    """
    The Inner RecursiveLink operates within a single LLM agent.

    It transforms the agent's last-layer hidden state at step t into
    an input embedding for step t+1, enabling the agent to reason
    in the continuous latent space without decoding to text.

    The residual connection is critical: it forces the module to learn
    only the distributional shift, not the full transformation.
    This leads to more stable gradients and faster convergence.

    Architecture:
        R_inner(h) = h + W2 * GELU(W1 * h)

    At initialization, W2 is set to all zeros so the module starts as
    a pure identity transformation (output == input). Training then
    learns the residual correction on top of this stable baseline.

    Args:
        hidden_dim: The hidden dimension of the LLM agent this link
                    is paired with. Must match the model's d_model.
    """

    def __init__(self, hidden_dim: int):
        super().__init__()
        self.hidden_dim = hidden_dim

        # W1: first linear layer, maps hidden_dim -> hidden_dim
        self.W1 = nn.Linear(hidden_dim, hidden_dim, bias=True)

        # W2: second linear layer, maps hidden_dim -> hidden_dim
        self.W2 = nn.Linear(hidden_dim, hidden_dim, bias=True)

        self._initialize_weights()

    def _initialize_weights(self):
        """
        Initialize weights so the module starts as a near-identity
        transformation.

        W1 is initialized with small random values (gain=0.1) so
        GELU(W1*h) produces small activations initially.
        W2 is initialized to zero so the entire nonlinear branch
        outputs zero at the start, making R_inner(h) = h + 0 = h.
        This identity-at-init property ensures stable early training.
        """
        nn.init.xavier_uniform_(self.W1.weight, gain=0.1)
        nn.init.zeros_(self.W1.bias)
        nn.init.zeros_(self.W2.weight)
        nn.init.zeros_(self.W2.bias)

    def forward(self, h: torch.Tensor) -> torch.Tensor:
        """
        Transform a last-layer hidden state into the next input embedding.

        Args:
            h: Hidden state tensor.
               Shape: (batch_size, hidden_dim) for a single step, or
                      (batch_size, seq_len, hidden_dim) for a sequence.

        Returns:
            Transformed embedding of the same shape as h, ready to be
            used as input to the next autoregressive step.
        """
        # Nonlinear branch learns the distributional correction
        correction = self.W2(F.gelu(self.W1(h)))
        # Residual connection preserves the original latent semantics
        return h + correction


class OuterRecursiveLink(nn.Module):
    """
    The Outer RecursiveLink bridges two heterogeneous LLM agents.

    It transforms the last-layer hidden states of a source agent into
    input embeddings aligned with the target agent's embedding space.
    This enables seamless cross-agent latent state transfer even when
    the two agents have different hidden dimensions (e.g., a 1.7B model
    talking to a 7B model).

    Architecture:
        R_outer(h) = W3 * h + W2 * GELU(W1 * h)

    The W3 term in the residual branch handles the dimensional alignment
    (a learned linear projection from source_dim to target_dim), while
    the nonlinear branch learns fine-grained distributional correction.

    Args:
        source_dim: Hidden dimension of the source (sending) agent.
        target_dim: Hidden dimension of the target (receiving) agent.
    """

    def __init__(self, source_dim: int, target_dim: int):
        super().__init__()
        self.source_dim = source_dim
        self.target_dim = target_dim

        # W1: projects within source space before the nonlinear activation
        self.W1 = nn.Linear(source_dim, source_dim, bias=True)

        # W2: projects from source space to target space (nonlinear branch)
        self.W2 = nn.Linear(source_dim, target_dim, bias=True)

        # W3: the residual projection that handles dimensional alignment.
        # This is the key structural difference from the inner link.
        # No bias: the bias in W2 already handles the offset.
        self.W3 = nn.Linear(source_dim, target_dim, bias=False)

        self._initialize_weights()

    def _initialize_weights(self):
        """
        Initialize so that W3 provides a reasonable linear baseline
        (Xavier uniform) and the nonlinear branch starts near zero,
        similar to the inner link's identity-at-init strategy.
        """
        nn.init.xavier_uniform_(self.W3.weight, gain=1.0)
        nn.init.xavier_uniform_(self.W1.weight, gain=0.1)
        nn.init.zeros_(self.W1.bias)
        nn.init.zeros_(self.W2.weight)
        nn.init.zeros_(self.W2.bias)

    def forward(self, h: torch.Tensor) -> torch.Tensor:
        """
        Transform source agent's hidden states into target agent's
        input embedding space.

        Args:
            h: Hidden state tensor from the source agent.
               Shape: (batch_size, seq_len, source_dim) or
                      (batch_size, source_dim).

        Returns:
            Transformed embedding aligned with the target agent's space.
            Shape: (batch_size, seq_len, target_dim) or
                   (batch_size, target_dim), matching the input rank.
        """
        # Linear residual branch: handles dimensional alignment
        linear_branch = self.W3(h)
        # Nonlinear branch: learns fine-grained distributional correction
        nonlinear_branch = self.W2(F.gelu(self.W1(h)))
        return linear_branch + nonlinear_branch

Notice how clean and small this is. The entire RecursiveLink — the module that makes the whole system work — is fewer than 120 lines of code including comments. This is one of the most elegant aspects of the paper: the key innovation is architecturally tiny, even though its effects are large.

Latent Thoughts Generation: Thinking Without Words

The next concept to understand is latent thoughts generation, which is the process by which an agent reasons in the continuous latent space rather than producing text.

In normal autoregressive generation, an LLM takes an input sequence of tokens, embeds them into vectors, runs them through the Transformer, and at each step produces a probability distribution over the vocabulary from which it samples the next token. The next token is then embedded and fed back in as input, and the cycle continues.

In latent thoughts generation, the process is different. After the Transformer produces the last-layer hidden state (h_t) at step (t), instead of projecting (h_t) to the vocabulary and sampling a token, we pass (h_t) through the Inner RecursiveLink to get (e_{t+1} = R_{\text{inner}}(h_t)). This embedding (e_{t+1}) is then fed directly as the input embedding for step (t+1), bypassing the token sampling entirely. The agent runs for (m) steps this way, producing a sequence of hidden states (H = [h_t, h_{t+1}, \ldots, h_{t+m}]) that represent its "latent thoughts" — its reasoning, encoded as vectors rather than words.

The paper shows that (m = 80) is a sweet spot: performance improves steadily as (m) increases from 0 to about 80, then plateaus. This is a practically useful finding because it means you do not need to run hundreds of latent steps to get the benefit — a moderate budget of 80 steps is enough for effective collaboration.

The following function illustrates how latent thoughts generation works. Note the careful separation of the frozen model forward pass (inside torch.no_grad()) from the trainable inner link forward pass (outside torch.no_grad()). This separation is essential: the model's parameters must not receive gradients (they are frozen), but the inner link's parameters must receive gradients during training.

from typing import Optional


def generate_latent_thoughts(
    model: nn.Module,
    input_embeddings: torch.Tensor,
    inner_link: InnerRecursiveLink,
    num_latent_steps: int = 80,
    prior_latent_state: Optional[torch.Tensor] = None,
    training: bool = False,
) -> torch.Tensor:
    """
    Run an agent in latent-thoughts generation mode.

    Instead of decoding tokens at each step, we use the Inner RecursiveLink
    to feed the last-layer hidden state back as the next input embedding.
    This allows the agent to reason in continuous latent space for
    `num_latent_steps` steps before any text is produced.

    IMPORTANT — gradient flow design:
        The frozen model forward pass is always wrapped in torch.no_grad()
        to avoid building a computation graph through frozen parameters
        (which would waste memory and compute). The inner_link forward
        pass is NOT wrapped in torch.no_grad() when training=True, so
        gradients flow through the link weights during backpropagation.
        When training=False (inference), everything runs under no_grad.

    Args:
        model: A Hugging Face causal LM. Its parameters must be frozen
               before calling this function.
        input_embeddings: The embedded input context for this agent.
                          Shape: (batch_size, context_len, hidden_dim)
        inner_link: The InnerRecursiveLink for this agent.
        num_latent_steps: How many latent reasoning steps to perform.
                          The paper finds m=80 is a good default.
        prior_latent_state: Optional latent thoughts from a previous
                            recursion round, already projected into this
                            agent's embedding space by an outer link.
                            Shape: (batch_size, prior_len, hidden_dim).
                            Prepended to the input context so the agent
                            can condition on previous-round information.
        training: Set True during outer-loop training so gradients flow
                  through the inner_link. Set False for inference.

    Returns:
        latent_thoughts: The sequence of last-layer hidden states produced
                         during latent generation.
                         Shape: (batch_size, num_latent_steps, hidden_dim)
    """
    # If we have latent state from a previous recursion round,
    # prepend it to the input context so the agent conditions on it.
    if prior_latent_state is not None:
        current_embeddings = torch.cat(
            [prior_latent_state, input_embeddings], dim=1
        )
    else:
        current_embeddings = input_embeddings

    latent_thoughts = []

    for _ in range(num_latent_steps):
        # ---- Frozen model forward pass (never builds grad graph) ----
        with torch.no_grad():
            outputs = model(
                inputs_embeds=current_embeddings,
                output_hidden_states=True,
                use_cache=False,
            )
            # Last-layer hidden state at the final sequence position.
            # Shape: (batch_size, hidden_dim)
            last_hidden = outputs.hidden_states[-1][:, -1, :]

        # ---- Trainable inner link (builds grad graph when training) ----
        # last_hidden is detached from the model graph (produced under
        # no_grad), but gradients still flow through inner_link's own
        # weight matrices W1 and W2, which is exactly what we want.
        if training:
            next_embedding = inner_link(last_hidden)
        else:
            with torch.no_grad():
                next_embedding = inner_link(last_hidden)

        # Store the hidden state (detached — we only need it as a value
        # to pass to the next agent, not to differentiate through it here)
        latent_thoughts.append(last_hidden.unsqueeze(1))

        # Append the new embedding to the running context.
        # The context grows by one embedding per step.
        current_embeddings = torch.cat(
            [current_embeddings, next_embedding.unsqueeze(1)], dim=1
        )

    # Stack all latent thoughts into a single tensor.
    # Shape: (batch_size, num_latent_steps, hidden_dim)
    return torch.cat(latent_thoughts, dim=1)

The key thing to notice here is the loop structure. At each step, we run the full Transformer forward pass under torch.no_grad() (because the model is frozen), grab the last-layer hidden state, then call the inner link outsidetorch.no_grad() (because the link needs gradients during training). The context window grows by one embedding per step. After (m) steps, we have a sequence of (m) hidden state vectors that encode the agent's latent reasoning about the problem.

This is conceptually similar to how Chain-of-Thought prompting works — you are giving the model space to reason before committing to an answer — but instead of generating text tokens that consume vocabulary space and require decoding, you are generating continuous vectors that are much more information-dense and much cheaper to produce.

Chaining Agents into a Loop

Now we have the two building blocks: the RecursiveLink modules and the latent thoughts generation process. The third piece is how these are combined to form the recursive loop.

The process for a single recursion round goes like this. Agent A₁ receives the input question (as embeddings) and, if this is not the first round, the latent state from the previous round. It runs latent thoughts generation for (m) steps, producing (H_{A_1}). These latent thoughts are then passed through the Outer RecursiveLink to transform them into the embedding space of Agent A₂. Agent A₂ receives both its own input context embeddings and the transformed latent thoughts from A₁, concatenated together. Agent A₂ then runs its own latent thoughts generation, producing (H_{A_2}). This continues through all N agents.

After the last agent Aâ‚™ completes latent thoughts generation, its latent outputs (H_{A_N}) are passed back to the first agent A₁ through another Outer RecursiveLink, closing the loop. This is the "recursive" part: the system's latent answer from round (r) becomes additional context for round (r+1). Each new round can condition on what the system collectively produced in all previous rounds, enabling iterative refinement.

Only after the final recursion round does any text get produced. The last agent Aâ‚™ decodes its latent thoughts into a textual answer using the standard vocabulary projection. All intermediate rounds are entirely in the latent space.

The following diagram represents the information flow for a two-agent system over two recursion rounds:

Round 1:
=========
Question (text) --> [Embed] --> E_A1
                                  |
                                  v
                          [Agent A1 + Inner Link]
                          generates H_A1 (latent)
                                  |
                          [Outer Link A1->A2]
                                  |
                                  v
                E_A2 + R_outer(H_A1) --> [Agent A2 + Inner Link]
                                         generates H_A2 (latent)
                                                |
                                        [Outer Link A2->A1]
                                                |
                          +---------------------+
                          |
                          v
Round 2:
=========
Question (text) --> [Embed] --> E_A1
R_outer(H_A2 from Round 1) ------+
                                  |
                                  v
                          [Agent A1 + Inner Link]
                          generates H_A1' (latent)
                                  |
                          [Outer Link A1->A2]
                                  |
                                  v
                E_A2 + R_outer(H_A1') --> [Agent A2 + Inner Link]
                                          generates H_A2' (latent)
                                                |
                                        [Decode to text]
                                                |
                                                v
                                        FINAL ANSWER

This is a beautiful structure. The question embeddings are always fed in fresh at the start of each round, so the agents never lose sight of what they are trying to answer. But they also receive the accumulated latent wisdom of all previous rounds, allowing them to iteratively refine their reasoning.


CHAPTER THREE: TRAINING THE SYSTEM

One of the most practically important aspects of RecursiveMAS is how it is trained. The answer is elegant: you do not train the large LLM agents at all. You freeze all their parameters and only train the tiny RecursiveLink modules. This means the training cost is dramatically lower than fine-tuning the agents themselves.

The paper reports that RecursiveMAS uses only 13.12 million trainable parameters (0.31% of the total parameter count), compared to 15.29 million for LoRA and 4.21 billion for full supervised fine-tuning. Despite having fewer trainable parameters than LoRA, RecursiveMAS achieves higher accuracy (74.9% average vs. 66.9% for LoRA and 68.6% for full SFT) and uses less GPU memory (15.29 GB vs. 21.67 GB for LoRA and 41.40 GB for full SFT). This is a remarkable result that demonstrates the power of training the right thing rather than training everything.

The training happens in two stages, which the paper calls the Inner-Outer Loop training paradigm.

Stage One: Inner-Loop Training

The inner loop trains each agent's Inner RecursiveLink independently. The goal is to warm-start each agent so that it can generate useful latent thoughts — so that the hidden states it produces in latent mode are semantically meaningful and aligned with what the model would produce if it were generating text.

The training objective for the inner loop is a cosine similarity loss. For each training example ((x, y)), you run the agent in latent mode to get an initial hidden state (H_0) (the last-layer hidden state after processing the input context). You also take the ground truth answer (y), pass it through the model's standard input embedding layer to get (\text{Emb}(y)), and compute the cosine similarity between (R_{\text{inner}}(H_0)) and the mean of (\text{Emb}(y)). The loss is:

$$L_{\text{inner}} = 1 - \cos!\left(R_{\text{inner}}(H_0),; \overline{\text{Emb}(y)}\right)$$

where (\overline{\text{Emb}(y)}) denotes the mean of the answer token embeddings, representing the semantic "direction" the latent thoughts should point toward. Minimizing this loss encourages the inner link to transform the agent's latent thoughts so that they point in the same direction as the embeddings of the correct answer. This is a form of self-supervised alignment: you are teaching the agent to think in a way that is consistent with its own text generation capabilities.

The following code shows how to compute this inner-loop training loss and run the training stage:

from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR
from typing import Tuple


def compute_inner_loop_loss(
    model: nn.Module,
    inner_link: InnerRecursiveLink,
    input_ids: torch.Tensor,
    target_ids: torch.Tensor,
) -> torch.Tensor:
    """
    Compute the inner-loop training loss for one agent.

    The loss encourages the inner link to produce latent thoughts
    that are semantically aligned with the ground-truth answer
    embeddings, measured by cosine similarity.

    Gradient flow:
        - model parameters: frozen, no gradients computed.
        - inner_link parameters: receive gradients through the
          cosine similarity loss. The key is that inner_link(H_0)
          is called OUTSIDE torch.no_grad(), so PyTorch builds a
          computation graph through W1 and W2 of the inner link.
          H_0 itself is a detached tensor (produced under no_grad),
          which is correct since we do not want gradients to flow
          into the frozen model.

    Args:
        model: The frozen LLM agent.
        inner_link: The Inner RecursiveLink to train.
        input_ids: Tokenized input question. Shape: (batch, seq_len).
        target_ids: Tokenized ground-truth answer. Shape: (batch, ans_len).

    Returns:
        loss: Scalar tensor. Value is in [0, 2]; minimizing it aligns
              the inner link's output with the answer embedding direction.
    """
    embedding_layer = model.get_input_embeddings()

    # Embed the input question context (no grad needed here)
    with torch.no_grad():
        # Shape: (batch, seq_len, hidden_dim)
        input_embeddings = embedding_layer(input_ids)

        # Embed the ground-truth answer — this is our alignment target.
        # Shape: (batch, ans_len, hidden_dim)
        target_embeddings = embedding_layer(target_ids)

        # Run the frozen model on the input context to get the initial
        # last-layer hidden state H_0.
        initial_output = model(
            inputs_embeds=input_embeddings,
            output_hidden_states=True,
            use_cache=False,
        )
        # Shape: (batch, hidden_dim)
        H_0 = initial_output.hidden_states[-1][:, -1, :]

    # Apply the inner link WITH gradient tracking.
    # H_0 is detached (produced under no_grad), but gradients still
    # flow through inner_link's own weight matrices W1 and W2.
    # Shape: (batch, hidden_dim)
    transformed_latent = inner_link(H_0)

    # Target direction: mean of the ground-truth answer embeddings.
    # This is a detached tensor (produced under no_grad above).
    # Shape: (batch, hidden_dim)
    target_direction = target_embeddings.mean(dim=1)

    # Cosine similarity: values in [-1, 1]. We want to maximize it,
    # so we minimize 1 - similarity.
    similarity = F.cosine_similarity(
        transformed_latent, target_direction, dim=-1
    )

    # Average over the batch
    loss = (1.0 - similarity).mean()
    return loss


def train_inner_loop(
    model: nn.Module,
    inner_link: InnerRecursiveLink,
    dataloader,
    num_epochs: int = 3,
    learning_rate: float = 1e-4,
    device: str = "cuda",
) -> None:
    """
    Run the inner-loop training stage for one agent.

    Only the inner_link parameters are updated. The model is frozen.
    Call this independently for each agent before running outer-loop
    training.

    Args:
        model: The frozen LLM agent (all parameters must have
               requires_grad=False before calling this function).
        inner_link: The Inner RecursiveLink to train.
        dataloader: DataLoader yielding (input_ids, target_ids) batches,
                    where both tensors are on `device`.
        num_epochs: Number of training epochs.
        learning_rate: Initial learning rate for AdamW.
        device: Device string for moving tensors if needed.
    """
    # Verify the model is frozen
    for param in model.parameters():
        param.requires_grad = False

    # Only optimize the inner link parameters
    optimizer = AdamW(inner_link.parameters(), lr=learning_rate)
    scheduler = CosineAnnealingLR(
        optimizer, T_max=num_epochs * len(dataloader)
    )

    inner_link.train()
    model.eval()

    for epoch in range(num_epochs):
        total_loss = 0.0
        num_batches = 0

        for input_ids, target_ids in dataloader:
            optimizer.zero_grad()

            loss = compute_inner_loop_loss(
                model=model,
                inner_link=inner_link,
                input_ids=input_ids,
                target_ids=target_ids,
            )

            loss.backward()

            # Gradient clipping for stability
            torch.nn.utils.clip_grad_norm_(
                inner_link.parameters(), max_norm=1.0
            )

            optimizer.step()
            scheduler.step()

            total_loss += loss.item()
            num_batches += 1

        avg_loss = total_loss / max(num_batches, 1)
        print(
            f"  [Inner Loop] Epoch {epoch + 1}/{num_epochs} "
            f"| Loss: {avg_loss:.4f}"
        )

The inner-loop training is done independently for each agent. You train Agent A₁'s inner link, then Agent A₂'s inner link, and so on. Each agent learns to generate latent thoughts that are semantically coherent on its own, before the agents start collaborating.

Stage Two: Outer-Loop Training

After the inner loop has warmed up each agent's latent thoughts generation capability, the outer loop trains the Outer RecursiveLink modules that connect agents to each other. This is where the real magic happens: the entire system is optimized end-to-end as a unified entity.

The outer-loop training objective is a standard cross-entropy loss on the final textual output:

$$L_{\text{outer}} = \text{CrossEntropy}!\left(S^{(n)}!\left(S^{(n-1)}!\left(\cdots S^{(1)}(x)\cdots\right)\right), y\right)$$

where (S^{(r)}(\cdot)) denotes the system state after recursion round (r). The computation graph is preserved through all recursion rounds, so gradients can flow backward from the final prediction all the way through every outer link in every round. Each outer link receives a gradient signal that reflects its global contribution to the final answer, not just its local behavior.

This is the "shared credit assignment" that the paper refers to. Instead of training each agent in isolation and hoping they work well together, the outer loop trains the connections between agents with full knowledge of how the whole system performs. It is the difference between training individual musicians in isolation and rehearsing the whole orchestra together.

The following code implements the full RecursiveMASSystem as a PyTorch nn.Module and the outer-loop training procedure. Note the careful use of nn.ModuleList to properly register the RecursiveLink submodules (so they appear in system.parameters() and system.state_dict()), and the separation of frozen agent models from trainable links:

class RecursiveMASSystem(nn.Module):
    """
    A RecursiveMAS system with N agents in sequential style.

    This class wires together N frozen LLM agents with their
    Inner RecursiveLinks and Outer RecursiveLinks, implementing
    the full recursive loop for training and inference.

    Architecture notes:
        - agent_models: stored as a plain Python list (NOT nn.ModuleList)
          because their parameters are frozen and we deliberately do not
          want them to appear in system.parameters() or system.state_dict().
          This keeps the optimizer and checkpoint focused on the links only.
        - inner_links, outer_links: stored as nn.ModuleList so they ARE
          registered as submodules, appear in system.parameters(), and are
          saved/loaded correctly via state_dict().

    Args:
        agent_models: List of N frozen Hugging Face causal LM models.
                      Parameters must be frozen before passing them here.
        inner_links: List of N InnerRecursiveLink modules, one per agent.
        outer_links: List of N OuterRecursiveLink modules. outer_links[i]
                     connects agent i to agent (i+1) % N, so the last
                     element connects the final agent back to the first,
                     closing the recursive loop.
        num_latent_steps: Number of latent reasoning steps per agent
                          per recursion round (default 80, per the paper).
    """

    def __init__(
        self,
        agent_models: list,
        inner_links: list,
        outer_links: list,
        num_latent_steps: int = 80,
    ):
        super().__init__()

        if len(agent_models) < 2:
            raise ValueError(
                "RecursiveMASSystem requires at least 2 agents."
            )
        if len(inner_links) != len(agent_models):
            raise ValueError(
                "Must provide exactly one inner_link per agent."
            )
        if len(outer_links) != len(agent_models):
            raise ValueError(
                "Must provide exactly one outer_link per agent "
                "(outer_links[i] connects agent i to agent (i+1) % N)."
            )

        self.num_agents = len(agent_models)
        self.num_latent_steps = num_latent_steps

        # Store frozen agent models as a plain list.
        # They are NOT registered as nn.Module submodules intentionally:
        # their parameters are frozen and should not appear in
        # self.parameters() to keep the optimizer clean.
        self._agent_models = agent_models

        # Register RecursiveLinks as proper nn.Module submodules so they
        # appear in self.parameters() and self.state_dict().
        self.inner_links = nn.ModuleList(inner_links)
        self.outer_links = nn.ModuleList(outer_links)

        # Verify all agent models are frozen
        for i, model in enumerate(self._agent_models):
            for param in model.parameters():
                if param.requires_grad:
                    raise ValueError(
                        f"Agent {i} has unfrozen parameters. "
                        f"Freeze all agent parameters before constructing "
                        f"RecursiveMASSystem."
                    )

    def _run_agent_latent(
        self,
        agent_idx: int,
        context_embeddings: torch.Tensor,
        training: bool,
    ) -> torch.Tensor:
        """
        Run one agent in latent-thoughts generation mode.

        The frozen model forward pass runs under torch.no_grad() to
        avoid building a computation graph through frozen parameters.
        The inner_link forward pass runs with gradient tracking when
        training=True, so the link weights receive proper gradients.

        Args:
            agent_idx: Index of the agent to run (0-based).
            context_embeddings: Input context for this agent, already
                                including any transferred latent state
                                from other agents.
                                Shape: (batch, context_len, hidden_dim)
            training: True during outer-loop training; False at inference.

        Returns:
            latent_thoughts: Shape (batch, num_latent_steps, hidden_dim)
        """
        model = self._agent_models[agent_idx]
        inner_link = self.inner_links[agent_idx]

        current_embeddings = context_embeddings
        latent_thoughts = []

        for _ in range(self.num_latent_steps):
            # Frozen model forward pass — never builds grad graph
            with torch.no_grad():
                outputs = model(
                    inputs_embeds=current_embeddings,
                    output_hidden_states=True,
                    use_cache=False,
                )
                # Shape: (batch, hidden_dim)
                last_hidden = outputs.hidden_states[-1][:, -1, :]

            # Trainable inner link — builds grad graph when training=True
            if training:
                next_emb = inner_link(last_hidden)
            else:
                with torch.no_grad():
                    next_emb = inner_link(last_hidden)

            # Store the hidden state value (always detached from model graph)
            latent_thoughts.append(last_hidden.unsqueeze(1))

            # Grow the context by one embedding
            current_embeddings = torch.cat(
                [current_embeddings, next_emb.unsqueeze(1)], dim=1
            )

        # Shape: (batch, num_latent_steps, hidden_dim)
        return torch.cat(latent_thoughts, dim=1)

    def forward(
        self,
        input_ids: torch.Tensor,
        num_recursion_rounds: int = 3,
    ) -> torch.Tensor:
        """
        Run the full RecursiveMAS forward pass.

        Performs `num_recursion_rounds` rounds of latent collaboration
        among all agents, then decodes the final answer as text logits
        from the last agent in the final round.

        Args:
            input_ids: Tokenized input question.
                       Shape: (batch, seq_len)
            num_recursion_rounds: Number of recursive collaboration rounds.
                                  The paper uses n=3 as the default.

        Returns:
            logits: Output logits from the final agent in the last round.
                    Shape: (batch, final_context_len, vocab_size)
        """
        training = self.training  # True if system.train() was called

        # Pre-embed the input question for each agent.
        # Each agent may have a different hidden dimension, so we embed
        # separately using each agent's own embedding layer.
        input_embeddings = []
        for model in self._agent_models:
            with torch.no_grad():
                emb = model.get_input_embeddings()(input_ids)
            input_embeddings.append(emb)

        # feedback_latent[i] holds the outer-link-projected latent state
        # that agent i will receive at the start of the next round.
        # Initialized to None (no feedback before round 1).
        feedback_latent = [None] * self.num_agents

        final_logits = None

        for round_idx in range(num_recursion_rounds):
            new_feedback_latent = [None] * self.num_agents

            for agent_idx in range(self.num_agents):
                # Build context: agent's own input + feedback from
                # the previous round (if any)
                if feedback_latent[agent_idx] is not None:
                    context = torch.cat(
                        [feedback_latent[agent_idx],
                         input_embeddings[agent_idx]],
                        dim=1,
                    )
                else:
                    context = input_embeddings[agent_idx]

                # Run this agent in latent mode
                latent = self._run_agent_latent(
                    agent_idx=agent_idx,
                    context_embeddings=context,
                    training=training,
                )

                # Project this agent's latent thoughts into the next
                # agent's embedding space via the outer link.
                next_agent_idx = (agent_idx + 1) % self.num_agents
                outer_link = self.outer_links[agent_idx]

                if training:
                    projected = outer_link(latent)
                else:
                    with torch.no_grad():
                        projected = outer_link(latent)

                # The projected latent becomes feedback for the next agent
                # in this same round (sequential style: each agent sees
                # the previous agent's output within the same round).
                new_feedback_latent[next_agent_idx] = projected

            feedback_latent = new_feedback_latent

            # After the final round, decode text from the last agent.
            # We run one more forward pass on the last agent with its
            # full context (input + feedback it received this round)
            # to get vocabulary logits.
            if round_idx == num_recursion_rounds - 1:
                last_agent_idx = self.num_agents - 1
                last_model = self._agent_models[last_agent_idx]

                # Reconstruct the last agent's full context for decoding
                if feedback_latent[last_agent_idx] is not None:
                    decode_context = torch.cat(
                        [feedback_latent[last_agent_idx],
                         input_embeddings[last_agent_idx]],
                        dim=1,
                    )
                else:
                    decode_context = input_embeddings[last_agent_idx]

                with torch.no_grad() if not training else torch.enable_grad():
                    decode_output = last_model(
                        inputs_embeds=decode_context,
                        output_hidden_states=False,
                        use_cache=False,
                    )
                final_logits = decode_output.logits

        return final_logits


def train_outer_loop(
    system: RecursiveMASSystem,
    dataloader,
    num_epochs: int = 3,
    learning_rate: float = 1e-4,
    num_recursion_rounds: int = 3,
) -> None:
    """
    Run the outer-loop training stage for the full RecursiveMAS system.

    Only the RecursiveLink parameters (inner_links + outer_links) are
    updated. All LLM agent models remain frozen throughout.

    The computation graph is preserved across all recursion rounds so
    that gradients flow through every outer link and inner link in
    every round — this is the "shared credit assignment" described in
    the paper.

    Args:
        system: The RecursiveMASSystem to train. Must have been
                constructed with frozen agent models.
        dataloader: DataLoader yielding (input_ids, target_ids) batches.
                    target_ids should use -100 for positions that should
                    not contribute to the loss (standard HF convention).
        num_epochs: Number of training epochs.
        learning_rate: Initial learning rate for AdamW.
        num_recursion_rounds: Number of recursive rounds during training.
    """
    # system.parameters() returns only the RecursiveLink parameters
    # because the agent models are stored as a plain list (not ModuleList)
    # and thus not registered as submodules.
    optimizer = AdamW(system.parameters(), lr=learning_rate)
    scheduler = CosineAnnealingLR(
        optimizer, T_max=num_epochs * len(dataloader)
    )
    criterion = nn.CrossEntropyLoss(ignore_index=-100)

    system.train()

    for epoch in range(num_epochs):
        total_loss = 0.0
        num_batches = 0

        for input_ids, target_ids in dataloader:
            optimizer.zero_grad()

            # Full forward pass through all recursion rounds.
            # The computation graph is preserved for backpropagation
            # through all RecursiveLink modules in all rounds.
            logits = system(
                input_ids=input_ids,
                num_recursion_rounds=num_recursion_rounds,
            )

            # logits shape: (batch, context_len, vocab_size)
            # target_ids shape: (batch, seq_len)
            # We align by taking only the last seq_len positions of logits
            # (the positions corresponding to the answer tokens).
            seq_len = target_ids.shape[1]
            logits_for_loss = logits[:, -seq_len:, :]

            batch_size, ans_len, vocab_size = logits_for_loss.shape
            loss = criterion(
                logits_for_loss.reshape(batch_size * ans_len, vocab_size),
                target_ids.reshape(batch_size * ans_len),
            )

            # Backpropagate through ALL recursion rounds.
            # Gradients flow through every outer link and inner link.
            loss.backward()

            # Gradient clipping for stability (paper uses AdamW + cosine LR)
            torch.nn.utils.clip_grad_norm_(
                system.parameters(), max_norm=1.0
            )

            optimizer.step()
            scheduler.step()

            total_loss += loss.item()
            num_batches += 1

        avg_loss = total_loss / max(num_batches, 1)
        print(
            f"  [Outer Loop] Epoch {epoch + 1}/{num_epochs} "
            f"| Loss: {avg_loss:.4f}"
        )

CHAPTER FOUR: WORKING WITH LOCAL AND REMOTE LLMS

Now that we understand the theory and the training procedure, let us look at how to actually use RecursiveMAS in practice. The paper uses models from the Hugging Face ecosystem (Qwen, LLaMA, Gemma, Mistral), but the principles apply equally to remote API-based models. We will build a practical wrapper that supports both local models (via Hugging Face Transformers or Ollama) and remote models (via OpenAI-compatible APIs).

The key challenge with remote APIs is that they do not give you access to hidden states. You cannot intercept the last-layer embeddings of GPT-4 or Claude. This means you cannot implement the full RecursiveMAS architecture with remote APIs — you can only implement a text-based approximation. However, understanding this limitation is itself valuable, and for local models you can implement the full system.

We will build a unified interface that handles both cases gracefully.

The Agent Abstraction

The first thing we need is a clean abstraction for an "agent" that works regardless of whether the underlying model is local or remote:

from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import Optional


@dataclass
class AgentResponse:
    """
    Unified response object returned by any agent, local or remote.

    For local agents, hidden_states contains the actual latent vectors
    that RecursiveMAS uses for cross-agent communication. For remote
    agents, hidden_states is None and only the text response is available.

    Attributes:
        text: The decoded text response from the agent.
        hidden_states: Last-layer hidden states if available (local only).
                       Shape: (1, num_generated_tokens, hidden_dim) or None.
                       Note: shape depends on how many tokens were generated.
        token_count: Number of tokens generated (for efficiency tracking).
    """
    text: str
    hidden_states: Optional[torch.Tensor]
    token_count: int


class BaseAgent(ABC):
    """
    Abstract base class for all RecursiveMAS agents.

    Concrete implementations handle local Hugging Face models,
    Ollama-served models, and remote OpenAI-compatible API models.
    All agents expose the same interface so the RecursiveMAS orchestrator
    can work with any combination of local and remote agents.
    """

    def __init__(self, name: str, role: str):
        """
        Args:
            name: A human-readable name for this agent (e.g., "Planner").
            role: The agent's role description, used in system prompts.
        """
        self.name = name
        self.role = role

    @abstractmethod
    def generate_text(
        self,
        prompt: str,
        system_prompt: Optional[str] = None,
        max_new_tokens: int = 512,
        temperature: float = 0.6,
    ) -> AgentResponse:
        """
        Generate a text response to the given prompt.

        This is the universal interface used by all agents.
        Local agents additionally populate hidden_states in the response.

        Args:
            prompt: The user prompt to respond to.
            system_prompt: Optional system prompt for role conditioning.
            max_new_tokens: Maximum number of tokens to generate.
            temperature: Sampling temperature (lower = more deterministic).

        Returns:
            AgentResponse with text, optional hidden_states, and token_count.
        """
        pass

    @property
    @abstractmethod
    def supports_latent_transfer(self) -> bool:
        """
        Whether this agent supports latent-space state transfer.

        Returns True for local Hugging Face agents (which expose hidden
        states), and False for remote API agents (which do not).
        """
        pass

    def __repr__(self) -> str:
        return (
            f"{self.__class__.__name__}"
            f"(name={self.name!r}, role={self.role!r})"
        )

Local Agent: Hugging Face Transformers

The local agent implementation uses Hugging Face Transformers and gives us full access to hidden states, enabling the complete RecursiveMAS architecture:

from transformers import AutoTokenizer, AutoModelForCausalLM


class LocalHuggingFaceAgent(BaseAgent):
    """
    A RecursiveMAS agent backed by a local Hugging Face model.

    This agent supports full latent-space state transfer because it
    has direct access to the model's internal hidden states. It is
    the agent type used in the original RecursiveMAS paper with
    models like Qwen3, LLaMA-3, Gemma3, and Mistral.

    The model parameters are always frozen — only the RecursiveLink
    parameters are trained.

    Args:
        model_name_or_path: Hugging Face model identifier or local path.
                            Examples:
                              "Qwen/Qwen2.5-1.5B-Instruct"
                              "meta-llama/Llama-3.2-1B-Instruct"
                              "/path/to/local/model"
        name: Human-readable agent name.
        role: Agent role description for system prompts.
        device: Torch device string. Use "cuda" for NVIDIA GPU,
                "mps" for Apple Silicon, "cpu" for CPU-only.
                Note: "mps" support varies by model; test before deploying.
        torch_dtype: Data type for model weights. torch.float16 is
                     recommended for GPU to save memory. Use torch.float32
                     for CPU or if you encounter numerical issues.
    """

    def __init__(
        self,
        model_name_or_path: str,
        name: str,
        role: str,
        device: str = "cuda",
        torch_dtype: torch.dtype = torch.float16,
    ):
        super().__init__(name=name, role=role)
        self.device = device
        self.model_name = model_name_or_path

        print(f"Loading {name} from {model_name_or_path}...")

        self.tokenizer = AutoTokenizer.from_pretrained(
            model_name_or_path,
            trust_remote_code=True,
        )

        # Ensure the tokenizer has a padding token
        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token

        # Load model and immediately freeze all parameters
        self.model = AutoModelForCausalLM.from_pretrained(
            model_name_or_path,
            torch_dtype=torch_dtype,
            device_map=device,
            trust_remote_code=True,
        )

        # Freeze all model parameters — only RecursiveLinks are trained
        for param in self.model.parameters():
            param.requires_grad = False

        self.model.eval()

        # Cache the hidden dimension for RecursiveLink sizing
        self.hidden_dim = self.model.config.hidden_size

        total_params = sum(p.numel() for p in self.model.parameters())
        print(
            f"  Loaded {name}: hidden_dim={self.hidden_dim}, "
            f"params={total_params:,} (all frozen)"
        )

    @property
    def supports_latent_transfer(self) -> bool:
        """Local models always support latent transfer."""
        return True

    def get_hidden_dim(self) -> int:
        """Return the model's hidden dimension for RecursiveLink sizing."""
        return self.hidden_dim

    def get_raw_model(self) -> nn.Module:
        """
        Return the underlying frozen Hugging Face model.

        Used by RecursiveMASSystem to access the model directly for
        latent-space operations.
        """
        return self.model

    def get_embeddings(self, text: str) -> torch.Tensor:
        """
        Get the input embeddings for a piece of text.

        Used to embed the question context before latent thoughts generation.

        Args:
            text: The text to embed.

        Returns:
            Embeddings tensor of shape (1, seq_len, hidden_dim).
        """
        inputs = self.tokenizer(
            text,
            return_tensors="pt",
            truncation=True,
            max_length=512,
        ).to(self.device)

        with torch.no_grad():
            embeddings = self.model.get_input_embeddings()(
                inputs["input_ids"]
            )
        return embeddings

    def generate_text(
        self,
        prompt: str,
        system_prompt: Optional[str] = None,
        max_new_tokens: int = 512,
        temperature: float = 0.6,
    ) -> AgentResponse:
        """
        Generate a text response using standard autoregressive decoding.

        This is used for the final output in the last recursion round,
        or for text-based approximation mode.

        Args:
            prompt: The user message.
            system_prompt: Optional system message for role conditioning.
            max_new_tokens: Maximum tokens to generate.
            temperature: Sampling temperature.

        Returns:
            AgentResponse with text output. hidden_states contains the
            last-layer hidden states from the final generation step,
            shape (1, 1, hidden_dim) — one position, last layer.
        """
        messages = []
        if system_prompt:
            messages.append({"role": "system", "content": system_prompt})
        messages.append({"role": "user", "content": prompt})

        formatted_prompt = self.tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=True,
        )

        inputs = self.tokenizer(
            formatted_prompt,
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=2048,
        ).to(self.device)

        input_length = inputs["input_ids"].shape[1]

        with torch.no_grad():
            outputs = self.model.generate(
                **inputs,
                max_new_tokens=max_new_tokens,
                temperature=temperature if temperature > 0 else 1.0,
                top_p=0.95,
                do_sample=(temperature > 0),
                output_hidden_states=True,
                return_dict_in_generate=True,
                pad_token_id=self.tokenizer.pad_token_id,
            )

        # Decode the generated text (excluding the input prompt tokens)
        generated_ids = outputs.sequences[:, input_length:]
        generated_text = self.tokenizer.decode(
            generated_ids[0],
            skip_special_tokens=True,
        )

        # Extract hidden states from the last generation step.
        # outputs.hidden_states is a tuple (one entry per generated token),
        # each entry is itself a tuple (one entry per Transformer layer).
        # outputs.hidden_states[-1][-1] is the last layer of the last token.
        # Shape: (batch_size=1, seq_position=1, hidden_dim)
        last_step_hidden = None
        if outputs.hidden_states:
            last_step_hidden = outputs.hidden_states[-1][-1]

        token_count = generated_ids.shape[1]

        return AgentResponse(
            text=generated_text,
            hidden_states=last_step_hidden,
            token_count=token_count,
        )

Remote Agent: OpenAI-Compatible API

For remote models, we implement a text-only agent that works with any OpenAI-compatible API. This includes OpenAI itself, Azure OpenAI, local Ollama servers, vLLM servers, and many other providers:

import time
import requests


class RemoteAPIAgent(BaseAgent):
    """
    A RecursiveMAS agent backed by a remote OpenAI-compatible API.

    This agent does NOT support latent-space transfer because remote
    APIs do not expose internal hidden states. It participates in the
    system through text-based communication only, making it suitable
    for the text-based approximation of RecursiveMAS or for the final
    output step where text generation is required anyway.

    Compatible with: OpenAI API, Azure OpenAI, Ollama (REST API),
                     vLLM, LM Studio, and any OpenAI-compatible server.

    Args:
        api_base_url: The base URL of the API endpoint.
                      Examples:
                        "https://api.openai.com/v1"
                        "http://localhost:11434/v1"   (Ollama)
                        "http://localhost:8000/v1"    (vLLM)
        model_id: The model identifier used in API calls.
                  Examples: "gpt-4o-mini", "llama3.2", "qwen2.5:7b"
        api_key: API key for authentication. Use "ollama" for Ollama
                 (it does not require a real key).
        name: Human-readable agent name.
        role: Agent role description.
        timeout: HTTP request timeout in seconds.
        max_retries: Number of retries on transient failures, with
                     exponential backoff between attempts.
    """

    def __init__(
        self,
        api_base_url: str,
        model_id: str,
        api_key: str,
        name: str,
        role: str,
        timeout: int = 60,
        max_retries: int = 3,
    ):
        super().__init__(name=name, role=role)
        self.api_base_url = api_base_url.rstrip("/")
        self.model_id = model_id
        self.timeout = timeout
        self.max_retries = max_retries

        self.headers = {
            "Authorization": f"Bearer {api_key}",
            "Content-Type": "application/json",
        }

    @property
    def supports_latent_transfer(self) -> bool:
        """Remote API agents do not support latent transfer."""
        return False

    def generate_text(
        self,
        prompt: str,
        system_prompt: Optional[str] = None,
        max_new_tokens: int = 512,
        temperature: float = 0.6,
    ) -> AgentResponse:
        """
        Generate a text response via the remote API.

        Implements exponential backoff retry logic for robustness
        against transient network failures and rate limiting.

        Args:
            prompt: The user message.
            system_prompt: Optional system message.
            max_new_tokens: Maximum tokens to generate.
            temperature: Sampling temperature.

        Returns:
            AgentResponse with text output. hidden_states is always None
            for remote agents since we cannot access internal model state.
        """
        messages = []
        if system_prompt:
            messages.append({"role": "system", "content": system_prompt})
        messages.append({"role": "user", "content": prompt})

        payload = {
            "model": self.model_id,
            "messages": messages,
            "max_tokens": max_new_tokens,
            "temperature": temperature,
            "top_p": 0.95,
        }

        last_error = None
        for attempt in range(self.max_retries):
            try:
                response = requests.post(
                    f"{self.api_base_url}/chat/completions",
                    headers=self.headers,
                    json=payload,
                    timeout=self.timeout,
                )
                response.raise_for_status()

                data = response.json()
                generated_text = data["choices"][0]["message"]["content"]
                token_count = data.get("usage", {}).get(
                    "completion_tokens",
                    len(generated_text.split()),
                )

                return AgentResponse(
                    text=generated_text,
                    hidden_states=None,
                    token_count=token_count,
                )

            except requests.exceptions.RequestException as e:
                last_error = e
                if attempt < self.max_retries - 1:
                    wait_time = 2 ** attempt
                    print(
                        f"  [{self.name}] Request failed "
                        f"(attempt {attempt + 1}/{self.max_retries}), "
                        f"retrying in {wait_time}s: {e}"
                    )
                    time.sleep(wait_time)

        raise RuntimeError(
            f"[{self.name}] All {self.max_retries} API attempts failed. "
            f"Last error: {last_error}"
        )


class OllamaAgent(RemoteAPIAgent):
    """
    Convenience subclass for agents backed by a local Ollama server.

    Ollama provides an OpenAI-compatible REST API at localhost:11434,
    making it easy to run models like Llama, Mistral, Qwen, and Gemma
    locally without writing GPU-intensive model loading code.

    Usage:
        # 1. Install Ollama: https://ollama.ai
        # 2. Start the server:
        #      ollama serve
        # 3. Pull a model:
        #      ollama pull qwen2.5:1.5b
        # 4. Create an agent:
        agent = OllamaAgent(
            model_id="qwen2.5:1.5b",
            name="Planner",
            role="You plan step-by-step solutions.",
        )

    Args:
        model_id: The Ollama model name (e.g., "qwen2.5:1.5b",
                  "llama3.2:1b", "qwen2.5:7b").
        name: Human-readable agent name.
        role: Agent role description.
        host: Ollama server hostname (default: "localhost").
        port: Ollama server port (default: 11434).
        timeout: HTTP request timeout in seconds (default: 120,
                 longer than RemoteAPIAgent default because local
                 models can be slow on CPU).
        max_retries: Number of retries on failure (default: 3).
    """

    def __init__(
        self,
        model_id: str,
        name: str,
        role: str,
        host: str = "localhost",
        port: int = 11434,
        timeout: int = 120,
        max_retries: int = 3,
    ):
        super().__init__(
            api_base_url=f"http://{host}:{port}/v1",
            model_id=model_id,
            api_key="ollama",
            name=name,
            role=role,
            timeout=timeout,
            max_retries=max_retries,
        )

CHAPTER FIVE: THE ORCHESTRATOR

With our agent abstractions in place, we can now build the orchestrator that implements the RecursiveMAS collaboration patterns. The orchestrator is responsible for managing the recursion rounds, routing latent states between agents (when available), and producing the final answer.

For the text-based approximation (which works with both local and remote agents), the orchestrator passes text between agents. For the full latent-space version (which requires local agents), it passes hidden state tensors through the RecursiveLink modules.

The following orchestrator implements the Sequential Style pattern, which is the most common and the one used for the paper's primary experiments. It supports both modes automatically based on whether the agents support latent transfer:

from typing import List, Optional, Tuple


class SequentialRecursiveMASOrchestrator:
    """
    Orchestrates a sequential-style RecursiveMAS collaboration.

    Manages a chain of agents (e.g., Planner -> Critic -> Solver) through
    multiple recursion rounds. Automatically uses latent-space transfer
    when all agents are local, and falls back to text-based transfer
    when any agent is remote.

    This implements the "light" and "scaled" sequential configuration
    from the paper: a Planner, Critic, and Solver in a chain, iterated
    over multiple recursion rounds.

    Args:
        agents: List of BaseAgent instances in pipeline order.
                Minimum 2 agents required.
        inner_links: List of InnerRecursiveLink modules, one per agent.
                     Required only for latent-space mode (all local agents).
                     Pass None to force text-based mode.
        outer_links: List of OuterRecursiveLink modules, one per agent.
                     outer_links[i] connects agent i to agent (i+1) % N.
                     Required only for latent-space mode.
                     Pass None to force text-based mode.
        num_latent_steps: Number of latent reasoning steps per agent
                          per round (default 80, per the paper).
        num_recursion_rounds: Number of recursive collaboration rounds
                              (default 3, per the paper).
    """

    def __init__(
        self,
        agents: List[BaseAgent],
        inner_links: Optional[List[InnerRecursiveLink]] = None,
        outer_links: Optional[List[OuterRecursiveLink]] = None,
        num_latent_steps: int = 80,
        num_recursion_rounds: int = 3,
    ):
        if len(agents) < 2:
            raise ValueError(
                "Sequential RecursiveMAS requires at least 2 agents."
            )

        self.agents = agents
        self.inner_links = inner_links
        self.outer_links = outer_links
        self.num_latent_steps = num_latent_steps
        self.num_recursion_rounds = num_recursion_rounds

        # Determine whether we can use full latent-space mode.
        # Requires: all agents are local AND links are provided.
        self.use_latent_mode = (
            all(agent.supports_latent_transfer for agent in agents)
            and inner_links is not None
            and outer_links is not None
            and len(inner_links) == len(agents)
            and len(outer_links) == len(agents)
        )

        mode = "latent-space" if self.use_latent_mode else "text-based"
        print(
            f"SequentialRecursiveMAS initialized with {len(agents)} agents "
            f"in {mode} mode over {num_recursion_rounds} recursion rounds."
        )

    def _build_agent_prompt(
        self,
        agent: BaseAgent,
        question: str,
        prior_context: Optional[str],
        round_idx: int,
        agent_idx: int,
    ) -> Tuple[str, str]:
        """
        Build the system and user prompts for an agent in text mode.

        In text mode, prior context from previous agents or rounds is
        included in the prompt as explicit text. This is the text-based
        approximation of what RecursiveMAS does in latent space.

        Args:
            agent: The agent to build prompts for.
            question: The original question.
            prior_context: Text output from the previous agent or round.
                           None if this is the first agent in round 1.
            round_idx: Current recursion round index (0-based).
            agent_idx: This agent's position in the pipeline (0-based).

        Returns:
            Tuple of (system_prompt, user_prompt).
        """
        system_prompt = (
            f"You are a {agent.role} in a recursive multi-agent system. "
            f"You are agent {agent_idx + 1} of {len(self.agents)}, "
            f"participating in recursion round {round_idx + 1} "
            f"of {self.num_recursion_rounds}. "
            f"Collaborate carefully and build on the work of other agents."
        )

        if prior_context:
            user_prompt = (
                f"Here is context from the previous agent in this round:\n\n"
                f"---\n{prior_context}\n---\n\n"
                f"Given this context, provide your response to the question:\n\n"
                f"{question}"
            )
        else:
            user_prompt = (
                f"Please respond to the following question:\n\n{question}"
            )

        return system_prompt, user_prompt

    def solve_text_mode(self, question: str) -> str:
        """
        Solve a question using text-based recursive collaboration.

        This is the fallback mode used when any agent does not support
        latent transfer, or when no RecursiveLink modules are provided.
        It approximates RecursiveMAS by passing text between agents
        across multiple recursion rounds.

        While less efficient than latent-space mode (no token savings,
        no gradient stability benefits), this still benefits from the
        recursive multi-round structure and works with any combination
        of local and remote agents.

        Args:
            question: The question to answer.

        Returns:
            The final text answer from the last agent in the last round.
        """
        if not self.agents:
            return ""

        print(
            f"\nSolving in text mode over "
            f"{self.num_recursion_rounds} rounds..."
        )

        # The last agent's output from the previous round feeds back
        # into the first agent at the start of the next round.
        previous_round_output: Optional[str] = None
        final_answer: str = ""

        for round_idx in range(self.num_recursion_rounds):
            print(
                f"  Round {round_idx + 1}/{self.num_recursion_rounds}"
            )
            # Within a round, each agent sees the previous agent's output
            current_context: Optional[str] = previous_round_output

            for agent_idx, agent in enumerate(self.agents):
                system_prompt, user_prompt = self._build_agent_prompt(
                    agent=agent,
                    question=question,
                    prior_context=current_context,
                    round_idx=round_idx,
                    agent_idx=agent_idx,
                )

                response = agent.generate_text(
                    prompt=user_prompt,
                    system_prompt=system_prompt,
                    max_new_tokens=512,
                    temperature=0.6,
                )

                print(
                    f"    [{agent.name}] Generated "
                    f"{response.token_count} tokens"
                )

                # This agent's output becomes context for the next agent
                current_context = response.text

                # Track the final agent's output as the answer
                if agent_idx == len(self.agents) - 1:
                    final_answer = response.text

            # The last agent's output feeds back to the first agent
            # in the next round
            previous_round_output = final_answer

        return final_answer

    def solve(self, question: str) -> str:
        """
        Solve a question using RecursiveMAS.

        Automatically selects the best available mode:
          - Latent-space mode if all agents are local and links are provided.
          - Text-based mode otherwise.

        Args:
            question: The question to answer.

        Returns:
            The final answer string.
        """
        if self.use_latent_mode:
            # Full latent-space RecursiveMAS requires direct model access
            # via the RecursiveMASSystem nn.Module. For simplicity in this
            # orchestrator, we note that latent-space inference should be
            # driven through RecursiveMASSystem.forward() directly.
            # The orchestrator's text mode is provided for quick prototyping
            # with any agent type.
            print(
                "Note: Full latent-space inference is available via "
                "RecursiveMASSystem.forward(). The orchestrator uses "
                "text-based mode for broad compatibility."
            )
        return self.solve_text_mode(question)

CHAPTER SIX: THE COMPLETE COMBINED FILE

Here is the complete, self-contained recursive_mas.py file that combines all code from Chapters Two through Five in the correct dependency order. Copy this entire block into recursive_mas.py and you are ready to run.

"""
recursive_mas.py
================
Complete implementation of RecursiveMAS-inspired multi-agent collaboration.

Based on: "RecursiveMAS: Scaling Agent Collaboration through Unified
Latent-Space Recursive Computation" (2604.25917)

This file contains:
  - InnerRecursiveLink and OuterRecursiveLink modules
  - generate_latent_thoughts() utility function
  - compute_inner_loop_loss() and train_inner_loop() for Stage 1 training
  - RecursiveMASSystem nn.Module and train_outer_loop() for Stage 2 training
  - AgentResponse dataclass and BaseAgent abstract class
  - LocalHuggingFaceAgent for local Hugging Face models
  - RemoteAPIAgent and OllamaAgent for remote/Ollama-served models
  - SequentialRecursiveMASOrchestrator for running the full pipeline

Requirements:
  pip install torch>=2.1.0 transformers>=4.40.0 accelerate>=0.27.0 requests>=2.31.0

Usage:
  See demo.py for a complete runnable example.
"""

# ============================================================
# Standard library imports
# ============================================================
import time
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import List, Optional, Tuple

# ============================================================
# Third-party imports
# ============================================================
import requests
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR

# Transformers imports are inside LocalHuggingFaceAgent.__init__
# to make the file importable even if transformers is not installed
# (useful when only using remote agents).


# ============================================================
# Section 1: RecursiveLink Modules
# ============================================================

class InnerRecursiveLink(nn.Module):
    """
    The Inner RecursiveLink operates within a single LLM agent.

    Transforms the agent's last-layer hidden state at step t into
    an input embedding for step t+1, enabling the agent to reason
    in continuous latent space without decoding to text.

    Architecture:
        R_inner(h) = h + W2 * GELU(W1 * h)

    At initialization, W2 is all zeros so the module starts as a
    pure identity (output == input). Training learns the residual
    correction on top of this stable baseline.

    Args:
        hidden_dim: Hidden dimension of the paired LLM agent.
    """

    def __init__(self, hidden_dim: int):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.W1 = nn.Linear(hidden_dim, hidden_dim, bias=True)
        self.W2 = nn.Linear(hidden_dim, hidden_dim, bias=True)
        self._initialize_weights()

    def _initialize_weights(self):
        nn.init.xavier_uniform_(self.W1.weight, gain=0.1)
        nn.init.zeros_(self.W1.bias)
        nn.init.zeros_(self.W2.weight)
        nn.init.zeros_(self.W2.bias)

    def forward(self, h: torch.Tensor) -> torch.Tensor:
        """
        Args:
            h: Shape (batch, hidden_dim) or (batch, seq, hidden_dim).
        Returns:
            Same shape as h.
        """
        return h + self.W2(F.gelu(self.W1(h)))


class OuterRecursiveLink(nn.Module):
    """
    The Outer RecursiveLink bridges two heterogeneous LLM agents.

    Transforms last-layer hidden states of a source agent into input
    embeddings aligned with the target agent's embedding space.

    Architecture:
        R_outer(h) = W3 * h + W2 * GELU(W1 * h)

    W3 handles dimensional alignment (source_dim -> target_dim).
    The nonlinear branch learns fine-grained distributional correction.

    Args:
        source_dim: Hidden dimension of the source agent.
        target_dim: Hidden dimension of the target agent.
    """

    def __init__(self, source_dim: int, target_dim: int):
        super().__init__()
        self.source_dim = source_dim
        self.target_dim = target_dim
        self.W1 = nn.Linear(source_dim, source_dim, bias=True)
        self.W2 = nn.Linear(source_dim, target_dim, bias=True)
        self.W3 = nn.Linear(source_dim, target_dim, bias=False)
        self._initialize_weights()

    def _initialize_weights(self):
        nn.init.xavier_uniform_(self.W3.weight, gain=1.0)
        nn.init.xavier_uniform_(self.W1.weight, gain=0.1)
        nn.init.zeros_(self.W1.bias)
        nn.init.zeros_(self.W2.weight)
        nn.init.zeros_(self.W2.bias)

    def forward(self, h: torch.Tensor) -> torch.Tensor:
        """
        Args:
            h: Shape (batch, seq, source_dim) or (batch, source_dim).
        Returns:
            Shape (batch, seq, target_dim) or (batch, target_dim).
        """
        return self.W3(h) + self.W2(F.gelu(self.W1(h)))


# ============================================================
# Section 2: Latent Thoughts Generation
# ============================================================

def generate_latent_thoughts(
    model: nn.Module,
    input_embeddings: torch.Tensor,
    inner_link: InnerRecursiveLink,
    num_latent_steps: int = 80,
    prior_latent_state: Optional[torch.Tensor] = None,
    training: bool = False,
) -> torch.Tensor:
    """
    Run an agent in latent-thoughts generation mode.

    The frozen model forward pass runs under torch.no_grad() to avoid
    building a computation graph through frozen parameters. The inner_link
    forward pass runs with gradient tracking when training=True.

    Args:
        model: Frozen Hugging Face causal LM.
        input_embeddings: Shape (batch, context_len, hidden_dim).
        inner_link: The InnerRecursiveLink for this agent.
        num_latent_steps: Latent steps per agent (paper default: 80).
        prior_latent_state: Optional projected latent from previous round.
                            Shape (batch, prior_len, hidden_dim).
        training: True during outer-loop training; False at inference.

    Returns:
        latent_thoughts: Shape (batch, num_latent_steps, hidden_dim).
    """
    if prior_latent_state is not None:
        current_embeddings = torch.cat(
            [prior_latent_state, input_embeddings], dim=1
        )
    else:
        current_embeddings = input_embeddings

    latent_thoughts = []

    for _ in range(num_latent_steps):
        # Frozen model forward — never builds grad graph
        with torch.no_grad():
            outputs = model(
                inputs_embeds=current_embeddings,
                output_hidden_states=True,
                use_cache=False,
            )
            # Shape: (batch, hidden_dim)
            last_hidden = outputs.hidden_states[-1][:, -1, :]

        # Trainable inner link — builds grad graph when training=True
        if training:
            next_emb = inner_link(last_hidden)
        else:
            with torch.no_grad():
                next_emb = inner_link(last_hidden)

        # Store the hidden state value (detached from model graph)
        latent_thoughts.append(last_hidden.unsqueeze(1))

        # Grow the context by one embedding per step
        current_embeddings = torch.cat(
            [current_embeddings, next_emb.unsqueeze(1)], dim=1
        )

    return torch.cat(latent_thoughts, dim=1)


# ============================================================
# Section 3: Inner-Loop Training
# ============================================================

def compute_inner_loop_loss(
    model: nn.Module,
    inner_link: InnerRecursiveLink,
    input_ids: torch.Tensor,
    target_ids: torch.Tensor,
) -> torch.Tensor:
    """
    Compute the inner-loop training loss for one agent.

    Loss = 1 - cosine_similarity(R_inner(H_0), mean(Emb(y)))

    Gradients flow through inner_link.W1 and inner_link.W2 only.
    The model is frozen; H_0 is produced under torch.no_grad().

    Args:
        model: Frozen LLM agent.
        inner_link: InnerRecursiveLink to train.
        input_ids: Shape (batch, seq_len).
        target_ids: Shape (batch, ans_len).

    Returns:
        Scalar loss tensor in [0, 2].
    """
    embedding_layer = model.get_input_embeddings()

    with torch.no_grad():
        input_embeddings = embedding_layer(input_ids)
        target_embeddings = embedding_layer(target_ids)
        initial_output = model(
            inputs_embeds=input_embeddings,
            output_hidden_states=True,
            use_cache=False,
        )
        # Shape: (batch, hidden_dim)
        H_0 = initial_output.hidden_states[-1][:, -1, :]

    # inner_link called OUTSIDE no_grad — gradients flow through W1, W2
    transformed_latent = inner_link(H_0)

    # Target: mean of answer token embeddings (detached)
    target_direction = target_embeddings.mean(dim=1)

    similarity = F.cosine_similarity(
        transformed_latent, target_direction, dim=-1
    )
    return (1.0 - similarity).mean()


def train_inner_loop(
    model: nn.Module,
    inner_link: InnerRecursiveLink,
    dataloader,
    num_epochs: int = 3,
    learning_rate: float = 1e-4,
) -> None:
    """
    Run the inner-loop training stage for one agent.

    Call this independently for each agent before outer-loop training.

    Args:
        model: Frozen LLM agent (all params must have requires_grad=False).
        inner_link: InnerRecursiveLink to train.
        dataloader: Yields (input_ids, target_ids) batches.
        num_epochs: Training epochs.
        learning_rate: Initial learning rate for AdamW with cosine schedule.
    """
    for param in model.parameters():
        param.requires_grad = False

    optimizer = AdamW(inner_link.parameters(), lr=learning_rate)
    scheduler = CosineAnnealingLR(
        optimizer, T_max=num_epochs * len(dataloader)
    )

    inner_link.train()
    model.eval()

    for epoch in range(num_epochs):
        total_loss = 0.0
        num_batches = 0

        for input_ids, target_ids in dataloader:
            optimizer.zero_grad()
            loss = compute_inner_loop_loss(model, inner_link,
                                           input_ids, target_ids)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(
                inner_link.parameters(), max_norm=1.0
            )
            optimizer.step()
            scheduler.step()
            total_loss += loss.item()
            num_batches += 1

        avg_loss = total_loss / max(num_batches, 1)
        print(
            f"  [Inner Loop] Epoch {epoch + 1}/{num_epochs} "
            f"| Loss: {avg_loss:.4f}"
        )


# ============================================================
# Section 4: RecursiveMASSystem and Outer-Loop Training
# ============================================================

class RecursiveMASSystem(nn.Module):
    """
    A RecursiveMAS system with N agents in sequential style.

    Frozen agent models are stored as a plain Python list (NOT
    nn.ModuleList) so they do not appear in self.parameters() or
    self.state_dict(). This keeps the optimizer and checkpoint
    focused on the RecursiveLink modules only.

    Inner and outer links ARE stored as nn.ModuleList so they are
    properly registered as submodules and appear in self.parameters().

    Args:
        agent_models: List of N frozen Hugging Face causal LM models.
        inner_links: List of N InnerRecursiveLink modules.
        outer_links: List of N OuterRecursiveLink modules.
                     outer_links[i] connects agent i to agent (i+1) % N.
        num_latent_steps: Latent steps per agent per round (default 80).
    """

    def __init__(
        self,
        agent_models: list,
        inner_links: list,
        outer_links: list,
        num_latent_steps: int = 80,
    ):
        super().__init__()

        if len(agent_models) < 2:
            raise ValueError("RecursiveMASSystem requires at least 2 agents.")
        if len(inner_links) != len(agent_models):
            raise ValueError("Must provide exactly one inner_link per agent.")
        if len(outer_links) != len(agent_models):
            raise ValueError("Must provide exactly one outer_link per agent.")

        self.num_agents = len(agent_models)
        self.num_latent_steps = num_latent_steps

        # Frozen models: plain list, NOT registered as submodules
        self._agent_models = agent_models

        # Trainable links: registered as submodules via nn.ModuleList
        self.inner_links = nn.ModuleList(inner_links)
        self.outer_links = nn.ModuleList(outer_links)

        # Verify all agent models are frozen
        for i, model in enumerate(self._agent_models):
            for param in model.parameters():
                if param.requires_grad:
                    raise ValueError(
                        f"Agent {i} has unfrozen parameters. "
                        f"Call model.requires_grad_(False) before "
                        f"constructing RecursiveMASSystem."
                    )

    def _run_agent_latent(
        self,
        agent_idx: int,
        context_embeddings: torch.Tensor,
        training: bool,
    ) -> torch.Tensor:
        """
        Run one agent in latent-thoughts generation mode.

        Args:
            agent_idx: Index of the agent (0-based).
            context_embeddings: Shape (batch, context_len, hidden_dim).
            training: True during outer-loop training.

        Returns:
            Shape (batch, num_latent_steps, hidden_dim).
        """
        model = self._agent_models[agent_idx]
        inner_link = self.inner_links[agent_idx]
        current_embeddings = context_embeddings
        latent_thoughts = []

        for _ in range(self.num_latent_steps):
            with torch.no_grad():
                outputs = model(
                    inputs_embeds=current_embeddings,
                    output_hidden_states=True,
                    use_cache=False,
                )
                last_hidden = outputs.hidden_states[-1][:, -1, :]

            if training:
                next_emb = inner_link(last_hidden)
            else:
                with torch.no_grad():
                    next_emb = inner_link(last_hidden)

            latent_thoughts.append(last_hidden.unsqueeze(1))
            current_embeddings = torch.cat(
                [current_embeddings, next_emb.unsqueeze(1)], dim=1
            )

        return torch.cat(latent_thoughts, dim=1)

    def forward(
        self,
        input_ids: torch.Tensor,
        num_recursion_rounds: int = 3,
    ) -> torch.Tensor:
        """
        Run the full RecursiveMAS forward pass.

        Args:
            input_ids: Shape (batch, seq_len).
            num_recursion_rounds: Number of recursive rounds (default 3).

        Returns:
            logits: Shape (batch, decode_context_len, vocab_size).
        """
        training = self.training

        # Pre-embed the input for each agent (each may have different hidden_dim)
        input_embeddings = []
        for model in self._agent_models:
            with torch.no_grad():
                emb = model.get_input_embeddings()(input_ids)
            input_embeddings.append(emb)

        # feedback_latent[i]: outer-link-projected latent that agent i
        # will receive at the start of the next round. None before round 1.
        feedback_latent: List[Optional[torch.Tensor]] = [None] * self.num_agents
        final_logits = None

        for round_idx in range(num_recursion_rounds):
            new_feedback_latent: List[Optional[torch.Tensor]] = [None] * self.num_agents

            for agent_idx in range(self.num_agents):
                # Build context: feedback from previous round + fresh input
                if feedback_latent[agent_idx] is not None:
                    context = torch.cat(
                        [feedback_latent[agent_idx],
                         input_embeddings[agent_idx]],
                        dim=1,
                    )
                else:
                    context = input_embeddings[agent_idx]

                # Run this agent in latent mode
                latent = self._run_agent_latent(
                    agent_idx=agent_idx,
                    context_embeddings=context,
                    training=training,
                )

                # Project latent into the next agent's embedding space
                next_agent_idx = (agent_idx + 1) % self.num_agents
                outer_link = self.outer_links[agent_idx]

                if training:
                    projected = outer_link(latent)
                else:
                    with torch.no_grad():
                        projected = outer_link(latent)

                # Projected latent becomes feedback for the next agent
                # in this same round (sequential: each agent sees the
                # previous agent's output within the round)
                new_feedback_latent[next_agent_idx] = projected

            feedback_latent = new_feedback_latent

            # After the final round, decode text from the last agent
            if round_idx == num_recursion_rounds - 1:
                last_agent_idx = self.num_agents - 1
                last_model = self._agent_models[last_agent_idx]

                # Reconstruct the last agent's decode context
                if feedback_latent[last_agent_idx] is not None:
                    decode_context = torch.cat(
                        [feedback_latent[last_agent_idx],
                         input_embeddings[last_agent_idx]],
                        dim=1,
                    )
                else:
                    decode_context = input_embeddings[last_agent_idx]

                if training:
                    decode_output = last_model(
                        inputs_embeds=decode_context,
                        output_hidden_states=False,
                        use_cache=False,
                    )
                else:
                    with torch.no_grad():
                        decode_output = last_model(
                            inputs_embeds=decode_context,
                            output_hidden_states=False,
                            use_cache=False,
                        )
                final_logits = decode_output.logits

        return final_logits


def train_outer_loop(
    system: RecursiveMASSystem,
    dataloader,
    num_epochs: int = 3,
    learning_rate: float = 1e-4,
    num_recursion_rounds: int = 3,
) -> None:
    """
    Run the outer-loop training stage for the full RecursiveMAS system.

    Only RecursiveLink parameters are updated. Agent models stay frozen.
    Gradients flow through every outer link and inner link in every round
    (shared credit assignment across the full recursive computation graph).

    Args:
        system: RecursiveMASSystem with frozen agent models.
        dataloader: Yields (input_ids, target_ids) batches.
                    Use -100 in target_ids for positions to ignore.
        num_epochs: Training epochs.
        learning_rate: Initial learning rate for AdamW with cosine schedule.
        num_recursion_rounds: Recursive rounds during training (default 3).
    """
    # system.parameters() returns only RecursiveLink params because
    # agent models are stored as a plain list (not nn.ModuleList)
    optimizer = AdamW(system.parameters(), lr=learning_rate)
    scheduler = CosineAnnealingLR(
        optimizer, T_max=num_epochs * len(dataloader)
    )
    criterion = nn.CrossEntropyLoss(ignore_index=-100)

    system.train()

    for epoch in range(num_epochs):
        total_loss = 0.0
        num_batches = 0

        for input_ids, target_ids in dataloader:
            optimizer.zero_grad()

            logits = system(
                input_ids=input_ids,
                num_recursion_rounds=num_recursion_rounds,
            )

            # Align logits with target: take the last seq_len positions
            seq_len = target_ids.shape[1]
            logits_for_loss = logits[:, -seq_len:, :]

            batch_size, ans_len, vocab_size = logits_for_loss.shape
            loss = criterion(
                logits_for_loss.reshape(batch_size * ans_len, vocab_size),
                target_ids.reshape(batch_size * ans_len),
            )

            # Backpropagate through all recursion rounds
            loss.backward()
            torch.nn.utils.clip_grad_norm_(
                system.parameters(), max_norm=1.0
            )
            optimizer.step()
            scheduler.step()

            total_loss += loss.item()
            num_batches += 1

        avg_loss = total_loss / max(num_batches, 1)
        print(
            f"  [Outer Loop] Epoch {epoch + 1}/{num_epochs} "
            f"| Loss: {avg_loss:.4f}"
        )


# ============================================================
# Section 5: Agent Abstractions
# ============================================================

@dataclass
class AgentResponse:
    """
    Unified response from any agent, local or remote.

    Attributes:
        text: Decoded text response.
        hidden_states: Last-layer hidden states (local agents only).
                       Shape: (1, 1, hidden_dim) for the last generated
                       token's last layer, or None for remote agents.
        token_count: Number of tokens generated.
    """
    text: str
    hidden_states: Optional[torch.Tensor]
    token_count: int


class BaseAgent(ABC):
    """Abstract base class for all RecursiveMAS agents."""

    def __init__(self, name: str, role: str):
        self.name = name
        self.role = role

    @abstractmethod
    def generate_text(
        self,
        prompt: str,
        system_prompt: Optional[str] = None,
        max_new_tokens: int = 512,
        temperature: float = 0.6,
    ) -> AgentResponse:
        pass

    @property
    @abstractmethod
    def supports_latent_transfer(self) -> bool:
        pass

    def __repr__(self) -> str:
        return (
            f"{self.__class__.__name__}"
            f"(name={self.name!r}, role={self.role!r})"
        )


class LocalHuggingFaceAgent(BaseAgent):
    """
    RecursiveMAS agent backed by a local Hugging Face model.

    Supports full latent-space state transfer via hidden states.
    All model parameters are frozen at construction time.

    Args:
        model_name_or_path: HF model ID or local path.
        name: Agent name.
        role: Role description for system prompts.
        device: "cuda", "mps", or "cpu". Note: MPS support varies by model.
        torch_dtype: torch.float16 recommended for GPU; torch.float32 for CPU.
    """

    def __init__(
        self,
        model_name_or_path: str,
        name: str,
        role: str,
        device: str = "cuda",
        torch_dtype: torch.dtype = torch.float16,
    ):
        super().__init__(name=name, role=role)
        self.device = device
        self.model_name = model_name_or_path

        # Import here so the file is importable without transformers installed
        from transformers import AutoTokenizer, AutoModelForCausalLM

        print(f"Loading {name} from {model_name_or_path}...")

        self.tokenizer = AutoTokenizer.from_pretrained(
            model_name_or_path, trust_remote_code=True
        )
        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token

        self.model = AutoModelForCausalLM.from_pretrained(
            model_name_or_path,
            torch_dtype=torch_dtype,
            device_map=device,
            trust_remote_code=True,
        )

        # Freeze all parameters immediately
        for param in self.model.parameters():
            param.requires_grad = False
        self.model.eval()

        self.hidden_dim = self.model.config.hidden_size
        total_params = sum(p.numel() for p in self.model.parameters())
        print(
            f"  Loaded {name}: hidden_dim={self.hidden_dim}, "
            f"params={total_params:,} (all frozen)"
        )

    @property
    def supports_latent_transfer(self) -> bool:
        return True

    def get_hidden_dim(self) -> int:
        return self.hidden_dim

    def get_raw_model(self) -> nn.Module:
        """Return the underlying frozen HF model for use in RecursiveMASSystem."""
        return self.model

    def get_embeddings(self, text: str) -> torch.Tensor:
        """
        Embed text using the model's input embedding layer.

        Returns:
            Shape (1, seq_len, hidden_dim).
        """
        from transformers import AutoTokenizer  # already loaded, no-op
        inputs = self.tokenizer(
            text, return_tensors="pt", truncation=True, max_length=512
        ).to(self.device)
        with torch.no_grad():
            embeddings = self.model.get_input_embeddings()(inputs["input_ids"])
        return embeddings

    def generate_text(
        self,
        prompt: str,
        system_prompt: Optional[str] = None,
        max_new_tokens: int = 512,
        temperature: float = 0.6,
    ) -> AgentResponse:
        """
        Generate text using standard autoregressive decoding.

        Returns:
            AgentResponse. hidden_states shape: (1, 1, hidden_dim)
            representing the last layer of the last generated token.
        """
        messages = []
        if system_prompt:
            messages.append({"role": "system", "content": system_prompt})
        messages.append({"role": "user", "content": prompt})

        formatted_prompt = self.tokenizer.apply_chat_template(
            messages, tokenize=False, add_generation_prompt=True
        )
        inputs = self.tokenizer(
            formatted_prompt,
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=2048,
        ).to(self.device)
        input_length = inputs["input_ids"].shape[1]

        with torch.no_grad():
            outputs = self.model.generate(
                **inputs,
                max_new_tokens=max_new_tokens,
                temperature=temperature if temperature > 0 else 1.0,
                top_p=0.95,
                do_sample=(temperature > 0),
                output_hidden_states=True,
                return_dict_in_generate=True,
                pad_token_id=self.tokenizer.pad_token_id,
            )

        generated_ids = outputs.sequences[:, input_length:]
        generated_text = self.tokenizer.decode(
            generated_ids[0], skip_special_tokens=True
        )

        # outputs.hidden_states: tuple[tuple[Tensor]]
        # Outer tuple: one entry per generated token
        # Inner tuple: one entry per Transformer layer
        # [-1][-1]: last layer of last generated token
        # Shape: (batch=1, seq_pos=1, hidden_dim)
        last_step_hidden = None
        if outputs.hidden_states:
            last_step_hidden = outputs.hidden_states[-1][-1]

        return AgentResponse(
            text=generated_text,
            hidden_states=last_step_hidden,
            token_count=generated_ids.shape[1],
        )


class RemoteAPIAgent(BaseAgent):
    """
    RecursiveMAS agent backed by a remote OpenAI-compatible API.

    Does NOT support latent-space transfer (no access to hidden states).
    Works with: OpenAI, Azure OpenAI, Ollama, vLLM, LM Studio, etc.

    Args:
        api_base_url: API base URL, e.g. "https://api.openai.com/v1".
        model_id: Model identifier, e.g. "gpt-4o-mini".
        api_key: API key. Use "ollama" for Ollama (no real key needed).
        name: Agent name.
        role: Role description.
        timeout: HTTP timeout in seconds.
        max_retries: Retry attempts with exponential backoff.
    """

    def __init__(
        self,
        api_base_url: str,
        model_id: str,
        api_key: str,
        name: str,
        role: str,
        timeout: int = 60,
        max_retries: int = 3,
    ):
        super().__init__(name=name, role=role)
        self.api_base_url = api_base_url.rstrip("/")
        self.model_id = model_id
        self.timeout = timeout
        self.max_retries = max_retries
        self.headers = {
            "Authorization": f"Bearer {api_key}",
            "Content-Type": "application/json",
        }

    @property
    def supports_latent_transfer(self) -> bool:
        return False

    def generate_text(
        self,
        prompt: str,
        system_prompt: Optional[str] = None,
        max_new_tokens: int = 512,
        temperature: float = 0.6,
    ) -> AgentResponse:
        messages = []
        if system_prompt:
            messages.append({"role": "system", "content": system_prompt})
        messages.append({"role": "user", "content": prompt})

        payload = {
            "model": self.model_id,
            "messages": messages,
            "max_tokens": max_new_tokens,
            "temperature": temperature,
            "top_p": 0.95,
        }

        last_error = None
        for attempt in range(self.max_retries):
            try:
                response = requests.post(
                    f"{self.api_base_url}/chat/completions",
                    headers=self.headers,
                    json=payload,
                    timeout=self.timeout,
                )
                response.raise_for_status()
                data = response.json()
                generated_text = data["choices"][0]["message"]["content"]
                token_count = data.get("usage", {}).get(
                    "completion_tokens", len(generated_text.split())
                )
                return AgentResponse(
                    text=generated_text,
                    hidden_states=None,
                    token_count=token_count,
                )
            except requests.exceptions.RequestException as e:
                last_error = e
                if attempt < self.max_retries - 1:
                    wait_time = 2 ** attempt
                    print(
                        f"  [{self.name}] Attempt {attempt + 1}/"
                        f"{self.max_retries} failed, retrying in "
                        f"{wait_time}s: {e}"
                    )
                    time.sleep(wait_time)

        raise RuntimeError(
            f"[{self.name}] All {self.max_retries} attempts failed. "
            f"Last error: {last_error}"
        )


class OllamaAgent(RemoteAPIAgent):
    """
    Agent backed by a local Ollama server.

    Ollama exposes an OpenAI-compatible REST API at localhost:11434.

    Setup:
        ollama serve
        ollama pull qwen2.5:1.5b   # or any supported model

    Args:
        model_id: Ollama model name, e.g. "qwen2.5:1.5b", "llama3.2:1b".
        name: Agent name.
        role: Role description.
        host: Ollama server host (default "localhost").
        port: Ollama server port (default 11434).
        timeout: HTTP timeout in seconds (default 120 for CPU inference).
        max_retries: Retry attempts (default 3).
    """

    def __init__(
        self,
        model_id: str,
        name: str,
        role: str,
        host: str = "localhost",
        port: int = 11434,
        timeout: int = 120,
        max_retries: int = 3,
    ):
        super().__init__(
            api_base_url=f"http://{host}:{port}/v1",
            model_id=model_id,
            api_key="ollama",
            name=name,
            role=role,
            timeout=timeout,
            max_retries=max_retries,
        )


# ============================================================
# Section 6: Orchestrator
# ============================================================

class SequentialRecursiveMASOrchestrator:
    """
    Orchestrates sequential-style RecursiveMAS collaboration.

    Manages a chain of agents (Planner -> Critic -> Solver -> ...) through
    multiple recursion rounds. Uses latent-space mode when all agents are
    local and links are provided; falls back to text-based mode otherwise.

    Args:
        agents: List of BaseAgent instances in pipeline order (min 2).
        inner_links: InnerRecursiveLink per agent (latent mode only).
        outer_links: OuterRecursiveLink per agent (latent mode only).
        num_latent_steps: Latent steps per agent per round (default 80).
        num_recursion_rounds: Recursive rounds (default 3).
    """

    def __init__(
        self,
        agents: List[BaseAgent],
        inner_links: Optional[List[InnerRecursiveLink]] = None,
        outer_links: Optional[List[OuterRecursiveLink]] = None,
        num_latent_steps: int = 80,
        num_recursion_rounds: int = 3,
    ):
        if len(agents) < 2:
            raise ValueError(
                "Sequential RecursiveMAS requires at least 2 agents."
            )
        self.agents = agents
        self.inner_links = inner_links
        self.outer_links = outer_links
        self.num_latent_steps = num_latent_steps
        self.num_recursion_rounds = num_recursion_rounds

        self.use_latent_mode = (
            all(agent.supports_latent_transfer for agent in agents)
            and inner_links is not None
            and outer_links is not None
            and len(inner_links) == len(agents)
            and len(outer_links) == len(agents)
        )

        mode = "latent-space" if self.use_latent_mode else "text-based"
        print(
            f"SequentialRecursiveMAS: {len(agents)} agents, {mode} mode, "
            f"{num_recursion_rounds} recursion rounds."
        )

    def _build_agent_prompt(
        self,
        agent: BaseAgent,
        question: str,
        prior_context: Optional[str],
        round_idx: int,
        agent_idx: int,
    ) -> Tuple[str, str]:
        system_prompt = (
            f"You are a {agent.role} in a recursive multi-agent system. "
            f"You are agent {agent_idx + 1} of {len(self.agents)}, "
            f"in recursion round {round_idx + 1} of {self.num_recursion_rounds}. "
            f"Build carefully on the work of other agents."
        )
        if prior_context:
            user_prompt = (
                f"Context from the previous agent:\n\n"
                f"---\n{prior_context}\n---\n\n"
                f"Your response to the question:\n\n{question}"
            )
        else:
            user_prompt = f"Please respond to:\n\n{question}"
        return system_prompt, user_prompt

    def solve_text_mode(self, question: str) -> str:
        """
        Solve using text-based recursive collaboration.

        Works with any combination of local and remote agents.
        Each agent sees the previous agent's text output within a round;
        the last agent's output feeds back to the first agent in the
        next round.

        Args:
            question: The question to answer.

        Returns:
            Final text answer from the last agent in the last round.
            Returns empty string if agents list is empty or rounds is 0.
        """
        if not self.agents or self.num_recursion_rounds == 0:
            return ""

        print(
            f"\nSolving in text mode: "
            f"{self.num_recursion_rounds} rounds, "
            f"{len(self.agents)} agents per round."
        )

        previous_round_output: Optional[str] = None
        final_answer: str = ""

        for round_idx in range(self.num_recursion_rounds):
            print(f"  Round {round_idx + 1}/{self.num_recursion_rounds}")
            current_context: Optional[str] = previous_round_output

            for agent_idx, agent in enumerate(self.agents):
                system_prompt, user_prompt = self._build_agent_prompt(
                    agent=agent,
                    question=question,
                    prior_context=current_context,
                    round_idx=round_idx,
                    agent_idx=agent_idx,
                )
                response = agent.generate_text(
                    prompt=user_prompt,
                    system_prompt=system_prompt,
                    max_new_tokens=512,
                    temperature=0.6,
                )
                print(
                    f"    [{agent.name}] {response.token_count} tokens"
                )
                current_context = response.text
                if agent_idx == len(self.agents) - 1:
                    final_answer = response.text

            previous_round_output = final_answer

        return final_answer

    def solve(self, question: str) -> str:
        """
        Solve a question using RecursiveMAS.

        Selects latent-space mode if available, otherwise text-based mode.

        Args:
            question: The question to answer.

        Returns:
            The final answer string.
        """
        if self.use_latent_mode:
            print(
                "Note: Full latent-space inference is available via "
                "RecursiveMASSystem.forward(). The orchestrator uses "
                "text-based mode for broad compatibility."
            )
        return self.solve_text_mode(question)

CHAPTER SEVEN: THE DEMO SCRIPT

Save the following as demo.py in the same directory as recursive_mas.py:

"""
demo.py
=======
Demonstration of RecursiveMAS sequential collaboration.

Run with:
    python demo.py --mode ollama
    python demo.py --mode openai
    python demo.py --mode ollama --question "Your question here"

Prerequisites:
    Ollama mode:
        ollama serve
        ollama pull qwen2.5:1.5b
        ollama pull llama3.2:1b

    OpenAI mode:
        export OPENAI_API_KEY="sk-..."
"""

import os
import argparse
from typing import List

# Import everything from the combined module
from recursive_mas import (
    BaseAgent,
    OllamaAgent,
    RemoteAPIAgent,
    SequentialRecursiveMASOrchestrator,
)


def create_ollama_agents() -> List[BaseAgent]:
    """
    Create a three-agent sequential system using local Ollama models.

    Uses small, fast models that run on a laptop (CPU is fine):
      - Planner:  qwen2.5:1.5b  (~1 GB)
      - Critic:   llama3.2:1b   (~700 MB)
      - Solver:   qwen2.5:1.5b  (~1 GB, same model, different role)
    """
    planner = OllamaAgent(
        model_id="qwen2.5:1.5b",
        name="Planner",
        role=(
            "expert problem decomposer who breaks complex questions into "
            "clear, structured step-by-step plans"
        ),
    )
    critic = OllamaAgent(
        model_id="llama3.2:1b",
        name="Critic",
        role=(
            "rigorous evaluator who identifies flaws, gaps, and improvements "
            "in proposed plans and solutions"
        ),
    )
    solver = OllamaAgent(
        model_id="qwen2.5:1.5b",
        name="Solver",
        role=(
            "precise problem solver who produces final, well-reasoned answers "
            "based on the plan and critique provided"
        ),
    )
    return [planner, critic, solver]


def create_openai_agents(api_key: str) -> List[BaseAgent]:
    """
    Create a three-agent sequential system using the OpenAI API.

    Uses gpt-4o-mini for all three roles for cost efficiency.
    """
    base_url = "https://api.openai.com/v1"
    model = "gpt-4o-mini"

    planner = RemoteAPIAgent(
        api_base_url=base_url,
        model_id=model,
        api_key=api_key,
        name="Planner",
        role=(
            "expert problem decomposer who breaks complex questions into "
            "clear, structured step-by-step plans"
        ),
    )
    critic = RemoteAPIAgent(
        api_base_url=base_url,
        model_id=model,
        api_key=api_key,
        name="Critic",
        role=(
            "rigorous evaluator who identifies flaws, gaps, and improvements "
            "in proposed plans and solutions"
        ),
    )
    solver = RemoteAPIAgent(
        api_base_url=base_url,
        model_id=model,
        api_key=api_key,
        name="Solver",
        role=(
            "precise problem solver who produces final, well-reasoned answers "
            "based on the plan and critique provided"
        ),
    )
    return [planner, critic, solver]


def run_demo(agents: List[BaseAgent], question: str) -> None:
    """Run a RecursiveMAS demo with the given agents and question."""
    print("\n" + "=" * 60)
    print("RecursiveMAS Sequential Collaboration Demo")
    print("=" * 60)
    print(f"Question : {question}")
    print(f"Agents   : {[a.name for a in agents]}")
    print(f"Rounds   : 3")
    print("-" * 60)

    orchestrator = SequentialRecursiveMASOrchestrator(
        agents=agents,
        num_recursion_rounds=3,
    )

    answer = orchestrator.solve(question)

    print("\n" + "=" * 60)
    print("FINAL ANSWER:")
    print("-" * 60)
    print(answer)
    print("=" * 60)


def main() -> None:
    parser = argparse.ArgumentParser(
        description="RecursiveMAS Sequential Collaboration Demo",
        formatter_class=argparse.RawDescriptionHelpFormatter,
        epilog=__doc__,
    )
    parser.add_argument(
        "--mode",
        choices=["ollama", "openai"],
        default="ollama",
        help="Backend to use for agents (default: ollama).",
    )
    parser.add_argument(
        "--question",
        type=str,
        default=(
            "A train travels from City A to City B at 60 km/h and returns "
            "at 40 km/h. What is the average speed for the entire journey?"
        ),
        help="The question to answer.",
    )
    args = parser.parse_args()

    if args.mode == "ollama":
        print("Using local Ollama agents.")
        print("Ensure Ollama is running ('ollama serve') and models are pulled:")
        print("  ollama pull qwen2.5:1.5b")
        print("  ollama pull llama3.2:1b")
        agents = create_ollama_agents()

    elif args.mode == "openai":
        api_key = os.environ.get("OPENAI_API_KEY", "")
        if not api_key:
            raise ValueError(
                "Set the OPENAI_API_KEY environment variable to use OpenAI. "
                "Example: export OPENAI_API_KEY='sk-...'"
            )
        print("Using OpenAI API agents (gpt-4o-mini).")
        agents = create_openai_agents(api_key)

    else:
        raise ValueError(f"Unknown mode: {args.mode}")

    run_demo(agents=agents, question=args.question)


if __name__ == "__main__":
    main()

CHAPTER EIGHT: PERFORMANCE RESULTS AND WHAT THEY MEAN

Let us take a step back and look at what the paper actually reports in terms of numbers, because the results are genuinely impressive and worth understanding in detail.

The paper evaluates RecursiveMAS across nine benchmarks. On MATH500, a collection of 500 competition-style math problems, RecursiveMAS achieves 88.0% accuracy. To put that in context, a single agent with LoRA fine-tuning achieves 83.1%, and the strongest text-based recursive baseline (Recursive-TextMAS) achieves 85.8%. RecursiveMAS beats the text-based recursive baseline by 2.2 percentage points while being faster and using fewer tokens.

On GPQA-Diamond, a graduate-level science and engineering benchmark that is genuinely difficult even for expert humans, RecursiveMAS achieves 66.2% compared to 62.5% for TextGrad and 61.6% for Recursive-TextMAS. On LiveCodeBench, a competitive programming benchmark, RecursiveMAS achieves 42.9% compared to 39.8% for TextGrad and 38.7% for Recursive-TextMAS.

On the AIME 2025 and AIME 2026 benchmarks (the American Invitational Mathematics Examination, which is extremely difficult), RecursiveMAS achieves 86.7% on both, compared to 73.3% for the best single-agent baseline and 73.3% for Recursive-TextMAS on AIME 2025. That is a 13.4 percentage point improvement over the text-based recursive baseline on one of the hardest math benchmarks in existence.

The efficiency gains are equally striking. As the number of recursion rounds increases from 1 to 3, RecursiveMAS becomes progressively faster relative to the text-based baseline, because each additional round of latent collaboration adds very little overhead (just the RecursiveLink transformations) compared to the text-based approach which must decode and re-encode all intermediate outputs. By round 3, RecursiveMAS is generating 34.6% to 75.6% fewer tokens than the text-based baseline while achieving higher accuracy.

The following table summarizes the key comparison from the paper:

MethodMATH500AIME25GPQA-DLiveCodeMedQA
Single Agent (LoRA)83.1%70.0%62.0%37.4%76.1%
Single Agent (Full-SFT)83.2%73.3%62.8%38.6%77.0%
Mixture-of-Agents (MoA)79.8%60.0%47.6%27.0%57.5%
TextGrad84.9%73.3%62.5%39.8%77.2%
LoopLM84.6%66.7%48.1%24.9%56.4%
Recursive-TextMAS85.8%73.3%61.6%38.7%77.0%
RecursiveMAS88.0%86.7%66.2%42.9%79.3%

RecursiveMAS wins on every single benchmark, often by a substantial margin. The improvement over Recursive-TextMAS (which has the same structure but uses text instead of latent states) is particularly telling because it isolates the contribution of the latent-space communication itself.

The RecursiveLink design ablation is also worth examining. The paper compares four architectural variants:

RecursiveLink DesignMATH500GPQA-DLiveCodeBench
1-Layer (no residual)84.4%63.2%40.1%
1-Layer + Residual86.7%65.3%41.4%
2-Layer (no residual)85.6%64.5%40.5%
2-Layer + Residual (ours)88.0%66.2%42.9%

The residual connection is clearly important — adding it to a 1-layer design (going from 84.4% to 86.7% on MATH500) gives a bigger boost than adding a second layer without the residual (going from 84.4% to 85.6%). The full 2-layer residual design is best across the board. This validates the design intuition: the residual connection forces the module to learn only the distributional shift, which is easier to learn and more stable to train.

The latent thoughts length ablation is also instructive:

Latent Steps (m)MATH500GPQA-DLiveCodeBench
0 (no latent)83.3%61.4%38.1%
1684.9%62.0%40.3%
3285.2%62.8%40.7%
4885.6%63.6%41.4%
6486.8%64.1%42.0%
8086.8%64.2%42.5%
9686.5%64.5%42.2%
11286.9%64.3%42.6%
12886.7%64.4%42.6%

Performance improves steadily from m=0 to m=80, then plateaus. This means you do not need to run hundreds of latent steps — 80 is enough. And even m=16 gives a meaningful improvement over no latent thoughts at all, which is encouraging for resource-constrained deployments.


CHAPTER NINE: THE BIGGER PICTURE AND WHAT COMES NEXT

RecursiveMAS represents a genuinely new way of thinking about multi-agent AI systems. Instead of treating agents as black boxes that communicate through text, it treats the entire multi-agent system as a single unified computation that happens to be distributed across multiple models. The RecursiveLink is the key enabler: a tiny, trainable module that knows how to translate between the hidden spaces of different models.

This has several implications that are worth thinking about carefully.

The first implication is about the nature of agent communication. When we build text-based multi-agent systems, we implicitly assume that text is the right medium for agents to share information. RecursiveMAS challenges this assumption. Text is great for communicating with humans, but between AI agents, continuous vectors are richer, faster, and more gradient-friendly. The paper's theoretical results make this precise: text-based communication introduces gradient vanishing and computational overhead that latent-space communication avoids.

The second implication is about the scalability of multi-agent systems. The paper shows a clear scaling law: more recursion rounds means better performance, and this improvement is consistent across all benchmarks and all collaboration patterns. This is exciting because it means you can trade compute for accuracy in a predictable way, just as you can with larger models or longer context windows. The "recursion depth" becomes a new axis of scaling.

The third implication is about the cost of intelligence. RecursiveMAS achieves its best results with only 13.12 million trainable parameters — a tiny fraction of the total parameter count of the agents it connects. This suggests that a lot of the "intelligence" in a multi-agent system is not in the individual agents themselves, but in how they communicate. Training better communication channels (the RecursiveLinks) is more efficient than training better agents.

The fourth implication is about heterogeneity. RecursiveMAS works with agents of different sizes and from different model families. The Outer RecursiveLink handles the dimensional mismatch between, say, a 1.7B parameter Qwen model and a 7B parameter Gemma model. This means you can build systems that mix and match models based on their strengths, without worrying about compatibility.

For developers building agentic systems today, the most immediately applicable lessons from this paper are these. First, consider whether your agents really need to communicate through text, or whether a more direct form of information transfer would be better. Second, think about your multi-agent system as a single unified entity that should be optimized as a whole, not as a collection of independent agents that happen to talk to each other. Third, recognize that recursive refinement — having agents iterate on their answers across multiple rounds — is a powerful and underutilized technique. And fourth, remember that the connections between agents (the "links") may be just as important as the agents themselves.

The code in this tutorial gives you a foundation to start experimenting with these ideas. The text-based orchestrator works today with any combination of local and remote models. The RecursiveLink modules and the latent-space generation code give you the building blocks for the full system when you have access to local models. And the agent abstraction makes it easy to swap in different models as you experiment.

The field of agentic AI is moving fast, and RecursiveMAS points toward a future where the boundaries between individual models and multi-agent systems become increasingly blurred. The "system" is the model, and the "model" is the system. That is a genuinely exciting direction, and one that developers with a solid understanding of the fundamentals — which you now have — are well positioned to explore.

Happy building.


APPENDIX: QUICK REFERENCE

RecursiveMAS Key Numbers (from the paper)


Metric

Value

Average accuracy improvement over baselines
8.3%
Inference speedup range1.2x – 2.4x
Token usage reduction range34.6% – 75.6%
Optimal latent thought steps (m)~80
Trainable parameters (RecursiveLinks only)13.12M (0.31% of total)
GPU memory for training15.29 GB
Benchmarks evaluated9
Collaboration patterns supported4
Model families tested4 (Qwen, LLaMA, Gemma, Mistral)

Formulas

$$R_{\text{inner}}(h) = h + W_2 \cdot \sigma(W_1 \cdot h)$$

$$R_{\text{outer}}(h) = W_3 \cdot h + W_2 \cdot \sigma(W_1 \cdot h)$$

$$L_{\text{inner}} = 1 - \cos!\left(R_{\text{inner}}(H_0),; \overline{\text{Emb}(y)}\right)$$

$$L_{\text{outer}} = \text{CrossEntropy}!\left(S^{(n)}!\left(\cdots S^{(1)}(x)\cdots\right), y\right)$$

Recommended Hyperparameters (from the paper)

HyperparameterValue
OptimizerAdamW
Learning rate1e-4 with cosine schedule
Batch size4
Temperature (reasoning tasks)0.6
Temperature (code generation)0.2
Top-p0.95
Training recursion rounds3
Inference recursion rounds3
Gradient clip norm1.0