Skip to content

Commit 5766461

Browse files
feat: Implement connection resolution and materialization for SQL inputs; enhance concurrency handling in stats checks
1 parent d6d3d93 commit 5766461

File tree

6 files changed

+247
-72
lines changed

6 files changed

+247
-72
lines changed

sqlcompare/helpers.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from __future__ import annotations
44

5+
import os
56
import re
67
from dataclasses import dataclass
78
from pathlib import Path
@@ -45,6 +46,66 @@ def detect_input(value: str) -> InputSpec:
4546

4647
return InputSpec(kind="table", value=value.strip(), source="none")
4748

49+
50+
def resolve_connection(
51+
connection: str | None,
52+
*,
53+
error_cls: type[Exception] = typer.BadParameter,
54+
) -> str:
55+
"""Resolve a connection from an explicit value or the default environment variables."""
56+
if connection:
57+
return connection
58+
59+
default_conn = os.getenv("SQLCOMPARE_CONN_DEFAULT") or os.getenv("DTK_CONN_DEFAULT")
60+
if not default_conn:
61+
raise error_cls(
62+
"No connection specified. Use --connection or set SQLCOMPARE_CONN_DEFAULT."
63+
)
64+
return default_conn
65+
66+
67+
def resolve_materialized_tables(
68+
previous_spec: InputSpec,
69+
current_spec: InputSpec,
70+
*,
71+
schema: str | None,
72+
prefix: str,
73+
suffix: str,
74+
) -> tuple[str, str]:
75+
"""Return table names for a pair of table-or-SQL inputs."""
76+
schema_prefix = f"{schema}." if schema else ""
77+
previous_name = (
78+
previous_spec.value
79+
if previous_spec.kind == "table"
80+
else f"{schema_prefix}{prefix}_{suffix}_previous"
81+
)
82+
current_name = (
83+
current_spec.value
84+
if current_spec.kind == "table"
85+
else f"{schema_prefix}{prefix}_{suffix}_new"
86+
)
87+
return previous_name, current_name
88+
89+
90+
def materialize_sql_inputs(
91+
db: DBConnection,
92+
*,
93+
previous_spec: InputSpec,
94+
current_spec: InputSpec,
95+
previous_table: str,
96+
current_table: str,
97+
schema: str | None,
98+
) -> None:
99+
"""Create tables for SQL inputs while leaving table-name inputs untouched."""
100+
if previous_spec.kind != "sql" and current_spec.kind != "sql":
101+
return
102+
103+
ensure_schema(db, schema or "")
104+
if previous_spec.kind == "sql":
105+
create_table_from_select(db, previous_table, previous_spec.value)
106+
if current_spec.kind == "sql":
107+
create_table_from_select(db, current_table, current_spec.value)
108+
48109
def expand_dataset_value(value: Any, base_dir: Path) -> Any:
49110
"""
50111
Recursively expand template variables in dataset configuration values.

sqlcompare/run.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,11 @@
55

66
import typer
77

8-
from sqlcompare.helpers import detect_input
8+
from sqlcompare.helpers import detect_input, resolve_connection
99
from sqlcompare.run_table import compare_table, run_table_cmd
1010
from sqlcompare.run_dataset import run_dataset_cmd
1111
from sqlcompare.run_stats import run_stats_cmd
12-
from sqlcompare.run_auto import _resolve_connection, run_auto_cmd
12+
from sqlcompare.run_auto import run_auto_cmd
1313
from sqlcompare.config import get_default_schema
1414
from sqlcompare.db import DBConnection
1515
from sqlcompare.helpers import create_table_from_select, ensure_schema
@@ -91,7 +91,7 @@ def run_query_cmd(
9191
schema = schema or get_default_schema()
9292
previous_sql = _load_sql_input(previous, "previous")
9393
current_sql = _load_sql_input(current, "current")
94-
connection = _resolve_connection(connection)
94+
connection = resolve_connection(connection)
9595

9696
schema_prefix = f"{schema}." if schema else ""
9797
suffix = _suffix_from_paths(previous, current)

sqlcompare/run_auto.py

Lines changed: 21 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,20 @@
11
from __future__ import annotations
22

3-
import os
43
import uuid
54

65
import typer
76

87
from sqlcompare.config import get_default_schema
98
from sqlcompare.db import DBConnection
10-
from sqlcompare.helpers import create_table_from_select, detect_input, ensure_schema
9+
from sqlcompare.helpers import (
10+
detect_input,
11+
materialize_sql_inputs,
12+
resolve_connection,
13+
resolve_materialized_tables,
14+
)
1115
from sqlcompare.run_table import compare_table
1216

1317

14-
def _resolve_connection(connection: str | None) -> str:
15-
if connection:
16-
return connection
17-
default_conn = os.getenv("SQLCOMPARE_CONN_DEFAULT") or os.getenv("DTK_CONN_DEFAULT")
18-
if not default_conn:
19-
raise typer.BadParameter(
20-
"No connection specified. Use --connection or set SQLCOMPARE_CONN_DEFAULT."
21-
)
22-
return default_conn
23-
24-
2518
def run_auto_cmd(
2619
previous: str = typer.Argument(
2720
..., help="Previous table name, CSV/XLSX file path, or SQL"
@@ -90,26 +83,24 @@ def run_auto_cmd(
9083
return
9184

9285
if prev_spec.kind == "sql" or new_spec.kind == "sql":
93-
connection = _resolve_connection(connection)
94-
schema_prefix = f"{schema}." if schema else ""
95-
suffix = uuid.uuid4().hex[:8]
96-
previous_table = (
97-
prev_spec.value
98-
if prev_spec.kind == "table"
99-
else f"{schema_prefix}sqlcompare_sql_{suffix}_previous"
100-
)
101-
new_table = (
102-
new_spec.value
103-
if new_spec.kind == "table"
104-
else f"{schema_prefix}sqlcompare_sql_{suffix}_new"
86+
connection = resolve_connection(connection)
87+
previous_table, new_table = resolve_materialized_tables(
88+
prev_spec,
89+
new_spec,
90+
schema=schema,
91+
prefix="sqlcompare_sql",
92+
suffix=uuid.uuid4().hex[:8],
10593
)
10694

10795
with DBConnection(connection) as db:
108-
ensure_schema(db, schema)
109-
if prev_spec.kind == "sql":
110-
create_table_from_select(db, previous_table, prev_spec.value)
111-
if new_spec.kind == "sql":
112-
create_table_from_select(db, new_table, new_spec.value)
96+
materialize_sql_inputs(
97+
db,
98+
previous_spec=prev_spec,
99+
current_spec=new_spec,
100+
previous_table=previous_table,
101+
current_table=new_table,
102+
schema=schema,
103+
)
113104

114105
compare_table(
115106
previous_table,

sqlcompare/stats/runner.py

Lines changed: 78 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,24 @@
11
from __future__ import annotations
22

3-
import os
3+
from dataclasses import replace
44
import uuid
55
from pathlib import Path
66

77
import typer
88

99
from sqlcompare.config import get_default_schema
1010
from sqlcompare.db import DBConnection
11-
from sqlcompare.helpers import create_table_from_select, detect_input, ensure_schema
11+
from sqlcompare.helpers import (
12+
detect_input,
13+
materialize_sql_inputs,
14+
resolve_connection,
15+
resolve_materialized_tables,
16+
)
1217
from sqlcompare.log import log
1318
from sqlcompare.stats.checks import get_check_map
1419
from sqlcompare.stats.models import ColumnPair, StatsContext
1520
from sqlcompare.stats.report import render_report
16-
17-
18-
def _resolve_connection(connection: str | None) -> str:
19-
if connection:
20-
return connection
21-
default_conn = os.getenv("SQLCOMPARE_CONN_DEFAULT") or os.getenv("DTK_CONN_DEFAULT")
22-
if not default_conn:
23-
raise ValueError(
24-
"No connection specified. Use --connection or set SQLCOMPARE_CONN_DEFAULT."
25-
)
26-
return default_conn
21+
from sqlcompare.utils.concurrency import run_ordered
2722

2823

2924
def _resolve_checks(checks: str | None) -> list[str]:
@@ -101,6 +96,29 @@ def _build_context(db: DBConnection, previous_name: str, current_name: str) -> S
10196
)
10297

10398

99+
def _run_selected_checks(
100+
context: StatsContext,
101+
selected_check_names: list[str],
102+
check_map: dict[str, object],
103+
connection_id: str | None,
104+
parallel_safe: bool,
105+
) -> list:
106+
def run_check(name: str):
107+
definition = check_map[name]
108+
if not parallel_safe:
109+
return definition.runner(context, definition)
110+
111+
with DBConnection(connection_id) as db:
112+
return definition.runner(replace(context, db=db), definition)
113+
114+
return run_ordered(
115+
selected_check_names,
116+
run_check,
117+
enabled=parallel_safe,
118+
max_workers=4,
119+
)
120+
121+
104122
def compare_table_stats(
105123
table1: str,
106124
table2: str,
@@ -125,42 +143,63 @@ def compare_table_stats(
125143
table1_name = Path(spec_prev.value).stem
126144
table2_name = Path(spec_new.value).stem
127145
elif spec_prev.kind == "sql" or spec_new.kind == "sql":
128-
connection_id = _resolve_connection(connection)
146+
connection_id = resolve_connection(connection, error_cls=ValueError)
129147
schema = get_default_schema()
130-
schema_prefix = f"{schema}." if schema else ""
131148
suffix = uuid.uuid4().hex[:8]
132-
table1_name = (
133-
spec_prev.value
134-
if spec_prev.kind == "table"
135-
else f"{schema_prefix}sqlcompare_stats_{suffix}_previous"
136-
)
137-
table2_name = (
138-
spec_new.value
139-
if spec_new.kind == "table"
140-
else f"{schema_prefix}sqlcompare_stats_{suffix}_new"
149+
table1_name, table2_name = resolve_materialized_tables(
150+
spec_prev,
151+
spec_new,
152+
schema=schema,
153+
prefix="sqlcompare_stats",
154+
suffix=suffix,
141155
)
142156
else:
143157
table1_name = spec_prev.value
144158
table2_name = spec_new.value
145159

146-
with DBConnection(connection_id) as db:
160+
parallel_safe = not (
161+
spec_prev.kind == "file" and spec_new.kind == "file" and connection is None
162+
)
163+
164+
def prepare_inputs(db: DBConnection) -> None:
147165
if spec_prev.kind == "file" and spec_new.kind == "file":
148166
if connection is None and connection_id == "duckdb:///:memory:":
149167
db.create_table_from_file(table1_name, spec_prev.value)
150168
db.create_table_from_file(table2_name, spec_new.value)
151-
if spec_prev.kind == "sql" or spec_new.kind == "sql":
152-
schema = get_default_schema()
153-
ensure_schema(db, schema)
154-
if spec_prev.kind == "sql":
155-
create_table_from_select(db, table1_name, spec_prev.value)
156-
if spec_new.kind == "sql":
157-
create_table_from_select(db, table2_name, spec_new.value)
158-
159-
context = _build_context(db, table1_name, table2_name)
160-
check_map = get_check_map()
161-
results = [
162-
check_map[name].runner(context, check_map[name])
163-
for name in selected_check_names
164-
]
169+
materialize_sql_inputs(
170+
db,
171+
previous_spec=spec_prev,
172+
current_spec=spec_new,
173+
previous_table=table1_name,
174+
current_table=table2_name,
175+
schema=get_default_schema(),
176+
)
177+
178+
if not parallel_safe:
179+
with DBConnection(connection_id) as db:
180+
prepare_inputs(db)
181+
context = _build_context(db, table1_name, table2_name)
182+
check_map = get_check_map()
183+
results = _run_selected_checks(
184+
context=context,
185+
selected_check_names=selected_check_names,
186+
check_map=check_map,
187+
connection_id=connection_id,
188+
parallel_safe=False,
189+
)
190+
else:
191+
with DBConnection(connection_id) as db:
192+
prepare_inputs(db)
193+
194+
with DBConnection(connection_id) as db:
195+
context = _build_context(db, table1_name, table2_name)
196+
check_map = get_check_map()
197+
results = _run_selected_checks(
198+
context=context,
199+
selected_check_names=selected_check_names,
200+
check_map=check_map,
201+
connection_id=connection_id,
202+
parallel_safe=True,
203+
)
165204

166205
log.info(render_report(context, selected_check_names, results))

sqlcompare/utils/concurrency.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
from __future__ import annotations
2+
3+
from concurrent.futures import ThreadPoolExecutor
4+
from typing import Callable, Sequence, TypeVar
5+
6+
T = TypeVar("T")
7+
R = TypeVar("R")
8+
9+
10+
def run_ordered(
11+
items: Sequence[T],
12+
worker: Callable[[T], R],
13+
*,
14+
enabled: bool = True,
15+
max_workers: int | None = None,
16+
) -> list[R]:
17+
"""Run work items concurrently while preserving the input order in the results."""
18+
if len(items) <= 1 or not enabled:
19+
return [worker(item) for item in items]
20+
21+
worker_count = max_workers or len(items)
22+
with ThreadPoolExecutor(max_workers=min(len(items), worker_count)) as executor:
23+
futures = [executor.submit(worker, item) for item in items]
24+
return [future.result() for future in futures]

0 commit comments

Comments
 (0)