Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
199 changes: 131 additions & 68 deletions opto/features/priority_search/module_regressor.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,115 @@
import numpy as np
import copy
from typing import Union
from opto.trainer.loader import DataLoader
from opto.trainer.utils import batch_run, async_run
from opto.optimizers.utils import print_color
# from opto.trainer.evaluators import evaluate
from typing import Union, List, Tuple, Dict, Any, Optional
from collections import deque
from opto.utils.llm import LLM # For the selector LLM
# from opto.trace.nodes import ParameterNode
import json
# import warnings
# from black import format_str, FileMode
import random
# import mathX
from opto.utils.auto_retry import retry_with_exponential_backoff
import litellm
import time
from opto.features.priority_search.priority_search import ModuleCandidate


def embed_text(model, text):
"""Call the embedding API for a given model and text string.

This is a standalone function so users can easily replace it with a custom
embedding provider (e.g. local model, different API) without subclassing.
Must return a litellm-compatible response with response.data[0].embedding.
"""
return litellm.embedding(model=model, input=text)


class RegressorTemplate:
"""Base class template for regression-based predictors for ModuleCandidate objects.

Provides common functionality for embedding generation and candidate processing.
Subclasses should implement update() and predict_scores() methods.

Regressors can be built on this template by implementing the update() and predict_scores() methods.
This class itself is enough for getting embeddings for candidates.
"""

def __init__(self, embedding_model="gemini/gemini-embedding-001", num_threads=None, regularization_strength=1, linear_dim=None, rich_text=True,verbose: bool = False, max_candidates_to_predict=500,original_embedding_dim=768):
'''
Args:
embedding_model: The embedding model to use.
num_threads: The number of threads to use for the embedding generation.
regularization_strength: The regularization strength for the logistic regression.
linear_dim: The dimension of the linear space.
rich_text: Whether to use rich text for the parameter text.
verbose: Whether to print the verbose output.
max_candidates_to_predict: The maximum number of candidates to predict.
original_embedding_dim: The original dimension of the embedding.
'''
def _get_parameter_text(self, candidate):
"""Get the parameter text for a ModuleCandidate."""
if not hasattr(candidate, 'update_dict'):
print(candidate)
assert hasattr(candidate, 'update_dict'), "ModuleCandidate must have an update_dict"
# Convert parameter nodes to readable names for deterministic embedding
params_with_names = {k.py_name: v for k, v in candidate.update_dict.items()}
return str(params_with_names)


def _get_embedding(self, candidate,max_retries=10,base_delay=1.0):
"""Get the embedding for a ModuleCandidate."""
parameter_text = self._get_parameter_text(candidate)

try:
response = retry_with_exponential_backoff(
lambda: embed_text(self.embedding_model, parameter_text),
max_retries=max_retries,
base_delay=base_delay,
operation_name="Embedding API call"
)
embedding = response.data[0].embedding
if self.random_projector is not None:
embedding_array = np.array(embedding).reshape(1, -1)
projected = self.random_projector.transform(embedding_array)
embedding = projected.flatten().tolist()
return embedding
except Exception as e:
print_color(f"ERROR: Embedding API call failed after retries: {e}", "red")
return None

def add_embeddings_to_candidates(self, candidates: List[ModuleCandidate]):
"""Add embeddings to a list of candidates. This function could be used outside."""
self._update_memory_embeddings_for_batch(candidates)

def _update_memory_embeddings_for_batch(self, batch,max_workers=50,max_retries=10,base_delay=1.0):
"""Update the embeddings for a batch of candidates."""
# Separate candidates that need embeddings from those that already have them
candidates_needing_embeddings = []
for candidate in batch:
if not hasattr(candidate, "embedding"):
candidates_needing_embeddings.append(candidate)

# Generate embeddings in parallel for candidates that need them
if candidates_needing_embeddings:
def get_embedding_for_candidate(candidate):
return self._get_embedding(candidate)

# Create function list for async_run
embedding_functions = [lambda c=candidate: get_embedding_for_candidate(c)
for candidate in candidates_needing_embeddings]

# Run embedding generation in parallel
new_embeddings = async_run(
embedding_functions,
max_workers=max_workers,
description=f"Generating embeddings for {len(candidates_needing_embeddings)} candidates"
)

# Assign embeddings back to candidates
for candidate, embedding in zip(candidates_needing_embeddings, new_embeddings):
candidate.embedding = embedding

