Skip to content

Commit f242ee8

Browse files
authored
Merge branch 'main' into op/group_diversity
2 parents 216bf50 + 6e6e4ab commit f242ee8

21 files changed

+1323
-48
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ Data-Juicer is being actively updated and maintained. We will periodically enhan
3838

3939
[Demo Video] DataJuicer-Agent: Quick start your data processing journey!
4040

41-
https://github.com/user-attachments/assets/58aea900-e51f-4ec2-b1c0-eead97967893
41+
https://github.com/user-attachments/assets/6eb726b7-6054-4b0c-905e-506b2b9c7927
4242

4343
[Demo Video] DataJuicer-Sandbox: Better data-model co-dev at a lower cost!
4444

README_ZH.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ Data-Juicer正在积极更新和维护中,我们将定期强化和新增更多
3232

3333
[Demo Video] DataJuicer-Agent:数据处理,即刻启程!
3434

35-
https://github.com/user-attachments/assets/58aea900-e51f-4ec2-b1c0-eead97967893
35+
https://github.com/user-attachments/assets/6eb726b7-6054-4b0c-905e-506b2b9c7927
3636

3737
[Demo Video] DataJuicer-Sandbox: 降本增效,优化数据-模型协同开发!
3838

configs/config_all.yaml

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -721,6 +721,24 @@ process:
721721
prob_threshold: 0.8 # the predicted watermark probability threshold for samples, range from 0 to 1. Samples with watermark probability less than this threshold will be kept.
722722
any_or_all: any # keep this sample when any/all images meet the filter condition
723723
mem_required: '500MB' # This operation (Op) utilizes deep neural network models that consume a significant amount of memory for computation, hence the system's available memory might constrain the maximum number of processes that can be launched
724+
- in_context_influence_filter: # filter to keep texts whose in-context influence upon validation set within a specific range.
725+
hf_model: 'Qwen/Qwen2.5-0.5B' # Huggingface embedding model name.
726+
model_params: null # Parameters for initializing the API model.
727+
min_score: 1.0 # Minimum perplexity score.
728+
max_score: 100.0 # Maximum perplexity score.
729+
query_template: null # Template for building the query string.
730+
response_template: mull # Template for building the response string.
731+
valid_dataset: null # The dataset to use for validation
732+
task_desc: null # The description of the validation task.
733+
valid_as_demo: True # If true, score = L(A|Q) / L(A|task_desc, Q_v, A_v, Q); If false, score = L(A_v|Q) L(A_v|task_desc, Q, A, Q_v) .
734+
n_shot: null # The number of shots in validation.
735+
- instruction_following_difficulty_filter: # filter to keep texts whose instruction follows difficulty (IFD, https://arxiv.org/abs/2308.12032) falls within a specific range."
736+
hf_model: 'Qwen/Qwen2.5-0.5B' # Huggingface embedding model name.
737+
model_params: null # Parameters for initializing the API model.
738+
min_score: 1.0 # Minimum perplexity score.
739+
max_score: 100.0 # Maximum perplexity score.
740+
query_template: null # Template for building the query string.
741+
response_template: mull # Template for building the response string.
724742
- language_id_score_filter: # filter text in specific language with language scores larger than a specific max value
725743
lang: en # keep text in what language
726744
min_score: 0.8 # the min language scores to filter text
@@ -739,6 +757,13 @@ process:
739757
enable_vllm: false # If true, use VLLM for loading hugging face or local llm. Otherwise, use API for reference.
740758
model_params: {} # Parameters for initializing the API model.
741759
sampling_params: {} # Extra parameters passed to the API call. e.g {'temperature': 0.9, 'top_p': 0.95}
760+
- llm_perplexity_filter: # filter to keep samples with perplexity score, computed using a specified llm, within a specific range.
761+
hf_model: 'Qwen/Qwen2.5-0.5B' # Huggingface embedding model name.
762+
model_params: null # Parameters for initializing the API model.
763+
min_score: 1.0 # Minimum perplexity score.
764+
max_score: 100.0 # Maximum perplexity score.
765+
query_template: null # Template for building the query string.
766+
response_template: null # Template for building the response string.
742767
- llm_quality_score_filter: # filter to keep sample with high quality score estimated by LLM.
743768
api_or_hf_model: 'gpt-4o' # API or huggingface model name.
744769
min_score: 0.5 # The lowest quality score threshold to keep the sample.
@@ -754,6 +779,23 @@ process:
754779
enable_vllm: false # If true, use VLLM for loading hugging face or local llm. Otherwise, use API for reference.
755780
model_params: {} # Parameters for initializing the API model.
756781
sampling_params: {} # Extra parameters passed to the API call. e.g {'temperature': 0.9, 'top_p': 0.95}
782+
- llm_task_relevance_filter: # filter to keep sample with high relevance score to validation tasks estimated by LLM.
783+
api_or_hf_model: 'gpt-4o' # API or huggingface model name.
784+
min_score: 0.5 # The lowest quality score threshold to keep the sample.
785+
api_endpoint: null # URL endpoint for the API.
786+
response_path: null # Path to extract content from the API response. Defaults to 'choices.0.message.content'.
787+
input_keys: ['text'] # Sub set of keys in the sample. Support data with multi fields such as 'query', 'analysis' and 'answer' in RFT data.
788+
field_names: ['text'] # Corresponding field names for input keys.
789+
system_prompt: null # System prompt for the task.
790+
input_template: null # The input template.
791+
field_template: null # Template for each field in the prompt.
792+
try_num: 3 # The number of retry attempts when there is an API call error or outputs parsing error.
793+
enable_vllm: false # If true, use VLLM for loading hugging face or local llm. Otherwise, use API for reference.
794+
model_params: {} # Parameters for initializing the API model.
795+
sampling_params: {} # Extra parameters passed to the API call. e.g {'temperature': 0.9, 'top_p': 0.95}
796+
valid_dataset: null # The dataset to use for validation
797+
task_desc: null # The description of the validation task.
798+
n_shot: null # The number of shots in validation.
757799
- maximum_line_length_filter: # filter text with the maximum length of lines out of specific range
758800
min_len: 10 # the min length of filter range
759801
max_len: 10000 # the max length of filter range
@@ -795,6 +837,18 @@ process:
795837
- text_action_filter: # filter text according the number of action verb
796838
lang: en # consider the words in what language
797839
min_action_num: 1 # text will be filtered whose verbs less the min action number
840+
- text_embd_similarity_filter: # Filter to keep texts whose average embedding similarity to a set of given validation texts falls within a specific range.
841+
api_or_hf_model: text-embedding-v4 # API or huggingface embedding model name
842+
is_hf_model: false # indicates if the model is from HuggingFace
843+
api_endpoint: embeddings # embedding URL endpoint for the API
844+
response_path: data.0.embedding # path to extract content from the API response
845+
model_params: null # parameters for initializing the API model
846+
min_score: 0.1 # the min average similarity to keep samples
847+
max_score: 1.0 # the max average similarity to keep samples
848+
valid_dataset: null # the dataset to use for validation
849+
ebd_dim: 1024 # the embedding's dimension via API
850+
pooling: null # strategy to extract embedding from the hidden states
851+
input_template: null # template for building the model input.
798852
- text_entity_dependency_filter: # filter text without non-independent entity nouns
799853
lang: en # consider the words in what language
800854
min_dependency_num: 1 # the min number of adjacent edges of a non-independent noun in dependency tree

