|
16 | 16 | from textual.containers import Container, Horizontal, Vertical |
17 | 17 | from textual.widgets import DataTable, Static, TextArea, Tree |
18 | 18 |
|
19 | | -from .adapters import DatabaseAdapter, get_adapter |
| 19 | +from .adapters import DatabaseAdapter, create_ssh_tunnel, get_adapter |
20 | 20 | from .config import ( |
21 | 21 | ConnectionConfig, |
22 | 22 | load_connections, |
@@ -197,6 +197,7 @@ def __init__(self): |
197 | 197 | self.current_connection: Any | None = None |
198 | 198 | self.current_config: ConnectionConfig | None = None |
199 | 199 | self.current_adapter: DatabaseAdapter | None = None |
| 200 | + self.current_ssh_tunnel: Any | None = None |
200 | 201 | self.vim_mode: VimMode = VimMode.NORMAL |
201 | 202 | self._expanded_paths: set[str] = set() |
202 | 203 | self._schema_cache: dict = { |
@@ -1228,27 +1229,55 @@ def handle_connection_result(self, result: tuple | None) -> None: |
1228 | 1229 |
|
1229 | 1230 | def connect_to_server(self, config: ConnectionConfig) -> None: |
1230 | 1231 | """Connect to a database.""" |
| 1232 | + from dataclasses import replace |
| 1233 | + |
1231 | 1234 | # Check for pyodbc only if it's a SQL Server connection |
1232 | 1235 | if config.db_type == "mssql" and not PYODBC_AVAILABLE: |
1233 | 1236 | self.notify("pyodbc not installed. Run: pip install pyodbc", severity="error") |
1234 | 1237 | return |
1235 | 1238 |
|
1236 | 1239 | try: |
| 1240 | + # Close any existing SSH tunnel |
| 1241 | + if self.current_ssh_tunnel: |
| 1242 | + try: |
| 1243 | + self.current_ssh_tunnel.stop() |
| 1244 | + except Exception: |
| 1245 | + pass |
| 1246 | + self.current_ssh_tunnel = None |
| 1247 | + |
| 1248 | + # Create SSH tunnel if enabled |
| 1249 | + tunnel, host, port = create_ssh_tunnel(config) |
| 1250 | + self.current_ssh_tunnel = tunnel |
| 1251 | + |
| 1252 | + # If SSH tunnel was created, use the tunnel's local address |
| 1253 | + if tunnel: |
| 1254 | + connect_config = replace(config, server=host, port=str(port)) |
| 1255 | + else: |
| 1256 | + connect_config = config |
| 1257 | + |
1237 | 1258 | adapter = get_adapter(config.db_type) |
1238 | | - self.current_connection = adapter.connect(config) |
1239 | | - self.current_config = config |
| 1259 | + self.current_connection = adapter.connect(connect_config) |
| 1260 | + self.current_config = config # Store original config (not tunneled) |
1240 | 1261 | self.current_adapter = adapter |
1241 | 1262 | self._set_connection_health(config.name, True) |
1242 | 1263 |
|
1243 | 1264 | status = self.query_one("#status-bar", Static) |
1244 | 1265 | display_info = config.get_display_info() |
1245 | | - status.update(f"[#90EE90]Connected to {config.name}[/] ({display_info})") |
| 1266 | + ssh_indicator = " [SSH]" if tunnel else "" |
| 1267 | + status.update(f"[#90EE90]Connected to {config.name}[/] ({display_info}){ssh_indicator}") |
1246 | 1268 |
|
1247 | 1269 | self.refresh_tree() |
1248 | 1270 | self._load_schema_cache() |
1249 | 1271 | self.notify(f"Connected to {config.name}") |
1250 | 1272 |
|
1251 | 1273 | except Exception as e: |
| 1274 | + # Clean up SSH tunnel on failure |
| 1275 | + if self.current_ssh_tunnel: |
| 1276 | + try: |
| 1277 | + self.current_ssh_tunnel.stop() |
| 1278 | + except Exception: |
| 1279 | + pass |
| 1280 | + self.current_ssh_tunnel = None |
1252 | 1281 | self._set_connection_health(config.name, False) |
1253 | 1282 | self.refresh_tree() |
1254 | 1283 | self.notify(f"Connection failed: {e}", severity="error") |
@@ -1336,6 +1365,14 @@ def _disconnect_silent(self) -> None: |
1336 | 1365 | self.current_config = None |
1337 | 1366 | self.current_adapter = None |
1338 | 1367 |
|
| 1368 | + # Close SSH tunnel if active |
| 1369 | + if self.current_ssh_tunnel: |
| 1370 | + try: |
| 1371 | + self.current_ssh_tunnel.stop() |
| 1372 | + except Exception: |
| 1373 | + pass |
| 1374 | + self.current_ssh_tunnel = None |
| 1375 | + |
1339 | 1376 | def action_disconnect(self) -> None: |
1340 | 1377 | """Disconnect from current database.""" |
1341 | 1378 | if self.current_connection: |
|
0 commit comments