Source code for benchbox.platforms.redshift

"""Amazon Redshift platform adapter with S3 integration and data warehouse optimizations.

Provides Redshift-specific optimizations for analytical workloads,
including COPY command for efficient data loading and distribution key optimization.

Copyright 2026 Joe Harris / BenchBox Project

Licensed under the MIT License. See LICENSE file in the project root for details.
"""

from __future__ import annotations

import json
import time
from pathlib import Path
from typing import TYPE_CHECKING, Any

if TYPE_CHECKING:
    from benchbox.core.tuning.interface import (
        ForeignKeyConfiguration,
        PlatformOptimizationConfiguration,
        PrimaryKeyConfiguration,
        UnifiedTuningConfiguration,
    )

from ..core.exceptions import ConfigurationError
from ..utils.dependencies import (
    check_platform_dependencies,
    get_dependency_error_message,
    get_dependency_group_packages,
)
from .base import PlatformAdapter
from .base.data_loading import FileFormatRegistry

try:
    import redshift_connector
except ImportError:
    try:
        import psycopg2

        redshift_connector = None
    except ImportError:
        psycopg2 = None
        redshift_connector = None

try:
    import boto3
    from botocore.exceptions import ClientError, NoCredentialsError
except ImportError:
    boto3 = None


[docs] class RedshiftAdapter(PlatformAdapter): """Amazon Redshift platform adapter with S3 integration."""
[docs] def __init__(self, **config): super().__init__(**config) dependency_packages = get_dependency_group_packages("redshift") # Check dependencies - prefer redshift-connector, fallback to psycopg2 if not redshift_connector and not psycopg2: available, missing = check_platform_dependencies("redshift") if not available: error_msg = get_dependency_error_message("redshift", missing) raise ImportError(error_msg) else: # Ensure shared helper libraries (e.g., boto3, cloudpathlib) are available shared_packages = [pkg for pkg in dependency_packages if pkg != "redshift-connector"] if shared_packages: available_shared, missing_shared = check_platform_dependencies("redshift", shared_packages) if not available_shared: error_msg = get_dependency_error_message("redshift", missing_shared) raise ImportError(error_msg) self._dialect = "redshift" # Redshift connection configuration self.host = config.get("host") self.port = config.get("port") if config.get("port") is not None else 5439 self.database = config.get("database") or "dev" self.username = config.get("username") self.password = config.get("password") self.cluster_identifier = config.get("cluster_identifier") # Admin database for metadata operations (CREATE/DROP DATABASE, checking database existence) # Redshift requires connecting to an existing database for admin operations # Default: "dev" (Redshift Serverless default database) self.admin_database = config.get("admin_database") or "dev" # Schema configuration self.schema = config.get("schema") or "public" # Connection settings self.connect_timeout = config.get("connect_timeout") if config.get("connect_timeout") is not None else 10 self.statement_timeout = config.get("statement_timeout") if config.get("statement_timeout") is not None else 0 self.sslmode = config.get("sslmode") or "require" # WLM settings self.wlm_query_slot_count = ( config.get("wlm_query_slot_count") if config.get("wlm_query_slot_count") is not None else 1 ) self.wlm_query_queue_name = config.get("wlm_query_queue_name") # SSL configuration (legacy compatibility) self.ssl_enabled = config.get("ssl_enabled") if config.get("ssl_enabled") is not None else True self.ssl_insecure = config.get("ssl_insecure") if config.get("ssl_insecure") is not None else False self.sslrootcert = config.get("sslrootcert") # S3 configuration for data loading # Check for staging_root first (set by orchestrator for CloudStagingPath) staging_root = config.get("staging_root") if staging_root: # Parse s3://bucket/path format to extract bucket and prefix from benchbox.utils.cloud_storage import get_cloud_path_info path_info = get_cloud_path_info(staging_root) if path_info["provider"] == "s3": self.s3_bucket = path_info["bucket"] # Use the path component if provided, otherwise use default self.s3_prefix = path_info["path"].strip("/") if path_info["path"] else "benchbox-data" self.logger.info(f"Using staging location from config: s3://{self.s3_bucket}/{self.s3_prefix}") else: raise ValueError(f"Redshift requires S3 (s3://) staging location, got: {path_info['provider']}://") else: # Fall back to explicit s3_bucket configuration self.s3_bucket = config.get("s3_bucket") self.s3_prefix = config.get("s3_prefix") or "benchbox-data" self.iam_role = config.get("iam_role") self.aws_access_key_id = config.get("aws_access_key_id") self.aws_secret_access_key = config.get("aws_secret_access_key") self.aws_region = config.get("aws_region") or "us-east-1" # Redshift optimization settings self.workload_management_config = config.get("wlm_config") # COMPUPDATE controls automatic compression during COPY (PRESET | ON | OFF) # PRESET: Apply compression based on column data types (no sampling) # ON: Apply compression based on data sampling # OFF: Disable automatic compression compupdate_raw = config.get("compupdate") or "PRESET" self.compupdate = compupdate_raw.upper() # Normalize to uppercase # Validate COMPUPDATE value valid_compupdate_values = {"ON", "OFF", "PRESET"} if self.compupdate not in valid_compupdate_values: raise ValueError( f"Invalid COMPUPDATE value: '{compupdate_raw}'. " f"Must be one of: {', '.join(sorted(valid_compupdate_values))}" ) self.auto_vacuum = config.get("auto_vacuum") if config.get("auto_vacuum") is not None else True self.auto_analyze = config.get("auto_analyze") if config.get("auto_analyze") is not None else True # Result cache control - disable by default for accurate benchmarking self.disable_result_cache = config.get("disable_result_cache", True) # Validation strictness - raise errors if cache control validation fails self.strict_validation = config.get("strict_validation", True) if not all([self.host, self.username, self.password]): missing = [] if not self.host: missing.append("host (or REDSHIFT_HOST)") if not self.username: missing.append("username (or REDSHIFT_USER)") if not self.password: missing.append("password (or REDSHIFT_PASSWORD)") raise ConfigurationError( f"Redshift configuration is incomplete. Missing: {', '.join(missing)}\n" "Configure with one of:\n" " 1. CLI: benchbox platforms setup --platform redshift\n" " 2. Environment variables: REDSHIFT_HOST, REDSHIFT_USER, REDSHIFT_PASSWORD\n" " 3. CLI options: --platform-option host=<cluster>.redshift.amazonaws.com" )
@property def platform_name(self) -> str: return "Redshift"
[docs] @staticmethod def add_cli_arguments(parser) -> None: """Add Redshift-specific CLI arguments.""" rs_group = parser.add_argument_group("Redshift Arguments") rs_group.add_argument("--host", type=str, help="Redshift cluster endpoint hostname") rs_group.add_argument("--port", type=int, default=5439, help="Redshift cluster port") rs_group.add_argument("--database", type=str, default="dev", help="Database name") rs_group.add_argument("--username", type=str, help="Database user with required privileges") rs_group.add_argument("--password", type=str, help="Password for the database user") rs_group.add_argument("--iam-role", type=str, help="IAM role ARN for COPY operations") rs_group.add_argument("--s3-bucket", type=str, help="S3 bucket for data staging") rs_group.add_argument("--s3-prefix", type=str, default="benchbox-data", help="Prefix within the staging bucket")
[docs] @classmethod def from_config(cls, config: dict[str, Any]): """Create Redshift adapter from unified configuration.""" from benchbox.utils.database_naming import generate_database_name adapter_config: dict[str, Any] = {} # Generate proper database name using benchmark characteristics # (unless explicitly overridden in config) if "database" in config and config["database"]: # User explicitly provided database name - use it adapter_config["database"] = config["database"] else: # Generate configuration-aware database name database_name = generate_database_name( benchmark_name=config["benchmark"], scale_factor=config["scale_factor"], platform="redshift", tuning_config=config.get("tuning_config"), ) adapter_config["database"] = database_name # Core connection parameters (database handled above) for key in ["host", "port", "username", "password", "schema"]: if key in config: adapter_config[key] = config[key] # Optional staging/optimization parameters for key in [ "iam_role", "s3_bucket", "s3_prefix", "aws_access_key_id", "aws_secret_access_key", "aws_region", "cluster_identifier", "admin_database", "connect_timeout", "statement_timeout", "sslmode", "ssl_enabled", "ssl_insecure", "sslrootcert", "wlm_query_slot_count", "wlm_query_queue_name", "wlm_config", "compupdate", "auto_vacuum", "auto_analyze", ]: if key in config: adapter_config[key] = config[key] return cls(**adapter_config)
[docs] def get_platform_info(self, connection: Any = None) -> dict[str, Any]: """Get Redshift platform information. Captures comprehensive Redshift configuration including: - Deployment type (serverless vs provisioned) - Capacity configuration (RPUs for serverless, node type/count for provisioned) - Redshift version - WLM (Workload Management) configuration - AWS region - Encryption and security settings Uses fallback chain: AWS API → SQL queries → hostname parsing Gracefully degrades if permissions are insufficient or AWS credentials unavailable. """ # Step 1: Detect deployment type from hostname deployment_type = self._detect_deployment_type(self.host) region = self._extract_region_from_hostname(self.host, deployment_type) or self.aws_region identifier = self._extract_identifier_from_hostname(self.host, deployment_type) platform_info = { "platform_type": "redshift", "platform_name": "Redshift", "connection_mode": "remote", "cloud_provider": "AWS", "host": self.host, "port": self.port, "configuration": { "database": self.database, "region": region, "s3_bucket": self.s3_bucket, "iam_role": self.iam_role, "compupdate": getattr(self, "compupdate", None), "result_cache_enabled": not self.disable_result_cache, "deployment_type": deployment_type, }, } # Get client library version if redshift_connector: try: platform_info["client_library_version"] = redshift_connector.__version__ except AttributeError: platform_info["client_library_version"] = None else: try: import psycopg2 platform_info["client_library_version"] = psycopg2.__version__ except (ImportError, AttributeError): platform_info["client_library_version"] = None # Try to get Redshift version and extended metadata from connection if connection: cursor = None try: cursor = connection.cursor() # Get Redshift version cursor.execute("SELECT version()") result = cursor.fetchone() platform_info["platform_version"] = result[0] if result else None # Step 2: Collect deployment-specific metadata using fallback chain deployment_metadata = {} if deployment_type == "serverless": # Try API first (most complete) if identifier and region: api_metadata = self._get_serverless_metadata_api(identifier, region) if api_metadata: deployment_metadata.update(api_metadata) self.logger.debug("Using Serverless metadata from AWS API") # Fall back to SQL if API didn't provide data if not deployment_metadata: sql_metadata = self._get_serverless_metadata_sql(cursor) if sql_metadata: deployment_metadata.update(sql_metadata) self.logger.debug("Using Serverless metadata from SQL queries") # Add serverless-specific fields to configuration if deployment_metadata: platform_info["configuration"]["workgroup_name"] = deployment_metadata.get( "workgroup_name", identifier ) platform_info["configuration"]["namespace_name"] = deployment_metadata.get("namespace_name") platform_info["configuration"]["base_capacity_rpu"] = deployment_metadata.get( "base_capacity_rpu" ) platform_info["configuration"]["max_capacity_rpu"] = deployment_metadata.get("max_capacity_rpu") platform_info["configuration"]["enhanced_vpc_routing"] = deployment_metadata.get( "enhanced_vpc_routing" ) platform_info["configuration"]["encrypted"] = deployment_metadata.get("encrypted") elif deployment_type == "provisioned": # Try API first (most complete) if identifier and region: api_metadata = self._get_provisioned_metadata_api(identifier, region) if api_metadata: deployment_metadata.update(api_metadata) self.logger.debug("Using Provisioned metadata from AWS API") # Fall back to SQL if API didn't provide data if not deployment_metadata: sql_metadata = self._get_provisioned_metadata_sql(cursor) if sql_metadata: deployment_metadata.update(sql_metadata) self.logger.debug("Using Provisioned metadata from SQL queries") # Add provisioned-specific fields to configuration if deployment_metadata: platform_info["configuration"]["cluster_identifier"] = deployment_metadata.get( "cluster_identifier", identifier ) platform_info["configuration"]["node_type"] = deployment_metadata.get("node_type") platform_info["configuration"]["number_of_nodes"] = deployment_metadata.get("number_of_nodes") platform_info["configuration"]["total_storage_capacity_mb"] = deployment_metadata.get( "total_storage_capacity_mb" ) platform_info["configuration"]["enhanced_vpc_routing"] = deployment_metadata.get( "enhanced_vpc_routing" ) platform_info["configuration"]["encrypted"] = deployment_metadata.get("encrypted") # Legacy compute_configuration field for backward compatibility platform_info["compute_configuration"] = { "node_type": deployment_metadata.get("node_type"), "cluster_version": deployment_metadata.get("cluster_version"), "num_compute_nodes": deployment_metadata.get("number_of_nodes"), } # Try to get WLM (Workload Management) configuration try: cursor.execute(""" SELECT service_class, num_query_tasks, query_working_mem, max_execution_time, user_group_wild_card, query_group_wild_card FROM stv_wlm_service_class_config WHERE service_class >= 6 ORDER BY service_class LIMIT 5 """) wlm_results = cursor.fetchall() if wlm_results: if "compute_configuration" not in platform_info: platform_info["compute_configuration"] = {} platform_info["compute_configuration"]["wlm_queues"] = [] for row in wlm_results: platform_info["compute_configuration"]["wlm_queues"].append( { "service_class": row[0] if len(row) > 0 else None, "num_query_tasks": row[1] if len(row) > 1 else None, "query_working_mem_mb": row[2] if len(row) > 2 else None, "max_execution_time_ms": row[3] if len(row) > 3 else None, } ) self.logger.debug("Successfully captured Redshift WLM configuration") except Exception as e: self.logger.debug(f"Could not query Redshift WLM configuration: {e}") except Exception as e: self.logger.debug(f"Error collecting Redshift platform info: {e}") if platform_info.get("platform_version") is None: platform_info["platform_version"] = None finally: if cursor: cursor.close() else: platform_info["platform_version"] = None return platform_info
[docs] def get_target_dialect(self) -> str: """Return the target SQL dialect for Redshift.""" return "redshift"
def _detect_deployment_type(self, hostname: str) -> str: """Detect Redshift deployment type from hostname pattern. Args: hostname: Redshift endpoint hostname Returns: "serverless", "provisioned", or "unknown" """ if not hostname: return "unknown" if ".redshift-serverless.amazonaws.com" in hostname: return "serverless" elif ".redshift.amazonaws.com" in hostname: return "provisioned" else: return "unknown" def _extract_region_from_hostname(self, hostname: str, deployment_type: str) -> str | None: """Extract AWS region from Redshift hostname. Args: hostname: Redshift endpoint hostname deployment_type: "serverless" or "provisioned" Returns: AWS region string or None if not found """ if not hostname: return None parts = hostname.split(".") if deployment_type == "serverless": # Format: workgroup.account.region.redshift-serverless.amazonaws.com return parts[2] if len(parts) > 2 else None elif deployment_type == "provisioned": # Format: cluster.region.redshift.amazonaws.com return parts[1] if len(parts) > 1 else None return None def _extract_identifier_from_hostname(self, hostname: str, deployment_type: str) -> str | None: """Extract workgroup name or cluster identifier from hostname. Args: hostname: Redshift endpoint hostname deployment_type: "serverless" or "provisioned" Returns: Workgroup name, cluster identifier, or None """ if not hostname: return None parts = hostname.split(".") if deployment_type == "serverless": # Format: workgroup.account.region.redshift-serverless.amazonaws.com return parts[0] if len(parts) > 0 else None elif deployment_type == "provisioned": # Format: cluster.region.redshift.amazonaws.com return parts[0] if len(parts) > 0 else None return None def _get_serverless_metadata_sql(self, cursor: Any) -> dict[str, Any]: """Get Redshift Serverless metadata using SQL queries. Args: cursor: Active database cursor Returns: Dictionary with serverless metadata (empty if not serverless or queries fail) """ metadata = {} try: # Try to query sys_serverless_usage (serverless-only table) cursor.execute(""" SELECT compute_capacity FROM sys_serverless_usage ORDER BY start_time DESC LIMIT 1 """) result = cursor.fetchone() if result: # Table exists and has data - this is serverless metadata["current_rpu_capacity"] = result[0] if result[0] is not None else None self.logger.debug("Detected Redshift Serverless via sys_serverless_usage table") except Exception as e: # Table doesn't exist or query failed - likely not serverless self.logger.debug(f"sys_serverless_usage query failed (not serverless or no permissions): {e}") return metadata def _get_provisioned_metadata_sql(self, cursor: Any) -> dict[str, Any]: """Get Redshift Provisioned metadata using SQL queries. Args: cursor: Active database cursor Returns: Dictionary with provisioned metadata (empty if not provisioned or queries fail) """ metadata = {} try: # Query stv_cluster_configuration (provisioned-only table) cursor.execute(""" SELECT node_type, cluster_version, COUNT(*) as num_nodes FROM stv_cluster_configuration GROUP BY node_type, cluster_version """) result = cursor.fetchone() if result: metadata["node_type"] = result[0] if len(result) > 0 else None metadata["cluster_version"] = result[1] if len(result) > 1 else None metadata["number_of_nodes"] = result[2] if len(result) > 2 else None self.logger.debug( f"Detected Redshift Provisioned: {metadata['node_type']} x{metadata['number_of_nodes']}" ) except Exception as e: # Table doesn't exist or query failed - likely not provisioned or no permissions self.logger.debug(f"stv_cluster_configuration query failed (not provisioned or no permissions): {e}") return metadata def _get_serverless_metadata_api(self, workgroup_name: str, region: str) -> dict[str, Any]: """Get Redshift Serverless metadata using boto3 API. Args: workgroup_name: Workgroup name region: AWS region Returns: Dictionary with serverless metadata (empty if API call fails) """ if not boto3: self.logger.debug("boto3 not available - skipping Serverless API metadata") return {} metadata = {} try: client = boto3.client("redshift-serverless", region_name=region) # Get workgroup details response = client.get_workgroup(workgroupName=workgroup_name) workgroup = response.get("workgroup", {}) # Extract essential sizing information metadata["workgroup_name"] = workgroup.get("workgroupName") metadata["base_capacity_rpu"] = workgroup.get("baseCapacity") metadata["max_capacity_rpu"] = workgroup.get("maxCapacity") metadata["namespace_name"] = workgroup.get("namespaceName") metadata["enhanced_vpc_routing"] = workgroup.get("enhancedVpcRouting", False) metadata["status"] = workgroup.get("status") # Get namespace details for encryption info if metadata.get("namespace_name"): try: namespace_response = client.get_namespace(namespaceName=metadata["namespace_name"]) namespace = namespace_response.get("namespace", {}) metadata["kms_key_id"] = namespace.get("kmsKeyId") metadata["encrypted"] = True # Serverless is always encrypted except Exception as e: self.logger.debug(f"Could not fetch namespace details: {e}") self.logger.debug(f"Retrieved Serverless metadata via API: {metadata['base_capacity_rpu']} RPUs") except NoCredentialsError: self.logger.debug("No AWS credentials found - skipping Serverless API metadata") except ClientError as e: error_code = e.response.get("Error", {}).get("Code", "Unknown") self.logger.debug(f"AWS API error querying Serverless metadata: {error_code} - {e}") except Exception as e: self.logger.debug(f"Failed to query Serverless API metadata: {e}") return metadata def _get_provisioned_metadata_api(self, cluster_identifier: str, region: str) -> dict[str, Any]: """Get Redshift Provisioned metadata using boto3 API. Args: cluster_identifier: Cluster identifier region: AWS region Returns: Dictionary with provisioned metadata (empty if API call fails) """ if not boto3: self.logger.debug("boto3 not available - skipping Provisioned API metadata") return {} metadata = {} try: client = boto3.client("redshift", region_name=region) # Get cluster details response = client.describe_clusters(ClusterIdentifier=cluster_identifier) clusters = response.get("Clusters", []) if clusters: cluster = clusters[0] # Extract essential sizing information metadata["cluster_identifier"] = cluster.get("ClusterIdentifier") metadata["node_type"] = cluster.get("NodeType") metadata["number_of_nodes"] = cluster.get("NumberOfNodes") metadata["cluster_status"] = cluster.get("ClusterStatus") metadata["encrypted"] = cluster.get("Encrypted", False) metadata["kms_key_id"] = cluster.get("KmsKeyId") metadata["enhanced_vpc_routing"] = cluster.get("EnhancedVpcRouting", False) # Storage capacity total_storage_mb = cluster.get("TotalStorageCapacityInMegaBytes") if total_storage_mb: metadata["total_storage_capacity_mb"] = total_storage_mb self.logger.debug( f"Retrieved Provisioned metadata via API: {metadata['node_type']} x{metadata['number_of_nodes']}" ) except NoCredentialsError: self.logger.debug("No AWS credentials found - skipping Provisioned API metadata") except ClientError as e: error_code = e.response.get("Error", {}).get("Code", "Unknown") self.logger.debug(f"AWS API error querying Provisioned metadata: {error_code} - {e}") except Exception as e: self.logger.debug(f"Failed to query Provisioned API metadata: {e}") return metadata def _get_connection_params(self, **connection_config) -> dict[str, Any]: """Get standardized connection parameters.""" return { "host": connection_config.get("host", self.host), "port": connection_config.get("port", self.port), "database": connection_config.get("database", self.database), "user": connection_config.get("username", self.username), "password": connection_config.get("password", self.password), "sslmode": connection_config.get("sslmode", self.sslmode), } def _create_admin_connection(self, **connection_config) -> Any: """Create Redshift connection for admin operations. Admin operations (CREATE DATABASE, DROP DATABASE, checking database existence) require connecting to an existing database. This uses self.admin_database (default: "dev") instead of the target database to avoid circular dependencies. """ params = self._get_connection_params(**connection_config) # Override database with admin database for admin operations # This prevents trying to connect to the target database to check if it exists admin_params = params.copy() admin_params["database"] = self.admin_database # Use same driver as main connection if redshift_connector: # Use redshift_connector (preferred) # Note: redshift_connector doesn't support connect_timeout or sslmode parameters # It only accepts ssl=True/False (not sslmode like psycopg2) return redshift_connector.connect( host=admin_params["host"], port=admin_params["port"], database=admin_params["database"], user=admin_params["user"], password=admin_params["password"], ssl=self.ssl_enabled, application_name="BenchBox-Admin", ) else: # Fall back to psycopg2 admin_params["connect_timeout"] = 30 return psycopg2.connect(**admin_params) def _create_direct_connection(self, **connection_config) -> Any: """Create direct connection to target database for validation. Connects directly to the specified database without: - Calling handle_existing_database() - Creating database if missing - Setting database_was_reused flag Used by validation framework to check existing database compatibility. Args: **connection_config: Connection configuration including database name Returns: Database connection object Raises: Exception: If connection fails (database doesn't exist, auth fails, etc.) """ # Get connection parameters for target database params = self._get_connection_params(**connection_config) # Use redshift_connector if available, otherwise fall back to psycopg2 if redshift_connector: # Note: redshift_connector only accepts ssl=True/False (not sslmode like psycopg2) connection = redshift_connector.connect( host=params["host"], database=params["database"], port=params["port"], user=params["user"], password=params["password"], ssl=self.ssl_enabled, application_name="BenchBox-Validation", ) else: # Fall back to psycopg2 params["connect_timeout"] = 30 connection = psycopg2.connect(**params) # Apply WLM queue settings if configured if self.wlm_query_queue_name: cursor = connection.cursor() try: # Escape single quotes in queue name for SQL safety queue_name_escaped = self.wlm_query_queue_name.replace("'", "''") cursor.execute(f"SET query_group TO '{queue_name_escaped}'") finally: cursor.close() return connection
[docs] def check_server_database_exists(self, **connection_config) -> bool: """Check if database exists in Redshift cluster. Connects to admin database to query pg_database for the target database. """ try: # Connect to admin database (not target database) connection = self._create_admin_connection() cursor = connection.cursor() database = connection_config.get("database", self.database) # Check if database exists cursor.execute("SELECT datname FROM pg_database WHERE datname = %s", (database,)) result = cursor.fetchone() return result is not None except Exception: # If we can't connect or check, assume database doesn't exist return False finally: if "connection" in locals() and connection: connection.close()
[docs] def drop_database(self, **connection_config) -> None: """Drop database in Redshift cluster. Connects to admin database to drop the target database. Note: DROP DATABASE must run with autocommit enabled. Note: Redshift doesn't support IF EXISTS for DROP DATABASE, so we check first. """ database = connection_config.get("database", self.database) # Check if database exists first (Redshift doesn't support IF EXISTS) if not self.check_server_database_exists(database=database): self.log_verbose(f"Database {database} does not exist - nothing to drop") return try: # Connect to admin database (not target database) connection = self._create_admin_connection() connection.autocommit = True # Enable autocommit for DROP DATABASE cursor = connection.cursor() # Try to drop database first (graceful approach) # Quote identifier for SQL safety try: cursor.execute(f'DROP DATABASE "{database}"') except Exception as drop_error: # If drop fails due to active connections, terminate them and retry error_msg = str(drop_error).lower() if "active connection" in error_msg or "being accessed" in error_msg: self.log_verbose("Database has active connections, terminating them...") # Terminate existing connections as fallback # Note: Use 'pid' column (not deprecated 'procpid') cursor.execute( """ SELECT pg_terminate_backend(pid) FROM pg_stat_activity WHERE datname = %s AND pid <> pg_backend_pid() """, (database,), ) # Retry drop after terminating connections (quoted for SQL safety) cursor.execute(f'DROP DATABASE "{database}"') else: # Re-raise if not a connection issue raise except Exception as e: raise RuntimeError(f"Failed to drop Redshift database {database}: {e}") from e finally: if "connection" in locals() and connection: connection.close()
[docs] def create_connection(self, **connection_config) -> Any: """Create optimized Redshift connection.""" self.log_operation_start("Redshift connection") # Handle existing database using base class method self.handle_existing_database(**connection_config) # Get connection parameters params = self._get_connection_params(**connection_config) target_database = params.get("database") # Create database if needed (before connecting to it) # Redshift requires connecting to an admin database to create new databases if not self.database_was_reused: # Check if target database exists database_exists = self.check_server_database_exists(database=target_database) if not database_exists: self.log_verbose(f"Creating database: {target_database}") # Create database using admin connection (connects to self.admin_database) # Note: CREATE DATABASE must run with autocommit enabled (cannot run in transaction block) try: admin_conn = self._create_admin_connection() admin_conn.autocommit = True # Enable autocommit for CREATE DATABASE admin_cursor = admin_conn.cursor() try: # Quote identifier for SQL safety admin_cursor.execute(f'CREATE DATABASE "{target_database}"') # No commit() needed - autocommit handles it automatically self.logger.info(f"Created database {target_database}") finally: admin_cursor.close() admin_conn.close() except Exception as e: self.logger.error(f"Failed to create database {target_database}: {e}") raise self.log_very_verbose(f"Redshift connection params: host={params.get('host')}, database={target_database}") try: if redshift_connector: # Use redshift_connector (preferred) # Note: redshift_connector only accepts ssl=True/False (not sslmode like psycopg2) connection = redshift_connector.connect( host=params["host"], port=params["port"], database=params["database"], user=params["user"], password=params["password"], ssl=self.ssl_enabled, # Connection optimization application_name="BenchBox", tcp_keepalive=True, tcp_keepalive_idle=600, tcp_keepalive_interval=30, tcp_keepalive_count=3, ) # Enable autocommit immediately after connection creation (before any SQL operations) connection.autocommit = True else: # Fall back to psycopg2 (already imported at top of file) connection = psycopg2.connect( host=params["host"], port=params["port"], database=params["database"], user=params["user"], password=params["password"], sslmode=params["sslmode"], application_name="BenchBox", connect_timeout=self.connect_timeout, ) # Enable autocommit for benchmark workloads (no transactions needed) connection.autocommit = True # Apply WLM settings and schema search path cursor = connection.cursor() # Integer settings validated in __init__, safe to interpolate if self.wlm_query_slot_count > 1: cursor.execute(f"SET wlm_query_slot_count = {int(self.wlm_query_slot_count)}") if self.statement_timeout > 0: cursor.execute(f"SET statement_timeout = {int(self.statement_timeout)}") # Set search_path to ensure all unqualified table references use correct schema # Critical for database reuse when schema already exists but connection is new # Quote identifier for SQL safety cursor.execute(f'SET search_path TO "{self.schema}"') # Test connection cursor.execute("SELECT version()") cursor.fetchone() cursor.close() self.logger.info(f"Connected to Redshift cluster at {params['host']}:{params['port']}") self.log_operation_complete( "Redshift connection", details=f"Connected to {params['host']}:{params['port']}" ) return connection except Exception as e: self.logger.error(f"Failed to connect to Redshift: {e}") raise
[docs] def create_schema(self, benchmark, connection: Any) -> float: """Create schema using Redshift-optimized table definitions.""" start_time = time.time() cursor = connection.cursor() try: # Create schema if needed (if not using default "public") # Quote identifiers for SQL safety if self.schema and self.schema.lower() != "public": self.log_verbose(f"Creating schema: {self.schema}") cursor.execute(f'CREATE SCHEMA IF NOT EXISTS "{self.schema}"') self.logger.info(f"Created schema {self.schema}") # Set search_path to use the correct schema self.log_very_verbose(f"Setting search_path to: {self.schema}") cursor.execute(f'SET search_path TO "{self.schema}"') # Use common schema creation helper schema_sql = self._create_schema_with_tuning(benchmark, source_dialect="duckdb") # Split schema into individual statements and execute statements = [stmt.strip() for stmt in schema_sql.split(";") if stmt.strip()] for statement in statements: # Normalize table names to lowercase for Redshift consistency # This ensures CREATE, COPY, and SELECT all use the same case statement = self._normalize_table_name_in_sql(statement) # Ensure idempotency with DROP TABLE IF EXISTS # (Redshift doesn't support CREATE OR REPLACE TABLE) if statement.upper().startswith("CREATE TABLE"): # Extract table name from CREATE TABLE statement table_name = self._extract_table_name(statement) if table_name: # Ensure table name is lowercase table_name_lower = table_name.strip('"').lower() drop_statement = f"DROP TABLE IF EXISTS {table_name_lower}" cursor.execute(drop_statement) self.logger.debug(f"Executed: {drop_statement}") # Optimize table definition for Redshift statement = self._optimize_table_definition(statement) cursor.execute(statement) self.logger.debug(f"Executed schema statement: {statement[:100]}...") self.logger.info("Schema created") except Exception as e: self.logger.error(f"Schema creation failed: {e}") raise finally: cursor.close() return time.time() - start_time
[docs] def load_data( self, benchmark, connection: Any, data_dir: Path ) -> tuple[dict[str, int], float, dict[str, Any] | None]: """Load data using Redshift COPY command with S3 integration.""" start_time = time.time() table_stats = {} cursor = connection.cursor() try: # Get data files from benchmark or manifest fallback if hasattr(benchmark, "tables") and benchmark.tables: data_files = benchmark.tables else: data_files = None try: manifest_path = Path(data_dir) / "_datagen_manifest.json" if manifest_path.exists(): with open(manifest_path) as f: manifest = json.load(f) tables = manifest.get("tables") or {} mapping = {} for table, entries in tables.items(): if entries: # Collect ALL chunk files, not just the first one chunk_paths = [] for entry in entries: rel = entry.get("path") if rel: chunk_paths.append(Path(data_dir) / rel) if chunk_paths: mapping[table] = chunk_paths if mapping: data_files = mapping self.logger.debug("Using data files from _datagen_manifest.json") except Exception as e: self.logger.debug(f"Manifest fallback failed: {e}") if not data_files: # No data files available - benchmark should have generated data first raise ValueError("No data files found. Ensure benchmark.generate_data() was called first.") # Upload files to S3 and load via COPY command if self.s3_bucket and boto3: # Create S3 client with explicit error handling try: s3_client = boto3.client( "s3", aws_access_key_id=self.aws_access_key_id, aws_secret_access_key=self.aws_secret_access_key, region_name=self.aws_region, ) except NoCredentialsError as e: raise ValueError( "AWS credentials not found. Configure credentials via:\n" " 1. IAM role (aws_access_key_id/aws_secret_access_key in config)\n" " 2. AWS CLI (aws configure)\n" " 3. Environment variables (AWS_ACCESS_KEY_ID/AWS_SECRET_ACCESS_KEY)" ) from e except ClientError as e: raise ValueError(f"Failed to create AWS S3 client: {e}") from e for table_name, file_paths in data_files.items(): # Normalize to list (data resolver should always return lists now) if not isinstance(file_paths, list): file_paths = [file_paths] # Filter out non-existent or empty files valid_files = [] for file_path in file_paths: file_path = Path(file_path) if file_path.exists() and file_path.stat().st_size > 0: valid_files.append(file_path) if not valid_files: self.logger.warning(f"Skipping {table_name} - no valid data files") table_stats[table_name.lower()] = 0 continue chunk_info = f" from {len(valid_files)} file(s)" if len(valid_files) > 1 else "" self.log_verbose(f"Loading data for table: {table_name}{chunk_info}") try: load_start = time.time() # Normalize table name to lowercase for Redshift consistency table_name_lower = table_name.lower() # Upload all files to S3 and collect S3 URIs # Redshift supports compressed files and any delimiter natively s3_uris = [] for file_idx, file_path in enumerate(valid_files): file_path = Path(file_path) # Detect file format to determine delimiter # TPC-H uses .tbl (pipe-delimited), TPC-DS uses .dat (pipe-delimited) # Use substring check to handle chunked files like customer.tbl.1 or customer.tbl.1.zst file_str = str(file_path.name) delimiter = "|" if ".tbl" in file_str or ".dat" in file_str else "," # Upload file directly with original compression and format # Preserve full multi-part suffix for chunked files (e.g., .tbl.1.zst) # Extract all suffixes after table name (e.g., "customer.tbl.1.zst" -> ".tbl.1.zst") file_stem = file_path.stem # e.g., "customer.tbl.1" or "customer" # Get original suffix (e.g., ".zst") original_suffix = file_path.suffix # Check if stem has more suffixes (e.g., ".tbl.1" in "customer.tbl.1") if "." in file_stem: # Extract all suffixes: split at first dot and take the rest parts = file_path.name.split(".", 1) # e.g., ["customer", "tbl.1.zst"] if len(parts) > 1: full_suffix = "." + parts[1] # e.g., ".tbl.1.zst" else: full_suffix = original_suffix else: full_suffix = original_suffix s3_key = f"{self.s3_prefix}/{table_name}_{file_idx}{full_suffix}" # Upload file with explicit error handling try: s3_client.upload_file(str(file_path), self.s3_bucket, s3_key) except ClientError as e: error_code = e.response.get("Error", {}).get("Code", "Unknown") if error_code == "NoSuchBucket": raise ValueError( f"S3 bucket '{self.s3_bucket}' does not exist. " f"Create the bucket or update your configuration." ) from e elif error_code == "AccessDenied": raise ValueError( f"Access denied to S3 bucket '{self.s3_bucket}'. " f"Check IAM permissions for s3:PutObject." ) from e else: raise ValueError( f"Failed to upload {file_path.name} to s3://{self.s3_bucket}/{s3_key}: " f"{error_code} - {e}" ) from e s3_uris.append(f"s3://{self.s3_bucket}/{s3_key}") # For multi-file loads, create manifest file and use that if len(s3_uris) > 1: manifest = {"entries": [{"url": uri, "mandatory": True} for uri in s3_uris]} manifest_key = f"{self.s3_prefix}/{table_name}_manifest.json" # Upload manifest with explicit error handling try: s3_client.put_object( Bucket=self.s3_bucket, Key=manifest_key, Body=json.dumps(manifest), ) except ClientError as e: error_code = e.response.get("Error", {}).get("Code", "Unknown") raise ValueError( f"Failed to upload manifest file to s3://{self.s3_bucket}/{manifest_key}: " f"{error_code} - {e}" ) from e copy_from_path = f"s3://{self.s3_bucket}/{manifest_key}" manifest_option = "manifest" else: copy_from_path = s3_uris[0] manifest_option = "" # Detect compression format from file extension # Redshift auto-detects gzip, but we should be explicit for zstd if any(str(f).endswith(".zst") for f in valid_files): compression_option = "ZSTD" elif any(str(f).endswith(".gz") for f in valid_files): compression_option = "GZIP" else: compression_option = "" # Load from S3 using COPY command with three-way credential handling: # 1. IAM role (preferred) # 2. Explicit access keys # 3. Cluster default IAM role (no credentials in SQL) if self.iam_role: # Use IAM role for authentication credentials_clause = f"IAM_ROLE '{self.iam_role}'" elif self.aws_access_key_id and self.aws_secret_access_key: # Use explicit access keys credentials_clause = f"ACCESS_KEY_ID '{self.aws_access_key_id}' SECRET_ACCESS_KEY '{self.aws_secret_access_key}'" else: # No explicit credentials - rely on cluster's default IAM role credentials_clause = "" self.log_verbose( "No explicit credentials configured for COPY; using cluster default IAM role" ) # Build COPY command with credentials clause (may be empty) # Fully qualify table name with schema for clarity and to avoid ambiguity qualified_table = f"{self.schema}.{table_name_lower}" copy_sql = f""" COPY {qualified_table} FROM '{copy_from_path}' {credentials_clause} {manifest_option} {compression_option} DELIMITER '{delimiter}' IGNOREHEADER 0 COMPUPDATE {self.compupdate} """ cursor.execute(copy_sql) # Get row count cursor.execute(f"SELECT COUNT(*) FROM {qualified_table}") row_count = cursor.fetchone()[0] table_stats[table_name_lower] = row_count # Run ANALYZE if configured if self.auto_analyze: cursor.execute(f"ANALYZE {qualified_table}") load_time = time.time() - load_start self.logger.info( f"✅ Loaded {row_count:,} rows into {table_name_lower}{chunk_info} in {load_time:.2f}s" ) except Exception as e: self.logger.error(f"Failed to load {table_name}: {str(e)[:100]}...") table_stats[table_name.lower()] = 0 else: # Direct loading without S3 (less efficient) self.logger.warning("No S3 bucket configured, using direct INSERT loading") for table_name, file_paths in data_files.items(): # Normalize to list (handle both single paths and lists for TPC-H vs TPC-DS) if not isinstance(file_paths, list): file_paths = [file_paths] # Filter valid files valid_files = [] for file_path in file_paths: file_path = Path(file_path) if file_path.exists() and file_path.stat().st_size > 0: valid_files.append(file_path) if not valid_files: self.logger.warning(f"Skipping {table_name} - no valid data files") table_stats[table_name.lower()] = 0 continue try: self.log_verbose(f"Direct loading data for table: {table_name}") load_start = time.time() # Normalize table name to lowercase for Redshift consistency table_name_lower = table_name.lower() # Load data row by row from all chunks (inefficient but works without S3) total_rows_loaded = 0 for file_idx, file_path in enumerate(valid_files): chunk_info = f" (chunk {file_idx + 1}/{len(valid_files)})" if len(valid_files) > 1 else "" self.log_very_verbose(f"Loading {table_name}{chunk_info} from {file_path.name}") # TPC-H uses .tbl files, TPC-DS uses .dat files - both are pipe-delimited file_str = str(file_path.name) delimiter = "|" if ".tbl" in file_str or ".dat" in file_str else "," # Get compression handler (handles .zst, .gz, or uncompressed) compression_handler = FileFormatRegistry.get_compression_handler(file_path) with compression_handler.open(file_path) as f: rows_loaded = 0 batch_size = 1000 batch_data = [] for line in f: line = line.strip() if line and line.endswith(delimiter): line = line[:-1] values = line.split(delimiter) # Simple escaping for SQL escaped_values = ["'" + str(v).replace("'", "''") + "'" for v in values] batch_data.append(f"({', '.join(escaped_values)})") if len(batch_data) >= batch_size: insert_sql = f"INSERT INTO {table_name_lower} VALUES " + ", ".join(batch_data) cursor.execute(insert_sql) rows_loaded += len(batch_data) total_rows_loaded += len(batch_data) batch_data = [] # Insert remaining batch if batch_data: insert_sql = f"INSERT INTO {table_name_lower} VALUES " + ", ".join(batch_data) cursor.execute(insert_sql) rows_loaded += len(batch_data) total_rows_loaded += len(batch_data) self.log_very_verbose(f"Loaded {rows_loaded:,} rows from {file_path.name}") table_stats[table_name_lower] = total_rows_loaded load_time = time.time() - load_start chunk_info = f" from {len(valid_files)} file(s)" if len(valid_files) > 1 else "" self.logger.info( f"✅ Loaded {total_rows_loaded:,} rows into {table_name_lower}{chunk_info} in {load_time:.2f}s" ) except Exception as e: self.logger.error(f"Failed to load {table_name}: {str(e)[:100]}...") table_stats[table_name.lower()] = 0 total_time = time.time() - start_time total_rows = sum(table_stats.values()) self.logger.info(f"✅ Loaded {total_rows:,} total rows in {total_time:.2f}s") except Exception as e: self.logger.error(f"Data loading failed: {e}") raise finally: cursor.close() # Redshift doesn't provide detailed per-table timings yet return table_stats, total_time, None
[docs] def configure_for_benchmark(self, connection: Any, benchmark_type: str) -> None: """Apply Redshift-specific optimizations based on benchmark type.""" cursor = connection.cursor() try: # Set session-level optimizations # Use OFF for result cache to ensure accurate benchmark measurements cache_setting = "OFF" if self.disable_result_cache else "ON" optimization_settings = [ f"SET enable_result_cache_for_session TO {cache_setting}", "SET query_group TO 'benchbox'", "SET statement_timeout TO '1800000'", # 30 minutes ] if benchmark_type.lower() in ["olap", "analytics", "tpch", "tpcds"]: # OLAP-specific optimizations optimization_settings.extend( [ "SET enable_case_sensitive_identifier TO OFF", "SET datestyle TO 'ISO, MDY'", "SET extra_float_digits TO 0", ] ) critical_failures = [] for setting in optimization_settings: try: cursor.execute(setting) self.logger.debug(f"Applied setting: {setting}") except Exception as e: # Track if critical cache control setting failed if "enable_result_cache_for_session" in setting: critical_failures.append(setting) self.logger.warning(f"Failed to apply setting {setting}: {e}") # Validate cache control settings were successfully applied if self.disable_result_cache or critical_failures: self.logger.debug("Validating cache control settings...") validation_result = self.validate_session_cache_control(connection) if not validation_result["validated"]: self.logger.warning(f"Cache control validation failed: {validation_result.get('errors', [])}") else: self.logger.info( f"Cache control validated successfully: cache_disabled={validation_result['cache_disabled']}" ) # Run VACUUM and ANALYZE on all tables if configured if self.auto_vacuum or self.auto_analyze: cursor.execute(f""" SELECT schemaname, tablename FROM pg_tables WHERE schemaname = '{self.schema}' """) tables = cursor.fetchall() for _schema, table in tables: if self.auto_vacuum: try: cursor.execute(f"VACUUM {table}") except Exception as e: self.logger.warning(f"VACUUM failed for {table}: {e}") if self.auto_analyze: try: cursor.execute(f"ANALYZE {table}") except Exception as e: self.logger.warning(f"ANALYZE failed for {table}: {e}") finally: cursor.close()
[docs] def validate_session_cache_control(self, connection: Any) -> dict[str, Any]: """Validate that session-level cache control settings were successfully applied. Args: connection: Active Redshift database connection Returns: dict with: - validated: bool - Whether validation passed - cache_disabled: bool - Whether cache is actually disabled - settings: dict - Actual session settings - warnings: list[str] - Any validation warnings - errors: list[str] - Any validation errors Raises: ConfigurationError: If cache control validation fails and strict_validation=True """ cursor = connection.cursor() result = { "validated": False, "cache_disabled": False, "settings": {}, "warnings": [], "errors": [], } try: # Query current session setting using current_setting() function cursor.execute("SELECT current_setting('enable_result_cache_for_session') as value") row = cursor.fetchone() if row: actual_value = str(row[0]).lower() result["settings"]["enable_result_cache_for_session"] = actual_value # Determine expected value based on configuration expected_value = "off" if self.disable_result_cache else "on" if actual_value == expected_value: result["validated"] = True result["cache_disabled"] = actual_value == "off" self.logger.debug( f"Cache control validated: enable_result_cache_for_session={actual_value} " f"(expected {expected_value})" ) else: error_msg = ( f"Cache control validation failed: " f"expected enable_result_cache_for_session={expected_value}, " f"got {actual_value}" ) result["errors"].append(error_msg) self.logger.error(error_msg) # Raise error if strict validation mode enabled if self.strict_validation: raise ConfigurationError( "Redshift session cache control validation failed - " "benchmark results may be incorrect due to cached query results", details=result, ) else: warning_msg = "Could not retrieve enable_result_cache_for_session parameter from session" result["warnings"].append(warning_msg) self.logger.warning(warning_msg) except Exception as e: # If this is our ConfigurationError, re-raise it if isinstance(e, ConfigurationError): raise # Otherwise log validation error error_msg = f"Validation query failed: {e}" result["errors"].append(error_msg) self.logger.error(f"Cache control validation error: {e}") # Raise if strict mode and query failed if self.strict_validation: raise ConfigurationError( "Failed to validate Redshift cache control settings", details={"original_error": str(e), "validation_result": result}, ) from e finally: cursor.close() return result
[docs] def execute_query( self, connection: Any, query: str, query_id: str, benchmark_type: str | None = None, scale_factor: float | None = None, validate_row_count: bool = True, stream_id: int | None = None, ) -> dict[str, Any]: """Execute query with detailed timing and performance tracking.""" start_time = time.time() cursor = connection.cursor() try: # Execute the query # Note: Query dialect translation is now handled automatically by the base adapter cursor.execute(query) result = cursor.fetchall() execution_time = time.time() - start_time actual_row_count = len(result) if result else 0 # Get query statistics try: query_stats = self._get_query_statistics(connection, query_id) # Add execution time for cost calculation query_stats["execution_time_seconds"] = execution_time except Exception: query_stats = {"execution_time_seconds": execution_time} # Validate row count if enabled and benchmark type is provided validation_result = None if validate_row_count and benchmark_type: from benchbox.core.validation.query_validation import QueryValidator validator = QueryValidator() validation_result = validator.validate_query_result( benchmark_type=benchmark_type, query_id=query_id, actual_row_count=actual_row_count, scale_factor=scale_factor, stream_id=stream_id, ) # Log validation result if validation_result.warning_message: self.log_verbose(f"Row count validation: {validation_result.warning_message}") elif not validation_result.is_valid: self.log_verbose(f"Row count validation FAILED: {validation_result.error_message}") else: self.log_very_verbose( f"Row count validation PASSED: {actual_row_count} rows " f"(expected: {validation_result.expected_row_count})" ) # Use base helper to build result with consistent validation field mapping result_dict = self._build_query_result_with_validation( query_id=query_id, execution_time=execution_time, actual_row_count=actual_row_count, first_row=result[0] if result else None, validation_result=validation_result, ) # Include Redshift-specific fields result_dict["translated_query"] = None # Translation handled by base adapter result_dict["query_statistics"] = query_stats # Map query_statistics to resource_usage for cost calculation result_dict["resource_usage"] = query_stats return result_dict except Exception as e: execution_time = time.time() - start_time return { "query_id": query_id, "status": "FAILED", "execution_time": execution_time, "rows_returned": 0, "error": str(e), "error_type": type(e).__name__, } finally: cursor.close()
def _extract_table_name(self, statement: str) -> str | None: """Extract table name from CREATE TABLE statement. Args: statement: CREATE TABLE SQL statement Returns: Table name or None if not found """ try: # Simple extraction: find text between "CREATE TABLE" and "(" import re match = re.search(r"CREATE\s+TABLE\s+([^\s(]+)", statement, re.IGNORECASE) if match: return match.group(1).strip() except Exception: pass return None def _normalize_table_name_in_sql(self, sql: str) -> str: """Normalize table names in SQL to lowercase for Redshift. Redshift converts unquoted identifiers to lowercase, so we normalize all table names to lowercase to ensure consistency across CREATE, COPY, and SELECT operations. This prevents case sensitivity issues. Args: sql: SQL statement Returns: SQL with normalized (lowercase) table names """ import re # Match CREATE TABLE "TABLENAME" or CREATE TABLE TABLENAME # and convert to CREATE TABLE tablename (unquoted lowercase) sql = re.sub( r'CREATE\s+TABLE\s+"?([A-Za-z_][A-Za-z0-9_]*)"?', lambda m: f"CREATE TABLE {m.group(1).lower()}", sql, flags=re.IGNORECASE, ) # Match foreign key references to quoted/unquoted table names # FOREIGN KEY ... REFERENCES "TABLENAME" → REFERENCES tablename sql = re.sub( r'REFERENCES\s+"?([A-Za-z_][A-Za-z0-9_]*)"?', lambda m: f"REFERENCES {m.group(1).lower()}", sql, flags=re.IGNORECASE, ) return sql def _optimize_table_definition(self, statement: str) -> str: """Optimize table definition for Redshift.""" if not statement.upper().startswith("CREATE TABLE"): return statement # Include distribution and sort keys for better performance if "DISTSTYLE" not in statement.upper(): # Include AUTO distribution style (Redshift will choose appropriate distribution) statement += " DISTSTYLE AUTO" if "SORTKEY" not in statement.upper(): # Include sort key on first column (simple heuristic) # In production, this would be more sophisticated statement += " SORTKEY AUTO" return statement def _get_platform_metadata(self, connection: Any) -> dict[str, Any]: """Get Redshift-specific metadata and system information.""" metadata = { "platform": self.platform_name, "host": self.host, "port": self.port, "database": self.database, "result_cache_enabled": not self.disable_result_cache, } cursor = connection.cursor() try: # Get Redshift version cursor.execute("SELECT version()") result = cursor.fetchone() metadata["redshift_version"] = result[0] if result else "unknown" # Get cluster information cursor.execute(""" SELECT node_type, num_nodes, cluster_version, publicly_accessible FROM stv_cluster_configuration LIMIT 1 """) result = cursor.fetchone() if result: metadata["cluster_info"] = { "node_type": result[0], "num_nodes": result[1], "cluster_version": result[2], "publicly_accessible": result[3], } # Get current session information cursor.execute(""" SELECT current_user, current_database(), current_schema(), inet_client_addr(), inet_client_port() """) result = cursor.fetchone() if result: metadata["session_info"] = { "current_user": result[0], "current_database": result[1], "current_schema": result[2], "client_addr": result[3], "client_port": result[4], } # Get table information cursor.execute(f""" SELECT schemaname, tablename, tableowner, tablespace, hasindexes, hasrules, hastriggers FROM pg_tables WHERE schemaname = '{self.schema}' """) tables = cursor.fetchall() metadata["tables"] = [ { "schema": row[0], "table": row[1], "owner": row[2], "tablespace": row[3], "has_indexes": row[4], "has_rules": row[5], "has_triggers": row[6], } for row in tables ] except Exception as e: metadata["metadata_error"] = str(e) finally: cursor.close() return metadata def _get_existing_tables(self, connection: Any) -> list[str]: """Get list of existing tables from Redshift (normalized to lowercase). Queries information_schema for tables in the current schema and returns them as lowercase names to match Redshift's identifier normalization. Args: connection: Database connection Returns: List of table names (lowercase) """ cursor = connection.cursor() try: # Query information_schema for tables in current schema cursor.execute( """ SELECT table_name FROM information_schema.tables WHERE table_schema = %s AND table_type = 'BASE TABLE' """, (self.schema,), ) # Return lowercase table names return [row[0].lower() for row in cursor.fetchall()] finally: cursor.close() def _get_query_statistics(self, connection: Any, query_id: str) -> dict[str, Any]: """Get query statistics from Redshift system tables. Note: This retrieves performance telemetry only - not used for validation. Query row count validation uses the actual result set from cursor.fetchall(). """ cursor = connection.cursor() try: # Get the actual Redshift query ID for the most recent query # pg_last_query_id() returns the query ID of the last executed query in this session cursor.execute("SELECT pg_last_query_id()") result = cursor.fetchone() if not result or result[0] == -1: # No queries executed yet, or query ran only on leader node return {} redshift_query_id = result[0] # Query STL tables for query statistics using the actual Redshift query ID cursor.execute( """ SELECT query, DATEDIFF('microseconds', starttime, endtime) as duration_microsecs, DATEDIFF('microseconds', starttime, endtime) as cpu_time_microsecs, 0 as bytes_scanned, 0 as bytes_returned, 1 as slots, 1 as wlm_slots, aborted FROM stl_query WHERE query = %s ORDER BY starttime DESC LIMIT 1 """, (redshift_query_id,), ) result = cursor.fetchone() if result: return { "query_id": str(result[0]), "duration_microsecs": result[1] or 0, "cpu_time_microsecs": result[2] or 0, "bytes_scanned": result[3] or 0, "bytes_returned": result[4] or 0, "slots": result[5] or 1, "wlm_slots": result[6] or 1, "aborted": bool(result[7]) if result[7] is not None else False, } else: return {} except Exception: return {} finally: cursor.close()
[docs] def analyze_table(self, connection: Any, table_name: str) -> None: """Run ANALYZE on table for query optimization.""" cursor = connection.cursor() try: cursor.execute(f"ANALYZE {table_name.lower()}") except Exception as e: self.logger.warning(f"Failed to analyze table {table_name}: {e}") finally: cursor.close()
[docs] def vacuum_table(self, connection: Any, table_name: str) -> None: """Run VACUUM on table for space reclamation.""" cursor = connection.cursor() try: cursor.execute(f"VACUUM {table_name.lower()}") except Exception as e: self.logger.warning(f"Failed to vacuum table {table_name}: {e}") finally: cursor.close()
[docs] def get_query_plan(self, connection: Any, query: str) -> str: """Get query execution plan for analysis.""" cursor = connection.cursor() try: # Note: Query dialect translation is now handled automatically by the base adapter cursor.execute(f"EXPLAIN {query}") plan_rows = cursor.fetchall() return "\n".join([row[0] for row in plan_rows]) except Exception as e: return f"Could not get query plan: {e}" finally: cursor.close()
[docs] def close_connection(self, connection: Any) -> None: """Close Redshift connection.""" try: if connection and hasattr(connection, "close"): connection.close() except Exception as e: self.logger.warning(f"Error closing connection: {e}")
[docs] def supports_tuning_type(self, tuning_type) -> bool: """Check if Redshift supports a specific tuning type. Redshift supports: - DISTRIBUTION: Via DISTSTYLE and DISTKEY clauses - SORTING: Via SORTKEY clause (compound and interleaved) - PARTITIONING: Through table design patterns and date partitioning Args: tuning_type: The type of tuning to check support for Returns: True if the tuning type is supported by Redshift """ # Import here to avoid circular imports try: from benchbox.core.tuning.interface import TuningType return tuning_type in { TuningType.DISTRIBUTION, TuningType.SORTING, TuningType.PARTITIONING, } except ImportError: return False
[docs] def generate_tuning_clause(self, table_tuning) -> str: """Generate Redshift-specific tuning clauses for CREATE TABLE statements. Redshift supports: - DISTSTYLE (EVEN | KEY | ALL) DISTKEY (column) - SORTKEY (column1, column2, ...) or INTERLEAVED SORTKEY (column1, column2, ...) Args: table_tuning: The tuning configuration for the table Returns: SQL clause string to be appended to CREATE TABLE statement """ if not table_tuning or not table_tuning.has_any_tuning(): return "" clauses = [] try: # Import here to avoid circular imports from benchbox.core.tuning.interface import TuningType # Handle distribution strategy distribution_columns = table_tuning.get_columns_by_type(TuningType.DISTRIBUTION) if distribution_columns: # Sort by order and use first column as distribution key sorted_cols = sorted(distribution_columns, key=lambda col: col.order) dist_col = sorted_cols[0] # Use KEY distribution style with the specified column clauses.append("DISTSTYLE KEY") clauses.append(f"DISTKEY ({dist_col.name})") else: # Default to EVEN distribution if no distribution columns specified clauses.append("DISTSTYLE EVEN") # Handle sorting sort_columns = table_tuning.get_columns_by_type(TuningType.SORTING) if sort_columns: # Sort by order for sortkey sorted_cols = sorted(sort_columns, key=lambda col: col.order) column_names = [col.name for col in sorted_cols] # Use compound sort key by default (better for most OLAP workloads) # Could be made configurable to choose between COMPOUND and INTERLEAVED sortkey_clause = f"SORTKEY ({', '.join(column_names)})" clauses.append(sortkey_clause) # Handle partitioning through table naming/organization (logged but not in CREATE TABLE) partition_columns = table_tuning.get_columns_by_type(TuningType.PARTITIONING) if partition_columns: # Redshift partitioning is typically handled through table design patterns # We'll log the strategy but not add SQL clauses pass # Clustering not directly supported in Redshift CREATE TABLE except ImportError: # If tuning interface not available, return empty string pass return " ".join(clauses)
[docs] def apply_table_tunings(self, table_tuning, connection: Any) -> None: """Apply tuning configurations to a Redshift table. Redshift tuning approach: - DISTRIBUTION: Handled via DISTSTYLE/DISTKEY in CREATE TABLE - SORTING: Handled via SORTKEY in CREATE TABLE - Post-creation optimizations via ANALYZE and VACUUM Args: table_tuning: The tuning configuration to apply connection: Redshift connection Raises: ValueError: If the tuning configuration is invalid for Redshift """ if not table_tuning or not table_tuning.has_any_tuning(): return table_name = table_tuning.table_name.lower() self.logger.info(f"Applying Redshift tunings for table: {table_name}") cursor = connection.cursor() try: # Import here to avoid circular imports from benchbox.core.tuning.interface import TuningType # Redshift tuning is primarily handled at table creation time # Post-creation optimizations are limited # Verify table exists and get current configuration cursor.execute(f""" SELECT "schema", "table", diststyle, distkey, sortkey1, sortkey2, sortkey3, sortkey4 FROM pg_table_def WHERE schemaname = 'public' AND tablename = '{table_name.lower()}' """) result = cursor.fetchone() if result: current_diststyle = result[2] current_distkey = result[3] current_sortkeys = [sk for sk in result[4:8] if sk] # Filter out None values self.logger.info(f"Current configuration for {table_name}:") self.logger.info(f" Distribution style: {current_diststyle}") self.logger.info(f" Distribution key: {current_distkey}") self.logger.info(f" Sort keys: {current_sortkeys}") # Check if configuration matches desired tuning distribution_columns = table_tuning.get_columns_by_type(TuningType.DISTRIBUTION) sort_columns = table_tuning.get_columns_by_type(TuningType.SORTING) needs_recreation = False # Check distribution configuration if distribution_columns: sorted_cols = sorted(distribution_columns, key=lambda col: col.order) desired_distkey = sorted_cols[0].name if current_distkey != desired_distkey or current_diststyle != "KEY": needs_recreation = True self.logger.info( f"Distribution key mismatch: current='{current_distkey}', desired='{desired_distkey}'" ) # Check sort key configuration if sort_columns: sorted_cols = sorted(sort_columns, key=lambda col: col.order) desired_sortkeys = [col.name for col in sorted_cols] if current_sortkeys != desired_sortkeys: needs_recreation = True self.logger.info(f"Sort keys mismatch: current={current_sortkeys}, desired={desired_sortkeys}") if needs_recreation: self.logger.warning( f"Table {table_name} configuration differs from desired tuning. " "Redshift requires table recreation to change distribution/sort keys." ) else: self.logger.warning(f"Could not find table configuration for {table_name}") # Perform maintenance operations that can help with performance try: # Run ANALYZE to update table statistics cursor.execute(f"ANALYZE {table_name}") self.logger.info(f"Analyzed table statistics for {table_name}") # Run VACUUM to reclaim space and re-sort data if self.auto_vacuum: cursor.execute(f"VACUUM {table_name}") self.logger.info(f"Vacuumed table {table_name}") except Exception as e: self.logger.warning(f"Failed to perform maintenance operations on {table_name}: {e}") # Handle partitioning strategy partition_columns = table_tuning.get_columns_by_type(TuningType.PARTITIONING) if partition_columns: sorted_cols = sorted(partition_columns, key=lambda col: col.order) column_names = [col.name for col in sorted_cols] self.logger.info( f"Partitioning strategy for {table_name}: {', '.join(column_names)} (handled via table design patterns)" ) # Clustering not directly supported in Redshift cluster_columns = table_tuning.get_columns_by_type(TuningType.CLUSTERING) if cluster_columns: sorted_cols = sorted(cluster_columns, key=lambda col: col.order) column_names = [col.name for col in sorted_cols] self.logger.info( f"Clustering strategy for {table_name} achieved via sort keys: {', '.join(column_names)}" ) except ImportError: self.logger.warning("Tuning interface not available - skipping tuning application") except Exception as e: raise ValueError(f"Failed to apply tunings to Redshift table {table_name}: {e}") from e finally: cursor.close()
[docs] def apply_unified_tuning(self, unified_config: UnifiedTuningConfiguration, connection: Any) -> None: """Apply unified tuning configuration to Redshift. Args: unified_config: Unified tuning configuration to apply connection: Redshift connection """ if not unified_config: return # Apply constraint configurations self.apply_constraint_configuration(unified_config.primary_keys, unified_config.foreign_keys, connection) # Apply platform optimizations if unified_config.platform_optimizations: self.apply_platform_optimizations(unified_config.platform_optimizations, connection) # Apply table-level tunings for _table_name, table_tuning in unified_config.table_tunings.items(): self.apply_table_tunings(table_tuning, connection)
[docs] def apply_platform_optimizations(self, platform_config: PlatformOptimizationConfiguration, connection: Any) -> None: """Apply Redshift-specific platform optimizations. Redshift optimizations include: - Workload Management (WLM) queue configuration - Query group settings for resource allocation - Compression encoding optimization - Statistics collection and maintenance Args: platform_config: Platform optimization configuration connection: Redshift connection """ if not platform_config: return # Redshift optimizations are typically applied at session or workload level # Store optimizations for use during query execution and maintenance operations self.logger.info("Redshift platform optimizations stored for session and workload management")
[docs] def apply_constraint_configuration( self, primary_key_config: PrimaryKeyConfiguration, foreign_key_config: ForeignKeyConfiguration, connection: Any, ) -> None: """Apply constraint configurations to Redshift. Note: Redshift supports PRIMARY KEY and FOREIGN KEY constraints for query optimization, but they are informational only (not enforced). Constraints must be applied during table creation time. Args: primary_key_config: Primary key constraint configuration foreign_key_config: Foreign key constraint configuration connection: Redshift connection """ # Redshift constraints are applied at table creation time for query optimization # This method is called after tables are created, so log the configurations if primary_key_config and primary_key_config.enabled: self.logger.info( "Primary key constraints enabled for Redshift (informational only, applied during table creation)" ) if foreign_key_config and foreign_key_config.enabled: self.logger.info( "Foreign key constraints enabled for Redshift (informational only, applied during table creation)" )
# Redshift constraints are informational and used for query optimization # No additional work to do here as they're applied during CREATE TABLE def _build_redshift_config( platform: str, options: dict[str, Any], overrides: dict[str, Any], info: Any, ) -> Any: """Build Redshift database configuration with credential loading. This function loads saved credentials from the CredentialManager and merges them with CLI options and runtime overrides. Args: platform: Platform name (should be 'redshift') options: CLI platform options from --platform-option flags overrides: Runtime overrides from orchestrator info: Platform info from registry Returns: DatabaseConfig with credentials loaded """ from benchbox.core.config import DatabaseConfig from benchbox.security.credentials import CredentialManager # Load saved credentials cred_manager = CredentialManager() saved_creds = cred_manager.get_platform_credentials("redshift") or {} # Build merged options: saved_creds < options < overrides merged_options = {} merged_options.update(saved_creds) merged_options.update(options) merged_options.update(overrides) # Extract credential fields for DatabaseConfig name = info.display_name if info else "Amazon Redshift" driver_package = info.driver_package if info else "redshift-connector" # Build config dict with platform-specific fields at top-level # This allows RedshiftAdapter.__init__() and from_config() to access them via config.get() config_dict = { "type": "redshift", "name": name, "options": merged_options or {}, # Ensure options is never None (Pydantic v2 uses None if explicitly passed) "driver_package": driver_package, "driver_version": overrides.get("driver_version") or options.get("driver_version"), "driver_auto_install": bool(overrides.get("driver_auto_install", options.get("driver_auto_install", False))), # Platform-specific fields at top-level (adapters expect these here) "host": merged_options.get("host"), "port": merged_options.get("port"), # NOTE: database is NOT included here - from_config() generates it from benchmark context # Only explicit overrides (via --platform-option database=...) should bypass generation "username": merged_options.get("username"), "password": merged_options.get("password"), "schema": merged_options.get("schema"), # S3 and AWS configuration "s3_bucket": merged_options.get("s3_bucket"), "s3_prefix": merged_options.get("s3_prefix"), "iam_role": merged_options.get("iam_role"), "aws_access_key_id": merged_options.get("aws_access_key_id"), "aws_secret_access_key": merged_options.get("aws_secret_access_key"), "aws_region": merged_options.get("aws_region"), # Optional settings "cluster_identifier": merged_options.get("cluster_identifier"), "admin_database": merged_options.get("admin_database", "dev"), "connect_timeout": merged_options.get("connect_timeout"), "statement_timeout": merged_options.get("statement_timeout"), "sslmode": merged_options.get("sslmode"), # Benchmark context for config-aware database naming (from overrides) "benchmark": overrides.get("benchmark"), "scale_factor": overrides.get("scale_factor"), "tuning_config": overrides.get("tuning_config"), } # Only include explicit database override if provided via CLI or overrides # Saved credentials should NOT override generated database names if "database" in overrides and overrides["database"]: config_dict["database"] = overrides["database"] return DatabaseConfig(**config_dict) # Register the config builder with the platform hook registry # This must happen when the module is imported try: from benchbox.cli.platform_hooks import PlatformHookRegistry PlatformHookRegistry.register_config_builder("redshift", _build_redshift_config) except ImportError: # Platform hooks may not be available in all contexts (e.g., core-only usage) pass