Source code for benchbox.platforms.redshift

"""Amazon Redshift platform adapter with S3 integration and data warehouse optimizations.

Provides Redshift-specific optimizations for analytical workloads,
including COPY command for efficient data loading and distribution key optimization.

Copyright 2026 Joe Harris / BenchBox Project

Licensed under the MIT License. See LICENSE file in the project root for details.
"""

from __future__ import annotations

import json
from pathlib import Path
from typing import TYPE_CHECKING, Any

from benchbox.core.sql_utils import normalize_table_name_in_sql
from benchbox.utils.clock import elapsed_seconds, mono_time

if TYPE_CHECKING:
    from benchbox.core.tuning.interface import (
        ForeignKeyConfiguration,
        PlatformOptimizationConfiguration,
        PrimaryKeyConfiguration,
        TuningColumn,
        UnifiedTuningConfiguration,
    )

from ..core.exceptions import ConfigurationError
from ..utils.dependencies import (
    check_platform_dependencies,
    get_dependency_error_message,
    get_dependency_group_packages,
)
from ..utils.file_format import detect_compression, get_delimiter_for_file, is_parquet_format
from .base import DriverIsolationCapability, PlatformAdapter
from .base.data_loading import DataSourceResolver, FileFormatRegistry, ManifestFileSource

try:
    import redshift_connector
except ImportError:
    try:
        import psycopg2

        redshift_connector = None
    except ImportError:
        psycopg2 = None
        redshift_connector = None

try:
    import boto3
    from botocore.config import Config as BotocoreConfig
    from botocore.exceptions import ClientError, NoCredentialsError
except ImportError:
    boto3 = None


