"""Databricks platform adapter with Unity Catalog and Delta Lake optimization.
Provides Databricks-specific optimizations for large-scale analytics,
including Delta Lake table creation and cluster management.
Copyright 2026 Joe Harris / BenchBox Project
Licensed under the MIT License. See LICENSE file in the project root for details.
"""
from __future__ import annotations
import contextlib
import json
import logging
import time
from pathlib import Path
from typing import TYPE_CHECKING, Any
if TYPE_CHECKING:
from benchbox.core.tuning.interface import (
ForeignKeyConfiguration,
PlatformOptimizationConfiguration,
PrimaryKeyConfiguration,
UnifiedTuningConfiguration,
)
from benchbox.core.upload_validation import UploadValidationEngine
from benchbox.platforms.base import PlatformAdapter
from benchbox.utils.datagen_manifest import MANIFEST_FILENAME
from benchbox.utils.dependencies import check_platform_dependencies, get_dependency_error_message
try:
from databricks import sql as databricks_sql
from databricks.sql.client import Connection as DatabricksConnection
except ImportError:
databricks_sql = None
DatabricksConnection = None
[docs]
class DatabricksAdapter(PlatformAdapter):
"""Databricks platform adapter with Delta Lake and Unity Catalog support."""
[docs]
def __init__(self, **config):
super().__init__(**config)
# Check dependencies with improved error message
available, missing = check_platform_dependencies("databricks")
if not available:
error_msg = get_dependency_error_message("databricks", missing)
raise ImportError(error_msg)
self._dialect = "databricks"
# Databricks configuration
self.server_hostname = config.get("server_hostname") or config.get("host")
self.http_path = config.get("http_path")
self.access_token = config.get("access_token") or config.get("token")
self.catalog = config.get("catalog") or "main"
self.schema = config.get("schema") or "benchbox"
# Unity Catalog Volume and staging support
self.uc_catalog = config.get("uc_catalog")
self.uc_schema = config.get("uc_schema")
self.uc_volume = config.get("uc_volume")
# Explicit staging root (e.g., dbfs:/Volumes/<cat>/<schema>/<volume>/... or s3://...)
self.staging_root = config.get("staging_root")
# Delta Lake settings
self.enable_delta_optimization = (
config.get("enable_delta_optimization") if config.get("enable_delta_optimization") is not None else True
)
self.delta_auto_optimize = (
config.get("delta_auto_optimize") if config.get("delta_auto_optimize") is not None else True
)
self.delta_auto_compact = (
config.get("delta_auto_compact") if config.get("delta_auto_compact") is not None else True
)
# Cluster settings
self.cluster_size = config.get("cluster_size") or "Medium"
self.auto_terminate_minutes = (
config.get("auto_terminate_minutes") if config.get("auto_terminate_minutes") is not None else 30
)
# Schema creation settings
self.create_catalog = config.get("create_catalog") if config.get("create_catalog") is not None else False
# Upload/validation controls
force_upload_val = config.get("force_upload")
self.force_upload = bool(force_upload_val if force_upload_val is not None else False)
# Result cache control - disable by default for accurate benchmarking
self.disable_result_cache = config.get("disable_result_cache", True)
if not self.server_hostname or not self.http_path or not self.access_token:
missing = []
if not self.server_hostname:
missing.append("server_hostname (or DATABRICKS_HOST)")
if not self.http_path:
missing.append("http_path (or DATABRICKS_HTTP_PATH)")
if not self.access_token:
missing.append("access_token (or DATABRICKS_TOKEN)")
from benchbox.core.exceptions import ConfigurationError
raise ConfigurationError(
f"Databricks configuration is incomplete. Missing: {', '.join(missing)}\n"
"Configure with one of:\n"
" 1. CLI: benchbox platforms setup --platform databricks\n"
" 2. Environment variables: DATABRICKS_HOST, DATABRICKS_HTTP_PATH, DATABRICKS_TOKEN\n"
" 3. CLI options: --platform-option server_hostname=<host> --platform-option http_path=<path>"
)
@property
def platform_name(self) -> str:
return "Databricks"
[docs]
@staticmethod
def add_cli_arguments(parser) -> None:
"""Add Databricks-specific CLI arguments."""
db_group = parser.add_argument_group("Databricks Arguments")
db_group.add_argument("--server-hostname", type=str, help="Databricks server hostname")
db_group.add_argument("--http-path", type=str, help="Databricks SQL Warehouse HTTP path")
db_group.add_argument("--access-token", type=str, help="Databricks access token")
db_group.add_argument("--catalog", type=str, default="workspace", help="Databricks catalog name")
db_group.add_argument(
"--schema", type=str, default=None, help="Databricks schema name (auto-generated if not specified)"
)
[docs]
@classmethod
def from_config(cls, config: dict[str, Any]):
"""Create Databricks adapter from unified configuration."""
from benchbox.utils.database_naming import generate_database_name
# Try auto-detection if credentials not provided
adapter_config = {}
very_verbose = config.get("very_verbose", False)
# Check if we have valid (non-placeholder) credentials
def is_placeholder(value):
if not value:
return True
str_val = str(value)
# Common placeholder patterns
return (
"your-workspace" in str_val
or "your-warehouse-id" in str_val
or "${" in str_val # Environment variable placeholder
or "example" in str_val.lower()
)
if not all(
[
config.get("server_hostname") and not is_placeholder(config.get("server_hostname")),
config.get("http_path") and not is_placeholder(config.get("http_path")),
config.get("access_token") and not is_placeholder(config.get("access_token")),
]
):
auto_config = cls._auto_detect_databricks_config(very_verbose=very_verbose)
if auto_config:
adapter_config.update(auto_config)
# Override with explicit config values (but skip placeholders)
for key in ["server_hostname", "http_path", "access_token"]:
if config.get(key) and not is_placeholder(config.get(key)):
adapter_config[key] = config[key]
# Handle catalog
adapter_config["catalog"] = config.get("catalog", "workspace")
# Handle schema - prioritize auto-generation when benchmark context is available
# This ensures schema names reflect benchmark/scale/tuning configuration,
# rather than using static values from credentials files
provided_schema = config.get("schema")
has_benchmark_context = "benchmark" in config and "scale_factor" in config
if has_benchmark_context:
# When running a benchmark, always auto-generate schema name unless
# user provided an explicit non-default override
is_default_schema = provided_schema in (None, "", "benchbox")
if is_default_schema:
# Generate proper schema name using benchmark configuration
schema_name = generate_database_name(
benchmark_name=config["benchmark"],
scale_factor=config["scale_factor"],
platform="databricks",
tuning_config=config.get("tuning_config"),
)
adapter_config["schema"] = schema_name
else:
# User provided explicit non-default schema - honor it
adapter_config["schema"] = provided_schema
else:
# No benchmark context - fall back to provided schema or default
adapter_config["schema"] = provided_schema or "benchbox"
# Pass through other relevant config
for key in [
"tuning_config",
"verbose_enabled",
"very_verbose",
"uc_catalog",
"uc_schema",
"uc_volume",
"staging_root",
]:
if key in config:
adapter_config[key] = config[key]
return cls(**adapter_config)
@staticmethod
def _auto_detect_databricks_config(very_verbose: bool = False):
"""Auto-detect Databricks configuration from SDK."""
logger = logging.getLogger("DatabricksAdapter")
try:
from databricks.sdk import WorkspaceClient
from databricks.sdk.service.sql import WarehousesAPI
if very_verbose:
logger.info("Attempting to auto-detect Databricks configuration from SDK...")
workspace = WorkspaceClient()
server_hostname = workspace.config.host.replace("https://", "")
access_token = workspace.config.token
if very_verbose:
logger.info(f"Found Databricks host: {server_hostname}")
warehouses = list(WarehousesAPI(workspace.api_client).list())
if very_verbose:
logger.info(f"Found {len(warehouses)} Databricks SQL Warehouses.")
for wh in warehouses:
logger.info(f" - Warehouse: {wh.name}, State: {wh.state}, ID: {wh.id}")
http_path = None
selected_warehouse = None
if warehouses:
# 1. Prefer a running warehouse
running_wh = next((wh for wh in warehouses if str(wh.state) == "RUNNING"), None)
if running_wh:
selected_warehouse = running_wh
if very_verbose:
logger.info(f"Selected running warehouse: {selected_warehouse.name}")
else:
if very_verbose:
logger.info("No running warehouses found. Looking for an available one to auto-start.")
# 2. Otherwise, take the first available one that is not in a terminal state
available_wh = next(
(wh for wh in warehouses if str(wh.state) not in ["DELETING", "DELETED"]),
None,
)
if available_wh:
selected_warehouse = available_wh
if very_verbose:
logger.info(
f"Selected available warehouse to auto-start: {selected_warehouse.name} (State: {selected_warehouse.state})"
)
if selected_warehouse:
http_path = f"/sql/1.0/warehouses/{selected_warehouse.id}"
if very_verbose:
logger.info(f"Using HTTP path: {http_path}")
elif very_verbose:
logger.warning("No suitable warehouse found for auto-detection.")
return {
"server_hostname": server_hostname,
"http_path": http_path,
"access_token": access_token,
}
except Exception as e:
if very_verbose:
logger.error(f"Databricks auto-detection failed: {e}")
return None
[docs]
def get_target_dialect(self) -> str:
"""Return the target SQL dialect for Databricks."""
return "databricks"
def _get_connection_params(self, **connection_config) -> dict[str, Any]:
"""Get standardized connection parameters."""
return {
"server_hostname": connection_config.get("server_hostname", self.server_hostname),
"http_path": connection_config.get("http_path", self.http_path),
"access_token": connection_config.get("access_token", self.access_token),
}
def _create_admin_connection(self, **connection_config) -> Any:
"""Create Databricks connection for admin operations."""
params = self._get_connection_params(**connection_config)
# Basic connection without session configuration to work with all warehouse types
return databricks_sql.connect(**params, user_agent_entry="BenchBox/1.0")
[docs]
def check_server_database_exists(self, **connection_config) -> bool:
"""Check if schema exists in Databricks catalog."""
try:
connection = self._create_admin_connection(**connection_config)
cursor = connection.cursor()
catalog = connection_config.get("catalog", self.catalog)
schema = connection_config.get("schema", self.schema)
# Check if catalog exists
cursor.execute("SHOW CATALOGS")
catalogs = [row[0] for row in cursor.fetchall()]
if catalog not in catalogs:
return False
# Check if schema exists in catalog
cursor.execute(f"SHOW SCHEMAS IN {catalog}")
schemas = [row[0] for row in cursor.fetchall()]
return schema in schemas
except Exception:
# If we can't connect or check, assume schema doesn't exist
return False
finally:
if "connection" in locals():
connection.close()
[docs]
def drop_database(self, **connection_config) -> None:
"""Drop schema in Databricks catalog."""
try:
connection = self._create_admin_connection(**connection_config)
cursor = connection.cursor()
catalog = connection_config.get("catalog", self.catalog)
schema = connection_config.get("schema", self.schema)
# Drop schema and all its tables
cursor.execute(f"DROP SCHEMA IF EXISTS {catalog}.{schema} CASCADE")
except Exception as e:
raise RuntimeError(f"Failed to drop Databricks schema {catalog}.{schema}: {e}")
finally:
if "connection" in locals():
connection.close()
[docs]
def create_connection(self, **connection_config) -> Any:
"""Create optimized Databricks SQL connection."""
self.log_operation_start("Databricks connection")
# Handle existing database using base class method
self.handle_existing_database(**connection_config)
try:
params = self._get_connection_params(**connection_config)
self.log_very_verbose(
f"Databricks connection params: host={params.get('server_hostname')}, catalog={self.catalog}"
)
connection = self._create_admin_connection(**connection_config)
# Test connection and set catalog
cursor = connection.cursor()
cursor.execute("SELECT 1")
cursor.fetchall()
self.log_very_verbose("Databricks connection test successful")
# Set catalog and schema context
# If database is being reused, schema already exists - set it now
# If database is new, schema will be created in create_schema() which will also set it
cursor.execute(f"USE CATALOG {self.catalog}")
if self.database_was_reused:
cursor.execute(f"USE SCHEMA {self.schema}")
self.log_very_verbose(f"Set schema context to {self.catalog}.{self.schema} (database reused)")
else:
self.log_very_verbose(f"Set catalog to {self.catalog}, schema will be set during schema creation")
self.log_operation_complete(
"Databricks connection",
details=f"Connected to {params['server_hostname']}, catalog: {self.catalog}",
)
return connection
except Exception as e:
self.logger.error(f"Failed to connect to Databricks: {e}")
raise
[docs]
def create_schema(self, benchmark, connection: Any) -> float:
"""Create schema using Databricks Delta Lake tables."""
start_time = time.time()
self.log_operation_start("Schema creation", f"benchmark: {benchmark.__class__.__name__}")
# Get constraint settings from tuning configuration
enable_primary_keys, enable_foreign_keys = self._get_constraint_configuration()
self._log_constraint_configuration(enable_primary_keys, enable_foreign_keys)
self.log_verbose(
f"Schema constraints - Primary keys: {enable_primary_keys}, Foreign keys: {enable_foreign_keys}"
)
try:
cursor = connection.cursor()
# Step 1: Ensure catalog exists (if create_catalog is enabled)
# Step 2: Create schema BEFORE attempting to USE it (correct order)
if self.create_catalog:
cursor.execute(f"CREATE CATALOG IF NOT EXISTS {self.catalog}")
cursor.execute(f"CREATE SCHEMA IF NOT EXISTS {self.catalog}.{self.schema}")
self.log_verbose(f"Created catalog and schema: {self.catalog}.{self.schema}")
else:
# Just create schema if catalog already exists
cursor.execute(f"CREATE SCHEMA IF NOT EXISTS {self.catalog}.{self.schema}")
self.log_verbose(f"Created schema: {self.catalog}.{self.schema}")
# Step 3: Set catalog and schema context (now that schema exists)
cursor.execute(f"USE CATALOG {self.catalog}")
cursor.execute(f"USE SCHEMA {self.schema}")
self.log_very_verbose(f"Set schema context to: {self.catalog}.{self.schema}")
# Use common schema creation helper
schema_sql = self._create_schema_with_tuning(benchmark, source_dialect="duckdb")
# Debug: Log schema SQL generation results
self.log_verbose(f"Received schema SQL from _create_schema_with_tuning: {len(schema_sql)} characters")
self.log_very_verbose(f"Schema SQL (first 300 chars): {schema_sql[:300]}")
if not schema_sql or not schema_sql.strip():
self.logger.error(f"Schema SQL is empty! Benchmark class: {benchmark.__class__.__name__}")
self.logger.error(f"Benchmark has get_schema_sql: {hasattr(benchmark, 'get_schema_sql')}")
raise RuntimeError(f"No schema SQL generated for {benchmark.__class__.__name__}")
# Transform SQL syntax for Databricks compatibility
original_len = len(schema_sql)
schema_sql = self._fix_databricks_sql_syntax(schema_sql)
self.log_very_verbose(
f"After _fix_databricks_sql_syntax: {len(schema_sql)} characters (was {original_len})"
)
if len(schema_sql) != original_len:
self.log_verbose(f"SQL length changed after Databricks syntax fix: {original_len} -> {len(schema_sql)}")
# Split schema into individual statements and execute
statements = [stmt.strip() for stmt in schema_sql.split(";") if stmt.strip()]
# Debug: Log statement count
self.log_verbose(f"Parsed {len(statements)} CREATE TABLE statements from schema SQL")
if not statements:
self.logger.error("No CREATE TABLE statements found after parsing schema SQL")
self.logger.error(f"Raw schema SQL (first 500 chars): {schema_sql[:500]}")
raise RuntimeError("Schema SQL produced no executable statements")
# Execute statements with error handling from base adapter
tables_created, failed_tables = self._execute_schema_statements(
statements, cursor, platform_transform_fn=self._convert_to_delta_table
)
duration = time.time() - start_time
self.log_operation_complete("Schema creation", duration, f"{tables_created} Delta Lake tables created")
return duration
except Exception as e:
self.logger.error(f"Schema creation failed: {e}")
raise
finally:
if "cursor" in locals():
cursor.close()
return time.time() - start_time
def _ensure_uc_volume_exists(self, uc_volume_path: str, connection: Any) -> None:
"""Ensure UC Volume exists, creating it if necessary.
This method also creates the schema if it doesn't exist, providing
a complete zero-setup experience for UC Volume workflows.
Args:
uc_volume_path: UC Volume path (e.g., dbfs:/Volumes/catalog/schema/volume)
connection: Databricks SQL connection
Raises:
ValueError: If volume path is invalid or creation fails
"""
# Parse volume path: dbfs:/Volumes/catalog/schema/volume
volume_path = uc_volume_path.replace("dbfs:", "").rstrip("/")
# Extract catalog, schema, volume from /Volumes/catalog/schema/volume
if not volume_path.startswith("/Volumes/"):
raise ValueError(f"Invalid UC Volume path: {uc_volume_path}. Must start with dbfs:/Volumes/")
path_parts = volume_path.split("/")
# path_parts = ['', 'Volumes', 'catalog', 'schema', 'volume', ...]
if len(path_parts) < 5:
# Be lenient in unit-test or minimal paths: skip ensure when volume parts are incomplete
self.logger.warning(
f"UC Volume path '{uc_volume_path}' missing components (expected dbfs:/Volumes/catalog/schema/volume). Skipping ensure."
)
return
catalog = path_parts[2]
schema = path_parts[3]
volume = path_parts[4]
self.log_verbose(f"Ensuring UC Volume exists: {catalog}.{schema}.{volume}")
try:
cursor = connection.cursor()
# First, ensure the schema exists (required for volume creation)
try:
create_schema_sql = f"CREATE SCHEMA IF NOT EXISTS {catalog}.{schema}"
cursor.execute(create_schema_sql)
self.log_very_verbose(f"Schema ready: {catalog}.{schema}")
except Exception as schema_error:
# If schema creation fails due to permissions, provide clear guidance
error_msg = str(schema_error).lower()
if "permission" in error_msg or "access denied" in error_msg or "unauthorized" in error_msg:
raise ValueError(
f"Permission denied creating schema: {catalog}.{schema}. "
f"Ensure you have CREATE SCHEMA permission on catalog {catalog}. "
f"Or create it manually: CREATE SCHEMA IF NOT EXISTS {catalog}.{schema}"
)
raise
# Now create the volume (IF NOT EXISTS is safe)
create_volume_sql = f"CREATE VOLUME IF NOT EXISTS {catalog}.{schema}.{volume}"
cursor.execute(create_volume_sql)
self.log_verbose(f"✅ UC Volume ready: {catalog}.{schema}.{volume}")
cursor.close()
except ValueError:
# Re-raise ValueError exceptions (our custom error messages)
raise
except Exception as e:
error_msg = str(e).lower()
# Check for permission errors
if "permission" in error_msg or "access denied" in error_msg or "unauthorized" in error_msg:
raise ValueError(
f"Permission denied creating UC Volume: {catalog}.{schema}.{volume}. "
f"Ensure you have CREATE VOLUME permission on schema {catalog}.{schema}. "
f"Or create it manually: CREATE VOLUME IF NOT EXISTS {catalog}.{schema}.{volume}"
)
# Generic error
raise ValueError(
f"Failed to create UC Volume {catalog}.{schema}.{volume}: {e}. "
f"Try creating manually: CREATE VOLUME IF NOT EXISTS {catalog}.{schema}.{volume}"
)
def _upload_to_uc_volume(
self,
data_files: dict[str, Any],
uc_volume_path: str,
data_dir: Path,
force_upload: bool = False,
) -> dict[str, str]:
"""Upload local data files to Unity Catalog Volume using Databricks Files API.
For sharded files (e.g., customer.tbl.1.zst, customer.tbl.2.zst, ...),
this method will find and upload ALL chunk files, returning a wildcard pattern
for COPY INTO to use.
Args:
data_files: Dictionary of table_name -> local file path (may be first chunk only)
uc_volume_path: UC Volume path (e.g., dbfs:/Volumes/catalog/schema/volume)
data_dir: Base data directory (for resolving relative paths)
Returns:
Dictionary mapping table names to UC Volume file URIs (with wildcards for sharded tables)
Raises:
ImportError: If databricks-sdk not available
Exception: If upload fails
"""
try:
from databricks.sdk import WorkspaceClient
except ImportError:
raise ImportError("databricks-sdk required for UC Volume uploads. Install with: uv add databricks-sdk")
# Initialize Databricks SDK client (uses same credentials as SQL connector)
workspace = WorkspaceClient(
host=f"https://{self.server_hostname}",
token=self.access_token,
)
# Convert dbfs:/Volumes/catalog/schema/volume to /Volumes/catalog/schema/volume
volume_path = uc_volume_path.replace("dbfs:", "")
# Extract local directory from DatabricksPath if applicable
from benchbox.utils.cloud_storage import DatabricksPath
if isinstance(data_dir, DatabricksPath):
local_base_dir = data_dir._path # Get actual local directory component
self.log_very_verbose(f"Using DatabricksPath local component: {local_base_dir}")
else:
local_base_dir = Path(data_dir)
# Pre-upload validation: if a local manifest exists and remote data matches it, reuse
try:
from benchbox.utils.cloud_storage import DatabricksPath
except Exception:
DatabricksPath = None # type: ignore
# Determine local manifest path
if DatabricksPath is not None and isinstance(data_dir, DatabricksPath):
manifest_path = data_dir._path / MANIFEST_FILENAME
else:
manifest_path = Path(data_dir) / MANIFEST_FILENAME
if not force_upload and manifest_path.exists():
# Use validation engine with Databricks adapter underneath (auto-detected)
validation_engine = UploadValidationEngine()
# Pass verbose flag from adapter settings for detailed validation reporting
verbose = getattr(self, "very_verbose", False)
should_upload, validation_result = validation_engine.should_upload_data(
remote_path=uc_volume_path,
local_manifest_path=manifest_path,
force_upload=force_upload,
verbose=verbose,
)
if not should_upload:
# Data exists and is valid - rebuild mapping from remote manifest
# Core module already logged validation messages
remote_manifest = validation_result.remote_manifest
if remote_manifest:
# Keep adapter-specific verbose logging for debugging
self.log_verbose("Reusing existing data from UC Volume (validation passed)")
return self._get_remote_file_uris_from_manifest(uc_volume_path, remote_manifest)
else:
self.log_verbose(
"Pre-upload validation passed but remote manifest unavailable, proceeding with upload"
)
# Upload manifest FIRST for atomic consistency
if manifest_path.exists():
try:
self._upload_manifest_to_uc_volume(manifest_path, uc_volume_path, workspace)
except Exception as e:
self.logger.warning(f"Failed to upload manifest to UC Volume: {e}")
uploaded_files = {}
for table_name, file_path in data_files.items():
local_path = Path(file_path) if not isinstance(file_path, Path) else file_path
# If path is not absolute, it's relative to CWD - make it absolute for verification
if not local_path.is_absolute():
local_path = local_path.resolve()
# Verify file exists and log details
if not local_path.exists():
self.logger.error(f"File not found for table {table_name}: {local_path}")
self.logger.error(f" Checked path: {local_path.absolute()}")
self.logger.error(f" CWD: {Path.cwd()}")
continue
file_size = local_path.stat().st_size
self.log_very_verbose(f"Found {local_path.name} ({file_size:,} bytes) at {local_path}")
# Check if this is a sharded file (e.g., customer.tbl.1.zst)
# Pattern: base_name.ext.N.compression OR base_name.ext.N
filename = local_path.name
parts = filename.split(".")
is_sharded = False
chunk_files = []
# Detect sharding patterns:
# 1. customer.tbl.1.zst -> parts = ['customer', 'tbl', '1', 'zst']
# 2. customer.tbl.1 -> parts = ['customer', 'tbl', '1']
compression_exts = {".zst", ".gz", ".bz2", ".xz", ".lz4"}
if len(parts) >= 3:
# Check if second-to-last part is a digit (for compressed files)
# OR if last part is a digit (for uncompressed files)
if (
len(parts) >= 4
and parts[-1] in [ext.lstrip(".") for ext in compression_exts]
and parts[-2].isdigit()
):
# Pattern: customer.tbl.1.zst
is_sharded = True
base_parts = parts[:-2] # ['customer', 'tbl']
compression = parts[-1] # 'zst'
pattern = f"{'.'.join(base_parts)}.*.{compression}"
elif parts[-1].isdigit():
# Pattern: customer.tbl.1
is_sharded = True
base_parts = parts[:-1] # ['customer', 'tbl']
pattern = f"{'.'.join(base_parts)}.*"
if is_sharded:
# Find all chunk files matching the pattern in the same directory
parent_dir = local_path.parent
chunk_files = sorted([f for f in parent_dir.glob(pattern) if f.is_file()])
if chunk_files:
self.log_verbose(f"Found {len(chunk_files)} chunk files for {table_name}: {pattern}")
# Upload files
if is_sharded and chunk_files:
# Upload ALL chunk files
for chunk_file in chunk_files:
# Validate file before upload
if not chunk_file.exists():
self.logger.error(f"Chunk file disappeared: {chunk_file}")
continue
chunk_size = chunk_file.stat().st_size
if chunk_size == 0:
self.logger.warning(f"Skipping empty chunk file: {chunk_file.name}")
continue
target_path = f"{volume_path}/{chunk_file.name}"
self.log_very_verbose(f"Uploading {chunk_file.name} ({chunk_size:,} bytes) to {target_path}")
try:
# Read file contents and verify before upload
with open(chunk_file, "rb") as f:
content = f.read()
if len(content) == 0:
self.logger.error(f"Read 0 bytes from {chunk_file} (expected {chunk_size})")
raise RuntimeError(f"Failed to read content from {chunk_file}")
if len(content) != chunk_size:
self.logger.warning(
f"Size mismatch for {chunk_file.name}: stat={chunk_size}, read={len(content)}"
)
# Upload using BytesIO to ensure we send what we read
from io import BytesIO
workspace.files.upload(target_path, BytesIO(content), overwrite=True)
self.log_very_verbose(f"Successfully uploaded {chunk_file.name} ({len(content):,} bytes)")
except Exception as e:
self.logger.error(f"Failed to upload {chunk_file.name} to UC Volume: {e}")
raise RuntimeError(f"Failed to upload {chunk_file.name} to {uc_volume_path}: {e}")
# Return wildcard pattern for COPY INTO
wildcard_uri = f"dbfs:{volume_path}/{pattern}"
uploaded_files[table_name] = wildcard_uri
self.log_verbose(f"Uploaded {len(chunk_files)} chunks for {table_name}, using wildcard: {wildcard_uri}")
else:
# Single file (not sharded) - upload with validation
single_file_size = local_path.stat().st_size
if single_file_size == 0:
self.logger.warning(f"Skipping empty file: {local_path.name}")
continue
target_path = f"{volume_path}/{local_path.name}"
self.log_verbose(f"Uploading {local_path.name} ({single_file_size:,} bytes) to {target_path}")
try:
# Read file contents and verify before upload
with open(local_path, "rb") as f:
content = f.read()
if len(content) == 0:
self.logger.error(f"Read 0 bytes from {local_path} (expected {single_file_size})")
raise RuntimeError(f"Failed to read content from {local_path}")
if len(content) != single_file_size:
self.logger.warning(
f"Size mismatch for {local_path.name}: stat={single_file_size}, read={len(content)}"
)
# Upload using BytesIO to ensure we send what we read
from io import BytesIO
workspace.files.upload(target_path, BytesIO(content), overwrite=True)
# Store the dbfs:// URI for COPY INTO
uploaded_files[table_name] = f"dbfs:{target_path}"
self.log_verbose(f"Successfully uploaded {local_path.name} ({len(content):,} bytes)")
except Exception as e:
self.logger.error(f"Failed to upload {local_path.name} to UC Volume: {e}")
raise RuntimeError(f"Failed to upload {local_path.name} to {uc_volume_path}: {e}")
# Upload manifest last if present (manifest-first upload is handled by pre-upload validation)
if manifest_path.exists():
try:
self._upload_manifest_to_uc_volume(manifest_path, uc_volume_path, workspace)
except Exception as e:
self.logger.warning(f"Failed to upload manifest to UC Volume: {e}")
return uploaded_files
def _upload_manifest_to_uc_volume(self, manifest_path: Path, uc_volume_path: str, workspace: Any) -> None:
"""Upload the manifest JSON to the UC Volume root."""
try:
target_path = uc_volume_path.replace("dbfs:", "")
if not target_path.endswith("/" + MANIFEST_FILENAME):
target_path = target_path.rstrip("/") + "/" + MANIFEST_FILENAME
with open(manifest_path, "rb") as fh:
content = fh.read()
from io import BytesIO
workspace.files.upload(target_path, BytesIO(content), overwrite=True)
# Small log for visibility
try:
manifest = json.loads(Path(manifest_path).read_text(encoding="utf-8"))
tables = manifest.get("tables") or {}
self.logger.info(f"Uploaded manifest to {uc_volume_path} ({len(content)} bytes, {len(tables)} tables)")
except Exception:
self.logger.info(f"Uploaded manifest to {uc_volume_path}")
except Exception as e:
raise RuntimeError(f"Manifest upload failed: {e}")
def _get_remote_file_uris_from_manifest(self, uc_volume_path: str, remote_manifest: dict) -> dict[str, str]:
"""Build UC Volume file URI map per table from manifest entries.
For sharded tables, return a wildcard pattern like customer.tbl.*.zst
"""
mapping: dict[str, str] = {}
tables = remote_manifest.get("tables") or {}
for table, entries in tables.items():
if not entries:
continue
if len(entries) == 1:
rel = entries[0].get("path")
if rel:
mapping[table] = f"{uc_volume_path.rstrip('/')}/{rel}"
continue
# Try to detect a common sharded pattern: base.N[.ext]
names = [str(e.get("path")) for e in entries if e.get("path")]
if not names:
continue
# Derive wildcard: if all names share same prefix/suffix around a numeric segment
# e.g., customer.tbl.1.zst -> base='customer.tbl', ext='.zst' => customer.tbl.*.zst
def pattern_for(name: str) -> tuple[str, str, str]:
parts = name.split(".")
if len(parts) >= 3 and parts[-2].isdigit():
base = ".".join(parts[:-2])
ext = "." + parts[-1]
return base, ".*", ext
if len(parts) >= 2 and parts[-1].isdigit():
base = ".".join(parts[:-1])
return base, ".*", ""
# Fallback: wildcard whole name
stem = Path(name).stem
return stem, ".*", Path(name).suffix
base0, star, ext0 = pattern_for(names[0])
# Verify others align
ok = True
for n in names[1:]:
b, s, e = pattern_for(n)
if b != base0 or e != ext0:
ok = False
break
if ok:
wildcard = f"{base0}{star}{ext0}"
mapping[table] = f"{uc_volume_path.rstrip('/')}/{wildcard}"
else:
# Fallback to first file
mapping[table] = f"{uc_volume_path.rstrip('/')}/{names[0]}"
return mapping
[docs]
def load_data(
self, benchmark, connection: Any, data_dir: Path
) -> tuple[dict[str, int], float, dict[str, Any] | None]:
"""Load data using Databricks COPY INTO from UC Volumes or cloud storage.
This implementation avoids temporary views and uses COPY INTO for robust ingestion.
"""
start_time = time.time()
self.log_operation_start("Data loading", f"benchmark: {benchmark.__class__.__name__}")
self.log_very_verbose(f"Data directory: {data_dir}")
table_stats = {}
per_table_timings = {} # Track detailed timings per table
cursor = connection.cursor()
try:
# Get data files from benchmark
data_files = None
if hasattr(benchmark, "tables") and benchmark.tables:
# Use generated data files from benchmark directly
data_files = benchmark.tables
elif hasattr(benchmark, "_impl") and hasattr(benchmark._impl, "tables") and benchmark._impl.tables:
# Use generated data files from benchmark implementation
data_files = benchmark._impl.tables
if not data_files:
# Manifest fallback
try:
import json
manifest_path = Path(data_dir) / "_datagen_manifest.json"
if manifest_path.exists():
with open(manifest_path) as f:
manifest = json.load(f)
tables = manifest.get("tables") or {}
mapping = {}
for table, entries in tables.items():
if entries:
rel = entries[0].get("path")
if rel:
mapping[table] = Path(data_dir) / rel
if mapping:
data_files = mapping
self.logger.debug("Using data files from _datagen_manifest.json")
except Exception as e:
self.logger.debug(f"Manifest fallback failed: {e}")
if not data_files:
# No data files available - benchmark should have generated data first
raise ValueError("No data files found. Ensure benchmark.generate_data() was called first.")
# Determine staging root for COPY INTO
def _is_cloud_uri(s: str) -> bool:
return s.startswith(("s3://", "gs://", "abfss://", "dbfs:/"))
stage_root = None
# Check if data_dir is a DatabricksPath with dbfs_target
from benchbox.utils.cloud_storage import DatabricksPath
if isinstance(data_dir, DatabricksPath) and hasattr(data_dir, "dbfs_target") and data_dir.dbfs_target:
# Use the dbfs target from DatabricksPath
stage_root = data_dir.dbfs_target.rstrip("/")
self.log_verbose(f"Using DatabricksPath dbfs_target: {stage_root}")
# Prefer explicit staging_root
elif isinstance(self.staging_root, str) and _is_cloud_uri(self.staging_root):
stage_root = self.staging_root.rstrip("/")
else:
# Try to use UC Volume config
if self.uc_catalog and self.uc_schema and self.uc_volume:
stage_root = f"dbfs:/Volumes/{self.uc_catalog}/{self.uc_schema}/{self.uc_volume}".rstrip("/")
else:
# If data_dir looks like a cloud/DBFS URI, use it directly
data_dir_str = str(data_dir)
if _is_cloud_uri(data_dir_str):
stage_root = data_dir_str.rstrip("/")
if not stage_root:
raise ValueError(
"Databricks data loading requires a cloud/UC Volume staging location. "
"Add --output flag with cloud path `dbfs:/`; `s3://`, `gs://`, `abfss://`."
)
# If data is local and stage_root is a UC Volume, upload files first
# For DatabricksPath, data is always local (temp dir) with remote target
data_is_local = isinstance(data_dir, DatabricksPath) or not _is_cloud_uri(str(data_dir))
def _is_complete_uc_volume_path(p: str) -> bool:
v = p.replace("dbfs:", "").rstrip("/")
if not v.startswith("/Volumes/"):
return False
parts = v.split("/")
return len(parts) >= 5 # ['', 'Volumes', 'catalog', 'schema', 'volume', ...]
if data_is_local and stage_root.startswith("dbfs:/Volumes/") and _is_complete_uc_volume_path(stage_root):
self.log_verbose(f"Uploading local data to UC Volume: {stage_root}")
# Ensure UC Volume exists (create if necessary)
self._ensure_uc_volume_exists(stage_root, connection)
force_upload = getattr(self, "force_upload", False)
original_files = dict(data_files)
uploaded_files = self._upload_to_uc_volume(
data_files,
stage_root,
data_dir,
force_upload=force_upload,
)
# If upload returned a mapping, use it; otherwise, fall back to original mapping
if uploaded_files:
data_files = uploaded_files
else:
data_files = original_files
self.log_verbose("Upload to UC Volume completed")
# Ensure we're in the correct schema context for table operations
cursor.execute(f"USE CATALOG {self.catalog}")
cursor.execute(f"USE SCHEMA {self.schema}")
self.log_verbose(f"Set schema context for data loading: {self.catalog}.{self.schema}")
# Verify tables exist before attempting to load data
cursor.execute(f"SHOW TABLES IN {self.catalog}.{self.schema}")
existing_tables = {row[1].lower() for row in cursor.fetchall()}
self.log_very_verbose(f"Found {len(existing_tables)} existing tables in {self.catalog}.{self.schema}")
# Load data for each table using COPY INTO
for table_name, file_path in data_files.items():
try:
load_start = time.time()
table_name_upper = table_name.upper()
# Verify table exists before COPY INTO
if table_name.lower() not in existing_tables:
self.logger.error(f"Table {table_name_upper} not found in schema {self.catalog}.{self.schema}")
self.logger.error(f"Available tables: {sorted(existing_tables)}")
raise RuntimeError(
f"Table {table_name_upper} does not exist in {self.catalog}.{self.schema}. "
f"Ensure schema creation completed successfully before loading data."
)
# Determine path and delimiter
# After UC Volume upload, file_path is already a full URI (dbfs:/Volumes/...)
# Otherwise, construct from stage_root + filename
if isinstance(file_path, str) and file_path.startswith("dbfs:/Volumes/"):
# Already uploaded to UC Volume - use as-is (may contain wildcards for sharded tables)
file_uri = file_path
# Extract filename pattern (handle wildcards like customer.tbl.*.zst)
uri_path = file_path.replace("dbfs:", "")
filename = uri_path.split("/")[-1] # Get last part of path (may have wildcards)
else:
# Construct URI from stage_root
rel = None
if hasattr(file_path, "name"):
rel = getattr(file_path, "name", None)
else:
# If this is already a string/path-like, just get filename
rel = Path(str(file_path)).name
filename = rel
file_uri = f"{stage_root}/{rel}"
# Detect file format - handle compressed files (.zst, .gz, .bz2)
# For wildcard patterns (e.g., customer.tbl.*.zst), remove wildcard for format detection
compression_exts = {".zst", ".gz", ".bz2", ".xz", ".lz4"}
# Strip wildcard component for format detection (customer.tbl.*.zst -> customer.tbl.zst)
filename_for_format = filename.replace(".*", "")
file_path_obj = Path(filename_for_format)
# Strip compression extension if present
base_name = filename_for_format
if file_path_obj.suffix in compression_exts:
base_name = file_path_obj.stem # nation.tbl.zst -> nation.tbl
# Now get the actual data format suffix
format_suffix = Path(base_name).suffix or ".tbl"
# TPC benchmarks use pipe delimiter, CSV uses comma
# TPC-H uses .tbl files, TPC-DS uses .dat files - both are pipe-delimited
delimiter = "|" if format_suffix in [".tbl", ".dat"] else ","
# Get column names from benchmark schema for explicit column mapping
# This fixes the "Incoming schema has additional field(s): _c0, _c1, _c2" error
# Delta Lake requires explicit column mapping when header='false'
column_list = ""
if hasattr(benchmark, "get_schema"):
try:
schema = benchmark.get_schema()
# All benchmarks now return dict[str, dict] format
# Try case-insensitive lookup
table_schema = schema.get(table_name.lower())
if not table_schema:
# Fallback to uppercase lookup
table_schema = schema.get(table_name_upper.lower())
if not table_schema:
# Try original case
table_schema = schema.get(table_name)
if table_schema and "columns" in table_schema:
columns = [col["name"] for col in table_schema["columns"]]
if columns:
column_list = f" ({', '.join(columns)})"
self.log_very_verbose(
f"Using explicit column mapping for {table_name_upper}: {len(columns)} columns"
)
except Exception as e:
self.log_very_verbose(f"Could not get column list for {table_name}: {e}")
copy_sql = (
f"COPY INTO {table_name_upper}{column_list} FROM '{file_uri}' "
f"FILEFORMAT = CSV FORMAT_OPTIONS('delimiter'='{delimiter}', 'header'='false')"
)
# Log wildcard pattern for visibility
if "*" in file_uri:
self.log_verbose(f"Loading {table_name_upper} from wildcard pattern: {file_uri}")
# Time COPY INTO
copy_start = time.time()
cursor.execute(copy_sql)
copy_time = time.time() - copy_start
# Row count
cursor.execute(f"SELECT COUNT(*) FROM {table_name_upper}")
row_count = cursor.fetchone()[0]
table_stats[table_name_upper] = row_count
# Optional optimize - track timing separately
optimize_time = 0.0
if self.enable_delta_optimization:
optimize_start = time.time()
with contextlib.suppress(Exception):
cursor.execute(f"OPTIMIZE {table_name_upper}")
optimize_time = time.time() - optimize_start
load_time = time.time() - load_start
# Store detailed timings
per_table_timings[table_name_upper] = {
"copy_into_ms": copy_time * 1000,
"optimize_ms": optimize_time * 1000,
"total_ms": load_time * 1000,
"rows": row_count,
}
self.logger.info(f"✅ Loaded {row_count:,} rows into {table_name_upper} in {load_time:.2f}s")
except Exception as e:
self.logger.error(f"Failed to load {table_name}: {str(e)[:200]}")
table_stats[table_name.upper()] = 0
# Record failed table with zero timings
per_table_timings[table_name.upper()] = {
"copy_into_ms": 0,
"optimize_ms": 0,
"total_ms": 0,
"rows": 0,
}
total_time = time.time() - start_time
total_rows = sum(table_stats.values())
self.log_operation_complete(
"Data loading", total_time, f"{total_rows:,} total rows, {len(table_stats)} tables"
)
finally:
cursor.close()
return table_stats, total_time, per_table_timings
[docs]
def execute_query(
self,
connection: Any,
query: str,
query_id: str,
benchmark_type: str | None = None,
scale_factor: float | None = None,
validate_row_count: bool = True,
stream_id: int | None = None,
) -> dict[str, Any]:
"""Execute query with detailed timing and profiling."""
start_time = time.time()
self.log_verbose(f"Executing query {query_id}")
self.log_very_verbose(f"Query SQL (first 200 chars): {query[:200]}{'...' if len(query) > 200 else ''}")
cursor = connection.cursor()
try:
# Schema context is already set in create_connection() and persists for the session
# No need to set USE <catalog>.<schema> before every query - it adds unnecessary overhead
# (Each USE statement = 1 extra round-trip to Databricks)
# Execute the query
# Note: Query dialect translation is now handled automatically by the base adapter
cursor.execute(query)
result = cursor.fetchall()
execution_time = time.time() - start_time
actual_row_count = len(result) if result else 0
# Validate row count if enabled and benchmark type is provided
validation_result = None
if validate_row_count and benchmark_type:
from benchbox.core.validation.query_validation import QueryValidator
validator = QueryValidator()
validation_result = validator.validate_query_result(
benchmark_type=benchmark_type,
query_id=query_id,
actual_row_count=actual_row_count,
scale_factor=scale_factor,
stream_id=stream_id,
)
# Log validation result
if validation_result.warning_message:
self.log_verbose(f"Row count validation: {validation_result.warning_message}")
elif not validation_result.is_valid:
self.log_verbose(f"Row count validation FAILED: {validation_result.error_message}")
else:
self.log_very_verbose(
f"Row count validation PASSED: {actual_row_count} rows "
f"(expected: {validation_result.expected_row_count})"
)
# Use base helper to build result with consistent validation field mapping
result_dict = self._build_query_result_with_validation(
query_id=query_id,
execution_time=execution_time,
actual_row_count=actual_row_count,
first_row=result[0] if result else None,
validation_result=validation_result,
)
# Include Databricks-specific fields
result_dict["translated_query"] = None # Translation handled by base adapter
# Add resource usage for cost calculation (execution time for DBU estimation)
result_dict["resource_usage"] = {
"execution_time_seconds": execution_time,
}
return result_dict
except Exception as e:
execution_time = time.time() - start_time
return {
"query_id": query_id,
"status": "FAILED",
"execution_time": execution_time,
"rows_returned": 0,
"error": str(e),
"error_type": type(e).__name__,
}
finally:
cursor.close()
def _fix_databricks_sql_syntax(self, sql: str) -> str:
"""Transform SQL syntax for Databricks compatibility.
This method removes SQL syntax that is not supported by Databricks/Spark SQL,
particularly NULLS FIRST/LAST clauses in PRIMARY KEY constraints.
Args:
sql: SQL statement(s) to fix
Returns:
Fixed SQL with Databricks-compatible syntax
"""
import re
original_sql = sql
# Pattern 1: Remove NULLS LAST/FIRST from PRIMARY KEY constraints
# Databricks doesn't support NULLS ordering in PRIMARY KEY definitions
# Match: PRIMARY KEY (col1, col2 NULLS LAST)
# Also match: PRIMARY KEY (col1 NULLS FIRST, col2)
nulls_in_pk_pattern = r"\b(PRIMARY\s+KEY\s*\([^)]*?)\s+NULLS\s+(LAST|FIRST)\s*([^)]*?\))"
def remove_nulls_from_pk(match):
# Reconstruct without the NULLS clause
before = match.group(1) # PRIMARY KEY (col1, col2
after = match.group(3) # remaining part + closing paren
return f"{before} {after}".strip()
fixed_sql = re.sub(nulls_in_pk_pattern, remove_nulls_from_pk, sql, flags=re.IGNORECASE)
# Pattern 2: Remove standalone NULLS clauses in column definitions within PRIMARY KEY
# This catches cases like: PRIMARY KEY (col1 NULLS LAST, col2 NULLS FIRST)
# Apply multiple times to catch all occurrences
max_iterations = 10 # Safety limit
for _ in range(max_iterations):
prev = fixed_sql
fixed_sql = re.sub(
r"\b(PRIMARY\s+KEY\s*\([^)]*?)\s+NULLS\s+(LAST|FIRST)\b",
r"\1",
fixed_sql,
flags=re.IGNORECASE,
)
if fixed_sql == prev:
break # No more replacements
# Log if any changes were made
if fixed_sql != original_sql:
changes_made = original_sql != fixed_sql
if changes_made:
self.log_very_verbose("Fixed Databricks SQL syntax (removed NULLS FIRST/LAST from PRIMARY KEY)")
self.log_very_verbose(f"Before: {original_sql[:200]}...")
self.log_very_verbose(f"After: {fixed_sql[:200]}...")
return fixed_sql
def _convert_to_delta_table(self, statement: str) -> str:
"""Convert CREATE TABLE statement to Delta Lake format."""
if not statement.upper().startswith("CREATE TABLE"):
return statement
# Ensure idempotency with OR REPLACE
if "CREATE TABLE" in statement.upper() and "OR REPLACE" not in statement.upper():
statement = statement.replace("CREATE TABLE", "CREATE OR REPLACE TABLE", 1)
# Default to DELTA format when unspecified
if "USING" not in statement.upper():
# Find the closing parenthesis of column definitions
paren_count = 0
using_pos = len(statement)
for i, char in enumerate(statement):
if char == "(":
paren_count += 1
elif char == ")":
paren_count -= 1
if paren_count == 0:
using_pos = i + 1
break
# Insert USING DELTA clause
statement = statement[:using_pos] + " USING DELTA" + statement[using_pos:]
# Include Delta Lake optimization properties
if "TBLPROPERTIES" not in statement.upper():
statement += " TBLPROPERTIES ("
properties = []
if self.delta_auto_optimize:
properties.append("'delta.autoOptimize.optimizeWrite' = 'true'")
properties.append("'delta.autoOptimize.autoCompact' = 'true'")
statement += ", ".join(properties) + ")"
return statement
def _get_platform_metadata(self, connection: Any) -> dict[str, Any]:
"""Get Databricks-specific metadata and system information."""
metadata = {
"platform": self.platform_name,
"server_hostname": self.server_hostname,
"catalog": self.catalog,
"schema": self.schema,
"result_cache_enabled": not self.disable_result_cache,
}
cursor = connection.cursor()
try:
# Get Spark version
cursor.execute("SELECT version()")
result = cursor.fetchone()
metadata["spark_version"] = result[0] if result else "unknown"
# Get current catalog and schema
cursor.execute("SELECT current_catalog(), current_schema()")
result = cursor.fetchone()
if result:
metadata["current_catalog"] = result[0]
metadata["current_schema"] = result[1]
# Get cluster information
cursor.execute("SHOW FUNCTIONS LIKE 'current_*'")
functions = cursor.fetchall()
metadata["available_functions"] = [f[0] for f in functions]
# Get Spark configurations
cursor.execute("SET")
configs = cursor.fetchall()
spark_configs = {k: v for k, v in configs if k.startswith("spark.")}
metadata["spark_configurations"] = spark_configs
except Exception as e:
metadata["metadata_error"] = str(e)
finally:
cursor.close()
return metadata
[docs]
def analyze_table(self, connection: Any, table_name: str) -> None:
"""Run ANALYZE TABLE for better query optimization."""
cursor = connection.cursor()
try:
cursor.execute(f"ANALYZE TABLE {table_name.upper()} COMPUTE STATISTICS")
self.logger.info(f"Analyzed table {table_name.upper()}")
except Exception as e:
self.logger.warning(f"Failed to analyze table {table_name}: {e}")
finally:
cursor.close()
[docs]
def optimize_table(self, connection: Any, table_name: str) -> None:
"""Optimize Delta Lake table."""
if not self.enable_delta_optimization:
return
cursor = connection.cursor()
try:
cursor.execute(f"OPTIMIZE {table_name.upper()}")
self.logger.info(f"Optimized Delta table {table_name.upper()}")
except Exception as e:
self.logger.warning(f"Failed to optimize table {table_name}: {e}")
finally:
cursor.close()
[docs]
def vacuum_table(self, connection: Any, table_name: str, hours: int = 168) -> None:
"""Vacuum Delta Lake table to remove old files."""
if not self.enable_delta_optimization:
return
cursor = connection.cursor()
try:
cursor.execute(f"VACUUM {table_name.upper()} RETAIN {hours} HOURS")
self.logger.info(f"Vacuumed Delta table {table_name.upper()}")
except Exception as e:
self.logger.warning(f"Failed to vacuum table {table_name}: {e}")
finally:
cursor.close()
def _get_existing_tables(self, connection: Any) -> list[str]:
"""Get list of existing tables in the Databricks schema."""
try:
cursor = connection.cursor()
# Use Databricks-specific query to get tables in current schema
cursor.execute(f"SHOW TABLES IN {self.catalog}.{self.schema}")
result = cursor.fetchall()
cursor.close()
# Result format is (database, tableName, isTemporary)
return [row[1] for row in result if not row[2]] # Exclude temporary tables
except Exception as e:
self.logger.debug(f"Failed to get existing tables: {e}")
return []
[docs]
def close_connection(self, connection: Any) -> None:
"""Close Databricks connection."""
try:
if connection and hasattr(connection, "close"):
connection.close()
except Exception as e:
self.logger.warning(f"Error closing connection: {e}")
[docs]
def supports_tuning_type(self, tuning_type) -> bool:
"""Check if Databricks supports a specific tuning type.
Databricks supports:
- PARTITIONING: Via PARTITIONED BY clause in Delta Lake
- CLUSTERING: Via CLUSTER BY clause (Delta Lake 2.0+)
- DISTRIBUTION: Via Spark optimization hints and Z-ORDER clustering
Args:
tuning_type: The type of tuning to check support for
Returns:
True if the tuning type is supported by Databricks
"""
# Import here to avoid circular imports
try:
from benchbox.core.tuning.interface import TuningType
return tuning_type in {
TuningType.PARTITIONING,
TuningType.CLUSTERING,
TuningType.DISTRIBUTION,
}
except ImportError:
return False
[docs]
def generate_tuning_clause(self, table_tuning) -> str:
"""Generate Databricks-specific tuning clauses for CREATE TABLE statements.
Databricks supports:
- USING DELTA (Delta Lake format)
- PARTITIONED BY (column1, column2, ...)
- CLUSTER BY (column1, column2, ...) for Delta Lake 2.0+
- Z-ORDER optimization
Args:
table_tuning: The tuning configuration for the table
Returns:
SQL clause string to be appended to CREATE TABLE statement
"""
if not table_tuning or not table_tuning.has_any_tuning():
return ""
clauses = []
try:
# Import here to avoid circular imports
from benchbox.core.tuning.interface import TuningType
# Always use Delta Lake format for better performance
clauses.append("USING DELTA")
# Handle partitioning
partition_columns = table_tuning.get_columns_by_type(TuningType.PARTITIONING)
if partition_columns:
# Sort by order and create partition clause
sorted_cols = sorted(partition_columns, key=lambda col: col.order)
column_names = [col.name for col in sorted_cols]
partition_clause = f"PARTITIONED BY ({', '.join(column_names)})"
clauses.append(partition_clause)
# Handle clustering (Delta Lake 2.0+)
cluster_columns = table_tuning.get_columns_by_type(TuningType.CLUSTERING)
if cluster_columns:
# Sort by order and create cluster clause
sorted_cols = sorted(cluster_columns, key=lambda col: col.order)
column_names = [col.name for col in sorted_cols]
cluster_clause = f"CLUSTER BY ({', '.join(column_names)})"
clauses.append(cluster_clause)
# Distribution handled through Z-ORDER optimization (applied post-creation)
except ImportError:
# If tuning interface not available, at least use Delta format
clauses.append("USING DELTA")
return " ".join(clauses)
[docs]
def apply_table_tunings(self, table_tuning, connection: Any) -> None:
"""Apply tuning configurations to a Databricks Delta Lake table.
Databricks tuning approach:
- PARTITIONING: Handled via PARTITIONED BY in CREATE TABLE
- CLUSTERING: Handled via CLUSTER BY in CREATE TABLE or ALTER TABLE
- DISTRIBUTION: Achieved through Z-ORDER clustering and OPTIMIZE
- Delta Lake optimization and maintenance
Args:
table_tuning: The tuning configuration to apply
connection: Databricks connection
Raises:
ValueError: If the tuning configuration is invalid for Databricks
"""
if not table_tuning or not table_tuning.has_any_tuning():
return
table_name = table_tuning.table_name.upper()
self.logger.info(f"Applying Databricks tunings for table: {table_name}")
cursor = connection.cursor()
try:
# Import here to avoid circular imports
from benchbox.core.tuning.interface import TuningType
# Check if table exists and is Delta format
cursor.execute(f"DESCRIBE EXTENDED {table_name}")
table_info = cursor.fetchall()
is_delta_table = any("DELTA" in str(row).upper() for row in table_info)
if not is_delta_table:
self.logger.warning(
f"Table {table_name} is not a Delta table - some optimizations may not be available"
)
# Handle clustering via Z-ORDER optimization
cluster_columns = table_tuning.get_columns_by_type(TuningType.CLUSTERING)
distribution_columns = table_tuning.get_columns_by_type(TuningType.DISTRIBUTION)
# Combine clustering and distribution columns for Z-ORDER
zorder_columns = []
if cluster_columns:
sorted_cols = sorted(cluster_columns, key=lambda col: col.order)
zorder_columns.extend([col.name for col in sorted_cols])
if distribution_columns:
sorted_cols = sorted(distribution_columns, key=lambda col: col.order)
# Include distribution columns if not already in clustering
for col in sorted_cols:
if col.name not in zorder_columns:
zorder_columns.append(col.name)
if zorder_columns and is_delta_table:
# Apply Z-ORDER optimization
zorder_clause = f"OPTIMIZE {table_name} ZORDER BY ({', '.join(zorder_columns)})"
try:
cursor.execute(zorder_clause)
self.logger.info(f"Applied Z-ORDER optimization to {table_name}: {', '.join(zorder_columns)}")
except Exception as e:
self.logger.warning(f"Failed to apply Z-ORDER optimization to {table_name}: {e}")
# Handle partitioning information (logging only, as it's defined at CREATE TABLE time)
partition_columns = table_tuning.get_columns_by_type(TuningType.PARTITIONING)
if partition_columns:
sorted_cols = sorted(partition_columns, key=lambda col: col.order)
column_names = [col.name for col in sorted_cols]
self.logger.info(
f"Partitioning strategy for {table_name}: {', '.join(column_names)} (defined at CREATE TABLE time)"
)
# Handle sorting through clustering/Z-ORDER
sort_columns = table_tuning.get_columns_by_type(TuningType.SORTING)
if sort_columns:
sorted_cols = sorted(sort_columns, key=lambda col: col.order)
column_names = [col.name for col in sorted_cols]
self.logger.info(
f"Sorting in Databricks achieved via Z-ORDER clustering for table {table_name}: {', '.join(column_names)}"
)
# Perform general Delta Lake optimizations
if is_delta_table and self.enable_delta_optimization:
try:
# Run OPTIMIZE to compact small files
cursor.execute(f"OPTIMIZE {table_name}")
self.logger.info(f"Optimized Delta table {table_name}")
# Refresh table statistics
cursor.execute(f"ANALYZE TABLE {table_name} COMPUTE STATISTICS")
self.logger.info(f"Updated statistics for {table_name}")
except Exception as e:
self.logger.warning(f"Failed to optimize Delta table {table_name}: {e}")
except ImportError:
self.logger.warning("Tuning interface not available - skipping tuning application")
except Exception as e:
raise ValueError(f"Failed to apply tunings to Databricks table {table_name}: {e}")
finally:
cursor.close()
[docs]
def apply_unified_tuning(self, unified_config: UnifiedTuningConfiguration, connection: Any) -> None:
"""Apply unified tuning configuration to Databricks.
Args:
unified_config: Unified tuning configuration to apply
connection: Databricks connection
"""
if not unified_config:
return
# Apply constraint configurations
self.apply_constraint_configuration(unified_config.primary_keys, unified_config.foreign_keys, connection)
# Apply platform optimizations
if unified_config.platform_optimizations:
self.apply_platform_optimizations(unified_config.platform_optimizations, connection)
# Apply table-level tunings
for _table_name, table_tuning in unified_config.table_tunings.items():
self.apply_table_tunings(table_tuning, connection)
[docs]
def apply_constraint_configuration(
self,
primary_key_config: PrimaryKeyConfiguration,
foreign_key_config: ForeignKeyConfiguration,
connection: Any,
) -> None:
"""Apply constraint configurations to Databricks.
Note: Databricks (Spark SQL) supports PRIMARY KEY and FOREIGN KEY constraints
but they are informational only (not enforced). They are used for query optimization
in Catalyst optimizer and must be applied during table creation time.
Args:
primary_key_config: Primary key constraint configuration
foreign_key_config: Foreign key constraint configuration
connection: Databricks connection
"""
# Databricks constraints are applied at table creation time for Catalyst optimization
# This method is called after tables are created, so log the configurations
if primary_key_config and primary_key_config.enabled:
self.logger.info(
"Primary key constraints enabled for Databricks (informational only, applied during table creation)"
)
if foreign_key_config and foreign_key_config.enabled:
self.logger.info(
"Foreign key constraints enabled for Databricks (informational only, applied during table creation)"
)
# Databricks constraints are informational and used by Catalyst optimizer
# No additional work to do here as they're applied during CREATE TABLE