"""Amazon Redshift platform adapter with S3 integration and data warehouse optimizations.
Provides Redshift-specific optimizations for analytical workloads,
including COPY command for efficient data loading and distribution key optimization.
Copyright 2026 Joe Harris / BenchBox Project
Licensed under the MIT License. See LICENSE file in the project root for details.
"""
from __future__ import annotations
import json
from pathlib import Path
from typing import TYPE_CHECKING, Any
from benchbox.core.sql_utils import normalize_table_name_in_sql
from benchbox.utils.clock import elapsed_seconds, mono_time
if TYPE_CHECKING:
from benchbox.core.tuning.interface import (
ForeignKeyConfiguration,
PlatformOptimizationConfiguration,
PrimaryKeyConfiguration,
TuningColumn,
UnifiedTuningConfiguration,
)
from ..core.exceptions import ConfigurationError
from ..utils.dependencies import (
check_platform_dependencies,
get_dependency_error_message,
get_dependency_group_packages,
)
from ..utils.file_format import detect_compression, get_delimiter_for_file, is_parquet_format
from .base import DriverIsolationCapability, PlatformAdapter
from .base.data_loading import DataSourceResolver, FileFormatRegistry, ManifestFileSource
try:
import redshift_connector
except ImportError:
try:
import psycopg2
redshift_connector = None
except ImportError:
psycopg2 = None
redshift_connector = None
try:
import boto3
from botocore.config import Config as BotocoreConfig
from botocore.exceptions import ClientError, NoCredentialsError
except ImportError:
boto3 = None
[docs]
class RedshiftAdapter(PlatformAdapter):
"""Amazon Redshift platform adapter with S3 integration."""
driver_isolation_capability = DriverIsolationCapability.FEASIBLE_CLIENT_ONLY
supports_external_tables = True
[docs]
def __init__(self, **config):
super().__init__(**config)
dependency_packages = get_dependency_group_packages("redshift")
# Check dependencies - prefer redshift-connector, fallback to psycopg2
if not redshift_connector and not psycopg2:
available, missing = check_platform_dependencies("redshift")
if not available:
error_msg = get_dependency_error_message("redshift", missing)
raise ImportError(error_msg)
else:
# Ensure shared helper libraries (e.g., boto3, cloudpathlib) are available
shared_packages = [pkg for pkg in dependency_packages if pkg != "redshift-connector"]
if shared_packages:
available_shared, missing_shared = check_platform_dependencies("redshift", shared_packages)
if not available_shared:
error_msg = get_dependency_error_message("redshift", missing_shared)
raise ImportError(error_msg)
self._dialect = "redshift"
# Redshift connection configuration
self.host = config.get("host")
self.port = config.get("port") if config.get("port") is not None else 5439
self.database = config.get("database") or "dev"
self.username = config.get("username")
self.password = config.get("password")
self.cluster_identifier = config.get("cluster_identifier")
# Admin database for metadata operations (CREATE/DROP DATABASE, checking database existence)
# Redshift requires connecting to an existing database for admin operations
# Default: "dev" (Redshift Serverless default database)
self.admin_database = config.get("admin_database") or "dev"
# Schema configuration
self.schema = config.get("schema") or "public"
# Connection settings
self.connect_timeout = config.get("connect_timeout") if config.get("connect_timeout") is not None else 10
self.statement_timeout = config.get("statement_timeout") if config.get("statement_timeout") is not None else 0
self.sslmode = config.get("sslmode") or "require"
# WLM settings
self.wlm_query_slot_count = (
config.get("wlm_query_slot_count") if config.get("wlm_query_slot_count") is not None else 1
)
self.wlm_query_queue_name = config.get("wlm_query_queue_name")
# SSL configuration (legacy compatibility)
self.ssl_enabled = config.get("ssl_enabled") if config.get("ssl_enabled") is not None else True
self.ssl_insecure = config.get("ssl_insecure") if config.get("ssl_insecure") is not None else False
self.sslrootcert = config.get("sslrootcert")
# S3 configuration for data loading
# Check for staging_root first (set by orchestrator for CloudStagingPath)
staging_root = config.get("staging_root")
if staging_root:
# Parse s3://bucket/path format to extract bucket and prefix
from benchbox.utils.cloud_storage import get_cloud_path_info
path_info = get_cloud_path_info(staging_root)
if path_info["provider"] == "s3":
self.s3_bucket = path_info["bucket"]
# Use the path component if provided, otherwise use default
self.s3_prefix = path_info["path"].strip("/") if path_info["path"] else "benchbox-data"
self.logger.info(f"Using staging location from config: s3://{self.s3_bucket}/{self.s3_prefix}")
else:
raise ValueError(f"Redshift requires S3 (s3://) staging location, got: {path_info['provider']}://")
else:
# Fall back to explicit s3_bucket configuration
self.s3_bucket = config.get("s3_bucket")
self.s3_prefix = config.get("s3_prefix") or "benchbox-data"
self.iam_role = config.get("iam_role")
self.aws_access_key_id = config.get("aws_access_key_id")
self.aws_secret_access_key = config.get("aws_secret_access_key")
self.aws_session_token = config.get("aws_session_token")
self.aws_region = config.get("aws_region") or "us-east-1"
# Redshift optimization settings
self.workload_management_config = config.get("wlm_config")
# COMPUPDATE controls automatic compression during COPY (PRESET | ON | OFF)
# PRESET: Apply compression based on column data types (no sampling)
# ON: Apply compression based on data sampling
# OFF: Disable automatic compression
compupdate_raw = config.get("compupdate") or "PRESET"
self.compupdate = compupdate_raw.upper() # Normalize to uppercase
# Validate COMPUPDATE value
valid_compupdate_values = {"ON", "OFF", "PRESET"}
if self.compupdate not in valid_compupdate_values:
raise ValueError(
f"Invalid COMPUPDATE value: '{compupdate_raw}'. "
f"Must be one of: {', '.join(sorted(valid_compupdate_values))}"
)
self.auto_vacuum = config.get("auto_vacuum") if config.get("auto_vacuum") is not None else True
self.auto_analyze = config.get("auto_analyze") if config.get("auto_analyze") is not None else True
# Result cache control - disable by default for accurate benchmarking
self.disable_result_cache = config.get("disable_result_cache", True)
# Validation strictness - raise errors if cache control validation fails
self.strict_validation = config.get("strict_validation", True)
# Cache deployment type for Serverless vs Provisioned branching
self.deployment_type = self._detect_deployment_type(self.host or "")
if not all([self.host, self.username, self.password]):
missing = []
if not self.host:
missing.append("host (or REDSHIFT_HOST)")
if not self.username:
missing.append("username (or REDSHIFT_USER)")
if not self.password:
missing.append("password (or REDSHIFT_PASSWORD)")
raise ConfigurationError(
f"Redshift configuration is incomplete. Missing: {', '.join(missing)}\n"
"Configure with one of:\n"
" 1. CLI: benchbox platforms setup --platform redshift\n"
" 2. Environment variables: REDSHIFT_HOST, REDSHIFT_USER, REDSHIFT_PASSWORD\n"
" 3. CLI options: --platform-option host=<cluster>.redshift.amazonaws.com"
)
@property
def platform_name(self) -> str:
return "Redshift"
def _build_ctas_sort_sql(self, table_name: str, sort_columns: list[TuningColumn]) -> str | list[str] | None:
"""Build opt-in sorted-ingestion SQL for Redshift."""
mode, method = self.resolve_sorted_ingestion_strategy()
if mode == "off":
return None
if method == "vacuum_sort":
return f"VACUUM SORT ONLY {self.schema}.{table_name}"
if method == "ctas":
ordered_cols = ", ".join(column.name for column in sort_columns)
temp_table = f"{table_name}__ctas_sort"
return [
f"CREATE TABLE {self.schema}.{temp_table} AS SELECT * FROM {self.schema}.{table_name} ORDER BY {ordered_cols}",
f"DROP TABLE {self.schema}.{table_name}",
f"ALTER TABLE {self.schema}.{temp_table} RENAME TO {table_name}",
]
raise ValueError(f"Sorted ingestion method '{method}' is not supported for Redshift.")
[docs]
@staticmethod
def add_cli_arguments(parser) -> None:
"""Add Redshift-specific CLI arguments."""
rs_group = parser.add_argument_group("Redshift Arguments")
rs_group.add_argument("--host", type=str, help="Redshift cluster endpoint hostname")
rs_group.add_argument("--port", type=int, default=5439, help="Redshift cluster port")
rs_group.add_argument("--database", type=str, default="dev", help="Database name")
rs_group.add_argument("--username", type=str, help="Database user with required privileges")
rs_group.add_argument("--password", type=str, help="Password for the database user")
rs_group.add_argument("--iam-role", type=str, help="IAM role ARN for COPY operations")
rs_group.add_argument("--s3-bucket", type=str, help="S3 bucket for data staging")
rs_group.add_argument("--s3-prefix", type=str, default="benchbox-data", help="Prefix within the staging bucket")
[docs]
@classmethod
def from_config(cls, config: dict[str, Any]):
"""Create Redshift adapter from unified configuration."""
from benchbox.utils.database_naming import generate_database_name
adapter_config: dict[str, Any] = {}
# Generate proper database name using benchmark characteristics
# (unless explicitly overridden in config)
if "database" in config and config["database"]:
# User explicitly provided database name - use it
adapter_config["database"] = config["database"]
else:
# Generate configuration-aware database name
database_name = generate_database_name(
benchmark_name=config["benchmark"],
scale_factor=config["scale_factor"],
platform="redshift",
tuning_config=config.get("tuning_config"),
)
adapter_config["database"] = database_name
# Core connection parameters (database handled above)
for key in ["host", "port", "username", "password", "schema"]:
if key in config:
adapter_config[key] = config[key]
# Optional staging/optimization parameters
for key in [
"iam_role",
"s3_bucket",
"s3_prefix",
"staging_root",
"aws_access_key_id",
"aws_secret_access_key",
"aws_session_token",
"aws_region",
"cluster_identifier",
"admin_database",
"connect_timeout",
"statement_timeout",
"sslmode",
"ssl_enabled",
"ssl_insecure",
"sslrootcert",
"wlm_query_slot_count",
"wlm_query_queue_name",
"wlm_config",
"compupdate",
"auto_vacuum",
"auto_analyze",
]:
if key in config:
adapter_config[key] = config[key]
return cls(**adapter_config)
[docs]
def get_target_dialect(self) -> str:
"""Return the target SQL dialect for Redshift."""
return "redshift"
def _detect_deployment_type(self, hostname: str) -> str:
"""Detect Redshift deployment type from hostname pattern.
Args:
hostname: Redshift endpoint hostname
Returns:
"serverless", "provisioned", or "unknown"
"""
if not hostname:
return "unknown"
if ".redshift-serverless.amazonaws.com" in hostname:
return "serverless"
elif ".redshift.amazonaws.com" in hostname:
return "provisioned"
else:
return "unknown"
def _extract_region_from_hostname(self, hostname: str, deployment_type: str) -> str | None:
"""Extract AWS region from Redshift hostname.
Args:
hostname: Redshift endpoint hostname
deployment_type: "serverless" or "provisioned"
Returns:
AWS region string or None if not found
"""
if not hostname:
return None
parts = hostname.split(".")
if deployment_type == "serverless":
# Format: workgroup.account.region.redshift-serverless.amazonaws.com
return parts[2] if len(parts) > 2 else None
elif deployment_type == "provisioned":
# Format: cluster.region.redshift.amazonaws.com
return parts[1] if len(parts) > 1 else None
return None
def _extract_identifier_from_hostname(self, hostname: str, deployment_type: str) -> str | None:
"""Extract workgroup name or cluster identifier from hostname.
Args:
hostname: Redshift endpoint hostname
deployment_type: "serverless" or "provisioned"
Returns:
Workgroup name, cluster identifier, or None
"""
if not hostname:
return None
parts = hostname.split(".")
if deployment_type == "serverless":
# Format: workgroup.account.region.redshift-serverless.amazonaws.com
return parts[0] if len(parts) > 0 else None
elif deployment_type == "provisioned":
# Format: cluster.region.redshift.amazonaws.com
return parts[0] if len(parts) > 0 else None
return None
def _get_serverless_metadata_sql(self, cursor: Any) -> dict[str, Any]:
"""Get Redshift Serverless metadata using SQL queries.
Args:
cursor: Active database cursor
Returns:
Dictionary with serverless metadata (empty if not serverless or queries fail)
"""
metadata = {}
try:
# Try to query sys_serverless_usage (serverless-only table)
cursor.execute("""
SELECT compute_capacity
FROM sys_serverless_usage
ORDER BY start_time DESC
LIMIT 1
""")
result = cursor.fetchone()
if result:
# Table exists and has data - this is serverless
metadata["current_rpu_capacity"] = result[0] if result[0] is not None else None
self.logger.debug("Detected Redshift Serverless via sys_serverless_usage table")
except Exception as e:
# Table doesn't exist or query failed - likely not serverless
self.logger.debug(f"sys_serverless_usage query failed (not serverless or no permissions): {e}")
return metadata
def _get_provisioned_metadata_sql(self, cursor: Any) -> dict[str, Any]:
"""Get Redshift Provisioned metadata using SQL queries.
Args:
cursor: Active database cursor
Returns:
Dictionary with provisioned metadata (empty if not provisioned or queries fail)
"""
metadata = {}
try:
# Query stv_cluster_configuration (provisioned-only table)
cursor.execute("""
SELECT node_type, cluster_version, COUNT(*) as num_nodes
FROM stv_cluster_configuration
GROUP BY node_type, cluster_version
""")
result = cursor.fetchone()
if result:
metadata["node_type"] = result[0] if len(result) > 0 else None
metadata["cluster_version"] = result[1] if len(result) > 1 else None
metadata["number_of_nodes"] = result[2] if len(result) > 2 else None
self.logger.debug(
f"Detected Redshift Provisioned: {metadata['node_type']} x{metadata['number_of_nodes']}"
)
except Exception as e:
# Table doesn't exist or query failed - likely not provisioned or no permissions
self.logger.debug(f"stv_cluster_configuration query failed (not provisioned or no permissions): {e}")
return metadata
def _get_serverless_metadata_api(self, workgroup_name: str, region: str) -> dict[str, Any]:
"""Get Redshift Serverless metadata using boto3 API.
Args:
workgroup_name: Workgroup name
region: AWS region
Returns:
Dictionary with serverless metadata (empty if API call fails)
"""
if not boto3:
self.logger.debug("boto3 not available - skipping Serverless API metadata")
return {}
metadata = {}
try:
client = boto3.client("redshift-serverless", region_name=region)
# Get workgroup details
response = client.get_workgroup(workgroupName=workgroup_name)
workgroup = response.get("workgroup", {})
# Extract essential sizing information
metadata["workgroup_name"] = workgroup.get("workgroupName")
metadata["base_capacity_rpu"] = workgroup.get("baseCapacity")
metadata["max_capacity_rpu"] = workgroup.get("maxCapacity")
metadata["namespace_name"] = workgroup.get("namespaceName")
metadata["enhanced_vpc_routing"] = workgroup.get("enhancedVpcRouting", False)
metadata["status"] = workgroup.get("status")
# Get namespace details for encryption info
if metadata.get("namespace_name"):
try:
namespace_response = client.get_namespace(namespaceName=metadata["namespace_name"])
namespace = namespace_response.get("namespace", {})
metadata["kms_key_id"] = namespace.get("kmsKeyId")
metadata["encrypted"] = True # Serverless is always encrypted
except Exception as e:
self.logger.debug(f"Could not fetch namespace details: {e}")
self.logger.debug(f"Retrieved Serverless metadata via API: {metadata['base_capacity_rpu']} RPUs")
except NoCredentialsError:
self.logger.debug("No AWS credentials found - skipping Serverless API metadata")
except ClientError as e:
error_code = e.response.get("Error", {}).get("Code", "Unknown")
self.logger.debug(f"AWS API error querying Serverless metadata: {error_code} - {e}")
except Exception as e:
self.logger.debug(f"Failed to query Serverless API metadata: {e}")
return metadata
def _get_provisioned_metadata_api(self, cluster_identifier: str, region: str) -> dict[str, Any]:
"""Get Redshift Provisioned metadata using boto3 API.
Args:
cluster_identifier: Cluster identifier
region: AWS region
Returns:
Dictionary with provisioned metadata (empty if API call fails)
"""
if not boto3:
self.logger.debug("boto3 not available - skipping Provisioned API metadata")
return {}
metadata = {}
try:
client = boto3.client("redshift", region_name=region)
# Get cluster details
response = client.describe_clusters(ClusterIdentifier=cluster_identifier)
clusters = response.get("Clusters", [])
if clusters:
cluster = clusters[0]
# Extract essential sizing information
metadata["cluster_identifier"] = cluster.get("ClusterIdentifier")
metadata["node_type"] = cluster.get("NodeType")
metadata["number_of_nodes"] = cluster.get("NumberOfNodes")
metadata["cluster_status"] = cluster.get("ClusterStatus")
metadata["encrypted"] = cluster.get("Encrypted", False)
metadata["kms_key_id"] = cluster.get("KmsKeyId")
metadata["enhanced_vpc_routing"] = cluster.get("EnhancedVpcRouting", False)
# Storage capacity
total_storage_mb = cluster.get("TotalStorageCapacityInMegaBytes")
if total_storage_mb:
metadata["total_storage_capacity_mb"] = total_storage_mb
self.logger.debug(
f"Retrieved Provisioned metadata via API: {metadata['node_type']} x{metadata['number_of_nodes']}"
)
except NoCredentialsError:
self.logger.debug("No AWS credentials found - skipping Provisioned API metadata")
except ClientError as e:
error_code = e.response.get("Error", {}).get("Code", "Unknown")
self.logger.debug(f"AWS API error querying Provisioned metadata: {error_code} - {e}")
except Exception as e:
self.logger.debug(f"Failed to query Provisioned API metadata: {e}")
return metadata
def _get_connection_params(self, **connection_config) -> dict[str, Any]:
"""Get standardized connection parameters."""
return {
"host": connection_config.get("host", self.host),
"port": connection_config.get("port", self.port),
"database": connection_config.get("database", self.database),
"user": connection_config.get("username", self.username),
"password": connection_config.get("password", self.password),
"sslmode": connection_config.get("sslmode", self.sslmode),
}
def _connect_with_driver(
self,
*,
application_name: str,
connect_timeout: int | None = None,
tcp_keepalive: bool = False,
tcp_keepalive_idle: int | None = None,
tcp_keepalive_interval: int | None = None,
tcp_keepalive_count: int | None = None,
**connection_config,
) -> Any:
"""Create a Redshift connection using the configured driver."""
params = self._get_connection_params(**connection_config)
if redshift_connector:
connect_kwargs: dict[str, Any] = {
"host": params["host"],
"port": params["port"],
"database": params["database"],
"user": params["user"],
"password": params["password"],
"ssl": self.ssl_enabled,
"application_name": application_name,
}
if connect_timeout is not None:
connect_kwargs["timeout"] = connect_timeout
# redshift_connector only implements certificate verification modes explicitly.
if self.ssl_enabled and params["sslmode"] in {"verify-ca", "verify-full"}:
connect_kwargs["sslmode"] = params["sslmode"]
if self.ssl_insecure:
connect_kwargs["ssl_insecure"] = True
if tcp_keepalive:
connect_kwargs["tcp_keepalive"] = True
if tcp_keepalive_idle is not None:
connect_kwargs["tcp_keepalive_idle"] = tcp_keepalive_idle
if tcp_keepalive_interval is not None:
connect_kwargs["tcp_keepalive_interval"] = tcp_keepalive_interval
if tcp_keepalive_count is not None:
connect_kwargs["tcp_keepalive_count"] = tcp_keepalive_count
return redshift_connector.connect(**connect_kwargs)
# psycopg2 supports keepalives/keepalives_idle/keepalives_interval/keepalives_count
# but the old code never passed them; tcp_keepalive params are silently ignored here.
connect_kwargs = {
"host": params["host"],
"port": params["port"],
"database": params["database"],
"user": params["user"],
"password": params["password"],
"sslmode": params["sslmode"],
"application_name": application_name,
}
if connect_timeout is not None:
connect_kwargs["connect_timeout"] = connect_timeout
if self.sslrootcert:
connect_kwargs["sslrootcert"] = self.sslrootcert
return psycopg2.connect(**connect_kwargs)
def _resolve_connect_timeout(self) -> int:
"""Determine appropriate connection timeout based on cluster state.
For provisioned clusters, queries AWS API to detect paused/resuming status
and extends the timeout accordingly. For serverless clusters, the compute
warmth cannot be queried via API (the workgroup status is always AVAILABLE
even when compute is cold), so a fixed extended timeout is used to cover
cold-start latency (~15-30s for Serverless).
Result is cached on the instance so the API is only hit once per adapter.
Returns:
Connection timeout in seconds.
"""
if hasattr(self, "_cached_connect_timeout"):
return self._cached_connect_timeout
timeout = self._compute_connect_timeout()
self._cached_connect_timeout = timeout
return timeout
def _compute_connect_timeout(self) -> int:
"""Internal: compute timeout by checking cluster state via AWS API."""
if not boto3 or not self.host:
return self.connect_timeout
deployment_type = self.deployment_type
region = self._extract_region_from_hostname(self.host, deployment_type)
identifier = self._extract_identifier_from_hostname(self.host, deployment_type)
if not region or not identifier:
return self.connect_timeout
if deployment_type == "provisioned":
try:
client = boto3.client("redshift", region_name=region)
response = client.describe_clusters(ClusterIdentifier=identifier)
clusters = response.get("Clusters", [])
if clusters:
status = clusters[0].get("ClusterStatus", "unknown").lower()
if status == "available":
return self.connect_timeout
elif status in ("resuming", "rebooting"):
self.log_verbose(f"Provisioned cluster is {status} — using extended timeout (120s)")
return max(self.connect_timeout, 120)
elif status == "paused":
self.logger.warning(
f"Redshift cluster '{identifier}' is paused. "
"Resume it via the AWS console or: "
f"aws redshift resume-cluster --cluster-identifier {identifier}"
)
return max(self.connect_timeout, 120)
except Exception as e:
self.logger.debug(f"Could not check provisioned cluster status: {e}")
elif deployment_type == "serverless":
# Serverless compute warmth is not queryable via the AWS API.
# The workgroup is always AVAILABLE even when compute is cold (auto-paused).
# Use an extended timeout to accommodate cold-start latency.
self.logger.debug("Serverless cluster: using extended timeout (60s) to cover potential cold-start")
return max(self.connect_timeout, 60)
return self.connect_timeout
def _long_running_timeout(self, floor: int = 300) -> int:
"""Timeout for long-running DDL operations (DROP DATABASE, VACUUM, ANALYZE).
Returns the adaptive timeout with a minimum floor, since these operations
routinely exceed normal connection timeouts. Uses the adaptive timeout as a
base so that cluster-state awareness (paused, resuming, cold-start) is preserved.
"""
return max(self._resolve_connect_timeout(), floor)
def _create_admin_connection(self, **connection_config) -> Any:
"""Create Redshift connection for admin operations.
Admin operations (CREATE DATABASE, DROP DATABASE, checking database existence)
require connecting to an existing database. This uses self.admin_database
(default: "dev") instead of the target database to avoid circular dependencies.
"""
connect_timeout = connection_config.pop("connect_timeout", self._resolve_connect_timeout())
admin_config = connection_config.copy()
admin_config["database"] = self.admin_database
return self._connect_with_driver(
application_name="BenchBox-Admin",
connect_timeout=connect_timeout,
**admin_config,
)
def _create_direct_connection(self, **connection_config) -> Any:
"""Create direct connection to target database for validation.
Connects directly to the specified database without:
- Calling handle_existing_database()
- Creating database if missing
- Setting database_was_reused flag
Used by validation framework to check existing database compatibility.
Args:
**connection_config: Connection configuration including database name
Returns:
Database connection object
Raises:
Exception: If connection fails (database doesn't exist, auth fails, etc.)
"""
connect_timeout = connection_config.pop("connect_timeout", self._resolve_connect_timeout())
connection = self._connect_with_driver(
application_name="BenchBox-Validation",
connect_timeout=connect_timeout,
**connection_config,
)
# Apply WLM queue settings if configured
if self.wlm_query_queue_name:
cursor = connection.cursor()
try:
# Escape single quotes in queue name for SQL safety
queue_name_escaped = self.wlm_query_queue_name.replace("'", "''")
cursor.execute(f"SET query_group TO '{queue_name_escaped}'")
finally:
cursor.close()
return connection
[docs]
def check_server_database_exists(self, **connection_config) -> bool:
"""Check if database exists in Redshift cluster.
Connects to admin database to query pg_database for the target database.
"""
try:
# Connect to admin database (not target database)
connection = self._create_admin_connection()
cursor = connection.cursor()
database = connection_config.get("database", self.database)
# Check if database exists
cursor.execute("SELECT datname FROM pg_database WHERE datname = %s", (database,))
result = cursor.fetchone()
return result is not None
except Exception:
# If we can't connect or check, assume database doesn't exist
return False
finally:
if "connection" in locals() and connection:
connection.close()
[docs]
def drop_database(self, **connection_config) -> None:
"""Drop database in Redshift cluster.
Connects to admin database to drop the target database.
Note: DROP DATABASE must run with autocommit enabled.
Note: Redshift doesn't support IF EXISTS for DROP DATABASE, so we check first.
Note: If DROP DATABASE fails with SQLSTATE 55006 (database still has active connections),
the method terminates backends again and retries on a *fresh* connection.
redshift_connector v2.1.x enters an aborted transaction state after a failed DDL
even with autocommit=True, causing any subsequent DDL on the same connection to
fail with error 25001. Opening a new connection guarantees clean driver state.
"""
database = connection_config.get("database", self.database)
# Check if database exists first (Redshift doesn't support IF EXISTS)
if not self.check_server_database_exists(database=database):
self.log_verbose(f"Database {database} does not exist - nothing to drop")
return
try:
# Connect to admin database with extended timeout — DROP DATABASE
# can take 30-120+ seconds on Redshift depending on cluster state.
connection = self._create_admin_connection(connect_timeout=self._long_running_timeout())
connection.autocommit = True # Enable autocommit for DROP DATABASE
# Terminate any existing connections first to avoid "being accessed" errors
# and to speed up the DROP. This is safe because we only kill sessions
# attached to the *target* database, not the admin database we're on.
cursor = connection.cursor()
self.log_verbose(f"Terminating existing connections to {database}...")
cursor.execute(
"""
SELECT pg_terminate_backend(procpid)
FROM pg_stat_activity
WHERE datname = %s AND procpid <> pg_backend_pid()
""",
(database,),
)
# Drop the database (quote identifier for SQL safety)
try:
cursor.execute(f'DROP DATABASE "{database}"')
except Exception as drop_error:
# Re-raise if not SQLSTATE 55006 (object_in_use / database has active connections).
# Check structured error code first (redshift_connector stores the server dict in
# args[0] with key 'C'; psycopg2 exposes it as pgcode). Fall back to substring
# matching only for unknown exception types so future driver versions don't silently
# skip the retry.
sqlstate = None
first_arg = drop_error.args[0] if drop_error.args else None
if isinstance(first_arg, dict):
sqlstate = first_arg.get("C") # redshift_connector wire-protocol dict
if sqlstate is None:
sqlstate = getattr(drop_error, "pgcode", None) # psycopg2
if sqlstate is not None:
if sqlstate != "55006":
raise
else:
# Unknown driver — fall back to message text
error_msg = str(drop_error).lower()
if "active connection" not in error_msg and "being accessed" not in error_msg:
raise
# Retry with a fresh connection — redshift_connector can enter an aborted
# transaction state after a failed DDL, causing subsequent DDL to fail with
# error 25001 even when autocommit=True. A fresh connection guarantees clean state.
self.log_verbose("Database still has active connections after terminate, retrying...")
cursor.close()
try:
connection.close()
except Exception:
pass # best-effort; fresh connection is the goal
connection = self._create_admin_connection(connect_timeout=self._long_running_timeout())
connection.autocommit = True
cursor = connection.cursor()
cursor.execute(
"""
SELECT pg_terminate_backend(procpid)
FROM pg_stat_activity
WHERE datname = %s AND procpid <> pg_backend_pid()
""",
(database,),
)
cursor.execute(f'DROP DATABASE "{database}"')
except Exception as e:
raise RuntimeError(f"Failed to drop Redshift database {database}: {e}") from e
finally:
if "cursor" in locals():
try:
cursor.close()
except Exception:
pass
if "connection" in locals() and connection:
try:
connection.close()
except Exception:
pass
[docs]
def create_connection(self, **connection_config) -> Any:
"""Create optimized Redshift connection."""
self.log_operation_start("Redshift connection")
# Handle existing database using base class method
self.handle_existing_database(**connection_config)
# Get connection parameters
params = self._get_connection_params(**connection_config)
target_database = params.get("database")
# Create database if needed (before connecting to it)
# Redshift requires connecting to an admin database to create new databases
if not self.database_was_reused:
# Check if target database exists
database_exists = self.check_server_database_exists(database=target_database)
if not database_exists:
self.log_verbose(f"Creating database: {target_database}")
# Create database using admin connection (connects to self.admin_database)
# Note: CREATE DATABASE must run with autocommit enabled (cannot run in transaction block)
try:
admin_conn = self._create_admin_connection()
admin_conn.autocommit = True # Enable autocommit for CREATE DATABASE
admin_cursor = admin_conn.cursor()
try:
# Quote identifier for SQL safety
admin_cursor.execute(f'CREATE DATABASE "{target_database}"')
# No commit() needed - autocommit handles it automatically
self.logger.info(f"Created database {target_database}")
finally:
admin_cursor.close()
admin_conn.close()
except Exception as e:
self.logger.error(f"Failed to create database {target_database}: {e}")
raise
self.log_very_verbose(f"Redshift connection params: host={params.get('host')}, database={target_database}")
try:
connect_timeout = connection_config.get("connect_timeout", self._resolve_connect_timeout())
connect_config = {k: v for k, v in connection_config.items() if k != "connect_timeout"}
connection = self._connect_with_driver(
application_name="BenchBox",
connect_timeout=connect_timeout,
tcp_keepalive=True,
tcp_keepalive_idle=600,
tcp_keepalive_interval=30,
tcp_keepalive_count=3,
**connect_config,
)
# Enable autocommit immediately after connection creation (before any SQL operations)
connection.autocommit = True
# Apply WLM settings and schema search path
cursor = connection.cursor()
# Integer settings validated in __init__, safe to interpolate
if self.wlm_query_slot_count > 1:
cursor.execute(f"SET wlm_query_slot_count = {int(self.wlm_query_slot_count)}")
if self.statement_timeout > 0:
cursor.execute(f"SET statement_timeout = {int(self.statement_timeout)}")
# Set search_path to ensure all unqualified table references use correct schema
# Critical for database reuse when schema already exists but connection is new
# Quote identifier for SQL safety
cursor.execute(f'SET search_path TO "{self.schema}"')
# Test connection
cursor.execute("SELECT version()")
cursor.fetchone()
cursor.close()
self.logger.info(f"Connected to Redshift cluster at {params['host']}:{params['port']}")
self.log_operation_complete(
"Redshift connection", details=f"Connected to {params['host']}:{params['port']}"
)
return connection
except Exception as e:
self.logger.error(f"Failed to connect to Redshift: {e}")
raise
[docs]
def create_schema(self, benchmark, connection: Any) -> float:
"""Create schema using Redshift-optimized table definitions."""
start_time = mono_time()
cursor = connection.cursor()
try:
# Create schema if needed (if not using default "public")
# Quote identifiers for SQL safety
if self.schema and self.schema.lower() != "public":
self.log_verbose(f"Creating schema: {self.schema}")
cursor.execute(f'CREATE SCHEMA IF NOT EXISTS "{self.schema}"')
self.logger.info(f"Created schema {self.schema}")
# Set search_path to use the correct schema
self.log_very_verbose(f"Setting search_path to: {self.schema}")
cursor.execute(f'SET search_path TO "{self.schema}"')
# Use common schema creation helper
schema_sql = self._create_schema_with_tuning(benchmark, source_dialect="duckdb")
# Split schema into individual statements and execute
statements = [stmt.strip() for stmt in schema_sql.split(";") if stmt.strip()]
for statement in statements:
# Normalize table names to lowercase for Redshift consistency
# This ensures CREATE, COPY, and SELECT all use the same case
statement = self._normalize_table_name_in_sql(statement)
# Ensure idempotency with DROP TABLE IF EXISTS
# (Redshift doesn't support CREATE OR REPLACE TABLE)
if statement.upper().startswith("CREATE TABLE"):
# Extract table name from CREATE TABLE statement
table_name = self._extract_table_name(statement)
if table_name:
# Ensure table name is lowercase
table_name_lower = table_name.strip('"').lower()
drop_statement = f"DROP TABLE IF EXISTS {table_name_lower}"
cursor.execute(drop_statement)
self.logger.debug(f"Executed: {drop_statement}")
# Optimize table definition for Redshift
statement = self._optimize_table_definition(statement)
cursor.execute(statement)
self.logger.debug(f"Executed schema statement: {statement[:100]}...")
self.logger.info("Schema created")
except Exception as e:
self.logger.error(f"Schema creation failed: {e}")
raise
finally:
cursor.close()
return elapsed_seconds(start_time)
def _resolve_data_files(self, benchmark, data_dir: Path) -> dict:
"""Resolve data files from benchmark tables or manifest fallback."""
resolver = DataSourceResolver(
platform_name=self.platform_name,
table_mode=getattr(self, "table_mode", "native"),
platform_config=getattr(self, "__dict__", None),
)
data_source = resolver.resolve(benchmark, data_dir)
if self._should_use_manifest_selected_files(data_source):
manifest_source = next((p for p in resolver.providers if isinstance(p, ManifestFileSource)), None)
if manifest_source is not None:
manifest_data_source = manifest_source.get_data_source(benchmark, data_dir)
if manifest_data_source and manifest_data_source.tables:
self.log_verbose(
"Redshift native mode detected directory-valued benchmark tables; "
"using manifest-selected native files instead"
)
data_source = manifest_data_source
if not data_source or not data_source.tables:
raise ValueError("No data files found. Ensure benchmark.generate_data() was called first.")
return data_source.tables
def _should_use_manifest_selected_files(self, data_source: Any) -> bool:
"""Use manifest-selected files when native benchmark tables only point at directories."""
if getattr(self, "table_mode", "native") == "external":
return False
if not data_source or getattr(data_source, "source_type", None) not in {
"benchmark_tables",
"benchmark_impl_tables",
}:
return False
for table_paths in data_source.tables.values():
normalized_paths = table_paths if isinstance(table_paths, list) else [table_paths]
for path_like in normalized_paths:
path = Path(path_like)
if path.exists() and path.is_dir():
return True
return False
def _create_s3_client(self):
"""Create S3 client with explicit error handling."""
try:
# XOR check: reject partial credentials — both key ID and secret must be
# provided together, or both omitted (to fall back to environment/IAM).
if bool(self.aws_access_key_id) != bool(self.aws_secret_access_key):
raise ValueError(
"Explicit S3 upload credentials require both aws_access_key_id and aws_secret_access_key."
)
session = boto3.Session(
aws_access_key_id=self.aws_access_key_id,
aws_secret_access_key=self.aws_secret_access_key,
aws_session_token=self.aws_session_token,
region_name=self.aws_region,
)
if session.get_credentials() is None:
raise ValueError(
"AWS credentials not found. Configure credentials via:\n"
" 1. Explicit S3 upload credentials in config\n"
" 2. AWS CLI (aws configure)\n"
" 3. Environment variables (AWS_ACCESS_KEY_ID/AWS_SECRET_ACCESS_KEY)"
)
# request_checksum_calculation="when_required" prevents botocore from wrapping
# upload streams in AwsChunkedWrapper (introduced in botocore 1.35+). Without this,
# retried uploads fail with "Need to rewind the stream... but stream is not seekable".
return session.client(
"s3",
config=BotocoreConfig(request_checksum_calculation="when_required"),
)
except NoCredentialsError as e:
raise ValueError(
"AWS credentials not found. Configure credentials via:\n"
" 1. IAM role (aws_access_key_id/aws_secret_access_key in config)\n"
" 2. AWS CLI (aws configure)\n"
" 3. Environment variables (AWS_ACCESS_KEY_ID/AWS_SECRET_ACCESS_KEY)"
) from e
except ClientError as e:
raise ValueError(f"Failed to create AWS S3 client: {e}") from e
def _upload_file_to_s3(self, s3_client, file_path: Path, table_name: str, file_idx: int) -> str:
"""Upload a single file to S3 and return the S3 URI."""
# Preserve full multi-part suffix for chunked files (e.g., .tbl.1.zst)
file_stem = file_path.stem
original_suffix = file_path.suffix
if "." in file_stem:
parts = file_path.name.split(".", 1)
full_suffix = "." + parts[1] if len(parts) > 1 else original_suffix
else:
full_suffix = original_suffix
s3_key = f"{self.s3_prefix}/{table_name}_{file_idx}{full_suffix}"
try:
s3_client.upload_file(str(file_path), self.s3_bucket, s3_key)
except ClientError as e:
error_code = e.response.get("Error", {}).get("Code", "Unknown")
if error_code == "NoSuchBucket":
raise ValueError(
f"S3 bucket '{self.s3_bucket}' does not exist. Create the bucket or update your configuration."
) from e
elif error_code == "AccessDenied":
raise ValueError(
f"Access denied to S3 bucket '{self.s3_bucket}'. Check IAM permissions for s3:PutObject."
) from e
else:
raise ValueError(
f"Failed to upload {file_path.name} to s3://{self.s3_bucket}/{s3_key}: {error_code} - {e}"
) from e
return f"s3://{self.s3_bucket}/{s3_key}"
def _build_s3_copy_source(self, s3_client, s3_uris: list[str], table_name: str) -> tuple[str, str]:
"""Build COPY source path and manifest option from S3 URIs.
Returns:
Tuple of (copy_from_path, manifest_option)
"""
if len(s3_uris) > 1:
manifest = {"entries": [{"url": uri, "mandatory": True} for uri in s3_uris]}
manifest_key = f"{self.s3_prefix}/{table_name}_manifest.json"
try:
s3_client.put_object(
Bucket=self.s3_bucket,
Key=manifest_key,
Body=json.dumps(manifest),
)
except ClientError as e:
error_code = e.response.get("Error", {}).get("Code", "Unknown")
raise ValueError(
f"Failed to upload manifest file to s3://{self.s3_bucket}/{manifest_key}: {error_code} - {e}"
) from e
return f"s3://{self.s3_bucket}/{manifest_key}", "manifest"
else:
return s3_uris[0], ""
@staticmethod
def _sanitize_copy_credential(value: str, name: str) -> str:
"""Reject credential values containing single quotes to prevent COPY SQL injection."""
if "'" in value:
raise ValueError(
f"Credential '{name}' contains a single-quote character, which is not allowed in COPY SQL."
)
return value
def _get_copy_credentials_clause(self) -> str:
"""Build credentials clause for Redshift COPY command."""
if self.iam_role:
return f"IAM_ROLE '{self._sanitize_copy_credential(self.iam_role, 'iam_role')}'"
elif self.aws_access_key_id and self.aws_secret_access_key:
key_id = self._sanitize_copy_credential(self.aws_access_key_id, "aws_access_key_id")
secret = self._sanitize_copy_credential(self.aws_secret_access_key, "aws_secret_access_key")
if self.aws_session_token:
token = self._sanitize_copy_credential(self.aws_session_token, "aws_session_token")
return f"ACCESS_KEY_ID '{key_id}' SECRET_ACCESS_KEY '{secret}' SESSION_TOKEN '{token}'"
return f"ACCESS_KEY_ID '{key_id}' SECRET_ACCESS_KEY '{secret}'"
else:
self.log_verbose("No explicit credentials configured for COPY; using cluster default IAM role")
return ""
def _load_table_via_s3(self, cursor, s3_client, table_name: str, valid_files: list[Path], connection: Any) -> int:
"""Upload files to S3 and load a single table via COPY command. Returns row count."""
table_name_lower = table_name.lower()
parquet_modes = {is_parquet_format(file_path) for file_path in valid_files}
if len(parquet_modes) > 1:
raise ValueError(
f"Redshift COPY requires a uniform source format per table; received mixed parquet and delimited "
f"files for '{table_name_lower}'."
)
is_parquet_copy = parquet_modes == {True}
# Upload all files to S3 and collect S3 URIs
s3_uris = []
for file_idx, file_path in enumerate(valid_files):
s3_uris.append(self._upload_file_to_s3(s3_client, file_path, table_name, file_idx))
copy_from_path, manifest_option = self._build_s3_copy_source(s3_client, s3_uris, table_name)
# Detect compression format from file extension
compressions = {detect_compression(f) for f in valid_files}
if "zstd" in compressions:
compression_option = "ZSTD"
elif "gzip" in compressions:
compression_option = "GZIP"
else:
compression_option = ""
credentials_clause = self._get_copy_credentials_clause()
# Build and execute COPY command
qualified_table = f"{self.schema}.{table_name_lower}"
if is_parquet_copy:
copy_sql = f"""
COPY {qualified_table}
FROM '{copy_from_path}'
{credentials_clause}
{manifest_option}
FORMAT AS PARQUET
"""
else:
delimiter = get_delimiter_for_file(valid_files[0])
copy_sql = f"""
COPY {qualified_table}
FROM '{copy_from_path}'
{credentials_clause}
{manifest_option}
{compression_option}
DELIMITER '{delimiter}'
IGNOREHEADER 0
COMPUPDATE {self.compupdate}
"""
cursor.execute(copy_sql)
effective_tuning = self.get_effective_tuning_configuration()
if effective_tuning is not None:
self.apply_ctas_sort(table_name_lower, effective_tuning, connection)
# Get row count
cursor.execute(f"SELECT COUNT(*) FROM {qualified_table}")
row_count = cursor.fetchone()[0]
# Run ANALYZE if configured
if self.auto_analyze:
cursor.execute(f"ANALYZE {qualified_table}")
return row_count
def _load_table_via_insert(self, cursor, table_name: str, valid_files: list[Path], connection: Any) -> int:
"""Load a single table via INSERT statements. Returns total rows loaded."""
table_name_lower = table_name.lower()
total_rows_loaded = 0
for file_idx, file_path in enumerate(valid_files):
chunk_info = f" (chunk {file_idx + 1}/{len(valid_files)})" if len(valid_files) > 1 else ""
self.log_very_verbose(f"Loading {table_name}{chunk_info} from {file_path.name}")
delimiter = get_delimiter_for_file(file_path)
compression_handler = FileFormatRegistry.get_compression_handler(file_path)
with compression_handler.open(file_path) as f:
rows_loaded = 0
batch_size = 1000
batch_data = []
for line in f:
line = line.strip()
if line and line.endswith(delimiter):
line = line[:-1]
values = line.split(delimiter)
escaped_values = ["'" + str(v).replace("'", "''") + "'" for v in values]
batch_data.append(f"({', '.join(escaped_values)})")
if len(batch_data) >= batch_size:
insert_sql = f"INSERT INTO {table_name_lower} VALUES " + ", ".join(batch_data)
cursor.execute(insert_sql)
rows_loaded += len(batch_data)
total_rows_loaded += len(batch_data)
batch_data = []
# Insert remaining batch
if batch_data:
insert_sql = f"INSERT INTO {table_name_lower} VALUES " + ", ".join(batch_data)
cursor.execute(insert_sql)
rows_loaded += len(batch_data)
total_rows_loaded += len(batch_data)
self.log_very_verbose(f"Loaded {rows_loaded:,} rows from {file_path.name}")
effective_tuning = self.get_effective_tuning_configuration()
if effective_tuning is not None:
self.apply_ctas_sort(table_name_lower, effective_tuning, connection)
return total_rows_loaded
def _filter_valid_files(self, file_paths) -> list[Path]:
"""Normalize to list and filter out non-existent or empty files."""
if not isinstance(file_paths, list):
file_paths = [file_paths]
valid_files = []
for file_path in file_paths:
file_path = Path(file_path)
if file_path.exists() and file_path.stat().st_size > 0:
valid_files.append(file_path)
return valid_files
[docs]
def validate_external_table_requirements(self) -> None:
"""Validate prerequisites for Redshift external table mode."""
if not self.s3_bucket:
raise ValueError(
"Redshift external mode requires S3 staging (set --platform-option s3_bucket=<bucket> "
"or provide --platform-option staging_root=s3://bucket/path)."
)
if not self.iam_role:
raise ValueError(
"Redshift external mode requires IAM role credentials for Spectrum "
"(set --platform-option iam_role=<arn>)."
)
@staticmethod
def _map_external_column_type(column_type: str) -> str:
"""Map benchmark schema types to Redshift Spectrum-compatible types."""
normalized = str(column_type).strip().upper()
if not normalized:
return "VARCHAR(65535)"
if normalized.startswith(("DECIMAL", "NUMERIC", "VARCHAR", "CHAR", "TIMESTAMP")):
return normalized
if "BIGINT" in normalized:
return "BIGINT"
if "SMALLINT" in normalized:
return "SMALLINT"
if "INT" in normalized:
return "INTEGER"
if "DOUBLE" in normalized:
return "DOUBLE PRECISION"
if "REAL" in normalized or "FLOAT" in normalized:
return "REAL"
if "DATE" in normalized:
return "DATE"
if "BOOLEAN" in normalized or normalized == "BOOL":
return "BOOLEAN"
return "VARCHAR(65535)"
def _build_external_column_definitions(self, benchmark: Any, table_name: str) -> str:
"""Build external table column definitions from benchmark schema."""
if not hasattr(benchmark, "get_schema"):
raise ValueError(
f"Benchmark schema metadata unavailable for '{table_name}'. "
"Redshift external mode requires benchmark.get_schema()."
)
schema = benchmark.get_schema() or {}
table_schema = schema.get(table_name) or schema.get(table_name.lower()) or schema.get(table_name.upper())
if not table_schema:
raise ValueError(f"No schema definition found for table '{table_name}'.")
columns = table_schema.get("columns", [])
if not columns:
raise ValueError(f"No columns found in schema definition for table '{table_name}'.")
column_defs = []
for column in columns:
column_name = str(column.get("name", "")).strip().lower()
if not column_name:
continue
mapped_type = self._map_external_column_type(str(column.get("type", "")))
column_defs.append(f"{column_name} {mapped_type}")
if not column_defs:
raise ValueError(f"No valid column definitions were generated for table '{table_name}'.")
return ", ".join(column_defs)
def _upload_external_parquet_files_to_s3(self, s3_client: Any, table_name: str, parquet_files: list[Path]) -> str:
"""Upload table Parquet files and return Spectrum LOCATION prefix."""
table_name_lower = table_name.lower()
s3_prefix = f"{self.s3_prefix}/{self.database.lower()}_external/{table_name_lower}"
for file_path in parquet_files:
s3_key = f"{s3_prefix}/{file_path.name}"
s3_client.upload_file(str(file_path), self.s3_bucket, s3_key)
return f"s3://{self.s3_bucket}/{s3_prefix}/"
def _upload_external_directory_to_s3(self, s3_client: Any, table_name: str, directory: Path) -> str:
"""Upload a directory tree for external Delta-style registration."""
table_name_lower = table_name.lower()
s3_prefix = f"{self.s3_prefix}/{self.database.lower()}_external/{table_name_lower}"
for file_path in directory.rglob("*"):
if not file_path.is_file():
continue
relative = file_path.relative_to(directory).as_posix()
s3_client.upload_file(str(file_path), self.s3_bucket, f"{s3_prefix}/{relative}")
return f"s3://{self.s3_bucket}/{s3_prefix}/"
[docs]
def create_external_tables(
self, benchmark: Any, connection: Any, data_dir: Path
) -> tuple[dict[str, int], float, dict[str, Any] | None]:
"""Upload external-table sources to S3 and register Redshift Spectrum tables."""
self.validate_external_table_requirements()
assert self.iam_role is not None
start_time = mono_time()
table_stats: dict[str, int] = {}
# Derive a unique external schema name; warn if the base name already
# looks like a previous external-mode artifact to avoid cascading suffixes.
if self.schema.endswith("_external"):
self.logger.warning(
f"Schema '{self.schema}' already ends with '_external'. "
"The derived external schema will be '%s_external'.",
self.schema,
)
external_schema = f"{self.schema}_external"
escaped_iam_role = self.iam_role.replace("'", "''")
escaped_db = self.database.lower().replace("'", "''")
cursor = connection.cursor()
try:
data_files = self._resolve_data_files(benchmark, data_dir)
s3_client = self._create_s3_client()
cursor.execute(
f"""
CREATE EXTERNAL SCHEMA IF NOT EXISTS {external_schema}
FROM DATA CATALOG
DATABASE '{escaped_db}_external'
IAM_ROLE '{escaped_iam_role}'
CREATE EXTERNAL DATABASE IF NOT EXISTS
"""
)
for table_name, file_paths in data_files.items():
table_name_lower = table_name.lower()
valid_files = self._filter_valid_files(file_paths)
delta_dirs = [path for path in valid_files if path.is_dir() and (path / "_delta_log").is_dir()]
parquet_files = [path for path in valid_files if path.suffix.lower() == ".parquet"]
if delta_dirs:
location = self._upload_external_directory_to_s3(s3_client, table_name_lower, delta_dirs[0])
source_format = "delta"
elif parquet_files:
location = self._upload_external_parquet_files_to_s3(s3_client, table_name_lower, parquet_files)
source_format = "parquet"
else:
raise ValueError(
f"Redshift external mode requires Parquet files or Delta directories for table "
f"'{table_name_lower}'. No supported sources were found."
)
column_defs = self._build_external_column_definitions(benchmark, table_name_lower)
cursor.execute(f"DROP TABLE IF EXISTS {external_schema}.{table_name_lower}")
if source_format == "delta":
cursor.execute(
f"""
CREATE EXTERNAL TABLE {external_schema}.{table_name_lower}
({column_defs})
STORED AS PARQUET
LOCATION '{location}'
TABLE PROPERTIES ('table_type'='DELTA')
"""
)
else:
cursor.execute(
f"""
CREATE EXTERNAL TABLE {external_schema}.{table_name_lower}
({column_defs})
STORED AS PARQUET
LOCATION '{location}'
"""
)
cursor.execute(f"SELECT COUNT(*) FROM {external_schema}.{table_name_lower}")
result = cursor.fetchone()
table_stats[table_name_lower] = int(result[0]) if result else 0
finally:
cursor.close()
total_time = elapsed_seconds(start_time)
return table_stats, total_time, None
[docs]
def load_data(
self, benchmark, connection: Any, data_dir: Path
) -> tuple[dict[str, int], float, dict[str, Any] | None]:
"""Load data using Redshift COPY command with S3 integration."""
start_time = mono_time()
table_stats = {}
cursor = connection.cursor()
try:
data_files = self._resolve_data_files(benchmark, data_dir)
# Upload files to S3 and load via COPY command
if self.s3_bucket and boto3:
s3_client = self._create_s3_client()
for table_name, file_paths in data_files.items():
valid_files = self._filter_valid_files(file_paths)
if not valid_files:
self.logger.warning(f"Skipping {table_name} - no valid data files")
table_stats[table_name.lower()] = 0
continue
chunk_info = f" from {len(valid_files)} file(s)" if len(valid_files) > 1 else ""
self.log_verbose(f"Loading data for table: {table_name}{chunk_info}")
try:
load_start = mono_time()
row_count = self._load_table_via_s3(cursor, s3_client, table_name, valid_files, connection)
table_stats[table_name.lower()] = row_count
load_time = elapsed_seconds(load_start)
self.logger.info(
f"✅ Loaded {row_count:,} rows into {table_name.lower()}{chunk_info} in {load_time:.2f}s"
)
except Exception as e:
self.logger.error(f"Failed to load {table_name}: {str(e)[:100]}...")
table_stats[table_name.lower()] = 0
else:
# Direct loading without S3 (less efficient)
self.logger.warning("No S3 bucket configured, using direct INSERT loading")
for table_name, file_paths in data_files.items():
valid_files = self._filter_valid_files(file_paths)
if not valid_files:
self.logger.warning(f"Skipping {table_name} - no valid data files")
table_stats[table_name.lower()] = 0
continue
try:
self.log_verbose(f"Direct loading data for table: {table_name}")
load_start = mono_time()
total_rows_loaded = self._load_table_via_insert(cursor, table_name, valid_files, connection)
table_stats[table_name.lower()] = total_rows_loaded
load_time = elapsed_seconds(load_start)
chunk_info = f" from {len(valid_files)} file(s)" if len(valid_files) > 1 else ""
self.logger.info(
f"✅ Loaded {total_rows_loaded:,} rows into {table_name.lower()}{chunk_info} in {load_time:.2f}s"
)
except Exception as e:
self.logger.error(f"Failed to load {table_name}: {str(e)[:100]}...")
table_stats[table_name.lower()] = 0
total_time = elapsed_seconds(start_time)
total_rows = sum(table_stats.values())
self.logger.info(f"✅ Loaded {total_rows:,} total rows in {total_time:.2f}s")
except Exception as e:
self.logger.error(f"Data loading failed: {e}")
raise
finally:
cursor.close()
# Redshift doesn't provide detailed per-table timings yet
return table_stats, total_time, None
def _run_vacuum_analyze_isolated(self, main_connection: Any) -> None:
"""Run VACUUM/ANALYZE on a dedicated connection to protect the main one.
VACUUM and ANALYZE are long-running DDL operations that can trigger
socket-level timeouts, especially on paused serverless clusters.
A socket timeout permanently breaks the TCP connection, so these
operations must be isolated from the benchmark connection.
Args:
main_connection: The main benchmark connection (used only to
query the table list; VACUUM/ANALYZE run on a separate conn).
"""
# Query table list using the main connection (fast metadata query)
main_cursor = main_connection.cursor()
try:
main_cursor.execute(f"""
SELECT schemaname, tablename
FROM pg_tables
WHERE schemaname = '{self.schema}'
""")
tables = main_cursor.fetchall()
finally:
main_cursor.close()
if not tables:
return
# Create a separate connection for VACUUM/ANALYZE.
# The finally block is the single owner of cleanup — error paths
# just return and let finally close the connection.
maint_conn = None
try:
maint_conn = self._connect_with_driver(
application_name="BenchBox-Maintenance",
connect_timeout=self._long_running_timeout(),
)
maint_conn.autocommit = True
maint_cursor = maint_conn.cursor()
# Set search_path to match the main connection
maint_cursor.execute(f'SET search_path TO "{self.schema}"')
for _schema, table in tables:
ops = []
if self.auto_vacuum:
ops.append(("VACUUM", f'VACUUM "{table}"'))
if self.auto_analyze:
ops.append(("ANALYZE", f'ANALYZE "{table}"'))
for op_name, sql in ops:
try:
maint_cursor.execute(sql)
except Exception as e:
self.logger.warning(f"{op_name} failed for {table}: {e}")
self.logger.warning(
f"Maintenance connection lost during {op_name} — skipping remaining VACUUM/ANALYZE"
)
return
except Exception as e:
self.logger.warning(f"Could not create maintenance connection for VACUUM/ANALYZE: {e}")
finally:
if maint_conn is not None:
try:
maint_conn.close()
except Exception:
pass
[docs]
def validate_session_cache_control(self, connection: Any) -> dict[str, Any]:
"""Validate that session-level cache control settings were successfully applied.
Args:
connection: Active Redshift database connection
Returns:
dict with:
- validated: bool - Whether validation passed
- cache_disabled: bool - Whether cache is actually disabled
- settings: dict - Actual session settings
- warnings: list[str] - Any validation warnings
- errors: list[str] - Any validation errors
Raises:
ConfigurationError: If cache control validation fails and strict_validation=True
"""
from benchbox.platforms.cloud_shared import validate_session_cache_control
return validate_session_cache_control(
connection=connection,
query="SELECT current_setting('enable_result_cache_for_session') as value",
setting_key="enable_result_cache_for_session",
disabled_value="off",
enabled_value="on",
normalize="lower",
platform_name="Redshift",
disable_result_cache=self.disable_result_cache,
strict_validation=self.strict_validation,
adapter_logger=self.logger,
)
[docs]
def execute_query(
self,
connection: Any,
query: str,
query_id: str,
benchmark_type: str | None = None,
scale_factor: float | None = None,
validate_row_count: bool = True,
stream_id: int | None = None,
) -> dict[str, Any]:
"""Execute query with detailed timing and performance tracking."""
start_time = mono_time()
cursor = connection.cursor()
try:
# Execute the query
# Note: Query dialect translation is now handled automatically by the base adapter
cursor.execute(query)
result = cursor.fetchall()
execution_time = elapsed_seconds(start_time)
actual_row_count = len(result) if result else 0
# Get query statistics
try:
query_stats = self._get_query_statistics(connection, query_id)
# Add execution time for cost calculation
query_stats["execution_time_seconds"] = execution_time
except Exception:
query_stats = {"execution_time_seconds": execution_time}
# Validate row count if enabled and benchmark type is provided
validation_result = None
if validate_row_count and benchmark_type:
from benchbox.core.validation.query_validation import QueryValidator
validator = QueryValidator()
validation_result = validator.validate_query_result(
benchmark_type=benchmark_type,
query_id=query_id,
actual_row_count=actual_row_count,
scale_factor=scale_factor,
stream_id=stream_id,
)
# Log validation result
if validation_result.warning_message:
self.log_verbose(f"Row count validation: {validation_result.warning_message}")
elif not validation_result.is_valid:
self.log_verbose(f"Row count validation FAILED: {validation_result.error_message}")
else:
self.log_very_verbose(
f"Row count validation PASSED: {actual_row_count} rows "
f"(expected: {validation_result.expected_row_count})"
)
# Use base helper to build result with consistent validation field mapping
result_dict = self._build_query_result_with_validation(
query_id=query_id,
execution_time=execution_time,
actual_row_count=actual_row_count,
first_row=result[0] if result else None,
validation_result=validation_result,
)
# Include Redshift-specific fields
result_dict["translated_query"] = None # Translation handled by base adapter
result_dict["query_statistics"] = query_stats
# Map query_statistics to resource_usage for cost calculation
result_dict["resource_usage"] = query_stats
return result_dict
except Exception as e:
execution_time = elapsed_seconds(start_time)
return {
"query_id": query_id,
"status": "FAILED",
"execution_time_seconds": execution_time,
"rows_returned": 0,
"error": str(e),
"error_type": type(e).__name__,
}
finally:
cursor.close()
def _extract_table_name(self, statement: str) -> str | None:
"""Extract table name from CREATE TABLE statement."""
from benchbox.core.sql_utils import extract_table_name
return extract_table_name(statement)
def _normalize_table_name_in_sql(self, sql: str) -> str:
"""Normalize table names in SQL to lowercase for Redshift."""
return normalize_table_name_in_sql(sql)
def _optimize_table_definition(self, statement: str) -> str:
"""Optimize table definition for Redshift."""
if not statement.upper().startswith("CREATE TABLE"):
return statement
# Include distribution and sort keys for better performance
if "DISTSTYLE" not in statement.upper():
# Include AUTO distribution style (Redshift will choose appropriate distribution)
statement += " DISTSTYLE AUTO"
if "SORTKEY" not in statement.upper():
# Include sort key on first column (simple heuristic)
# In production, this would be more sophisticated
statement += " SORTKEY AUTO"
return statement
def _get_platform_metadata(self, connection: Any) -> dict[str, Any]:
"""Get Redshift-specific metadata and system information."""
metadata = {
"platform": self.platform_name,
"host": self.host,
"port": self.port,
"database": self.database,
"result_cache_enabled": not self.disable_result_cache,
}
cursor = connection.cursor()
try:
# Get Redshift version
cursor.execute("SELECT version()")
result = cursor.fetchone()
metadata["redshift_version"] = result[0] if result else "unknown"
# Get cluster information (STV tables unavailable on Serverless)
try:
if self.deployment_type == "serverless":
cursor.execute("""
SELECT compute_capacity
FROM sys_serverless_usage
ORDER BY start_time DESC
LIMIT 1
""")
result = cursor.fetchone()
if result:
metadata["cluster_info"] = {
"deployment_type": "serverless",
"compute_capacity_rpu": result[0],
}
else:
cursor.execute("""
SELECT
node_type,
num_nodes,
cluster_version,
publicly_accessible
FROM stv_cluster_configuration
LIMIT 1
""")
result = cursor.fetchone()
if result:
metadata["cluster_info"] = {
"deployment_type": "provisioned",
"node_type": result[0],
"num_nodes": result[1],
"cluster_version": result[2],
"publicly_accessible": result[3],
}
except Exception as e:
self.logger.debug(f"Could not query cluster information: {e}")
# Get current session information
cursor.execute("""
SELECT
current_user,
current_database(),
current_schema(),
inet_client_addr(),
inet_client_port()
""")
result = cursor.fetchone()
if result:
metadata["session_info"] = {
"current_user": result[0],
"current_database": result[1],
"current_schema": result[2],
"client_addr": result[3],
"client_port": result[4],
}
# Get table information
cursor.execute(f"""
SELECT
schemaname,
tablename,
tableowner,
tablespace,
hasindexes,
hasrules,
hastriggers
FROM pg_tables
WHERE schemaname = '{self.schema}'
""")
tables = cursor.fetchall()
metadata["tables"] = [
{
"schema": row[0],
"table": row[1],
"owner": row[2],
"tablespace": row[3],
"has_indexes": row[4],
"has_rules": row[5],
"has_triggers": row[6],
}
for row in tables
]
except Exception as e:
metadata["metadata_error"] = str(e)
finally:
cursor.close()
return metadata
def _get_existing_tables(self, connection: Any) -> list[str]:
"""Get list of existing tables from Redshift (normalized to lowercase).
Queries information_schema for tables in the current schema and returns
them as lowercase names to match Redshift's identifier normalization.
Args:
connection: Database connection
Returns:
List of table names (lowercase)
"""
cursor = connection.cursor()
try:
# Query information_schema for tables in current schema
cursor.execute(
"""
SELECT table_name
FROM information_schema.tables
WHERE table_schema = %s
AND table_type = 'BASE TABLE'
""",
(self.schema,),
)
# Return lowercase table names
return [row[0].lower() for row in cursor.fetchall()]
finally:
cursor.close()
def _get_query_statistics(self, connection: Any, query_id: str) -> dict[str, Any]:
"""Get query statistics from Redshift system tables.
Note: This retrieves performance telemetry only - not used for validation.
Query row count validation uses the actual result set from cursor.fetchall().
"""
cursor = connection.cursor()
try:
# Get the actual Redshift query ID for the most recent query
# pg_last_query_id() returns the query ID of the last executed query in this session
cursor.execute("SELECT pg_last_query_id()")
result = cursor.fetchone()
if not result or result[0] == -1:
# No queries executed yet, or query ran only on leader node
return {}
redshift_query_id = result[0]
# Query system tables for statistics (STL unavailable on Serverless)
if self.deployment_type == "serverless":
cursor.execute(
"""
SELECT
query_id,
DATEDIFF('microseconds', start_time, end_time) as duration_microsecs,
DATEDIFF('microseconds', start_time, end_time) as cpu_time_microsecs,
0 as bytes_scanned,
0 as bytes_returned,
1 as slots,
1 as wlm_slots,
status
FROM sys_query_history
WHERE query_id = %s
ORDER BY start_time DESC
LIMIT 1
""",
(redshift_query_id,),
)
else:
cursor.execute(
"""
SELECT
query,
DATEDIFF('microseconds', starttime, endtime) as duration_microsecs,
DATEDIFF('microseconds', starttime, endtime) as cpu_time_microsecs,
0 as bytes_scanned,
0 as bytes_returned,
1 as slots,
1 as wlm_slots,
aborted
FROM stl_query
WHERE query = %s
ORDER BY starttime DESC
LIMIT 1
""",
(redshift_query_id,),
)
result = cursor.fetchone()
if result:
# sys_query_history returns status string; stl_query returns aborted int
raw_aborted = result[7]
if isinstance(raw_aborted, str):
aborted = raw_aborted not in ("success", "running")
else:
aborted = bool(raw_aborted) if raw_aborted is not None else False
return {
"query_id": str(result[0]),
"duration_microsecs": result[1] or 0,
"cpu_time_microsecs": result[2] or 0,
"bytes_scanned": result[3] or 0,
"bytes_returned": result[4] or 0,
"slots": result[5] or 1,
"wlm_slots": result[6] or 1,
"aborted": aborted,
}
else:
return {}
except Exception:
return {}
finally:
cursor.close()
[docs]
def analyze_table(self, connection: Any, table_name: str) -> None:
"""Run ANALYZE on table for query optimization."""
cursor = connection.cursor()
try:
cursor.execute(f"ANALYZE {table_name.lower()}")
except Exception as e:
self.logger.warning(f"Failed to analyze table {table_name}: {e}")
finally:
cursor.close()
[docs]
def vacuum_table(self, connection: Any, table_name: str) -> None:
"""Run VACUUM on table for space reclamation."""
cursor = connection.cursor()
try:
cursor.execute(f"VACUUM {table_name.lower()}")
except Exception as e:
self.logger.warning(f"Failed to vacuum table {table_name}: {e}")
finally:
cursor.close()
[docs]
def get_query_plan(self, connection: Any, query: str) -> str:
"""Get query execution plan for analysis."""
cursor = connection.cursor()
try:
# Note: Query dialect translation is now handled automatically by the base adapter
cursor.execute(f"EXPLAIN {query}")
plan_rows = cursor.fetchall()
return "\n".join([row[0] for row in plan_rows])
except Exception as e:
return f"Could not get query plan: {e}"
finally:
cursor.close()
[docs]
def close_connection(self, connection: Any) -> None:
"""Close Redshift connection."""
try:
if connection and hasattr(connection, "close"):
connection.close()
except Exception as e:
self.logger.warning(f"Error closing connection: {e}")
[docs]
def supports_tuning_type(self, tuning_type) -> bool:
"""Check if Redshift supports a specific tuning type.
Redshift supports:
- DISTRIBUTION: Via DISTSTYLE and DISTKEY clauses
- SORTING: Via SORTKEY clause (compound and interleaved)
- PARTITIONING: Through table design patterns and date partitioning
Args:
tuning_type: The type of tuning to check support for
Returns:
True if the tuning type is supported by Redshift
"""
# Import here to avoid circular imports
try:
from benchbox.core.tuning.interface import TuningType
return tuning_type in {
TuningType.DISTRIBUTION,
TuningType.SORTING,
TuningType.PARTITIONING,
}
except ImportError:
return False
[docs]
def generate_tuning_clause(self, table_tuning) -> str:
"""Generate Redshift-specific tuning clauses for CREATE TABLE statements.
Redshift supports:
- DISTSTYLE (EVEN | KEY | ALL) DISTKEY (column)
- SORTKEY (column1, column2, ...) or INTERLEAVED SORTKEY (column1, column2, ...)
Args:
table_tuning: The tuning configuration for the table
Returns:
SQL clause string to be appended to CREATE TABLE statement
"""
if not table_tuning or not table_tuning.has_any_tuning():
return ""
clauses = []
try:
# Import here to avoid circular imports
from benchbox.core.tuning.interface import TuningType
# Handle distribution strategy
distribution_columns = table_tuning.get_columns_by_type(TuningType.DISTRIBUTION)
if distribution_columns:
# Sort by order and use first column as distribution key
sorted_cols = sorted(distribution_columns, key=lambda col: col.order)
dist_col = sorted_cols[0]
# Use KEY distribution style with the specified column
clauses.append("DISTSTYLE KEY")
clauses.append(f"DISTKEY ({dist_col.name})")
else:
# Default to EVEN distribution if no distribution columns specified
clauses.append("DISTSTYLE EVEN")
# Handle sorting
sort_columns = table_tuning.get_columns_by_type(TuningType.SORTING)
if sort_columns:
# Sort by order for sortkey
sorted_cols = sorted(sort_columns, key=lambda col: col.order)
column_names = [col.name for col in sorted_cols]
# Use compound sort key by default (better for most OLAP workloads)
# Could be made configurable to choose between COMPOUND and INTERLEAVED
sortkey_clause = f"SORTKEY ({', '.join(column_names)})"
clauses.append(sortkey_clause)
# Handle partitioning through table naming/organization (logged but not in CREATE TABLE)
partition_columns = table_tuning.get_columns_by_type(TuningType.PARTITIONING)
if partition_columns:
# Redshift partitioning is typically handled through table design patterns
# We'll log the strategy but not add SQL clauses
pass
# Clustering not directly supported in Redshift CREATE TABLE
except ImportError:
# If tuning interface not available, return empty string
pass
return " ".join(clauses)
[docs]
def apply_table_tunings(self, table_tuning, connection: Any) -> None:
"""Apply tuning configurations to a Redshift table.
Redshift tuning approach:
- DISTRIBUTION: Handled via DISTSTYLE/DISTKEY in CREATE TABLE
- SORTING: Handled via SORTKEY in CREATE TABLE
- Post-creation optimizations via ANALYZE and VACUUM
Args:
table_tuning: The tuning configuration to apply
connection: Redshift connection
Raises:
ValueError: If the tuning configuration is invalid for Redshift
"""
if not table_tuning or not table_tuning.has_any_tuning():
return
table_name = table_tuning.table_name.lower()
self.logger.info(f"Applying Redshift tunings for table: {table_name}")
cursor = connection.cursor()
try:
# Import here to avoid circular imports
from benchbox.core.tuning.interface import TuningType
# Redshift tuning is primarily handled at table creation time
# Post-creation optimizations are limited
# Verify table exists and get current configuration
cursor.execute(f"""
SELECT
"schema",
"table",
diststyle,
distkey,
sortkey1,
sortkey2,
sortkey3,
sortkey4
FROM pg_table_def
WHERE schemaname = 'public'
AND tablename = '{table_name.lower()}'
""")
result = cursor.fetchone()
if result:
current_diststyle = result[2]
current_distkey = result[3]
current_sortkeys = [sk for sk in result[4:8] if sk] # Filter out None values
self.logger.info(f"Current configuration for {table_name}:")
self.logger.info(f" Distribution style: {current_diststyle}")
self.logger.info(f" Distribution key: {current_distkey}")
self.logger.info(f" Sort keys: {current_sortkeys}")
# Check if configuration matches desired tuning
distribution_columns = table_tuning.get_columns_by_type(TuningType.DISTRIBUTION)
sort_columns = table_tuning.get_columns_by_type(TuningType.SORTING)
needs_recreation = False
# Check distribution configuration
if distribution_columns:
sorted_cols = sorted(distribution_columns, key=lambda col: col.order)
desired_distkey = sorted_cols[0].name
if current_distkey != desired_distkey or current_diststyle != "KEY":
needs_recreation = True
self.logger.info(
f"Distribution key mismatch: current='{current_distkey}', desired='{desired_distkey}'"
)
# Check sort key configuration
if sort_columns:
sorted_cols = sorted(sort_columns, key=lambda col: col.order)
desired_sortkeys = [col.name for col in sorted_cols]
if current_sortkeys != desired_sortkeys:
needs_recreation = True
self.logger.info(f"Sort keys mismatch: current={current_sortkeys}, desired={desired_sortkeys}")
if needs_recreation:
self.logger.warning(
f"Table {table_name} configuration differs from desired tuning. "
"Redshift requires table recreation to change distribution/sort keys."
)
else:
self.logger.warning(f"Could not find table configuration for {table_name}")
# Perform maintenance operations that can help with performance
try:
# Run ANALYZE to update table statistics
cursor.execute(f"ANALYZE {table_name}")
self.logger.info(f"Analyzed table statistics for {table_name}")
# Run VACUUM to reclaim space and re-sort data
if self.auto_vacuum:
cursor.execute(f"VACUUM {table_name}")
self.logger.info(f"Vacuumed table {table_name}")
except Exception as e:
self.logger.warning(f"Failed to perform maintenance operations on {table_name}: {e}")
# Handle partitioning strategy
partition_columns = table_tuning.get_columns_by_type(TuningType.PARTITIONING)
if partition_columns:
sorted_cols = sorted(partition_columns, key=lambda col: col.order)
column_names = [col.name for col in sorted_cols]
self.logger.info(
f"Partitioning strategy for {table_name}: {', '.join(column_names)} (handled via table design patterns)"
)
# Clustering not directly supported in Redshift
cluster_columns = table_tuning.get_columns_by_type(TuningType.CLUSTERING)
if cluster_columns:
sorted_cols = sorted(cluster_columns, key=lambda col: col.order)
column_names = [col.name for col in sorted_cols]
self.logger.info(
f"Clustering strategy for {table_name} achieved via sort keys: {', '.join(column_names)}"
)
except ImportError:
self.logger.warning("Tuning interface not available - skipping tuning application")
except Exception as e:
raise ValueError(f"Failed to apply tunings to Redshift table {table_name}: {e}") from e
finally:
cursor.close()
[docs]
def apply_unified_tuning(self, unified_config: UnifiedTuningConfiguration, connection: Any) -> None:
"""Apply unified tuning configuration to Redshift.
Args:
unified_config: Unified tuning configuration to apply
connection: Redshift connection
"""
if not unified_config:
return
# Apply constraint configurations
self.apply_constraint_configuration(unified_config.primary_keys, unified_config.foreign_keys, connection)
# Apply platform optimizations
if unified_config.platform_optimizations:
self.apply_platform_optimizations(unified_config.platform_optimizations, connection)
# Apply table-level tunings
for _table_name, table_tuning in unified_config.table_tunings.items():
self.apply_table_tunings(table_tuning, connection)
[docs]
def apply_constraint_configuration(
self,
primary_key_config: PrimaryKeyConfiguration,
foreign_key_config: ForeignKeyConfiguration,
connection: Any,
) -> None:
"""Apply constraint configurations to Redshift.
Note: Redshift supports PRIMARY KEY and FOREIGN KEY constraints for query optimization,
but they are informational only (not enforced). Constraints must be applied during
table creation time.
Args:
primary_key_config: Primary key constraint configuration
foreign_key_config: Foreign key constraint configuration
connection: Redshift connection
"""
# Redshift constraints are applied at table creation time for query optimization
# This method is called after tables are created, so log the configurations
if primary_key_config and primary_key_config.enabled:
self.logger.info(
"Primary key constraints enabled for Redshift (informational only, applied during table creation)"
)
if foreign_key_config and foreign_key_config.enabled:
self.logger.info(
"Foreign key constraints enabled for Redshift (informational only, applied during table creation)"
)
# Redshift constraints are informational and used for query optimization
# No additional work to do here as they're applied during CREATE TABLE
def _build_redshift_config(
platform: str,
options: dict[str, Any],
overrides: dict[str, Any],
info: Any,
) -> Any:
"""Build Redshift database configuration with credential loading.
This function loads saved credentials from the CredentialManager and
merges them with CLI options and runtime overrides.
Args:
platform: Platform name (should be 'redshift')
options: CLI platform options from --platform-option flags
overrides: Runtime overrides from orchestrator
info: Platform info from registry
Returns:
DatabaseConfig with credentials loaded
"""
from benchbox.core.schemas import DatabaseConfig
from benchbox.security.credentials import CredentialManager
# Load saved credentials
cred_manager = CredentialManager()
saved_creds = cred_manager.get_platform_credentials("redshift") or {}
# Build merged options: saved_creds < options < overrides
merged_options = {}
merged_options.update(saved_creds)
merged_options.update(options)
merged_options.update(overrides)
# Extract credential fields for DatabaseConfig
name = info.display_name if info else "Amazon Redshift"
driver_package = info.driver_package if info else "redshift-connector"
# Build config dict with platform-specific fields at top-level
# This allows RedshiftAdapter.__init__() and from_config() to access them via config.get()
config_dict = {
"type": "redshift",
"name": name,
"options": merged_options or {}, # Ensure options is never None (Pydantic v2 uses None if explicitly passed)
"driver_package": driver_package,
"driver_version": overrides.get("driver_version") or options.get("driver_version"),
"driver_auto_install": bool(overrides.get("driver_auto_install", options.get("driver_auto_install", False))),
# Platform-specific fields at top-level (adapters expect these here)
"host": merged_options.get("host"),
"port": merged_options.get("port"),
# NOTE: database is NOT included here - from_config() generates it from benchmark context
# Only explicit overrides (via --platform-option database=...) should bypass generation
"username": merged_options.get("username"),
"password": merged_options.get("password"),
"schema": merged_options.get("schema"),
# S3 and AWS configuration
"s3_bucket": merged_options.get("s3_bucket"),
"s3_prefix": merged_options.get("s3_prefix"),
"staging_root": merged_options.get("staging_root"),
"iam_role": merged_options.get("iam_role"),
"aws_access_key_id": merged_options.get("aws_access_key_id"),
"aws_secret_access_key": merged_options.get("aws_secret_access_key"),
"aws_session_token": merged_options.get("aws_session_token"),
"aws_region": merged_options.get("aws_region"),
# Optional settings
"cluster_identifier": merged_options.get("cluster_identifier"),
"admin_database": merged_options.get("admin_database", "dev"),
"connect_timeout": merged_options.get("connect_timeout"),
"statement_timeout": merged_options.get("statement_timeout"),
"sslmode": merged_options.get("sslmode"),
# Benchmark context for config-aware database naming (from overrides)
"benchmark": overrides.get("benchmark"),
"scale_factor": overrides.get("scale_factor"),
"tuning_config": overrides.get("tuning_config"),
}
# Only include explicit database override if provided via CLI or overrides
# Saved credentials should NOT override generated database names
if "database" in overrides and overrides["database"]:
config_dict["database"] = overrides["database"]
return DatabaseConfig(**config_dict)
# Register the config builder with the platform hook registry
# This must happen when the module is imported
try:
from benchbox.cli.platform_hooks import PlatformHookRegistry
PlatformHookRegistry.register_config_builder("redshift", _build_redshift_config)
except ImportError:
# Platform hooks may not be available in all contexts (e.g., core-only usage)
pass