"""Base class for all benchmarks.
Copyright 2026 Joe Harris / BenchBox Project
Licensed under the MIT License. See LICENSE file in the project root for details.
"""
import logging
import time
from abc import ABC, abstractmethod
from pathlib import Path
from typing import TYPE_CHECKING, Any, Optional, Union
from benchbox.core.connection import DatabaseConnection
from benchbox.utils.cloud_storage import create_path_handler
from benchbox.utils.scale_factor import format_scale_factor
from benchbox.utils.verbosity import VerbosityMixin, compute_verbosity
if TYPE_CHECKING:
import sqlglot
from benchbox.core.results.models import BenchmarkResults
else:
try:
import sqlglot
except ImportError:
sqlglot = None # type: ignore
[docs]
class BaseBenchmark(VerbosityMixin, ABC):
"""Base class for all benchmarks.
All benchmarks inherit from this class.
"""
[docs]
def __init__(
self,
scale_factor: float = 1.0,
output_dir: Optional[Union[str, Path]] = None,
**kwargs: Any,
) -> None:
"""Initialize a benchmark.
Args:
scale_factor: Scale factor (1.0 = standard size)
output_dir: Data output directory
**kwargs: Additional options
"""
if scale_factor <= 0:
raise ValueError("Scale factor must be positive")
# Validate that scale factors >= 1 are whole integers
if scale_factor >= 1 and scale_factor != int(scale_factor):
raise ValueError(
f"Scale factors >= 1 must be whole integers. Got: {scale_factor}. "
f"Use values like 1, 2, 10, etc. for large scale factors. "
f"Use values like 0.1, 0.01, 0.001, etc. for small scale factors."
)
self.scale_factor = scale_factor
if output_dir is None:
# Use CLI-compatible default path: benchmark_runs/datagen/{benchmark}_{sf}
benchmark_name = self._get_benchmark_name().lower()
sf_str = format_scale_factor(scale_factor)
self.output_dir = Path.cwd() / "benchmark_runs" / "datagen" / f"{benchmark_name}_{sf_str}"
else:
# Support both local and cloud storage paths
self.output_dir = create_path_handler(output_dir)
# Verbosity and quiet handling (normalize bool/int)
verbose_value = kwargs.pop("verbose", 0)
quiet_value = kwargs.pop("quiet", False)
verbosity_settings = compute_verbosity(verbose_value, quiet_value)
self.apply_verbosity(verbosity_settings)
# Logger for core benchmarks
self.logger = logging.getLogger(f"benchbox.core.{self._get_benchmark_name()}")
# Store remaining kwargs as attributes
for key, value in kwargs.items():
setattr(self, key, value)
def _validate_scale_factor_type(self, scale_factor: float) -> None:
"""Validate scale factor is a number (int or float).
Args:
scale_factor: Scale factor to validate
Raises:
TypeError: If scale_factor is not a number
"""
if not isinstance(scale_factor, (int, float)):
raise TypeError(f"scale_factor must be a number, got {type(scale_factor).__name__}")
def _initialize_benchmark_implementation(
self,
implementation_class,
scale_factor: float,
output_dir: Optional[Union[str, Path]],
**kwargs,
):
"""Common initialization pattern for benchmark implementations.
Args:
implementation_class: The benchmark implementation class to instantiate
scale_factor: Scale factor for the benchmark
output_dir: Directory to output generated data files
**kwargs: Additional implementation-specific options
"""
# Extract verbose and force_regenerate from kwargs to avoid passing them twice
verbose = kwargs.pop("verbose", False)
force_regenerate = kwargs.pop("force_regenerate", False)
self._impl = implementation_class(
scale_factor=scale_factor,
output_dir=output_dir,
verbose=verbose,
force_regenerate=force_regenerate,
**kwargs,
)
def _get_benchmark_name(self) -> str:
"""Derive benchmark name from class name.
Maps class names to benchmark names:
- TPCH -> tpch
- TPCDS -> tpcds
- ClickBench -> clickbench
- AMPLab -> amplab
- H2ODB -> h2odb
- JoinOrder -> joinorder
- TPCHavoc -> tpchavoc
- NYCTaxi -> nyctaxi
- etc.
Returns:
Lowercase benchmark name suitable for directory paths
"""
class_name = self.__class__.__name__
# Handle special cases first
special_mappings = {
"ClickBench": "clickbench",
"AMPLab": "amplab",
"AMPLabBenchmark": "amplab",
"H2ODB": "h2odb",
"H2OBenchmark": "h2odb",
"JoinOrder": "joinorder",
"JoinOrderBenchmark": "joinorder",
"TPCHavoc": "tpchavoc",
"TPCHBenchmark": "tpch",
"TPCDSBenchmark": "tpcds",
"TPCDIBenchmark": "tpcdi",
"SSBBenchmark": "ssb",
"PrimitivesBenchmark": "primitives",
"MergeBenchmark": "merge",
"CoffeeShopBenchmark": "coffeeshop",
"NYCTaxi": "nyctaxi",
"NYCTaxiBenchmark": "nyctaxi",
"TSBSDevOps": "tsbs_devops",
"TSBSDevOpsBenchmark": "tsbs_devops",
}
if class_name in special_mappings:
return special_mappings[class_name]
# For standard cases, just lowercase
# TPCH -> tpch, TPCDS -> tpcds, etc.
return class_name.lower()
[docs]
def get_data_source_benchmark(self) -> Optional[str]:
"""Return the canonical source benchmark when data is shared.
Benchmarks that reuse data generated by another benchmark (for example,
``Primitives`` reusing ``TPC-H`` datasets) should override this method and
return the lower-case identifier of the source benchmark. Benchmarks that
produce their own data should return ``None`` (default).
"""
return None
[docs]
@abstractmethod
def generate_data(self) -> list[Union[str, Path]]:
"""Generate benchmark data.
Returns:
List of data file paths
"""
[docs]
@abstractmethod
def get_queries(self) -> dict[str, str]:
"""Get all benchmark queries.
Returns:
Dictionary mapping query IDs to query strings
"""
[docs]
@abstractmethod
def get_query(self, query_id: Union[int, str], *, params: Optional[dict[str, Any]] = None) -> str:
"""Get a benchmark query.
Args:
query_id: Query ID
params: Optional parameters
Returns:
Query string with parameters resolved
Raises:
ValueError: If query_id is invalid
"""
def _load_data(self, connection: DatabaseConnection) -> None:
"""Load benchmark data into database.
Each benchmark implements this method for its data loading.
Args:
connection: Database connection
Raises:
NotImplementedError: If not implemented by subclass
"""
# Default implementation - benchmarks should override
# Non-abstract for backward compatibility
# Raises NotImplementedError if not overridden
raise NotImplementedError(
f"{self.__class__.__name__} must implement _load_data() method to support database execution functionality"
)
[docs]
def setup_database(self, connection: DatabaseConnection) -> None:
"""Set up database with schema and data.
Creates necessary database schema and loads
benchmark data into the database.
Args:
connection: Database connection to set up
Raises:
ValueError: If data generation fails
Exception: If database setup fails
"""
logger = logging.getLogger(__name__)
try:
logger.info("Setting up database schema and loading data...")
start_time = time.time()
# Generate data if not already generated
if not hasattr(self, "_data_generated") or not self._data_generated:
logger.info("Generating benchmark data...")
self.generate_data()
self._data_generated = True
# Load data into database
self._load_data(connection)
setup_time = time.time() - start_time
logger.info(f"Database setup completed in {setup_time:.2f} seconds")
except Exception as e:
logger.error(f"Database setup failed: {str(e)}")
raise
[docs]
def run_query(
self,
query_id: Union[int, str],
connection: DatabaseConnection,
params: Optional[dict[str, Any]] = None,
fetch_results: bool = False,
) -> dict[str, Any]:
"""Execute single query and return timing and results.
Args:
query_id: ID of the query to execute
connection: Database connection to execute query on
params: Optional parameters for query customization
fetch_results: Whether to fetch and return query results
Returns:
Dictionary containing:
- query_id: Executed query ID
- execution_time: Time taken to execute query in seconds
- query_text: Executed query text
- results: Query results if fetch_results=True, otherwise None
- row_count: Number of rows returned (if results fetched)
Raises:
ValueError: If query_id is invalid
Exception: If query execution fails
"""
logger = logging.getLogger(__name__)
try:
# Get the query text
query_text = self.get_query(query_id, params=params)
logger.debug(f"Executing query {query_id}")
start_time = time.time()
# Execute the query
cursor = connection.execute(query_text)
# Fetch results if requested
results = None
row_count = 0
if fetch_results:
results = connection.fetchall(cursor)
row_count = len(results) if results else 0
execution_time = time.time() - start_time
logger.info(f"Query {query_id} completed in {execution_time:.3f} seconds")
if fetch_results:
logger.debug(f"Query {query_id} returned {row_count} rows")
return {
"query_id": query_id,
"execution_time": execution_time,
"query_text": query_text,
"results": results,
"row_count": row_count,
}
except Exception as e:
logger.error(f"Query {query_id} execution failed: {str(e)}")
raise
[docs]
def run_benchmark(
self,
connection: DatabaseConnection,
query_ids: Optional[list[Union[int, str]]] = None,
fetch_results: bool = False,
setup_database: bool = True,
) -> dict[str, Any]:
"""Run the complete benchmark suite.
Args:
connection: Database connection to execute queries on
query_ids: Optional list of specific query IDs to run (defaults to all)
fetch_results: Whether to fetch and return query results
setup_database: Whether to set up the database first
Returns:
Dictionary containing:
- benchmark_name: Name of the benchmark
- total_execution_time: Total time for all queries
- total_queries: Number of queries executed
- successful_queries: Number of queries that succeeded
- failed_queries: Number of queries that failed
- query_results: List of individual query results
- setup_time: Time taken for database setup (if performed)
Raises:
Exception: If benchmark execution fails
"""
logger = logging.getLogger(__name__)
benchmark_start_time = time.time()
setup_time = 0.0
try:
# Set up database if requested
if setup_database:
logger.info("Setting up database for benchmark...")
setup_start_time = time.time()
self.setup_database(connection)
setup_time = time.time() - setup_start_time
# Determine which queries to run
if query_ids is None:
all_queries = self.get_queries()
query_ids = list(all_queries.keys())
logger.info(f"Running benchmark with {len(query_ids)} queries...")
# Execute all queries
query_results = []
successful_queries = 0
failed_queries = 0
for query_id in query_ids:
try:
result = self.run_query(query_id, connection, fetch_results=fetch_results)
query_results.append(result)
successful_queries += 1
except Exception as e:
logger.error(f"Query {query_id} failed: {str(e)}")
failed_queries += 1
# Include failed query in results
query_results.append(
{
"query_id": query_id,
"execution_time": 0.0,
"query_text": None,
"results": None,
"row_count": 0,
"error": str(e),
}
)
total_execution_time = time.time() - benchmark_start_time - setup_time
# Calculate summary statistics
query_times = [r["execution_time"] for r in query_results if "error" not in r]
benchmark_result = {
"benchmark_name": self.__class__.__name__,
"total_execution_time": total_execution_time,
"total_queries": len(query_ids),
"successful_queries": successful_queries,
"failed_queries": failed_queries,
"query_results": query_results,
"setup_time": setup_time,
"average_query_time": sum(query_times) / len(query_times) if query_times else 0.0,
"min_query_time": min(query_times) if query_times else 0.0,
"max_query_time": max(query_times) if query_times else 0.0,
}
logger.info(f"Benchmark completed: {successful_queries}/{len(query_ids)} queries successful")
logger.info(f"Total execution time: {total_execution_time:.2f} seconds")
return benchmark_result
except Exception as e:
logger.error(f"Benchmark execution failed: {str(e)}")
raise
def _get_default_benchmark_type(self) -> str:
"""Get the default benchmark type for platform optimizations.
Subclasses can override this to specify their workload characteristics.
Returns:
Default benchmark type ('olap', 'oltp', 'mixed', 'analytics')
"""
# Most benchmarks in BenchBox are OLAP-focused
return "olap"
def _format_time(self, seconds: float) -> str:
"""Format execution time for display.
Args:
seconds: Time in seconds
Returns:
Formatted time string
"""
if seconds < 1:
return f"{seconds * 1000:.1f}ms"
elif seconds < 60:
return f"{seconds:.2f}s"
else:
minutes = int(seconds // 60)
remaining_seconds = seconds % 60
return f"{minutes}m {remaining_seconds:.1f}s"
[docs]
def translate_query(self, query_id: Union[int, str], dialect: str) -> str:
"""Translate a query to a specific SQL dialect.
Args:
query_id: The ID of the query to translate
dialect: The target SQL dialect
Returns:
The translated query string
Raises:
ValueError: If the query_id is invalid
ImportError: If sqlglot is not installed
ValueError: If the dialect is not supported
"""
if sqlglot is None:
raise ImportError("sqlglot is required for query translation. Install it with `pip install sqlglot`.")
from benchbox.utils.dialect_utils import normalize_dialect_for_sqlglot
query = self.get_query(query_id)
# Normalize dialect for SQLGlot compatibility
normalized_dialect = normalize_dialect_for_sqlglot(dialect)
# Apply translation for specific SQL syntax
try:
# Use identify=True to quote identifiers and prevent reserved keyword conflicts
translated = sqlglot.transpile( # type: ignore[attr-defined]
query, read="postgres", write=normalized_dialect, identify=True
)[0]
return translated
except (ValueError, AttributeError) as e:
raise ValueError(f"Error translating to dialect '{dialect}': {e}")
@property
def benchmark_name(self) -> str:
"""Get the human-readable benchmark name."""
# Try to get from implementation first, then fallback to class name
if hasattr(self, "_impl"):
if hasattr(self._impl, "_name"):
return self._impl._name
elif hasattr(self._impl, "benchmark_name"):
return self._impl.benchmark_name
# For classes without _impl (like core implementation classes)
return getattr(self, "_name", type(self).__name__)
[docs]
def create_enhanced_benchmark_result(
self,
platform: str,
query_results: list[dict[str, Any]],
execution_metadata: Optional[dict[str, Any]] = None,
phases: Optional[dict[str, dict[str, Any]]] = None,
resource_utilization: Optional[dict[str, Any]] = None,
performance_characteristics: Optional[dict[str, Any]] = None,
**kwargs: Any,
) -> "BenchmarkResults":
"""Create a BenchmarkResults object with standardized fields.
This centralizes the logic for creating benchmark results that was previously
duplicated across platform adapters and CLI orchestrator.
Args:
platform: Platform name (e.g., "DuckDB", "ClickHouse")
query_results: List of query execution results
execution_metadata: Optional execution metadata
phases: Optional phase tracking information
resource_utilization: Optional resource usage metrics
performance_characteristics: Optional performance analysis
**kwargs: Additional fields to override defaults
Returns:
Fully configured BenchmarkResults object
"""
# Delegate to implementation's create method if available, otherwise use our own
if hasattr(self, "_impl") and hasattr(self._impl, "create_enhanced_benchmark_result"):
return self._impl.create_enhanced_benchmark_result(
platform=platform,
query_results=query_results,
execution_metadata=execution_metadata,
phases=phases,
resource_utilization=resource_utilization,
performance_characteristics=performance_characteristics,
**kwargs,
)
# Fallback implementation for wrapper class
import time
import uuid
from datetime import datetime
# Calculate basic metrics from query results
total_queries = len(query_results)
successful_queries = len([r for r in query_results if r.get("status") == "SUCCESS"])
failed_queries = total_queries - successful_queries
# Calculate timing metrics
successful_results = [r for r in query_results if r.get("status") == "SUCCESS"]
total_execution_time = sum(r.get("execution_time", 0.0) for r in successful_results)
average_query_time = total_execution_time / max(successful_queries, 1)
# Generate standard identifiers
execution_id = kwargs.get(
"execution_id",
f"{self.benchmark_name.lower().replace(' ', '_')}_{int(time.time())}",
)
timestamp = kwargs.get("timestamp", datetime.now())
# Import required classes for creating proper results (from core)
from benchbox.core.results.models import (
BenchmarkResults,
DataGenerationPhase,
DataLoadingPhase,
ExecutionPhases,
QueryDefinition,
SchemaCreationPhase,
SetupPhase,
ValidationPhase,
)
# Create query definitions from query results
query_definitions = {}
if hasattr(self, "get_queries"):
queries = self.get_queries()
for query_id, sql in queries.items():
query_definitions.setdefault("stream_1", {})[query_id] = QueryDefinition(sql=sql, parameters=None)
# Create basic execution phases with proper structure and timing
data_gen_time = kwargs.get("data_generation_time", 0.0)
schema_time = kwargs.get("schema_creation_time", 0.0)
loading_time = kwargs.get("data_loading_time", 0.0)
validation_time = kwargs.get("validation_time", 0.0)
setup_phase = SetupPhase(
data_generation=DataGenerationPhase(
duration_ms=int(data_gen_time * 1000),
status="SUCCESS",
tables_generated=kwargs.get("tables_generated", 0),
total_rows_generated=kwargs.get("total_rows_generated", 0),
total_data_size_bytes=kwargs.get("total_data_size_bytes", 0),
per_table_stats={},
),
schema_creation=SchemaCreationPhase(
duration_ms=int(schema_time * 1000),
status="SUCCESS",
tables_created=kwargs.get("tables_created", 0),
constraints_applied=kwargs.get("constraints_applied", 0),
indexes_created=kwargs.get("indexes_created", 0),
per_table_creation={},
),
data_loading=DataLoadingPhase(
duration_ms=int(loading_time * 1000),
status="SUCCESS",
total_rows_loaded=kwargs.get("total_rows_loaded", 0),
tables_loaded=kwargs.get("tables_loaded", 0),
per_table_stats={},
),
validation=ValidationPhase(
duration_ms=int(validation_time * 1000),
row_count_validation=kwargs.get("row_count_validation", "SKIPPED"),
schema_validation=kwargs.get("schema_validation", "SKIPPED"),
data_integrity_checks=kwargs.get("data_integrity_checks", "SKIPPED"),
validation_details=kwargs.get("validation_details", {}),
),
)
# Use provided execution phases if available, otherwise create basic setup phase
provided_phases = phases or kwargs.get("execution_phases") or kwargs.get("phases")
if provided_phases and hasattr(provided_phases, "power_test"):
# Type checker doesn't understand the hasattr check narrows the type
assert isinstance(provided_phases, ExecutionPhases)
execution_phases = provided_phases
else:
execution_phases = ExecutionPhases(setup=setup_phase)
# Create result with proper structured data
result = BenchmarkResults(
# Core benchmark info
benchmark_name=self.benchmark_name,
platform=platform,
scale_factor=self.scale_factor,
execution_id=execution_id,
timestamp=timestamp,
duration_seconds=kwargs.get("duration_seconds", total_execution_time),
# Required structured data
query_results=query_results,
query_definitions=query_definitions,
execution_phases=execution_phases,
# Summary metrics (calculated from phases)
total_queries=total_queries,
successful_queries=successful_queries,
failed_queries=failed_queries,
total_execution_time=total_execution_time,
average_query_time=average_query_time,
# Setup metrics (from execution phases)
data_loading_time=loading_time,
schema_creation_time=schema_time,
total_rows_loaded=kwargs.get("total_rows_loaded", 0),
data_size_mb=kwargs.get("total_data_size_bytes", 0) / (1024 * 1024),
table_statistics=kwargs.get("table_statistics", {}),
# Optional fields with defaults
test_execution_type=kwargs.get("test_execution_type", "standard"),
validation_status=kwargs.get("validation_status", "UNKNOWN"),
validation_details=kwargs.get("validation_details", {}),
platform_info=kwargs.get("platform_info"),
tunings_applied=kwargs.get("tunings_applied", {}),
tuning_validation_status=kwargs.get("tuning_validation_status", "NOT_APPLIED"),
tuning_metadata_saved=kwargs.get("tuning_metadata_saved", False),
system_profile=kwargs.get("system_profile", {}),
anonymous_machine_id=kwargs.get("anonymous_machine_id", str(uuid.uuid4())[:8]),
execution_metadata=execution_metadata or {},
)
if performance_characteristics is not None:
result.performance_characteristics = performance_characteristics
else:
result.performance_characteristics = {}
snapshot_payload = kwargs.get("performance_snapshot")
if snapshot_payload is not None:
try:
from benchbox.monitoring.performance import PerformanceSnapshot, attach_snapshot_to_result
if isinstance(snapshot_payload, PerformanceSnapshot):
attach_snapshot_to_result(result, snapshot_payload)
elif isinstance(snapshot_payload, dict):
result.performance_summary = dict(snapshot_payload)
if not result.performance_characteristics:
result.performance_characteristics = dict(snapshot_payload)
except ImportError:
if isinstance(snapshot_payload, dict):
result.performance_summary = dict(snapshot_payload)
if not result.performance_characteristics:
result.performance_characteristics = dict(snapshot_payload)
elif performance_characteristics and not result.performance_summary:
result.performance_summary = dict(performance_characteristics)
return result