Skip to content

Commit 34a231a

Browse files
2 parents 68f4cd2 + 29a586a commit 34a231a

File tree

12 files changed

+124
-58
lines changed

12 files changed

+124
-58
lines changed

sqlcompare/compare/comparator.py

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -260,13 +260,10 @@ def compare(
260260
except Exception:
261261
pass
262262

263-
# Get column names from pre-existing tables
264-
_, cols_prev = db.query(
265-
f"SELECT * FROM {tables['previous']} WHERE 1=0", include_columns=True
266-
)
267-
_, cols_new = db.query(
268-
f"SELECT * FROM {tables['new']} WHERE 1=0", include_columns=True
269-
)
263+
# Get column names from pre-existing tables using database-specific metadata
264+
# This ensures we get the actual column case (important for Snowflake)
265+
cols_prev = db.get_table_columns(tables['previous'])
266+
cols_new = db.get_table_columns(tables['new'])
270267
self.cols_prev = cols_prev
271268
self.cols_new = cols_new
272269

@@ -420,17 +417,15 @@ def compare(
420417
runs = load_test_runs()
421418

422419
# Save metadata for later analysis
423-
# For DuckDB connections (file-based tests), save "duckdb"
424-
# For remote databases (Snowflake, etc.), save the connection name
425-
conn_name = self.connection if isinstance(self.connection, str) else "duckdb"
426-
420+
# Save the connection as-is (can be None, connection ID, or URL)
421+
# When None, it will resolve to the default connection on inspection
427422
run_data = {
428423
"tables": tables,
429424
"index_cols": list(self.index_cols),
430425
"cols_prev": self.cols_prev,
431426
"cols_new": self.cols_new,
432427
"common_cols": self.common_cols,
433-
"conn": conn_name,
428+
"conn": self.connection,
434429
}
435430

436431
runs[diff_id] = run_data

sqlcompare/config.py

Lines changed: 21 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,12 @@
33
import os
44
from pathlib import Path
55
from typing import Any
6+
import yaml
67

78

89
def _config_dir() -> Path:
910
return Path(
10-
os.getenv("SQLCOMPARE_CONFIG_DIR", Path.home() / ".config" / "data-toolkit")
11+
os.getenv("SQLCOMPARE_CONFIG_DIR", Path.home() / ".config" / "sqlcompare")
1112
)
1213

1314

@@ -23,8 +24,8 @@ def load_config() -> dict[str, Any]:
2324
}
2425

2526

26-
def _runs_file() -> Path:
27-
return _config_dir() / "db_test_runs.yaml"
27+
def _runs_dir() -> Path:
28+
return _config_dir() / "runs"
2829

2930

3031
def get_tests_folder() -> Path:
@@ -35,39 +36,32 @@ def get_tests_folder() -> Path:
3536

3637

3738
def load_test_runs() -> dict[str, Any]:
38-
path = _runs_file()
39-
if not path.exists():
39+
runs_dir = _runs_dir()
40+
if not runs_dir.exists():
4041
return {}
41-
payload = _read_yaml(path)
42-
if not payload:
43-
return {}
44-
if not isinstance(payload, dict):
45-
raise ValueError("Expected mapping in db_test_runs.yaml")
46-
return payload
42+
43+
runs = {}
44+
for yaml_file in runs_dir.glob("*.yaml"):
45+
run_id = yaml_file.stem # filename without extension
46+
payload = _read_yaml(yaml_file)
47+
if payload:
48+
runs[run_id] = payload
49+
50+
return runs
4751

4852

4953
def save_test_runs(runs: dict[str, Any]) -> None:
50-
path = _runs_file()
51-
path.parent.mkdir(parents=True, exist_ok=True)
52-
_write_yaml(path, runs)
54+
runs_dir = _runs_dir()
55+
runs_dir.mkdir(parents=True, exist_ok=True)
5356

57+
for run_id, run_data in runs.items():
58+
file_path = runs_dir / f"{run_id}.yaml"
59+
_write_yaml(file_path, run_data)
5460

55-
def _read_yaml(path: Path) -> dict[str, Any] | None:
56-
try:
57-
import yaml
58-
except ImportError:
59-
import json
6061

61-
return json.loads(path.read_text(encoding="utf-8"))
62+
def _read_yaml(path: Path) -> dict[str, Any] | None:
6263
return yaml.safe_load(path.read_text(encoding="utf-8"))
6364

6465

6566
def _write_yaml(path: Path, payload: dict[str, Any]) -> None:
66-
try:
67-
import yaml
68-
except ImportError:
69-
import json
70-
71-
path.write_text(json.dumps(payload, indent=2), encoding="utf-8")
72-
return
7367
path.write_text(yaml.safe_dump(payload, sort_keys=True), encoding="utf-8")

sqlcompare/db/connection.py

Lines changed: 44 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ def __enter__(self) -> "DBConnection":
7878

7979
return self
8080

81-
def __exit__(self, exc_type, _exc, _tb) -> None:
81+
def __exit__(self, exc_type, _exc_val, _exc_tb) -> None:
8282
# Commit / rollback transaction
8383
try:
8484
if self._tx is not None:
@@ -196,12 +196,22 @@ def _run(self, sql: str, params: dict[str, Any] | None = None) -> Result:
196196
res = self.conn.execute(text(sql), params or {})
197197
return res
198198
except SQLAlchemyError as e:
199+
# Extract just the database error message without the full SQL statement
200+
error_msg = str(e)
201+
# SQLAlchemy often includes the SQL in square brackets at the end
202+
# Format: "error message [SQL: long query here]"
203+
if "[SQL:" in error_msg:
204+
# Extract just the part before [SQL:
205+
db_error = error_msg.split("[SQL:")[0].strip()
206+
else:
207+
db_error = error_msg
208+
199209
raise DBConnectionError(
200-
"SQL execution failed.",
210+
db_error,
201211
conn_id=self.conn_id,
202212
sql=sql,
203213
original=e,
204-
) from e
214+
) from None # Suppress the original exception chain to avoid showing SQL twice
205215
finally:
206216
self.last_elapsed_ms = int((time.perf_counter() - t0) * 1000)
207217

