Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 59 additions & 30 deletions src/transformers/generation/configuration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,6 +350,11 @@ class GenerationConfig(PushToHubMixin):
_original_object_hash: int | None

def __init__(self, **kwargs):
# Snapshot of the attributes the caller explicitly provided (before the `kwargs.pop(...)` calls below
# consume them). Used by `validate()` to restrict "minor issue" warnings to flags actually set by the user,
# as opposed to defaults inherited from a model's `generation_config.json`.
user_set_attributes = {k for k in kwargs if not k.startswith("_")}

# Parameters that control the length of the output
self.max_length = kwargs.pop("max_length", None)
self.max_new_tokens = kwargs.pop("max_new_tokens", None)
Expand Down Expand Up @@ -466,7 +471,7 @@ def __init__(self, **kwargs):
)

# Validate the values of the attributes
self.validate()
self.validate(user_set_attributes=user_set_attributes)

def __hash__(self):
return hash(self.to_json_string(ignore_metadata=True))
Expand Down Expand Up @@ -587,7 +592,7 @@ def _get_default_generation_params() -> dict[str, Any]:
"diversity_penalty": 0.0,
}

def validate(self, strict=False):
def validate(self, strict=False, user_set_attributes: set[str] | None = None):
"""
Validates the values of the attributes of the [`GenerationConfig`] instance. Raises exceptions in the presence
of parameterization that can be detected as incorrect from the configuration instance alone.
Expand All @@ -597,9 +602,17 @@ def validate(self, strict=False):

Args:
strict (bool): If True, raise an exception for any issues found. If False, only log issues.
user_set_attributes (set[str], *optional*): Names of attributes the caller explicitly provided. When
supplied, "minor issue" warnings about conflicting flag combinations (e.g. sampling-only flags set
while `do_sample=False`) only fire if the conflicting flag is in this set -- avoiding noisy warnings
when the value was inherited from a model's default `generation_config.json`. When `None`, all set
attributes are considered user-set (backward-compatible behavior for direct `validate()` calls).
"""
minor_issues = {} # format: {attribute_name: issue_description}

def _is_user_set(attr: str) -> bool:
return user_set_attributes is None or attr in user_set_attributes

# 1. Validation of individual attributes
# 1.1. Decoding attributes
if self.early_stopping not in {None, True, False, "never"}:
Expand Down Expand Up @@ -633,50 +646,64 @@ def validate(self, strict=False):

# 2. Validation of attribute combinations
# 2.1. detect sampling-only parameterization when not in sampling mode

