Skip to content

Commit c31c619

Browse files
committed
fix: cache fingerprint stability for wrapped methods and FusedFilter
The Hasher only walked one level of __wrapped__ when resolving bound method owners, but wrap_func_with_nested_access adds multiple decorator layers. This caused every NestedDataset.map/filter call to miss cache despite OP hashes matching, because the actual function passed to HF datasets was the deeply-wrapped variant. Additionally, FusedFilter.fused_filters (a list of child OPs) was serialized via dill.dumps which included each child's work_dir, defeating cache for fused pipelines. - Walk the full __wrapped__ chain (up to 10 levels) in Hasher._find_op_owner - Recursively sanitize nested OP instances in _fingerprint_bytes - Add tests for FusedFilter, wrapped methods, and multi-step pipeline cache hits
1 parent 09ac5db commit c31c619

File tree

3 files changed

+146
-15
lines changed

3 files changed

+146
-15
lines changed

data_juicer/ops/base_op.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -329,10 +329,27 @@ def _fingerprint_bytes(self):
329329
poison the cache. Callable attributes (bound/wrapped methods like
330330
``process``, ``compute_stats``) are also excluded because they
331331
close over ``self`` and would re-introduce the excluded attrs.
332+
333+
Nested OP instances (e.g. ``FusedFilter.fused_filters``) are
334+
recursively fingerprinted via their own ``_fingerprint_bytes``
335+
so that their ``work_dir`` is also excluded.
332336
"""
333337
import dill
334338

335-
state = {k: v for k, v in self.__dict__.items() if k not in self._NON_FINGERPRINT_ATTRS and not callable(v)}
339+
def _sanitize(v):
340+
"""Recursively replace OP instances with their fingerprint bytes."""
341+
if isinstance(v, OP) and hasattr(v, "_fingerprint_bytes"):
342+
return v._fingerprint_bytes()
343+
if isinstance(v, (list, tuple)):
344+
converted = [_sanitize(item) for item in v]
345+
return type(v)(converted)
346+
return v
347+
348+
state = {}
349+
for k, v in self.__dict__.items():
350+
if k in self._NON_FINGERPRINT_ATTRS or callable(v):
351+
continue
352+
state[k] = _sanitize(v)
336353
return dill.dumps(state)
337354

338355
def __init__(self, *args, **kwargs):

data_juicer/utils/fingerprint_utils.py

Lines changed: 27 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,31 @@ def hash_bytes(cls, value: Union[bytes, List[bytes]]) -> str:
2929
m.update(x)
3030
return m.hexdigest()
3131

32+
@classmethod
33+
def _find_op_owner(cls, value):
34+
"""Walk the ``__self__`` / ``__wrapped__`` chain to find an object
35+
that exposes ``_fingerprint_bytes``. Returns ``(obj, func_name)``
36+
or ``(None, None)``."""
37+
# Direct bound method
38+
obj = getattr(value, "__self__", None)
39+
if obj is not None:
40+
if callable(getattr(obj, "_fingerprint_bytes", None)):
41+
func_name = getattr(value, "__name__", getattr(value, "__qualname__", ""))
42+
return obj, func_name
43+
# Walk the full __wrapped__ chain (handles multiple decorator
44+
# layers such as wrap_func_with_nested_access → @wraps → bound
45+
# method).
46+
cur = value
47+
for _ in range(10): # guard against infinite loops
48+
cur = getattr(cur, "__wrapped__", None)
49+
if cur is None:
50+
break
51+
obj = getattr(cur, "__self__", None)
52+
if obj is not None and callable(getattr(obj, "_fingerprint_bytes", None)):
53+
func_name = getattr(cur, "__name__", getattr(cur, "__qualname__", ""))
54+
return obj, func_name
55+
return None, None
56+
3257
@classmethod
3358
def hash_default(cls, value: Any) -> str:
3459
"""
@@ -45,21 +70,9 @@ def hash_default(cls, value: Any) -> str:
4570
# _fingerprint_bytes, hash the (fingerprint, method_name) pair
4671
# instead of dill-dumping the bound method (which would
4772
# re-serialize the full object including excluded attrs).
48-
obj = getattr(value, "__self__", None)
73+
obj, func_name = cls._find_op_owner(value)
4974
if obj is not None:
50-
obj_fp = getattr(obj, "_fingerprint_bytes", None)
51-
if callable(obj_fp):
52-
func_name = getattr(value, "__name__", getattr(value, "__qualname__", ""))
53-
return cls.hash_bytes(obj_fp() + dill.dumps(func_name))
54-
# functools.wraps closures: check __wrapped__.__self__
55-
wrapped = getattr(value, "__wrapped__", None)
56-
if wrapped is not None:
57-
obj = getattr(wrapped, "__self__", None)
58-
if obj is not None:
59-
obj_fp = getattr(obj, "_fingerprint_bytes", None)
60-
if callable(obj_fp):
61-
func_name = getattr(wrapped, "__name__", getattr(wrapped, "__qualname__", ""))
62-
return cls.hash_bytes(obj_fp() + dill.dumps(func_name))
75+
return cls.hash_bytes(obj._fingerprint_bytes() + dill.dumps(func_name))
6376
return cls.hash_bytes(dill.dumps(value))
6477

6578
@classmethod

tests/utils/test_fingerprint_utils.py

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,5 +85,106 @@ def test_serialization_round_trip_preserves_all_attrs(self):
8585
self.assertTrue(restored.skip_op_error)
8686

8787

