"""SQL dialect utilities for cross-database compatibility.
Copyright 2026 Joe Harris / BenchBox Project
Licensed under the MIT License. See LICENSE file in the project root for details.
"""
from typing import Callable
[docs]
def normalize_dialect_for_sqlglot(dialect: str) -> str:
"""Normalize dialect names for SQLGlot compatibility.
Some database dialects are not directly supported by SQLGlot but are
based on dialects that are supported. This function maps unsupported
dialects to their closest supported equivalent.
SQLGlot supports these dialects: 'athena', 'bigquery', 'clickhouse',
'databricks', 'doris', 'drill', 'druid', 'duckdb', 'dune', 'hive',
'materialize', 'mysql', 'oracle', 'postgres', 'presto', 'prql',
'redshift', 'risingwave', 'snowflake', 'spark', 'spark2', 'sqlite',
'starrocks', 'tableau', 'teradata', 'trino', 'tsql'.
Args:
dialect: The dialect name to normalize (case-insensitive)
Returns:
The normalized dialect name that SQLGlot can process
Examples:
>>> normalize_dialect_for_sqlglot("netezza")
'postgres'
>>> normalize_dialect_for_sqlglot("duckdb")
'duckdb'
>>> normalize_dialect_for_sqlglot("NETEZZA")
'postgres'
"""
# Convert to lowercase for case-insensitive matching
dialect_lower = dialect.lower() if dialect else ""
# Map unsupported dialects to their SQLGlot-compatible equivalents
dialect_mapping = {
"netezza": "postgres", # Netezza is based on PostgreSQL
"greenplum": "postgres", # Greenplum is PostgreSQL-based
"vertica": "postgres", # Vertica uses PostgreSQL-compatible SQL
"ansi": "postgres", # ANSI SQL → PostgreSQL (defensive mapping, should not be used)
"standard": "postgres", # Standard SQL → PostgreSQL (defensive mapping, should not be used)
# DuckDB, ClickHouse, BigQuery, Snowflake, Redshift already supported directly
}
return dialect_mapping.get(dialect_lower, dialect_lower)
def translate_sql_query(
query: str,
target_dialect: str,
source_dialect: str = "netezza",
identify: bool = True,
pre_processors: list[Callable[[str], str]] | None = None,
post_processors: list[Callable[[str], str]] | None = None,
) -> str:
"""Translate SQL query from source dialect to target dialect using SQLGlot.
This is the centralized SQL translation function used by all benchmarks.
It follows the TPC-DS gold standard pattern with comprehensive error handling.
The default source dialect is "netezza" (Postgres-compatible), which provides
the best compatibility with modern SQL features across platforms. This can be
overridden for benchmarks that use platform-specific source queries (e.g.,
ClickBench uses "clickhouse" as the source dialect).
Args:
query: SQL query text to translate
target_dialect: Target SQL dialect (e.g., 'duckdb', 'bigquery', 'snowflake')
source_dialect: Source SQL dialect (default: 'netezza' for modern SQL compatibility)
identify: Whether to quote identifiers to prevent reserved keyword conflicts (default: True)
pre_processors: Optional list of functions to pre-process query before translation
post_processors: Optional list of functions to post-process query after translation
Returns:
Translated SQL query text. Returns original query if translation fails.
Examples:
>>> translate_sql_query("SELECT * FROM orders", "duckdb")
'SELECT * FROM "orders"'
>>> translate_sql_query("SELECT * FROM orders", "bigquery", source_dialect="postgres")
'SELECT * FROM `orders`'
>>> # With custom pre-processor
>>> def fix_syntax(q): return q.replace("LIMIT", "FETCH FIRST")
>>> translate_sql_query("SELECT * FROM t LIMIT 10", "oracle", pre_processors=[fix_syntax])
'SELECT * FROM "t" FETCH FIRST 10 ROWS ONLY'
"""
import logging
logger = logging.getLogger(__name__)
try:
import sqlglot
except ImportError:
logger.warning("SQLGlot not available, returning original query")
return query
try:
# Normalize both source and target dialects for SQLGlot compatibility
src = normalize_dialect_for_sqlglot(source_dialect.lower())
tgt = normalize_dialect_for_sqlglot(target_dialect.lower())
# Apply pre-processors (e.g., TPC-DS interval syntax normalization)
processed_query = query
if pre_processors:
for pre_proc in pre_processors:
processed_query = pre_proc(processed_query)
# Translate using SQLGlot
# Some databases don't need identifier quoting:
# - ClickHouse: case-sensitive without case-folding, lowercase schema matches unquoted lowercase
# - PostgreSQL/DataFusion: unquoted identifiers are folded to lowercase by the engine
# Quoting preserves case which causes mismatches with lowercase schemas.
should_identify = identify and (tgt not in ("clickhouse", "postgres"))
translated = sqlglot.transpile(processed_query, read=src, write=tgt, identify=should_identify)[0]
# Apply post-processors (e.g., TPC-DS Query 58 ambiguity fix)
if post_processors:
for post_proc in post_processors:
translated = post_proc(translated)
return translated
except Exception as e:
logger.warning(
f"SQLGlot translation failed from {source_dialect} to {target_dialect}: {e}. Returning original query."
)
return query
def fix_postgres_date_arithmetic(query: str) -> str:
"""Convert integer date arithmetic to INTERVAL syntax for PostgreSQL/DataFusion.
PostgreSQL and DataFusion don't support `date + integer` directly.
This converts patterns like `d_date + 5` to `d_date + INTERVAL '5' DAY`.
Args:
query: SQL query with potential date arithmetic
Returns:
Query with date arithmetic converted to INTERVAL syntax
"""
import re
# Pattern: column_name + integer or column_name - integer
# where column_name contains 'date' (case-insensitive)
# Matches: d_date + 5, d1.d_date + 30, d_date - 7
pattern = r"(\b\w*\.?\w*d_date\w*)\s*([+-])\s*(\d+)"
def replace_with_interval(match: re.Match) -> str:
col = match.group(1)
op = match.group(2)
num = match.group(3)
return f"{col} {op} INTERVAL '{num}' DAY"
return re.sub(pattern, replace_with_interval, query, flags=re.IGNORECASE)