Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion doc/guides/utilities/customizing_configuration.rst
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ are available by using
>>> import skrub
>>> config = skrub.get_config()
>>> config.keys()
dict_keys(['use_table_report_data_ops', 'table_report_verbosity', 'max_plot_columns', 'max_association_columns', 'subsampling_seed', 'enable_subsampling', 'float_precision', 'cardinality_threshold', 'data_dir', 'eager_data_ops'])
dict_keys(['use_table_report_data_ops', 'table_report_verbosity', 'max_plot_columns', 'max_association_columns', 'subsampling_seed', 'enable_subsampling', 'float_precision', 'cardinality_threshold', 'data_dir', 'cache_dir', 'memory', 'eager_data_ops'])

These are the parameters currently available in the global configuration:

Expand Down
13 changes: 13 additions & 0 deletions skrub/_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from contextlib import contextmanager
from pathlib import Path

import joblib
import numpy as np


Expand Down Expand Up @@ -40,6 +41,12 @@ def _get_default_data_dir():
return str(data_home)


def _get_default_cache_dir():
cache_dir = Path(_get_default_data_dir()) / "_cache"
cache_dir.mkdir(exist_ok=True)
return str(cache_dir)


def _parse_env_bool(env_variable_name, default):
value = os.getenv(env_variable_name, default)
if isinstance(value, bool):
Expand All @@ -64,6 +71,8 @@ def _parse_env_bool(env_variable_name, default):
"float_precision": int(os.environ.get("SKB_FLOAT_PRECISION", 3)),
"cardinality_threshold": int(os.environ.get("SKB_CARDINALITY_THRESHOLD", 40)),
"data_dir": _get_default_data_dir(),
"cache_dir": _get_default_cache_dir(),
"memory": joblib.Memory(_get_default_cache_dir(), verbose=0),
"eager_data_ops": _parse_env_bool("SKB_EAGER_DATA_OPS", True),
}
_threadlocal = threading.local()
Expand Down Expand Up @@ -113,6 +122,8 @@ def set_config(
float_precision=None,
cardinality_threshold=None,
data_dir=None,
cache_dir=None,
memory=None,
eager_data_ops=None,
):
"""Set global skrub configuration.
Expand Down Expand Up @@ -314,6 +325,8 @@ def config_context(
float_precision=None,
cardinality_threshold=None,
data_dir=None,
cache_dir=None,
memory=None,
eager_data_ops=None,
):
"""Context manager for global skrub configuration.
Expand Down
19 changes: 19 additions & 0 deletions skrub/_data_ops/_cached_helpers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
"""
functions meant to be cached with joblib.

They are in their own module so the cache is less likely to be invalidated due
to the line number of the function definition changing.
"""

import joblib


def _call_fitting_method(estimator, method_name, args, kwargs):
# we could also just generate a str(uuid.uuid4()) 🤔
estimator_id = joblib.hash((estimator, method_name, args, kwargs))
result = getattr(estimator, method_name)(*args, **kwargs)
return estimator, result, estimator_id


def _call_non_fitting_method(estimator, method_name, args, kwargs, estimator_id):
return getattr(estimator, method_name)(*args, **kwargs)
58 changes: 50 additions & 8 deletions skrub/_data_ops/_data_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
import itertools
import operator
import pathlib
import pickle
import re
import textwrap
import traceback
Expand All @@ -49,7 +50,7 @@
from .._reporting._utils import strip_xml_declaration
from .._utils import PassThrough, set_module, short_repr
from .._wrap_transformer import wrap_transformer
from . import _utils
from . import _cached_helpers, _utils
from ._choosing import get_chosen_or_default
from ._utils import FITTED_PREDICTOR_METHODS, NULL, attribute_error

Expand Down Expand Up @@ -1325,6 +1326,36 @@ def check_subsampled_X_y_shape(X_op, y_op, X_value, y_value, mode, environment,
)


def _call_fitting_method(estimator, method_name, args, kwargs):
memory = _config.get_config()["memory"]
if memory is None:
result = getattr(estimator, method_name)(*args, **kwargs)
return estimator, result, None
try:
return memory.cache(_cached_helpers._call_fitting_method)(
estimator, method_name, args, kwargs
)
except pickle.PicklingError:
pass
# Fall back to non-cached call if arguments cannot be serialized
result = getattr(estimator, method_name)(*args, **kwargs)
return estimator, result, None


def _call_non_fitting_method(estimator, method_name, args, kwargs, estimator_id):
memory = _config.get_config()["memory"]
if memory is None or estimator_id is None:
return getattr(estimator, method_name)(*args, **kwargs)
try:
return memory.cache(
_cached_helpers._call_non_fitting_method, ignore=["estimator"]
)(estimator, method_name, args, kwargs, estimator_id)
except pickle.PicklingError:
pass
# Fall back to non-cached call if arguments cannot be serialized
return getattr(estimator, method_name)(*args, **kwargs)


class Apply(DataOpImpl):
""".skb.apply() nodes."""

Expand Down Expand Up @@ -1389,9 +1420,13 @@ def eval(self, *, mode, environment):
# with `.predict()`
if method_name == "fit_transform":
fit_kwargs = yield from self._eval_kwargs("fit")
self.estimator_.fit(X, y, **fit_kwargs)
self.estimator_, _, self.estimator_id_ = _call_fitting_method(
self.estimator_, "fit", (X, y), fit_kwargs
)
predict_kwargs = yield from self._eval_kwargs("predict")
pred = self.estimator_.predict(X, **predict_kwargs)
pred = _call_non_fitting_method(
self.estimator_, "predict", (X,), predict_kwargs, self.estimator_id_
)
# In `(fit_)transform` mode only, format the predictions as a
# dataframe or column if y was one during `fit()`
return self._format_predictions(X, pred)
Expand All @@ -1402,13 +1437,20 @@ def eval(self, *, mode, environment):
method_name = "fit_transform"

if "fit" in method_name:
y_arg = () if self.unsupervised else (y,)
args = (X,) if self.unsupervised else (X, y)
elif method_name == "score":
y_arg = (y,)
args = (X, y)
else:
y_arg = ()
method_kwargs = yield from self._eval_kwargs(method_name)
return getattr(self.estimator_, method_name)(X, *y_arg, **method_kwargs)
args = (X,)
kwargs = yield from self._eval_kwargs(method_name)
if "fit" in method_name:
self.estimator_, result, self.estimator_id_ = _call_fitting_method(
self.estimator_, method_name, args, kwargs
)
return result
return _call_non_fitting_method(
self.estimator_, method_name, args, kwargs, self.estimator_id_
)

def _store_y_format(self, y):
if sbd.is_dataframe(y):
Expand Down
2 changes: 2 additions & 0 deletions skrub/tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ def test_default_config():
expected_keys = {
"use_table_report_data_ops",
"data_dir",
"cache_dir",
"memory",
"table_report_verbosity",
"max_plot_columns",
"max_association_columns",
Expand Down
Loading