data_juicer/ops/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ def timing_context(description):
1616
with timing_context('Importing operator modules'):
1717
from . import aggregator, deduplicator, filter, grouper, mapper, selector
1818
from .base_op import (
19+
ATTRIBUTION_FILTERS,
1920
NON_STATS_FILTERS,
2021
OPERATORS,
2122
TAGGING_OPS,

data_juicer/ops/base_op.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
UNFORKABLE = Registry("Unforkable")
1717
NON_STATS_FILTERS = Registry("Non-stats Filters")
1818
TAGGING_OPS = Registry("Tagging Operators")
19+
ATTRIBUTION_FILTERS = Registry("Attribution Filters")
1920

2021

2122
def convert_list_dict_to_dict_list(samples):

data_juicer/ops/filter/__init__.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,16 @@
1818
from .image_text_matching_filter import ImageTextMatchingFilter
1919
from .image_text_similarity_filter import ImageTextSimilarityFilter
2020
from .image_watermark_filter import ImageWatermarkFilter
21+
from .in_context_influence_filter import InContextInfluenceFilter
22+
from .instruction_following_difficulty_filter import (
23+
InstructionFollowingDifficultyFilter,
24+
)
2125
from .language_id_score_filter import LanguageIDScoreFilter
2226
from .llm_analysis_filter import LLMAnalysisFilter
2327
from .llm_difficulty_score_filter import LLMDifficultyScoreFilter
28+
from .llm_perplexity_filter import LLMPerplexityFilter
2429
from .llm_quality_score_filter import LLMQualityScoreFilter
30+
from .llm_task_relevance_filter import LLMTaskRelevanceFilter
2531
from .maximum_line_length_filter import MaximumLineLengthFilter
2632
from .perplexity_filter import PerplexityFilter
2733
from .phrase_grounding_recall_filter import PhraseGroundingRecallFilter
@@ -31,6 +37,7 @@
3137
from .stopwords_filter import StopWordsFilter
3238
from .suffix_filter import SuffixFilter
3339
from .text_action_filter import TextActionFilter
40+
from .text_embd_similarity_filter import TextEmbdSimilarityFilter
3441
from .text_entity_dependency_filter import TextEntityDependencyFilter
3542
from .text_length_filter import TextLengthFilter
3643
from .text_pair_similarity_filter import TextPairSimilarityFilter
@@ -71,9 +78,13 @@
7178
"ImageTextSimilarityFilter",
7279
"ImageWatermarkFilter",
7380
"LanguageIDScoreFilter",
81+
"InContextInfluenceFilter",
82+
"InstructionFollowingDifficultyFilter",
7483
"LLMAnalysisFilter",
7584
"LLMQualityScoreFilter",
85+
"LLMPerplexityFilter",
7686
"LLMDifficultyScoreFilter",
87+
"LLMTaskRelevanceFilter",
7788
"MaximumLineLengthFilter",
7889
"PerplexityFilter",
7990
"PhraseGroundingRecallFilter",
@@ -83,6 +94,7 @@
8394
"StopWordsFilter",
8495
"SuffixFilter",
8596
"TextActionFilter",
97+
"TextEmbdSimilarityFilter",
8698
"TextEntityDependencyFilter",
8799
"TextLengthFilter",
88100
"TextPairSimilarityFilter",
Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
from typing import Dict, List, Optional
2+
3+
from datasets import Dataset
4+
from loguru import logger
5+
6+
from data_juicer.ops.base_op import ATTRIBUTION_FILTERS, OPERATORS
7+
from data_juicer.ops.filter.llm_perplexity_filter import LLMPerplexityFilter
8+
from data_juicer.utils.constant import Fields, StatsKeys
9+
from data_juicer.utils.lazy_loader import LazyLoader
10+
11+
torch = LazyLoader("torch")
12+
transformers = LazyLoader("transformers")
13+
14+
OP_NAME = "in_context_influence_filter"
15+
16+
17+
@OPERATORS.register_module(OP_NAME)
18+
@ATTRIBUTION_FILTERS.register_module(OP_NAME)
19+
class InContextInfluenceFilter(LLMPerplexityFilter):
20+
"""Filter to keep texts whose in-context influence upon validation set within a specific range."""
21+
22+
# This operator is currently under development and evaluation as part of an ongoing research project.
23+
# The Data-Juicer team retains full copyright over this operator.
24+
25+
_accelerator = "cuda"
26+
27+
def __init__(
28+
self,
29+
valid_dataset: Optional[List[Dict]] = None,
30+
task_desc: str = None,
31+
valid_as_demo: bool = False,
32+
n_shot: Optional[int] = None,
33+
*args,
34+
**kwargs,
35+
):
36+
"""
37+
Initialization method.
38+
39+
:param valid_dataset: The dataset to use for validation.
40+
If None, 'self.prepare_valid_feature' should be manually called before applying the filter.
41+
:param task_desc: The description of the validation task.
42+
:param valid_as_demo: If true, score = L(A|Q) / L(A|task_desc, Q_v, A_v, Q);
43+
If false, score = L(A_v|Q) L(A_v|task_desc, Q, A, Q_v).
44+
:param n_shot: The number of shots in validation.
45+
"""
46+
super().__init__(*args, **kwargs)
47+
self.valid_as_demo = valid_as_demo
48+
self.task_desc = task_desc
49+
self.valid_feature = {}
50+
if valid_dataset is not None:
51+
self.prepare_valid_feature(Dataset.from_list(valid_dataset), task_desc, n_shot)
52+
else:
53+
logger.warning(
54+
f"valid_dataset and task_desc are both None when initializing {OP_NAME}. \
55+
'prepare_valid_feature' method should be manually called before applying the filter."
56+
)
57+
58+
@property
59+
def valid_feature_ready(self):
60+
return "valid_samples" in self.valid_feature and "valid_losses" in self.valid_feature
61+
62+
def prepare_valid_feature(self, dataset=None, task_desc=None, n_shot=None, *args, **kwargs):
63+
n_shot = n_shot or len(dataset)
64+
self.valid_feature["valid_samples"] = []
65+
self.valid_feature["valid_losses"] = []
66+
for i, sample in enumerate(dataset):
67+
if i >= n_shot:
68+
break
69+
sample_w_msgs = self.sample_with_messages(sample, system_prompt=task_desc)
70+
self.valid_feature["valid_samples"].append(sample_w_msgs)
71+
loss = self._loss(sample_w_msgs)
72+
self.valid_feature["valid_losses"].append(loss)
73+
74+
def compute_stats_single(self, sample, rank=None):
75+
# check if it's computed already
76+
if StatsKeys.in_context_influence in sample[Fields.stats]:
77+
return sample
78+
79+
assert self.valid_feature_ready, "Validation feature not ready yet. Call prepare_valid_feature first."
80+
81+
sample_w_msgs = self.sample_with_messages(sample)
82+
83+
scores = []
84+
if self.valid_as_demo:
85+
# L(A|Q) / L(A|Q_v, A_v, Q)
86+
loss_wo_demo = self._loss(sample_w_msgs, rank=rank)
87+
for valid_sample in self.valid_feature["valid_samples"]:
88+
loss_w_demo = self._loss(sample_w_msgs, pre_example=valid_sample, rank=rank)
89+
scores.append(loss_wo_demo / loss_w_demo)
90+
else:
91+
# L(A_v|Q_v) / L(A_v|Q, A, Q_v)
92+
for valid_sample, loss_wo_demo in zip(
93+
self.valid_feature["valid_samples"], self.valid_feature["valid_losses"]
94+
):
95+
loss_w_demo = self._loss(valid_sample, pre_example=sample_w_msgs, rank=rank)
96+
scores.append(loss_wo_demo / loss_w_demo)
97+
98+
# TODO: aggregation strategies
99+
in_context_influence = sum(scores) / len(scores)
100+
sample[Fields.stats][StatsKeys.in_context_influence] = in_context_influence
101+
102+
return sample
103+
104+
def process_single(self, sample):
105+
score = sample[Fields.stats][StatsKeys.in_context_influence]
106+
if score is None:
107+
return True
108+
109+
return self.min_score <= score <= self.max_score
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
import logging
2+
3+
from data_juicer.ops.base_op import OPERATORS
4+
from data_juicer.ops.filter.llm_perplexity_filter import LLMPerplexityFilter
5+
from data_juicer.utils.constant import Fields, StatsKeys
6+
from data_juicer.utils.lazy_loader import LazyLoader
7+
8+
torch = LazyLoader("torch")
9+
transformers = LazyLoader("transformers")
10+
11+
logger = logging.getLogger(__name__)
12+
logging.basicConfig(level=logging.INFO)
13+
14+
OP_NAME = "instruction_following_difficulty_filter"
15+
16+
17+
@OPERATORS.register_module(OP_NAME)
18+
class InstructionFollowingDifficultyFilter(LLMPerplexityFilter):
19+
"""Filter to keep texts whose instruction follows difficulty (IFD, https://arxiv.org/abs/2308.12032)
20+
falls within a specific range."""
21+
22+
_accelerator = "cuda"
23+
24+
def compute_stats_single(self, sample, rank=None):
25+
26+
# check if it's computed already
27+
if StatsKeys.ifd_score in sample[Fields.stats]:
28+
return sample
29+
30+
sample_w_msgs = self.sample_with_messages(sample)
31+
msgs_wo_query = sample_w_msgs["messages"][-1:]
32+
sample_w_msg_wo_query = dict(**sample_w_msgs)
33+
sample_w_msg_wo_query.update({"messages": msgs_wo_query})
34+
35+
loss_w_query = self._loss(sample_w_msgs, rank)
36+
loss_wo_query = self._loss(sample_w_msg_wo_query, rank)
37+
sample[Fields.stats][StatsKeys.ifd_score] = loss_w_query / loss_wo_query
38+
39+
return sample
40+
41+
def process_single(self, sample):
42+
score = sample[Fields.stats][StatsKeys.ifd_score]
43+
44+
return self.min_score <= score <= self.max_score

0 commit comments

Comments
 (0)