Skip to content

Commit 0c475c2

Browse files
authored
Daemonize (#21)
* daemonize support * add pid writting support * ensure defer cleanups work * reduce cyclomatic complexity of main * move setupLogging into its own function * Split operations so that we can display certain errors to stderr
1 parent 96fb7f0 commit 0c475c2

File tree

1 file changed

+184
-59
lines changed

1 file changed

+184
-59
lines changed

main.go

Lines changed: 184 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,9 @@ import (
77
"errors"
88
"flag"
99
"fmt"
10-
"log"
1110
"net/http"
1211
"os"
12+
"os/exec"
1313
"os/signal"
1414
"strings"
1515
"sync"
@@ -45,6 +45,8 @@ var (
4545
tokenRevalidationPeriod time.Duration
4646
logFile string
4747
logLevel string
48+
daemonize bool
49+
pidFile string
4850
)
4951

5052
type wsConnection struct {
@@ -101,14 +103,39 @@ var (
101103
lm sync.Mutex
102104
)
103105

104-
// ///////////////////
105-
// RATE LIMITER
106-
// ///////////////////
107106
type limiter struct {
108107
tokens int
109108
last time.Time
110109
}
111110

111+
// parseFlags parses command-line flags into global configuration variables.
112+
// It supports database driver, DSN, server address, log file/level, PID file, and various WS server limits.
113+
func parseFlags() {
114+
var origins string
115+
116+
flag.StringVar(&dbDriver, "db", "sqlite", "Database driver")
117+
flag.StringVar(&dbDSN, "dsn", "file:ws_tokens.db?cache=shared", "Database DSN")
118+
flag.StringVar(&serverAddr, "addr", ":8080", "Server address")
119+
flag.StringVar(&origins, "origins", "", "Allowed WS origins")
120+
flag.StringVar(&logFile, "log-file", "", "Path to log file")
121+
flag.StringVar(&logLevel, "log-level", "info", "Log level")
122+
flag.StringVar(&pidFile, "pid-file", "", "Path to PID file")
123+
flag.IntVar(&maxQueuedMessagesPerPlayer, "max-queued", 100, "")
124+
flag.IntVar(&rateLimit, "rate-limit", 10, "")
125+
flag.DurationVar(&ratePeriod, "rate-period", time.Second, "")
126+
flag.IntVar(&maxConnectionsPerPlayer, "max-conns", 5, "")
127+
flag.DurationVar(&tokenRevalidationPeriod, "revalidate-period", time.Minute, "")
128+
flag.DurationVar(&offlineTTL, "offline-ttl", 10*time.Second, "")
129+
flag.BoolVar(&daemonize, "daemon", false, "")
130+
flag.Parse()
131+
132+
if origins != "" {
133+
allowedOrigins = strings.Split(origins, ",")
134+
}
135+
}
136+
137+
// allow implements a simple token-based rate limiter for a given key.
138+
// Returns true if the action is allowed, false if the rate limit has been exceeded.
112139
func allow(key string, rate int, per time.Duration) bool {
113140
lm.Lock()
114141
defer lm.Unlock()
@@ -161,9 +188,8 @@ type pendingMessage struct {
161188
timestamp time.Time
162189
}
163190

164-
// ///////////////////
165-
// DB
166-
// ///////////////////
191+
// initDB initializes the global database connection based on dbDriver and dbDSN.
192+
// Returns an error if the driver is unsupported or if the connection cannot be established.
167193
func initDB() error {
168194
var err error
169195

@@ -187,7 +213,8 @@ func initDB() error {
187213
return db.Ping()
188214
}
189215

190-
// validateToken checks token validity in DB
216+
// validateToken checks whether a given token is valid in the database.
217+
// Returns the associated player/subject ID and true if valid, or empty string and false if invalid.
191218
func validateToken(token string, isServer bool) (string, bool) {
192219
const q = `
193220
SELECT IFNULL(player_id, subject_id)
@@ -208,11 +235,8 @@ func validateToken(token string, isServer bool) (string, bool) {
208235
return "", false
209236
}
210237

211-
// ///////////////////
212-
// CONNECTION MANAGEMENT
213-
// ///////////////////
214-
215-
// registerConnection registers a WS connection and stores its token
238+
// registerConnection registers a websocket connection for a player, storing the associated token.
239+
// It also increments the active connections metric and flushes any pending messages to the new connection.
216240
func registerConnection(playerID string, c *websocket.Conn, token string) {
217241
mu.Lock()
218242
if players[playerID] == nil {
@@ -226,7 +250,8 @@ func registerConnection(playerID string, c *websocket.Conn, token string) {
226250
flushPendingMessages(playerID, c)
227251
}
228252

229-
// unregisterConnection removes a WS connection
253+
// unregisterConnection removes a websocket connection for a player and decrements the active connections metric.
254+
// If no connections remain for the player, the player's entry is removed from the players map.
230255
func unregisterConnection(playerID string, c *websocket.Conn) {
231256
mu.Lock()
232257
defer mu.Unlock()
@@ -237,7 +262,8 @@ func unregisterConnection(playerID string, c *websocket.Conn) {
237262
connections.Dec()
238263
}
239264

240-
// closeAllConnections closes all WS connections (on shutdown)
265+
// closeAllConnections closes all active websocket connections for all players.
266+
// Typically used during server shutdown.
241267
func closeAllConnections() {
242268
mu.Lock()
243269
defer mu.Unlock()
@@ -248,7 +274,8 @@ func closeAllConnections() {
248274
}
249275
}
250276

251-
// flushPendingMessages sends queued messages to a newly connected player
277+
// flushPendingMessages sends any queued offline messages to a newly connected websocket.
278+
// Messages older than offlineTTL are ignored and removed.
252279
func flushPendingMessages(playerID string, c *websocket.Conn) {
253280
pendingMu.Lock()
254281
msgs := pendingMessages[playerID]
@@ -271,9 +298,9 @@ func flushPendingMessages(playerID string, c *websocket.Conn) {
271298
pendingMu.Unlock()
272299
}
273300

274-
// ///////////////////
275-
// WEBSOCKET HANDLER
276-
// ///////////////////
301+
// wsHandler handles incoming websocket upgrade requests from clients.
302+
// Validates the token, enforces connection limits, sets up heartbeat, and reads messages.
303+
// Connections are automatically unregistered on disconnect.
277304
func wsHandler(w http.ResponseWriter, r *http.Request) {
278305
token := r.URL.Query().Get("token")
279306
if token == "" {
@@ -345,9 +372,9 @@ func wsHandler(w http.ResponseWriter, r *http.Request) {
345372
}).Info("Player disconnected")
346373
}
347374

348-
// ///////////////////
349-
// PUBLISH HANDLER
350-
// ///////////////////
375+
// publishHandler handles incoming messages from authorized servers to a specific player.
376+
// Validates server token, enforces rate limits, delivers message immediately if the player is connected,
377+
// or queues the message for offline delivery if not.
351378
func publishHandler(w http.ResponseWriter, r *http.Request) {
352379
auth := r.Header.Get("Authorization")
353380
if !strings.HasPrefix(auth, "Bearer ") {
@@ -416,9 +443,9 @@ func publishHandler(w http.ResponseWriter, r *http.Request) {
416443
w.WriteHeader(http.StatusOK)
417444
}
418445

419-
// ///////////////////
420-
// BROADCAST HANDLER
421-
// ///////////////////
446+
// broadcastHandler handles incoming broadcast messages from authorized servers.
447+
// Can target a specific player or all connected players.
448+
// Enforces rate limits and increments metrics for delivered messages.
422449
func broadcastHandler(w http.ResponseWriter, r *http.Request) {
423450
auth := r.Header.Get("Authorization")
424451
if !strings.HasPrefix(auth, "Bearer ") {
@@ -467,18 +494,13 @@ func broadcastHandler(w http.ResponseWriter, r *http.Request) {
467494
w.WriteHeader(http.StatusOK)
468495
}
469496

470-
// ///////////////////
471-
// METRICS
472-
// ///////////////////
497+
// initMetrics registers Prometheus metrics for connections, messages published, and messages delivered.
473498
func initMetrics() {
474499
prometheus.MustRegister(connections, messagesPublished, messagesDelivered)
475500
}
476501

477-
// ///////////////////
478-
// TOKEN REVALIDATION
479-
// ///////////////////
480-
481-
// startTokenRevalidation periodically checks all WS tokens and closes invalid ones
502+
// startTokenRevalidation periodically validates all active websocket tokens.
503+
// Invalid tokens cause connections to be closed and removed.
482504
func startTokenRevalidation(interval time.Duration) {
483505
ticker := time.NewTicker(interval)
484506
go func() {
@@ -503,36 +525,76 @@ func startTokenRevalidation(interval time.Duration) {
503525
}()
504526
}
505527

506-
// ///////////////////
507-
// MAIN
508-
// ///////////////////
509-
func main() {
510-
var origins string
528+
// writePIDFile writes the current process PID to the specified file path.
529+
// Returns an error if writing fails.
530+
func writePIDFile(path string) error {
531+
pid := os.Getpid()
532+
data := []byte(fmt.Sprintf("%d\n", pid))
533+
return os.WriteFile(path, data, 0644)
534+
}
511535

512-
flag.StringVar(&dbDriver, "db", "sqlite", "Database driver")
513-
flag.StringVar(&dbDSN, "dsn", "file:ws_tokens.db?cache=shared", "Database DSN")
514-
flag.StringVar(&serverAddr, "addr", ":8080", "Server address")
515-
flag.StringVar(&origins, "origins", "", "Allowed WS origins")
516-
flag.StringVar(&logFile, "log-file", "", "Path to log file (default: stdout)")
517-
flag.StringVar(&logLevel, "log-level", "info", "Log level (panic, fatal, error, warn, info, debug, trace)")
518-
flag.IntVar(&maxQueuedMessagesPerPlayer, "max-queued", 100, "Maximum queued messages per player")
519-
flag.IntVar(&rateLimit, "rate-limit", 10, "Number of messages allowed per rate-period per server token")
520-
flag.DurationVar(&ratePeriod, "rate-period", time.Second, "Duration for rate limiting (e.g., 1s, 500ms)")
521-
flag.IntVar(&maxConnectionsPerPlayer, "max-conns", 5, "Maximum concurrent WebSocket connections per player")
522-
flag.DurationVar(&tokenRevalidationPeriod, "revalidate-period", time.Minute, "Period for WS token revalidation (e.g., 30s, 1m)")
523-
flag.DurationVar(&offlineTTL, "offline-ttl", 10*time.Second, "Duration that messages will be stored offline (e.g., 30s, 1m)")
524-
flag.Parse()
536+
// removePIDFile deletes the PID file at the specified path.
537+
// Any errors are ignored.
538+
func removePIDFile(path string) {
539+
_ = os.Remove(path)
540+
}
525541

526-
if origins != "" {
527-
allowedOrigins = strings.Split(origins, ",")
542+
// pidFileExists checks if the PID file exists and reads its PID.
543+
// Returns the PID and true if the file exists and contains a valid integer, otherwise 0 and false.
544+
func pidFileExists(path string) (int, bool) {
545+
data, err := os.ReadFile(path)
546+
if err != nil {
547+
return 0, false
528548
}
549+
var pid int
550+
if _, err := fmt.Sscanf(string(data), "%d", &pid); err != nil {
551+
return 0, false
552+
}
553+
return pid, true
554+
}
555+
556+
// daemonizeSelf re-launches the current executable as a background daemon process.
557+
// It returns an error if the executable cannot be determined or if the child process fails to start.
558+
// If successful, the parent process will exit immediately using os.Exit(0) to allow the daemon to continue independently.
559+
func daemonizeSelf() error {
560+
if os.Getenv("DAEMONIZED") == "1" {
561+
return nil
562+
}
563+
564+
exe, err := os.Executable()
565+
if err != nil {
566+
return fmt.Errorf("cannot get executable path: %w", err)
567+
}
568+
569+
args := []string{}
570+
for _, a := range os.Args[1:] {
571+
if a != "-daemon" {
572+
args = append(args, a)
573+
}
574+
}
575+
576+
cmd := exec.Command(exe, args...)
577+
cmd.Env = append(os.Environ(), "DAEMONIZED=1")
578+
cmd.SysProcAttr = &syscall.SysProcAttr{Setsid: true}
529579

580+
if err := cmd.Start(); err != nil {
581+
return fmt.Errorf("failed to daemonize: %w", err)
582+
}
583+
584+
os.Exit(0) // safe here, because no defers in daemon parent context
585+
return nil // unreachable, but satisfies compiler
586+
}
587+
588+
// setupLogging configures logrus logging for the application.
589+
// It sets the output destination and log level based on global flags.
590+
// Returns an error if the log file cannot be opened or if the log level is invalid.
591+
func setupLogging() error {
530592
logrus.SetFormatter(&logrus.JSONFormatter{})
531593

532594
if logFile != "" {
533595
f, err := os.OpenFile(logFile, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0644)
534596
if err != nil {
535-
log.Fatalf("failed to open log file %s: %v", logFile, err)
597+
return fmt.Errorf("failed to open log file %s: %w", logFile, err)
536598
}
537599
logrus.SetOutput(f)
538600
} else {
@@ -541,17 +603,75 @@ func main() {
541603

542604
level, err := logrus.ParseLevel(strings.ToLower(logLevel))
543605
if err != nil {
544-
log.Fatalf("invalid log level: %s", logLevel)
606+
return fmt.Errorf("invalid log level: %s", logLevel)
545607
}
546608
logrus.SetLevel(level)
547609

610+
return nil
611+
}
612+
613+
// handlePIDFile ensures that the PID file is created and removed properly.
614+
// If the PID file already exists, it returns an error.
615+
// The PID file is automatically removed when the function that called this defers cleanup.
616+
// Returns an error if writing the PID file fails.
617+
func handlePIDFile() error {
618+
if pidFile == "" {
619+
return nil
620+
}
621+
622+
if pid, ok := pidFileExists(pidFile); ok {
623+
return fmt.Errorf("pid file already exists for PID: %d", pid)
624+
}
625+
626+
if err := writePIDFile(pidFile); err != nil {
627+
return fmt.Errorf("failed to write pid file: %w", err)
628+
}
629+
630+
// The caller should defer removePIDFile(pidFile) to ensure cleanup
631+
return nil
632+
}
633+
634+
// ///////////////////
635+
// MAIN
636+
// ///////////////////
637+
func main() {
638+
if err := run(); err != nil {
639+
fmt.Fprintln(os.Stderr, err)
640+
os.Exit(1)
641+
}
642+
}
643+
644+
func run() error {
645+
parseFlags()
646+
647+
// Handle PID file
648+
if err := handlePIDFile(); err != nil {
649+
return err
650+
}
651+
if pidFile != "" {
652+
defer removePIDFile(pidFile)
653+
}
654+
655+
// Setup logging
656+
if err := setupLogging(); err != nil {
657+
return fmt.Errorf("failed to setup logging: %w", err)
658+
}
659+
660+
// Initialize DB
548661
if err := initDB(); err != nil {
549-
log.Fatal(err)
662+
return fmt.Errorf("failed to init DB: %w", err)
663+
}
664+
665+
// Daemonize if needed
666+
if daemonize {
667+
if err := daemonizeSelf(); err != nil {
668+
return fmt.Errorf("failed to daemonize: %w", err)
669+
}
550670
}
551671

552672
initMetrics()
553673

554-
// cleanup expired offline messages
674+
// Start offline message cleanup
555675
go func() {
556676
ticker := time.NewTicker(30 * time.Second)
557677
for range ticker.C {
@@ -585,7 +705,7 @@ func main() {
585705

586706
server := &http.Server{Addr: serverAddr, Handler: mux}
587707

588-
// graceful shutdown
708+
// Graceful shutdown
589709
quit := make(chan os.Signal, 1)
590710
signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM)
591711
go func() {
@@ -598,5 +718,10 @@ func main() {
598718
}()
599719

600720
logrus.Infof("Server listening on %s", serverAddr)
601-
log.Fatal(server.ListenAndServe())
721+
err := server.ListenAndServe()
722+
if err != nil && err != http.ErrServerClosed {
723+
return fmt.Errorf("server error: %w", err)
724+
}
725+
726+
return nil
602727
}

0 commit comments

Comments
 (0)