-
Notifications
You must be signed in to change notification settings - Fork 33k
Drop noisy generate warnings when do_sample=False (or num_beams=1) #45559
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
ArthurZucker
wants to merge
5
commits into
huggingface:main
Choose a base branch
from
ArthurZucker:drop-sampling-flag-warnings
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+145
−40
Open
Changes from all commits
Commits
Show all changes
5 commits
Select commit
Hold shift + click to select a range
18c5d9d
Drop noisy generate warnings when do_sample=False (or num_beams=1)
ArthurZucker 5b504f6
Make sampling/beam flag warnings provenance-aware
ArthurZucker 7bf56c2
Require both sides of the conflict to be user-set before warning
ArthurZucker c918a92
Add dedicated test for provenance-aware sampling/beam flag warnings
ArthurZucker 7bb735c
Merge branch 'main' into drop-sampling-flag-warnings
ArthurZucker File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Some comments aren't visible on the classic Files Changed page.
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -157,32 +157,49 @@ def test_validate(self): | |
| GenerationConfig() | ||
| self.assertEqual(len(captured_logs.out), 0) | ||
|
|
||
| # Inconsequent but technically wrong configuration will throw a warning (e.g. setting sampling | ||
| # parameters with `do_sample=False`). May be escalated to an error in the future. | ||
| # Inconsequent but technically wrong configuration will throw a warning (e.g. requesting an extra output | ||
| # without `return_dict_in_generate=True`). May be escalated to an error in the future. | ||
| logger.warning_once.cache_clear() | ||
| with CaptureLogger(logger) as captured_logs: | ||
| GenerationConfig(return_dict_in_generate=False, output_scores=True) | ||
| self.assertNotEqual(len(captured_logs.out), 0) | ||
|
|
||
| # Explicitly setting a sampling flag alongside `do_sample=False` still warns: this is a user-level mistake. | ||
| logger.warning_once.cache_clear() | ||
| with CaptureLogger(logger) as captured_logs: | ||
| generation_config_bad_temperature = GenerationConfig(do_sample=False, temperature=0.5) # store for later | ||
| self.assertNotEqual(len(captured_logs.out), 0) | ||
|
|
||
| # Expanding on the case above, we can update a bad configuration to get rid of the warning. Ideally, | ||
| # that is done by unsetting the parameter (i.e. setting it to None) | ||
| # But a value inherited from a model's default config (i.e. not in this update's kwargs) does NOT warn: in | ||
| # the real world, `generate(do_sample=False)` on a model whose `generation_config.json` has `temperature=0.6` | ||
| # would otherwise log a useless warning. | ||
| logger.warning_once.cache_clear() | ||
| base_config = GenerationConfig(do_sample=True, temperature=0.6) # mimics a model's default config | ||
| with CaptureLogger(logger) as captured_logs: | ||
| # BAD - 0.9 means it is still set, we should warn | ||
| generation_config_bad_temperature.update(temperature=0.9) | ||
| self.assertNotEqual(len(captured_logs.out), 0) | ||
| base_config.update(do_sample=False) | ||
| self.assertEqual(len(captured_logs.out), 0) | ||
|
|
||
| # Inverse provenance case: `do_sample` inherited from a model's config (so not user-set this call), user only | ||
| # sets a sampling flag. The conflict shouldn't produce noise because the user never asked for greedy. | ||
| logger.warning_once.cache_clear() | ||
| greedy_hub_config = GenerationConfig(do_sample=False) # mimics a model's default config forcing greedy | ||
| with CaptureLogger(logger) as captured_logs: | ||
| greedy_hub_config.update(top_p=0.8) | ||
| self.assertEqual(len(captured_logs.out), 0) | ||
|
Comment on lines
+182
to
+188
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. i am not sure about this one. Beginner users might expect this to just work and sample with top-p, while we silently fallback to greedy |
||
|
|
||
| # Updating only `temperature` (do_sample was pre-existing, i.e. "from the hub") does NOT warn: we only flag | ||
| # the conflict when both sides of the pair were set by the caller in the same context. | ||
| logger.warning_once.cache_clear() | ||
| with CaptureLogger(logger) as captured_logs: | ||
| # CORNER CASE - 1.0 is the default, we can't detect whether it is set by the user or not, we shouldn't warn | ||
| generation_config_bad_temperature.update(temperature=1.0) | ||
| generation_config_bad_temperature.update(temperature=0.9) | ||
| self.assertEqual(len(captured_logs.out), 0) | ||
|
|
||
| # But setting both in the same `update()` call DOES warn. | ||
| logger.warning_once.cache_clear() | ||
| with CaptureLogger(logger) as captured_logs: | ||
| generation_config_bad_temperature.update(do_sample=False, temperature=0.9) | ||
| self.assertNotEqual(len(captured_logs.out), 0) | ||
|
|
||
| logger.warning_once.cache_clear() | ||
| with CaptureLogger(logger) as captured_logs: | ||
| # OK - None means it is unset, nothing to warn about | ||
|
|
@@ -223,13 +240,72 @@ def test_validate(self): | |
| self.assertTrue(len(captured_logs.out) > 400) # long log | ||
| self.assertNotIn("Set `TRANSFORMERS_VERBOSITY=info` for more details", captured_logs.out) | ||
|
|
||
| # Finally, we can set `strict=True` to raise an exception on what would otherwise be a warning. | ||
| # Finally, we can set `strict=True` to raise an exception on what would otherwise be a warning. Note: | ||
| # `validate()` called directly (no `user_set_attributes`) treats every set value as user-set, preserving the | ||
| # "refuse to save a bad config" behavior below. | ||
| generation_config = GenerationConfig() | ||
| generation_config.temperature = 0.5 | ||
| generation_config.do_sample = False | ||
| with self.assertRaises(ValueError): | ||
| generation_config.validate(strict=True) | ||
|
|
||
| def test_validate_sampling_flag_provenance(self): | ||
| """ | ||
| Dedicated coverage for the provenance-aware warning rule on sampling-only flags: | ||
| we only warn when BOTH `do_sample=False` AND a conflicting sampling flag (e.g. `top_p`, `temperature`) | ||
| were explicitly provided by the caller in the same context -- not when one of them was inherited from a | ||
| model's `generation_config.json`. | ||
| """ | ||
| logger = transformers_logging.get_logger("transformers.generation.configuration_utils") | ||
|
|
||
| def _warn_count(fn): | ||
| logger.warning_once.cache_clear() | ||
| with CaptureLogger(logger) as captured: | ||
| fn() | ||
| return len(captured.out) | ||
|
|
||
| # 1. Hub config sets `temperature`, user does only `generate(do_sample=False)` -> NO warning. | ||
| # (Emulates: model whose `generation_config.json` carries `do_sample=True, temperature=0.6`, user | ||
| # explicitly asks for greedy decoding.) | ||
| def case_hub_temp_user_do_sample_only(): | ||
| cfg = GenerationConfig(do_sample=True, temperature=0.6) # stands in for the hub default | ||
| cfg.update(do_sample=False) | ||
|
|
||
| self.assertEqual(_warn_count(case_hub_temp_user_do_sample_only), 0) | ||
|
|
||
| # 2. User explicitly sets BOTH `do_sample=False` and `top_p=0.8` in the same call -> WARN. | ||
| self.assertNotEqual(_warn_count(lambda: GenerationConfig(do_sample=False, top_p=0.8)), 0) | ||
|
|
||
| # 3. User explicitly sets only `do_sample=False` (no sampling flag) -> NO warning, even though | ||
| # attribute defaults (like `top_k=50`) may be present. | ||
| self.assertEqual(_warn_count(lambda: GenerationConfig(do_sample=False)), 0) | ||
|
|
||
| # 4. Hub config forces greedy (`do_sample=False`), user sets only `top_p=0.8` -> NO warning: | ||
| # `do_sample` was inherited, not user-expressed intent, so flagging their `top_p` would be misleading. | ||
| def case_hub_greedy_user_top_p(): | ||
| cfg = GenerationConfig(do_sample=False) # stands in for the hub default | ||
| cfg.update(top_p=0.8) | ||
|
|
||
| self.assertEqual(_warn_count(case_hub_greedy_user_top_p), 0) | ||
|
|
||
| # 5. User sets `do_sample=False` and `temperature=0.5` via a single `update()` call -> WARN. | ||
| def case_update_both_sides(): | ||
| cfg = GenerationConfig() | ||
| cfg.update(do_sample=False, temperature=0.5) | ||
|
|
||
| self.assertNotEqual(_warn_count(case_update_both_sides), 0) | ||
|
|
||
| # 6. Same idea for beam flags: user only asks for `num_beams=1`, hub default has `length_penalty=0.8` | ||
| # -> NO warning. | ||
| def case_hub_length_penalty_user_num_beams_only(): | ||
| cfg = GenerationConfig(num_beams=4, length_penalty=0.8) # stands in for the hub default | ||
| cfg.update(num_beams=1) | ||
|
|
||
| self.assertEqual(_warn_count(case_hub_length_penalty_user_num_beams_only), 0) | ||
|
|
||
| # 7. User sets BOTH `num_beams=1` and `length_penalty=0.8` explicitly -> WARN. | ||
| self.assertNotEqual(_warn_count(lambda: GenerationConfig(num_beams=1, length_penalty=0.8)), 0) | ||
|
|
||
| def test_refuse_to_save(self): | ||
| """Tests that we refuse to save a generation config that fails validation.""" | ||
|
|
||
|
|
||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
great opportunity to move to hub-dataclass validation 😄