@@ -78,7 +78,7 @@ def __enter__(self) -> "DBConnection":
7878
7979 return self
8080
81- def __exit__ (self , exc_type , _exc , _tb ) -> None :
81+ def __exit__ (self , exc_type , _exc_val , _exc_tb ) -> None :
8282 # Commit / rollback transaction
8383 try :
8484 if self ._tx is not None :
@@ -196,12 +196,22 @@ def _run(self, sql: str, params: dict[str, Any] | None = None) -> Result:
196196 res = self .conn .execute (text (sql ), params or {})
197197 return res
198198 except SQLAlchemyError as e :
199+ # Extract just the database error message without the full SQL statement
200+ error_msg = str (e )
201+ # SQLAlchemy often includes the SQL in square brackets at the end
202+ # Format: "error message [SQL: long query here]"
203+ if "[SQL:" in error_msg :
204+ # Extract just the part before [SQL:
205+ db_error = error_msg .split ("[SQL:" )[0 ].strip ()
206+ else :
207+ db_error = error_msg
208+
199209 raise DBConnectionError (
200- "SQL execution failed." ,
210+ db_error ,
201211 conn_id = self .conn_id ,
202212 sql = sql ,
203213 original = e ,
204- ) from e
214+ ) from None # Suppress the original exception chain to avoid showing SQL twice
205215 finally :
206216 self .last_elapsed_ms = int ((time .perf_counter () - t0 ) * 1000 )
207217
@@ -220,6 +230,37 @@ def execute(self, sql: str, params: dict[str, Any] | None = None) -> ExecMeta:
220230 elapsed_ms = self .last_elapsed_ms or 0 , rowcount = rowcount , columns = cols
221231 )
222232
233+ def get_table_columns (self , table_name : str ) -> list [str ]:
234+ """
235+ Get the actual column names from a table, using database-specific metadata queries
236+ to ensure correct case sensitivity (especially important for Snowflake).
237+
238+ Args:
239+ table_name: Fully qualified table name (e.g., "schema.table" or "db.schema.table")
240+
241+ Returns:
242+ List of column names with their actual case as stored in the database
243+ """
244+ # Detect if we're using Snowflake by checking the dialect
245+ try :
246+ dialect_name = self ._engine .dialect .name .lower () if self ._engine else None
247+ except Exception :
248+ dialect_name = None
249+
250+ # For Snowflake, use DESCRIBE TABLE to get actual column names
251+ if dialect_name == "snowflake" :
252+ try :
253+ # DESCRIBE TABLE returns columns with their actual case
254+ result = self .query (f"DESCRIBE TABLE { table_name } " )
255+ # First column is the column name
256+ return [row [0 ] for row in result ]
257+ except Exception :
258+ pass
259+
260+ # Fallback: use SELECT * WHERE 1=0 and get column names from result
261+ _ , columns = self .query (f"SELECT * FROM { table_name } WHERE 1=0" , include_columns = True )
262+ return columns
263+
223264 def query (
224265 self ,
225266 sql : str ,
0 commit comments