|
33 | 33 | synchronize |
34 | 34 | ) |
35 | 35 |
|
| 36 | +class EarlyStopping: |
| 37 | + def __init__(self, patience: int, min_delta: float, mode="max"): |
| 38 | + self.patience = patience |
| 39 | + self.min_delta = min_delta |
| 40 | + self.mode = mode # "max", "min", "percentage" |
| 41 | + self.best = None |
| 42 | + self.counter = 0 |
| 43 | + |
| 44 | + def step(self, value): |
| 45 | + # Initialize best value on first call |
| 46 | + if self.best is None: |
| 47 | + self.best = value |
| 48 | + return False |
| 49 | + |
| 50 | + # Compute improvement depending on mode |
| 51 | + if self.mode == "max": |
| 52 | + improvement = value - self.best |
| 53 | + elif self.mode == "min": |
| 54 | + improvement = self.best - value |
| 55 | + elif self.mode == "percentage": |
| 56 | + if self.best == 0: |
| 57 | + improvement = 0 # avoid division by zero |
| 58 | + else: |
| 59 | + improvement = (value - self.best) / abs(self.best) |
| 60 | + |
| 61 | + # Check if improvement is sufficient |
| 62 | + if improvement > self.min_delta: |
| 63 | + self.best = value |
| 64 | + self.counter = 0 |
| 65 | + else: |
| 66 | + self.counter += 1 |
| 67 | + |
| 68 | + return self.counter >= self.patience |
| 69 | + |
36 | 70 |
|
37 | 71 | class Trainer: |
38 | 72 | def __init__(self, exp: Exp, args): |
@@ -234,7 +268,15 @@ def after_epoch(self): |
234 | 268 |
|
235 | 269 | if (self.epoch + 1) % self.exp.eval_interval == 0: |
236 | 270 | all_reduce_norm(self.model) |
237 | | - self.evaluate_and_save_model() |
| 271 | + ap50_95 = self.evaluate_and_save_model() |
| 272 | + |
| 273 | + # Early stopping |
| 274 | + if self.early_stopper is not None: |
| 275 | + if self.early_stopper.step(ap50_95): |
| 276 | + logger.info(f"Early stopping triggered at epoch {self.epoch}. " f"Best AP: {self.early_stopper.best}") |
| 277 | + # save best checkpoint before exiting |
| 278 | + self.save_ckpt("best_ckpt") |
| 279 | + raise SystemExit |
238 | 280 |
|
239 | 281 | def before_iter(self): |
240 | 282 | pass |
@@ -395,6 +437,7 @@ def evaluate_and_save_model(self): |
395 | 437 | } |
396 | 438 | self.mlflow_logger.save_checkpoints(self.args, self.exp, self.file_name, self.epoch, |
397 | 439 | metadata, update_best_ckpt) |
| 440 | + return ap50_95 |
398 | 441 |
|
399 | 442 | def save_ckpt(self, ckpt_name, update_best_ckpt=False, ap=None): |
400 | 443 | if self.rank == 0: |
|
0 commit comments