#
# Note that we check `is not True` in purpose. Boolean fields can also be `None` so we
# have to be explicit. Value of `None` is same as having `False`, i.e. the default value
if self.do_sample is not True:
#
# The warning is suppressed for flags that weren't explicitly set by the caller (see `_is_user_set`): values
# inherited from a model's `generation_config.json` are harmless when the user opts into greedy decoding.
# We also require `do_sample` itself to be user-set -- otherwise the non-sampling mode was inherited and the
# user never expressed intent to skip sampling, so flagging their sampling kwargs would be misleading.
if self.do_sample is not True and _is_user_set("do_sample"):
greedy_wrong_parameter_msg = (
"`do_sample` is set not to set `True`. However, `{flag_name}` is set to `{flag_value}` -- this flag is only "
"used in sample-based generation modes. You should set `do_sample=True` or unset `{flag_name}`."
"`do_sample` is set to `{do_sample}`. However, `{flag_name}` is set to `{flag_value}` -- this flag is "
"only used in sample-based generation modes. You should set `do_sample=True` or unset `{flag_name}`."
Comment on lines -639 to +660
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

great opportunity to move to hub-dataclass validation 😄

)
if self.temperature is not None and self.temperature != 1.0:
if self.temperature is not None and self.temperature != 1.0 and _is_user_set("temperature"):
minor_issues["temperature"] = greedy_wrong_parameter_msg.format(
flag_name="temperature", flag_value=self.temperature
do_sample=self.do_sample, flag_name="temperature", flag_value=self.temperature
)
if self.top_p is not None and self.top_p != 1.0 and _is_user_set("top_p"):
minor_issues["top_p"] = greedy_wrong_parameter_msg.format(
do_sample=self.do_sample, flag_name="top_p", flag_value=self.top_p
)
if self.top_p is not None and self.top_p != 1.0:
minor_issues["top_p"] = greedy_wrong_parameter_msg.format(flag_name="top_p", flag_value=self.top_p)
if self.min_p is not None:
minor_issues["min_p"] = greedy_wrong_parameter_msg.format(flag_name="min_p", flag_value=self.min_p)
if self.top_h is not None:
minor_issues["top_h"] = greedy_wrong_parameter_msg.format(flag_name="top_h", flag_value=self.top_h)
if self.typical_p is not None and self.typical_p != 1.0:
if self.min_p is not None and _is_user_set("min_p"):
minor_issues["min_p"] = greedy_wrong_parameter_msg.format(
do_sample=self.do_sample, flag_name="min_p", flag_value=self.min_p
)
if self.top_h is not None and _is_user_set("top_h"):
minor_issues["top_h"] = greedy_wrong_parameter_msg.format(
do_sample=self.do_sample, flag_name="top_h", flag_value=self.top_h
)
if self.typical_p is not None and self.typical_p != 1.0 and _is_user_set("typical_p"):
minor_issues["typical_p"] = greedy_wrong_parameter_msg.format(
flag_name="typical_p", flag_value=self.typical_p
do_sample=self.do_sample, flag_name="typical_p", flag_value=self.typical_p
)
if self.top_k is not None and self.top_k != 50 and _is_user_set("top_k"):
minor_issues["top_k"] = greedy_wrong_parameter_msg.format(
do_sample=self.do_sample, flag_name="top_k", flag_value=self.top_k
)
if self.top_k is not None and self.top_k != 50:
minor_issues["top_k"] = greedy_wrong_parameter_msg.format(flag_name="top_k", flag_value=self.top_k)
if self.epsilon_cutoff is not None and self.epsilon_cutoff != 0.0:
if self.epsilon_cutoff is not None and self.epsilon_cutoff != 0.0 and _is_user_set("epsilon_cutoff"):
minor_issues["epsilon_cutoff"] = greedy_wrong_parameter_msg.format(
flag_name="epsilon_cutoff", flag_value=self.epsilon_cutoff
do_sample=self.do_sample, flag_name="epsilon_cutoff", flag_value=self.epsilon_cutoff
)
if self.eta_cutoff is not None and self.eta_cutoff != 0.0:
if self.eta_cutoff is not None and self.eta_cutoff != 0.0 and _is_user_set("eta_cutoff"):
minor_issues["eta_cutoff"] = greedy_wrong_parameter_msg.format(
flag_name="eta_cutoff", flag_value=self.eta_cutoff
do_sample=self.do_sample, flag_name="eta_cutoff", flag_value=self.eta_cutoff
)

# 2.2. detect beam-only parameterization when not in beam mode
if self.num_beams is None or self.num_beams == 1:
# 2.2. detect beam-only parameterization when not in beam mode. Same provenance filtering as above --
# both `num_beams` and the beam-only flag must be user-set for the warning to fire.
if (self.num_beams is None or self.num_beams == 1) and _is_user_set("num_beams"):
single_beam_wrong_parameter_msg = (
"`num_beams` is set to {num_beams}. However, `{flag_name}` is set to `{flag_value}` -- this flag is only used "
"in beam-based generation modes. You should set `num_beams>1` or unset `{flag_name}`."
"`num_beams` is set to {num_beams}. However, `{flag_name}` is set to `{flag_value}` -- this flag is "
"only used in beam-based generation modes. You should set `num_beams>1` or unset `{flag_name}`."
)
if self.early_stopping is not None and self.early_stopping is not False:
if self.early_stopping is not None and self.early_stopping is not False and _is_user_set("early_stopping"):
minor_issues["early_stopping"] = single_beam_wrong_parameter_msg.format(
num_beams=self.num_beams, flag_name="early_stopping", flag_value=self.early_stopping
)
if self.length_penalty is not None and self.length_penalty != 1.0:
if self.length_penalty is not None and self.length_penalty != 1.0 and _is_user_set("length_penalty"):
minor_issues["length_penalty"] = single_beam_wrong_parameter_msg.format(
num_beams=self.num_beams, flag_name="length_penalty", flag_value=self.length_penalty
)
Expand Down Expand Up @@ -1232,8 +1259,10 @@ def update(self, defaults_only=False, allow_custom_entries=False, **kwargs):
setattr(self, key, value)
to_remove.append(key)

