Skip to content
180 changes: 168 additions & 12 deletions ingestion/src/metadata/ingestion/source/database/oracle/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,16 @@
"""
Source connection handler
"""
import base64
import io
import os
import shutil
import sys
import tempfile
import weakref
import zipfile
from copy import deepcopy
from typing import Optional
from typing import Any, Optional
from urllib.parse import quote_plus

import oracledb
Expand All @@ -30,6 +36,7 @@
OracleConnection as OracleConnectionConfig,
)
from metadata.generated.schema.entity.services.connections.database.oracleConnection import (
OracleAutonomousConnection,
OracleDatabaseSchema,
OracleServiceName,
OracleTNSConnection,
Expand All @@ -41,6 +48,7 @@
create_generic_db_connection,
get_connection_args_common,
get_connection_options_dict,
init_empty_connection_arguments,
)
from metadata.ingestion.connections.connection import BaseConnection
from metadata.ingestion.connections.secrets import connection_with_options_secrets
Expand All @@ -67,22 +75,156 @@
class OracleConnection(BaseConnection[OracleConnectionConfig, Engine]):
def __init__(self, connection: OracleConnectionConfig):
super().__init__(connection)
self._wallet_temp_dir: Optional[str] = None
self._wallet_cleanup_finalizer: Optional[weakref.finalize] = None

def _set_wallet_temp_dir(self, wallet_temp_dir: str) -> None:
self._cleanup_wallet_temp_dir()
self._wallet_temp_dir = wallet_temp_dir
self._wallet_cleanup_finalizer = weakref.finalize(
self,
shutil.rmtree,
wallet_temp_dir,
ignore_errors=True,
)

def _cleanup_wallet_temp_dir(self) -> None:
wallet_temp_dir = self._wallet_temp_dir
if self._wallet_cleanup_finalizer and self._wallet_cleanup_finalizer.alive:
self._wallet_cleanup_finalizer()
elif wallet_temp_dir:
shutil.rmtree(wallet_temp_dir, ignore_errors=True)

self._wallet_cleanup_finalizer = None
self._wallet_temp_dir = None

def _is_autonomous_connection(self) -> bool:
return isinstance(
self.service_connection.oracleConnectionType, OracleAutonomousConnection
)

@staticmethod
def _get_autonomous_connection_config(
connection_type: OracleAutonomousConnection,
) -> Any:
return connection_type
Comment thread
hassaansaleem28 marked this conversation as resolved.

@staticmethod
def _safe_extract_wallet_archive(zip_ref: zipfile.ZipFile, target_dir: str) -> None:
target_dir_real = os.path.realpath(target_dir)
safe_prefix = f"{target_dir_real}{os.sep}"

for member in zip_ref.infolist():
member_path = os.path.realpath(
os.path.join(target_dir_real, member.filename)
)

if (
not member_path.startswith(safe_prefix)
and member_path != target_dir_real
):
raise ValueError(
"Invalid walletContent. Wallet zip contains unsafe file paths."
)

if member.is_dir():
os.makedirs(member_path, exist_ok=True)
continue

os.makedirs(os.path.dirname(member_path), exist_ok=True)
target_fd = os.open(
member_path, os.O_WRONLY | os.O_CREAT | os.O_TRUNC, 0o600
)
with zip_ref.open(member, "r") as source_file, os.fdopen(
target_fd, "wb"
) as target_file:
shutil.copyfileobj(source_file, target_file)
Comment thread
hassaansaleem28 marked this conversation as resolved.
Outdated

def _extract_wallet_content(self, wallet_content: SecretStr) -> str:
try:
decoded_wallet = base64.b64decode(wallet_content.get_secret_value())
except (ValueError, TypeError) as exc:
raise ValueError(
"Invalid walletContent. Expected a base64-encoded wallet zip."
) from exc
Comment thread
hassaansaleem28 marked this conversation as resolved.
Comment thread
hassaansaleem28 marked this conversation as resolved.

wallet_temp_dir = tempfile.mkdtemp(prefix="oracle_wallet_")
self._set_wallet_temp_dir(wallet_temp_dir)

try:
with zipfile.ZipFile(io.BytesIO(decoded_wallet)) as zip_ref:
self._safe_extract_wallet_archive(zip_ref, wallet_temp_dir)
except (ValueError, zipfile.BadZipFile) as exc:
self._cleanup_wallet_temp_dir()
if isinstance(exc, zipfile.BadZipFile):
raise ValueError(
"Invalid walletContent. Expected a valid zip archive."
) from exc
raise

return wallet_temp_dir

