Skip to content

Commit de0fb21

Browse files
feat: Add support for Snowflake table column type retrieval using DESCRIBE TABLE
1 parent 5766461 commit de0fb21

File tree

2 files changed

+59
-0
lines changed

2 files changed

+59
-0
lines changed

sqlcompare/db/connection.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -265,6 +265,29 @@ def get_table_column_types(self, table_name: str) -> dict[str, str | None]:
265265
"""
266266
Get column types keyed by the actual column name as stored in the database.
267267
"""
268+
try:
269+
dialect_name = self._engine.dialect.name.lower() if self._engine else None
270+
except Exception:
271+
dialect_name = None
272+
273+
# SQLAlchemy's Snowflake inspector mis-parses three-part names by combining
274+
# the current database with the provided database.schema prefix. DESCRIBE
275+
# TABLE accepts the fully qualified name directly and preserves identifier case.
276+
if dialect_name == "snowflake":
277+
try:
278+
result = self.query(f"DESCRIBE TABLE {table_name}")
279+
if result:
280+
return {
281+
str(row[0]): (
282+
None
283+
if len(row) < 2 or row[1] is None
284+
else str(row[1]).upper()
285+
)
286+
for row in result
287+
}
288+
except Exception:
289+
pass
290+
268291
inspector = inspect(self.conn)
269292
schema = None
270293
table = table_name

tests/test_db_connection.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
from __future__ import annotations
2+
3+
from types import SimpleNamespace
4+
5+
from sqlcompare.db import DBConnection
6+
import sqlcompare.db.connection as connection_module
7+
8+
9+
def test_get_table_column_types_snowflake_uses_describe_for_fully_qualified_name(
10+
monkeypatch,
11+
) -> None:
12+
db = DBConnection("snowflake://example")
13+
db._engine = SimpleNamespace(dialect=SimpleNamespace(name="snowflake"))
14+
15+
captured: dict[str, str] = {}
16+
17+
def fake_query(sql: str, *args, **kwargs):
18+
captured["sql"] = sql
19+
return [("ID", "NUMBER(38,0)"), ("NAME", "VARCHAR(16777216)")]
20+
21+
monkeypatch.setattr(db, "query", fake_query)
22+
23+
def fail_inspect(_conn):
24+
raise AssertionError("inspect() should not be used for Snowflake")
25+
26+
monkeypatch.setattr(connection_module, "inspect", fail_inspect)
27+
28+
result = db.get_table_column_types(
29+
"LH_PROD_DATA_SNOWFLAKE.salesforce_derived.account_augmented"
30+
)
31+
32+
assert (
33+
captured["sql"]
34+
== "DESCRIBE TABLE LH_PROD_DATA_SNOWFLAKE.salesforce_derived.account_augmented"
35+
)
36+
assert result == {"ID": "NUMBER(38,0)", "NAME": "VARCHAR(16777216)"}

0 commit comments

Comments
 (0)