@@ -91,13 +91,24 @@ def close(self):
9191class WandbTracker (Tracker ):
9292 """Tracks experiments using Weights & Biases."""
9393
94+ @staticmethod
95+ def _default_wandb_dir () -> str :
96+ # specforge/tracker.py -> project root is one level up
97+ return os .path .normpath (os .path .join (os .path .dirname (__file__ ), ".." , "wandb" ))
98+
9499 @classmethod
95100 def validate_args (cls , parser , args ):
96101 if wandb is None :
97102 parser .error (
98103 "To use --report-to wandb, you must install wandb: 'pip install wandb'"
99104 )
100105
106+ if args .wandb_dir is None :
107+ args .wandb_dir = cls ._default_wandb_dir ()
108+
109+ if args .wandb_offline :
110+ return
111+
101112 if args .wandb_key is not None :
102113 return
103114
@@ -128,10 +139,21 @@ def validate_args(cls, parser, args):
128139 def __init__ (self , args , output_dir : str ):
129140 super ().__init__ (args , output_dir )
130141 if self .rank == 0 :
131- wandb .login (key = args .wandb_key )
132- wandb .init (
133- project = args .wandb_project , name = args .wandb_name , config = vars (args )
134- )
142+ if args .wandb_dir is None :
143+ args .wandb_dir = self ._default_wandb_dir ()
144+ os .makedirs (args .wandb_dir , exist_ok = True )
145+
146+ if not args .wandb_offline :
147+ wandb .login (key = args .wandb_key )
148+ init_kwargs = {
149+ "project" : args .wandb_project ,
150+ "name" : args .wandb_name ,
151+ "config" : vars (args ),
152+ "dir" : args .wandb_dir ,
153+ }
154+ if args .wandb_offline :
155+ init_kwargs ["mode" ] = "offline"
156+ wandb .init (** init_kwargs )
135157 self .is_initialized = True
136158
137159 def log (self , log_dict : Dict [str , Any ], step : Optional [int ] = None ):
0 commit comments