[docs] class RedshiftAdapter(PlatformAdapter): """Amazon Redshift platform adapter with S3 integration.""" driver_isolation_capability = DriverIsolationCapability.FEASIBLE_CLIENT_ONLY supports_external_tables = True
[docs] def __init__(self, **config): super().__init__(**config) dependency_packages = get_dependency_group_packages("redshift") # Check dependencies - prefer redshift-connector, fallback to psycopg2 if not redshift_connector and not psycopg2: available, missing = check_platform_dependencies("redshift") if not available: error_msg = get_dependency_error_message("redshift", missing) raise ImportError(error_msg) else: # Ensure shared helper libraries (e.g., boto3, cloudpathlib) are available shared_packages = [pkg for pkg in dependency_packages if pkg != "redshift-connector"] if shared_packages: available_shared, missing_shared = check_platform_dependencies("redshift", shared_packages) if not available_shared: error_msg = get_dependency_error_message("redshift", missing_shared) raise ImportError(error_msg) self._dialect = "redshift" # Redshift connection configuration self.host = config.get("host") self.port = config.get("port") if config.get("port") is not None else 5439 self.database = config.get("database") or "dev" self.username = config.get("username") self.password = config.get("password") self.cluster_identifier = config.get("cluster_identifier") # Admin database for metadata operations (CREATE/DROP DATABASE, checking database existence) # Redshift requires connecting to an existing database for admin operations # Default: "dev" (Redshift Serverless default database) self.admin_database = config.get("admin_database") or "dev" # Schema configuration self.schema = config.get("schema") or "public" # Connection settings self.connect_timeout = config.get("connect_timeout") if config.get("connect_timeout") is not None else 10 self.statement_timeout = config.get("statement_timeout") if config.get("statement_timeout") is not None else 0 self.sslmode = config.get("sslmode") or "require" # WLM settings self.wlm_query_slot_count = ( config.get("wlm_query_slot_count") if config.get("wlm_query_slot_count") is not None else 1 ) self.wlm_query_queue_name = config.get("wlm_query_queue_name") # SSL configuration (legacy compatibility) self.ssl_enabled = config.get("ssl_enabled") if config.get("ssl_enabled") is not None else True self.ssl_insecure = config.get("ssl_insecure") if config.get("ssl_insecure") is not None else False self.sslrootcert = config.get("sslrootcert") # S3 configuration for data loading # Check for staging_root first (set by orchestrator for CloudStagingPath) staging_root = config.get("staging_root") if staging_root: # Parse s3://bucket/path format to extract bucket and prefix from benchbox.utils.cloud_storage import get_cloud_path_info path_info = get_cloud_path_info(staging_root) if path_info["provider"] == "s3": self.s3_bucket = path_info["bucket"] # Use the path component if provided, otherwise use default self.s3_prefix = path_info["path"].strip("/") if path_info["path"] else "benchbox-data" self.logger.info(f"Using staging location from config: s3://{self.s3_bucket}/{self.s3_prefix}") else: raise ValueError(f"Redshift requires S3 (s3://) staging location, got: {path_info['provider']}://") else: # Fall back to explicit s3_bucket configuration self.s3_bucket = config.get("s3_bucket") self.s3_prefix = config.get("s3_prefix") or "benchbox-data" self.iam_role = config.get("iam_role") self.aws_access_key_id = config.get("aws_access_key_id") self.aws_secret_access_key = config.get("aws_secret_access_key") self.aws_session_token = config.get("aws_session_token") self.aws_region = config.get("aws_region") or "us-east-1" # Redshift optimization settings self.workload_management_config = config.get("wlm_config") # COMPUPDATE controls automatic compression during COPY (PRESET | ON | OFF) # PRESET: Apply compression based on column data types (no sampling) # ON: Apply compression based on data sampling # OFF: Disable automatic compression compupdate_raw = config.get("compupdate") or "PRESET" self.compupdate = compupdate_raw.upper() # Normalize to uppercase # Validate COMPUPDATE value valid_compupdate_values = {"ON", "OFF", "PRESET"} if self.compupdate not in valid_compupdate_values: raise ValueError( f"Invalid COMPUPDATE value: '{compupdate_raw}'. " f"Must be one of: {', '.join(sorted(valid_compupdate_values))}" ) self.auto_vacuum = config.get("auto_vacuum") if config.get("auto_vacuum") is not None else True self.auto_analyze = config.get("auto_analyze") if config.get("auto_analyze") is not None else True # Result cache control - disable by default for accurate benchmarking self.disable_result_cache = config.get("disable_result_cache", True) # Validation strictness - raise errors if cache control validation fails self.strict_validation = config.get("strict_validation", True) # Cache deployment type for Serverless vs Provisioned branching self.deployment_type = self._detect_deployment_type(self.host or "") if not all([self.host, self.username, self.password]): missing = [] if not self.host: missing.append("host (or REDSHIFT_HOST)") if not self.username: missing.append("username (or REDSHIFT_USER)") if not self.password: missing.append("password (or REDSHIFT_PASSWORD)") raise ConfigurationError( f"Redshift configuration is incomplete. Missing: {', '.join(missing)}\n" "Configure with one of:\n" " 1. CLI: benchbox platforms setup --platform redshift\n" " 2. Environment variables: REDSHIFT_HOST, REDSHIFT_USER, REDSHIFT_PASSWORD\n" " 3. CLI options: --platform-option host=<cluster>.redshift.amazonaws.com" )
@property def platform_name(self) -> str: return "Redshift" def _build_ctas_sort_sql(self, table_name: str, sort_columns: list[TuningColumn]) -> str | list[str] | None: """Build opt-in sorted-ingestion SQL for Redshift.""" mode, method = self.resolve_sorted_ingestion_strategy() if mode == "off": return None if method == "vacuum_sort": return f"VACUUM SORT ONLY {self.schema}.{table_name}" if method == "ctas": ordered_cols = ", ".join(column.name for column in sort_columns) temp_table = f"{table_name}__ctas_sort" return [ f"CREATE TABLE {self.schema}.{temp_table} AS SELECT * FROM {self.schema}.{table_name} ORDER BY {ordered_cols}", f"DROP TABLE {self.schema}.{table_name}", f"ALTER TABLE {self.schema}.{temp_table} RENAME TO {table_name}", ] raise ValueError(f"Sorted ingestion method '{method}' is not supported for Redshift.")
[docs] @staticmethod def add_cli_arguments(parser) -> None: """Add Redshift-specific CLI arguments.""" rs_group = parser.add_argument_group("Redshift Arguments") rs_group.add_argument("--host", type=str, help="Redshift cluster endpoint hostname") rs_group.add_argument("--port", type=int, default=5439, help="Redshift cluster port") rs_group.add_argument("--database", type=str, default="dev", help="Database name") rs_group.add_argument("--username", type=str, help="Database user with required privileges") rs_group.add_argument("--password", type=str, help="Password for the database user") rs_group.add_argument("--iam-role", type=str, help="IAM role ARN for COPY operations") rs_group.add_argument("--s3-bucket", type=str, help="S3 bucket for data staging") rs_group.add_argument("--s3-prefix", type=str, default="benchbox-data", help="Prefix within the staging bucket")
[docs] @classmethod def from_config(cls, config: dict[str, Any]): """Create Redshift adapter from unified configuration.""" from benchbox.utils.database_naming import generate_database_name adapter_config: dict[str, Any] = {} # Generate proper database name using benchmark characteristics # (unless explicitly overridden in config) if "database" in config and config["database"]: # User explicitly provided database name - use it adapter_config["database"] = config["database"] else: # Generate configuration-aware database name database_name = generate_database_name( benchmark_name=config["benchmark"], scale_factor=config["scale_factor"], platform="redshift", tuning_config=config.get("tuning_config"), ) adapter_config["database"] = database_name # Core connection parameters (database handled above) for key in ["host", "port", "username", "password", "schema"]: if key in config: adapter_config[key] = config[key] # Optional staging/optimization parameters for key in [ "iam_role", "s3_bucket", "s3_prefix", "staging_root", "aws_access_key_id", "aws_secret_access_key", "aws_session_token", "aws_region", "cluster_identifier", "admin_database", "connect_timeout", "statement_timeout", "sslmode", "ssl_enabled", "ssl_insecure", "sslrootcert", "wlm_query_slot_count", "wlm_query_queue_name", "wlm_config", "compupdate", "auto_vacuum", "auto_analyze", ]: if key in config: adapter_config[key] = config[key] return cls(**adapter_config)
[docs] def get_platform_info(self, connection: Any = None) -> dict[str, Any]: """Get Redshift platform information. Captures comprehensive Redshift configuration including: - Deployment type (serverless vs provisioned) - Capacity configuration (RPUs for serverless, node type/count for provisioned) - Redshift version - WLM (Workload Management) configuration - AWS region - Encryption and security settings Uses fallback chain: AWS API → SQL queries → hostname parsing Gracefully degrades if permissions are insufficient or AWS credentials unavailable. """ # Step 1: Use cached deployment type deployment_type = self.deployment_type region = self._extract_region_from_hostname(self.host, deployment_type) or self.aws_region identifier = self._extract_identifier_from_hostname(self.host, deployment_type) platform_info = { "platform_type": "redshift", "platform_name": "Redshift", "connection_mode": "remote", "cloud_provider": "AWS", "host": self.host, "port": self.port, "configuration": { "database": self.database, "region": region, "s3_bucket": self.s3_bucket, "iam_role": self.iam_role, "compupdate": getattr(self, "compupdate", None), "result_cache_enabled": not self.disable_result_cache, "deployment_type": deployment_type, }, } # Get client library version if redshift_connector: try: platform_info["client_library_version"] = redshift_connector.__version__ except AttributeError: platform_info["client_library_version"] = None else: try: import psycopg2 platform_info["client_library_version"] = psycopg2.__version__ except (ImportError, AttributeError): platform_info["client_library_version"] = None # Try to get Redshift version and extended metadata from connection if connection: cursor = None try: cursor = connection.cursor() # Get Redshift version cursor.execute("SELECT version()") result = cursor.fetchone() platform_info["platform_version"] = result[0] if result else None platform_info["engine_version"] = platform_info["platform_version"] platform_info["engine_version_source"] = "sql_query" # Step 2: Collect deployment-specific metadata using fallback chain deployment_metadata = {} if deployment_type == "serverless": # Try API first (most complete) if identifier and region: api_metadata = self._get_serverless_metadata_api(identifier, region) if api_metadata: deployment_metadata.update(api_metadata) self.logger.debug("Using Serverless metadata from AWS API") # Fall back to SQL if API didn't provide data if not deployment_metadata: sql_metadata = self._get_serverless_metadata_sql(cursor) if sql_metadata: deployment_metadata.update(sql_metadata) self.logger.debug("Using Serverless metadata from SQL queries") # Add serverless-specific fields to configuration if deployment_metadata: platform_info["configuration"]["workgroup_name"] = deployment_metadata.get( "workgroup_name", identifier ) platform_info["configuration"]["namespace_name"] = deployment_metadata.get("namespace_name") platform_info["configuration"]["base_capacity_rpu"] = deployment_metadata.get( "base_capacity_rpu" ) platform_info["configuration"]["max_capacity_rpu"] = deployment_metadata.get("max_capacity_rpu") platform_info["configuration"]["enhanced_vpc_routing"] = deployment_metadata.get( "enhanced_vpc_routing" ) platform_info["configuration"]["encrypted"] = deployment_metadata.get("encrypted") elif deployment_type == "provisioned": # Try API first (most complete) if identifier and region: api_metadata = self._get_provisioned_metadata_api(identifier, region) if api_metadata: deployment_metadata.update(api_metadata) self.logger.debug("Using Provisioned metadata from AWS API") # Fall back to SQL if API didn't provide data if not deployment_metadata: sql_metadata = self._get_provisioned_metadata_sql(cursor) if sql_metadata: deployment_metadata.update(sql_metadata) self.logger.debug("Using Provisioned metadata from SQL queries") # Add provisioned-specific fields to configuration if deployment_metadata: platform_info["configuration"]["cluster_identifier"] = deployment_metadata.get( "cluster_identifier", identifier ) platform_info["configuration"]["node_type"] = deployment_metadata.get("node_type") platform_info["configuration"]["number_of_nodes"] = deployment_metadata.get("number_of_nodes") platform_info["configuration"]["total_storage_capacity_mb"] = deployment_metadata.get( "total_storage_capacity_mb" ) platform_info["configuration"]["enhanced_vpc_routing"] = deployment_metadata.get( "enhanced_vpc_routing" ) platform_info["configuration"]["encrypted"] = deployment_metadata.get("encrypted") # Legacy compute_configuration field for backward compatibility platform_info["compute_configuration"] = { "node_type": deployment_metadata.get("node_type"), "cluster_version": deployment_metadata.get("cluster_version"), "num_compute_nodes": deployment_metadata.get("number_of_nodes"), } # Try to get WLM (Workload Management) configuration try: cursor.execute(""" SELECT service_class, num_query_tasks, query_working_mem, max_execution_time, user_group_wild_card, query_group_wild_card FROM stv_wlm_service_class_config WHERE service_class >= 6 ORDER BY service_class LIMIT 5 """) wlm_results = cursor.fetchall() if wlm_results: if "compute_configuration" not in platform_info: platform_info["compute_configuration"] = {} platform_info["compute_configuration"]["wlm_queues"] = [] for row in wlm_results: platform_info["compute_configuration"]["wlm_queues"].append( { "service_class": row[0] if len(row) > 0 else None, "num_query_tasks": row[1] if len(row) > 1 else None, "query_working_mem_mb": row[2] if len(row) > 2 else None, "max_execution_time_ms": row[3] if len(row) > 3 else None, } ) self.logger.debug("Successfully captured Redshift WLM configuration") except Exception as e: self.logger.debug(f"Could not query Redshift WLM configuration: {e}") except Exception as e: self.logger.debug(f"Error collecting Redshift platform info: {e}") if platform_info.get("platform_version") is None: platform_info["platform_version"] = None finally: if cursor: cursor.close() else: platform_info["platform_version"] = None return platform_info
[docs] def get_target_dialect(self) -> str: """Return the target SQL dialect for Redshift.""" return "redshift"
def _detect_deployment_type(self, hostname: str) -> str: """Detect Redshift deployment type from hostname pattern. Args: hostname: Redshift endpoint hostname Returns: "serverless", "provisioned", or "unknown" """ if not hostname: return "unknown" if ".redshift-serverless.amazonaws.com" in hostname: return "serverless" elif ".redshift.amazonaws.com" in hostname: return "provisioned" else: return "unknown" def _extract_region_from_hostname(self, hostname: str, deployment_type: str) -> str | None: """Extract AWS region from Redshift hostname. Args: hostname: Redshift endpoint hostname deployment_type: "serverless" or "provisioned" Returns: AWS region string or None if not found """ if not hostname: return None parts = hostname.split(".") if deployment_type == "serverless": # Format: workgroup.account.region.redshift-serverless.amazonaws.com return parts[2] if len(parts) > 2 else None elif deployment_type == "provisioned": # Format: cluster.region.redshift.amazonaws.com return parts[1] if len(parts) > 1 else None return None def _extract_identifier_from_hostname(self, hostname: str, deployment_type: str) -> str | None: """Extract workgroup name or cluster identifier from hostname. Args: hostname: Redshift endpoint hostname deployment_type: "serverless" or "provisioned" Returns: Workgroup name, cluster identifier, or None """ if not hostname: return None parts = hostname.split(".") if deployment_type == "serverless": # Format: workgroup.account.region.redshift-serverless.amazonaws.com return parts[0] if len(parts) > 0 else None elif deployment_type == "provisioned": # Format: cluster.region.redshift.amazonaws.com return parts[0] if len(parts) > 0 else None return None def _get_serverless_metadata_sql(self, cursor: Any) -> dict[str, Any]: """Get Redshift Serverless metadata using SQL queries. Args: cursor: Active database cursor Returns: Dictionary with serverless metadata (empty if not serverless or queries fail) """ metadata = {} try: # Try to query sys_serverless_usage (serverless-only table) cursor.execute(""" SELECT compute_capacity FROM sys_serverless_usage ORDER BY start_time DESC LIMIT 1 """) result = cursor.fetchone() if result: # Table exists and has data - this is serverless metadata["current_rpu_capacity"] = result[0] if result[0] is not None else None self.logger.debug("Detected Redshift Serverless via sys_serverless_usage table") except Exception as e: # Table doesn't exist or query failed - likely not serverless self.logger.debug(f"sys_serverless_usage query failed (not serverless or no permissions): {e}") return metadata def _get_provisioned_metadata_sql(self, cursor: Any) -> dict[str, Any]: """Get Redshift Provisioned metadata using SQL queries. Args: cursor: Active database cursor Returns: Dictionary with provisioned metadata (empty if not provisioned or queries fail) """ metadata = {} try: # Query stv_cluster_configuration (provisioned-only table) cursor.execute(""" SELECT node_type, cluster_version, COUNT(*) as num_nodes FROM stv_cluster_configuration GROUP BY node_type, cluster_version """) result = cursor.fetchone() if result: metadata["node_type"] = result[0] if len(result) > 0 else None metadata["cluster_version"] = result[1] if len(result) > 1 else None metadata["number_of_nodes"] = result[2] if len(result) > 2 else None self.logger.debug( f"Detected Redshift Provisioned: {metadata['node_type']} x{metadata['number_of_nodes']}" ) except Exception as e: # Table doesn't exist or query failed - likely not provisioned or no permissions self.logger.debug(f"stv_cluster_configuration query failed (not provisioned or no permissions): {e}") return metadata def _get_serverless_metadata_api(self, workgroup_name: str, region: str) -> dict[str, Any]: """Get Redshift Serverless metadata using boto3 API. Args: workgroup_name: Workgroup name region: AWS region Returns: Dictionary with serverless metadata (empty if API call fails) """ if not boto3: self.logger.debug("boto3 not available - skipping Serverless API metadata") return {} metadata = {} try: client = boto3.client("redshift-serverless", region_name=region) # Get workgroup details response = client.get_workgroup(workgroupName=workgroup_name) workgroup = response.get("workgroup", {}) # Extract essential sizing information metadata["workgroup_name"] = workgroup.get("workgroupName") metadata["base_capacity_rpu"] = workgroup.get("baseCapacity") metadata["max_capacity_rpu"] = workgroup.get("maxCapacity") metadata["namespace_name"] = workgroup.get("namespaceName") metadata["enhanced_vpc_routing"] = workgroup.get("enhancedVpcRouting", False) metadata["status"] = workgroup.get("status") # Get namespace details for encryption info if metadata.get("namespace_name"): try: namespace_response = client.get_namespace(namespaceName=metadata["namespace_name"]) namespace = namespace_response.get("namespace", {}) metadata["kms_key_id"] = namespace.get("kmsKeyId") metadata["encrypted"] = True # Serverless is always encrypted except Exception as e: self.logger.debug(f"Could not fetch namespace details: {e}") self.logger.debug(f"Retrieved Serverless metadata via API: {metadata['base_capacity_rpu']} RPUs") except NoCredentialsError: self.logger.debug("No AWS credentials found - skipping Serverless API metadata") except ClientError as e: error_code = e.response.get("Error", {}).get("Code", "Unknown") self.logger.debug(f"AWS API error querying Serverless metadata: {error_code} - {e}") except Exception as e: self.logger.debug(f"Failed to query Serverless API metadata: {e}") return metadata def _get_provisioned_metadata_api(self, cluster_identifier: str, region: str) -> dict[str, Any]: """Get Redshift Provisioned metadata using boto3 API. Args: cluster_identifier: Cluster identifier region: AWS region Returns: Dictionary with provisioned metadata (empty if API call fails) """ if not boto3: self.logger.debug("boto3 not available - skipping Provisioned API metadata") return {} metadata = {} try: client = boto3.client("redshift", region_name=region) # Get cluster details response = client.describe_clusters(ClusterIdentifier=cluster_identifier) clusters = response.get("Clusters", []) if clusters: cluster = clusters[0] # Extract essential sizing information metadata["cluster_identifier"] = cluster.get("ClusterIdentifier") metadata["node_type"] = cluster.get("NodeType") metadata["number_of_nodes"] = cluster.get("NumberOfNodes") metadata["cluster_status"] = cluster.get("ClusterStatus") metadata["encrypted"] = cluster.get("Encrypted", False) metadata["kms_key_id"] = cluster.get("KmsKeyId") metadata["enhanced_vpc_routing"] = cluster.get("EnhancedVpcRouting", False) # Storage capacity total_storage_mb = cluster.get("TotalStorageCapacityInMegaBytes") if total_storage_mb: metadata["total_storage_capacity_mb"] = total_storage_mb self.logger.debug( f"Retrieved Provisioned metadata via API: {metadata['node_type']} x{metadata['number_of_nodes']}" ) except NoCredentialsError: self.logger.debug("No AWS credentials found - skipping Provisioned API metadata") except ClientError as e: error_code = e.response.get("Error", {}).get("Code", "Unknown") self.logger.debug(f"AWS API error querying Provisioned metadata: {error_code} - {e}") except Exception as e: self.logger.debug(f"Failed to query Provisioned API metadata: {e}") return metadata def _get_connection_params(self, **connection_config) -> dict[str, Any]: """Get standardized connection parameters.""" return { "host": connection_config.get("host", self.host), "port": connection_config.get("port", self.port), "database": connection_config.get("database", self.database), "user": connection_config.get("username", self.username), "password": connection_config.get("password", self.password), "sslmode": connection_config.get("sslmode", self.sslmode), } def _connect_with_driver( self, *, application_name: str, connect_timeout: int | None = None, tcp_keepalive: bool = False, tcp_keepalive_idle: int | None = None, tcp_keepalive_interval: int | None = None, tcp_keepalive_count: int | None = None, **connection_config, ) -> Any: """Create a Redshift connection using the configured driver.""" params = self._get_connection_params(**connection_config) if redshift_connector: connect_kwargs: dict[str, Any] = { "host": params["host"], "port": params["port"], "database": params["database"], "user": params["user"], "password": params["password"], "ssl": self.ssl_enabled, "application_name": application_name, } if connect_timeout is not None: connect_kwargs["timeout"] = connect_timeout # redshift_connector only implements certificate verification modes explicitly. if self.ssl_enabled and params["sslmode"] in {"verify-ca", "verify-full"}: connect_kwargs["sslmode"] = params["sslmode"] if self.ssl_insecure: connect_kwargs["ssl_insecure"] = True if tcp_keepalive: connect_kwargs["tcp_keepalive"] = True if tcp_keepalive_idle is not None: connect_kwargs["tcp_keepalive_idle"] = tcp_keepalive_idle if tcp_keepalive_interval is not None: connect_kwargs["tcp_keepalive_interval"] = tcp_keepalive_interval if tcp_keepalive_count is not None: connect_kwargs["tcp_keepalive_count"] = tcp_keepalive_count return redshift_connector.connect(**connect_kwargs) # psycopg2 supports keepalives/keepalives_idle/keepalives_interval/keepalives_count # but the old code never passed them; tcp_keepalive params are silently ignored here. connect_kwargs = { "host": params["host"], "port": params["port"], "database": params["database"], "user": params["user"], "password": params["password"], "sslmode": params["sslmode"], "application_name": application_name, } if connect_timeout is not None: connect_kwargs["connect_timeout"] = connect_timeout if self.sslrootcert: connect_kwargs["sslrootcert"] = self.sslrootcert return psycopg2.connect(**connect_kwargs) def _resolve_connect_timeout(self) -> int: """Determine appropriate connection timeout based on cluster state. For provisioned clusters, queries AWS API to detect paused/resuming status and extends the timeout accordingly. For serverless clusters, the compute warmth cannot be queried via API (the workgroup status is always AVAILABLE even when compute is cold), so a fixed extended timeout is used to cover cold-start latency (~15-30s for Serverless). Result is cached on the instance so the API is only hit once per adapter. Returns: Connection timeout in seconds. """ if hasattr(self, "_cached_connect_timeout"): return self._cached_connect_timeout timeout = self._compute_connect_timeout() self._cached_connect_timeout = timeout return timeout def _compute_connect_timeout(self) -> int: """Internal: compute timeout by checking cluster state via AWS API.""" if not boto3 or not self.host: return self.connect_timeout deployment_type = self.deployment_type region = self._extract_region_from_hostname(self.host, deployment_type) identifier = self._extract_identifier_from_hostname(self.host, deployment_type) if not region or not identifier: return self.connect_timeout if deployment_type == "provisioned": try: client = boto3.client("redshift", region_name=region) response = client.describe_clusters(ClusterIdentifier=identifier) clusters = response.get("Clusters", []) if clusters: status = clusters[0].get("ClusterStatus", "unknown").lower() if status == "available": return self.connect_timeout elif status in ("resuming", "rebooting"): self.log_verbose(f"Provisioned cluster is {status} — using extended timeout (120s)") return max(self.connect_timeout, 120) elif status == "paused": self.logger.warning( f"Redshift cluster '{identifier}' is paused. " "Resume it via the AWS console or: " f"aws redshift resume-cluster --cluster-identifier {identifier}" ) return max(self.connect_timeout, 120) except Exception as e: self.logger.debug(f"Could not check provisioned cluster status: {e}") elif deployment_type == "serverless": # Serverless compute warmth is not queryable via the AWS API. # The workgroup is always AVAILABLE even when compute is cold (auto-paused). # Use an extended timeout to accommodate cold-start latency. self.logger.debug("Serverless cluster: using extended timeout (60s) to cover potential cold-start") return max(self.connect_timeout, 60) return self.connect_timeout def _long_running_timeout(self, floor: int = 300) -> int: """Timeout for long-running DDL operations (DROP DATABASE, VACUUM, ANALYZE). Returns the adaptive timeout with a minimum floor, since these operations routinely exceed normal connection timeouts. Uses the adaptive timeout as a base so that cluster-state awareness (paused, resuming, cold-start) is preserved. """ return max(self._resolve_connect_timeout(), floor) def _create_admin_connection(self, **connection_config) -> Any: """Create Redshift connection for admin operations. Admin operations (CREATE DATABASE, DROP DATABASE, checking database existence) require connecting to an existing database. This uses self.admin_database (default: "dev") instead of the target database to avoid circular dependencies. """ connect_timeout = connection_config.pop("connect_timeout", self._resolve_connect_timeout()) admin_config = connection_config.copy() admin_config["database"] = self.admin_database return self._connect_with_driver( application_name="BenchBox-Admin", connect_timeout=connect_timeout, **admin_config, ) def _create_direct_connection(self, **connection_config) -> Any: """Create direct connection to target database for validation. Connects directly to the specified database without: - Calling handle_existing_database() - Creating database if missing - Setting database_was_reused flag Used by validation framework to check existing database compatibility. Args: **connection_config: Connection configuration including database name Returns: Database connection object Raises: Exception: If connection fails (database doesn't exist, auth fails, etc.) """ connect_timeout = connection_config.pop("connect_timeout", self._resolve_connect_timeout()) connection = self._connect_with_driver( application_name="BenchBox-Validation", connect_timeout=connect_timeout, **connection_config, ) # Apply WLM queue settings if configured if self.wlm_query_queue_name: cursor = connection.cursor() try: # Escape single quotes in queue name for SQL safety queue_name_escaped = self.wlm_query_queue_name.replace("'", "''") cursor.execute(f"SET query_group TO '{queue_name_escaped}'") finally: cursor.close() return connection
[docs] def check_server_database_exists(self, **connection_config) -> bool: """Check if database exists in Redshift cluster. Connects to admin database to query pg_database for the target database. """ try: # Connect to admin database (not target database) connection = self._create_admin_connection() cursor = connection.cursor() database = connection_config.get("database", self.database) # Check if database exists cursor.execute("SELECT datname FROM pg_database WHERE datname = %s", (database,)) result = cursor.fetchone() return result is not None except Exception: # If we can't connect or check, assume database doesn't exist return False finally: if "connection" in locals() and connection: connection.close()
[docs] def drop_database(self, **connection_config) -> None: """Drop database in Redshift cluster. Connects to admin database to drop the target database. Note: DROP DATABASE must run with autocommit enabled. Note: Redshift doesn't support IF EXISTS for DROP DATABASE, so we check first. Note: If DROP DATABASE fails with SQLSTATE 55006 (database still has active connections), the method terminates backends again and retries on a *fresh* connection. redshift_connector v2.1.x enters an aborted transaction state after a failed DDL even with autocommit=True, causing any subsequent DDL on the same connection to fail with error 25001. Opening a new connection guarantees clean driver state. """ database = connection_config.get("database", self.database) # Check if database exists first (Redshift doesn't support IF EXISTS) if not self.check_server_database_exists(database=database): self.log_verbose(f"Database {database} does not exist - nothing to drop") return try: # Connect to admin database with extended timeout — DROP DATABASE # can take 30-120+ seconds on Redshift depending on cluster state. connection = self._create_admin_connection(connect_timeout=self._long_running_timeout()) connection.autocommit = True # Enable autocommit for DROP DATABASE # Terminate any existing connections first to avoid "being accessed" errors # and to speed up the DROP. This is safe because we only kill sessions # attached to the *target* database, not the admin database we're on. cursor = connection.cursor() self.log_verbose(f"Terminating existing connections to {database}...") cursor.execute( """ SELECT pg_terminate_backend(procpid) FROM pg_stat_activity WHERE datname = %s AND procpid <> pg_backend_pid() """, (database,), ) # Drop the database (quote identifier for SQL safety) try: cursor.execute(f'DROP DATABASE "{database}"') except Exception as drop_error: # Re-raise if not SQLSTATE 55006 (object_in_use / database has active connections). # Check structured error code first (redshift_connector stores the server dict in # args[0] with key 'C'; psycopg2 exposes it as pgcode). Fall back to substring # matching only for unknown exception types so future driver versions don't silently # skip the retry. sqlstate = None first_arg = drop_error.args[0] if drop_error.args else None if isinstance(first_arg, dict): sqlstate = first_arg.get("C") # redshift_connector wire-protocol dict if sqlstate is None: sqlstate = getattr(drop_error, "pgcode", None) # psycopg2 if sqlstate is not None: if sqlstate != "55006": raise else: # Unknown driver — fall back to message text error_msg = str(drop_error).lower() if "active connection" not in error_msg and "being accessed" not in error_msg: raise # Retry with a fresh connection — redshift_connector can enter an aborted # transaction state after a failed DDL, causing subsequent DDL to fail with # error 25001 even when autocommit=True. A fresh connection guarantees clean state. self.log_verbose("Database still has active connections after terminate, retrying...") cursor.close() try: connection.close() except Exception: pass # best-effort; fresh connection is the goal connection = self._create_admin_connection(connect_timeout=self._long_running_timeout()) connection.autocommit = True cursor = connection.cursor() cursor.execute( """ SELECT pg_terminate_backend(procpid) FROM pg_stat_activity WHERE datname = %s AND procpid <> pg_backend_pid() """, (database,), ) cursor.execute(f'DROP DATABASE "{database}"') except Exception as e: raise RuntimeError(f"Failed to drop Redshift database {database}: {e}") from e finally: if "cursor" in locals(): try: cursor.close() except Exception: pass if "connection" in locals() and connection: try: connection.close() except Exception: pass
[docs] def create_connection(self, **connection_config) -> Any: """Create optimized Redshift connection.""" self.log_operation_start("Redshift connection") # Handle existing database using base class method self.handle_existing_database(**connection_config) # Get connection parameters params = self._get_connection_params(**connection_config) target_database = params.get("database") # Create database if needed (before connecting to it) # Redshift requires connecting to an admin database to create new databases if not self.database_was_reused: # Check if target database exists database_exists = self.check_server_database_exists(database=target_database) if not database_exists: self.log_verbose(f"Creating database: {target_database}") # Create database using admin connection (connects to self.admin_database) # Note: CREATE DATABASE must run with autocommit enabled (cannot run in transaction block) try: admin_conn = self._create_admin_connection() admin_conn.autocommit = True # Enable autocommit for CREATE DATABASE admin_cursor = admin_conn.cursor() try: # Quote identifier for SQL safety admin_cursor.execute(f'CREATE DATABASE "{target_database}"') # No commit() needed - autocommit handles it automatically self.logger.info(f"Created database {target_database}") finally: admin_cursor.close() admin_conn.close() except Exception as e: self.logger.error(f"Failed to create database {target_database}: {e}") raise self.log_very_verbose(f"Redshift connection params: host={params.get('host')}, database={target_database}") try: connect_timeout = connection_config.get("connect_timeout", self._resolve_connect_timeout()) connect_config = {k: v for k, v in connection_config.items() if k != "connect_timeout"} connection = self._connect_with_driver( application_name="BenchBox", connect_timeout=connect_timeout, tcp_keepalive=True, tcp_keepalive_idle=600, tcp_keepalive_interval=30, tcp_keepalive_count=3, **connect_config, ) # Enable autocommit immediately after connection creation (before any SQL operations) connection.autocommit = True # Apply WLM settings and schema search path cursor = connection.cursor() # Integer settings validated in __init__, safe to interpolate if self.wlm_query_slot_count > 1: cursor.execute(f"SET wlm_query_slot_count = {int(self.wlm_query_slot_count)}") if self.statement_timeout > 0: cursor.execute(f"SET statement_timeout = {int(self.statement_timeout)}") # Set search_path to ensure all unqualified table references use correct schema # Critical for database reuse when schema already exists but connection is new # Quote identifier for SQL safety cursor.execute(f'SET search_path TO "{self.schema}"') # Test connection cursor.execute("SELECT version()") cursor.fetchone() cursor.close() self.logger.info(f"Connected to Redshift cluster at {params['host']}:{params['port']}") self.log_operation_complete( "Redshift connection", details=f"Connected to {params['host']}:{params['port']}" ) return connection except Exception as e: self.logger.error(f"Failed to connect to Redshift: {e}") raise
[docs] def create_schema(self, benchmark, connection: Any) -> float: """Create schema using Redshift-optimized table definitions.""" start_time = mono_time() cursor = connection.cursor() try: # Create schema if needed (if not using default "public") # Quote identifiers for SQL safety if self.schema and self.schema.lower() != "public": self.log_verbose(f"Creating schema: {self.schema}") cursor.execute(f'CREATE SCHEMA IF NOT EXISTS "{self.schema}"') self.logger.info(f"Created schema {self.schema}") # Set search_path to use the correct schema self.log_very_verbose(f"Setting search_path to: {self.schema}") cursor.execute(f'SET search_path TO "{self.schema}"') # Use common schema creation helper schema_sql = self._create_schema_with_tuning(benchmark, source_dialect="duckdb") # Split schema into individual statements and execute statements = [stmt.strip() for stmt in schema_sql.split(";") if stmt.strip()] for statement in statements: # Normalize table names to lowercase for Redshift consistency # This ensures CREATE, COPY, and SELECT all use the same case statement = self._normalize_table_name_in_sql(statement) # Ensure idempotency with DROP TABLE IF EXISTS # (Redshift doesn't support CREATE OR REPLACE TABLE) if statement.upper().startswith("CREATE TABLE"): # Extract table name from CREATE TABLE statement table_name = self._extract_table_name(statement) if table_name: # Ensure table name is lowercase table_name_lower = table_name.strip('"').lower() drop_statement = f"DROP TABLE IF EXISTS {table_name_lower}" cursor.execute(drop_statement) self.logger.debug(f"Executed: {drop_statement}") # Optimize table definition for Redshift statement = self._optimize_table_definition(statement) cursor.execute(statement) self.logger.debug(f"Executed schema statement: {statement[:100]}...") self.logger.info("Schema created") except Exception as e: self.logger.error(f"Schema creation failed: {e}") raise finally: cursor.close() return elapsed_seconds(start_time)
def _resolve_data_files(self, benchmark, data_dir: Path) -> dict: """Resolve data files from benchmark tables or manifest fallback.""" resolver = DataSourceResolver( platform_name=self.platform_name, table_mode=getattr(self, "table_mode", "native"), platform_config=getattr(self, "__dict__", None), ) data_source = resolver.resolve(benchmark, data_dir) if self._should_use_manifest_selected_files(data_source): manifest_source = next((p for p in resolver.providers if isinstance(p, ManifestFileSource)), None) if manifest_source is not None: manifest_data_source = manifest_source.get_data_source(benchmark, data_dir) if manifest_data_source and manifest_data_source.tables: self.log_verbose( "Redshift native mode detected directory-valued benchmark tables; " "using manifest-selected native files instead" ) data_source = manifest_data_source if not data_source or not data_source.tables: raise ValueError("No data files found. Ensure benchmark.generate_data() was called first.") return data_source.tables def _should_use_manifest_selected_files(self, data_source: Any) -> bool: """Use manifest-selected files when native benchmark tables only point at directories.""" if getattr(self, "table_mode", "native") == "external": return False if not data_source or getattr(data_source, "source_type", None) not in { "benchmark_tables", "benchmark_impl_tables", }: return False for table_paths in data_source.tables.values(): normalized_paths = table_paths if isinstance(table_paths, list) else [table_paths] for path_like in normalized_paths: path = Path(path_like) if path.exists() and path.is_dir(): return True return False def _create_s3_client(self): """Create S3 client with explicit error handling.""" try: # XOR check: reject partial credentials — both key ID and secret must be # provided together, or both omitted (to fall back to environment/IAM). if bool(self.aws_access_key_id) != bool(self.aws_secret_access_key): raise ValueError( "Explicit S3 upload credentials require both aws_access_key_id and aws_secret_access_key." ) session = boto3.Session( aws_access_key_id=self.aws_access_key_id, aws_secret_access_key=self.aws_secret_access_key, aws_session_token=self.aws_session_token, region_name=self.aws_region, ) if session.get_credentials() is None: raise ValueError( "AWS credentials not found. Configure credentials via:\n" " 1. Explicit S3 upload credentials in config\n" " 2. AWS CLI (aws configure)\n" " 3. Environment variables (AWS_ACCESS_KEY_ID/AWS_SECRET_ACCESS_KEY)" ) # request_checksum_calculation="when_required" prevents botocore from wrapping # upload streams in AwsChunkedWrapper (introduced in botocore 1.35+). Without this, # retried uploads fail with "Need to rewind the stream... but stream is not seekable". return session.client( "s3", config=BotocoreConfig(request_checksum_calculation="when_required"), ) except NoCredentialsError as e: raise ValueError( "AWS credentials not found. Configure credentials via:\n" " 1. IAM role (aws_access_key_id/aws_secret_access_key in config)\n" " 2. AWS CLI (aws configure)\n" " 3. Environment variables (AWS_ACCESS_KEY_ID/AWS_SECRET_ACCESS_KEY)" ) from e except ClientError as e: raise ValueError(f"Failed to create AWS S3 client: {e}") from e def _upload_file_to_s3(self, s3_client, file_path: Path, table_name: str, file_idx: int) -> str: """Upload a single file to S3 and return the S3 URI.""" # Preserve full multi-part suffix for chunked files (e.g., .tbl.1.zst) file_stem = file_path.stem original_suffix = file_path.suffix if "." in file_stem: parts = file_path.name.split(".", 1) full_suffix = "." + parts[1] if len(parts) > 1 else original_suffix else: full_suffix = original_suffix s3_key = f"{self.s3_prefix}/{table_name}_{file_idx}{full_suffix}" try: s3_client.upload_file(str(file_path), self.s3_bucket, s3_key) except ClientError as e: error_code = e.response.get("Error", {}).get("Code", "Unknown") if error_code == "NoSuchBucket": raise ValueError( f"S3 bucket '{self.s3_bucket}' does not exist. Create the bucket or update your configuration." ) from e elif error_code == "AccessDenied": raise ValueError( f"Access denied to S3 bucket '{self.s3_bucket}'. Check IAM permissions for s3:PutObject." ) from e else: raise ValueError( f"Failed to upload {file_path.name} to s3://{self.s3_bucket}/{s3_key}: {error_code} - {e}" ) from e return f"s3://{self.s3_bucket}/{s3_key}" def _build_s3_copy_source(self, s3_client, s3_uris: list[str], table_name: str) -> tuple[str, str]: """Build COPY source path and manifest option from S3 URIs. Returns: Tuple of (copy_from_path, manifest_option) """ if len(s3_uris) > 1: manifest = {"entries": [{"url": uri, "mandatory": True} for uri in s3_uris]} manifest_key = f"{self.s3_prefix}/{table_name}_manifest.json" try: s3_client.put_object( Bucket=self.s3_bucket, Key=manifest_key, Body=json.dumps(manifest), ) except ClientError as e: error_code = e.response.get("Error", {}).get("Code", "Unknown") raise ValueError( f"Failed to upload manifest file to s3://{self.s3_bucket}/{manifest_key}: {error_code} - {e}" ) from e return f"s3://{self.s3_bucket}/{manifest_key}", "manifest" else: return s3_uris[0], "" @staticmethod def _sanitize_copy_credential(value: str, name: str) -> str: """Reject credential values containing single quotes to prevent COPY SQL injection.""" if "'" in value: raise ValueError( f"Credential '{name}' contains a single-quote character, which is not allowed in COPY SQL." ) return value def _get_copy_credentials_clause(self) -> str: """Build credentials clause for Redshift COPY command.""" if self.iam_role: return f"IAM_ROLE '{self._sanitize_copy_credential(self.iam_role, 'iam_role')}'" elif self.aws_access_key_id and self.aws_secret_access_key: key_id = self._sanitize_copy_credential(self.aws_access_key_id, "aws_access_key_id") secret = self._sanitize_copy_credential(self.aws_secret_access_key, "aws_secret_access_key") if self.aws_session_token: token = self._sanitize_copy_credential(self.aws_session_token, "aws_session_token") return f"ACCESS_KEY_ID '{key_id}' SECRET_ACCESS_KEY '{secret}' SESSION_TOKEN '{token}'" return f"ACCESS_KEY_ID '{key_id}' SECRET_ACCESS_KEY '{secret}'" else: self.log_verbose("No explicit credentials configured for COPY; using cluster default IAM role") return "" def _load_table_via_s3(self, cursor, s3_client, table_name: str, valid_files: list[Path], connection: Any) -> int: """Upload files to S3 and load a single table via COPY command. Returns row count.""" table_name_lower = table_name.lower() parquet_modes = {is_parquet_format(file_path) for file_path in valid_files} if len(parquet_modes) > 1: raise ValueError( f"Redshift COPY requires a uniform source format per table; received mixed parquet and delimited " f"files for '{table_name_lower}'." ) is_parquet_copy = parquet_modes == {True} # Upload all files to S3 and collect S3 URIs s3_uris = [] for file_idx, file_path in enumerate(valid_files): s3_uris.append(self._upload_file_to_s3(s3_client, file_path, table_name, file_idx)) copy_from_path, manifest_option = self._build_s3_copy_source(s3_client, s3_uris, table_name) # Detect compression format from file extension compressions = {detect_compression(f) for f in valid_files} if "zstd" in compressions: compression_option = "ZSTD" elif "gzip" in compressions: compression_option = "GZIP" else: compression_option = "" credentials_clause = self._get_copy_credentials_clause() # Build and execute COPY command qualified_table = f"{self.schema}.{table_name_lower}" if is_parquet_copy: copy_sql = f""" COPY {qualified_table} FROM '{copy_from_path}' {credentials_clause} {manifest_option} FORMAT AS PARQUET """ else: delimiter = get_delimiter_for_file(valid_files[0]) copy_sql = f""" COPY {qualified_table} FROM '{copy_from_path}' {credentials_clause} {manifest_option} {compression_option} DELIMITER '{delimiter}' IGNOREHEADER 0 COMPUPDATE {self.compupdate} """ cursor.execute(copy_sql) effective_tuning = self.get_effective_tuning_configuration() if effective_tuning is not None: self.apply_ctas_sort(table_name_lower, effective_tuning, connection) # Get row count cursor.execute(f"SELECT COUNT(*) FROM {qualified_table}") row_count = cursor.fetchone()[0] # Run ANALYZE if configured if self.auto_analyze: cursor.execute(f"ANALYZE {qualified_table}") return row_count def _load_table_via_insert(self, cursor, table_name: str, valid_files: list[Path], connection: Any) -> int: """Load a single table via INSERT statements. Returns total rows loaded.""" table_name_lower = table_name.lower() total_rows_loaded = 0 for file_idx, file_path in enumerate(valid_files): chunk_info = f" (chunk {file_idx + 1}/{len(valid_files)})" if len(valid_files) > 1 else "" self.log_very_verbose(f"Loading {table_name}{chunk_info} from {file_path.name}") delimiter = get_delimiter_for_file(file_path) compression_handler = FileFormatRegistry.get_compression_handler(file_path) with compression_handler.open(file_path) as f: rows_loaded = 0 batch_size = 1000 batch_data = [] for line in f: line = line.strip() if line and line.endswith(delimiter): line = line[:-1] values = line.split(delimiter) escaped_values = ["'" + str(v).replace("'", "''") + "'" for v in values] batch_data.append(f"({', '.join(escaped_values)})") if len(batch_data) >= batch_size: insert_sql = f"INSERT INTO {table_name_lower} VALUES " + ", ".join(batch_data) cursor.execute(insert_sql) rows_loaded += len(batch_data) total_rows_loaded += len(batch_data) batch_data = [] # Insert remaining batch if batch_data: insert_sql = f"INSERT INTO {table_name_lower} VALUES " + ", ".join(batch_data) cursor.execute(insert_sql) rows_loaded += len(batch_data) total_rows_loaded += len(batch_data) self.log_very_verbose(f"Loaded {rows_loaded:,} rows from {file_path.name}") effective_tuning = self.get_effective_tuning_configuration() if effective_tuning is not None: self.apply_ctas_sort(table_name_lower, effective_tuning, connection) return total_rows_loaded def _filter_valid_files(self, file_paths) -> list[Path]: """Normalize to list and filter out non-existent or empty files.""" if not isinstance(file_paths, list): file_paths = [file_paths] valid_files = [] for file_path in file_paths: file_path = Path(file_path) if file_path.exists() and file_path.stat().st_size > 0: valid_files.append(file_path) return valid_files
[docs] def validate_external_table_requirements(self) -> None: """Validate prerequisites for Redshift external table mode.""" if not self.s3_bucket: raise ValueError( "Redshift external mode requires S3 staging (set --platform-option s3_bucket=<bucket> " "or provide --platform-option staging_root=s3://bucket/path)." ) if not self.iam_role: raise ValueError( "Redshift external mode requires IAM role credentials for Spectrum " "(set --platform-option iam_role=<arn>)." )
@staticmethod def _map_external_column_type(column_type: str) -> str: """Map benchmark schema types to Redshift Spectrum-compatible types.""" normalized = str(column_type).strip().upper() if not normalized: return "VARCHAR(65535)" if normalized.startswith(("DECIMAL", "NUMERIC", "VARCHAR", "CHAR", "TIMESTAMP")): return normalized if "BIGINT" in normalized: return "BIGINT" if "SMALLINT" in normalized: return "SMALLINT" if "INT" in normalized: return "INTEGER" if "DOUBLE" in normalized: return "DOUBLE PRECISION" if "REAL" in normalized or "FLOAT" in normalized: return "REAL" if "DATE" in normalized: return "DATE" if "BOOLEAN" in normalized or normalized == "BOOL": return "BOOLEAN" return "VARCHAR(65535)" def _build_external_column_definitions(self, benchmark: Any, table_name: str) -> str: """Build external table column definitions from benchmark schema.""" if not hasattr(benchmark, "get_schema"): raise ValueError( f"Benchmark schema metadata unavailable for '{table_name}'. " "Redshift external mode requires benchmark.get_schema()." ) schema = benchmark.get_schema() or {} table_schema = schema.get(table_name) or schema.get(table_name.lower()) or schema.get(table_name.upper()) if not table_schema: raise ValueError(f"No schema definition found for table '{table_name}'.") columns = table_schema.get("columns", []) if not columns: raise ValueError(f"No columns found in schema definition for table '{table_name}'.") column_defs = [] for column in columns: column_name = str(column.get("name", "")).strip().lower() if not column_name: continue mapped_type = self._map_external_column_type(str(column.get("type", ""))) column_defs.append(f"{column_name} {mapped_type}") if not column_defs: raise ValueError(f"No valid column definitions were generated for table '{table_name}'.") return ", ".join(column_defs) def _upload_external_parquet_files_to_s3(self, s3_client: Any, table_name: str, parquet_files: list[Path]) -> str: """Upload table Parquet files and return Spectrum LOCATION prefix.""" table_name_lower = table_name.lower() s3_prefix = f"{self.s3_prefix}/{self.database.lower()}_external/{table_name_lower}" for file_path in parquet_files: s3_key = f"{s3_prefix}/{file_path.name}" s3_client.upload_file(str(file_path), self.s3_bucket, s3_key) return f"s3://{self.s3_bucket}/{s3_prefix}/" def _upload_external_directory_to_s3(self, s3_client: Any, table_name: str, directory: Path) -> str: """Upload a directory tree for external Delta-style registration.""" table_name_lower = table_name.lower() s3_prefix = f"{self.s3_prefix}/{self.database.lower()}_external/{table_name_lower}" for file_path in directory.rglob("*"): if not file_path.is_file(): continue relative = file_path.relative_to(directory).as_posix() s3_client.upload_file(str(file_path), self.s3_bucket, f"{s3_prefix}/{relative}") return f"s3://{self.s3_bucket}/{s3_prefix}/"
[docs] def create_external_tables( self, benchmark: Any, connection: Any, data_dir: Path ) -> tuple[dict[str, int], float, dict[str, Any] | None]: """Upload external-table sources to S3 and register Redshift Spectrum tables.""" self.validate_external_table_requirements() assert self.iam_role is not None start_time = mono_time() table_stats: dict[str, int] = {} # Derive a unique external schema name; warn if the base name already # looks like a previous external-mode artifact to avoid cascading suffixes. if self.schema.endswith("_external"): self.logger.warning( f"Schema '{self.schema}' already ends with '_external'. " "The derived external schema will be '%s_external'.", self.schema, ) external_schema = f"{self.schema}_external" escaped_iam_role = self.iam_role.replace("'", "''") escaped_db = self.database.lower().replace("'", "''") cursor = connection.cursor() try: data_files = self._resolve_data_files(benchmark, data_dir) s3_client = self._create_s3_client() cursor.execute( f""" CREATE EXTERNAL SCHEMA IF NOT EXISTS {external_schema} FROM DATA CATALOG DATABASE '{escaped_db}_external' IAM_ROLE '{escaped_iam_role}' CREATE EXTERNAL DATABASE IF NOT EXISTS """ ) for table_name, file_paths in data_files.items(): table_name_lower = table_name.lower() valid_files = self._filter_valid_files(file_paths) delta_dirs = [path for path in valid_files if path.is_dir() and (path / "_delta_log").is_dir()] parquet_files = [path for path in valid_files if path.suffix.lower() == ".parquet"] if delta_dirs: location = self._upload_external_directory_to_s3(s3_client, table_name_lower, delta_dirs[0]) source_format = "delta" elif parquet_files: location = self._upload_external_parquet_files_to_s3(s3_client, table_name_lower, parquet_files) source_format = "parquet" else: raise ValueError( f"Redshift external mode requires Parquet files or Delta directories for table " f"'{table_name_lower}'. No supported sources were found." ) column_defs = self._build_external_column_definitions(benchmark, table_name_lower) cursor.execute(f"DROP TABLE IF EXISTS {external_schema}.{table_name_lower}") if source_format == "delta": cursor.execute( f""" CREATE EXTERNAL TABLE {external_schema}.{table_name_lower} ({column_defs}) STORED AS PARQUET LOCATION '{location}' TABLE PROPERTIES ('table_type'='DELTA') """ ) else: cursor.execute( f""" CREATE EXTERNAL TABLE {external_schema}.{table_name_lower} ({column_defs}) STORED AS PARQUET LOCATION '{location}' """ ) cursor.execute(f"SELECT COUNT(*) FROM {external_schema}.{table_name_lower}") result = cursor.fetchone() table_stats[table_name_lower] = int(result[0]) if result else 0 finally: cursor.close() total_time = elapsed_seconds(start_time) return table_stats, total_time, None
[docs] def load_data( self, benchmark, connection: Any, data_dir: Path ) -> tuple[dict[str, int], float, dict[str, Any] | None]: """Load data using Redshift COPY command with S3 integration.""" start_time = mono_time() table_stats = {} cursor = connection.cursor() try: data_files = self._resolve_data_files(benchmark, data_dir) # Upload files to S3 and load via COPY command if self.s3_bucket and boto3: s3_client = self._create_s3_client() for table_name, file_paths in data_files.items(): valid_files = self._filter_valid_files(file_paths) if not valid_files: self.logger.warning(f"Skipping {table_name} - no valid data files") table_stats[table_name.lower()] = 0 continue chunk_info = f" from {len(valid_files)} file(s)" if len(valid_files) > 1 else "" self.log_verbose(f"Loading data for table: {table_name}{chunk_info}") try: load_start = mono_time() row_count = self._load_table_via_s3(cursor, s3_client, table_name, valid_files, connection) table_stats[table_name.lower()] = row_count load_time = elapsed_seconds(load_start) self.logger.info( f"✅ Loaded {row_count:,} rows into {table_name.lower()}{chunk_info} in {load_time:.2f}s" ) except Exception as e: self.logger.error(f"Failed to load {table_name}: {str(e)[:100]}...") table_stats[table_name.lower()] = 0 else: # Direct loading without S3 (less efficient) self.logger.warning("No S3 bucket configured, using direct INSERT loading") for table_name, file_paths in data_files.items(): valid_files = self._filter_valid_files(file_paths) if not valid_files: self.logger.warning(f"Skipping {table_name} - no valid data files") table_stats[table_name.lower()] = 0 continue try: self.log_verbose(f"Direct loading data for table: {table_name}") load_start = mono_time() total_rows_loaded = self._load_table_via_insert(cursor, table_name, valid_files, connection) table_stats[table_name.lower()] = total_rows_loaded load_time = elapsed_seconds(load_start) chunk_info = f" from {len(valid_files)} file(s)" if len(valid_files) > 1 else "" self.logger.info( f"✅ Loaded {total_rows_loaded:,} rows into {table_name.lower()}{chunk_info} in {load_time:.2f}s" ) except Exception as e: self.logger.error(f"Failed to load {table_name}: {str(e)[:100]}...") table_stats[table_name.lower()] = 0 total_time = elapsed_seconds(start_time) total_rows = sum(table_stats.values()) self.logger.info(f"✅ Loaded {total_rows:,} total rows in {total_time:.2f}s") except Exception as e: self.logger.error(f"Data loading failed: {e}") raise finally: cursor.close() # Redshift doesn't provide detailed per-table timings yet return table_stats, total_time, None
[docs] def configure_for_benchmark(self, connection: Any, benchmark_type: str) -> None: """Apply Redshift-specific optimizations based on benchmark type.""" cursor = connection.cursor() try: # Set session-level optimizations # Use OFF for result cache to ensure accurate benchmark measurements cache_setting = "OFF" if self.disable_result_cache else "ON" optimization_settings = [ f"SET enable_result_cache_for_session TO {cache_setting}", "SET query_group TO 'benchbox'", "SET statement_timeout TO '1800000'", # 30 minutes ] if benchmark_type.lower() in ["olap", "analytics", "tpch", "tpcds"]: # OLAP-specific optimizations optimization_settings.extend( [ "SET enable_case_sensitive_identifier TO OFF", "SET datestyle TO 'ISO, MDY'", "SET extra_float_digits TO 0", ] ) critical_failures = [] for setting in optimization_settings: try: cursor.execute(setting) self.logger.debug(f"Applied setting: {setting}") except Exception as e: # Track if critical cache control setting failed if "enable_result_cache_for_session" in setting: critical_failures.append(setting) self.logger.warning(f"Failed to apply setting {setting}: {e}") # Validate cache control settings were successfully applied if self.disable_result_cache or critical_failures: self.logger.debug("Validating cache control settings...") validation_result = self.validate_session_cache_control(connection) if not validation_result["validated"]: self.logger.warning(f"Cache control validation failed: {validation_result.get('errors', [])}") else: self.logger.info( f"Cache control validated successfully: cache_disabled={validation_result['cache_disabled']}" ) # Run VACUUM and ANALYZE on all tables if configured. # These operations use a **separate connection** because they are # long-running DDL operations that can trigger socket-level timeouts # (e.g. on paused serverless clusters). A socket timeout permanently # breaks the underlying TCP connection — rollback() cannot recover it # because the socket itself is dead. By isolating VACUUM/ANALYZE on # their own connection, a timeout only destroys the disposable # connection while the main benchmark connection stays healthy. if self.auto_vacuum or self.auto_analyze: self._run_vacuum_analyze_isolated(connection) finally: cursor.close()
def _run_vacuum_analyze_isolated(self, main_connection: Any) -> None: """Run VACUUM/ANALYZE on a dedicated connection to protect the main one. VACUUM and ANALYZE are long-running DDL operations that can trigger socket-level timeouts, especially on paused serverless clusters. A socket timeout permanently breaks the TCP connection, so these operations must be isolated from the benchmark connection. Args: main_connection: The main benchmark connection (used only to query the table list; VACUUM/ANALYZE run on a separate conn). """ # Query table list using the main connection (fast metadata query) main_cursor = main_connection.cursor() try: main_cursor.execute(f""" SELECT schemaname, tablename FROM pg_tables WHERE schemaname = '{self.schema}' """) tables = main_cursor.fetchall() finally: main_cursor.close() if not tables: return # Create a separate connection for VACUUM/ANALYZE. # The finally block is the single owner of cleanup — error paths # just return and let finally close the connection. maint_conn = None try: maint_conn = self._connect_with_driver( application_name="BenchBox-Maintenance", connect_timeout=self._long_running_timeout(), ) maint_conn.autocommit = True maint_cursor = maint_conn.cursor() # Set search_path to match the main connection maint_cursor.execute(f'SET search_path TO "{self.schema}"') for _schema, table in tables: ops = [] if self.auto_vacuum: ops.append(("VACUUM", f'VACUUM "{table}"')) if self.auto_analyze: ops.append(("ANALYZE", f'ANALYZE "{table}"')) for op_name, sql in ops: try: maint_cursor.execute(sql) except Exception as e: self.logger.warning(f"{op_name} failed for {table}: {e}") self.logger.warning( f"Maintenance connection lost during {op_name} — skipping remaining VACUUM/ANALYZE" ) return except Exception as e: self.logger.warning(f"Could not create maintenance connection for VACUUM/ANALYZE: {e}") finally: if maint_conn is not None: try: maint_conn.close() except Exception: pass
[docs] def validate_session_cache_control(self, connection: Any) -> dict[str, Any]: """Validate that session-level cache control settings were successfully applied. Args: connection: Active Redshift database connection Returns: dict with: - validated: bool - Whether validation passed - cache_disabled: bool - Whether cache is actually disabled - settings: dict - Actual session settings - warnings: list[str] - Any validation warnings - errors: list[str] - Any validation errors Raises: ConfigurationError: If cache control validation fails and strict_validation=True """ from benchbox.platforms.cloud_shared import validate_session_cache_control return validate_session_cache_control( connection=connection, query="SELECT current_setting('enable_result_cache_for_session') as value", setting_key="enable_result_cache_for_session", disabled_value="off", enabled_value="on", normalize="lower", platform_name="Redshift", disable_result_cache=self.disable_result_cache, strict_validation=self.strict_validation, adapter_logger=self.logger, )
[docs] def execute_query( self, connection: Any, query: str, query_id: str, benchmark_type: str | None = None, scale_factor: float | None = None, validate_row_count: bool = True, stream_id: int | None = None, ) -> dict[str, Any]: """Execute query with detailed timing and performance tracking.""" start_time = mono_time() cursor = connection.cursor() try: # Execute the query # Note: Query dialect translation is now handled automatically by the base adapter cursor.execute(query) result = cursor.fetchall() execution_time = elapsed_seconds(start_time) actual_row_count = len(result) if result else 0 # Get query statistics try: query_stats = self._get_query_statistics(connection, query_id) # Add execution time for cost calculation query_stats["execution_time_seconds"] = execution_time except Exception: query_stats = {"execution_time_seconds": execution_time} # Validate row count if enabled and benchmark type is provided validation_result = None if validate_row_count and benchmark_type: from benchbox.core.validation.query_validation import QueryValidator validator = QueryValidator() validation_result = validator.validate_query_result( benchmark_type=benchmark_type, query_id=query_id, actual_row_count=actual_row_count, scale_factor=scale_factor, stream_id=stream_id, ) # Log validation result if validation_result.warning_message: self.log_verbose(f"Row count validation: {validation_result.warning_message}") elif not validation_result.is_valid: self.log_verbose(f"Row count validation FAILED: {validation_result.error_message}") else: self.log_very_verbose( f"Row count validation PASSED: {actual_row_count} rows " f"(expected: {validation_result.expected_row_count})" ) # Use base helper to build result with consistent validation field mapping result_dict = self._build_query_result_with_validation( query_id=query_id, execution_time=execution_time, actual_row_count=actual_row_count, first_row=result[0] if result else None, validation_result=validation_result, ) # Include Redshift-specific fields result_dict["translated_query"] = None # Translation handled by base adapter result_dict["query_statistics"] = query_stats # Map query_statistics to resource_usage for cost calculation result_dict["resource_usage"] = query_stats return result_dict except Exception as e: execution_time = elapsed_seconds(start_time) return { "query_id": query_id, "status": "FAILED", "execution_time_seconds": execution_time, "rows_returned": 0, "error": str(e), "error_type": type(e).__name__, } finally: cursor.close()
def _extract_table_name(self, statement: str) -> str | None: """Extract table name from CREATE TABLE statement.""" from benchbox.core.sql_utils import extract_table_name return extract_table_name(statement) def _normalize_table_name_in_sql(self, sql: str) -> str: """Normalize table names in SQL to lowercase for Redshift.""" return normalize_table_name_in_sql(sql) def _optimize_table_definition(self, statement: str) -> str: """Optimize table definition for Redshift.""" if not statement.upper().startswith("CREATE TABLE"): return statement # Include distribution and sort keys for better performance if "DISTSTYLE" not in statement.upper(): # Include AUTO distribution style (Redshift will choose appropriate distribution) statement += " DISTSTYLE AUTO" if "SORTKEY" not in statement.upper(): # Include sort key on first column (simple heuristic) # In production, this would be more sophisticated statement += " SORTKEY AUTO" return statement def _get_platform_metadata(self, connection: Any) -> dict[str, Any]: """Get Redshift-specific metadata and system information.""" metadata = { "platform": self.platform_name, "host": self.host, "port": self.port, "database": self.database, "result_cache_enabled": not self.disable_result_cache, } cursor = connection.cursor() try: # Get Redshift version cursor.execute("SELECT version()") result = cursor.fetchone() metadata["redshift_version"] = result[0] if result else "unknown" # Get cluster information (STV tables unavailable on Serverless) try: if self.deployment_type == "serverless": cursor.execute(""" SELECT compute_capacity FROM sys_serverless_usage ORDER BY start_time DESC LIMIT 1 """) result = cursor.fetchone() if result: metadata["cluster_info"] = { "deployment_type": "serverless", "compute_capacity_rpu": result[0], } else: cursor.execute(""" SELECT node_type, num_nodes, cluster_version, publicly_accessible FROM stv_cluster_configuration LIMIT 1 """) result = cursor.fetchone() if result: metadata["cluster_info"] = { "deployment_type": "provisioned", "node_type": result[0], "num_nodes": result[1], "cluster_version": result[2], "publicly_accessible": result[3], } except Exception as e: self.logger.debug(f"Could not query cluster information: {e}") # Get current session information cursor.execute(""" SELECT current_user, current_database(), current_schema(), inet_client_addr(), inet_client_port() """) result = cursor.fetchone() if result: metadata["session_info"] = { "current_user": result[0], "current_database": result[1], "current_schema": result[2], "client_addr": result[3], "client_port": result[4], } # Get table information cursor.execute(f""" SELECT schemaname, tablename, tableowner, tablespace, hasindexes, hasrules, hastriggers FROM pg_tables WHERE schemaname = '{self.schema}' """) tables = cursor.fetchall() metadata["tables"] = [ { "schema": row[0], "table": row[1], "owner": row[2], "tablespace": row[3], "has_indexes": row[4], "has_rules": row[5], "has_triggers": row[6], } for row in tables ] except Exception as e: metadata["metadata_error"] = str(e) finally: cursor.close() return metadata def _get_existing_tables(self, connection: Any) -> list[str]: """Get list of existing tables from Redshift (normalized to lowercase). Queries information_schema for tables in the current schema and returns them as lowercase names to match Redshift's identifier normalization. Args: connection: Database connection Returns: List of table names (lowercase) """ cursor = connection.cursor() try: # Query information_schema for tables in current schema cursor.execute( """ SELECT table_name FROM information_schema.tables WHERE table_schema = %s AND table_type = 'BASE TABLE' """, (self.schema,), ) # Return lowercase table names return [row[0].lower() for row in cursor.fetchall()] finally: cursor.close() def _get_query_statistics(self, connection: Any, query_id: str) -> dict[str, Any]: """Get query statistics from Redshift system tables. Note: This retrieves performance telemetry only - not used for validation. Query row count validation uses the actual result set from cursor.fetchall(). """ cursor = connection.cursor() try: # Get the actual Redshift query ID for the most recent query # pg_last_query_id() returns the query ID of the last executed query in this session cursor.execute("SELECT pg_last_query_id()") result = cursor.fetchone() if not result or result[0] == -1: # No queries executed yet, or query ran only on leader node return {} redshift_query_id = result[0] # Query system tables for statistics (STL unavailable on Serverless) if self.deployment_type == "serverless": cursor.execute( """ SELECT query_id, DATEDIFF('microseconds', start_time, end_time) as duration_microsecs, DATEDIFF('microseconds', start_time, end_time) as cpu_time_microsecs, 0 as bytes_scanned, 0 as bytes_returned, 1 as slots, 1 as wlm_slots, status FROM sys_query_history WHERE query_id = %s ORDER BY start_time DESC LIMIT 1 """, (redshift_query_id,), ) else: cursor.execute( """ SELECT query, DATEDIFF('microseconds', starttime, endtime) as duration_microsecs, DATEDIFF('microseconds', starttime, endtime) as cpu_time_microsecs, 0 as bytes_scanned, 0 as bytes_returned, 1 as slots, 1 as wlm_slots, aborted FROM stl_query WHERE query = %s ORDER BY starttime DESC LIMIT 1 """, (redshift_query_id,), ) result = cursor.fetchone() if result: # sys_query_history returns status string; stl_query returns aborted int raw_aborted = result[7] if isinstance(raw_aborted, str): aborted = raw_aborted not in ("success", "running") else: aborted = bool(raw_aborted) if raw_aborted is not None else False return { "query_id": str(result[0]), "duration_microsecs": result[1] or 0, "cpu_time_microsecs": result[2] or 0, "bytes_scanned": result[3] or 0, "bytes_returned": result[4] or 0, "slots": result[5] or 1, "wlm_slots": result[6] or 1, "aborted": aborted, } else: return {} except Exception: return {} finally: cursor.close()
[docs] def analyze_table(self, connection: Any, table_name: str) -> None: """Run ANALYZE on table for query optimization.""" cursor = connection.cursor() try: cursor.execute(f"ANALYZE {table_name.lower()}") except Exception as e: self.logger.warning(f"Failed to analyze table {table_name}: {e}") finally: cursor.close()
[docs] def vacuum_table(self, connection: Any, table_name: str) -> None: """Run VACUUM on table for space reclamation.""" cursor = connection.cursor() try: cursor.execute(f"VACUUM {table_name.lower()}") except Exception as e: self.logger.warning(f"Failed to vacuum table {table_name}: {e}") finally: cursor.close()
[docs] def get_query_plan(self, connection: Any, query: str) -> str: """Get query execution plan for analysis.""" cursor = connection.cursor() try: # Note: Query dialect translation is now handled automatically by the base adapter cursor.execute(f"EXPLAIN {query}") plan_rows = cursor.fetchall() return "\n".join([row[0] for row in plan_rows]) except Exception as e: return f"Could not get query plan: {e}" finally: cursor.close()
[docs] def close_connection(self, connection: Any) -> None: """Close Redshift connection.""" try: if connection and hasattr(connection, "close"): connection.close() except Exception as e: self.logger.warning(f"Error closing connection: {e}")
[docs] def supports_tuning_type(self, tuning_type) -> bool: """Check if Redshift supports a specific tuning type. Redshift supports: - DISTRIBUTION: Via DISTSTYLE and DISTKEY clauses - SORTING: Via SORTKEY clause (compound and interleaved) - PARTITIONING: Through table design patterns and date partitioning Args: tuning_type: The type of tuning to check support for Returns: True if the tuning type is supported by Redshift """ # Import here to avoid circular imports try: from benchbox.core.tuning.interface import TuningType return tuning_type in { TuningType.DISTRIBUTION, TuningType.SORTING, TuningType.PARTITIONING, } except ImportError: return False
[docs] def generate_tuning_clause(self, table_tuning) -> str: """Generate Redshift-specific tuning clauses for CREATE TABLE statements. Redshift supports: - DISTSTYLE (EVEN | KEY | ALL) DISTKEY (column) - SORTKEY (column1, column2, ...) or INTERLEAVED SORTKEY (column1, column2, ...) Args: table_tuning: The tuning configuration for the table Returns: SQL clause string to be appended to CREATE TABLE statement """ if not table_tuning or not table_tuning.has_any_tuning(): return "" clauses = [] try: # Import here to avoid circular imports from benchbox.core.tuning.interface import TuningType # Handle distribution strategy distribution_columns = table_tuning.get_columns_by_type(TuningType.DISTRIBUTION) if distribution_columns: # Sort by order and use first column as distribution key sorted_cols = sorted(distribution_columns, key=lambda col: col.order) dist_col = sorted_cols[0] # Use KEY distribution style with the specified column clauses.append("DISTSTYLE KEY") clauses.append(f"DISTKEY ({dist_col.name})") else: # Default to EVEN distribution if no distribution columns specified clauses.append("DISTSTYLE EVEN") # Handle sorting sort_columns = table_tuning.get_columns_by_type(TuningType.SORTING) if sort_columns: # Sort by order for sortkey sorted_cols = sorted(sort_columns, key=lambda col: col.order) column_names = [col.name for col in sorted_cols] # Use compound sort key by default (better for most OLAP workloads) # Could be made configurable to choose between COMPOUND and INTERLEAVED sortkey_clause = f"SORTKEY ({', '.join(column_names)})" clauses.append(sortkey_clause) # Handle partitioning through table naming/organization (logged but not in CREATE TABLE) partition_columns = table_tuning.get_columns_by_type(TuningType.PARTITIONING) if partition_columns: # Redshift partitioning is typically handled through table design patterns # We'll log the strategy but not add SQL clauses pass # Clustering not directly supported in Redshift CREATE TABLE except ImportError: # If tuning interface not available, return empty string pass return " ".join(clauses)
[docs] def apply_table_tunings(self, table_tuning, connection: Any) -> None: """Apply tuning configurations to a Redshift table. Redshift tuning approach: - DISTRIBUTION: Handled via DISTSTYLE/DISTKEY in CREATE TABLE - SORTING: Handled via SORTKEY in CREATE TABLE - Post-creation optimizations via ANALYZE and VACUUM Args: table_tuning: The tuning configuration to apply connection: Redshift connection Raises: ValueError: If the tuning configuration is invalid for Redshift """ if not table_tuning or not table_tuning.has_any_tuning(): return table_name = table_tuning.table_name.lower() self.logger.info(f"Applying Redshift tunings for table: {table_name}") cursor = connection.cursor() try: # Import here to avoid circular imports from benchbox.core.tuning.interface import TuningType # Redshift tuning is primarily handled at table creation time # Post-creation optimizations are limited # Verify table exists and get current configuration cursor.execute(f""" SELECT "schema", "table", diststyle, distkey, sortkey1, sortkey2, sortkey3, sortkey4 FROM pg_table_def WHERE schemaname = 'public' AND tablename = '{table_name.lower()}' """) result = cursor.fetchone() if result: current_diststyle = result[2] current_distkey = result[3] current_sortkeys = [sk for sk in result[4:8] if sk] # Filter out None values self.logger.info(f"Current configuration for {table_name}:") self.logger.info(f" Distribution style: {current_diststyle}") self.logger.info(f" Distribution key: {current_distkey}") self.logger.info(f" Sort keys: {current_sortkeys}") # Check if configuration matches desired tuning distribution_columns = table_tuning.get_columns_by_type(TuningType.DISTRIBUTION) sort_columns = table_tuning.get_columns_by_type(TuningType.SORTING) needs_recreation = False # Check distribution configuration if distribution_columns: sorted_cols = sorted(distribution_columns, key=lambda col: col.order) desired_distkey = sorted_cols[0].name if current_distkey != desired_distkey or current_diststyle != "KEY": needs_recreation = True self.logger.info( f"Distribution key mismatch: current='{current_distkey}', desired='{desired_distkey}'" ) # Check sort key configuration if sort_columns: sorted_cols = sorted(sort_columns, key=lambda col: col.order) desired_sortkeys = [col.name for col in sorted_cols] if current_sortkeys != desired_sortkeys: needs_recreation = True self.logger.info(f"Sort keys mismatch: current={current_sortkeys}, desired={desired_sortkeys}") if needs_recreation: self.logger.warning( f"Table {table_name} configuration differs from desired tuning. " "Redshift requires table recreation to change distribution/sort keys." ) else: self.logger.warning(f"Could not find table configuration for {table_name}") # Perform maintenance operations that can help with performance try: # Run ANALYZE to update table statistics cursor.execute(f"ANALYZE {table_name}") self.logger.info(f"Analyzed table statistics for {table_name}") # Run VACUUM to reclaim space and re-sort data if self.auto_vacuum: cursor.execute(f"VACUUM {table_name}") self.logger.info(f"Vacuumed table {table_name}") except Exception as e: self.logger.warning(f"Failed to perform maintenance operations on {table_name}: {e}") # Handle partitioning strategy partition_columns = table_tuning.get_columns_by_type(TuningType.PARTITIONING) if partition_columns: sorted_cols = sorted(partition_columns, key=lambda col: col.order) column_names = [col.name for col in sorted_cols] self.logger.info( f"Partitioning strategy for {table_name}: {', '.join(column_names)} (handled via table design patterns)" ) # Clustering not directly supported in Redshift cluster_columns = table_tuning.get_columns_by_type(TuningType.CLUSTERING) if cluster_columns: sorted_cols = sorted(cluster_columns, key=lambda col: col.order) column_names = [col.name for col in sorted_cols] self.logger.info( f"Clustering strategy for {table_name} achieved via sort keys: {', '.join(column_names)}" ) except ImportError: self.logger.warning("Tuning interface not available - skipping tuning application") except Exception as e: raise ValueError(f"Failed to apply tunings to Redshift table {table_name}: {e}") from e finally: cursor.close()
[docs] def apply_unified_tuning(self, unified_config: UnifiedTuningConfiguration, connection: Any) -> None: """Apply unified tuning configuration to Redshift. Args: unified_config: Unified tuning configuration to apply connection: Redshift connection """ if not unified_config: return # Apply constraint configurations self.apply_constraint_configuration(unified_config.primary_keys, unified_config.foreign_keys, connection) # Apply platform optimizations if unified_config.platform_optimizations: self.apply_platform_optimizations(unified_config.platform_optimizations, connection) # Apply table-level tunings for _table_name, table_tuning in unified_config.table_tunings.items(): self.apply_table_tunings(table_tuning, connection)
[docs] def apply_platform_optimizations(self, platform_config: PlatformOptimizationConfiguration, connection: Any) -> None: """Apply Redshift-specific platform optimizations. Redshift optimizations include: - Workload Management (WLM) queue configuration - Query group settings for resource allocation - Compression encoding optimization - Statistics collection and maintenance Args: platform_config: Platform optimization configuration connection: Redshift connection """ if not platform_config: return # Redshift optimizations are typically applied at session or workload level # Store optimizations for use during query execution and maintenance operations self.logger.info("Redshift platform optimizations stored for session and workload management")
[docs] def apply_constraint_configuration( self, primary_key_config: PrimaryKeyConfiguration, foreign_key_config: ForeignKeyConfiguration, connection: Any, ) -> None: """Apply constraint configurations to Redshift. Note: Redshift supports PRIMARY KEY and FOREIGN KEY constraints for query optimization, but they are informational only (not enforced). Constraints must be applied during table creation time. Args: primary_key_config: Primary key constraint configuration foreign_key_config: Foreign key constraint configuration connection: Redshift connection """ # Redshift constraints are applied at table creation time for query optimization # This method is called after tables are created, so log the configurations if primary_key_config and primary_key_config.enabled: self.logger.info( "Primary key constraints enabled for Redshift (informational only, applied during table creation)" ) if foreign_key_config and foreign_key_config.enabled: self.logger.info( "Foreign key constraints enabled for Redshift (informational only, applied during table creation)" )
# Redshift constraints are informational and used for query optimization # No additional work to do here as they're applied during CREATE TABLE def _build_redshift_config( platform: str, options: dict[str, Any], overrides: dict[str, Any], info: Any, ) -> Any: """Build Redshift database configuration with credential loading. This function loads saved credentials from the CredentialManager and merges them with CLI options and runtime overrides. Args: platform: Platform name (should be 'redshift') options: CLI platform options from --platform-option flags overrides: Runtime overrides from orchestrator info: Platform info from registry Returns: DatabaseConfig with credentials loaded """ from benchbox.core.schemas import DatabaseConfig from benchbox.security.credentials import CredentialManager # Load saved credentials cred_manager = CredentialManager() saved_creds = cred_manager.get_platform_credentials("redshift") or {} # Build merged options: saved_creds < options < overrides merged_options = {} merged_options.update(saved_creds) merged_options.update(options) merged_options.update(overrides) # Extract credential fields for DatabaseConfig name = info.display_name if info else "Amazon Redshift" driver_package = info.driver_package if info else "redshift-connector" # Build config dict with platform-specific fields at top-level # This allows RedshiftAdapter.__init__() and from_config() to access them via config.get() config_dict = { "type": "redshift", "name": name, "options": merged_options or {}, # Ensure options is never None (Pydantic v2 uses None if explicitly passed) "driver_package": driver_package, "driver_version": overrides.get("driver_version") or options.get("driver_version"), "driver_auto_install": bool(overrides.get("driver_auto_install", options.get("driver_auto_install", False))), # Platform-specific fields at top-level (adapters expect these here) "host": merged_options.get("host"), "port": merged_options.get("port"), # NOTE: database is NOT included here - from_config() generates it from benchmark context # Only explicit overrides (via --platform-option database=...) should bypass generation "username": merged_options.get("username"), "password": merged_options.get("password"), "schema": merged_options.get("schema"), # S3 and AWS configuration "s3_bucket": merged_options.get("s3_bucket"), "s3_prefix": merged_options.get("s3_prefix"), "staging_root": merged_options.get("staging_root"), "iam_role": merged_options.get("iam_role"), "aws_access_key_id": merged_options.get("aws_access_key_id"), "aws_secret_access_key": merged_options.get("aws_secret_access_key"), "aws_session_token": merged_options.get("aws_session_token"), "aws_region": merged_options.get("aws_region"), # Optional settings "cluster_identifier": merged_options.get("cluster_identifier"), "admin_database": merged_options.get("admin_database", "dev"), "connect_timeout": merged_options.get("connect_timeout"), "statement_timeout": merged_options.get("statement_timeout"), "sslmode": merged_options.get("sslmode"), # Benchmark context for config-aware database naming (from overrides) "benchmark": overrides.get("benchmark"), "scale_factor": overrides.get("scale_factor"), "tuning_config": overrides.get("tuning_config"), } # Only include explicit database override if provided via CLI or overrides # Saved credentials should NOT override generated database names if "database" in overrides and overrides["database"]: config_dict["database"] = overrides["database"] return DatabaseConfig(**config_dict) # Register the config builder with the platform hook registry # This must happen when the module is imported try: from benchbox.cli.platform_hooks import PlatformHookRegistry PlatformHookRegistry.register_config_builder("redshift", _build_redshift_config) except ImportError: # Platform hooks may not be available in all contexts (e.g., core-only usage) pass