# Confirm that the updated instance is still valid
self.validate()
# Confirm that the updated instance is still valid. Only attributes *explicitly* updated in this call count
# as user-set for warning purposes: defaults inherited from a model's config shouldn't emit warnings.
user_set_attributes = set() if defaults_only else set(to_remove)
self.validate(user_set_attributes=user_set_attributes)

# Remove all the attributes that were updated, without modifying the input dict
unused_kwargs = {key: value for key, value in kwargs.items() if key not in to_remove}
Expand Down
96 changes: 86 additions & 10 deletions tests/generation/test_configuration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,32 +157,49 @@ def test_validate(self):
GenerationConfig()
self.assertEqual(len(captured_logs.out), 0)

# Inconsequent but technically wrong configuration will throw a warning (e.g. setting sampling
# parameters with `do_sample=False`). May be escalated to an error in the future.
# Inconsequent but technically wrong configuration will throw a warning (e.g. requesting an extra output
# without `return_dict_in_generate=True`). May be escalated to an error in the future.
logger.warning_once.cache_clear()
with CaptureLogger(logger) as captured_logs:
GenerationConfig(return_dict_in_generate=False, output_scores=True)
self.assertNotEqual(len(captured_logs.out), 0)

# Explicitly setting a sampling flag alongside `do_sample=False` still warns: this is a user-level mistake.
logger.warning_once.cache_clear()
with CaptureLogger(logger) as captured_logs:
generation_config_bad_temperature = GenerationConfig(do_sample=False, temperature=0.5) # store for later
self.assertNotEqual(len(captured_logs.out), 0)

# Expanding on the case above, we can update a bad configuration to get rid of the warning. Ideally,
# that is done by unsetting the parameter (i.e. setting it to None)
# But a value inherited from a model's default config (i.e. not in this update's kwargs) does NOT warn: in
# the real world, `generate(do_sample=False)` on a model whose `generation_config.json` has `temperature=0.6`
# would otherwise log a useless warning.
logger.warning_once.cache_clear()
base_config = GenerationConfig(do_sample=True, temperature=0.6) # mimics a model's default config
with CaptureLogger(logger) as captured_logs:
# BAD - 0.9 means it is still set, we should warn
generation_config_bad_temperature.update(temperature=0.9)
self.assertNotEqual(len(captured_logs.out), 0)
base_config.update(do_sample=False)
self.assertEqual(len(captured_logs.out), 0)

# Inverse provenance case: `do_sample` inherited from a model's config (so not user-set this call), user only
# sets a sampling flag. The conflict shouldn't produce noise because the user never asked for greedy.
logger.warning_once.cache_clear()
greedy_hub_config = GenerationConfig(do_sample=False) # mimics a model's default config forcing greedy
with CaptureLogger(logger) as captured_logs:
greedy_hub_config.update(top_p=0.8)
self.assertEqual(len(captured_logs.out), 0)
Comment on lines +182 to +188
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i am not sure about this one. Beginner users might expect this to just work and sample with top-p, while we silently fallback to greedy


# Updating only `temperature` (do_sample was pre-existing, i.e. "from the hub") does NOT warn: we only flag
# the conflict when both sides of the pair were set by the caller in the same context.
logger.warning_once.cache_clear()
with CaptureLogger(logger) as captured_logs:
# CORNER CASE - 1.0 is the default, we can't detect whether it is set by the user or not, we shouldn't warn
generation_config_bad_temperature.update(temperature=1.0)
generation_config_bad_temperature.update(temperature=0.9)
self.assertEqual(len(captured_logs.out), 0)

