Retrieval-Augmented Generation (RAG)
Retrieval-Augmented Generation represents one of the most significant advancements for applying LLMs to large codebases. RAG combines the generative capabilities of LLMs with retrieval systems that can access relevant code snippets or documentation on demand. This approach allows the model to "look up" information rather than relying solely on its parametric knowledge.
In a RAG system for code analysis, the codebase is first indexed into a vector database, where code snippets are embedded as high-dimensional vectors that capture their semantic meaning. When a developer asks a question about the codebase, the system retrieves the most relevant code snippets and provides them as context to the LLM, which then generates an informed response.
Here's an example of implementing a simple RAG system for code analysis:
import os
from langchain.embeddings import OpenAIEmbeddings
from langchain.vectorstores import FAISS
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.llms import OpenAI
from langchain.chains import RetrievalQA
# Step 1: Load and process the codebase
def load_codebase(directory):
documents = []
for root, _, files in os.walk(directory):
for file in files:
if file.endswith('.py'): # Filter for Python files
file_path = os.path.join(root, file)
with open(file_path, 'r', encoding='utf-8') as f:
content = f.read()
documents.append({"content": content, "path": file_path})
return documents
# Step 2: Split the documents into chunks
def split_documents(documents):
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=1000,
chunk_overlap=200,
separators=["\nclass ", "\ndef ", "\n\n", "\n", " ", ""]
)
chunks = []
for doc in documents:
texts = text_splitter.split_text(doc["content"])
for text in texts:
chunks.append({
"content": text,
"path": doc["path"]
})
return chunks
# Step 3: Create vector embeddings and index
def create_vector_index(chunks):
embeddings = OpenAIEmbeddings()
texts = [chunk["content"] for chunk in chunks]
metadatas = [{"path": chunk["path"]} for chunk in chunks]
vector_store = FAISS.from_texts(texts, embeddings, metadatas=metadatas)
return vector_store
# Step 4: Set up the RAG system
def setup_rag_system(vector_store):
llm = OpenAI(temperature=0)
qa_chain = RetrievalQA.from_chain_type(
llm=llm,
chain_type="stuff",
retriever=vector_store.as_retriever()
)
return qa_chain
# Example usage
codebase_dir = "./my_project"
documents = load_codebase(codebase_dir)
chunks = split_documents(documents)
vector_store = create_vector_index(chunks)
qa_system = setup_rag_system(vector_store)
# Query the system
response = qa_system.run("How is error handling implemented in the authentication module?")
print(response)
This code demonstrates a basic RAG implementation for code analysis. It processes a codebase by loading Python files, splitting them into manageable chunks, creating vector embeddings using OpenAI's embedding model, and setting up a question-answering system. The system retrieves relevant code chunks when queried and uses them as context for generating responses.
RAG systems significantly improve the accuracy of LLM responses about large codebases by grounding the model's responses in actual code rather than relying on potentially outdated or incorrect information from its training data.
GraphRAG: Leveraging Code Structure
GraphRAG extends the RAG concept by incorporating structural information about the codebase in the form of a graph. Traditional RAG treats code snippets as independent chunks, potentially missing important relationships between components. GraphRAG addresses this by modeling the codebase as a graph where nodes represent code entities (functions, classes, modules) and edges represent relationships (calls, imports, inheritance).
When retrieving context for an LLM, GraphRAG can traverse the graph to find not just textually similar code but functionally related components. For example, if a developer asks about a function, GraphRAG can provide context about its callers, callees, and related classes.
Here's a simplified example of implementing a GraphRAG system:
import networkx as nx
from langchain.embeddings import OpenAIEmbeddings
from langchain.llms import OpenAI
class CodebaseGraph:
def __init__(self, codebase_dir):
self.graph = nx.DiGraph()
self.embeddings = OpenAIEmbeddings()
self.llm = OpenAI(temperature=0)
self.build_graph(codebase_dir)
def build_graph(self, codebase_dir):
# Parse the codebase and extract entities and relationships
# This is a simplified placeholder - actual implementation would use
# static analysis tools like AST or more advanced parsers
# Example: Add nodes for functions and classes
self.graph.add_node("AuthManager", type="class",
file="auth/manager.py",
embedding=self.get_embedding("class AuthManager..."))
self.graph.add_node("validate_token", type="function",
file="auth/validation.py",
embedding=self.get_embedding("def validate_token(token)..."))
# Add edges for relationships
self.graph.add_edge("AuthManager", "validate_token", type="calls")
def get_embedding(self, text):
return self.embeddings.embed_query(text)
def query(self, question):
# 1. Convert question to embedding
question_embedding = self.get_embedding(question)
# 2. Find relevant nodes using embedding similarity
relevant_nodes = self.find_relevant_nodes(question_embedding)
# 3. Expand to related nodes using graph traversal
expanded_nodes = self.expand_context(relevant_nodes)
# 4. Construct context from selected nodes
context = self.build_context(expanded_nodes)
# 5. Query LLM with the constructed context
prompt = f"Based on the following code context:\n\n{context}\n\nQuestion: {question}\n\nAnswer:"
return self.llm(prompt)
def find_relevant_nodes(self, query_embedding, top_k=3):
# Find nodes with similar embeddings
# Simplified implementation - would use proper vector similarity
return ["AuthManager", "validate_token"]
def expand_context(self, nodes, max_distance=2):
# Expand from initial nodes using graph traversal
expanded = set(nodes)
frontier = set(nodes)
for _ in range(max_distance):
new_frontier = set()
for node in frontier:
neighbors = list(self.graph.successors(node)) + list(self.graph.predecessors(node))
new_frontier.update(neighbors)
expanded.update(new_frontier)
frontier = new_frontier
return expanded
def build_context(self, nodes):
# Construct a coherent context from the selected nodes
context = []
for node in nodes:
node_data = self.graph.nodes[node]
context.append(f"File: {node_data['file']}\nType: {node_data['type']}\nName: {node}\n")
return "\n".join(context)
# Example usage
graph_rag = CodebaseGraph("./my_project")
response = graph_rag.query("How does token validation work in the authentication system?")
print(response)
This example demonstrates the core concepts of GraphRAG. It builds a graph representation of the codebase, with nodes for code entities and edges for their relationships. When queried, it finds relevant nodes based on embedding similarity, expands to related nodes through graph traversal, and constructs a context that preserves the structural relationships in the code.
GraphRAG is particularly valuable for understanding complex codebases where the relationships between components are as important as the components themselves. It helps developers answer questions that require understanding code flow, inheritance hierarchies, and module dependencies.
Code Summarization Techniques
Summarization techniques help developers understand large codebases by condensing verbose code into concise descriptions. LLMs can generate different types of summaries, from high-level overviews of entire modules to detailed explanations of specific functions.
Hierarchical summarization is particularly effective for large codebases. It works by summarizing individual functions first, then using those summaries to create higher-level summaries of classes, modules, and entire systems. This preserves important details while providing a manageable overview.
Here's an example of implementing hierarchical code summarization:
from langchain.llms import OpenAI
import ast
import os
class CodeSummarizer:
def __init__(self):
self.llm = OpenAI(temperature=0.2)
def summarize_function(self, func_code):
prompt = f"""
Summarize the following function in 1-2 sentences, focusing on its purpose and key functionality:
```
{func_code}
```
Summary:
"""
return self.llm(prompt)
def summarize_class(self, class_code, method_summaries):
methods_context = "\n".join([f"- {name}: {summary}" for name, summary in method_summaries.items()])
prompt = f"""
Summarize the following class based on its structure and method summaries:
Class code:
```
{class_code}
```
Method summaries:
{methods_context}
Class summary:
"""
return self.llm(prompt)
def summarize_module(self, file_path):
with open(file_path, 'r') as file:
content = file.read()
try:
module = ast.parse(content)
# Extract and summarize functions
function_summaries = {}
for node in ast.walk(module):
if isinstance(node, ast.FunctionDef):
func_code = ast.get_source_segment(content, node)
function_summaries[node.name] = self.summarize_function(func_code)
# Extract and summarize classes
class_summaries = {}
for node in ast.walk(module):
if isinstance(node, ast.ClassDef):
class_code = ast.get_source_segment(content, node)
# Get method summaries for this class
method_summaries = {}
for method in [n for n in ast.walk(node) if isinstance(n, ast.FunctionDef)]:
method_code = ast.get_source_segment(content, method)
method_summaries[method.name] = self.summarize_function(method_code)
class_summaries[node.name] = self.summarize_class(class_code, method_summaries)
# Create module summary
module_name = os.path.basename(file_path)
functions_context = "\n".join([f"- {name}: {summary}" for name, summary in function_summaries.items()
if not any(name in methods for methods in class_summaries.values())])
classes_context = "\n".join([f"- {name}: {summary}" for name, summary in class_summaries.items()])
prompt = f"""
Create a comprehensive summary of the module {module_name} based on its components:
Functions:
{functions_context}
Classes:
{classes_context}
Module summary:
"""
return self.llm(prompt)
except SyntaxError:
return "Could not parse the file due to syntax errors."
# Example usage
summarizer = CodeSummarizer()
module_summary = summarizer.summarize_module("./my_project/auth/manager.py")
print(module_summary)
This code demonstrates a hierarchical approach to code summarization. It first parses the Python code using the AST module, then summarizes individual functions, uses those summaries to create class summaries, and finally combines everything into a module summary. This preserves the hierarchical structure of the code while making it more digestible.
Effective summarization helps developers quickly understand unfamiliar codebases, identify relevant components, and navigate complex systems. It's particularly valuable for onboarding new team members and for maintaining legacy codebases where documentation may be lacking.
Code Compression Approaches
Code compression techniques aim to reduce the size of code representations while preserving their semantic meaning. This is crucial for working with LLMs that have limited context windows. Unlike summarization, which creates natural language descriptions, compression maintains the code's functional content in a more compact form.
Several approaches to code compression exist:
- Semantic compression removes redundant or non-essential code elements like comments, formatting, and unused imports.
- Abstract syntax tree (AST) compression represents code in a more compact tree structure rather than as raw text.
- Symbolic compression replaces common patterns with shorter symbols or tokens.
Here's an example of implementing semantic code compression:
import re
import ast
import astunparse
class CodeCompressor:
def __init__(self):
pass
def remove_comments(self, code):
# Remove single-line comments
code = re.sub(r'#.*$', '', code, flags=re.MULTILINE)
# Remove multi-line docstrings
code = re.sub(r'"""[\s\S]*?"""', '', code)
code = re.sub(r"'''[\s\S]*?'''", '', code)
return code
def remove_blank_lines(self, code):
return "\n".join([line for line in code.split("\n") if line.strip()])
def compress_whitespace(self, code):
# Replace multiple spaces with a single space
code = re.sub(r' +', ' ', code)
# Remove unnecessary spaces around operators and punctuation
code = re.sub(r'\s*([=+\-*/,:;()])\s*', r'\1', code)
return code
def semantic_compression(self, code):
# First level: remove comments and blank lines
code = self.remove_comments(code)
code = self.remove_blank_lines(code)
try:
# Parse the code to AST
tree = ast.parse(code)
# Remove docstrings from functions, classes, and modules
for node in ast.walk(tree):
if isinstance(node, (ast.FunctionDef, ast.ClassDef, ast.Module)):
if ast.get_docstring(node):
if node.body and isinstance(node.body[0], ast.Expr) and isinstance(node.body[0].value, ast.Str):
node.body = node.body[1:]
# Convert back to code
compressed_code = astunparse.unparse(tree)
# Apply whitespace compression
compressed_code = self.compress_whitespace(compressed_code)
return compressed_code
except SyntaxError:
# If parsing fails, fall back to basic compression
return self.compress_whitespace(code)
def ast_based_compression(self, code):
try:
# Parse the code to AST
tree = ast.parse(code)
# Create a simplified AST representation
simplified_tree = self.simplify_ast(tree)
# Convert the simplified AST to a compact string representation
return str(simplified_tree)
except SyntaxError:
return "Failed to parse code"
def simplify_ast(self, node):
if isinstance(node, ast.Module):
return {"type": "Module", "body": [self.simplify_ast(n) for n in node.body]}
elif isinstance(node, ast.FunctionDef):
return {
"type": "Function",
"name": node.name,
"args": [arg.arg for arg in node.args.args],
"body_size": len(node.body)
}
elif isinstance(node, ast.ClassDef):
return {
"type": "Class",
"name": node.name,
"bases": [base.id if isinstance(base, ast.Name) else "complex_base" for base in node.bases],
"methods": [self.simplify_ast(n) for n in node.body if isinstance(n, ast.FunctionDef)]
}
elif isinstance(node, ast.Assign):
return {"type": "Assignment"}
elif isinstance(node, ast.Call):
func_name = node.func.id if isinstance(node.func, ast.Name) else "complex_call"
return {"type": "Call", "func": func_name}
else:
return {"type": str(type(node).__name__)}
# Example usage
compressor = CodeCompressor()
code = """
# This is a sample class
class UserAuthentication:
\"\"\"
Handles user authentication and session management.
Supports multiple authentication methods.
\"\"\"
def __init__(self, database_connection):
# Initialize with database connection
self.db = database_connection
self.active_sessions = {} # Track active user sessions
def authenticate(self, username, password):
\"\"\"Authenticate a user with username and password\"\"\"
# Query the database
user_record = self.db.query("SELECT * FROM users WHERE username = %s", (username,))
if not user_record:
return False
# Check password hash
if self._verify_password(password, user_record['password_hash']):
# Create new session
session_id = self._generate_session_id()
self.active_sessions[session_id] = username
return session_id
return False
"""
compressed_code = compressor.semantic_compression(code)
print("Semantically compressed code:")
print(compressed_code)
ast_representation = compressor.ast_based_compression(code)
print("\nAST-based compression:")
print(ast_representation)
This example demonstrates two approaches to code compression. The semantic compression removes comments, docstrings, and unnecessary whitespace while preserving the functional code. The AST-based compression converts the code into a simplified abstract syntax tree representation, which can be much more compact than the original text.
Compressed code representations are particularly useful when working with large codebases and LLMs with limited context windows. By removing non-essential elements and focusing on the code's structure and semantics, more code can fit within the context window, enabling more comprehensive analysis.
Multiple Prompts Strategies
When dealing with large codebases, a single prompt may not be sufficient to capture all relevant information or to perform complex analyses. Multiple prompts strategies involve breaking down complex tasks into sequences of simpler prompts, with each prompt building on the results of previous ones.
Chain-of-thought prompting is one such approach, where the LLM is guided through a step-by-step reasoning process. For code analysis, this might involve first identifying relevant components, then analyzing their relationships, and finally synthesizing insights.
Here's an example of implementing a multiple prompts strategy for analyzing a complex function:
from langchain.llms import OpenAI
import time
class MultiPromptCodeAnalyzer:
def __init__(self):
self.llm = OpenAI(temperature=0.2)
def analyze_complex_function(self, function_code):
# Step 1: Identify the function's inputs, outputs, and dependencies
structure_prompt = f"""
Analyze the following function and identify:
1. Input parameters and their types
2. Return value and its type
3. External dependencies (imported modules, global variables)
4. Internal helper functions
Function:
```
{function_code}
```
Analysis:
"""
structure_analysis = self.llm(structure_prompt)
time.sleep(1) # Avoid rate limiting
# Step 2: Analyze the algorithm and control flow
algorithm_prompt = f"""
Based on this structural analysis:
{structure_analysis}
Analyze the algorithm and control flow of the function:
```
{function_code}
```
Focus on:
1. The main algorithm steps
2. Control flow (conditionals, loops)
3. Error handling approach
4. Performance considerations
Algorithm analysis:
"""
algorithm_analysis = self.llm(algorithm_prompt)
time.sleep(1) # Avoid rate limiting
# Step 3: Identify potential issues and improvements
issues_prompt = f"""
Based on the previous analyses:
Structural analysis:
{structure_analysis}
Algorithm analysis:
{algorithm_analysis}
Identify potential issues and improvements for this function:
```
{function_code}
```
Consider:
1. Edge cases that might not be handled
2. Potential bugs or logical errors
3. Performance optimizations
4. Code readability and maintainability improvements
Issues and improvements:
"""
issues_analysis = self.llm(issues_prompt)
# Step 4: Synthesize a comprehensive analysis
final_prompt = f"""
Create a comprehensive analysis of the following function by synthesizing these analyses:
Structural analysis:
{structure_analysis}
Algorithm analysis:
{algorithm_analysis}
Issues and improvements:
{issues_analysis}
Function:
```
{function_code}
```
Comprehensive analysis:
"""
return self.llm(final_prompt)
# Example usage
analyzer = MultiPromptCodeAnalyzer()
complex_function = """
def process_transaction_batch(transactions, user_accounts, settings=None):
\"\"\"
Process a batch of financial transactions against user accounts.
Args:
transactions: List of transaction dictionaries with keys 'user_id', 'amount', 'type'
user_accounts: Dictionary mapping user_ids to account dictionaries
settings: Optional settings dictionary with processing parameters
Returns:
Tuple of (processed_transactions, failed_transactions, processing_summary)
\"\"\"
if settings is None:
settings = {
'max_daily_withdrawal': 5000,
'transaction_fee': 0.01,
'retry_failed': False
}
processed = []
failed = []
summary = {'total_processed': 0, 'total_amount': 0, 'fees_collected': 0}
# Track daily withdrawals per user
daily_withdrawals = {}
for transaction in transactions:
user_id = transaction.get('user_id')
amount = transaction.get('amount', 0)
tx_type = transaction.get('type', 'unknown')
# Validate transaction
if not user_id or user_id not in user_accounts:
failed.append({'transaction': transaction, 'reason': 'Invalid user_id'})
continue
if amount <= 0:
failed.append({'transaction': transaction, 'reason': 'Invalid amount'})
continue
account = user_accounts[user_id]
# Apply transaction type-specific logic
if tx_type == 'withdrawal':
# Check daily withdrawal limit
user_daily = daily_withdrawals.get(user_id, 0)
if user_daily + amount > settings['max_daily_withdrawal']:
failed.append({'transaction': transaction, 'reason': 'Daily withdrawal limit exceeded'})
continue
# Check sufficient balance
if account['balance'] < amount:
failed.append({'transaction': transaction, 'reason': 'Insufficient funds'})
continue
# Update daily withdrawal tracking
daily_withdrawals[user_id] = user_daily + amount
# Apply withdrawal
account['balance'] -= amount
elif tx_type == 'deposit':
# Apply deposit
account['balance'] += amount
else:
failed.append({'transaction': transaction, 'reason': 'Unknown transaction type'})
continue
# Calculate fee
fee = amount * settings['transaction_fee']
account['balance'] -= fee
# Record successful transaction
processed_tx = transaction.copy()
processed_tx['fee'] = fee
processed_tx['processed_at'] = time.time()
processed.append(processed_tx)
# Update summary
summary['total_processed'] += 1
summary['total_amount'] += amount
summary['fees_collected'] += fee
# Handle retry logic if enabled
if settings.get('retry_failed') and failed:
# Simplified retry logic for example purposes
retry_candidates = [tx for tx in failed if tx['reason'] == 'Insufficient funds']
if retry_candidates:
# Sort by user_id to process all transactions for the same user together
retry_candidates.sort(key=lambda x: x['transaction']['user_id'])
# Attempt to retry with reduced amounts
for failed_tx in retry_candidates:
original_tx = failed_tx['transaction']
user_id = original_tx['user_id']
# Try with 50% of the original amount
reduced_tx = original_tx.copy()
reduced_tx['amount'] = original_tx['amount'] * 0.5
# Check if this would succeed
account = user_accounts[user_id]
if account['balance'] >= reduced_tx['amount']:
account['balance'] -= reduced_tx['amount']
fee = reduced_tx['amount'] * settings['transaction_fee']
account['balance'] -= fee
reduced_tx['fee'] = fee
reduced_tx['processed_at'] = time.time()
reduced_tx['reduced_from_original'] = True
processed.append(reduced_tx)
failed.remove(failed_tx)
summary['total_processed'] += 1
summary['total_amount'] += reduced_tx['amount']
summary['fees_collected'] += fee
return (processed, failed, summary)
"""
analysis = analyzer.analyze_complex_function(complex_function)
print(analysis)
This example demonstrates a multi-prompt approach to analyzing a complex function. It breaks down the analysis into four steps: structural analysis, algorithm analysis, issues identification, and synthesis. Each step builds on the results of previous steps, allowing the LLM to focus on specific aspects of the code while maintaining context from earlier analyses.
Multiple prompts strategies are particularly valuable for complex code analysis tasks that require deep understanding and reasoning. By guiding the LLM through a structured analytical process, developers can obtain more thorough and accurate insights than would be possible with a single prompt.
Semantic Partitioning
Semantic partitioning involves dividing a codebase into meaningful segments based on their semantic relationships rather than arbitrary size limits. This approach ensures that related code stays together, making it easier for LLMs to understand the context and relationships between components.
Effective semantic partitioning considers several factors:
- Functional cohesion: Grouping code that works together to implement a specific feature or functionality
- Data dependencies: Keeping code that operates on the same data structures together
- Control flow: Preserving the flow of execution within partitions
- Module boundaries: Respecting the existing architectural boundaries in the codebase
Here's an example of implementing semantic partitioning:
import ast
import networkx as nx
from community import community_louvain
import numpy as np
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity
class SemanticCodePartitioner:
def __init__(self):
self.dependency_graph = nx.DiGraph()
self.similarity_matrix = None
self.functions = {}
self.classes = {}
def parse_file(self, file_path):
with open(file_path, 'r') as file:
content = file.read()
try:
tree = ast.parse(content)
file_name = file_path.split("/")[-1]
# Extract functions and classes
for node in ast.walk(tree):
if isinstance(node, ast.FunctionDef):
func_id = f"{file_name}:{node.name}"
self.functions[func_id] = {
'name': node.name,
'file': file_path,
'code': ast.get_source_segment(content, node),
'calls': [],
'imports': []
}
elif isinstance(node, ast.ClassDef):
class_id = f"{file_name}:{node.name}"
self.classes[class_id] = {
'name': node.name,
'file': file_path,
'code': ast.get_source_segment(content, node),
'methods': [],
'inherits': []
}
# Extract methods
for class_node in ast.walk(node):
if isinstance(class_node, ast.FunctionDef):
method_id = f"{class_id}.{class_node.name}"
self.functions[method_id] = {
'name': class_node.name,
'file': file_path,
'class': class_id,
'code': ast.get_source_segment(content, class_node),
'calls': []
}
self.classes[class_id]['methods'].append(method_id)
# Extract function calls and imports
for func_id, func_info in self.functions.items():
func_node = ast.parse(func_info['code'])
for node in ast.walk(func_node):
if isinstance(node, ast.Call) and isinstance(node.func, ast.Name):
func_info['calls'].append(node.func.id)
elif isinstance(node, ast.Import):
for name in node.names:
func_info['imports'].append(name.name)
elif isinstance(node, ast.ImportFrom):
if node.module:
func_info['imports'].append(node.module)
# Extract class inheritance
for class_id, class_info in self.classes.items():
class_node = ast.parse(class_info['code'])
for node in ast.walk(class_node):
if isinstance(node, ast.ClassDef):
for base in node.bases:
if isinstance(base, ast.Name):
class_info['inherits'].append(base.id)
except SyntaxError:
print(f"Syntax error in {file_path}")
def build_dependency_graph(self):
# Add nodes for functions and classes
for func_id in self.functions:
self.dependency_graph.add_node(func_id, type='function')
for class_id in self.classes:
self.dependency_graph.add_node(class_id, type='class')
# Add edges for function calls
for func_id, func_info in self.functions.items():
for called_func in func_info.get('calls', []):
# Find matching function
for target_id in self.functions:
if target_id.endswith(f":{called_func}") or target_id.endswith(f".{called_func}"):
self.dependency_graph.add_edge(func_id, target_id, type='calls')
# Add edges for class inheritance
for class_id, class_info in self.classes.items():
for base_class in class_info.get('inherits', []):
# Find matching class
for target_id in self.classes:
if target_id.endswith(f":{base_class}"):
self.dependency_graph.add_edge(class_id, target_id, type='inherits')
# Add edges for class-method relationships
for class_id, class_info in self.classes.items():
for method_id in class_info.get('methods', []):
self.dependency_graph.add_edge(class_id, method_id, type='has_method')
self.dependency_graph.add_edge(method_id, class_id, type='belongs_to')
def calculate_semantic_similarity(self):
# Extract code content for all nodes
node_ids = list(self.dependency_graph.nodes())
code_content = []
for node_id in node_ids:
if node_id in self.functions:
code_content.append(self.functions[node_id]['code'])
elif node_id in self.classes:
code_content.append(self.classes[node_id]['code'])
else:
code_content.append("")
# Calculate TF-IDF vectors
vectorizer = TfidfVectorizer(analyzer='word', token_pattern=r'\w+', max_features=1000)
tfidf_matrix = vectorizer.fit_transform(code_content)
# Calculate cosine similarity
self.similarity_matrix = cosine_similarity(tfidf_matrix)
# Add semantic similarity edges to the graph
for i in range(len(node_ids)):
for j in range(i+1, len(node_ids)):
if self.similarity_matrix[i, j] > 0.5: # Threshold for significant similarity
self.dependency_graph.add_edge(node_ids[i], node_ids[j],
type='semantic',
weight=self.similarity_matrix[i, j])
def partition_codebase(self, num_partitions=None):
# Ensure the graph is built
if self.dependency_graph.number_of_nodes() == 0:
raise ValueError("Dependency graph is empty. Call build_dependency_graph first.")
# Convert directed graph to undirected for community detection
undirected_graph = self.dependency_graph.to_undirected()
# Add weights to edges if not present
for u, v in undirected_graph.edges():
if 'weight' not in undirected_graph[u][v]:
# Default weights based on edge type
edge_type = self.dependency_graph[u][v].get('type', '')
if edge_type == 'calls' or edge_type == 'inherits':
undirected_graph[u][v]['weight'] = 0.8
elif edge_type == 'has_method' or edge_type == 'belongs_to':
undirected_graph[u][v]['weight'] = 0.9
else:
undirected_graph[u][v]['weight'] = 0.5
# Apply community detection algorithm
partition = community_louvain.best_partition(undirected_graph)
# Organize nodes by partition
partitions = {}
for node, part_id in partition.items():
if part_id not in partitions:
partitions[part_id] = []
partitions[part_id].append(node)
# Merge small partitions if needed
if num_partitions and len(partitions) > num_partitions:
self.merge_partitions(partitions, num_partitions)
return partitions
def merge_partitions(self, partitions, target_count):
# Calculate partition sizes
sizes = {part_id: len(nodes) for part_id, nodes in partitions.items()}
# While we have more partitions than target
while len(partitions) > target_count:
# Find the smallest partition
smallest_id = min(sizes.keys(), key=lambda k: sizes[k])
smallest_size = sizes[smallest_id]
# Find the most connected partition to merge with
best_merge_id = None
best_connection_strength = -1
for part_id in partitions:
if part_id == smallest_id:
continue
# Calculate connection strength between partitions
connection_strength = 0
for node1 in partitions[smallest_id]:
for node2 in partitions[part_id]:
if self.dependency_graph.has_edge(node1, node2) or self.dependency_graph.has_edge(node2, node1):
edge_data = self.dependency_graph.get_edge_data(node1, node2) or self.dependency_graph.get_edge_data(node2, node1)
connection_strength += edge_data.get('weight', 1)
if connection_strength > best_connection_strength:
best_connection_strength = connection_strength
best_merge_id = part_id
# Merge the partitions
if best_merge_id is not None:
partitions[best_merge_id].extend(partitions[smallest_id])
sizes[best_merge_id] += smallest_size
del partitions[smallest_id]
del sizes[smallest_id]
else:
# If no connections, just merge with the next smallest
next_smallest = min(sizes.keys(), key=lambda k: sizes[k] if k != smallest_id else float('inf'))
partitions[next_smallest].extend(partitions[smallest_id])
sizes[next_smallest] += smallest_size
del partitions[smallest_id]
del sizes[smallest_id]
def get_partition_code(self, partition_nodes):
code_blocks = []
for node in partition_nodes:
if node in self.functions:
code_blocks.append(f"# Function: {node}\n{self.functions[node]['code']}")
elif node in self.classes:
code_blocks.append(f"# Class: {node}\n{self.classes[node]['code']}")
return "\n\n".join(code_blocks)
# Example usage
partitioner = SemanticCodePartitioner()
partitioner.parse_file("./my_project/auth/manager.py")
partitioner.parse_file("./my_project/auth/validation.py")
partitioner.parse_file("./my_project/users/profile.py")
partitioner.build_dependency_graph()
partitioner.calculate_semantic_similarity()
partitions = partitioner.partition_codebase(num_partitions=3)
for part_id, nodes in partitions.items():
print(f"Partition {part_id} contains {len(nodes)} nodes:")
for node in nodes[:5]: # Show first 5 nodes
print(f" - {node}")
if len(nodes) > 5:
print(f" - ... and {len(nodes) - 5} more")
# Get the code for this partition
partition_code = partitioner.get_partition_code(nodes)
print(f"Partition size: {len(partition_code)} characters")
print()
This example demonstrates a sophisticated approach to semantic partitioning. It builds a dependency graph of the codebase, considering both structural relationships (function calls, class inheritance) and semantic similarity based on code content. It then applies community detection algorithms to identify cohesive groups of code that should be kept together.
Semantic partitioning is particularly valuable for large codebases where related code may be spread across multiple files or modules. By grouping code based on its semantic relationships rather than arbitrary boundaries, LLMs can better understand the context and provide more accurate analyses and suggestions.
Agent-Based Approaches
Agent-based approaches involve creating autonomous or semi-autonomous systems that can navigate, analyze, and manipulate codebases with minimal human intervention. These agents use LLMs as their reasoning engines but add capabilities like memory, planning, and tool use.
A code analysis agent might perform tasks like:
- Exploring a codebase to find relevant components
- Analyzing dependencies and relationships
- Identifying bugs or performance issues
- Suggesting improvements or refactorings
Here's an example of implementing a simple code analysis agent:
import os
import ast
import re
from langchain.llms import OpenAI
from langchain.agents import Tool
from langchain.memory import ConversationBufferMemory
import networkx as nx
class CodeAnalysisAgent:
def __init__(self, codebase_path):
self.codebase_path = codebase_path
self.llm = OpenAI(temperature=0.2)
self.memory = ConversationBufferMemory(memory_key="chat_history")
self.file_cache = {}
self.dependency_graph = nx.DiGraph()
self.current_file = None
# Initialize tools
self.tools = [
Tool(
name="list_files",
func=self.list_files,
description="Lists files in a directory. Input should be a relative path within the codebase."
),
Tool(
name="read_file",
func=self.read_file,
description="Reads the content of a file. Input should be the file path."
),
Tool(
name="find_function",
func=self.find_function,
description="Finds a function definition in the codebase. Input should be the function name."
),
Tool(
name="find_class",
func=self.find_class,
description="Finds a class definition in the codebase. Input should be the class name."
),
Tool(
name="analyze_dependencies",
func=self.analyze_dependencies,
description="Analyzes dependencies of a file. Input should be the file path."
),
Tool(
name="search_code",
func=self.search_code,
description="Searches for a pattern in the codebase. Input should be a regex pattern."
)
]
def list_files(self, directory="."):
"""Lists files in the specified directory."""
full_path = os.path.join(self.codebase_path, directory)
if not os.path.exists(full_path):
return f"Directory not found: {directory}"
files = []
for item in os.listdir(full_path):
item_path = os.path.join(full_path, item)
if os.path.isfile(item_path):
files.append(item)
elif os.path.isdir(item_path):
files.append(f"{item}/")
return "\n".join(files)
def read_file(self, file_path):
"""Reads the content of the specified file."""
full_path = os.path.join(self.codebase_path, file_path)
if not os.path.exists(full_path):
return f"File not found: {file_path}"
if full_path in self.file_cache:
self.current_file = file_path
return self.file_cache[full_path]
try:
with open(full_path, 'r') as file:
content = file.read()
self.file_cache[full_path] = content
self.current_file = file_path
# If it's a Python file, parse it to build the dependency graph
if file_path.endswith('.py'):
self.parse_file_dependencies(file_path, content)
return content
except Exception as e:
return f"Error reading file: {str(e)}"
def find_function(self, function_name):
"""Finds a function definition in the codebase."""
results = []
for root, _, files in os.walk(self.codebase_path):
for file in files:
if file.endswith('.py'):
file_path = os.path.join(root, file)
relative_path = os.path.relpath(file_path, self.codebase_path)
content = self.read_file(relative_path)
# Simple pattern matching for function definition
pattern = rf"def\s+{re.escape(function_name)}\s*\("
if re.search(pattern, content):
# Extract the function definition
try:
tree = ast.parse(content)
for node in ast.walk(tree):
if isinstance(node, ast.FunctionDef) and node.name == function_name:
func_code = ast.get_source_segment(content, node)
results.append(f"Found in {relative_path}:\n{func_code}")
except:
results.append(f"Found in {relative_path} but couldn't extract the full definition.")
if results:
return "\n\n".join(results)
else:
return f"Function '{function_name}' not found in the codebase."
def find_class(self, class_name):
"""Finds a class definition in the codebase."""
results = []
for root, _, files in os.walk(self.codebase_path):
for file in files:
if file.endswith('.py'):
file_path = os.path.join(root, file)
relative_path = os.path.relpath(file_path, self.codebase_path)
content = self.read_file(relative_path)
# Simple pattern matching for class definition
pattern = rf"class\s+{re.escape(class_name)}\s*[:\(]"
if re.search(pattern, content):
# Extract the class definition
try:
tree = ast.parse(content)
for node in ast.walk(tree):
if isinstance(node, ast.ClassDef) and node.name == class_name:
class_code = ast.get_source_segment(content, node)
results.append(f"Found in {relative_path}:\n{class_code}")
except:
results.append(f"Found in {relative_path} but couldn't extract the full definition.")
if results:
return "\n\n".join(results)
else:
return f"Class '{class_name}' not found in the codebase."
def parse_file_dependencies(self, file_path, content):
"""Parses a Python file to extract dependencies."""
try:
tree = ast.parse(content)
file_node = file_path
# Add node for the file
if file_node not in self.dependency_graph:
self.dependency_graph.add_node(file_node, type='file')
# Extract imports
for node in ast.walk(tree):
if isinstance(node, ast.Import):
for name in node.names:
import_name = name.name
self.dependency_graph.add_edge(file_node, import_name, type='imports')
elif isinstance(node, ast.ImportFrom):
if node.module:
import_name = node.module
self.dependency_graph.add_edge(file_node, import_name, type='imports_from')
# Extract function definitions
elif isinstance(node, ast.FunctionDef):
func_name = f"{file_path}:{node.name}"
self.dependency_graph.add_node(func_name, type='function')
self.dependency_graph.add_edge(file_node, func_name, type='defines')
# Extract class definitions
elif isinstance(node, ast.ClassDef):
class_name = f"{file_path}:{node.name}"
self.dependency_graph.add_node(class_name, type='class')
self.dependency_graph.add_edge(file_node, class_name, type='defines')
# Extract base classes
for base in node.bases:
if isinstance(base, ast.Name):
self.dependency_graph.add_edge(class_name, base.id, type='inherits')
except Exception as e:
print(f"Error parsing {file_path}: {str(e)}")
def analyze_dependencies(self, file_path):
"""Analyzes dependencies of a file."""
full_path = os.path.join(self.codebase_path, file_path)
if not os.path.exists(full_path):
return f"File not found: {file_path}"
# Ensure the file is parsed
if file_path not in self.file_cache:
self.read_file(file_path)
# Get imports
imports = []
for _, target in self.dependency_graph.out_edges(file_path):
edge_data = self.dependency_graph.get_edge_data(file_path, target)
if edge_data.get('type') in ('imports', 'imports_from'):
imports.append(target)
# Get defined functions and classes
definitions = []
for _, target in self.dependency_graph.out_edges(file_path):
edge_data = self.dependency_graph.get_edge_data(file_path, target)
if edge_data.get('type') == 'defines':
definitions.append(target.split(':')[1])
# Get files that import this file
imported_by = []
for source, target in self.dependency_graph.in_edges(file_path):
edge_data = self.dependency_graph.get_edge_data(source, target)
if edge_data.get('type') in ('imports', 'imports_from'):
imported_by.append(source)
result = f"Dependencies analysis for {file_path}:\n\n"
result += f"Imports ({len(imports)}):\n" + "\n".join(imports) + "\n\n"
result += f"Definitions ({len(definitions)}):\n" + "\n".join(definitions) + "\n\n"
result += f"Imported by ({len(imported_by)}):\n" + "\n".join(imported_by)
return result
def search_code(self, pattern):
"""Searches for a pattern in the codebase."""
results = []
try:
regex = re.compile(pattern)
for root, _, files in os.walk(self.codebase_path):
for file in files:
if file.endswith(('.py', '.js', '.java', '.c', '.cpp', '.h')):
file_path = os.path.join(root, file)
relative_path = os.path.relpath(file_path, self.codebase_path)
# Use cached content if available
if file_path in self.file_cache:
content = self.file_cache[file_path]
else:
with open(file_path, 'r', encoding='utf-8', errors='ignore') as f:
content = f.read()
self.file_cache[file_path] = content
# Search for matches
matches = regex.finditer(content)
for match in matches:
# Get some context around the match
start = max(0, match.start() - 50)
end = min(len(content), match.end() + 50)
context = content[start:end].replace('\n', ' ')
# Add ellipsis if we truncated
if start > 0:
context = "..." + context
if end < len(content):
context = context + "..."
results.append(f"{relative_path}: {context}")
if results:
return f"Found {len(results)} matches for '{pattern}':\n\n" + "\n\n".join(results[:10]) + (
f"\n\n...and {len(results) - 10} more matches." if len(results) > 10 else ""
)
else:
return f"No matches found for '{pattern}'."
except re.error:
return f"Invalid regex pattern: {pattern}"
def run(self, query):
"""Runs the agent on a user query."""
# First, determine which tool to use
tool_selection_prompt = f"""
Based on the user query: "{query}"
Select the most appropriate tool from the following options:
{', '.join([tool.name for tool in self.tools])}
Just respond with the tool name only.
"""
selected_tool = self.llm(tool_selection_prompt).strip()
# Find the selected tool
tool = next((t for t in self.tools if t.name == selected_tool), None)
if not tool:
return f"I'm not sure how to handle that query. Could you rephrase it?"
# Determine the input for the tool
tool_input_prompt = f"""
Based on the user query: "{query}"
And the selected tool: "{selected_tool}"
What should be the input parameter for this tool? Be concise and only provide the parameter value.
"""
tool_input = self.llm(tool_input_prompt).strip()
# Execute the tool
tool_result = tool.func(tool_input)
# Generate a response based on the tool's output
response_prompt = f"""
The user asked: "{query}"
I used the {selected_tool} tool with input: {tool_input}
The tool returned the following result:
{tool_result}
Based on this information, provide a helpful response to the user's query.
"""
response = self.llm(response_prompt)
# Update memory
self.memory.chat_memory.add_user_message(query)
self.memory.chat_memory.add_ai_message(response)
return response
# Example usage
agent = CodeAnalysisAgent("./my_project")
response = agent.run("Find all functions that handle authentication")
print(response)
response = agent.run("What are the dependencies of the user profile module?")
print(response)
This example demonstrates a code analysis agent that can navigate and analyze a codebase. The agent has tools for listing files, reading file contents, finding functions and classes, analyzing dependencies, and searching for patterns. It uses these tools to answer user queries about the codebase.
The agent approach is particularly valuable for complex analysis tasks that require multiple steps and reasoning. Rather than trying to fit an entire analysis into a single prompt, the agent can break down the task, gather information as needed, and synthesize insights based on what it discovers.
Code Generation Techniques Using LLMs
LLMs have demonstrated impressive capabilities for code generation, from completing small snippets to creating entire modules. However, generating high-quality code for large, complex systems requires specialized techniques to ensure consistency, correctness, and integration with existing codebases.
Iterative Refinement
Iterative refinement involves generating code in stages, with each stage improving upon the previous one. This approach is particularly valuable for complex code generation tasks where the first attempt may be incomplete or contain errors.
The process typically involves:
- Generating an initial implementation
- Evaluating the code for correctness, style, and edge cases
- Refining the code based on the evaluation
- Repeating until the code meets the desired quality standards
Here's an example of implementing iterative refinement for code generation:
from langchain.llms import OpenAI
import ast
import re
class IterativeCodeGenerator:
def __init__(self):
self.llm = OpenAI(temperature=0.2)
def generate_initial_code(self, specification):
"""Generate an initial implementation based on the specification."""
prompt = f"""
Write a Python function that implements the following specification:
{specification}
Only include the function code, no explanations.
"""
initial_code = self.llm(prompt)
return self.extract_function_code(initial_code)
def extract_function_code(self, text):
"""Extract function code from text that might contain explanations."""
# Try to find a Python function definition
function_match = re.search(r'def\s+\w+\s*\(.*?\).*?:', text, re.DOTALL)
if not function_match:
return text
# Find the function body by parsing indentation
lines = text[function_match.start():].split('\n')
function_lines = [lines[0]] # First line (def statement)
# Determine the indentation of the function body
for i in range(1, len(lines)):
line = lines[i]
if line.strip() and not line.startswith(' ') and not line.startswith('\t'):
# This line is not indented, so it's outside the function
break
function_lines.append(line)
return '\n'.join(function_lines)
def evaluate_code(self, code, specification):
"""Evaluate the code for correctness, style, and edge cases."""
prompt = f"""
Evaluate the following Python function against this specification:
Specification:
{specification}
Code:
{code}
Identify any issues with:
1. Correctness: Does it implement the specification correctly?
2. Edge cases: Does it handle all possible inputs and edge cases?
3. Style: Does it follow good Python coding practices?
4. Performance: Are there any obvious performance issues?
For each issue, provide a specific explanation of the problem.
"""
evaluation = self.llm(prompt)
return evaluation
def refine_code(self, code, specification, evaluation):
"""Refine the code based on the evaluation."""
prompt = f"""
Improve the following Python function based on the evaluation:
Specification:
{specification}
Current code:
{code}
Evaluation:
{evaluation}
Provide an improved version of the function that addresses the issues in the evaluation.
Only include the function code, no explanations.
"""
refined_code = self.llm(prompt)
return self.extract_function_code(refined_code)
def check_syntax(self, code):
"""Check if the code has valid Python syntax."""
try:
ast.parse(code)
return True, None
except SyntaxError as e:
return False, str(e)
def generate_tests(self, code, specification):
"""Generate test cases for the function."""
prompt = f"""
Write comprehensive test cases for the following Python function:
Specification:
{specification}
Function code:
{code}
Include tests for normal cases and edge cases. Use pytest format.
Only include the test code, no explanations.
"""
tests = self.llm(prompt)
return tests
def generate_code_iteratively(self, specification, max_iterations=3):
"""Generate code through iterative refinement."""
# Generate initial code
code = self.generate_initial_code(specification)
iterations = []
iterations.append({
"iteration": 1,
"code": code,
"syntax_valid": self.check_syntax(code)[0]
})
# Iterative refinement
for i in range(max_iterations - 1):
# Check syntax
syntax_valid, syntax_error = self.check_syntax(code)
if not syntax_valid:
evaluation = f"Syntax error: {syntax_error}"
else:
# Evaluate the code
evaluation = self.evaluate_code(code, specification)
# Refine the code
code = self.refine_code(code, specification, evaluation)
iterations.append({
"iteration": i + 2,
"code": code,
"evaluation": evaluation,
"syntax_valid": self.check_syntax(code)[0]
})
# If the code has valid syntax and no major issues, we can stop
if syntax_valid and "no issues" in evaluation.lower():
break
# Generate tests for the final code
tests = self.generate_tests(code, specification)
return {
"final_code": code,
"iterations": iterations,
"tests": tests
}
# Example usage
generator = IterativeCodeGenerator()
specification = """
Create a function called 'parse_log_entry' that parses a log entry string and extracts key information.
The log entry format is as follows:
[TIMESTAMP] [LEVEL] [MODULE] - Message
Where:
- TIMESTAMP is in the format YYYY-MM-DD HH:MM:SS.mmm
- LEVEL is one of: DEBUG, INFO, WARNING, ERROR, CRITICAL
- MODULE is the name of the module that generated the log
- Message is the actual log message
The function should return a dictionary with the following keys:
- timestamp: a datetime object
- level: the log level as a string
- module: the module name as a string
- message: the log message as a string
If the log entry doesn't match the expected format, the function should raise a ValueError.
"""
result = generator.generate_code_iteratively(specification)
print("Final code:")
print(result["final_code"])
print("\nTests:")
print(result["tests"])
print("\nIteration history:")
for iteration in result["iterations"]:
print(f"Iteration {iteration['iteration']} - Syntax valid: {iteration['syntax_valid']}")
if "evaluation" in iteration:
print(f"Evaluation: {iteration['evaluation'][:100]}...")
print("---")
This example demonstrates iterative code generation. It starts with an initial implementation based on the specification, evaluates it for correctness and edge cases, and then refines it based on the evaluation. This process repeats for a specified number of iterations or until the code meets quality standards.
Iterative refinement is particularly valuable for complex code generation tasks where the first attempt is unlikely to be perfect. By evaluating and refining the code in stages, the system can produce higher-quality results than would be possible with a single generation step.
Test-Driven Generation
Test-driven generation applies the principles of test-driven development to LLM-based code generation. It involves:
- Generating test cases based on the specification
- Generating code that passes the tests
- Refining both the tests and the code iteratively
This approach helps ensure that the generated code meets the functional requirements and handles edge cases correctly.
Here's an example of implementing test-driven code generation:
from langchain.llms import OpenAI
import ast
import subprocess
import tempfile
import os
class TestDrivenCodeGenerator:
def __init__(self):
self.llm = OpenAI(temperature=0.2)
def generate_tests(self, specification):
"""Generate test cases based on the specification."""
prompt = f"""
Write comprehensive pytest test cases for a function with the following specification:
{specification}
Include tests for normal cases and edge cases. The tests should be detailed enough to fully validate the function's behavior.
Only include the test code, no explanations.
"""
tests = self.llm(prompt)
return tests
def generate_implementation(self, specification, tests):
"""Generate an implementation that passes the tests."""
prompt = f"""
Write a Python function that implements the following specification and passes the provided tests:
Specification:
{specification}
Tests:
{tests}
Only include the function code, no explanations.
"""
implementation = self.llm(prompt)
return implementation
def run_tests(self, implementation, tests):
"""Run the tests against the implementation."""
# Create temporary files
with tempfile.NamedTemporaryFile(suffix='.py', delete=False) as impl_file:
impl_file.write(implementation.encode('utf-8'))
impl_path = impl_file.name
with tempfile.NamedTemporaryFile(suffix='.py', delete=False) as test_file:
# Modify the test to import from the implementation file
impl_module = os.path.basename(impl_path)[:-3] # Remove .py
test_content = f"from {impl_module} import *\n\n{tests}"
test_file.write(test_content.encode('utf-8'))
test_path = test_file.name
try:
# Run pytest
result = subprocess.run(
['pytest', test_path, '-v'],
capture_output=True,
text=True
)
return {
'success': result.returncode == 0,
'output': result.stdout + result.stderr
}
except Exception as e:
return {
'success': False,
'output': str(e)
}
finally:
# Clean up temporary files
os.unlink(impl_path)
os.unlink(test_path)
def refine_implementation(self, specification, implementation, tests, test_results):
"""Refine the implementation based on test results."""
prompt = f"""
Improve the following Python function to pass the failing tests:
Specification:
{specification}
Current implementation:
{implementation}
Tests:
{tests}
Test results:
{test_results['output']}
Provide an improved version of the function that passes all tests.
Only include the function code, no explanations.
"""
refined_implementation = self.llm(prompt)
return refined_implementation
def refine_tests(self, specification, implementation, tests):
"""Refine the tests to cover more edge cases."""
prompt = f"""
Improve the following test cases to more thoroughly test the function:
Specification:
{specification}
Implementation:
{implementation}
Current tests:
{tests}
Add additional test cases for edge cases that might not be covered.
Only include the test code, no explanations.
"""
refined_tests = self.llm(prompt)
return refined_tests
def generate_code_with_tdd(self, specification, max_iterations=3):
"""Generate code using test-driven development approach."""
# Generate initial tests
tests = self.generate_tests(specification)
# Generate initial implementation
implementation = self.generate_implementation(specification, tests)
iterations = []
iterations.append({
"iteration": 1,
"implementation": implementation,
"tests": tests
})
# Run tests
test_results = self.run_tests(implementation, tests)
# Iterative refinement
for i in range(max_iterations - 1):
if test_results['success']:
# If tests pass, refine the tests to cover more cases
tests = self.refine_tests(specification, implementation, tests)
test_results = self.run_tests(implementation, tests)
# If the refined tests fail, refine the implementation
if not test_results['success']:
implementation = self.refine_implementation(
specification, implementation, tests, test_results
)
else:
# If tests fail, refine the implementation
implementation = self.refine_implementation(
specification, implementation, tests, test_results
)
iterations.append({
"iteration": i + 2,
"implementation": implementation,
"tests": tests,
"test_results": test_results
})
# Run tests again
test_results = self.run_tests(implementation, tests)
# If tests pass after refinement, we can stop
if test_results['success']:
break
return {
"final_implementation": implementation,
"final_tests": tests,
"final_test_results": test_results,
"iterations": iterations
}
# Example usage
generator = TestDrivenCodeGenerator()
specification = """
Create a function called 'validate_email' that checks if a given string is a valid email address.
The function should:
1. Return True if the email is valid, False otherwise
2. Consider an email valid if it:
- Contains exactly one @ symbol
- Has at least one character before the @
- Has at least one character between @ and the last dot
- Has at least two characters after the last dot
- Contains only alphanumeric characters, dots, underscores, hyphens, and the @ symbol
- Does not have consecutive dots
"""
result = generator.generate_code_with_tdd(specification)
print("Final implementation:")
print(result["final_implementation"])
print("\nFinal tests:")
print(result["final_tests"])
print("\nTest results:")
print(result["final_test_results"]["output"])
This example demonstrates test-driven code generation. It first generates test cases based on the specification, then generates an implementation that should pass those tests. It runs the tests and, if they fail, refines the implementation. If the tests pass, it refines the tests to cover more edge cases. This process continues iteratively until the code passes all tests or reaches the maximum number of iterations.
Test-driven generation is particularly valuable for ensuring that the generated code meets functional requirements and handles edge cases correctly. By focusing on test cases first, it helps ensure that the code is correct and robust.
Architecture-Aware Generation
Architecture-aware generation involves creating code that fits into an existing system architecture. This requires understanding the system's components, their interactions, and the architectural patterns and principles in use.
For large codebases, architecture-aware generation might involve:
- Analyzing the existing architecture to understand patterns and conventions
- Generating code that follows those patterns
- Ensuring proper integration with existing components
Here's an example of implementing architecture-aware code generation:
from langchain.llms import OpenAI
import os
import ast
import re
class ArchitectureAwareGenerator:
def __init__(self, codebase_path):
self.codebase_path = codebase_path
self.llm = OpenAI(temperature=0.2)
self.architecture_patterns = {}
self.coding_conventions = {}
self.dependencies = {}
def analyze_architecture(self):
"""Analyze the existing codebase to extract architectural patterns."""
# Analyze module structure
self.analyze_module_structure()
# Analyze class hierarchies
self.analyze_class_hierarchies()
# Analyze dependency patterns
self.analyze_dependencies()
# Analyze coding conventions
self.analyze_coding_conventions()
def analyze_module_structure(self):
"""Analyze the module structure of the codebase."""
modules = {}
for root, dirs, files in os.walk(self.codebase_path):
for file in files:
if file.endswith('.py') and file != '__init__.py':
file_path = os.path.join(root, file)
relative_path = os.path.relpath(file_path, self.codebase_path)
module_path = relative_path.replace('/', '.').replace('\\', '.')[:-3]
with open(file_path, 'r', encoding='utf-8', errors='ignore') as f:
content = f.read()
# Extract module responsibilities
try:
tree = ast.parse(content)
classes = [node.name for node in ast.walk(tree) if isinstance(node, ast.ClassDef)]
functions = [node.name for node in ast.walk(tree) if isinstance(node, ast.FunctionDef)]
modules[module_path] = {
'path': relative_path,
'classes': classes,
'functions': functions
}
except:
pass
self.architecture_patterns['modules'] = modules
def analyze_class_hierarchies(self):
"""Analyze class hierarchies in the codebase."""
class_hierarchies = {}
for root, _, files in os.walk(self.codebase_path):
for file in files:
if file.endswith('.py'):
file_path = os.path.join(root, file)
with open(file_path, 'r', encoding='utf-8', errors='ignore') as f:
content = f.read()
try:
tree = ast.parse(content)
for node in ast.walk(tree):
if isinstance(node, ast.ClassDef):
class_name = node.name
base_classes = []
for base in node.bases:
if isinstance(base, ast.Name):
base_classes.append(base.id)
elif isinstance(base, ast.Attribute):
base_classes.append(f"{base.value.id}.{base.attr}")
class_hierarchies[class_name] = {
'file': file_path,
'base_classes': base_classes
}
except:
pass
self.architecture_patterns['class_hierarchies'] = class_hierarchies
def analyze_dependencies(self):
"""Analyze dependency patterns in the codebase."""
dependencies = {}
for root, _, files in os.walk(self.codebase_path):
for file in files:
if file.endswith('.py'):
file_path = os.path.join(root, file)
relative_path = os.path.relpath(file_path, self.codebase_path)
with open(file_path, 'r', encoding='utf-8', errors='ignore') as f:
content = f.read()
# Extract imports
imports = []
try:
tree = ast.parse(content)
for node in ast.walk(tree):
if isinstance(node, ast.Import):
for name in node.names:
imports.append(name.name)
elif isinstance(node, ast.ImportFrom):
if node.module:
imports.append(node.module)
except:
pass
dependencies[relative_path] = imports
self.dependencies = dependencies
def analyze_coding_conventions(self):
"""Analyze coding conventions used in the codebase."""
# Sample a few Python files
python_files = []
for root, _, files in os.walk(self.codebase_path):
for file in files:
if file.endswith('.py'):
python_files.append(os.path.join(root, file))
# Sample up to 10 files
sample_files = python_files[:10] if len(python_files) > 10 else python_files
# Analyze conventions
naming_conventions = {
'class_style': {},
'function_style': {},
'variable_style': {}
}
indentation = {'spaces': 0, 'tabs': 0}
line_length = []
for file_path in sample_files:
with open(file_path, 'r', encoding='utf-8', errors='ignore') as f:
content = f.read()
lines = content.split('\n')
# Check indentation
for line in lines:
if line.startswith(' '):
indentation['spaces'] += 1
elif line.startswith('\t'):
indentation['tabs'] += 1
# Check line length
for line in lines:
line_length.append(len(line))
# Check naming conventions
try:
tree = ast.parse(content)
for node in ast.walk(tree):
if isinstance(node, ast.ClassDef):
style = self.determine_naming_style(node.name)
naming_conventions['class_style'][style] = naming_conventions['class_style'].get(style, 0) + 1
elif isinstance(node, ast.FunctionDef):
style = self.determine_naming_style(node.name)
naming_conventions['function_style'][style] = naming_conventions['function_style'].get(style, 0) + 1
elif isinstance(node, ast.Assign):
for target in node.targets:
if isinstance(target, ast.Name):
style = self.determine_naming_style(target.id)
naming_conventions['variable_style'][style] = naming_conventions['variable_style'].get(style, 0) + 1
except:
pass
# Determine dominant conventions
self.coding_conventions = {
'indentation': 'spaces' if indentation['spaces'] >= indentation['tabs'] else 'tabs',
'avg_line_length': sum(line_length) // len(line_length) if line_length else 80,
'class_naming': max(naming_conventions['class_style'].items(), key=lambda x: x[1])[0] if naming_conventions['class_style'] else 'pascal_case',
'function_naming': max(naming_conventions['function_style'].items(), key=lambda x: x[1])[0] if naming_conventions['function_style'] else 'snake_case',
'variable_naming': max(naming_conventions['variable_style'].items(), key=lambda x: x[1])[0] if naming_conventions['variable_style'] else 'snake_case'
}
def determine_naming_style(self, name):
"""Determine the naming style of a given identifier."""
if re.match(r'^[A-Z][a-zA-Z0-9]*$', name):
return 'pascal_case'
elif re.match(r'^[a-z][a-zA-Z0-9]*$', name):
return 'camel_case'
elif re.match(r'^[a-z][a-z0-9_]*$', name):
return 'snake_case'
elif re.match(r'^[A-Z][A-Z0-9_]*$', name):
return 'upper_snake_case'
else:
return 'other'
def find_similar_components(self, component_type, description):
"""Find existing components similar to the one we want to generate."""
prompt = f"""
Based on this description of a new component:
"{description}"
Find the most similar existing {component_type}s in this architecture:
{self.architecture_patterns.get('modules', {})}
Return the names of the 2-3 most similar components and briefly explain why they are similar.
"""
similar_components = self.llm(prompt)
return similar_components
def generate_component_skeleton(self, component_type, name, description):
"""Generate a skeleton for a new component based on architectural patterns."""
# Find similar components
similar_components = self.find_similar_components(component_type, description)
prompt = f"""
Generate a {component_type} skeleton for a new component called "{name}" that:
{description}
Similar existing components:
{similar_components}
Follow these coding conventions:
- Indentation: {self.coding_conventions.get('indentation', 'spaces')}
- Class naming: {self.coding_conventions.get('class_naming', 'pascal_case')}
- Function naming: {self.coding_conventions.get('function_naming', 'snake_case')}
- Variable naming: {self.coding_conventions.get('variable_naming', 'snake_case')}
Only include the code, no explanations.
"""
skeleton = self.llm(prompt)
return skeleton
def identify_dependencies(self, component_code):
"""Identify required dependencies for the component."""
prompt = f"""
Analyze this component code:
{component_code}
Based on the existing architecture and dependencies in the codebase:
{self.dependencies}
List the imports that would be required for this component.
Format as Python import statements.
"""
dependencies = self.llm(prompt)
return dependencies
def generate_architecture_aware_component(self, component_type, name, description):
"""Generate a new component that fits into the existing architecture."""
# Ensure architecture has been analyzed
if not self.architecture_patterns:
self.analyze_architecture()
# Generate component skeleton
skeleton = self.generate_component_skeleton(component_type, name, description)
# Identify dependencies
dependencies = self.identify_dependencies(skeleton)
# Combine dependencies and skeleton
component_code = f"{dependencies}\n\n{skeleton}"
return component_code
# Example usage
generator = ArchitectureAwareGenerator("./my_project")
generator.analyze_architecture()
new_component = generator.generate_architecture_aware_component(
component_type="class",
name="PaymentProcessor",
description="Process payments using various payment methods (credit card, PayPal, etc.) and integrate with the existing user account system. Should handle payment validation, processing, and error handling."
)
print("Generated component:")
print(new_component)
This example demonstrates architecture-aware code generation. It first analyzes the existing codebase to understand its architecture, including module structure, class hierarchies, dependencies, and coding conventions. It then uses this information to generate a new component that fits into the existing architecture, following the same patterns and conventions.
Architecture-aware generation is particularly valuable for large, complex codebases where consistency and integration are important. By understanding the existing architecture, the system can generate code that follows established patterns and integrates smoothly with existing components.
Challenges and Limitations
Context Window Limitations
One of the most significant challenges when using LLMs for large codebases is the limited context window. Most LLMs can only process a fixed amount of text at once, typically ranging from a few thousand to a few hundred thousand tokens. This limitation makes it difficult to analyze or generate code for large systems that may contain millions of lines of code.
Several techniques can help address context window limitations:
- Chunking and retrieval: Breaking the codebase into smaller chunks and retrieving only the relevant ones when needed (as in RAG systems)
- Hierarchical processing: Analyzing code at different levels of abstraction, from high-level architecture to detailed implementation
- Compression: Using techniques like semantic compression to fit more code into the context window
- Iterative processing: Processing the codebase in multiple passes, carrying forward important information from each pass
Despite these techniques, context window limitations remain a significant challenge, particularly for tasks that require understanding complex relationships across a large codebase.
Accuracy and Hallucination Issues
LLMs can sometimes generate plausible-sounding but incorrect code or analysis, a phenomenon known as "hallucination." This is particularly problematic for code generation, where correctness is critical.
Several approaches can help mitigate hallucination issues:
- Grounding: Providing the LLM with accurate, relevant information from the codebase to ground its responses
- Verification: Using automated tests or static analysis to verify generated code
- Iterative refinement: Generating code in stages with verification at each stage
- Explicit uncertainty: Encouraging the LLM to express uncertainty when it doesn't have enough information
While these techniques can help, hallucination remains a challenge, especially for complex code generation tasks where the LLM may need to make assumptions about the system's behavior.
Performance Considerations
Working with large codebases and LLMs can be computationally expensive and time-consuming. Processing a large codebase may require significant computational resources, and generating complex code may take multiple iterations.
Performance optimizations include:
- Caching: Storing intermediate results to avoid redundant processing
- Parallelization: Processing independent parts of the codebase in parallel
- Incremental analysis: Only analyzing parts of the codebase that have changed
- Efficient retrieval: Using optimized vector databases and retrieval algorithms to quickly find relevant code
Balancing performance with accuracy and completeness is an ongoing challenge, particularly for real-time applications where developers expect quick responses.
Security Concerns
Using LLMs for code analysis and generation raises several security concerns:
- Sensitive information: LLMs might inadvertently leak sensitive information from the codebase
- Vulnerable code: LLMs might generate code with security vulnerabilities
- Dependency issues: Generated code might introduce dependencies with known vulnerabilities
- Data privacy: Sending code to external LLM services might violate data privacy policies
Addressing these concerns requires careful consideration of the deployment model (local vs. cloud-based), access controls, and security validation of generated code.
Future Directions and Emerging Techniques
Multi-Agent Systems
Multi-agent systems involve multiple specialized agents working together to analyze or generate code. Each agent might focus on a specific aspect of the codebase, such as security, performance, or architecture.
These systems can potentially overcome the limitations of single-agent approaches by distributing the workload and leveraging specialized expertise. For example, one agent might focus on understanding the high-level architecture, while another focuses on implementation details.
Hybrid Human-AI Workflows
Rather than fully automating code analysis or generation, hybrid workflows combine the strengths of humans and AI. Developers might use LLMs to generate initial code or identify potential issues, then refine and validate the results themselves.
These workflows leverage the creativity and contextual understanding of human developers while benefiting from the speed and pattern recognition capabilities of LLMs.
Specialized Code Models
While general-purpose LLMs have shown impressive capabilities for code-related tasks, specialized models trained specifically for code analysis or generation might offer better performance and accuracy.
These models might be trained on specific programming languages, frameworks, or domains, allowing them to develop deeper understanding of relevant patterns and conventions.
Continuous Learning Systems
Continuous learning systems improve over time by learning from their interactions with codebases and developers. They might track which suggestions were accepted or rejected, which generated code passed tests, and which analyses were found to be useful.
By continuously refining their understanding and capabilities, these systems can become increasingly valuable tools for working with large codebases.
Conclusion
Large Language Models have opened new possibilities for analyzing and generating code at scales that were previously impractical. Techniques like RAG, GraphRAG, semantic partitioning, and agent-based approaches help overcome the inherent limitations of LLMs when working with large codebases. Similarly, approaches like iterative refinement, test-driven generation, and architecture-aware generation enable the creation of high-quality code that integrates well with existing systems.
Despite these advances, significant challenges remain, including context window limitations, accuracy concerns, performance issues, and security considerations. Addressing these challenges will require continued innovation in techniques for processing and generating code with LLMs.
Looking ahead, emerging approaches like multi-agent systems, hybrid human-AI workflows, specialized code models, and continuous learning systems promise to further enhance the capabilities of LLMs for code-related tasks.
As these technologies mature, they have the potential to transform how developers work with large codebases, making it easier to understand, maintain, and extend complex software systems. However, they are likely to augment rather than replace human developers, offering tools that enhance productivity and quality while still relying on human judgment and expertise for critical decisions.
No comments:
Post a Comment