Skip to content

Commit 34928a2

Browse files
committed
added early stopping with max, min and percentage mode
1 parent 6ddff48 commit 34928a2

File tree

3 files changed

+59
-1
lines changed

3 files changed

+59
-1
lines changed

tools/train.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,13 @@ def make_parser():
8888
Implemented loggers include `tensorboard`, `mlflow` and `wandb`.",
8989
default="tensorboard"
9090
)
91+
parser.add_argument(
92+
"--early-stopping",
93+
dest="early_stopping",
94+
default=False,
95+
action="store_true",
96+
help="Use early stopping to prevent overfitting.",
97+
)
9198
parser.add_argument(
9299
"opts",
93100
help="Modify config options using the command-line",
@@ -115,6 +122,13 @@ def main(exp: Exp, args):
115122
cudnn.benchmark = True
116123

117124
trainer = exp.get_trainer(args)
125+
126+
# configure early stopping parameters
127+
if args.early_stopping:
128+
# requires 1% relative improvement over 10 epochs to reset patience
129+
# available modes: "max", "min", "percentage"
130+
trainer.early_stopper = exp.get_early_stopping(patience=10, min_delta=0.01, mode="percentage")
131+
118132
trainer.train()
119133

120134

yolox/core/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,4 @@
44

55
from .launch import launch
66
from .trainer import Trainer
7+
from .trainer import EarlyStopping

yolox/core/trainer.py

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,40 @@
3333
synchronize
3434
)
3535

36+
class EarlyStopping:
37+
def __init__(self, patience: int, min_delta: float, mode="max"):
38+
self.patience = patience
39+
self.min_delta = min_delta
40+
self.mode = mode # "max", "min", "percentage"
41+
self.best = None
42+
self.counter = 0
43+
44+
def step(self, value):
45+
# Initialize best value on first call
46+
if self.best is None:
47+
self.best = value
48+
return False
49+
50+
# Compute improvement depending on mode
51+
if self.mode == "max":
52+
improvement = value - self.best
53+
elif self.mode == "min":
54+
improvement = self.best - value
55+
elif self.mode == "percentage":
56+
if self.best == 0:
57+
improvement = 0 # avoid division by zero
58+
else:
59+
improvement = (value - self.best) / abs(self.best)
60+
61+
# Check if improvement is sufficient
62+
if improvement > self.min_delta:
63+
self.best = value
64+
self.counter = 0
65+
else:
66+
self.counter += 1
67+
68+
return self.counter >= self.patience
69+
3670

3771
class Trainer:
3872
def __init__(self, exp: Exp, args):
@@ -234,7 +268,15 @@ def after_epoch(self):
234268

235269
if (self.epoch + 1) % self.exp.eval_interval == 0:
236270
all_reduce_norm(self.model)
237-
self.evaluate_and_save_model()
271+
ap50_95 = self.evaluate_and_save_model()
272+
273+
# Early stopping
274+
if self.early_stopper is not None:
275+
if self.early_stopper.step(ap50_95):
276+
logger.info(f"Early stopping triggered at epoch {self.epoch}. " f"Best AP: {self.early_stopper.best}")
277+
# save best checkpoint before exiting
278+
self.save_ckpt("best_ckpt")
279+
raise SystemExit
238280

239281
def before_iter(self):
240282
pass
@@ -395,6 +437,7 @@ def evaluate_and_save_model(self):
395437
}
396438
self.mlflow_logger.save_checkpoints(self.args, self.exp, self.file_name, self.epoch,
397439
metadata, update_best_ckpt)
440+
return ap50_95
398441

399442
def save_ckpt(self, ckpt_name, update_best_ckpt=False, ap=None):
400443
if self.rank == 0:

0 commit comments

Comments
 (0)