# But setting both in the same `update()` call DOES warn.
logger.warning_once.cache_clear()
with CaptureLogger(logger) as captured_logs:
generation_config_bad_temperature.update(do_sample=False, temperature=0.9)
self.assertNotEqual(len(captured_logs.out), 0)

logger.warning_once.cache_clear()
with CaptureLogger(logger) as captured_logs:
# OK - None means it is unset, nothing to warn about
Expand Down Expand Up @@ -223,13 +240,72 @@ def test_validate(self):
self.assertTrue(len(captured_logs.out) > 400) # long log
self.assertNotIn("Set `TRANSFORMERS_VERBOSITY=info` for more details", captured_logs.out)

# Finally, we can set `strict=True` to raise an exception on what would otherwise be a warning.
# Finally, we can set `strict=True` to raise an exception on what would otherwise be a warning. Note:
# `validate()` called directly (no `user_set_attributes`) treats every set value as user-set, preserving the
# "refuse to save a bad config" behavior below.
generation_config = GenerationConfig()
generation_config.temperature = 0.5
generation_config.do_sample = False
with self.assertRaises(ValueError):
generation_config.validate(strict=True)

def test_validate_sampling_flag_provenance(self):
"""
Dedicated coverage for the provenance-aware warning rule on sampling-only flags:
we only warn when BOTH `do_sample=False` AND a conflicting sampling flag (e.g. `top_p`, `temperature`)
were explicitly provided by the caller in the same context -- not when one of them was inherited from a
model's `generation_config.json`.
"""
logger = transformers_logging.get_logger("transformers.generation.configuration_utils")

def _warn_count(fn):
logger.warning_once.cache_clear()
with CaptureLogger(logger) as captured:
fn()
return len(captured.out)

# 1. Hub config sets `temperature`, user does only `generate(do_sample=False)` -> NO warning.
# (Emulates: model whose `generation_config.json` carries `do_sample=True, temperature=0.6`, user
# explicitly asks for greedy decoding.)
def case_hub_temp_user_do_sample_only():
cfg = GenerationConfig(do_sample=True, temperature=0.6) # stands in for the hub default
cfg.update(do_sample=False)

self.assertEqual(_warn_count(case_hub_temp_user_do_sample_only), 0)

# 2. User explicitly sets BOTH `do_sample=False` and `top_p=0.8` in the same call -> WARN.
self.assertNotEqual(_warn_count(lambda: GenerationConfig(do_sample=False, top_p=0.8)), 0)

# 3. User explicitly sets only `do_sample=False` (no sampling flag) -> NO warning, even though
# attribute defaults (like `top_k=50`) may be present.
self.assertEqual(_warn_count(lambda: GenerationConfig(do_sample=False)), 0)

# 4. Hub config forces greedy (`do_sample=False`), user sets only `top_p=0.8` -> NO warning:
# `do_sample` was inherited, not user-expressed intent, so flagging their `top_p` would be misleading.
def case_hub_greedy_user_top_p():
cfg = GenerationConfig(do_sample=False) # stands in for the hub default
cfg.update(top_p=0.8)

self.assertEqual(_warn_count(case_hub_greedy_user_top_p), 0)

# 5. User sets `do_sample=False` and `temperature=0.5` via a single `update()` call -> WARN.
def case_update_both_sides():
cfg = GenerationConfig()
cfg.update(do_sample=False, temperature=0.5)

self.assertNotEqual(_warn_count(case_update_both_sides), 0)

# 6. Same idea for beam flags: user only asks for `num_beams=1`, hub default has `length_penalty=0.8`
# -> NO warning.
def case_hub_length_penalty_user_num_beams_only():
cfg = GenerationConfig(num_beams=4, length_penalty=0.8) # stands in for the hub default
cfg.update(num_beams=1)

self.assertEqual(_warn_count(case_hub_length_penalty_user_num_beams_only), 0)

# 7. User sets BOTH `num_beams=1` and `length_penalty=0.8` explicitly -> WARN.
self.assertNotEqual(_warn_count(lambda: GenerationConfig(num_beams=1, length_penalty=0.8)), 0)

def test_refuse_to_save(self):
"""Tests that we refuse to save a generation config that fails validation."""

Expand Down
Loading