"""Snowflake platform adapter with native cloud data warehouse optimizations.
Provides Snowflake-specific optimizations for cloud-native analytics,
including multi-cluster warehouse support and automatic scaling.
Copyright 2026 Joe Harris / BenchBox Project
Licensed under the MIT License. See LICENSE file in the project root for details.
"""
from __future__ import annotations
import json
import time
from pathlib import Path
from typing import TYPE_CHECKING, Any
if TYPE_CHECKING:
from benchbox.core.tuning.interface import (
ForeignKeyConfiguration,
PlatformOptimizationConfiguration,
PrimaryKeyConfiguration,
UnifiedTuningConfiguration,
)
from ..core.exceptions import ConfigurationError
from ..utils.dependencies import check_platform_dependencies, get_dependency_error_message
from .base import PlatformAdapter
try:
import snowflake.connector
from snowflake.connector import DictCursor
from snowflake.connector.errors import Error as SnowflakeError
except ImportError:
snowflake = None
DictCursor = None
SnowflakeError = Exception
[docs]
class SnowflakeAdapter(PlatformAdapter):
"""Snowflake platform adapter with cloud data warehouse optimizations."""
[docs]
def __init__(self, **config):
super().__init__(**config)
# Check dependencies with improved error message
available, missing = check_platform_dependencies("snowflake")
if not available:
error_msg = get_dependency_error_message("snowflake", missing)
raise ImportError(error_msg)
self._dialect = "snowflake"
# Snowflake connection configuration
self.account = config.get("account")
self.warehouse = config.get("warehouse") or "COMPUTE_WH"
self.database = config.get("database") or "BENCHBOX"
self.schema = config.get("schema") or "PUBLIC"
self.username = config.get("username")
self.password = config.get("password")
self.role = config.get("role")
# Authentication options
self.authenticator = config.get("authenticator") or "snowflake" # snowflake, oauth, etc.
self.private_key_path = config.get("private_key_path")
self.private_key_passphrase = config.get("private_key_passphrase")
# Warehouse settings
self.warehouse_size = config.get("warehouse_size") or "MEDIUM"
self.auto_suspend = config.get("auto_suspend") if config.get("auto_suspend") is not None else 300 # seconds
self.auto_resume = config.get("auto_resume") if config.get("auto_resume") is not None else True
self.multi_cluster_warehouse = (
config.get("multi_cluster_warehouse") if config.get("multi_cluster_warehouse") is not None else False
)
# Session settings
self.query_tag = config.get("query_tag") or "BenchBox"
self.timezone = config.get("timezone") or "UTC"
# Result cache control - disable by default for accurate benchmarking
self.disable_result_cache = config.get("disable_result_cache", True)
# Validation strictness - raise errors if cache control validation fails
self.strict_validation = config.get("strict_validation", True)
# Nondeterministic error suppression - disabled by default to preserve Snowflake's
# default behavior of erroring on nondeterministic MERGE/UPDATE operations.
# Enable this for workloads that intentionally use nondeterministic operations.
self.suppress_nondeterministic_errors = config.get("suppress_nondeterministic_errors", False)
# Warehouse modification control - when True, BenchBox will ALTER WAREHOUSE settings
# (size, auto-suspend, scaling policy). These changes PERSIST beyond the benchmark run.
# Default is False to avoid unexpected infrastructure changes in governed environments.
# Set to True explicitly if you want BenchBox to configure your warehouse for benchmarking.
self.modify_warehouse_settings = config.get("modify_warehouse_settings", False)
# File format settings
self.file_format = config.get("file_format") or "CSV"
self.compression = config.get("compression") or "AUTO"
# Cloud storage staging (optional - Snowflake uses internal stages by default)
# staging_root is passed by orchestrator when using CloudStagingPath
# For now, we log it but continue using internal stages (which work with local files)
staging_root = config.get("staging_root")
if staging_root:
from benchbox.utils.cloud_storage import get_cloud_path_info
path_info = get_cloud_path_info(staging_root)
self.logger.info(
f"Note: staging_root provided ({path_info['provider']}://{path_info['bucket']}), "
"but Snowflake adapter uses internal stages for data loading"
)
if not all([self.account, self.username, self.password, self.warehouse, self.database]):
missing = []
if not self.account:
missing.append("account (or SNOWFLAKE_ACCOUNT)")
if not self.username:
missing.append("username (or SNOWFLAKE_USER)")
if not self.password:
missing.append("password (or SNOWFLAKE_PASSWORD)")
if not self.warehouse:
missing.append("warehouse (or SNOWFLAKE_WAREHOUSE)")
if not self.database:
missing.append("database (or SNOWFLAKE_DATABASE)")
raise ConfigurationError(
f"Snowflake configuration is incomplete. Missing: {', '.join(missing)}\n"
"Configure with one of:\n"
" 1. CLI: benchbox platforms setup --platform snowflake\n"
" 2. Environment variables: SNOWFLAKE_ACCOUNT, SNOWFLAKE_USER, SNOWFLAKE_PASSWORD, etc.\n"
" 3. CLI options: --platform-option account=<account> --platform-option warehouse=<wh>"
)
@property
def platform_name(self) -> str:
return "Snowflake"
[docs]
@staticmethod
def add_cli_arguments(parser) -> None:
"""Add Snowflake-specific CLI arguments."""
sf_group = parser.add_argument_group("Snowflake Arguments")
sf_group.add_argument("--account", type=str, help="Snowflake account identifier")
sf_group.add_argument("--warehouse", type=str, default="COMPUTE_WH", help="Warehouse name")
sf_group.add_argument(
"--platform", type=str, default=None, help="Database name (auto-generated if not specified)"
)
sf_group.add_argument("--schema", type=str, default="PUBLIC", help="Schema name")
sf_group.add_argument("--username", type=str, help="User name")
sf_group.add_argument("--password", type=str, help="User password")
sf_group.add_argument("--role", type=str, help="Role to assume for the session")
sf_group.add_argument("--authenticator", type=str, default="snowflake", help="Authentication method")
sf_group.add_argument("--private-key-path", type=str, help="Path to private key for key pair auth")
# Behavior control options
sf_group.add_argument(
"--modify-warehouse-settings",
action="store_true",
default=False,
help="Modify warehouse settings (size, auto-suspend, scaling). PERSISTENT changes - use with caution.",
)
sf_group.add_argument(
"--suppress-nondeterministic-errors",
action="store_true",
default=False,
help="Suppress errors on nondeterministic MERGE/UPDATE operations",
)
sf_group.add_argument(
"--no-disable-result-cache",
action="store_false",
dest="disable_result_cache",
help="Enable result cache (disabled by default for accurate benchmarking)",
)
[docs]
@classmethod
def from_config(cls, config: dict[str, Any]):
"""Create Snowflake adapter from unified configuration."""
from benchbox.utils.database_naming import generate_database_name
adapter_config: dict[str, Any] = {}
# Generate proper database name using benchmark characteristics
# (unless explicitly overridden in config)
if "database" in config and config["database"]:
# User explicitly provided database name - use it
adapter_config["database"] = config["database"]
else:
# Generate configuration-aware database name
database_name = generate_database_name(
benchmark_name=config["benchmark"],
scale_factor=config["scale_factor"],
platform="snowflake",
tuning_config=config.get("tuning_config"),
)
adapter_config["database"] = database_name
# Copy other config keys
for key in [
"account",
"warehouse",
# "database" - handled above with generation logic
"schema",
"username",
"password",
"role",
"authenticator",
"private_key_path",
"private_key_passphrase",
"warehouse_size",
"auto_suspend",
"auto_resume",
"multi_cluster_warehouse",
"query_tag",
"timezone",
"file_format",
"compression",
# Behavior control options
"disable_result_cache",
"strict_validation",
"suppress_nondeterministic_errors",
"modify_warehouse_settings",
]:
if key in config:
adapter_config[key] = config[key]
return cls(**adapter_config)
[docs]
def get_target_dialect(self) -> str:
"""Return the target SQL dialect for Snowflake."""
return "snowflake"
def _get_connection_params(self, **connection_config) -> dict[str, Any]:
"""Get standardized connection parameters."""
return {
"account": connection_config.get("account", self.account),
"username": connection_config.get("username", self.username),
"password": connection_config.get("password", self.password),
"warehouse": connection_config.get("warehouse", self.warehouse),
"role": connection_config.get("role", self.role),
}
def _create_admin_connection(self, **connection_config) -> Any:
"""Create Snowflake connection for admin operations."""
params = self._get_connection_params(**connection_config)
return snowflake.connector.connect(
**params,
client_session_keep_alive=True,
login_timeout=30,
network_timeout=60,
# Don't specify database for admin operations
)
[docs]
def check_server_database_exists(self, **connection_config) -> bool:
"""Check if database exists in Snowflake account.
Also checks for existing schemas and tables, since they may exist from a
previous run even if the database doesn't formally exist at account level.
"""
try:
connection = self._create_admin_connection(**connection_config)
cursor = connection.cursor()
database = connection_config.get("database", self.database)
schema = connection_config.get("schema", self.schema)
# Check if database exists at account level
cursor.execute("SHOW DATABASES")
databases = [row[1] for row in cursor.fetchall()] # Database name is in column 1
if database.upper() in [db.upper() for db in databases]:
return True
# Even if database doesn't formally exist, check if schema/tables exist
# (they might exist from previous run where database/schema were created)
try:
# Quote identifiers and escape LIKE patterns for SQL safety
cursor.execute(f'USE DATABASE "{database}"')
schema_escaped = schema.replace("'", "''")
cursor.execute(f"SHOW SCHEMAS LIKE '{schema_escaped}'")
schemas = cursor.fetchall()
if schemas:
# Schema exists - check for tables
cursor.execute(f'USE SCHEMA "{schema}"')
cursor.execute("SHOW TABLES")
tables = cursor.fetchall()
if tables:
# Tables exist - database should be considered as existing
return True
except Exception:
# Database or schema don't exist
pass
return False
except Exception:
# If we can't connect or check, assume database doesn't exist
return False
finally:
if "connection" in locals() and connection:
connection.close()
[docs]
def drop_database(self, **connection_config) -> None:
"""Drop database in Snowflake account."""
try:
connection = self._create_admin_connection(**connection_config)
cursor = connection.cursor()
database = connection_config.get("database", self.database)
# Drop database and all its schemas/tables (quote identifier for SQL safety)
cursor.execute(f'DROP DATABASE IF EXISTS "{database}"')
except Exception as e:
raise RuntimeError(f"Failed to drop Snowflake database {database}: {e}")
finally:
if "connection" in locals() and connection:
connection.close()
[docs]
def create_connection(self, **connection_config) -> Any:
"""Create optimized Snowflake connection."""
self.log_operation_start("Snowflake connection")
self.log_verbose("Creating Snowflake connection")
# Handle existing database using base class method
self.handle_existing_database(**connection_config)
# Get connection parameters
params = self._get_connection_params(**connection_config)
self.log_very_verbose(
f"Snowflake connection params: account={params.get('account')}, database={connection_config.get('database', self.database)}"
)
database = connection_config.get("database", self.database)
schema = connection_config.get("schema", self.schema)
try:
# Prepare connection parameters
self.log_verbose(f"Connecting to Snowflake account: {params['account']}")
conn_params = {
"account": params["account"],
"user": params["username"], # Snowflake uses 'user' not 'username'
"password": params["password"],
"warehouse": params["warehouse"],
"database": database,
"schema": schema,
"application": "BenchBox",
"timezone": self.timezone,
"autocommit": True,
}
if params["role"]:
conn_params["role"] = params["role"]
self.log_very_verbose(f"Using role: {params['role']}")
# Handle different authentication methods
if self.authenticator != "snowflake":
self.log_very_verbose(f"Using authenticator: {self.authenticator}")
conn_params["authenticator"] = self.authenticator
if self.private_key_path:
# Key pair authentication
self.log_verbose("Using key pair authentication")
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.serialization import (
load_pem_private_key,
)
with open(self.private_key_path, "rb") as key_file:
private_key = load_pem_private_key(
key_file.read(),
password=self.private_key_passphrase.encode() if self.private_key_passphrase else None,
)
pkb = private_key.private_bytes(
encoding=serialization.Encoding.DER,
format=serialization.PrivateFormat.PKCS8,
encryption_algorithm=serialization.NoEncryption(),
)
conn_params["private_key"] = pkb
del conn_params["password"] # Exclude password when using key pair
# Create connection
connection = snowflake.connector.connect(**conn_params)
# Test connection
self.log_verbose("Testing Snowflake connection")
cursor = connection.cursor()
cursor.execute("SELECT CURRENT_VERSION()")
cursor.fetchall()
cursor.close()
self.log_verbose(f"Connected to Snowflake account: {self.account}")
self.log_verbose(f"Using warehouse: {self.warehouse}, database: {self.database}, schema: {self.schema}")
self.log_operation_complete(
"Snowflake connection", details=f"Connected to account {self.account}, database: {self.database}"
)
return connection
except Exception as e:
self.logger.error(f"Failed to connect to Snowflake: {e}")
raise
def _should_skip_schema_creation(self, benchmark, connection: Any) -> bool:
"""Check if schema already exists with data, allowing us to skip recreation.
This prevents dropping/recreating tables which would:
1. Remove internal stages (@%TABLE)
2. Delete uploaded files
3. Force expensive re-uploads
Args:
benchmark: Benchmark instance
connection: Snowflake connection
Returns:
True if all expected tables exist with data, False otherwise
"""
try:
cursor = connection.cursor()
# Get expected tables from benchmark
expected_tables = self._get_expected_tables(benchmark)
if not expected_tables:
return False # Can't determine, recreate to be safe
# Check each table exists and has data
for table_name in expected_tables:
table_upper = table_name.upper()
# Check if table exists
cursor.execute(f"SHOW TABLES LIKE '{table_upper}'")
if not cursor.fetchone():
self.log_verbose(f"Table {table_upper} missing - schema creation required")
return False
# Check if table has data
cursor.execute(f"SELECT COUNT(*) FROM {table_upper}")
row_count = cursor.fetchone()[0]
if row_count == 0:
self.log_verbose(f"Table {table_upper} empty - schema creation required")
return False
self.log_verbose(f"All {len(expected_tables)} tables exist with data - skipping schema creation")
return True
except Exception as e:
self.log_very_verbose(f"Schema check failed: {e} - proceeding with creation")
return False # If check fails, recreate to be safe
[docs]
def create_schema(self, benchmark, connection: Any) -> float:
"""Create schema using Snowflake table definitions."""
self.log_operation_start("Snowflake schema creation")
start_time = time.time()
self.log_verbose(f"Creating schema for benchmark: {benchmark.__class__.__name__}")
self.log_very_verbose(f"Target database: {self.database}, schema: {self.schema}")
cursor = connection.cursor()
try:
# Ensure database and schema exist
self.log_verbose(f"Creating/using database: {self.database}")
cursor.execute(f"CREATE DATABASE IF NOT EXISTS {self.database}")
cursor.execute(f"USE DATABASE {self.database}")
self.log_verbose(f"Creating/using schema: {self.schema}")
cursor.execute(f"CREATE SCHEMA IF NOT EXISTS {self.schema}")
cursor.execute(f"USE SCHEMA {self.schema}")
# Check if we can skip table creation (tables exist with data)
if self._should_skip_schema_creation(benchmark, connection):
elapsed_time = time.time() - start_time
self.log_operation_complete(
"Snowflake schema creation", details=f"Skipped (existing data is valid) in {elapsed_time:.2f}s"
)
return elapsed_time
# Set query tag for tracking
self.log_very_verbose(f"Setting query tag: {self.query_tag}_schema_creation")
cursor.execute(f"ALTER SESSION SET QUERY_TAG = '{self.query_tag}_schema_creation'")
# Use common schema creation helper
self.log_very_verbose("Retrieving schema SQL from benchmark")
schema_sql = self._create_schema_with_tuning(benchmark, source_dialect="duckdb")
# Split schema into individual statements and execute
statements = [stmt.strip() for stmt in schema_sql.split(";") if stmt.strip()]
self.log_verbose(f"Executing {len(statements)} schema statements")
for i, statement in enumerate(statements, 1):
# Optimize table definition for Snowflake
optimized_statement = self._optimize_table_definition(statement)
self.log_very_verbose(f"Executing statement {i}/{len(statements)}: {optimized_statement[:100]}...")
cursor.execute(optimized_statement)
self.log_verbose("Schema created successfully")
except Exception as e:
self.logger.error(f"Schema creation failed: {e}")
raise
finally:
cursor.close()
elapsed_time = time.time() - start_time
self.log_operation_complete("Snowflake schema creation", details=f"Completed in {elapsed_time:.2f}s")
return elapsed_time
[docs]
def load_data(
self, benchmark, connection: Any, data_dir: Path
) -> tuple[dict[str, int], float, dict[str, Any] | None]:
"""Load data using Snowflake PUT and COPY INTO commands."""
self.log_operation_start("Snowflake data loading")
self.log_verbose(f"Starting data loading for benchmark: {benchmark.__class__.__name__}")
self.log_very_verbose(f"Data directory: {data_dir}")
start_time = time.time()
table_stats = {}
cursor = connection.cursor()
try:
# Set query tag for tracking
self.log_very_verbose(f"Setting query tag: {self.query_tag}_data_loading")
cursor.execute(f"ALTER SESSION SET QUERY_TAG = '{self.query_tag}_data_loading'")
# Create file format for CSV loading if needed
self.log_verbose("Creating file formats for data loading")
cursor.execute(f"""
CREATE OR REPLACE FILE FORMAT {self.schema}.BENCHBOX_CSV_FORMAT
TYPE = 'CSV'
FIELD_DELIMITER = ','
RECORD_DELIMITER = '\\n'
SKIP_HEADER = 0
ERROR_ON_COLUMN_COUNT_MISMATCH = FALSE
REPLACE_INVALID_CHARACTERS = TRUE
EMPTY_FIELD_AS_NULL = TRUE
COMPRESSION = '{self.compression}'
""")
# Create file format for TPC-H .tbl files (pipe delimited)
cursor.execute(f"""
CREATE OR REPLACE FILE FORMAT {self.schema}.BENCHBOX_TBL_FORMAT
TYPE = 'CSV'
FIELD_DELIMITER = '|'
RECORD_DELIMITER = '\\n'
SKIP_HEADER = 0
ERROR_ON_COLUMN_COUNT_MISMATCH = FALSE
REPLACE_INVALID_CHARACTERS = TRUE
EMPTY_FIELD_AS_NULL = TRUE
COMPRESSION = '{self.compression}'
""")
# Get data files from benchmark or manifest fallback
if hasattr(benchmark, "tables") and benchmark.tables:
data_files = benchmark.tables
else:
data_files = None
try:
manifest_path = Path(data_dir) / "_datagen_manifest.json"
if manifest_path.exists():
with open(manifest_path) as f:
manifest = json.load(f)
tables = manifest.get("tables") or {}
mapping = {}
for table, entries in tables.items():
if entries:
# Collect ALL chunk files, not just the first one
chunk_paths = []
for entry in entries:
rel = entry.get("path")
if rel:
chunk_paths.append(Path(data_dir) / rel)
if chunk_paths:
mapping[table] = chunk_paths
if mapping:
data_files = mapping
self.log_very_verbose("Using data files from _datagen_manifest.json")
except Exception as e:
self.log_very_verbose(f"Manifest fallback failed: {e}")
if not data_files:
# No data files available - benchmark should have generated data first
raise ValueError("No data files found. Ensure benchmark.generate_data() was called first.")
# Load data for each table (handle multi-chunk files)
for table_name, file_paths in data_files.items():
# Normalize to list (data resolver should always return lists now)
if not isinstance(file_paths, list):
file_paths = [file_paths]
# Filter out non-existent or empty files
valid_files = []
for file_path in file_paths:
file_path = Path(file_path)
if file_path.exists() and file_path.stat().st_size > 0:
valid_files.append(file_path)
if not valid_files:
self.logger.warning(f"Skipping {table_name} - no valid data files")
table_stats[table_name] = 0
continue
chunk_info = f" from {len(valid_files)} file(s)" if len(valid_files) > 1 else ""
self.log_verbose(f"Loading data for table: {table_name}{chunk_info}")
try:
load_start = time.time()
table_name_upper = table_name.upper()
# Create stage for file upload
stage_name = f"@%{table_name_upper}"
self.log_very_verbose(f"Using stage: {stage_name}")
# Upload all files to Snowflake internal stage using PUT
for file_idx, file_path in enumerate(valid_files):
put_command = f"PUT file://{file_path.absolute()} {stage_name}"
chunk_msg = f" (chunk {file_idx + 1}/{len(valid_files)})" if len(valid_files) > 1 else ""
self.log_very_verbose(f"Uploading file{chunk_msg} with PUT: {file_path.name}")
cursor.execute(put_command)
# Determine file format based on first file name
# Check if filename contains .tbl anywhere (handles .tbl, .tbl.1, .tbl.gz, .tbl.1.gz, etc.)
first_file = valid_files[0]
file_str = str(first_file.name)
if ".tbl" in file_str:
file_format = f"{self.schema}.BENCHBOX_TBL_FORMAT"
self.log_very_verbose(f"Using TBL file format for {table_name}")
else:
file_format = f"{self.schema}.BENCHBOX_CSV_FORMAT"
self.log_very_verbose(f"Using CSV file format for {table_name}")
# Load data using COPY INTO (loads all files from stage)
copy_command = f"""
COPY INTO {table_name_upper}
FROM {stage_name}
FILE_FORMAT = (FORMAT_NAME = '{file_format}')
ON_ERROR = 'CONTINUE'
PURGE = TRUE
"""
self.log_very_verbose(f"Executing COPY INTO for {table_name_upper}")
cursor.execute(copy_command)
copy_results = cursor.fetchall()
# Parse copy results to get row count
# Snowflake COPY INTO result columns:
# 0: file, 1: status, 2: rows_parsed, 3: rows_loaded, 4: error_limit, 5: first_error, 6: first_error_line, ...
rows_loaded = 0
for row in copy_results:
if len(row) > 3: # Ensure rows_loaded column exists
status = str(row[1]) if len(row) > 1 else "UNKNOWN"
try:
loaded = int(row[3]) # row[3] = rows_loaded (integer)
rows_loaded += loaded
# Log warning if file had issues, including error details
if status != "LOADED":
file_name = str(row[0]) if len(row) > 0 else "unknown"
error_msg = str(row[5]) if len(row) > 5 and row[5] else "No error message provided"
self.logger.warning(
f"File {file_name} status: {status}, loaded {loaded} rows. Error: {error_msg}"
)
except (ValueError, TypeError) as e:
self.logger.warning(f"Could not parse rows_loaded from COPY result: {e}")
continue
# Get actual row count from table
cursor.execute(f"SELECT COUNT(*) FROM {table_name_upper}")
actual_count = cursor.fetchone()[0]
table_stats[table_name_upper] = actual_count
load_time = time.time() - load_start
self.log_verbose(
f"✅ Loaded {actual_count:,} rows into {table_name_upper}{chunk_info} in {load_time:.2f}s"
)
except Exception as e:
self.logger.error(f"Failed to load {table_name}: {str(e)[:100]}...")
table_stats[table_name.upper()] = 0
total_time = time.time() - start_time
total_rows = sum(table_stats.values())
self.log_verbose(f"✅ Loaded {total_rows:,} total rows in {total_time:.2f}s")
self.log_operation_complete(
"Snowflake data loading", details=f"Loaded {total_rows:,} rows in {total_time:.2f}s"
)
except Exception as e:
self.logger.error(f"Data loading failed: {e}")
raise
finally:
cursor.close()
# Snowflake doesn't provide detailed per-table timings yet
return table_stats, total_time, None
[docs]
def validate_session_cache_control(self, connection: Any) -> dict[str, Any]:
"""Validate that session-level cache control settings were successfully applied.
Args:
connection: Active Snowflake database connection
Returns:
dict with:
- validated: bool - Whether validation passed
- cache_disabled: bool - Whether cache is actually disabled
- settings: dict - Actual session settings
- warnings: list[str] - Any validation warnings
- errors: list[str] - Any validation errors
Raises:
ConfigurationError: If cache control validation fails and strict_validation=True
"""
cursor = connection.cursor()
result = {
"validated": False,
"cache_disabled": False,
"settings": {},
"warnings": [],
"errors": [],
}
try:
# Query current session parameter value
cursor.execute("SELECT SYSTEM$GET_SESSION_PARAMETER('USE_CACHED_RESULT') as value")
row = cursor.fetchone()
if row:
actual_value = str(row[0]).upper()
result["settings"]["USE_CACHED_RESULT"] = actual_value
# Determine expected value based on configuration
expected_value = "FALSE" if self.disable_result_cache else "TRUE"
if actual_value == expected_value:
result["validated"] = True
result["cache_disabled"] = actual_value == "FALSE"
self.logger.debug(
f"Cache control validated: USE_CACHED_RESULT={actual_value} (expected {expected_value})"
)
else:
error_msg = (
f"Cache control validation failed: "
f"expected USE_CACHED_RESULT={expected_value}, "
f"got {actual_value}"
)
result["errors"].append(error_msg)
self.logger.error(error_msg)
# Raise error if strict validation mode enabled
if self.strict_validation:
raise ConfigurationError(
"Snowflake session cache control validation failed - "
"benchmark results may be incorrect due to cached query results",
details=result,
)
else:
warning_msg = "Could not retrieve USE_CACHED_RESULT parameter from session"
result["warnings"].append(warning_msg)
self.logger.warning(warning_msg)
except Exception as e:
# If this is our ConfigurationError, re-raise it
if isinstance(e, ConfigurationError):
raise
# Otherwise log validation error
error_msg = f"Validation query failed: {e}"
result["errors"].append(error_msg)
self.logger.error(f"Cache control validation error: {e}")
# Raise if strict mode and query failed
if self.strict_validation:
raise ConfigurationError(
"Failed to validate Snowflake cache control settings",
details={"original_error": str(e), "validation_result": result},
)
finally:
cursor.close()
return result
[docs]
def execute_query(
self,
connection: Any,
query: str,
query_id: str,
benchmark_type: str | None = None,
scale_factor: float | None = None,
validate_row_count: bool = True,
stream_id: int | None = None,
) -> dict[str, Any]:
"""Execute query with detailed timing and performance tracking."""
self.log_operation_start("Snowflake query execution", query_id)
self.log_very_verbose(f"Executing query {query_id}: {query[:100]}...")
start_time = time.time()
cursor = connection.cursor()
try:
# Set query tag for tracking
self.log_very_verbose(f"Setting query tag: {self.query_tag}_{query_id}")
cursor.execute(f"ALTER SESSION SET QUERY_TAG = '{self.query_tag}_{query_id}'")
# Execute the query
# Note: Query dialect translation is now handled automatically by the base adapter
self.log_verbose(f"Executing query {query_id} on Snowflake")
cursor.execute(query)
result = cursor.fetchall()
execution_time = time.time() - start_time
actual_row_count = len(result) if result else 0
# Get query history for performance metrics
query_stats = self._get_query_statistics(connection, query_id)
# Validate row count if enabled and benchmark type is provided
validation_result = None
if validate_row_count and benchmark_type:
from benchbox.core.validation.query_validation import QueryValidator
validator = QueryValidator()
validation_result = validator.validate_query_result(
benchmark_type=benchmark_type,
query_id=query_id,
actual_row_count=actual_row_count,
scale_factor=scale_factor,
stream_id=stream_id,
)
# Log validation result
if validation_result.warning_message:
self.log_verbose(f"Row count validation: {validation_result.warning_message}")
elif not validation_result.is_valid:
self.log_verbose(f"Row count validation FAILED: {validation_result.error_message}")
else:
self.log_very_verbose(
f"Row count validation PASSED: {actual_row_count} rows "
f"(expected: {validation_result.expected_row_count})"
)
# Use base helper to build result with consistent validation field mapping
result_dict = self._build_query_result_with_validation(
query_id=query_id,
execution_time=execution_time,
actual_row_count=actual_row_count,
first_row=result[0] if result else None,
validation_result=validation_result,
)
# Include Snowflake-specific fields
result_dict["translated_query"] = None # Translation handled by base adapter
result_dict["query_statistics"] = query_stats
# Map query_statistics to resource_usage for cost calculation
result_dict["resource_usage"] = query_stats
# Log completion based on final status
if result_dict["status"] == "FAILED":
self.log_operation_complete("Snowflake query execution", query_id, "FAILED: validation error")
else:
self.log_verbose(f"Query {query_id} completed: {actual_row_count} rows in {execution_time:.3f}s")
self.log_operation_complete("Snowflake query execution", query_id, f"returned {actual_row_count} rows")
return result_dict
except Exception as e:
execution_time = time.time() - start_time
self.log_verbose(f"Query {query_id} failed after {execution_time:.3f}s: {e}")
self.log_operation_complete("Snowflake query execution", query_id, f"FAILED: {e}")
return {
"query_id": query_id,
"status": "FAILED",
"execution_time": execution_time,
"rows_returned": 0,
"error": str(e),
"error_type": type(e).__name__,
}
finally:
cursor.close()
def _optimize_table_definition(self, statement: str) -> str:
"""Optimize table definition for Snowflake.
Makes tables idempotent by using CREATE OR REPLACE TABLE.
"""
if not statement.upper().startswith("CREATE TABLE"):
return statement
# Ensure idempotency with OR REPLACE (defense-in-depth)
if "CREATE TABLE" in statement and "OR REPLACE" not in statement.upper():
statement = statement.replace("CREATE TABLE", "CREATE OR REPLACE TABLE", 1)
# Snowflake automatically optimizes most aspects, but we can add clustering keys
# This is a simplified heuristic - in production would be more sophisticated
if "CLUSTER BY" not in statement.upper():
# Include clustering on first column (simple heuristic)
# Snowflake will auto-cluster in most cases anyway
pass
return statement
def _get_existing_tables(self, connection: Any) -> list[str]:
"""Get list of existing tables using Snowflake SHOW TABLES command.
Args:
connection: Snowflake connection
Returns:
List of table names (lowercase, normalized for case-insensitive comparison)
"""
try:
cursor = connection.cursor()
cursor.execute("SHOW TABLES")
# SHOW TABLES returns: created_on, name, database_name, schema_name, kind, ...
# Table name is in column index 1
# Normalize to lowercase since Snowflake is case-insensitive but stores uppercase,
# while benchmarks expect lowercase names
tables = [row[1].lower() for row in cursor.fetchall()]
return tables
except Exception:
# Fallback to base implementation if SHOW TABLES fails
return []
def _validate_data_integrity(
self, benchmark, connection: Any, table_stats: dict[str, int]
) -> tuple[str, dict[str, Any]]:
"""Validate basic data integrity checks using Snowflake cursor pattern.
Snowflake connections require cursor-based execution, unlike the base
adapter which assumes connection.execute() exists.
Args:
benchmark: Benchmark instance
connection: Snowflake connection
table_stats: Dictionary of table names to row counts
Returns:
Tuple of (status, validation_details)
"""
validation_details = {}
try:
# Verify tables are accessible using Snowflake cursor
accessible_tables = []
inaccessible_tables = []
cursor = connection.cursor()
for table_name in table_stats:
try:
# Try a simple SELECT to verify table is accessible
# table_stats has uppercase keys from Snowflake
cursor.execute(f"SELECT 1 FROM {table_name} LIMIT 1")
cursor.fetchone() # Consume the result to prevent resource leaks
accessible_tables.append(table_name)
except Exception:
inaccessible_tables.append(table_name)
if inaccessible_tables:
validation_details["inaccessible_tables"] = inaccessible_tables
validation_details["constraints_enabled"] = False
return "FAILED", validation_details
else:
validation_details["accessible_tables"] = accessible_tables
validation_details["constraints_enabled"] = True
return "PASSED", validation_details
except Exception as e:
validation_details["constraints_enabled"] = False
validation_details["integrity_error"] = str(e)
return "FAILED", validation_details
def _get_query_statistics(
self, connection: Any, query_id: str, max_retries: int = 3, initial_delay: float = 0.5
) -> dict[str, Any]:
"""Get detailed query statistics from Snowflake query history.
Snowflake query history may not be immediately available after query execution.
This method implements retry logic with exponential backoff to handle delayed
statistics availability.
Args:
connection: Snowflake connection
query_id: Query identifier to look up in history
max_retries: Maximum number of retry attempts (default: 3)
initial_delay: Initial delay in seconds between retries (default: 0.5s)
Returns:
Dictionary with query statistics or note if not available
"""
import time as time_module
cursor = connection.cursor()
last_error = None
for attempt in range(max_retries + 1):
try:
# Query the QUERY_HISTORY view for performance metrics
cursor.execute(f"""
SELECT
QUERY_ID,
QUERY_TEXT,
TOTAL_ELAPSED_TIME,
EXECUTION_TIME,
COMPILATION_TIME,
BYTES_SCANNED,
BYTES_WRITTEN,
BYTES_SPILLED_TO_LOCAL_STORAGE,
BYTES_SPILLED_TO_REMOTE_STORAGE,
ROWS_PRODUCED,
ROWS_EXAMINED,
CREDITS_USED_CLOUD_SERVICES,
WAREHOUSE_SIZE,
CLUSTER_NUMBER
FROM TABLE(INFORMATION_SCHEMA.QUERY_HISTORY(
END_TIME_RANGE_START => DATEADD('MINUTE', -5, CURRENT_TIMESTAMP()),
END_TIME_RANGE_END => CURRENT_TIMESTAMP()
))
WHERE QUERY_TAG LIKE '%{query_id}%'
ORDER BY START_TIME DESC
LIMIT 1
""")
result = cursor.fetchone()
if result:
# Statistics available
cursor.close()
return {
"snowflake_query_id": result[0],
"total_elapsed_time_ms": result[2],
"execution_time_ms": result[3],
"compilation_time_ms": result[4],
"bytes_scanned": result[5],
"bytes_written": result[6],
"bytes_spilled_local": result[7],
"bytes_spilled_remote": result[8],
"rows_produced": result[9],
"rows_examined": result[10],
"credits_used": result[11],
"warehouse_size": result[12],
"cluster_number": result[13],
"retrieval_attempts": attempt + 1,
}
else:
# Statistics not yet available
if attempt < max_retries:
# Retry with exponential backoff
delay = initial_delay * (2**attempt)
self.logger.debug(
f"Query statistics not yet available for {query_id}, "
f"retrying in {delay:.1f}s (attempt {attempt + 1}/{max_retries})"
)
time_module.sleep(delay)
else:
# Max retries reached
cursor.close()
return {
"note": f"Query statistics not available after {max_retries + 1} attempts. "
"Statistics may appear in query history later.",
"retrieval_attempts": max_retries + 1,
}
except Exception as e:
last_error = e
if attempt < max_retries:
delay = initial_delay * (2**attempt)
self.logger.debug(
f"Error retrieving query statistics for {query_id}: {e}, "
f"retrying in {delay:.1f}s (attempt {attempt + 1}/{max_retries})"
)
time_module.sleep(delay)
else:
cursor.close()
return {
"statistics_error": str(last_error),
"retrieval_attempts": max_retries + 1,
}
cursor.close()
return {"note": "Query statistics not yet available", "retrieval_attempts": max_retries + 1}
def _get_platform_metadata(self, connection: Any) -> dict[str, Any]:
"""Get Snowflake-specific metadata and system information."""
metadata = {
"platform": self.platform_name,
"account": self.account,
"warehouse": self.warehouse,
"database": self.database,
"schema": self.schema,
"result_cache_enabled": not self.disable_result_cache,
}
cursor = connection.cursor()
try:
# Get Snowflake version
cursor.execute("SELECT CURRENT_VERSION()")
result = cursor.fetchone()
metadata["snowflake_version"] = result[0] if result else "unknown"
# Get current session information
cursor.execute("""
SELECT
CURRENT_USER(),
CURRENT_ROLE(),
CURRENT_WAREHOUSE(),
CURRENT_DATABASE(),
CURRENT_SCHEMA(),
CURRENT_REGION(),
CURRENT_ACCOUNT()
""")
result = cursor.fetchone()
if result:
metadata["session_info"] = {
"current_user": result[0],
"current_role": result[1],
"current_warehouse": result[2],
"current_database": result[3],
"current_schema": result[4],
"current_region": result[5],
"current_account": result[6],
}
# Get warehouse information
cursor.execute(f"""
SHOW WAREHOUSES LIKE '{self.warehouse}'
""")
wh_result = cursor.fetchall()
if wh_result:
wh_info = wh_result[0]
metadata["warehouse_info"] = {
"name": wh_info[0],
"state": wh_info[1],
"type": wh_info[2],
"size": wh_info[3],
"min_cluster_count": wh_info[4],
"max_cluster_count": wh_info[5],
"started_clusters": wh_info[6],
"running": wh_info[7],
"queued": wh_info[8],
"auto_suspend": wh_info[12],
"auto_resume": wh_info[13],
"available": wh_info[14],
"provisioning": wh_info[15],
"quiescing": wh_info[16],
"other": wh_info[17],
"created_on": wh_info[18],
"resumed_on": wh_info[19],
"updated_on": wh_info[20],
"owner": wh_info[21],
"comment": wh_info[22],
"scaling_policy": wh_info[25] if len(wh_info) > 25 else None,
}
# Get table information
cursor.execute(f"""
SELECT
TABLE_NAME,
ROW_COUNT,
BYTES,
RETENTION_TIME,
CREATED,
LAST_ALTERED,
CLUSTERING_KEY
FROM INFORMATION_SCHEMA.TABLES
WHERE TABLE_SCHEMA = '{self.schema}'
AND TABLE_TYPE = 'BASE TABLE'
""")
tables = cursor.fetchall()
metadata["tables"] = [
{
"table_name": row[0],
"row_count": row[1],
"bytes": row[2],
"retention_time": row[3],
"created": row[4].isoformat() if row[4] else None,
"last_altered": row[5].isoformat() if row[5] else None,
"clustering_key": row[6],
}
for row in tables
]
except Exception as e:
metadata["metadata_error"] = str(e)
finally:
cursor.close()
return metadata
[docs]
def analyze_table(self, connection: Any, table_name: str) -> None:
"""Trigger table analysis for better query optimization."""
cursor = connection.cursor()
try:
# Snowflake automatically maintains statistics, but we can trigger clustering
cursor.execute(f"ALTER TABLE {table_name.upper()} RECLUSTER")
self.logger.info(f"Triggered reclustering for table {table_name.upper()}")
except Exception as e:
self.logger.warning(f"Failed to recluster table {table_name}: {e}")
finally:
cursor.close()
[docs]
def close_connection(self, connection: Any) -> None:
"""Close Snowflake connection."""
try:
if connection and hasattr(connection, "close"):
connection.close()
except Exception as e:
self.logger.warning(f"Error closing connection: {e}")
[docs]
def supports_tuning_type(self, tuning_type) -> bool:
"""Check if Snowflake supports a specific tuning type.
Snowflake supports:
- CLUSTERING: Via CLUSTER BY clause and automatic clustering
- PARTITIONING: Via micro-partitions (automatic) and manual clustering keys
Args:
tuning_type: The type of tuning to check support for
Returns:
True if the tuning type is supported by Snowflake
"""
# Import here to avoid circular imports
try:
from benchbox.core.tuning.interface import TuningType
return tuning_type in {TuningType.CLUSTERING, TuningType.PARTITIONING}
except ImportError:
return False
[docs]
def generate_tuning_clause(self, table_tuning) -> str:
"""Generate Snowflake-specific tuning clauses for CREATE TABLE statements.
Snowflake supports:
- CLUSTER BY (column1, column2, ...) for clustering keys
- Micro-partitions are automatic based on ingestion order and clustering
Args:
table_tuning: The tuning configuration for the table
Returns:
SQL clause string to be appended to CREATE TABLE statement
"""
if not table_tuning or not table_tuning.has_any_tuning():
return ""
clauses = []
try:
# Import here to avoid circular imports
from benchbox.core.tuning.interface import TuningType
# Handle clustering - primary tuning mechanism in Snowflake
cluster_columns = table_tuning.get_columns_by_type(TuningType.CLUSTERING)
if cluster_columns:
# Sort by order and create clustering key
sorted_cols = sorted(cluster_columns, key=lambda col: col.order)
column_names = [col.name for col in sorted_cols]
cluster_clause = f"CLUSTER BY ({', '.join(column_names)})"
clauses.append(cluster_clause)
# Handle partitioning as clustering (Snowflake uses micro-partitions automatically)
partition_columns = table_tuning.get_columns_by_type(TuningType.PARTITIONING)
if partition_columns and not cluster_columns:
# Use partitioning columns as clustering keys if no explicit clustering
sorted_cols = sorted(partition_columns, key=lambda col: col.order)
column_names = [col.name for col in sorted_cols]
cluster_clause = f"CLUSTER BY ({', '.join(column_names)})"
clauses.append(cluster_clause)
# Distribution and sorting handled through clustering in Snowflake
except ImportError:
# If tuning interface not available, return empty string
pass
return " ".join(clauses)
[docs]
def apply_table_tunings(self, table_tuning, connection: Any) -> None:
"""Apply tuning configurations to a Snowflake table.
Snowflake tuning approach:
- CLUSTERING: Handled via CLUSTER BY in CREATE TABLE or ALTER TABLE
- PARTITIONING: Automatic micro-partitions with optional clustering keys
- Automatic clustering can be enabled for maintenance
Args:
table_tuning: The tuning configuration to apply
connection: Snowflake connection
Raises:
ValueError: If the tuning configuration is invalid for Snowflake
"""
if not table_tuning or not table_tuning.has_any_tuning():
return
table_name = table_tuning.table_name.upper()
self.logger.info(f"Applying Snowflake tunings for table: {table_name}")
cursor = connection.cursor()
try:
# Import here to avoid circular imports
from benchbox.core.tuning.interface import TuningType
# Handle clustering keys
cluster_columns = table_tuning.get_columns_by_type(TuningType.CLUSTERING)
partition_columns = table_tuning.get_columns_by_type(TuningType.PARTITIONING)
# Determine clustering strategy
clustering_columns = []
if cluster_columns:
sorted_cols = sorted(cluster_columns, key=lambda col: col.order)
clustering_columns = [col.name for col in sorted_cols]
elif partition_columns:
# Use partition columns as clustering keys
sorted_cols = sorted(partition_columns, key=lambda col: col.order)
clustering_columns = [col.name for col in sorted_cols]
if clustering_columns:
# Check current clustering key
cursor.execute(f"""
SELECT CLUSTERING_KEY
FROM INFORMATION_SCHEMA.TABLES
WHERE TABLE_SCHEMA = '{self.schema}'
AND TABLE_NAME = '{table_name}'
""")
result = cursor.fetchone()
current_clustering = result[0] if result and result[0] else None
desired_clustering = f"({', '.join(clustering_columns)})"
if current_clustering != desired_clustering:
# Apply clustering key
cluster_sql = f"ALTER TABLE {table_name} CLUSTER BY ({', '.join(clustering_columns)})"
try:
cursor.execute(cluster_sql)
self.logger.info(f"Applied clustering key to {table_name}: {', '.join(clustering_columns)}")
# Enable automatic clustering if desired
if len(clustering_columns) <= 4: # Snowflake recommendation
try:
cursor.execute(f"ALTER TABLE {table_name} RESUME RECLUSTER")
self.logger.info(f"Enabled automatic clustering for {table_name}")
except Exception as e:
self.logger.debug(f"Could not enable automatic clustering for {table_name}: {e}")
except Exception as e:
self.logger.warning(f"Failed to apply clustering key to {table_name}: {e}")
else:
self.logger.info(f"Table {table_name} already has desired clustering key: {current_clustering}")
# Handle sorting - in Snowflake, this is achieved through clustering
sort_columns = table_tuning.get_columns_by_type(TuningType.SORTING)
if sort_columns and not clustering_columns:
sorted_cols = sorted(sort_columns, key=lambda col: col.order)
column_names = [col.name for col in sorted_cols]
self.logger.info(
f"Sorting in Snowflake achieved via clustering for table {table_name}: {', '.join(column_names)}"
)
# Distribution not applicable for Snowflake's architecture
distribution_columns = table_tuning.get_columns_by_type(TuningType.DISTRIBUTION)
if distribution_columns:
self.logger.warning(
f"Distribution tuning not applicable for Snowflake's shared-nothing architecture on table: {table_name}"
)
except ImportError:
self.logger.warning("Tuning interface not available - skipping tuning application")
except Exception as e:
raise ValueError(f"Failed to apply tunings to Snowflake table {table_name}: {e}")
finally:
cursor.close()
[docs]
def apply_unified_tuning(self, unified_config: UnifiedTuningConfiguration, connection: Any) -> None:
"""Apply unified tuning configuration to Snowflake.
Args:
unified_config: Unified tuning configuration to apply
connection: Snowflake connection
"""
if not unified_config:
return
# Apply constraint configurations
self.apply_constraint_configuration(unified_config.primary_keys, unified_config.foreign_keys, connection)
# Apply platform optimizations
if unified_config.platform_optimizations:
self.apply_platform_optimizations(unified_config.platform_optimizations, connection)
# Apply table-level tunings
for _table_name, table_tuning in unified_config.table_tunings.items():
self.apply_table_tunings(table_tuning, connection)
[docs]
def apply_constraint_configuration(
self,
primary_key_config: PrimaryKeyConfiguration,
foreign_key_config: ForeignKeyConfiguration,
connection: Any,
) -> None:
"""Apply constraint configurations to Snowflake.
Note: Snowflake supports PRIMARY KEY and FOREIGN KEY constraints but they are
not enforced (informational only). They are used for query optimization and
must be applied during table creation time.
Args:
primary_key_config: Primary key constraint configuration
foreign_key_config: Foreign key constraint configuration
connection: Snowflake connection
"""
# Snowflake constraints are applied at table creation time for query optimization
# This method is called after tables are created, so log the configurations
if primary_key_config and primary_key_config.enabled:
self.logger.info(
"Primary key constraints enabled for Snowflake (informational only, applied during table creation)"
)
if foreign_key_config and foreign_key_config.enabled:
self.logger.info(
"Foreign key constraints enabled for Snowflake (informational only, applied during table creation)"
)
# Snowflake constraints are informational and used for query optimization
# No additional work to do here as they're applied during CREATE TABLE
def _build_snowflake_config(
platform: str,
options: dict[str, Any],
overrides: dict[str, Any],
info: Any,
) -> Any:
"""Build Snowflake database configuration with credential loading.
This function loads saved credentials from the CredentialManager and
merges them with CLI options and runtime overrides.
Args:
platform: Platform name (should be 'snowflake')
options: CLI platform options from --platform-option flags
overrides: Runtime overrides from orchestrator
info: Platform info from registry
Returns:
DatabaseConfig with credentials loaded and platform-specific fields at top-level
"""
from benchbox.core.config import DatabaseConfig
from benchbox.security.credentials import CredentialManager
# Load saved credentials
cred_manager = CredentialManager()
saved_creds = cred_manager.get_platform_credentials("snowflake") or {}
# Build merged options: saved_creds < options < overrides
merged_options = {}
merged_options.update(saved_creds)
merged_options.update(options)
merged_options.update(overrides)
# Extract credential fields for DatabaseConfig
name = info.display_name if info else "Snowflake"
driver_package = info.driver_package if info else "snowflake-connector-python"
# Build config dict with platform-specific fields at top-level
# This allows SnowflakeAdapter.__init__() to access them via config.get()
config_dict = {
"type": "snowflake",
"name": name,
"options": merged_options or {}, # Ensure options is never None (Pydantic v2 uses None if explicitly passed)
"driver_package": driver_package,
"driver_version": overrides.get("driver_version") or options.get("driver_version"),
"driver_auto_install": bool(overrides.get("driver_auto_install", options.get("driver_auto_install", False))),
# Platform-specific fields at top-level (adapters expect these here)
"account": merged_options.get("account"),
"warehouse": merged_options.get("warehouse"),
# NOTE: database is NOT included here - from_config() generates it from benchmark context
# Only explicit overrides (via --platform-option database=...) should bypass generation
"schema": merged_options.get("schema"),
"username": merged_options.get("username"),
"password": merged_options.get("password"),
"role": merged_options.get("role"),
"authenticator": merged_options.get("authenticator"),
"private_key_path": merged_options.get("private_key_path"),
"private_key_passphrase": merged_options.get("private_key_passphrase"),
# Benchmark context for config-aware database naming (from overrides)
"benchmark": overrides.get("benchmark"),
"scale_factor": overrides.get("scale_factor"),
"tuning_config": overrides.get("tuning_config"),
}
# Only include explicit database override if provided via CLI or overrides
# Saved credentials should NOT override generated database names
if "database" in overrides and overrides["database"]:
config_dict["database"] = overrides["database"]
return DatabaseConfig(**config_dict)
# Register the config builder with the platform hook registry
# This must happen when the module is imported
try:
from benchbox.cli.platform_hooks import PlatformHookRegistry
PlatformHookRegistry.register_config_builder("snowflake", _build_snowflake_config)
except ImportError:
# Platform hooks may not be available in all contexts (e.g., core-only usage)
pass