88+
class FusedFilterFingerprintTest(DataJuicerTestCaseBase):
89+
"""Tests that FusedFilter fingerprints exclude child OP work_dirs."""
90+
91+
def test_fused_filter_stable_across_work_dirs(self):
92+
from data_juicer.ops.filter.words_num_filter import WordsNumFilter
93+
from data_juicer.ops.op_fusion import FusedFilter
94+
95+
f1a = TextLengthFilter(min_len=5, max_len=10000, work_dir='/tmp/a')
96+
f2a = WordsNumFilter(min_num=2, max_num=1000, work_dir='/tmp/a')
97+
fused_a = FusedFilter('fused', [f1a, f2a])
98+
99+
f1b = TextLengthFilter(min_len=5, max_len=10000, work_dir='/tmp/b')
100+
f2b = WordsNumFilter(min_num=2, max_num=1000, work_dir='/tmp/b')
101+
fused_b = FusedFilter('fused', [f1b, f2b])
102+
103+
self.assertEqual(Hasher.hash(fused_a), Hasher.hash(fused_b))
104+
105+
def test_fused_filter_differs_when_child_params_change(self):
106+
from data_juicer.ops.filter.words_num_filter import WordsNumFilter
107+
from data_juicer.ops.op_fusion import FusedFilter
108+
109+
f1a = TextLengthFilter(min_len=5, max_len=10000, work_dir='/tmp/a')
110+
f2a = WordsNumFilter(min_num=2, max_num=1000, work_dir='/tmp/a')
111+
fused_a = FusedFilter('fused', [f1a, f2a])
112+
113+
f1b = TextLengthFilter(min_len=50, max_len=10000, work_dir='/tmp/a')
114+
f2b = WordsNumFilter(min_num=2, max_num=1000, work_dir='/tmp/a')
115+
fused_b = FusedFilter('fused', [f1b, f2b])
116+
117+
self.assertNotEqual(Hasher.hash(fused_a), Hasher.hash(fused_b))
118+
119+
120+
class WrappedFunctionFingerprintTest(DataJuicerTestCaseBase):
121+
"""Tests that wrapped bound methods (via wrap_func_with_nested_access)
122+
produce stable fingerprints across work_dirs."""
123+
124+
def test_wrapped_compute_stats_stable(self):
125+
from data_juicer.core.data.dj_dataset import wrap_func_with_nested_access
126+
127+
op_a = TextLengthFilter(min_len=5, max_len=10000, work_dir='/tmp/a')
128+
op_b = TextLengthFilter(min_len=5, max_len=10000, work_dir='/tmp/b')
129+
wa = wrap_func_with_nested_access(op_a.compute_stats)
130+
wb = wrap_func_with_nested_access(op_b.compute_stats)
131+
self.assertEqual(Hasher.hash(wa), Hasher.hash(wb))
132+
133+
def test_wrapped_differs_when_params_change(self):
134+
from data_juicer.core.data.dj_dataset import wrap_func_with_nested_access
135+
136+
op_a = TextLengthFilter(min_len=5, max_len=10000, work_dir='/tmp/a')
137+
op_b = TextLengthFilter(min_len=50, max_len=10000, work_dir='/tmp/a')
138+
wa = wrap_func_with_nested_access(op_a.compute_stats)
139+
wb = wrap_func_with_nested_access(op_b.compute_stats)
140+
self.assertNotEqual(Hasher.hash(wa), Hasher.hash(wb))
141+
142+
def test_multistep_pipeline_cache_hit(self):
143+
"""Full pipeline with multiple OPs: second run with different
144+
work_dir must produce zero new cache files."""
145+
import glob
146+
import os
147+
148+
from datasets import load_dataset, enable_caching
149+
150+
from data_juicer.ops.filter.alphanumeric_filter import AlphanumericFilter
151+
from data_juicer.ops.filter.words_num_filter import WordsNumFilter
152+
from data_juicer.utils.constant import Fields
153+
154+
enable_caching()
155+
ds = NestedDataset(load_dataset(
156+
'json',
157+
data_files='demos/data/demo-dataset.jsonl',
158+
split='train',
159+
))
160+
if Fields.stats not in ds.features:
161+
ds = ds.map(lambda x: {Fields.stats: {}})
162+
cache_dir = os.path.dirname(ds.cache_files[0]['filename'])
163+
164+
def run_pipeline(dataset, work_dir):
165+
ops = [
166+
TextLengthFilter(min_len=5, max_len=10000, work_dir=work_dir),
167+
WordsNumFilter(min_num=2, max_num=1000, work_dir=work_dir),
168+
AlphanumericFilter(min_ratio=0.0, max_ratio=1.0,
169+
work_dir=work_dir),
170+
]
171+
cur = dataset
172+
for op in ops:
173+
cur = cur.map(op.compute_stats, num_proc=1)
174+
cur = cur.filter(op.process, num_proc=1)
175+
return cur
176+
177+
run_pipeline(ds, '/tmp/pipeline_test_A')
178+
cache_after_a = set(glob.glob(os.path.join(cache_dir, '*.arrow')))
179+
180+
run_pipeline(ds, '/tmp/pipeline_test_B')
181+
cache_after_b = set(glob.glob(os.path.join(cache_dir, '*.arrow')))
182+
183+
new_files = cache_after_b - cache_after_a
184+
self.assertEqual(len(new_files), 0,
185+
f'Pipeline B created {len(new_files)} new cache '
186+
f'files; expected 0 (full cache hit)')
187+
188+
88189
if __name__ == '__main__':
89190
unittest.main()

0 commit comments

Comments
 (0)