Source code for benchbox.platforms.datafusion

"""DataFusion platform adapter with data loading and query execution.

Provides Apache DataFusion-specific optimizations for in-memory OLAP workloads,
supporting both CSV and Parquet formats with automatic conversion options.

Copyright 2026 Joe Harris / BenchBox Project

Licensed under the MIT License. See LICENSE file in the project root for details.
"""

from __future__ import annotations

import logging
import os
import threading
import time
from pathlib import Path
from typing import TYPE_CHECKING, Any

try:
    from datafusion import SessionConfig, SessionContext

    try:
        from datafusion import RuntimeEnv
    except ImportError:
        # Newer versions use RuntimeEnvBuilder
        from datafusion import RuntimeEnvBuilder as RuntimeEnv
except ImportError:
    SessionContext = None  # type: ignore[assignment, misc]
    SessionConfig = None  # type: ignore[assignment, misc]
    RuntimeEnv = None  # type: ignore[assignment, misc]

from benchbox.platforms.base import DriverIsolationCapability, PlatformAdapter
from benchbox.utils.clock import elapsed_seconds, mono_time
from benchbox.utils.file_format import get_delimiter_for_file

logger = logging.getLogger(__name__)

if TYPE_CHECKING:
    from benchbox.core.tuning.interface import TuningColumn


class DataFusionCursorCompat:
    """DB-API-like cursor wrapper for DataFusion SQL results."""

    def __init__(self, dataframe: Any):
        self._dataframe = dataframe
        self._rows: list[tuple[Any, ...]] | None = None
        self.rowcount = -1

    def _materialize(self) -> list[tuple[Any, ...]]:
        if self._rows is not None:
            return self._rows

        rows: list[tuple[Any, ...]] = []
        batches = self._dataframe.collect()

        for batch in batches:
            column_names = [str(name) for name in batch.schema.names]
            for row in batch.to_pylist():
                rows.append(tuple(row.get(name) for name in column_names))

        self.rowcount = len(rows)
        self._rows = rows
        return rows

    def fetchone(self) -> tuple[Any, ...] | None:
        rows = self._materialize()
        return rows[0] if rows else None

    def fetchall(self) -> list[tuple[Any, ...]]:
        return self._materialize()


class DataFusionConnectionCompat:
    """SessionContext wrapper exposing a DB-API-like execute() method."""

    def __init__(self, context: Any):
        self._context = context

    @staticmethod
    def _requires_eager_execution(query: str) -> bool:
        """Return True for SQL statements that must execute immediately for side effects."""
        statement = query.lstrip()
        while statement.startswith("--"):
            newline_pos = statement.find("\n")
            if newline_pos == -1:
                return False
            statement = statement[newline_pos + 1 :].lstrip()

        upper_statement = statement.upper()
        eager_prefixes = (
            "INSERT",
            "UPDATE",
            "DELETE",
            "MERGE",
            "CREATE",
            "DROP",
            "ALTER",
            "TRUNCATE",
            "COPY",
        )
        return upper_statement.startswith(eager_prefixes)

    def execute(self, query: str, parameters: Any = None) -> DataFusionCursorCompat:
        if parameters is not None:
            raise ValueError("DataFusion SQL execute() does not support bound parameters in this adapter path")
        cursor = DataFusionCursorCompat(self._context.sql(query))
        if self._requires_eager_execution(query):
            cursor.fetchall()
        return cursor

    def sql(self, query: str) -> Any:
        return self._context.sql(query)

    def __getattr__(self, name: str) -> Any:
        return getattr(self._context, name)


