"""Amazon Redshift platform adapter with S3 integration and data warehouse optimizations.
Provides Redshift-specific optimizations for analytical workloads,
including COPY command for efficient data loading and distribution key optimization.
Copyright 2026 Joe Harris / BenchBox Project
Licensed under the MIT License. See LICENSE file in the project root for details.
"""
from __future__ import annotations
import json
import time
from pathlib import Path
from typing import TYPE_CHECKING, Any
if TYPE_CHECKING:
from benchbox.core.tuning.interface import (
ForeignKeyConfiguration,
PlatformOptimizationConfiguration,
PrimaryKeyConfiguration,
UnifiedTuningConfiguration,
)
from ..core.exceptions import ConfigurationError
from ..utils.dependencies import (
check_platform_dependencies,
get_dependency_error_message,
get_dependency_group_packages,
)
from .base import PlatformAdapter
from .base.data_loading import FileFormatRegistry
try:
import redshift_connector
except ImportError:
try:
import psycopg2
redshift_connector = None
except ImportError:
psycopg2 = None
redshift_connector = None
try:
import boto3
from botocore.exceptions import ClientError, NoCredentialsError
except ImportError:
boto3 = None
[docs]
class RedshiftAdapter(PlatformAdapter):
"""Amazon Redshift platform adapter with S3 integration."""
[docs]
def __init__(self, **config):
super().__init__(**config)
dependency_packages = get_dependency_group_packages("redshift")
# Check dependencies - prefer redshift-connector, fallback to psycopg2
if not redshift_connector and not psycopg2:
available, missing = check_platform_dependencies("redshift")
if not available:
error_msg = get_dependency_error_message("redshift", missing)
raise ImportError(error_msg)
else:
# Ensure shared helper libraries (e.g., boto3, cloudpathlib) are available
shared_packages = [pkg for pkg in dependency_packages if pkg != "redshift-connector"]
if shared_packages:
available_shared, missing_shared = check_platform_dependencies("redshift", shared_packages)
if not available_shared:
error_msg = get_dependency_error_message("redshift", missing_shared)
raise ImportError(error_msg)
self._dialect = "redshift"
# Redshift connection configuration
self.host = config.get("host")
self.port = config.get("port") if config.get("port") is not None else 5439
self.database = config.get("database") or "dev"
self.username = config.get("username")
self.password = config.get("password")
self.cluster_identifier = config.get("cluster_identifier")
# Admin database for metadata operations (CREATE/DROP DATABASE, checking database existence)
# Redshift requires connecting to an existing database for admin operations
# Default: "dev" (Redshift Serverless default database)
self.admin_database = config.get("admin_database") or "dev"
# Schema configuration
self.schema = config.get("schema") or "public"
# Connection settings
self.connect_timeout = config.get("connect_timeout") if config.get("connect_timeout") is not None else 10
self.statement_timeout = config.get("statement_timeout") if config.get("statement_timeout") is not None else 0
self.sslmode = config.get("sslmode") or "require"
# WLM settings
self.wlm_query_slot_count = (
config.get("wlm_query_slot_count") if config.get("wlm_query_slot_count") is not None else 1
)
self.wlm_query_queue_name = config.get("wlm_query_queue_name")
# SSL configuration (legacy compatibility)
self.ssl_enabled = config.get("ssl_enabled") if config.get("ssl_enabled") is not None else True
self.ssl_insecure = config.get("ssl_insecure") if config.get("ssl_insecure") is not None else False
self.sslrootcert = config.get("sslrootcert")
# S3 configuration for data loading
# Check for staging_root first (set by orchestrator for CloudStagingPath)
staging_root = config.get("staging_root")
if staging_root:
# Parse s3://bucket/path format to extract bucket and prefix
from benchbox.utils.cloud_storage import get_cloud_path_info
path_info = get_cloud_path_info(staging_root)
if path_info["provider"] == "s3":
self.s3_bucket = path_info["bucket"]
# Use the path component if provided, otherwise use default
self.s3_prefix = path_info["path"].strip("/") if path_info["path"] else "benchbox-data"
self.logger.info(f"Using staging location from config: s3://{self.s3_bucket}/{self.s3_prefix}")
else:
raise ValueError(f"Redshift requires S3 (s3://) staging location, got: {path_info['provider']}://")
else:
# Fall back to explicit s3_bucket configuration
self.s3_bucket = config.get("s3_bucket")
self.s3_prefix = config.get("s3_prefix") or "benchbox-data"
self.iam_role = config.get("iam_role")
self.aws_access_key_id = config.get("aws_access_key_id")
self.aws_secret_access_key = config.get("aws_secret_access_key")
self.aws_region = config.get("aws_region") or "us-east-1"
# Redshift optimization settings
self.workload_management_config = config.get("wlm_config")
# COMPUPDATE controls automatic compression during COPY (PRESET | ON | OFF)
# PRESET: Apply compression based on column data types (no sampling)
# ON: Apply compression based on data sampling
# OFF: Disable automatic compression
compupdate_raw = config.get("compupdate") or "PRESET"
self.compupdate = compupdate_raw.upper() # Normalize to uppercase
# Validate COMPUPDATE value
valid_compupdate_values = {"ON", "OFF", "PRESET"}
if self.compupdate not in valid_compupdate_values:
raise ValueError(
f"Invalid COMPUPDATE value: '{compupdate_raw}'. "
f"Must be one of: {', '.join(sorted(valid_compupdate_values))}"
)
self.auto_vacuum = config.get("auto_vacuum") if config.get("auto_vacuum") is not None else True
self.auto_analyze = config.get("auto_analyze") if config.get("auto_analyze") is not None else True
# Result cache control - disable by default for accurate benchmarking
self.disable_result_cache = config.get("disable_result_cache", True)
# Validation strictness - raise errors if cache control validation fails
self.strict_validation = config.get("strict_validation", True)
if not all([self.host, self.username, self.password]):
missing = []
if not self.host:
missing.append("host (or REDSHIFT_HOST)")
if not self.username:
missing.append("username (or REDSHIFT_USER)")
if not self.password:
missing.append("password (or REDSHIFT_PASSWORD)")
raise ConfigurationError(
f"Redshift configuration is incomplete. Missing: {', '.join(missing)}\n"
"Configure with one of:\n"
" 1. CLI: benchbox platforms setup --platform redshift\n"
" 2. Environment variables: REDSHIFT_HOST, REDSHIFT_USER, REDSHIFT_PASSWORD\n"
" 3. CLI options: --platform-option host=<cluster>.redshift.amazonaws.com"
)
@property
def platform_name(self) -> str:
return "Redshift"
[docs]
@staticmethod
def add_cli_arguments(parser) -> None:
"""Add Redshift-specific CLI arguments."""
rs_group = parser.add_argument_group("Redshift Arguments")
rs_group.add_argument("--host", type=str, help="Redshift cluster endpoint hostname")
rs_group.add_argument("--port", type=int, default=5439, help="Redshift cluster port")
rs_group.add_argument("--database", type=str, default="dev", help="Database name")
rs_group.add_argument("--username", type=str, help="Database user with required privileges")
rs_group.add_argument("--password", type=str, help="Password for the database user")
rs_group.add_argument("--iam-role", type=str, help="IAM role ARN for COPY operations")
rs_group.add_argument("--s3-bucket", type=str, help="S3 bucket for data staging")
rs_group.add_argument("--s3-prefix", type=str, default="benchbox-data", help="Prefix within the staging bucket")
[docs]
@classmethod
def from_config(cls, config: dict[str, Any]):
"""Create Redshift adapter from unified configuration."""
from benchbox.utils.database_naming import generate_database_name
adapter_config: dict[str, Any] = {}
# Generate proper database name using benchmark characteristics
# (unless explicitly overridden in config)
if "database" in config and config["database"]:
# User explicitly provided database name - use it
adapter_config["database"] = config["database"]
else:
# Generate configuration-aware database name
database_name = generate_database_name(
benchmark_name=config["benchmark"],
scale_factor=config["scale_factor"],
platform="redshift",
tuning_config=config.get("tuning_config"),
)
adapter_config["database"] = database_name
# Core connection parameters (database handled above)
for key in ["host", "port", "username", "password", "schema"]:
if key in config:
adapter_config[key] = config[key]
# Optional staging/optimization parameters
for key in [
"iam_role",
"s3_bucket",
"s3_prefix",
"aws_access_key_id",
"aws_secret_access_key",
"aws_region",
"cluster_identifier",
"admin_database",
"connect_timeout",
"statement_timeout",
"sslmode",
"ssl_enabled",
"ssl_insecure",
"sslrootcert",
"wlm_query_slot_count",
"wlm_query_queue_name",
"wlm_config",
"compupdate",
"auto_vacuum",
"auto_analyze",
]:
if key in config:
adapter_config[key] = config[key]
return cls(**adapter_config)
[docs]
def get_target_dialect(self) -> str:
"""Return the target SQL dialect for Redshift."""
return "redshift"
def _detect_deployment_type(self, hostname: str) -> str:
"""Detect Redshift deployment type from hostname pattern.
Args:
hostname: Redshift endpoint hostname
Returns:
"serverless", "provisioned", or "unknown"
"""
if not hostname:
return "unknown"
if ".redshift-serverless.amazonaws.com" in hostname:
return "serverless"
elif ".redshift.amazonaws.com" in hostname:
return "provisioned"
else:
return "unknown"
def _extract_region_from_hostname(self, hostname: str, deployment_type: str) -> str | None:
"""Extract AWS region from Redshift hostname.
Args:
hostname: Redshift endpoint hostname
deployment_type: "serverless" or "provisioned"
Returns:
AWS region string or None if not found
"""
if not hostname:
return None
parts = hostname.split(".")
if deployment_type == "serverless":
# Format: workgroup.account.region.redshift-serverless.amazonaws.com
return parts[2] if len(parts) > 2 else None
elif deployment_type == "provisioned":
# Format: cluster.region.redshift.amazonaws.com
return parts[1] if len(parts) > 1 else None
return None
def _extract_identifier_from_hostname(self, hostname: str, deployment_type: str) -> str | None:
"""Extract workgroup name or cluster identifier from hostname.
Args:
hostname: Redshift endpoint hostname
deployment_type: "serverless" or "provisioned"
Returns:
Workgroup name, cluster identifier, or None
"""
if not hostname:
return None
parts = hostname.split(".")
if deployment_type == "serverless":
# Format: workgroup.account.region.redshift-serverless.amazonaws.com
return parts[0] if len(parts) > 0 else None
elif deployment_type == "provisioned":
# Format: cluster.region.redshift.amazonaws.com
return parts[0] if len(parts) > 0 else None
return None
def _get_serverless_metadata_sql(self, cursor: Any) -> dict[str, Any]:
"""Get Redshift Serverless metadata using SQL queries.
Args:
cursor: Active database cursor
Returns:
Dictionary with serverless metadata (empty if not serverless or queries fail)
"""
metadata = {}
try:
# Try to query sys_serverless_usage (serverless-only table)
cursor.execute("""
SELECT compute_capacity
FROM sys_serverless_usage
ORDER BY start_time DESC
LIMIT 1
""")
result = cursor.fetchone()
if result:
# Table exists and has data - this is serverless
metadata["current_rpu_capacity"] = result[0] if result[0] is not None else None
self.logger.debug("Detected Redshift Serverless via sys_serverless_usage table")
except Exception as e:
# Table doesn't exist or query failed - likely not serverless
self.logger.debug(f"sys_serverless_usage query failed (not serverless or no permissions): {e}")
return metadata
def _get_provisioned_metadata_sql(self, cursor: Any) -> dict[str, Any]:
"""Get Redshift Provisioned metadata using SQL queries.
Args:
cursor: Active database cursor
Returns:
Dictionary with provisioned metadata (empty if not provisioned or queries fail)
"""
metadata = {}
try:
# Query stv_cluster_configuration (provisioned-only table)
cursor.execute("""
SELECT node_type, cluster_version, COUNT(*) as num_nodes
FROM stv_cluster_configuration
GROUP BY node_type, cluster_version
""")
result = cursor.fetchone()
if result:
metadata["node_type"] = result[0] if len(result) > 0 else None
metadata["cluster_version"] = result[1] if len(result) > 1 else None
metadata["number_of_nodes"] = result[2] if len(result) > 2 else None
self.logger.debug(
f"Detected Redshift Provisioned: {metadata['node_type']} x{metadata['number_of_nodes']}"
)
except Exception as e:
# Table doesn't exist or query failed - likely not provisioned or no permissions
self.logger.debug(f"stv_cluster_configuration query failed (not provisioned or no permissions): {e}")
return metadata
def _get_serverless_metadata_api(self, workgroup_name: str, region: str) -> dict[str, Any]:
"""Get Redshift Serverless metadata using boto3 API.
Args:
workgroup_name: Workgroup name
region: AWS region
Returns:
Dictionary with serverless metadata (empty if API call fails)
"""
if not boto3:
self.logger.debug("boto3 not available - skipping Serverless API metadata")
return {}
metadata = {}
try:
client = boto3.client("redshift-serverless", region_name=region)
# Get workgroup details
response = client.get_workgroup(workgroupName=workgroup_name)
workgroup = response.get("workgroup", {})
# Extract essential sizing information
metadata["workgroup_name"] = workgroup.get("workgroupName")
metadata["base_capacity_rpu"] = workgroup.get("baseCapacity")
metadata["max_capacity_rpu"] = workgroup.get("maxCapacity")
metadata["namespace_name"] = workgroup.get("namespaceName")
metadata["enhanced_vpc_routing"] = workgroup.get("enhancedVpcRouting", False)
metadata["status"] = workgroup.get("status")
# Get namespace details for encryption info
if metadata.get("namespace_name"):
try:
namespace_response = client.get_namespace(namespaceName=metadata["namespace_name"])
namespace = namespace_response.get("namespace", {})
metadata["kms_key_id"] = namespace.get("kmsKeyId")
metadata["encrypted"] = True # Serverless is always encrypted
except Exception as e:
self.logger.debug(f"Could not fetch namespace details: {e}")
self.logger.debug(f"Retrieved Serverless metadata via API: {metadata['base_capacity_rpu']} RPUs")
except NoCredentialsError:
self.logger.debug("No AWS credentials found - skipping Serverless API metadata")
except ClientError as e:
error_code = e.response.get("Error", {}).get("Code", "Unknown")
self.logger.debug(f"AWS API error querying Serverless metadata: {error_code} - {e}")
except Exception as e:
self.logger.debug(f"Failed to query Serverless API metadata: {e}")
return metadata
def _get_provisioned_metadata_api(self, cluster_identifier: str, region: str) -> dict[str, Any]:
"""Get Redshift Provisioned metadata using boto3 API.
Args:
cluster_identifier: Cluster identifier
region: AWS region
Returns:
Dictionary with provisioned metadata (empty if API call fails)
"""
if not boto3:
self.logger.debug("boto3 not available - skipping Provisioned API metadata")
return {}
metadata = {}
try:
client = boto3.client("redshift", region_name=region)
# Get cluster details
response = client.describe_clusters(ClusterIdentifier=cluster_identifier)
clusters = response.get("Clusters", [])
if clusters:
cluster = clusters[0]
# Extract essential sizing information
metadata["cluster_identifier"] = cluster.get("ClusterIdentifier")
metadata["node_type"] = cluster.get("NodeType")
metadata["number_of_nodes"] = cluster.get("NumberOfNodes")
metadata["cluster_status"] = cluster.get("ClusterStatus")
metadata["encrypted"] = cluster.get("Encrypted", False)
metadata["kms_key_id"] = cluster.get("KmsKeyId")
metadata["enhanced_vpc_routing"] = cluster.get("EnhancedVpcRouting", False)
# Storage capacity
total_storage_mb = cluster.get("TotalStorageCapacityInMegaBytes")
if total_storage_mb:
metadata["total_storage_capacity_mb"] = total_storage_mb
self.logger.debug(
f"Retrieved Provisioned metadata via API: {metadata['node_type']} x{metadata['number_of_nodes']}"
)
except NoCredentialsError:
self.logger.debug("No AWS credentials found - skipping Provisioned API metadata")
except ClientError as e:
error_code = e.response.get("Error", {}).get("Code", "Unknown")
self.logger.debug(f"AWS API error querying Provisioned metadata: {error_code} - {e}")
except Exception as e:
self.logger.debug(f"Failed to query Provisioned API metadata: {e}")
return metadata
def _get_connection_params(self, **connection_config) -> dict[str, Any]:
"""Get standardized connection parameters."""
return {
"host": connection_config.get("host", self.host),
"port": connection_config.get("port", self.port),
"database": connection_config.get("database", self.database),
"user": connection_config.get("username", self.username),
"password": connection_config.get("password", self.password),
"sslmode": connection_config.get("sslmode", self.sslmode),
}
def _create_admin_connection(self, **connection_config) -> Any:
"""Create Redshift connection for admin operations.
Admin operations (CREATE DATABASE, DROP DATABASE, checking database existence)
require connecting to an existing database. This uses self.admin_database
(default: "dev") instead of the target database to avoid circular dependencies.
"""
params = self._get_connection_params(**connection_config)
# Override database with admin database for admin operations
# This prevents trying to connect to the target database to check if it exists
admin_params = params.copy()
admin_params["database"] = self.admin_database
# Use same driver as main connection
if redshift_connector:
# Use redshift_connector (preferred)
# Note: redshift_connector doesn't support connect_timeout or sslmode parameters
# It only accepts ssl=True/False (not sslmode like psycopg2)
return redshift_connector.connect(
host=admin_params["host"],
port=admin_params["port"],
database=admin_params["database"],
user=admin_params["user"],
password=admin_params["password"],
ssl=self.ssl_enabled,
application_name="BenchBox-Admin",
)
else:
# Fall back to psycopg2
admin_params["connect_timeout"] = 30
return psycopg2.connect(**admin_params)
def _create_direct_connection(self, **connection_config) -> Any:
"""Create direct connection to target database for validation.
Connects directly to the specified database without:
- Calling handle_existing_database()
- Creating database if missing
- Setting database_was_reused flag
Used by validation framework to check existing database compatibility.
Args:
**connection_config: Connection configuration including database name
Returns:
Database connection object
Raises:
Exception: If connection fails (database doesn't exist, auth fails, etc.)
"""
# Get connection parameters for target database
params = self._get_connection_params(**connection_config)
# Use redshift_connector if available, otherwise fall back to psycopg2
if redshift_connector:
# Note: redshift_connector only accepts ssl=True/False (not sslmode like psycopg2)
connection = redshift_connector.connect(
host=params["host"],
database=params["database"],
port=params["port"],
user=params["user"],
password=params["password"],
ssl=self.ssl_enabled,
application_name="BenchBox-Validation",
)
else:
# Fall back to psycopg2
params["connect_timeout"] = 30
connection = psycopg2.connect(**params)
# Apply WLM queue settings if configured
if self.wlm_query_queue_name:
cursor = connection.cursor()
try:
# Escape single quotes in queue name for SQL safety
queue_name_escaped = self.wlm_query_queue_name.replace("'", "''")
cursor.execute(f"SET query_group TO '{queue_name_escaped}'")
finally:
cursor.close()
return connection
[docs]
def check_server_database_exists(self, **connection_config) -> bool:
"""Check if database exists in Redshift cluster.
Connects to admin database to query pg_database for the target database.
"""
try:
# Connect to admin database (not target database)
connection = self._create_admin_connection()
cursor = connection.cursor()
database = connection_config.get("database", self.database)
# Check if database exists
cursor.execute("SELECT datname FROM pg_database WHERE datname = %s", (database,))
result = cursor.fetchone()
return result is not None
except Exception:
# If we can't connect or check, assume database doesn't exist
return False
finally:
if "connection" in locals() and connection:
connection.close()
[docs]
def drop_database(self, **connection_config) -> None:
"""Drop database in Redshift cluster.
Connects to admin database to drop the target database.
Note: DROP DATABASE must run with autocommit enabled.
Note: Redshift doesn't support IF EXISTS for DROP DATABASE, so we check first.
"""
database = connection_config.get("database", self.database)
# Check if database exists first (Redshift doesn't support IF EXISTS)
if not self.check_server_database_exists(database=database):
self.log_verbose(f"Database {database} does not exist - nothing to drop")
return
try:
# Connect to admin database (not target database)
connection = self._create_admin_connection()
connection.autocommit = True # Enable autocommit for DROP DATABASE
cursor = connection.cursor()
# Try to drop database first (graceful approach)
# Quote identifier for SQL safety
try:
cursor.execute(f'DROP DATABASE "{database}"')
except Exception as drop_error:
# If drop fails due to active connections, terminate them and retry
error_msg = str(drop_error).lower()
if "active connection" in error_msg or "being accessed" in error_msg:
self.log_verbose("Database has active connections, terminating them...")
# Terminate existing connections as fallback
# Note: Use 'pid' column (not deprecated 'procpid')
cursor.execute(
"""
SELECT pg_terminate_backend(pid)
FROM pg_stat_activity
WHERE datname = %s AND pid <> pg_backend_pid()
""",
(database,),
)
# Retry drop after terminating connections (quoted for SQL safety)
cursor.execute(f'DROP DATABASE "{database}"')
else:
# Re-raise if not a connection issue
raise
except Exception as e:
raise RuntimeError(f"Failed to drop Redshift database {database}: {e}") from e
finally:
if "connection" in locals() and connection:
connection.close()
[docs]
def create_connection(self, **connection_config) -> Any:
"""Create optimized Redshift connection."""
self.log_operation_start("Redshift connection")
# Handle existing database using base class method
self.handle_existing_database(**connection_config)
# Get connection parameters
params = self._get_connection_params(**connection_config)
target_database = params.get("database")
# Create database if needed (before connecting to it)
# Redshift requires connecting to an admin database to create new databases
if not self.database_was_reused:
# Check if target database exists
database_exists = self.check_server_database_exists(database=target_database)
if not database_exists:
self.log_verbose(f"Creating database: {target_database}")
# Create database using admin connection (connects to self.admin_database)
# Note: CREATE DATABASE must run with autocommit enabled (cannot run in transaction block)
try:
admin_conn = self._create_admin_connection()
admin_conn.autocommit = True # Enable autocommit for CREATE DATABASE
admin_cursor = admin_conn.cursor()
try:
# Quote identifier for SQL safety
admin_cursor.execute(f'CREATE DATABASE "{target_database}"')
# No commit() needed - autocommit handles it automatically
self.logger.info(f"Created database {target_database}")
finally:
admin_cursor.close()
admin_conn.close()
except Exception as e:
self.logger.error(f"Failed to create database {target_database}: {e}")
raise
self.log_very_verbose(f"Redshift connection params: host={params.get('host')}, database={target_database}")
try:
if redshift_connector:
# Use redshift_connector (preferred)
# Note: redshift_connector only accepts ssl=True/False (not sslmode like psycopg2)
connection = redshift_connector.connect(
host=params["host"],
port=params["port"],
database=params["database"],
user=params["user"],
password=params["password"],
ssl=self.ssl_enabled,
# Connection optimization
application_name="BenchBox",
tcp_keepalive=True,
tcp_keepalive_idle=600,
tcp_keepalive_interval=30,
tcp_keepalive_count=3,
)
# Enable autocommit immediately after connection creation (before any SQL operations)
connection.autocommit = True
else:
# Fall back to psycopg2 (already imported at top of file)
connection = psycopg2.connect(
host=params["host"],
port=params["port"],
database=params["database"],
user=params["user"],
password=params["password"],
sslmode=params["sslmode"],
application_name="BenchBox",
connect_timeout=self.connect_timeout,
)
# Enable autocommit for benchmark workloads (no transactions needed)
connection.autocommit = True
# Apply WLM settings and schema search path
cursor = connection.cursor()
# Integer settings validated in __init__, safe to interpolate
if self.wlm_query_slot_count > 1:
cursor.execute(f"SET wlm_query_slot_count = {int(self.wlm_query_slot_count)}")
if self.statement_timeout > 0:
cursor.execute(f"SET statement_timeout = {int(self.statement_timeout)}")
# Set search_path to ensure all unqualified table references use correct schema
# Critical for database reuse when schema already exists but connection is new
# Quote identifier for SQL safety
cursor.execute(f'SET search_path TO "{self.schema}"')
# Test connection
cursor.execute("SELECT version()")
cursor.fetchone()
cursor.close()
self.logger.info(f"Connected to Redshift cluster at {params['host']}:{params['port']}")
self.log_operation_complete(
"Redshift connection", details=f"Connected to {params['host']}:{params['port']}"
)
return connection
except Exception as e:
self.logger.error(f"Failed to connect to Redshift: {e}")
raise
[docs]
def create_schema(self, benchmark, connection: Any) -> float:
"""Create schema using Redshift-optimized table definitions."""
start_time = time.time()
cursor = connection.cursor()
try:
# Create schema if needed (if not using default "public")
# Quote identifiers for SQL safety
if self.schema and self.schema.lower() != "public":
self.log_verbose(f"Creating schema: {self.schema}")
cursor.execute(f'CREATE SCHEMA IF NOT EXISTS "{self.schema}"')
self.logger.info(f"Created schema {self.schema}")
# Set search_path to use the correct schema
self.log_very_verbose(f"Setting search_path to: {self.schema}")
cursor.execute(f'SET search_path TO "{self.schema}"')
# Use common schema creation helper
schema_sql = self._create_schema_with_tuning(benchmark, source_dialect="duckdb")
# Split schema into individual statements and execute
statements = [stmt.strip() for stmt in schema_sql.split(";") if stmt.strip()]
for statement in statements:
# Normalize table names to lowercase for Redshift consistency
# This ensures CREATE, COPY, and SELECT all use the same case
statement = self._normalize_table_name_in_sql(statement)
# Ensure idempotency with DROP TABLE IF EXISTS
# (Redshift doesn't support CREATE OR REPLACE TABLE)
if statement.upper().startswith("CREATE TABLE"):
# Extract table name from CREATE TABLE statement
table_name = self._extract_table_name(statement)
if table_name:
# Ensure table name is lowercase
table_name_lower = table_name.strip('"').lower()
drop_statement = f"DROP TABLE IF EXISTS {table_name_lower}"
cursor.execute(drop_statement)
self.logger.debug(f"Executed: {drop_statement}")
# Optimize table definition for Redshift
statement = self._optimize_table_definition(statement)
cursor.execute(statement)
self.logger.debug(f"Executed schema statement: {statement[:100]}...")
self.logger.info("Schema created")
except Exception as e:
self.logger.error(f"Schema creation failed: {e}")
raise
finally:
cursor.close()
return time.time() - start_time
[docs]
def load_data(
self, benchmark, connection: Any, data_dir: Path
) -> tuple[dict[str, int], float, dict[str, Any] | None]:
"""Load data using Redshift COPY command with S3 integration."""
start_time = time.time()
table_stats = {}
cursor = connection.cursor()
try:
# Get data files from benchmark or manifest fallback
if hasattr(benchmark, "tables") and benchmark.tables:
data_files = benchmark.tables
else:
data_files = None
try:
manifest_path = Path(data_dir) / "_datagen_manifest.json"
if manifest_path.exists():
with open(manifest_path) as f:
manifest = json.load(f)
tables = manifest.get("tables") or {}
mapping = {}
for table, entries in tables.items():
if entries:
# Collect ALL chunk files, not just the first one
chunk_paths = []
for entry in entries:
rel = entry.get("path")
if rel:
chunk_paths.append(Path(data_dir) / rel)
if chunk_paths:
mapping[table] = chunk_paths
if mapping:
data_files = mapping
self.logger.debug("Using data files from _datagen_manifest.json")
except Exception as e:
self.logger.debug(f"Manifest fallback failed: {e}")
if not data_files:
# No data files available - benchmark should have generated data first
raise ValueError("No data files found. Ensure benchmark.generate_data() was called first.")
# Upload files to S3 and load via COPY command
if self.s3_bucket and boto3:
# Create S3 client with explicit error handling
try:
s3_client = boto3.client(
"s3",
aws_access_key_id=self.aws_access_key_id,
aws_secret_access_key=self.aws_secret_access_key,
region_name=self.aws_region,
)
except NoCredentialsError as e:
raise ValueError(
"AWS credentials not found. Configure credentials via:\n"
" 1. IAM role (aws_access_key_id/aws_secret_access_key in config)\n"
" 2. AWS CLI (aws configure)\n"
" 3. Environment variables (AWS_ACCESS_KEY_ID/AWS_SECRET_ACCESS_KEY)"
) from e
except ClientError as e:
raise ValueError(f"Failed to create AWS S3 client: {e}") from e
for table_name, file_paths in data_files.items():
# Normalize to list (data resolver should always return lists now)
if not isinstance(file_paths, list):
file_paths = [file_paths]
# Filter out non-existent or empty files
valid_files = []
for file_path in file_paths:
file_path = Path(file_path)
if file_path.exists() and file_path.stat().st_size > 0:
valid_files.append(file_path)
if not valid_files:
self.logger.warning(f"Skipping {table_name} - no valid data files")
table_stats[table_name.lower()] = 0
continue
chunk_info = f" from {len(valid_files)} file(s)" if len(valid_files) > 1 else ""
self.log_verbose(f"Loading data for table: {table_name}{chunk_info}")
try:
load_start = time.time()
# Normalize table name to lowercase for Redshift consistency
table_name_lower = table_name.lower()
# Upload all files to S3 and collect S3 URIs
# Redshift supports compressed files and any delimiter natively
s3_uris = []
for file_idx, file_path in enumerate(valid_files):
file_path = Path(file_path)
# Detect file format to determine delimiter
# TPC-H uses .tbl (pipe-delimited), TPC-DS uses .dat (pipe-delimited)
# Use substring check to handle chunked files like customer.tbl.1 or customer.tbl.1.zst
file_str = str(file_path.name)
delimiter = "|" if ".tbl" in file_str or ".dat" in file_str else ","
# Upload file directly with original compression and format
# Preserve full multi-part suffix for chunked files (e.g., .tbl.1.zst)
# Extract all suffixes after table name (e.g., "customer.tbl.1.zst" -> ".tbl.1.zst")
file_stem = file_path.stem # e.g., "customer.tbl.1" or "customer"
# Get original suffix (e.g., ".zst")
original_suffix = file_path.suffix
# Check if stem has more suffixes (e.g., ".tbl.1" in "customer.tbl.1")
if "." in file_stem:
# Extract all suffixes: split at first dot and take the rest
parts = file_path.name.split(".", 1) # e.g., ["customer", "tbl.1.zst"]
if len(parts) > 1:
full_suffix = "." + parts[1] # e.g., ".tbl.1.zst"
else:
full_suffix = original_suffix
else:
full_suffix = original_suffix
s3_key = f"{self.s3_prefix}/{table_name}_{file_idx}{full_suffix}"
# Upload file with explicit error handling
try:
s3_client.upload_file(str(file_path), self.s3_bucket, s3_key)
except ClientError as e:
error_code = e.response.get("Error", {}).get("Code", "Unknown")
if error_code == "NoSuchBucket":
raise ValueError(
f"S3 bucket '{self.s3_bucket}' does not exist. "
f"Create the bucket or update your configuration."
) from e
elif error_code == "AccessDenied":
raise ValueError(
f"Access denied to S3 bucket '{self.s3_bucket}'. "
f"Check IAM permissions for s3:PutObject."
) from e
else:
raise ValueError(
f"Failed to upload {file_path.name} to s3://{self.s3_bucket}/{s3_key}: "
f"{error_code} - {e}"
) from e
s3_uris.append(f"s3://{self.s3_bucket}/{s3_key}")
# For multi-file loads, create manifest file and use that
if len(s3_uris) > 1:
manifest = {"entries": [{"url": uri, "mandatory": True} for uri in s3_uris]}
manifest_key = f"{self.s3_prefix}/{table_name}_manifest.json"
# Upload manifest with explicit error handling
try:
s3_client.put_object(
Bucket=self.s3_bucket,
Key=manifest_key,
Body=json.dumps(manifest),
)
except ClientError as e:
error_code = e.response.get("Error", {}).get("Code", "Unknown")
raise ValueError(
f"Failed to upload manifest file to s3://{self.s3_bucket}/{manifest_key}: "
f"{error_code} - {e}"
) from e
copy_from_path = f"s3://{self.s3_bucket}/{manifest_key}"
manifest_option = "manifest"
else:
copy_from_path = s3_uris[0]
manifest_option = ""
# Detect compression format from file extension
# Redshift auto-detects gzip, but we should be explicit for zstd
if any(str(f).endswith(".zst") for f in valid_files):
compression_option = "ZSTD"
elif any(str(f).endswith(".gz") for f in valid_files):
compression_option = "GZIP"
else:
compression_option = ""
# Load from S3 using COPY command with three-way credential handling:
# 1. IAM role (preferred)
# 2. Explicit access keys
# 3. Cluster default IAM role (no credentials in SQL)
if self.iam_role:
# Use IAM role for authentication
credentials_clause = f"IAM_ROLE '{self.iam_role}'"
elif self.aws_access_key_id and self.aws_secret_access_key:
# Use explicit access keys
credentials_clause = f"ACCESS_KEY_ID '{self.aws_access_key_id}' SECRET_ACCESS_KEY '{self.aws_secret_access_key}'"
else:
# No explicit credentials - rely on cluster's default IAM role
credentials_clause = ""
self.log_verbose(
"No explicit credentials configured for COPY; using cluster default IAM role"
)
# Build COPY command with credentials clause (may be empty)
# Fully qualify table name with schema for clarity and to avoid ambiguity
qualified_table = f"{self.schema}.{table_name_lower}"
copy_sql = f"""
COPY {qualified_table}
FROM '{copy_from_path}'
{credentials_clause}
{manifest_option}
{compression_option}
DELIMITER '{delimiter}'
IGNOREHEADER 0
COMPUPDATE {self.compupdate}
"""
cursor.execute(copy_sql)
# Get row count
cursor.execute(f"SELECT COUNT(*) FROM {qualified_table}")
row_count = cursor.fetchone()[0]
table_stats[table_name_lower] = row_count
# Run ANALYZE if configured
if self.auto_analyze:
cursor.execute(f"ANALYZE {qualified_table}")
load_time = time.time() - load_start
self.logger.info(
f"✅ Loaded {row_count:,} rows into {table_name_lower}{chunk_info} in {load_time:.2f}s"
)
except Exception as e:
self.logger.error(f"Failed to load {table_name}: {str(e)[:100]}...")
table_stats[table_name.lower()] = 0
else:
# Direct loading without S3 (less efficient)
self.logger.warning("No S3 bucket configured, using direct INSERT loading")
for table_name, file_paths in data_files.items():
# Normalize to list (handle both single paths and lists for TPC-H vs TPC-DS)
if not isinstance(file_paths, list):
file_paths = [file_paths]
# Filter valid files
valid_files = []
for file_path in file_paths:
file_path = Path(file_path)
if file_path.exists() and file_path.stat().st_size > 0:
valid_files.append(file_path)
if not valid_files:
self.logger.warning(f"Skipping {table_name} - no valid data files")
table_stats[table_name.lower()] = 0
continue
try:
self.log_verbose(f"Direct loading data for table: {table_name}")
load_start = time.time()
# Normalize table name to lowercase for Redshift consistency
table_name_lower = table_name.lower()
# Load data row by row from all chunks (inefficient but works without S3)
total_rows_loaded = 0
for file_idx, file_path in enumerate(valid_files):
chunk_info = f" (chunk {file_idx + 1}/{len(valid_files)})" if len(valid_files) > 1 else ""
self.log_very_verbose(f"Loading {table_name}{chunk_info} from {file_path.name}")
# TPC-H uses .tbl files, TPC-DS uses .dat files - both are pipe-delimited
file_str = str(file_path.name)
delimiter = "|" if ".tbl" in file_str or ".dat" in file_str else ","
# Get compression handler (handles .zst, .gz, or uncompressed)
compression_handler = FileFormatRegistry.get_compression_handler(file_path)
with compression_handler.open(file_path) as f:
rows_loaded = 0
batch_size = 1000
batch_data = []
for line in f:
line = line.strip()
if line and line.endswith(delimiter):
line = line[:-1]
values = line.split(delimiter)
# Simple escaping for SQL
escaped_values = ["'" + str(v).replace("'", "''") + "'" for v in values]
batch_data.append(f"({', '.join(escaped_values)})")
if len(batch_data) >= batch_size:
insert_sql = f"INSERT INTO {table_name_lower} VALUES " + ", ".join(batch_data)
cursor.execute(insert_sql)
rows_loaded += len(batch_data)
total_rows_loaded += len(batch_data)
batch_data = []
# Insert remaining batch
if batch_data:
insert_sql = f"INSERT INTO {table_name_lower} VALUES " + ", ".join(batch_data)
cursor.execute(insert_sql)
rows_loaded += len(batch_data)
total_rows_loaded += len(batch_data)
self.log_very_verbose(f"Loaded {rows_loaded:,} rows from {file_path.name}")
table_stats[table_name_lower] = total_rows_loaded
load_time = time.time() - load_start
chunk_info = f" from {len(valid_files)} file(s)" if len(valid_files) > 1 else ""
self.logger.info(
f"✅ Loaded {total_rows_loaded:,} rows into {table_name_lower}{chunk_info} in {load_time:.2f}s"
)
except Exception as e:
self.logger.error(f"Failed to load {table_name}: {str(e)[:100]}...")
table_stats[table_name.lower()] = 0
total_time = time.time() - start_time
total_rows = sum(table_stats.values())
self.logger.info(f"✅ Loaded {total_rows:,} total rows in {total_time:.2f}s")
except Exception as e:
self.logger.error(f"Data loading failed: {e}")
raise
finally:
cursor.close()
# Redshift doesn't provide detailed per-table timings yet
return table_stats, total_time, None
[docs]
def validate_session_cache_control(self, connection: Any) -> dict[str, Any]:
"""Validate that session-level cache control settings were successfully applied.
Args:
connection: Active Redshift database connection
Returns:
dict with:
- validated: bool - Whether validation passed
- cache_disabled: bool - Whether cache is actually disabled
- settings: dict - Actual session settings
- warnings: list[str] - Any validation warnings
- errors: list[str] - Any validation errors
Raises:
ConfigurationError: If cache control validation fails and strict_validation=True
"""
cursor = connection.cursor()
result = {
"validated": False,
"cache_disabled": False,
"settings": {},
"warnings": [],
"errors": [],
}
try:
# Query current session setting using current_setting() function
cursor.execute("SELECT current_setting('enable_result_cache_for_session') as value")
row = cursor.fetchone()
if row:
actual_value = str(row[0]).lower()
result["settings"]["enable_result_cache_for_session"] = actual_value
# Determine expected value based on configuration
expected_value = "off" if self.disable_result_cache else "on"
if actual_value == expected_value:
result["validated"] = True
result["cache_disabled"] = actual_value == "off"
self.logger.debug(
f"Cache control validated: enable_result_cache_for_session={actual_value} "
f"(expected {expected_value})"
)
else:
error_msg = (
f"Cache control validation failed: "
f"expected enable_result_cache_for_session={expected_value}, "
f"got {actual_value}"
)
result["errors"].append(error_msg)
self.logger.error(error_msg)
# Raise error if strict validation mode enabled
if self.strict_validation:
raise ConfigurationError(
"Redshift session cache control validation failed - "
"benchmark results may be incorrect due to cached query results",
details=result,
)
else:
warning_msg = "Could not retrieve enable_result_cache_for_session parameter from session"
result["warnings"].append(warning_msg)
self.logger.warning(warning_msg)
except Exception as e:
# If this is our ConfigurationError, re-raise it
if isinstance(e, ConfigurationError):
raise
# Otherwise log validation error
error_msg = f"Validation query failed: {e}"
result["errors"].append(error_msg)
self.logger.error(f"Cache control validation error: {e}")
# Raise if strict mode and query failed
if self.strict_validation:
raise ConfigurationError(
"Failed to validate Redshift cache control settings",
details={"original_error": str(e), "validation_result": result},
) from e
finally:
cursor.close()
return result
[docs]
def execute_query(
self,
connection: Any,
query: str,
query_id: str,
benchmark_type: str | None = None,
scale_factor: float | None = None,
validate_row_count: bool = True,
stream_id: int | None = None,
) -> dict[str, Any]:
"""Execute query with detailed timing and performance tracking."""
start_time = time.time()
cursor = connection.cursor()
try:
# Execute the query
# Note: Query dialect translation is now handled automatically by the base adapter
cursor.execute(query)
result = cursor.fetchall()
execution_time = time.time() - start_time
actual_row_count = len(result) if result else 0
# Get query statistics
try:
query_stats = self._get_query_statistics(connection, query_id)
# Add execution time for cost calculation
query_stats["execution_time_seconds"] = execution_time
except Exception:
query_stats = {"execution_time_seconds": execution_time}
# Validate row count if enabled and benchmark type is provided
validation_result = None
if validate_row_count and benchmark_type:
from benchbox.core.validation.query_validation import QueryValidator
validator = QueryValidator()
validation_result = validator.validate_query_result(
benchmark_type=benchmark_type,
query_id=query_id,
actual_row_count=actual_row_count,
scale_factor=scale_factor,
stream_id=stream_id,
)
# Log validation result
if validation_result.warning_message:
self.log_verbose(f"Row count validation: {validation_result.warning_message}")
elif not validation_result.is_valid:
self.log_verbose(f"Row count validation FAILED: {validation_result.error_message}")
else:
self.log_very_verbose(
f"Row count validation PASSED: {actual_row_count} rows "
f"(expected: {validation_result.expected_row_count})"
)
# Use base helper to build result with consistent validation field mapping
result_dict = self._build_query_result_with_validation(
query_id=query_id,
execution_time=execution_time,
actual_row_count=actual_row_count,
first_row=result[0] if result else None,
validation_result=validation_result,
)
# Include Redshift-specific fields
result_dict["translated_query"] = None # Translation handled by base adapter
result_dict["query_statistics"] = query_stats
# Map query_statistics to resource_usage for cost calculation
result_dict["resource_usage"] = query_stats
return result_dict
except Exception as e:
execution_time = time.time() - start_time
return {
"query_id": query_id,
"status": "FAILED",
"execution_time": execution_time,
"rows_returned": 0,
"error": str(e),
"error_type": type(e).__name__,
}
finally:
cursor.close()
def _extract_table_name(self, statement: str) -> str | None:
"""Extract table name from CREATE TABLE statement.
Args:
statement: CREATE TABLE SQL statement
Returns:
Table name or None if not found
"""
try:
# Simple extraction: find text between "CREATE TABLE" and "("
import re
match = re.search(r"CREATE\s+TABLE\s+([^\s(]+)", statement, re.IGNORECASE)
if match:
return match.group(1).strip()
except Exception:
pass
return None
def _normalize_table_name_in_sql(self, sql: str) -> str:
"""Normalize table names in SQL to lowercase for Redshift.
Redshift converts unquoted identifiers to lowercase, so we normalize
all table names to lowercase to ensure consistency across CREATE, COPY,
and SELECT operations. This prevents case sensitivity issues.
Args:
sql: SQL statement
Returns:
SQL with normalized (lowercase) table names
"""
import re
# Match CREATE TABLE "TABLENAME" or CREATE TABLE TABLENAME
# and convert to CREATE TABLE tablename (unquoted lowercase)
sql = re.sub(
r'CREATE\s+TABLE\s+"?([A-Za-z_][A-Za-z0-9_]*)"?',
lambda m: f"CREATE TABLE {m.group(1).lower()}",
sql,
flags=re.IGNORECASE,
)
# Match foreign key references to quoted/unquoted table names
# FOREIGN KEY ... REFERENCES "TABLENAME" → REFERENCES tablename
sql = re.sub(
r'REFERENCES\s+"?([A-Za-z_][A-Za-z0-9_]*)"?',
lambda m: f"REFERENCES {m.group(1).lower()}",
sql,
flags=re.IGNORECASE,
)
return sql
def _optimize_table_definition(self, statement: str) -> str:
"""Optimize table definition for Redshift."""
if not statement.upper().startswith("CREATE TABLE"):
return statement
# Include distribution and sort keys for better performance
if "DISTSTYLE" not in statement.upper():
# Include AUTO distribution style (Redshift will choose appropriate distribution)
statement += " DISTSTYLE AUTO"
if "SORTKEY" not in statement.upper():
# Include sort key on first column (simple heuristic)
# In production, this would be more sophisticated
statement += " SORTKEY AUTO"
return statement
def _get_platform_metadata(self, connection: Any) -> dict[str, Any]:
"""Get Redshift-specific metadata and system information."""
metadata = {
"platform": self.platform_name,
"host": self.host,
"port": self.port,
"database": self.database,
"result_cache_enabled": not self.disable_result_cache,
}
cursor = connection.cursor()
try:
# Get Redshift version
cursor.execute("SELECT version()")
result = cursor.fetchone()
metadata["redshift_version"] = result[0] if result else "unknown"
# Get cluster information
cursor.execute("""
SELECT
node_type,
num_nodes,
cluster_version,
publicly_accessible
FROM stv_cluster_configuration
LIMIT 1
""")
result = cursor.fetchone()
if result:
metadata["cluster_info"] = {
"node_type": result[0],
"num_nodes": result[1],
"cluster_version": result[2],
"publicly_accessible": result[3],
}
# Get current session information
cursor.execute("""
SELECT
current_user,
current_database(),
current_schema(),
inet_client_addr(),
inet_client_port()
""")
result = cursor.fetchone()
if result:
metadata["session_info"] = {
"current_user": result[0],
"current_database": result[1],
"current_schema": result[2],
"client_addr": result[3],
"client_port": result[4],
}
# Get table information
cursor.execute(f"""
SELECT
schemaname,
tablename,
tableowner,
tablespace,
hasindexes,
hasrules,
hastriggers
FROM pg_tables
WHERE schemaname = '{self.schema}'
""")
tables = cursor.fetchall()
metadata["tables"] = [
{
"schema": row[0],
"table": row[1],
"owner": row[2],
"tablespace": row[3],
"has_indexes": row[4],
"has_rules": row[5],
"has_triggers": row[6],
}
for row in tables
]
except Exception as e:
metadata["metadata_error"] = str(e)
finally:
cursor.close()
return metadata
def _get_existing_tables(self, connection: Any) -> list[str]:
"""Get list of existing tables from Redshift (normalized to lowercase).
Queries information_schema for tables in the current schema and returns
them as lowercase names to match Redshift's identifier normalization.
Args:
connection: Database connection
Returns:
List of table names (lowercase)
"""
cursor = connection.cursor()
try:
# Query information_schema for tables in current schema
cursor.execute(
"""
SELECT table_name
FROM information_schema.tables
WHERE table_schema = %s
AND table_type = 'BASE TABLE'
""",
(self.schema,),
)
# Return lowercase table names
return [row[0].lower() for row in cursor.fetchall()]
finally:
cursor.close()
def _get_query_statistics(self, connection: Any, query_id: str) -> dict[str, Any]:
"""Get query statistics from Redshift system tables.
Note: This retrieves performance telemetry only - not used for validation.
Query row count validation uses the actual result set from cursor.fetchall().
"""
cursor = connection.cursor()
try:
# Get the actual Redshift query ID for the most recent query
# pg_last_query_id() returns the query ID of the last executed query in this session
cursor.execute("SELECT pg_last_query_id()")
result = cursor.fetchone()
if not result or result[0] == -1:
# No queries executed yet, or query ran only on leader node
return {}
redshift_query_id = result[0]
# Query STL tables for query statistics using the actual Redshift query ID
cursor.execute(
"""
SELECT
query,
DATEDIFF('microseconds', starttime, endtime) as duration_microsecs,
DATEDIFF('microseconds', starttime, endtime) as cpu_time_microsecs,
0 as bytes_scanned,
0 as bytes_returned,
1 as slots,
1 as wlm_slots,
aborted
FROM stl_query
WHERE query = %s
ORDER BY starttime DESC
LIMIT 1
""",
(redshift_query_id,),
)
result = cursor.fetchone()
if result:
return {
"query_id": str(result[0]),
"duration_microsecs": result[1] or 0,
"cpu_time_microsecs": result[2] or 0,
"bytes_scanned": result[3] or 0,
"bytes_returned": result[4] or 0,
"slots": result[5] or 1,
"wlm_slots": result[6] or 1,
"aborted": bool(result[7]) if result[7] is not None else False,
}
else:
return {}
except Exception:
return {}
finally:
cursor.close()
[docs]
def analyze_table(self, connection: Any, table_name: str) -> None:
"""Run ANALYZE on table for query optimization."""
cursor = connection.cursor()
try:
cursor.execute(f"ANALYZE {table_name.lower()}")
except Exception as e:
self.logger.warning(f"Failed to analyze table {table_name}: {e}")
finally:
cursor.close()
[docs]
def vacuum_table(self, connection: Any, table_name: str) -> None:
"""Run VACUUM on table for space reclamation."""
cursor = connection.cursor()
try:
cursor.execute(f"VACUUM {table_name.lower()}")
except Exception as e:
self.logger.warning(f"Failed to vacuum table {table_name}: {e}")
finally:
cursor.close()
[docs]
def get_query_plan(self, connection: Any, query: str) -> str:
"""Get query execution plan for analysis."""
cursor = connection.cursor()
try:
# Note: Query dialect translation is now handled automatically by the base adapter
cursor.execute(f"EXPLAIN {query}")
plan_rows = cursor.fetchall()
return "\n".join([row[0] for row in plan_rows])
except Exception as e:
return f"Could not get query plan: {e}"
finally:
cursor.close()
[docs]
def close_connection(self, connection: Any) -> None:
"""Close Redshift connection."""
try:
if connection and hasattr(connection, "close"):
connection.close()
except Exception as e:
self.logger.warning(f"Error closing connection: {e}")
[docs]
def supports_tuning_type(self, tuning_type) -> bool:
"""Check if Redshift supports a specific tuning type.
Redshift supports:
- DISTRIBUTION: Via DISTSTYLE and DISTKEY clauses
- SORTING: Via SORTKEY clause (compound and interleaved)
- PARTITIONING: Through table design patterns and date partitioning
Args:
tuning_type: The type of tuning to check support for
Returns:
True if the tuning type is supported by Redshift
"""
# Import here to avoid circular imports
try:
from benchbox.core.tuning.interface import TuningType
return tuning_type in {
TuningType.DISTRIBUTION,
TuningType.SORTING,
TuningType.PARTITIONING,
}
except ImportError:
return False
[docs]
def generate_tuning_clause(self, table_tuning) -> str:
"""Generate Redshift-specific tuning clauses for CREATE TABLE statements.
Redshift supports:
- DISTSTYLE (EVEN | KEY | ALL) DISTKEY (column)
- SORTKEY (column1, column2, ...) or INTERLEAVED SORTKEY (column1, column2, ...)
Args:
table_tuning: The tuning configuration for the table
Returns:
SQL clause string to be appended to CREATE TABLE statement
"""
if not table_tuning or not table_tuning.has_any_tuning():
return ""
clauses = []
try:
# Import here to avoid circular imports
from benchbox.core.tuning.interface import TuningType
# Handle distribution strategy
distribution_columns = table_tuning.get_columns_by_type(TuningType.DISTRIBUTION)
if distribution_columns:
# Sort by order and use first column as distribution key
sorted_cols = sorted(distribution_columns, key=lambda col: col.order)
dist_col = sorted_cols[0]
# Use KEY distribution style with the specified column
clauses.append("DISTSTYLE KEY")
clauses.append(f"DISTKEY ({dist_col.name})")
else:
# Default to EVEN distribution if no distribution columns specified
clauses.append("DISTSTYLE EVEN")
# Handle sorting
sort_columns = table_tuning.get_columns_by_type(TuningType.SORTING)
if sort_columns:
# Sort by order for sortkey
sorted_cols = sorted(sort_columns, key=lambda col: col.order)
column_names = [col.name for col in sorted_cols]
# Use compound sort key by default (better for most OLAP workloads)
# Could be made configurable to choose between COMPOUND and INTERLEAVED
sortkey_clause = f"SORTKEY ({', '.join(column_names)})"
clauses.append(sortkey_clause)
# Handle partitioning through table naming/organization (logged but not in CREATE TABLE)
partition_columns = table_tuning.get_columns_by_type(TuningType.PARTITIONING)
if partition_columns:
# Redshift partitioning is typically handled through table design patterns
# We'll log the strategy but not add SQL clauses
pass
# Clustering not directly supported in Redshift CREATE TABLE
except ImportError:
# If tuning interface not available, return empty string
pass
return " ".join(clauses)
[docs]
def apply_table_tunings(self, table_tuning, connection: Any) -> None:
"""Apply tuning configurations to a Redshift table.
Redshift tuning approach:
- DISTRIBUTION: Handled via DISTSTYLE/DISTKEY in CREATE TABLE
- SORTING: Handled via SORTKEY in CREATE TABLE
- Post-creation optimizations via ANALYZE and VACUUM
Args:
table_tuning: The tuning configuration to apply
connection: Redshift connection
Raises:
ValueError: If the tuning configuration is invalid for Redshift
"""
if not table_tuning or not table_tuning.has_any_tuning():
return
table_name = table_tuning.table_name.lower()
self.logger.info(f"Applying Redshift tunings for table: {table_name}")
cursor = connection.cursor()
try:
# Import here to avoid circular imports
from benchbox.core.tuning.interface import TuningType
# Redshift tuning is primarily handled at table creation time
# Post-creation optimizations are limited
# Verify table exists and get current configuration
cursor.execute(f"""
SELECT
"schema",
"table",
diststyle,
distkey,
sortkey1,
sortkey2,
sortkey3,
sortkey4
FROM pg_table_def
WHERE schemaname = 'public'
AND tablename = '{table_name.lower()}'
""")
result = cursor.fetchone()
if result:
current_diststyle = result[2]
current_distkey = result[3]
current_sortkeys = [sk for sk in result[4:8] if sk] # Filter out None values
self.logger.info(f"Current configuration for {table_name}:")
self.logger.info(f" Distribution style: {current_diststyle}")
self.logger.info(f" Distribution key: {current_distkey}")
self.logger.info(f" Sort keys: {current_sortkeys}")
# Check if configuration matches desired tuning
distribution_columns = table_tuning.get_columns_by_type(TuningType.DISTRIBUTION)
sort_columns = table_tuning.get_columns_by_type(TuningType.SORTING)
needs_recreation = False
# Check distribution configuration
if distribution_columns:
sorted_cols = sorted(distribution_columns, key=lambda col: col.order)
desired_distkey = sorted_cols[0].name
if current_distkey != desired_distkey or current_diststyle != "KEY":
needs_recreation = True
self.logger.info(
f"Distribution key mismatch: current='{current_distkey}', desired='{desired_distkey}'"
)
# Check sort key configuration
if sort_columns:
sorted_cols = sorted(sort_columns, key=lambda col: col.order)
desired_sortkeys = [col.name for col in sorted_cols]
if current_sortkeys != desired_sortkeys:
needs_recreation = True
self.logger.info(f"Sort keys mismatch: current={current_sortkeys}, desired={desired_sortkeys}")
if needs_recreation:
self.logger.warning(
f"Table {table_name} configuration differs from desired tuning. "
"Redshift requires table recreation to change distribution/sort keys."
)
else:
self.logger.warning(f"Could not find table configuration for {table_name}")
# Perform maintenance operations that can help with performance
try:
# Run ANALYZE to update table statistics
cursor.execute(f"ANALYZE {table_name}")
self.logger.info(f"Analyzed table statistics for {table_name}")
# Run VACUUM to reclaim space and re-sort data
if self.auto_vacuum:
cursor.execute(f"VACUUM {table_name}")
self.logger.info(f"Vacuumed table {table_name}")
except Exception as e:
self.logger.warning(f"Failed to perform maintenance operations on {table_name}: {e}")
# Handle partitioning strategy
partition_columns = table_tuning.get_columns_by_type(TuningType.PARTITIONING)
if partition_columns:
sorted_cols = sorted(partition_columns, key=lambda col: col.order)
column_names = [col.name for col in sorted_cols]
self.logger.info(
f"Partitioning strategy for {table_name}: {', '.join(column_names)} (handled via table design patterns)"
)
# Clustering not directly supported in Redshift
cluster_columns = table_tuning.get_columns_by_type(TuningType.CLUSTERING)
if cluster_columns:
sorted_cols = sorted(cluster_columns, key=lambda col: col.order)
column_names = [col.name for col in sorted_cols]
self.logger.info(
f"Clustering strategy for {table_name} achieved via sort keys: {', '.join(column_names)}"
)
except ImportError:
self.logger.warning("Tuning interface not available - skipping tuning application")
except Exception as e:
raise ValueError(f"Failed to apply tunings to Redshift table {table_name}: {e}") from e
finally:
cursor.close()
[docs]
def apply_unified_tuning(self, unified_config: UnifiedTuningConfiguration, connection: Any) -> None:
"""Apply unified tuning configuration to Redshift.
Args:
unified_config: Unified tuning configuration to apply
connection: Redshift connection
"""
if not unified_config:
return
# Apply constraint configurations
self.apply_constraint_configuration(unified_config.primary_keys, unified_config.foreign_keys, connection)
# Apply platform optimizations
if unified_config.platform_optimizations:
self.apply_platform_optimizations(unified_config.platform_optimizations, connection)
# Apply table-level tunings
for _table_name, table_tuning in unified_config.table_tunings.items():
self.apply_table_tunings(table_tuning, connection)
[docs]
def apply_constraint_configuration(
self,
primary_key_config: PrimaryKeyConfiguration,
foreign_key_config: ForeignKeyConfiguration,
connection: Any,
) -> None:
"""Apply constraint configurations to Redshift.
Note: Redshift supports PRIMARY KEY and FOREIGN KEY constraints for query optimization,
but they are informational only (not enforced). Constraints must be applied during
table creation time.
Args:
primary_key_config: Primary key constraint configuration
foreign_key_config: Foreign key constraint configuration
connection: Redshift connection
"""
# Redshift constraints are applied at table creation time for query optimization
# This method is called after tables are created, so log the configurations
if primary_key_config and primary_key_config.enabled:
self.logger.info(
"Primary key constraints enabled for Redshift (informational only, applied during table creation)"
)
if foreign_key_config and foreign_key_config.enabled:
self.logger.info(
"Foreign key constraints enabled for Redshift (informational only, applied during table creation)"
)
# Redshift constraints are informational and used for query optimization
# No additional work to do here as they're applied during CREATE TABLE
def _build_redshift_config(
platform: str,
options: dict[str, Any],
overrides: dict[str, Any],
info: Any,
) -> Any:
"""Build Redshift database configuration with credential loading.
This function loads saved credentials from the CredentialManager and
merges them with CLI options and runtime overrides.
Args:
platform: Platform name (should be 'redshift')
options: CLI platform options from --platform-option flags
overrides: Runtime overrides from orchestrator
info: Platform info from registry
Returns:
DatabaseConfig with credentials loaded
"""
from benchbox.core.config import DatabaseConfig
from benchbox.security.credentials import CredentialManager
# Load saved credentials
cred_manager = CredentialManager()
saved_creds = cred_manager.get_platform_credentials("redshift") or {}
# Build merged options: saved_creds < options < overrides
merged_options = {}
merged_options.update(saved_creds)
merged_options.update(options)
merged_options.update(overrides)
# Extract credential fields for DatabaseConfig
name = info.display_name if info else "Amazon Redshift"
driver_package = info.driver_package if info else "redshift-connector"
# Build config dict with platform-specific fields at top-level
# This allows RedshiftAdapter.__init__() and from_config() to access them via config.get()
config_dict = {
"type": "redshift",
"name": name,
"options": merged_options or {}, # Ensure options is never None (Pydantic v2 uses None if explicitly passed)
"driver_package": driver_package,
"driver_version": overrides.get("driver_version") or options.get("driver_version"),
"driver_auto_install": bool(overrides.get("driver_auto_install", options.get("driver_auto_install", False))),
# Platform-specific fields at top-level (adapters expect these here)
"host": merged_options.get("host"),
"port": merged_options.get("port"),
# NOTE: database is NOT included here - from_config() generates it from benchmark context
# Only explicit overrides (via --platform-option database=...) should bypass generation
"username": merged_options.get("username"),
"password": merged_options.get("password"),
"schema": merged_options.get("schema"),
# S3 and AWS configuration
"s3_bucket": merged_options.get("s3_bucket"),
"s3_prefix": merged_options.get("s3_prefix"),
"iam_role": merged_options.get("iam_role"),
"aws_access_key_id": merged_options.get("aws_access_key_id"),
"aws_secret_access_key": merged_options.get("aws_secret_access_key"),
"aws_region": merged_options.get("aws_region"),
# Optional settings
"cluster_identifier": merged_options.get("cluster_identifier"),
"admin_database": merged_options.get("admin_database", "dev"),
"connect_timeout": merged_options.get("connect_timeout"),
"statement_timeout": merged_options.get("statement_timeout"),
"sslmode": merged_options.get("sslmode"),
# Benchmark context for config-aware database naming (from overrides)
"benchmark": overrides.get("benchmark"),
"scale_factor": overrides.get("scale_factor"),
"tuning_config": overrides.get("tuning_config"),
}
# Only include explicit database override if provided via CLI or overrides
# Saved credentials should NOT override generated database names
if "database" in overrides and overrides["database"]:
config_dict["database"] = overrides["database"]
return DatabaseConfig(**config_dict)
# Register the config builder with the platform hook registry
# This must happen when the module is imported
try:
from benchbox.cli.platform_hooks import PlatformHookRegistry
PlatformHookRegistry.register_config_builder("redshift", _build_redshift_config)
except ImportError:
# Platform hooks may not be available in all contexts (e.g., core-only usage)
pass