def _configure_autonomous_connection_arguments(self) -> None:
connection_type = self.service_connection.oracleConnectionType
if not isinstance(connection_type, OracleAutonomousConnection):
return

autonomous_connection = self._get_autonomous_connection_config(connection_type)
if not self.service_connection.connectionArguments:
self.service_connection.connectionArguments = (
init_empty_connection_arguments()
)
elif self.service_connection.connectionArguments.root is None:
self.service_connection.connectionArguments.root = {}

connection_arguments = self.service_connection.connectionArguments.root

wallet_path = autonomous_connection.walletPath
if autonomous_connection.walletContent:
if self._wallet_temp_dir and os.path.isdir(self._wallet_temp_dir):
wallet_path = self._wallet_temp_dir
else:
wallet_path = self._extract_wallet_content(
autonomous_connection.walletContent
)
else:
self._cleanup_wallet_temp_dir()

if not wallet_path:
raise ValueError(
"Oracle Autonomous connections require either walletPath or walletContent."
)

connection_arguments["config_dir"] = wallet_path
connection_arguments["wallet_location"] = wallet_path

if autonomous_connection.walletPassword:
connection_arguments[
"wallet_password"
] = autonomous_connection.walletPassword.get_secret_value()
else:
connection_arguments.pop("wallet_password", None)

def _get_client(self) -> Engine:
"""
Create connection
"""
try:
if self.service_connection.instantClientDirectory:
logger.info(
f"Initializing Oracle thick client at {self.service_connection.instantClientDirectory}"
)
os.environ[LD_LIB_ENV] = self.service_connection.instantClientDirectory
oracledb.init_oracle_client(
lib_dir=self.service_connection.instantClientDirectory
)
except DatabaseError as err:
logger.info(f"Could not initialize Oracle thick client: {err}")
self._configure_autonomous_connection_arguments()
Comment thread
gitar-bot[bot] marked this conversation as resolved.

if not self._is_autonomous_connection():
try:
if self.service_connection.instantClientDirectory:
logger.info(
f"Initializing Oracle thick client at {self.service_connection.instantClientDirectory}"
)
os.environ[
LD_LIB_ENV
] = self.service_connection.instantClientDirectory
oracledb.init_oracle_client(
lib_dir=self.service_connection.instantClientDirectory
)
except DatabaseError as err:
logger.info(f"Could not initialize Oracle thick client: {err}")

return create_generic_db_connection(
connection=self.service_connection,
Expand Down Expand Up @@ -150,6 +292,13 @@ def get_connection_dict(self) -> dict:
connection_dict[
"host"
] = connection_copy.oracleConnectionType.oracleTNSConnection
elif isinstance(
connection_copy.oracleConnectionType, OracleAutonomousConnection
):
autonomous_connection = self._get_autonomous_connection_config(
connection_copy.oracleConnectionType
)
connection_dict["host"] = autonomous_connection.tnsAlias

# Add connection options if present
if connection_copy.connectionOptions and connection_copy.connectionOptions.root:
Expand Down Expand Up @@ -209,6 +358,13 @@ def _handle_connection_type(url: str, connection: OracleConnectionConfig) -> str
url += connection.oracleConnectionType.oracleTNSConnection
return url

if isinstance(connection.oracleConnectionType, OracleAutonomousConnection):
autonomous_connection = OracleConnection._get_autonomous_connection_config(
connection.oracleConnectionType
)
url += autonomous_connection.tnsAlias
return url

# If not TNS, we add the hostPort
url += connection.hostPort

Expand Down
146 changes: 146 additions & 0 deletions ingestion/tests/unit/test_source_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import base64
import io
import os
import zipfile
from unittest import TestCase
from unittest.mock import patch

from trino.auth import BasicAuthentication, JWTAuthentication, OAuth2Authentication

Expand Down Expand Up @@ -78,6 +83,7 @@
OracleConnection as OracleConnectionConfig,
)
from metadata.generated.schema.entity.services.connections.database.oracleConnection import (
OracleAutonomousConnection,
OracleDatabaseSchema,
OracleScheme,
OracleServiceName,
Expand Down Expand Up @@ -1227,6 +1233,146 @@ def test_oracle_url(self):
)
assert OracleConnection.get_connection_url(oracle_conn_obj) == expected_url

expected_url = "oracle+cx_oracle://admin:password@myadb_high"
oracle_conn_obj = OracleConnectionConfig(
username="admin",
password="password",
oracleConnectionType=OracleAutonomousConnection(
tnsAlias="myadb_high",
walletPath="/tmp/my_wallet",
),
)
assert OracleConnection.get_connection_url(oracle_conn_obj) == expected_url