[docs] class DataFusionAdapter(PlatformAdapter): """Apache DataFusion platform adapter with optimized bulk loading and execution.""" driver_isolation_capability = DriverIsolationCapability.SUPPORTED supports_external_tables = True # Process-wide lock bookkeeping keyed by working-dir lock file path. # This ensures ownership/reentrancy is shared across adapter instances. _process_working_dir_lock_depth: dict[str, int] = {} _process_working_dir_lock_guard = threading.Lock() @property def platform_name(self) -> str: return "DataFusion"
[docs] def get_target_dialect(self) -> str: """Get the target SQL dialect for DataFusion. Returns platform dialect identifier so catalog variants can target DataFusion. SQL translation normalizes this to PostgreSQL semantics where needed. """ return "datafusion"
[docs] def preprocess_operation_sql(self, operation_id: str, operation: Any) -> str | None: """Preprocess write operation SQL for DataFusion compatibility. Rewrites COPY-based bulk load SQL to CREATE EXTERNAL TABLE pattern. Returns None for non-bulk_load operations (no preprocessing needed). Args: operation_id: Operation identifier operation: WriteOperation object with category, write_sql, file_dependencies Returns: Transformed SQL string, or None if no preprocessing needed """ if operation.category.lower() == "bulk_load": from benchbox.platforms.datafusion_write_transformer import transform_write_sql return transform_write_sql( operation_id, operation.category, operation.write_sql, operation.file_dependencies, ) return None
[docs] @staticmethod def add_cli_arguments(parser) -> None: """Add DataFusion-specific CLI arguments.""" datafusion_group = parser.add_argument_group("DataFusion Arguments") datafusion_group.add_argument( "--datafusion-memory-limit", type=str, default="16G", help="DataFusion memory limit (e.g., '16G', '8GB', '4096MB')", ) datafusion_group.add_argument( "--datafusion-partitions", type=int, default=None, help="Number of parallel partitions (default: CPU count)", ) datafusion_group.add_argument( "--datafusion-format", type=str, choices=["csv", "parquet"], default="parquet", help="Data format to use (parquet recommended for performance)", ) datafusion_group.add_argument( "--datafusion-temp-dir", type=str, default=None, help="Temporary directory for disk spilling", ) datafusion_group.add_argument( "--datafusion-batch-size", type=int, default=8192, help="RecordBatch size for query execution", ) datafusion_group.add_argument( "--datafusion-working-dir", type=str, help="Working directory for DataFusion tables and data", )
[docs] @classmethod def from_config(cls, config: dict[str, Any]): """Create DataFusion adapter from unified configuration.""" from pathlib import Path from benchbox.utils.database_naming import generate_database_filename # Extract DataFusion-specific configuration adapter_config = {} # Working directory handling (similar to database path for file-based DBs) if config.get("working_dir"): adapter_config["working_dir"] = config["working_dir"] else: # Generate database directory path using standard naming convention # DataFusion stores data in Parquet format (.parquet extension determined by platform) from benchbox.utils.path_utils import get_benchmark_runs_databases_path if config.get("output_dir"): data_dir = get_benchmark_runs_databases_path( config["benchmark"], config["scale_factor"], base_dir=Path(config["output_dir"]) / "databases", ) else: data_dir = get_benchmark_runs_databases_path(config["benchmark"], config["scale_factor"]) db_filename = generate_database_filename( benchmark_name=config["benchmark"], scale_factor=config["scale_factor"], platform="datafusion", tuning_config=config.get("tuning_config"), ) # Full path to DataFusion working directory working_dir = data_dir / db_filename adapter_config["working_dir"] = str(working_dir) working_dir.mkdir(parents=True, exist_ok=True) # Memory limit adapter_config["memory_limit"] = config.get("memory_limit", "16G") # Parallelism (default to CPU count) adapter_config["target_partitions"] = config.get("partitions") or os.cpu_count() # Data format adapter_config["data_format"] = config.get("format", "parquet") # Temp directory for spilling adapter_config["temp_dir"] = config.get("temp_dir") # Batch size adapter_config["batch_size"] = config.get("batch_size", 8192) # Force recreate adapter_config["force_recreate"] = config.get("force", False) # Pass through other relevant config for key in ["tuning_config", "verbose_enabled", "very_verbose"]: if key in config: adapter_config[key] = config[key] return cls(**adapter_config)
[docs] def __init__(self, **config): super().__init__(**config) if SessionContext is None: raise ImportError("DataFusion not installed. Install with: pip install datafusion") # DataFusion configuration self.working_dir = Path(config.get("working_dir", "./datafusion_working")) self.memory_limit = config.get("memory_limit", "16G") self.target_partitions = config.get("target_partitions", os.cpu_count()) self.data_format = config.get("data_format", "parquet") self.temp_dir = config.get("temp_dir") self.batch_size = config.get("batch_size", 8192) # Schema tracking (populated during create_schema) self._table_schemas = {} # Create working directory self.working_dir.mkdir(parents=True, exist_ok=True)
[docs] def get_platform_info(self, connection: Any = None) -> dict[str, Any]: """Get DataFusion platform information.""" platform_info = { "platform_type": "datafusion", "platform_name": "DataFusion", "connection_mode": "in-memory", "configuration": { "working_dir": str(self.working_dir), "memory_limit": self.memory_limit, "target_partitions": self.target_partitions, "data_format": self.data_format, "temp_dir": self.temp_dir, "batch_size": self.batch_size, "result_cache_enabled": False, }, } # Get DataFusion version from live module (coupled: driver == engine). try: import datafusion as df_module live_version = df_module.__version__ platform_info["client_library_version"] = live_version platform_info["platform_version"] = live_version platform_info["driver_version_actual"] = live_version self.driver_version_actual = live_version except (ImportError, AttributeError): platform_info["client_library_version"] = None platform_info["platform_version"] = None # Propagate driver runtime contract metadata. if self.driver_runtime_strategy: platform_info["driver_runtime_strategy"] = self.driver_runtime_strategy if self.driver_version_requested: platform_info["driver_version_requested"] = self.driver_version_requested if self.driver_version_resolved: platform_info["driver_version_resolved"] = self.driver_version_resolved if self.driver_version_actual: platform_info["driver_version_actual"] = self.driver_version_actual return platform_info
[docs] def create_connection(self, **connection_config) -> Any: """Create DataFusion SessionContext with optimized configuration.""" self.log_operation_start("DataFusion connection") # Serialize database management on shared working dirs to avoid cleanup races. lock_acquired = self._acquire_working_dir_lock(timeout_seconds=10, **connection_config) if not lock_acquired: raise RuntimeError("Could not acquire DataFusion working directory lock after 10 seconds") try: # Handle existing database using base class method self.handle_existing_database(**connection_config) finally: self._release_working_dir_lock(**connection_config) # Configure runtime environment for disk spilling and memory management # Note: RuntimeEnv/RuntimeEnvBuilder API varies by version runtime = None if RuntimeEnv is not None: try: # Check if this is RuntimeEnvBuilder (newer API) if hasattr(RuntimeEnv, "build"): # RuntimeEnvBuilder API builder = RuntimeEnv() # Configure memory pool using fair spill pool # This replaces the invalid config.set("memory_pool_size") approach if self.memory_limit: memory_bytes = int(self._parse_memory_limit(self.memory_limit)) builder = builder.with_fair_spill_pool(memory_bytes) self.log_very_verbose( f"Configured fair spill pool: {self.memory_limit} ({memory_bytes:,} bytes)" ) # Configure disk manager for spilling builder = builder.with_disk_manager_os() if self.temp_dir: self.log_very_verbose(f"Enabled disk spilling (temp dir: {self.temp_dir})") else: self.log_very_verbose("Enabled disk spilling (using system temp dir)") runtime = builder.build() else: # Old RuntimeEnv API (fallback) runtime = RuntimeEnv() self.log_very_verbose("Using default RuntimeEnv (memory configuration not available in old API)") except Exception as e: self.log_very_verbose(f"Could not configure RuntimeEnv: {e}, using defaults") runtime = None # Create session configuration config = SessionConfig() # Set parallelism config = config.with_target_partitions(self.target_partitions) # Enable optimizations config = config.with_parquet_pruning(True) config = config.with_repartition_joins(True) config = config.with_repartition_aggregations(True) config = config.with_repartition_windows(True) config = config.with_information_schema(True) # Set batch size config = config.with_batch_size(self.batch_size) # Track applied configuration for logging config_applied = [ f"target_partitions={self.target_partitions}", f"batch_size={self.batch_size}", "parquet_pruning=enabled", "repartitioning=enabled", ] # Add runtime configuration to tracking if runtime is not None: if self.memory_limit: config_applied.append(f"memory_pool={self.memory_limit}") config_applied.append("disk_spilling=enabled") # Note: Memory configuration now handled via RuntimeEnvBuilder above # The invalid config.set("memory_pool_size") approach has been removed # Note: Parquet optimizations already configured via with_parquet_pruning(True) # Redundant config.set() calls have been removed # Try to normalize identifiers to lowercase (for TPC benchmark compatibility) # Note: This setting is undocumented in DataFusion 50.x and may not be necessary # as DataFusion already handles identifier case according to PostgreSQL semantics try: config = config.set("datafusion.sql_parser.enable_ident_normalization", "true") self.log_very_verbose("Enabled SQL identifier normalization (if supported)") config_applied.append("ident_normalization=enabled") except BaseException as e: # Not critical - DataFusion's PostgreSQL semantics handle TPC naming correctly self.log_very_verbose(f"Identifier normalization not available (using PostgreSQL defaults): {e}") # Create SessionContext with runtime environment if available if runtime is not None: try: ctx = SessionContext(config, runtime) self.log_very_verbose("SessionContext created with RuntimeEnv") except TypeError: # Older versions may not accept runtime parameter ctx = SessionContext(config) self.log_very_verbose("SessionContext created without RuntimeEnv (not supported in this version)") else: ctx = SessionContext(config) self.log_operation_complete("DataFusion connection", details=f"Applied: {', '.join(config_applied)}") return DataFusionConnectionCompat(ctx)
def _get_working_dir_lock_file(self, **connection_config) -> Path: """Get lock file path for working-dir lifecycle operations.""" working_dir = Path(connection_config.get("working_dir", self.working_dir)) return working_dir.parent / f".{working_dir.name}.db_manage.lock" def _is_pid_running(self, pid: int) -> bool: """Return True when process exists, False when definitely absent.""" if pid <= 0: return False try: os.kill(pid, 0) return True except ProcessLookupError: return False except PermissionError: # Process exists but may be owned by another user. return True except Exception: return False def _read_lock_pid(self, lock_file: Path) -> int | None: """Read PID from lock file content, returning None if unavailable.""" try: content = lock_file.read_text(encoding="utf-8") for line in content.splitlines(): if line.startswith("pid:"): return int(line.split(":", 1)[1].strip()) except Exception: return None return None def _acquire_working_dir_lock(self, timeout_seconds: int = 300, **connection_config) -> bool: """Acquire lock for DataFusion working-dir lifecycle operations.""" lock_file = self._get_working_dir_lock_file(**connection_config) lock_key = str(lock_file.resolve()) lock_file.parent.mkdir(parents=True, exist_ok=True) start_time = mono_time() lock_detected_logged = False wait_warning_emitted = False while elapsed_seconds(start_time) < timeout_seconds: try: with self._process_working_dir_lock_guard: # Reentrant fast-path shared across all adapter instances. existing_depth = self._process_working_dir_lock_depth.get(lock_key, 0) if existing_depth > 0: self._process_working_dir_lock_depth[lock_key] = existing_depth + 1 return True fd = os.open(lock_file, os.O_CREAT | os.O_EXCL | os.O_WRONLY) with os.fdopen(fd, "w", encoding="utf-8") as f: f.write(f"pid:{os.getpid()}\n") f.write(f"time:{time.time()}\n") self._process_working_dir_lock_depth[lock_key] = 1 return True except FileExistsError: try: age_seconds = time.time() - lock_file.stat().st_mtime lock_pid = self._read_lock_pid(lock_file) # Recover interrupted self-owned lock files when no active owner is tracked. if lock_pid == os.getpid(): with self._process_working_dir_lock_guard: existing_depth = self._process_working_dir_lock_depth.get(lock_key, 0) if existing_depth == 0: self.log_verbose( f"Removing stale self-owned DataFusion working-dir lock: {lock_file} " f"(pid={lock_pid}, age={age_seconds:.1f}s)" ) lock_file.unlink(missing_ok=True) continue # Verbose signal for immediate lock detection visibility. if not lock_detected_logged: self.log_verbose( f"DataFusion working-dir lock detected: {lock_file} " f"(pid={lock_pid}, age={age_seconds:.1f}s)" ) lock_detected_logged = True # Standard warning signal when we're blocked and waiting. if not wait_warning_emitted: self.logger.warning( f"DataFusion working directory lock is held by another process " f"(pid={lock_pid}). Waiting for release..." ) wait_warning_emitted = True owner_dead = lock_pid is not None and not self._is_pid_running(lock_pid) # Treat lock as stale when owner is gone, or metadata is absent and lock is old enough. if owner_dead or (lock_pid is None and age_seconds > 10): self.log_verbose( f"Removing stale DataFusion working-dir lock: {lock_file} " f"(pid={lock_pid}, age={age_seconds:.1f}s)" ) lock_file.unlink(missing_ok=True) continue except Exception: pass time.sleep(0.2) except Exception as e: self.log_verbose(f"Failed to acquire DataFusion working-dir lock: {e}") return False return False def _release_working_dir_lock(self, **connection_config) -> None: """Release lock for DataFusion working-dir lifecycle operations.""" lock_file = self._get_working_dir_lock_file(**connection_config) lock_key = str(lock_file.resolve()) should_unlink = False with self._process_working_dir_lock_guard: depth = self._process_working_dir_lock_depth.get(lock_key, 0) if depth > 1: self._process_working_dir_lock_depth[lock_key] = depth - 1 return if depth == 1: self._process_working_dir_lock_depth.pop(lock_key, None) should_unlink = True else: # No tracked ownership in this process: do not unlink another owner's lock. return if should_unlink: try: lock_file.unlink(missing_ok=True) except Exception as e: self.log_verbose(f"Failed to release DataFusion working-dir lock: {e}") def _parse_memory_limit(self, memory_limit: str) -> str: """Parse memory limit string to bytes. Args: memory_limit: Memory limit string (e.g., "16G", "8GB", "4096MB") Returns: Memory limit in bytes as string """ memory_str = memory_limit.upper().strip() # Remove 'B' suffix if present if memory_str.endswith("B"): memory_str = memory_str[:-1] # Parse numeric value and unit if memory_str.endswith("G"): return str(int(float(memory_str[:-1]) * 1024 * 1024 * 1024)) elif memory_str.endswith("M"): return str(int(float(memory_str[:-1]) * 1024 * 1024)) elif memory_str.endswith("K"): return str(int(float(memory_str[:-1]) * 1024)) else: # Assume already in bytes return memory_str
[docs] def create_schema(self, benchmark, connection: Any) -> float: """Create schema using DataFusion. Note: For DataFusion, actual table creation happens during load_data() via CREATE EXTERNAL TABLE. This method validates the schema is available. """ start_time = mono_time() self.log_operation_start("Schema creation", f"benchmark: {benchmark.__class__.__name__}") # Get constraint settings from tuning configuration enable_primary_keys, enable_foreign_keys = self._get_constraint_configuration() self._log_constraint_configuration(enable_primary_keys, enable_foreign_keys) # Note: DataFusion doesn't enforce constraints, so we log but don't apply them if enable_primary_keys or enable_foreign_keys: self.log_verbose( "DataFusion does not enforce PRIMARY KEY or FOREIGN KEY constraints - schema will be created without constraints" ) # Get structured schema directly from benchmark # This is cleaner than parsing SQL and provides type-safe access self._table_schemas = self._get_benchmark_schema(benchmark) duration = elapsed_seconds(start_time) self.log_operation_complete( "Schema creation", duration, f"Schema validated for {len(self._table_schemas)} tables" ) return duration
def _get_benchmark_schema(self, benchmark) -> dict[str, dict]: """Get structured schema directly from benchmark. Returns: Dict mapping table_name -> {'columns': [...]} where each column is {'name': str, 'type': str} """ schemas = {} # Try to get schema from benchmark's get_schema() method try: benchmark_schema = benchmark.get_schema() except (AttributeError, TypeError): # Fallback: some benchmarks might not have get_schema() self.log_verbose( f"Benchmark {benchmark.__class__.__name__} does not provide get_schema() method, " "will rely on schema inference during data loading" ) return {} if not benchmark_schema: self.log_verbose("Benchmark returned empty schema, will rely on schema inference") return {} # Convert benchmark schema format to our internal format # Benchmark schema: {table_name: {'name': str, 'columns': [{'name': str, 'type': str, ...}]}} for table_name_key, table_def in benchmark_schema.items(): # Normalize table name to lowercase table_name = table_name_key.lower() # Extract column information columns = [] if isinstance(table_def, dict) and "columns" in table_def: for col in table_def["columns"]: if isinstance(col, dict) and "name" in col: # Get type from the column definition col_type = col.get("type", "VARCHAR") # Handle both string types and potentially nested type info if not isinstance(col_type, str): col_type = "VARCHAR" columns.append({"name": col["name"], "type": col_type}) else: self.log_very_verbose(f"Skipping invalid column definition in {table_name}: {col}") if columns: schemas[table_name] = {"columns": columns} self.log_very_verbose(f"Extracted schema for {table_name}: {len(columns)} columns") else: self.log_verbose(f"Warning: No valid columns found for table {table_name}") return schemas
[docs] def load_data( self, benchmark, connection: Any, data_dir: Path ) -> tuple[dict[str, int], float, dict[str, Any] | None]: """Load data into DataFusion. Supports CSV, Parquet, Delta Lake, and Iceberg formats. Directory-based formats (delta/iceberg) are auto-detected from the file path. """ from benchbox.platforms.base.data_loading import DataSourceResolver start_time = mono_time() self.log_operation_start("Data loading", f"format: {self.data_format}") # Resolve data source (pass platform name for correct format preference) resolver = DataSourceResolver(platform_name=self.platform_name.lower()) data_source = resolver.resolve(benchmark, data_dir) if not data_source or not data_source.tables: raise ValueError(f"No data files found in {data_dir}") table_stats = {} per_table_timings = {} effective_tuning = self.unified_tuning_configuration if self.tuning_enabled else None # Load each table for table_name, file_paths in data_source.tables.items(): table_start = mono_time() # Ensure file_paths is a list if not isinstance(file_paths, list): file_paths = [file_paths] # Normalize table name to lowercase table_name_lower = table_name.lower() # Detect directory-based table formats (delta/iceberg) dir_format = self._detect_directory_format(file_paths) if dir_format == "delta": row_count = self._load_table_delta(connection, table_name_lower, file_paths[0]) elif dir_format == "iceberg": row_count = self._load_table_iceberg(connection, table_name_lower, file_paths[0]) elif self.data_format == "parquet": row_count = self._load_table_parquet(connection, table_name_lower, file_paths, data_dir) else: # Load CSV directly row_count = self._load_table_csv(connection, table_name_lower, file_paths, data_dir) if effective_tuning: self.apply_ctas_sort(table_name_lower, effective_tuning, connection) table_duration = elapsed_seconds(table_start) table_stats[table_name_lower] = row_count per_table_timings[table_name_lower] = {"total_ms": table_duration * 1000} self.log_verbose(f"Loaded table {table_name_lower}: {row_count:,} rows in {table_duration:.2f}s") total_duration = elapsed_seconds(start_time) total_rows = sum(table_stats.values()) self.log_operation_complete( "Data loading", total_duration, f"{total_rows:,} rows across {len(table_stats)} tables", ) return table_stats, total_duration, per_table_timings
[docs] def create_external_tables( self, benchmark: Any, connection: Any, data_dir: Path ) -> tuple[dict[str, int], float, dict[str, Any] | None]: """Alias external-table mode to DataFusion's existing external registration path.""" return self.load_data(benchmark, connection, data_dir)
def _build_ctas_sort_sql(self, table_name: str, sort_columns: list[TuningColumn]) -> str | None: """Build DataFusion CTAS SQL used by PlatformAdapter.apply_ctas_sort. ``sort_columns`` is pre-sorted by the caller; no internal sort needed. """ order_by_clause = ", ".join(column.name for column in sort_columns) return f"CREATE OR REPLACE TABLE {table_name} AS SELECT * FROM {table_name} ORDER BY {order_by_clause};" def _detect_csv_format(self, file_paths: list[Path]) -> str: """Detect CSV delimiter from file extension. Returns: Delimiter string """ if file_paths: return get_delimiter_for_file(file_paths[0]) return "," def _load_table_csv(self, connection: Any, table_name: str, file_paths: list[Path], data_dir: Path) -> int: """Load table from CSV files using CREATE EXTERNAL TABLE. Handles TPC benchmark format with trailing pipe delimiters and uses glob patterns for multiple files. """ # Detect delimiter delimiter = self._detect_csv_format(file_paths) # Get schema information for proper column names schema_info = self._table_schemas.get(table_name, {}) columns = schema_info.get("columns", []) # Build column schema for CREATE EXTERNAL TABLE if columns: # Use actual column names and types from schema schema_clause = ", ".join([f"{col['name']} {self._map_to_arrow_type(col['type'])}" for col in columns]) schema_clause = f"({schema_clause})" else: # No schema available - let DataFusion infer schema_clause = "" self.log_verbose(f"Warning: No schema found for {table_name}, using schema inference") # Use glob pattern for multiple files or single file if len(file_paths) > 1: # Check if files are in same directory and can use glob parent_dir = file_paths[0].parent if all(f.parent == parent_dir for f in file_paths): # Use glob pattern based on actual file extension # E.g., table.tbl.1, table.tbl.2 -> table.tbl.* # This preserves the extension to avoid matching unintended files first_file_name = file_paths[0].name # Find the position of the first numeric extension or just use table_name if "." in first_file_name: # Keep everything up to the last numeric part base_pattern = ( first_file_name.rsplit(".", 1)[0] if first_file_name.split(".")[-1].isdigit() else first_file_name ) location = str(parent_dir / f"{base_pattern}*") else: location = str(parent_dir / f"{table_name}*") self.log_very_verbose(f"Using glob pattern for {table_name}: {location}") else: # Files in different directories - fall back to first file with warning location = str(file_paths[0]) self.log_verbose( f"Warning: Multiple files in different directories for {table_name}, using first file only" ) else: location = str(file_paths[0]) # Build CREATE EXTERNAL TABLE statement # Note: DataFusion's CSV reader doesn't have a direct "ignore trailing delimiter" option # We need to handle this via schema definition with exact column count options = [ "'has_header' 'false'", f"'delimiter' '{delimiter}'", ] options_clause = ", ".join(options) create_sql = f""" CREATE EXTERNAL TABLE {table_name} {schema_clause} STORED AS CSV LOCATION '{location}' OPTIONS ({options_clause}) """ try: connection.sql(create_sql) self.log_very_verbose(f"Created external table: {table_name}") except Exception as e: self.log_verbose(f"Error creating external table {table_name}: {e}") raise RuntimeError(f"Failed to create external table {table_name}: {e}") from e # Count rows try: result = connection.sql(f"SELECT COUNT(*) FROM {table_name}").collect() row_count = int(result[0].column(0)[0]) return row_count except Exception as e: self.log_verbose(f"Error counting rows in {table_name}: {e}") raise RuntimeError(f"Failed to count rows in {table_name}: {e}") from e def _map_to_arrow_type(self, sql_type: str) -> str: """Map SQL types to Arrow/DataFusion types. This mapping is used when creating external CSV tables with explicit schemas. For Parquet tables, PyArrow infers types automatically during CSV parsing. """ sql_type_upper = sql_type.upper() # Map common SQL types to DataFusion types type_mapping = { "INTEGER": "INT", "BIGINT": "BIGINT", "DECIMAL": "DECIMAL", "DOUBLE": "DOUBLE", "FLOAT": "FLOAT", "VARCHAR": "VARCHAR", "CHAR": "VARCHAR", "TEXT": "VARCHAR", "DATE": "DATE", "TIMESTAMP": "TIMESTAMP", "BOOLEAN": "BOOLEAN", } # Check for parameterized types like DECIMAL(10,2) or VARCHAR(100) base_type = sql_type_upper.split("(")[0] if base_type in type_mapping: # For parameterized types, preserve the parameters if "(" in sql_type_upper: return f"{type_mapping[base_type]}{sql_type_upper[len(base_type) :]}" return type_mapping[base_type] # Return as-is if not in mapping (assume it's already valid) return sql_type def _load_table_parquet(self, connection: Any, table_name: str, file_paths: list[Path], data_dir: Path) -> int: """Load table as Parquet, converting from CSV/TBL if needed. If the input files are already Parquet, registers them directly. Otherwise, converts CSV/TBL to Parquet first, preserving column names from schema. """ import pyarrow as pa import pyarrow.parquet as pq # Check if input files are already Parquet input_is_parquet = all(self._is_parquet_file(fp) for fp in file_paths) if input_is_parquet: return self._register_parquet_files(connection, table_name, file_paths) return self._convert_and_register_parquet(connection, table_name, file_paths, pa, pq) def _is_parquet_file(self, file_path: Path) -> bool: """Check if a file is Parquet by extension (stripping compression suffixes).""" name = file_path.name # Strip known compression suffixes for suffix in (".zst", ".gz", ".bz2", ".lz4", ".snappy"): if name.endswith(suffix): name = name[: -len(suffix)] break return name.endswith(".parquet") def _register_parquet_files(self, connection: Any, table_name: str, file_paths: list[Path]) -> int: """Register pre-existing Parquet files directly with DataFusion.""" import pyarrow.parquet as pq if len(file_paths) == 1: parquet_path = str(file_paths[0]) self.log_very_verbose(f"Registering existing Parquet file for {table_name}: {parquet_path}") row_count = pq.read_metadata(file_paths[0]).num_rows connection.register_parquet(table_name, parquet_path) return row_count # Multiple parquet files: concatenate into single file in working directory import pyarrow as pa self.log_very_verbose(f"Concatenating {len(file_paths)} Parquet files for {table_name}") tables = [] for fp in file_paths: tables.append(pq.read_table(fp)) combined = pa.concat_tables(tables) parquet_file = self.working_dir / f"{table_name}.parquet" self.working_dir.mkdir(exist_ok=True) pq.write_table(combined, parquet_file, compression="snappy") connection.register_parquet(table_name, str(parquet_file)) return combined.num_rows @staticmethod def _detect_directory_format(file_paths: list[Path]) -> str | None: """Detect if file paths point to a directory-based table format. Returns 'delta', 'iceberg', or None for file-based formats. """ if len(file_paths) != 1: return None path = file_paths[0] if not path.is_dir(): return None if (path / "_delta_log").is_dir(): return "delta" if (path / "metadata").is_dir(): return "iceberg" return None def _load_table_delta(self, connection: Any, table_name: str, table_path: Path) -> int: """Load a Delta Lake table into DataFusion via deltalake + PyArrow.""" try: from deltalake import DeltaTable except ImportError as e: raise RuntimeError( "Delta Lake support requires the 'deltalake' package. " "Install it with: uv add deltalake --optional table-formats" ) from e self.log_very_verbose(f"Loading Delta Lake table for {table_name}: {table_path}") delta_table = DeltaTable(str(table_path)) arrow_table = delta_table.to_pyarrow_table() batches = arrow_table.to_batches() if batches: connection.register_record_batches(table_name, [batches]) else: # Empty table — register with schema but no data import pyarrow as pa empty_batch = pa.RecordBatch.from_pydict( {name: [] for name in arrow_table.schema.names}, schema=arrow_table.schema, ) connection.register_record_batches(table_name, [[empty_batch]]) self.log_very_verbose(f"Registered Delta table {table_name}: {arrow_table.num_rows:,} rows") return arrow_table.num_rows def _load_table_iceberg(self, connection: Any, table_name: str, table_path: Path) -> int: """Load an Iceberg table into DataFusion via pyiceberg + PyArrow.""" try: from pyiceberg.catalog.sql import SqlCatalog except ImportError as e: raise RuntimeError( "Iceberg support requires the 'pyiceberg' package. " "Install it with: uv add pyiceberg --optional table-formats" ) from e self.log_very_verbose(f"Loading Iceberg table for {table_name}: {table_path}") # Use a file-based SQLite catalog pointing at the table's warehouse catalog = SqlCatalog( "benchbox", uri=f"sqlite:///{table_path}/catalog.db", warehouse=str(table_path.parent), ) ice_table = catalog.load_table(f"default.{table_name}") arrow_table = ice_table.scan().to_arrow() batches = arrow_table.to_batches() if batches: connection.register_record_batches(table_name, [batches]) else: import pyarrow as pa empty_batch = pa.RecordBatch.from_pydict( {name: [] for name in arrow_table.schema.names}, schema=arrow_table.schema, ) connection.register_record_batches(table_name, [[empty_batch]]) self.log_very_verbose(f"Registered Iceberg table {table_name}: {arrow_table.num_rows:,} rows") return arrow_table.num_rows def _convert_and_register_parquet( self, connection: Any, table_name: str, file_paths: list[Path], pa: Any, pq: Any ) -> int: """Convert CSV/TBL files to Parquet and register with DataFusion.""" import pyarrow.csv as csv # Store parquet files directly in working directory parquet_dir = self.working_dir parquet_dir.mkdir(exist_ok=True) parquet_file = parquet_dir / f"{table_name}.parquet" # Detect delimiter - PyArrow's CSV reader handles trailing delimiters automatically delimiter = self._detect_csv_format(file_paths) # Get schema information for proper column names and types schema_info = self._table_schemas.get(table_name, {}) columns = schema_info.get("columns", []) # Build column names list from schema # PyArrow's CSV reader handles trailing delimiters correctly - it doesn't create extra columns column_names = None column_types = None if columns: column_names = [col["name"] for col in columns] # Build PyArrow column types from schema to prevent incorrect type inference # e.g., ca_zip "89436" should be string, not int64 column_types = {} for col in columns: col_name = col["name"] col_type = col.get("type", "VARCHAR").upper() # Map schema types to PyArrow types - string types must be explicit # to prevent PyArrow from inferring numeric types for zip codes etc. if col_type.startswith(("CHAR", "VARCHAR", "TEXT", "STRING")): column_types[col_name] = pa.string() elif col_type.startswith("DATE"): column_types[col_name] = pa.date32() elif col_type.startswith("DECIMAL"): # Use float64 for decimal to avoid precision issues column_types[col_name] = pa.float64() # Other types (INT, BIGINT, etc.) can use auto-inference self.log_very_verbose(f"Using {len(column_names)} columns from schema for {table_name}: {column_names}") else: self.log_verbose(f"Warning: No schema found for {table_name}, using auto-generated column names") # Convert CSV to Parquet self.log_very_verbose(f"Converting {len(file_paths)} CSV file(s) to Parquet for {table_name}") # Read all CSV files and combine # Note: PyArrow automatically detects and handles compressed files (.gz, .bz2, etc.) tables = [] for file_path in file_paths: try: # Configure CSV read options read_options = csv.ReadOptions( column_names=column_names, autogenerate_column_names=(column_names is None), ) parse_options = csv.ParseOptions( delimiter=delimiter, quote_char='"', # Standard quote character escape_char="\\", # Standard escape character ) convert_options = csv.ConvertOptions( null_values=[""], strings_can_be_null=True, column_types=column_types, # Explicit types to prevent incorrect inference ) # Read CSV with PyArrow (handles Path objects and compression automatically) table = csv.read_csv( file_path, read_options=read_options, parse_options=parse_options, convert_options=convert_options, ) tables.append(table) except Exception as e: self.log_verbose(f"Error reading CSV file {file_path}: {e}") raise RuntimeError(f"Failed to read CSV file {file_path}: {e}") from e # Concatenate all tables try: combined_table = pa.concat_tables(tables) except Exception as e: self.log_verbose(f"Error concatenating tables for {table_name}: {e}") raise RuntimeError(f"Failed to concatenate CSV data for {table_name}: {e}") from e # Write to Parquet with compression try: pq.write_table( combined_table, parquet_file, compression="snappy", # Fast compression, good balance ) self.log_very_verbose(f"Created Parquet file: {parquet_file} ({combined_table.num_rows:,} rows)") except Exception as e: self.log_verbose(f"Error writing Parquet file for {table_name}: {e}") raise RuntimeError(f"Failed to write Parquet file for {table_name}: {e}") from e # Register Parquet table in DataFusion try: connection.register_parquet(table_name, str(parquet_file)) except Exception as e: # Clean up the Parquet file if registration fails try: if parquet_file.exists(): parquet_file.unlink() self.log_very_verbose(f"Cleaned up orphaned Parquet file: {parquet_file}") except Exception as cleanup_error: self.log_very_verbose(f"Could not clean up Parquet file: {cleanup_error}") self.log_verbose(f"Error registering Parquet table {table_name}: {e}") raise RuntimeError(f"Failed to register Parquet table {table_name}: {e}") from e return combined_table.num_rows
[docs] def configure_for_benchmark(self, connection: Any, benchmark_type: str) -> None: """Apply DataFusion-specific optimizations based on benchmark type.""" # DataFusion optimizations are set during connection creation # Additional benchmark-specific settings can be added here self.log_verbose(f"DataFusion configured for {benchmark_type} benchmark")
[docs] def execute_query( self, connection: Any, query: str, query_id: str, benchmark_type: str | None = None, scale_factor: float | None = None, validate_row_count: bool = True, stream_id: int | None = None, ) -> dict[str, Any]: """Execute query with detailed timing and result collection.""" self.log_verbose(f"Executing query {query_id}") self.log_very_verbose(f"Query SQL (first 200 chars): {query[:200]}{'...' if len(query) > 200 else ''}") # Apply DataFusion-specific query transformations for SQL compatibility from benchbox.platforms.datafusion_query_transformer import DataFusionQueryTransformer transformer = DataFusionQueryTransformer(verbose=getattr(self, "very_verbose", False)) query = transformer.transform(query, query_id=query_id) if transformer.get_transformations_applied(): self.log_verbose( f"Query {query_id}: Applied transformations: {', '.join(transformer.get_transformations_applied())}" ) # In dry-run mode we intentionally capture transformed SQL so output # reflects what DataFusion would execute after compatibility rewrites. if self.dry_run_mode: self.capture_sql(query, "query", None) self.log_very_verbose(f"Captured query {query_id} for dry-run") return { "query_id": query_id, "status": "DRY_RUN", "execution_time_seconds": 0.0, "rows_returned": 0, "first_row": None, "error": None, "dry_run": True, } start_time = mono_time() try: # Execute the query df = connection.sql(query) # Collect results result_batches = df.collect() # Calculate total rows actual_row_count = sum(batch.num_rows for batch in result_batches) # Get first row if results exist first_row = None if result_batches and result_batches[0].num_rows > 0: # Convert first row to tuple first_batch = result_batches[0] first_row = tuple( first_batch.column(i)[0].as_py() if hasattr(first_batch.column(i)[0], "as_py") else first_batch.column(i)[0] for i in range(first_batch.num_columns) ) execution_time = elapsed_seconds(start_time) logger.debug(f"Query {query_id} completed in {execution_time:.3f}s, returned {actual_row_count} rows") # Validate row count if enabled validation_result = None if validate_row_count and benchmark_type: from benchbox.core.validation.query_validation import QueryValidator validator = QueryValidator() validation_result = validator.validate_query_result( benchmark_type=benchmark_type, query_id=query_id, actual_row_count=actual_row_count, scale_factor=scale_factor, stream_id=stream_id, ) # Log validation result if validation_result.warning_message: self.log_verbose(f"Row count validation: {validation_result.warning_message}") elif not validation_result.is_valid: self.log_verbose(f"Row count validation FAILED: {validation_result.error_message}") else: self.log_very_verbose( f"Row count validation PASSED: {actual_row_count} rows " f"(expected: {validation_result.expected_row_count})" ) # Use centralized helper to build result with validation return self._build_query_result_with_validation( query_id=query_id, execution_time=execution_time, actual_row_count=actual_row_count, first_row=first_row, validation_result=validation_result, ) except Exception as e: execution_time = elapsed_seconds(start_time) logger.error( f"Query {query_id} failed after {execution_time:.3f}s: {e}", exc_info=True, ) return { "query_id": query_id, "status": "FAILED", "execution_time_seconds": execution_time, "rows_returned": 0, "error": str(e), "error_type": type(e).__name__, }
[docs] def apply_platform_optimizations(self, platform_config, connection: Any) -> None: """Apply DataFusion-specific platform optimizations. DataFusion optimizations are primarily configured at SessionContext creation. This method logs the optimization configuration. """ if not platform_config: self.log_verbose("No platform optimizations configured") return self.log_verbose("DataFusion optimizations configured at session creation")
# DataFusion doesn't support most traditional database optimizations # like Z-ordering, auto-optimize, bloom filters, etc. # These are handled through file organization (Parquet partitioning)
[docs] def apply_constraint_configuration( self, primary_key_config, foreign_key_config, connection: Any, ) -> None: """Apply constraint configuration. Note: DataFusion does not enforce PRIMARY KEY or FOREIGN KEY constraints. This method logs the configuration but does not apply constraints. """ if primary_key_config and primary_key_config.enabled: self.log_verbose( "DataFusion does not enforce PRIMARY KEY constraints - configuration noted but not applied" ) if foreign_key_config and foreign_key_config.enabled: self.log_verbose( "DataFusion does not enforce FOREIGN KEY constraints - configuration noted but not applied" )
[docs] def check_database_exists(self, **connection_config) -> bool: """Check if DataFusion working directory exists with data.""" working_dir = Path(connection_config.get("working_dir", self.working_dir)) if not working_dir.exists(): return False # Check if working directory has parquet files if any(working_dir.glob("*.parquet")): return True return False
[docs] def drop_database(self, **connection_config) -> None: """Drop DataFusion working directory and all data.""" import shutil working_dir = Path(connection_config.get("working_dir", self.working_dir)) if working_dir.exists(): self.log_verbose(f"Removing DataFusion working directory: {working_dir}") shutil.rmtree(working_dir) self.log_verbose("DataFusion working directory removed")
[docs] def validate_platform_capabilities(self, benchmark_type: str): """Validate DataFusion-specific capabilities for the benchmark.""" from benchbox.core.validation import ValidationResult errors = [] warnings = [] # Check if DataFusion is available if SessionContext is None: errors.append("DataFusion library not available - install with 'pip install datafusion'") else: # Check DataFusion version try: import datafusion as df_module version = df_module.__version__ self.log_very_verbose(f"DataFusion version: {version}") except (ImportError, AttributeError): warnings.append("Could not determine DataFusion version") # Warn about TPC-DS limitations if benchmark_type.lower() == "tpcds": warnings.append("Some TPC-DS queries may fail due to DataFusion SQL feature limitations") # Check memory configuration if self.memory_limit: try: memory_bytes = int(self._parse_memory_limit(self.memory_limit)) memory_gb = memory_bytes / (1024**3) if memory_gb < 2.0: warnings.append(f"Memory limit ({self.memory_limit}) may be insufficient for larger scale factors") except (ValueError, TypeError): warnings.append(f"Could not parse memory limit: {self.memory_limit}") platform_info = { "platform": self.platform_name, "benchmark_type": benchmark_type, "dry_run_mode": self.dry_run_mode, "datafusion_available": SessionContext is not None, "working_dir": str(self.working_dir), "memory_limit": self.memory_limit, "target_partitions": self.target_partitions, "data_format": self.data_format, } if SessionContext: try: import datafusion as df_module platform_info["datafusion_version"] = df_module.__version__ except (ImportError, AttributeError): # Version attribute not available in this DataFusion version pass return ValidationResult( is_valid=len(errors) == 0, errors=errors, warnings=warnings, details=platform_info, )
def _get_existing_tables(self, connection) -> list[str]: """Get list of existing tables in DataFusion SessionContext. Override base class method to use DataFusion's catalog API instead of information_schema queries. """ try: # DataFusion SessionContext has a catalog() method to list tables # Get all registered tables tables = [] # DataFusion stores tables in a catalog structure # Use SHOW TABLES to get list of tables result = connection.sql("SHOW TABLES") rows = result.collect() # Extract table names from the result # SHOW TABLES returns a DataFrame with columns: # [table_catalog, table_schema, table_name, table_type] for batch in rows: # Convert to pydict to get table names data = batch.to_pydict() # Get the 'table_name' column specifically if data and "table_name" in data: tables.extend([name.lower() for name in data["table_name"]]) return tables except Exception as e: self.log_verbose(f"Error getting existing tables: {e}") # Fallback - return empty list if query fails return [] def _validate_data_integrity( self, benchmark, connection, table_stats: dict[str, int] ) -> tuple[str, dict[str, Any]]: """Validate basic data integrity using DataFusion SessionContext API. Override base class method to use DataFusion's ctx.sql() instead of DB-API 2.0 cursor interface. """ validation_details = {} try: # DataFusion connection is a SessionContext, not a DB-API 2.0 connection # Use ctx.sql() to validate table accessibility accessible_tables = [] inaccessible_tables = [] for table_name in table_stats: try: # Try a simple SELECT to verify table is accessible # Use SessionContext.sql() which returns a DataFrame result = connection.sql(f"SELECT 1 FROM {table_name} LIMIT 1") # Execute the query to verify it works result.collect() accessible_tables.append(table_name) except Exception as e: self.log_verbose(f"Table {table_name} inaccessible: {e}") inaccessible_tables.append(table_name) if inaccessible_tables: validation_details["inaccessible_tables"] = inaccessible_tables validation_details["constraints_enabled"] = False return "FAILED", validation_details else: validation_details["accessible_tables"] = accessible_tables validation_details["constraints_enabled"] = True return "PASSED", validation_details except Exception as e: validation_details["error"] = str(e) return "FAILED", validation_details