@@ -220,6 +230,37 @@ def execute(self, sql: str, params: dict[str, Any] | None = None) -> ExecMeta:
220230
elapsed_ms=self.last_elapsed_ms or 0, rowcount=rowcount, columns=cols
221231
)
222232

233+
def get_table_columns(self, table_name: str) -> list[str]:
234+
"""
235+
Get the actual column names from a table, using database-specific metadata queries
236+
to ensure correct case sensitivity (especially important for Snowflake).
237+
238+
Args:
239+
table_name: Fully qualified table name (e.g., "schema.table" or "db.schema.table")
240+
241+
Returns:
242+
List of column names with their actual case as stored in the database
243+
"""
244+
# Detect if we're using Snowflake by checking the dialect
245+
try:
246+
dialect_name = self._engine.dialect.name.lower() if self._engine else None
247+
except Exception:
248+
dialect_name = None
249+
250+
# For Snowflake, use DESCRIBE TABLE to get actual column names
251+
if dialect_name == "snowflake":
252+
try:
253+
# DESCRIBE TABLE returns columns with their actual case
254+
result = self.query(f"DESCRIBE TABLE {table_name}")
255+
# First column is the column name
256+
return [row[0] for row in result]
257+
except Exception:
258+
pass
259+
260+
# Fallback: use SELECT * WHERE 1=0 and get column names from result
261+
_, columns = self.query(f"SELECT * FROM {table_name} WHERE 1=0", include_columns=True)
262+
return columns
263+
223264
def query(
224265
self,
225266
sql: str,

sqlcompare/db/constants.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,11 @@
44
from pathlib import Path
55

66
lib_name = "sqlcompare"
7-
lib_config_folder = Path.home() / f".{lib_name}"
7+
lib_config_folder = Path.home() / ".config" / f"{lib_name}"
88

99
LIBRARY_CONNECTIONS = [
1010
os.getenv(
11-
"SQLCOMPARE_CONNECTIONS_FILE", str(lib_config_folder / "connections.yml")
11+
"SQLCOMPARE_CONNECTIONS_FILE", str(lib_config_folder / "connections.yaml")
1212
),
1313
os.getenv("DTK_CONNECTIONS_FILE", str(Path.home() / ".dtk" / "connections.yml")),
1414
]

sqlcompare/db/exceptions.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,3 +11,17 @@ def __init__(
1111
self.conn_id = conn_id
1212
self.sql = sql
1313
self.original = original
14+
15+
def __str__(self) -> str:
16+
# The main error message is already set in __init__, just add SQL context if helpful
17+
msg = super().__str__()
18+
19+
# Optionally show a brief SQL snippet for context
20+
if self.sql and len(self.sql) > 200:
21+
# Only show snippet if SQL is long (short SQL is fine to show in full)
22+
sql_lines = self.sql.strip().split('\n')
23+
first_line = sql_lines[0] if sql_lines else self.sql
24+
sql_snippet = first_line[:100] + "..." if len(first_line) > 100 else first_line
25+
msg += f"\n\nSQL operation: {sql_snippet}"
26+
27+
return msg

sqlcompare/db/resolver.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,27 @@ def normalize_connection_name(name: str) -> str:
2323
return s
2424

2525

26+
def _expand_env_vars(value: Any) -> Any:
27+
"""
28+
Recursively expand environment variables in strings, dicts, and lists.
29+
30+
Expands ${VAR_NAME} and $VAR_NAME patterns in strings.
31+
"""
32+
if isinstance(value, str):
33+
return os.path.expandvars(value)
34+
elif isinstance(value, dict):
35+
return {k: _expand_env_vars(v) for k, v in value.items()}
36+
elif isinstance(value, list):
37+
return [_expand_env_vars(item) for item in value]
38+
else:
39+
return value
40+
41+
2642
def _load_connections_from_yaml(yaml_path: str) -> dict[str, dict[str, Any]]:
2743
"""
2844
Load connection definitions from YAML file.
2945
Returns a dict mapping conn_id -> URL.create() parameters.
46+
Expands environment variables in all string values.
3047
"""
3148
expanded_path = Path(yaml_path).expanduser()
3249
if not expanded_path.exists():
@@ -35,7 +52,10 @@ def _load_connections_from_yaml(yaml_path: str) -> dict[str, dict[str, Any]]:
3552
try:
3653
with open(expanded_path, "r") as f:
3754
data = yaml.safe_load(f)
38-
return data if isinstance(data, dict) else {}
55+
if isinstance(data, dict):
56+
# Expand environment variables in all values
57+
return _expand_env_vars(data)
58+
return {}
3959
except Exception:
4060
return {}
4161

sqlcompare/query.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def query(q: str, connection: str | None, output: str) -> None:
3737
if rows:
3838
from tabulate import tabulate
3939

40-
print(tabulate(rows, headers=columns))
40+
print(tabulate(rows, headers=columns, tablefmt="pretty"))
4141
print(f"\nTotal rows: {len(rows)}")
4242
else:
4343
print("Query executed successfully. No rows returned.")

sqlcompare/utils/format.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def format_table(
3333
"""
3434
n_cols = len(columns)
3535
if n_cols == 0:
36-
return tabulate([], headers=[], tablefmt=tablefmt, **tabulate_kwargs)
36+
return tabulate([], headers=[], tablefmt="pretty", **tabulate_kwargs)
3737

3838
# ---- column trimming ----
3939
use_col_ellipsis = n_cols > max_cols and max_cols >= 2

sqlcompare/utils/test_types/stats.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77

88
def _quote_ident(name: str) -> str:
9-
return f'"{name.replace('"', '""')}"'
9+
return '"' + name.replace('"', '""') + '"'
1010

1111

1212
def _collect_table_stats(
@@ -69,12 +69,8 @@ def compare_table_stats(table1: str, table2: str, connection: str | None) -> Non
6969
db.create_table_from_file(table1_name, table1)
7070
db.create_table_from_file(table2_name, table2)
7171

72-
_, cols_prev = db.query(
73-
f"SELECT * FROM {table1_name} WHERE 1=0", include_columns=True
74-
)
75-
_, cols_new = db.query(
76-
f"SELECT * FROM {table2_name} WHERE 1=0", include_columns=True
77-
)
72+
cols_prev = db.get_table_columns(table1_name)
73+
cols_new = db.get_table_columns(table2_name)
7874

7975
prev_map = {col.upper(): col for col in cols_prev}
8076
new_map = {col.upper(): col for col in cols_new}

tests/cli_helpers.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,3 +29,9 @@ def set_cli_env(
2929
monkeypatch.setenv("SQLCOMPARE_CONN_DEFAULT", connection_name)
3030
monkeypatch.setenv(f"SQLCOMPARE_CONN_{connection_name.upper()}", connection_url)
3131
monkeypatch.setenv("SQLCOMPARE_COMPARISON_SCHEMA", schema)
32+
33+
34+
def setup_duckdb_env(monkeypatch, config_dir: Path, schema: str = "sqlcompare") -> None:
35+
"""Setup environment for DuckDB tests without a specific connection."""
36+
monkeypatch.setenv("SQLCOMPARE_CONFIG_DIR", str(config_dir))
37+
monkeypatch.setenv("SQLCOMPARE_COMPARISON_SCHEMA", schema)

0 commit comments

Comments
 (0)