expected_url = [
"oracle+cx_oracle://admin:password@myadb_high?test_key_2=test_value_2&test_key_1=test_value_1",
"oracle+cx_oracle://admin:password@myadb_high?test_key_1=test_value_1&test_key_2=test_value_2",
]
oracle_conn_obj = OracleConnectionConfig(
username="admin",
password="password",
oracleConnectionType=OracleAutonomousConnection(
tnsAlias="myadb_high",
walletPath="/tmp/my_wallet",
),
connectionOptions=dict(
test_key_1="test_value_1", test_key_2="test_value_2"
),
)
assert OracleConnection.get_connection_url(oracle_conn_obj) in expected_url

@patch(
"metadata.ingestion.source.database.oracle.connection.oracledb.init_oracle_client"
)
@patch(
"metadata.ingestion.source.database.oracle.connection.create_generic_db_connection"
)
def test_oracle_autonomous_wallet_path_args(
self, mock_create_generic_db_connection, mock_init_oracle_client
):
connection = OracleConnectionConfig(
username="admin",
password="password",
instantClientDirectory="/instantclient",
oracleConnectionType=OracleAutonomousConnection(
tnsAlias="myadb_high",
walletPath="/tmp/my_wallet",
walletPassword="wallet_password",
),
)
oracle_connection = OracleConnection(connection)
mock_create_generic_db_connection.return_value = "dummy_engine"

oracle_connection._get_client()

assert mock_init_oracle_client.call_count == 0
assert (
oracle_connection.service_connection.connectionArguments.root["config_dir"]
== "/tmp/my_wallet"
)
assert (
oracle_connection.service_connection.connectionArguments.root[
"wallet_location"
]
== "/tmp/my_wallet"
)
assert (
oracle_connection.service_connection.connectionArguments.root[
"wallet_password"
]
== "wallet_password"
)

@patch(
"metadata.ingestion.source.database.oracle.connection.create_generic_db_connection"
)
def test_oracle_autonomous_wallet_content_args(
self, mock_create_generic_db_connection
):
wallet_bytes = io.BytesIO()
with zipfile.ZipFile(wallet_bytes, "w", zipfile.ZIP_DEFLATED) as zip_file:
zip_file.writestr("tnsnames.ora", "MYADB_HIGH=(DESCRIPTION=...)")

encoded_wallet = base64.b64encode(wallet_bytes.getvalue()).decode("utf-8")

connection = OracleConnectionConfig(
username="admin",
password="password",
oracleConnectionType=OracleAutonomousConnection(
tnsAlias="myadb_high",
walletContent=encoded_wallet,
),
)
oracle_connection = OracleConnection(connection)
mock_create_generic_db_connection.return_value = "dummy_engine"

oracle_connection._get_client()

wallet_dir = oracle_connection.service_connection.connectionArguments.root[
"config_dir"
]
assert os.path.isdir(wallet_dir)
assert os.path.exists(os.path.join(wallet_dir, "tnsnames.ora"))

# Repeated _get_client calls should reuse the same extracted wallet directory.
oracle_connection._get_client()
assert (
oracle_connection.service_connection.connectionArguments.root["config_dir"]
== wallet_dir
)

oracle_connection._cleanup_wallet_temp_dir()
assert not os.path.exists(wallet_dir)

@patch(
"metadata.ingestion.source.database.oracle.connection.create_generic_db_connection"
)
def test_oracle_autonomous_wallet_content_zip_slip_rejected(
self, mock_create_generic_db_connection
):
wallet_bytes = io.BytesIO()
with zipfile.ZipFile(wallet_bytes, "w", zipfile.ZIP_DEFLATED) as zip_file:
zip_file.writestr("../malicious.txt", "malicious")

encoded_wallet = base64.b64encode(wallet_bytes.getvalue()).decode("utf-8")

connection = OracleConnectionConfig(
username="admin",
password="password",
oracleConnectionType=OracleAutonomousConnection(
tnsAlias="myadb_high",
walletContent=encoded_wallet,
),
)
oracle_connection = OracleConnection(connection)
mock_create_generic_db_connection.return_value = "dummy_engine"

with self.assertRaises(ValueError) as error:
oracle_connection._get_client()

assert "unsafe file paths" in str(error.exception)
assert oracle_connection._wallet_temp_dir is None

def test_exasol_url(self):
from metadata.ingestion.source.database.exasol.connection import (
get_connection_url,
Expand Down
Loading
Loading