Skip to content

Commit d5fb617

Browse files
authored
Merge pull request #513 from bingyang-lei/feat/wandb-offline-dir
feat: add wandb offline mode and custom wandb directory support
2 parents e0fbb86 + 284db42 commit d5fb617

2 files changed

Lines changed: 39 additions & 4 deletions

File tree

specforge/args.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@ class TrackerArgs:
1111
wandb_project: str = None
1212
wandb_name: str = None
1313
wandb_key: str = None
14+
wandb_offline: bool = False
15+
wandb_dir: str = None
1416
swanlab_project: str = None
1517
swanlab_name: str = None
1618
swanlab_key: str = None
@@ -33,6 +35,17 @@ def add_args(parser: argparse.ArgumentParser) -> None:
3335
parser.add_argument("--wandb-project", type=str, default=None)
3436
parser.add_argument("--wandb-name", type=str, default=None)
3537
parser.add_argument("--wandb-key", type=str, default=None, help="W&B API key.")
38+
parser.add_argument(
39+
"--wandb-offline",
40+
action="store_true",
41+
help="Enable W&B offline mode and store logs locally.",
42+
)
43+
parser.add_argument(
44+
"--wandb-dir",
45+
type=str,
46+
default=None,
47+
help="Directory to store W&B files. Defaults to './wandb' under the project root when using W&B.",
48+
)
3649
# swanlab-specific args
3750
parser.add_argument(
3851
"--swanlab-project",

specforge/tracker.py

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -91,13 +91,24 @@ def close(self):
9191
class WandbTracker(Tracker):
9292
"""Tracks experiments using Weights & Biases."""
9393

94+
@staticmethod
95+
def _default_wandb_dir() -> str:
96+
# specforge/tracker.py -> project root is one level up
97+
return os.path.normpath(os.path.join(os.path.dirname(__file__), "..", "wandb"))
98+
9499
@classmethod
95100
def validate_args(cls, parser, args):
96101
if wandb is None:
97102
parser.error(
98103
"To use --report-to wandb, you must install wandb: 'pip install wandb'"
99104
)
100105

106+
if args.wandb_dir is None:
107+
args.wandb_dir = cls._default_wandb_dir()
108+
109+
if args.wandb_offline:
110+
return
111+
101112
if args.wandb_key is not None:
102113
return
103114

@@ -128,10 +139,21 @@ def validate_args(cls, parser, args):
128139
def __init__(self, args, output_dir: str):
129140
super().__init__(args, output_dir)
130141
if self.rank == 0:
131-
wandb.login(key=args.wandb_key)
132-
wandb.init(
133-
project=args.wandb_project, name=args.wandb_name, config=vars(args)
134-
)
142+
if args.wandb_dir is None:
143+
args.wandb_dir = self._default_wandb_dir()
144+
os.makedirs(args.wandb_dir, exist_ok=True)
145+
146+
if not args.wandb_offline:
147+
wandb.login(key=args.wandb_key)
148+
init_kwargs = {
149+
"project": args.wandb_project,
150+
"name": args.wandb_name,
151+
"config": vars(args),
152+
"dir": args.wandb_dir,
153+
}
154+
if args.wandb_offline:
155+
init_kwargs["mode"] = "offline"
156+
wandb.init(**init_kwargs)
135157
self.is_initialized = True
136158

137159
def log(self, log_dict: Dict[str, Any], step: Optional[int] = None):

0 commit comments

Comments
 (0)