def update(self, memory: List[Tuple[float, ModuleCandidate]]):
"""Update the regression model parameters. Should be implemented by subclasses."""
raise NotImplementedError("Subclasses must implement the update method")

def predict_scores(self, memory: List[Tuple[float, ModuleCandidate]]):
"""Predict scores for candidates. Should be implemented by subclasses."""
raise NotImplementedError("Subclasses must implement the predict_scores method")

class ModuleCandidateRegressor:
"""
Expand All @@ -25,9 +118,8 @@ class ModuleCandidateRegressor:
predict_scores has no parameters, it could return predicted scores for all candidates in the memory.
predict_scores_for_batch has one parameter, a batch of candidates, it could return predicted scores for the batch of candidates."""

def __init__(self, memory=None, embedding_model="gemini/text-embedding-004", num_threads=None, learning_rate=0.2, regularization_strength=1e-4, max_iterations=20000, tolerance=5e-3):
# In the regressor, no need for calling LLM to make the prediction. So we could predict the entire memory at once.
self.max_candidates_to_predict = 500
def __init__(self, memory=None, embedding_model="gemini/text-embedding-004", num_threads=None, learning_rate=0.2, regularization_strength=1e-4, max_iterations=20000, tolerance=5e-3, max_candidates_to_predict=500, original_embedding_dim=768,patience=20,lr_decay_factor=0.8):
self.max_candidates_to_predict = max_candidates_to_predict
self.memory = memory
self.embedding_model = embedding_model
self.num_threads = num_threads
Expand All @@ -36,10 +128,9 @@ def __init__(self, memory=None, embedding_model="gemini/text-embedding-004", num
self.regularization_strength = regularization_strength # L2 regularization strength (lambda)
self.max_iterations = max_iterations
self.tolerance = tolerance
self.patience = 20 # Early stopping patience
self.lr_decay_factor = 0.8 # Learning rate decay factor
# default linear dimension is 768
self.linear_dim = 768
self.patience = patience # Early stopping patience
self.lr_decay_factor = lr_decay_factor # Learning rate decay factor
self.linear_dim = original_embedding_dim
# Initialize weights with larger values for more aggressive learning
self.weights = np.random.normal(0, 0.1, self.linear_dim)
self.bias = 0.0
Expand All @@ -50,42 +141,33 @@ def _sigmoid(self, z):

def _get_parameter_text(self, candidate):
"""Get the parameter text for a ModuleCandidate."""
if not candidate.update_dict:
# If update_dict is empty, use a default text or base module info
return "base_module_parameters"

# Get the first value from update_dict (similar to additional_instructions)
# TODO: support for multiple parameters
parameter_text = list(candidate.update_dict.values())[0]
return str(parameter_text)
if not hasattr(candidate, 'update_dict'):
print(candidate)
assert hasattr(candidate, 'update_dict'), "ModuleCandidate must have an update_dict"
# Convert parameter nodes to readable names for deterministic embedding
params_with_names = {k.py_name: v for k, v in candidate.update_dict.items()}
return str(params_with_names)

def _get_embedding(self, candidate):
def _get_embedding(self, candidate,max_retries=10,base_delay=1.0):
"""Get the embedding for a ModuleCandidate."""
parameter_text = self._get_parameter_text(candidate)

def single_embedding_call():
return litellm.embedding(
model=self.embedding_model,
input=parameter_text
)

try:
response = retry_with_exponential_backoff(
single_embedding_call,
max_retries=10,
base_delay=1.0,
lambda: embed_text(self.embedding_model, parameter_text),
max_retries=max_retries,
base_delay=base_delay,
operation_name="Embedding API call"
)
embedding = response.data[0].embedding
return embedding
except Exception as e:
print_color(f"ERROR: Embedding API call failed after retries: {e}", "red")
# Return a random embedding as fallback to prevent complete failure
print_color("Using random embedding as fallback", "yellow")
fallback_embedding = np.random.normal(0, 0.01, self.linear_dim)
return fallback_embedding / np.linalg.norm(fallback_embedding)

def _update_memory_embeddings_for_batch(self, batch):
def _update_memory_embeddings_for_batch(self, batch,max_workers=1000,max_retries=10,base_delay=1.0):
"""Update the embeddings for a batch of candidates."""
# Separate candidates that need embeddings from those that already have them
candidates_needing_embeddings = []
Expand All @@ -105,7 +187,7 @@ def get_embedding_for_candidate(candidate):
# Run embedding generation in parallel
new_embeddings = async_run(
embedding_functions,
max_workers=1000,
max_workers=max_workers,
description=f"Generating embeddings for {len(candidates_needing_embeddings)} candidates"
)

Expand All @@ -116,7 +198,8 @@ def get_embedding_for_candidate(candidate):
def update(self):
"""Update the regression model parameters using the current memory with logistic regression."""
start_time = time.time()
print_color("Updating regression model using the current memory with logistic regression...", "blue")
if self.verbose:
print_color("Updating regression model using the current memory with logistic regression...", "blue")
# Extract candidates from memory (memory contains (neg_score, candidate) tuples)
batch = [candidate for _, candidate in self.memory]
# Ensure all candidates have embeddings
Expand All @@ -126,10 +209,12 @@ def update(self):
training_candidates = [candidate for neg_score, candidate in self.memory if candidate.num_rollouts > 0 and candidate.mean_score() is not None]

if len(training_candidates) == 0:
print_color("Warning: No training data available for regression model.", "yellow")
if self.verbose:
print_color("Warning: No training data available for regression model.", "yellow")
end_time = time.time()
elapsed_time = end_time - start_time
print_color(f"Regressor update completed in {elapsed_time:.4f} seconds (no training data)", "cyan")
if self.verbose:
print_color(f"Regressor update completed in {elapsed_time:.4f} seconds (no training data)", "cyan")
return

# Extract raw binary training data from each candidate
Expand Down Expand Up @@ -169,7 +254,8 @@ def update(self):
print_color("Warning: No binary training samples generated.", "yellow")
end_time = time.time()
elapsed_time = end_time - start_time
print_color(f"Regressor update completed in {elapsed_time:.4f} seconds (no binary samples)", "cyan")
if self.verbose:
print_color(f"Regressor update completed in {elapsed_time:.4f} seconds (no binary samples)", "cyan")
return

# Convert to numpy arrays
Expand All @@ -183,21 +269,7 @@ def update(self):
self.weights = np.random.normal(0, 0.1, self.linear_dim)

# Convergence-based regularized logistic regression training using all raw binary data
m = len(X_list)
# print_color(f"Training regularized logistic regression with {m} binary samples from {len(training_candidates)} candidates until convergence.", "blue")
# print_color(f"Using L2 regularization strength: {self.regularization_strength}, learning rate: {self.learning_rate}", "blue")
# print_color(f"Max iterations: {self.max_iterations}, tolerance: {self.tolerance}", "blue")

# Debug: Print initial weight statistics
initial_weight_norm = np.linalg.norm(self.weights)
# print_color(f"Initial weight norm: {initial_weight_norm:.6f}", "yellow")

# Debug: Print embedding statistics
embedding_mean = np.mean(X)
embedding_std = np.std(X)
embedding_norm_mean = np.mean([np.linalg.norm(row) for row in X])
# print_color(f"Embedding stats - mean: {embedding_mean:.6f}, std: {embedding_std:.6f}, avg norm: {embedding_norm_mean:.6f}", "yellow")

m = len(X_list)
# Training loop until convergence with adaptive learning rate and early stopping
prev_cost = float('inf')
best_cost = float('inf')
Expand Down Expand Up @@ -253,15 +325,6 @@ def update(self):
self.weights -= self.learning_rate * dw
self.bias -= self.learning_rate * db

# Print progress periodically
# if iteration == 0 or (iteration + 1) % max(1, min(50, self.max_iterations // 20)) == 0:
# z_mean, z_std = np.mean(z), np.std(z)
# weight_norm = np.linalg.norm(self.weights)
# print_color(f"Iteration {iteration + 1}: Cost: {total_cost:.6f} (change: {cost_change:.8f}), LR: {self.learning_rate:.6f}, Weight norm: {weight_norm:.6f}, Gradient norm: {gradient_norm:.8f}", "cyan")
# print_color(f" Logits - mean: {z_mean:.6f}, std: {z_std:.6f}, range: [{np.min(z):.6f}, {np.max(z):.6f}]", "cyan")
# print_color(f" Predictions - range: [{np.min(predictions):.6f}, {np.max(predictions):.6f}], mean: {np.mean(predictions):.6f}", "cyan")
# print_color(f" Patience: {patience_counter}/{self.patience}", "cyan")

prev_cost = total_cost

# Final status
Expand Down
Loading
Loading