Skip to content

fix: correct NaN preprocessing and torch.compile dead assignment#401

Open
Angelopgit wants to merge 1 commit intogoogle-research:masterfrom
Angelopgit:fix/torch-compile-dead-assignment
Open

fix: correct NaN preprocessing and torch.compile dead assignment#401
Angelopgit wants to merge 1 commit intogoogle-research:masterfrom
Angelopgit:fix/torch-compile-dead-assignment

Conversation

@Angelopgit
Copy link
Copy Markdown

Summary

Three silent bugs in the inference preprocessing path, all present in the current main:

1. strip_leading_nans returns the wrong value for all-NaN input (src/ and v1/)

np.argmax(~isnan) returns 0 when the input is entirely NaN (all-False mask — NumPy does not raise). So arr[0:] silently returns the full NaN array instead of the empty array the docstring promises. Downstream, linear_interpolation converts the NaN array to all-zeros via np.where(np.isfinite(...), ..., 0.0), and the model receives a plausible-looking but meaningless zero-padded context with no error.

# Reproduce
import numpy as np
arr = np.full(5, np.nan)
isnan = np.isnan(arr)
print(np.argmax(~isnan))  # 0  — looks like a valid index, is not
print(arr[0:])            # [nan nan nan nan nan]  — should be []

2. linear_interpolation uses ambiguous bool coercion on a NumPy array (src/)

The except ValueError fallback does if non_nans_values: where non_nans_values is a NumPy array. NumPy already emits DeprecationWarning for truth-value testing of arrays with more than one element; a future release will raise ValueError. The v1 version of this function already uses len(...) > 0 — the 2.5 version regressed. Changed to .size > 0 to match.

3. torch.compile dead assignment in load_checkpoint (src/, PyTorch backend)

# In TimesFM_2p5_200M_torch_module.load_checkpoint:
self = torch.compile(self)   # only rebinds the local variable

Rebinding self inside a method never affects the caller's reference. torch.compile returns a new OptimizedModule wrapper; assigning it to the local self discards it immediately. Since torch_compile=True is the default, every PyTorch user was getting unoptimized inference while believing the model was compiled.

Moved the torch.compile call to _from_pretrained where instance.model = torch.compile(instance.model) actually takes effect.

Changes

File Change
src/timesfm/timesfm_2p5/timesfm_2p5_base.py Guard in strip_leading_nans; fix bool coercion in linear_interpolation
v1/src/timesfm/timesfm_base.py Guard in strip_leading_nans
src/timesfm/timesfm_2p5/timesfm_2p5_torch.py Move torch.compile to _from_pretrained; simplify load_checkpoint

Three bugs in the inference path, all silently producing wrong results:

1. strip_leading_nans (src/ and v1/): np.argmax(~isnan) returns 0 when
   the input is all-NaN, so arr[0:] returns the full NaN array instead of
   the empty array promised by the docstring. Add an early-exit guard.

2. linear_interpolation (src/): the except-ValueError fallback uses
   `if non_nans_values:` where non_nans_values is a NumPy array.
   NumPy already emits DeprecationWarning for truth-value testing of
   arrays; a future release will raise ValueError. Use .size > 0.

3. load_checkpoint / _from_pretrained (torch): `self = torch.compile(self)`
   inside a method only rebinds the local variable — the caller's
   reference to instance.model is never replaced. torch_compile=True is
   the default, so every PyTorch user was getting uncompiled inference
   while believing the model was compiled. Move the torch.compile call
   to _from_pretrained where the assignment can actually take effect.
@google-cla
Copy link
Copy Markdown

google-cla Bot commented Apr 9, 2026

Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA).

View this failed invocation of the CLA check for more information.

For the most up to date status, view the checks section at the bottom of the pull request.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant