77 "errors"
88 "flag"
99 "fmt"
10- "log"
1110 "net/http"
1211 "os"
12+ "os/exec"
1313 "os/signal"
1414 "strings"
1515 "sync"
4545 tokenRevalidationPeriod time.Duration
4646 logFile string
4747 logLevel string
48+ daemonize bool
49+ pidFile string
4850)
4951
5052type wsConnection struct {
@@ -101,14 +103,39 @@ var (
101103 lm sync.Mutex
102104)
103105
104- // ///////////////////
105- // RATE LIMITER
106- // ///////////////////
107106type limiter struct {
108107 tokens int
109108 last time.Time
110109}
111110
111+ // parseFlags parses command-line flags into global configuration variables.
112+ // It supports database driver, DSN, server address, log file/level, PID file, and various WS server limits.
113+ func parseFlags () {
114+ var origins string
115+
116+ flag .StringVar (& dbDriver , "db" , "sqlite" , "Database driver" )
117+ flag .StringVar (& dbDSN , "dsn" , "file:ws_tokens.db?cache=shared" , "Database DSN" )
118+ flag .StringVar (& serverAddr , "addr" , ":8080" , "Server address" )
119+ flag .StringVar (& origins , "origins" , "" , "Allowed WS origins" )
120+ flag .StringVar (& logFile , "log-file" , "" , "Path to log file" )
121+ flag .StringVar (& logLevel , "log-level" , "info" , "Log level" )
122+ flag .StringVar (& pidFile , "pid-file" , "" , "Path to PID file" )
123+ flag .IntVar (& maxQueuedMessagesPerPlayer , "max-queued" , 100 , "" )
124+ flag .IntVar (& rateLimit , "rate-limit" , 10 , "" )
125+ flag .DurationVar (& ratePeriod , "rate-period" , time .Second , "" )
126+ flag .IntVar (& maxConnectionsPerPlayer , "max-conns" , 5 , "" )
127+ flag .DurationVar (& tokenRevalidationPeriod , "revalidate-period" , time .Minute , "" )
128+ flag .DurationVar (& offlineTTL , "offline-ttl" , 10 * time .Second , "" )
129+ flag .BoolVar (& daemonize , "daemon" , false , "" )
130+ flag .Parse ()
131+
132+ if origins != "" {
133+ allowedOrigins = strings .Split (origins , "," )
134+ }
135+ }
136+
137+ // allow implements a simple token-based rate limiter for a given key.
138+ // Returns true if the action is allowed, false if the rate limit has been exceeded.
112139func allow (key string , rate int , per time.Duration ) bool {
113140 lm .Lock ()
114141 defer lm .Unlock ()
@@ -161,9 +188,8 @@ type pendingMessage struct {
161188 timestamp time.Time
162189}
163190
164- // ///////////////////
165- // DB
166- // ///////////////////
191+ // initDB initializes the global database connection based on dbDriver and dbDSN.
192+ // Returns an error if the driver is unsupported or if the connection cannot be established.
167193func initDB () error {
168194 var err error
169195
@@ -187,7 +213,8 @@ func initDB() error {
187213 return db .Ping ()
188214}
189215
190- // validateToken checks token validity in DB
216+ // validateToken checks whether a given token is valid in the database.
217+ // Returns the associated player/subject ID and true if valid, or empty string and false if invalid.
191218func validateToken (token string , isServer bool ) (string , bool ) {
192219 const q = `
193220 SELECT IFNULL(player_id, subject_id)
@@ -208,11 +235,8 @@ func validateToken(token string, isServer bool) (string, bool) {
208235 return "" , false
209236}
210237
211- // ///////////////////
212- // CONNECTION MANAGEMENT
213- // ///////////////////
214-
215- // registerConnection registers a WS connection and stores its token
238+ // registerConnection registers a websocket connection for a player, storing the associated token.
239+ // It also increments the active connections metric and flushes any pending messages to the new connection.
216240func registerConnection (playerID string , c * websocket.Conn , token string ) {
217241 mu .Lock ()
218242 if players [playerID ] == nil {
@@ -226,7 +250,8 @@ func registerConnection(playerID string, c *websocket.Conn, token string) {
226250 flushPendingMessages (playerID , c )
227251}
228252
229- // unregisterConnection removes a WS connection
253+ // unregisterConnection removes a websocket connection for a player and decrements the active connections metric.
254+ // If no connections remain for the player, the player's entry is removed from the players map.
230255func unregisterConnection (playerID string , c * websocket.Conn ) {
231256 mu .Lock ()
232257 defer mu .Unlock ()
@@ -237,7 +262,8 @@ func unregisterConnection(playerID string, c *websocket.Conn) {
237262 connections .Dec ()
238263}
239264
240- // closeAllConnections closes all WS connections (on shutdown)
265+ // closeAllConnections closes all active websocket connections for all players.
266+ // Typically used during server shutdown.
241267func closeAllConnections () {
242268 mu .Lock ()
243269 defer mu .Unlock ()
@@ -248,7 +274,8 @@ func closeAllConnections() {
248274 }
249275}
250276
251- // flushPendingMessages sends queued messages to a newly connected player
277+ // flushPendingMessages sends any queued offline messages to a newly connected websocket.
278+ // Messages older than offlineTTL are ignored and removed.
252279func flushPendingMessages (playerID string , c * websocket.Conn ) {
253280 pendingMu .Lock ()
254281 msgs := pendingMessages [playerID ]
@@ -271,9 +298,9 @@ func flushPendingMessages(playerID string, c *websocket.Conn) {
271298 pendingMu .Unlock ()
272299}
273300
274- // ///////////////////
275- // WEBSOCKET HANDLER
276- // ///////////////////
301+ // wsHandler handles incoming websocket upgrade requests from clients.
302+ // Validates the token, enforces connection limits, sets up heartbeat, and reads messages.
303+ // Connections are automatically unregistered on disconnect.
277304func wsHandler (w http.ResponseWriter , r * http.Request ) {
278305 token := r .URL .Query ().Get ("token" )
279306 if token == "" {
@@ -345,9 +372,9 @@ func wsHandler(w http.ResponseWriter, r *http.Request) {
345372 }).Info ("Player disconnected" )
346373}
347374
348- // ///////////////////
349- // PUBLISH HANDLER
350- // ///////////////////
375+ // publishHandler handles incoming messages from authorized servers to a specific player.
376+ // Validates server token, enforces rate limits, delivers message immediately if the player is connected,
377+ // or queues the message for offline delivery if not.
351378func publishHandler (w http.ResponseWriter , r * http.Request ) {
352379 auth := r .Header .Get ("Authorization" )
353380 if ! strings .HasPrefix (auth , "Bearer " ) {
@@ -416,9 +443,9 @@ func publishHandler(w http.ResponseWriter, r *http.Request) {
416443 w .WriteHeader (http .StatusOK )
417444}
418445
419- // ///////////////////
420- // BROADCAST HANDLER
421- // ///////////////////
446+ // broadcastHandler handles incoming broadcast messages from authorized servers.
447+ // Can target a specific player or all connected players.
448+ // Enforces rate limits and increments metrics for delivered messages.
422449func broadcastHandler (w http.ResponseWriter , r * http.Request ) {
423450 auth := r .Header .Get ("Authorization" )
424451 if ! strings .HasPrefix (auth , "Bearer " ) {
@@ -467,18 +494,13 @@ func broadcastHandler(w http.ResponseWriter, r *http.Request) {
467494 w .WriteHeader (http .StatusOK )
468495}
469496
470- // ///////////////////
471- // METRICS
472- // ///////////////////
497+ // initMetrics registers Prometheus metrics for connections, messages published, and messages delivered.
473498func initMetrics () {
474499 prometheus .MustRegister (connections , messagesPublished , messagesDelivered )
475500}
476501
477- // ///////////////////
478- // TOKEN REVALIDATION
479- // ///////////////////
480-
481- // startTokenRevalidation periodically checks all WS tokens and closes invalid ones
502+ // startTokenRevalidation periodically validates all active websocket tokens.
503+ // Invalid tokens cause connections to be closed and removed.
482504func startTokenRevalidation (interval time.Duration ) {
483505 ticker := time .NewTicker (interval )
484506 go func () {
@@ -503,36 +525,76 @@ func startTokenRevalidation(interval time.Duration) {
503525 }()
504526}
505527
506- // ///////////////////
507- // MAIN
508- // ///////////////////
509- func main () {
510- var origins string
528+ // writePIDFile writes the current process PID to the specified file path.
529+ // Returns an error if writing fails.
530+ func writePIDFile (path string ) error {
531+ pid := os .Getpid ()
532+ data := []byte (fmt .Sprintf ("%d\n " , pid ))
533+ return os .WriteFile (path , data , 0644 )
534+ }
511535
512- flag .StringVar (& dbDriver , "db" , "sqlite" , "Database driver" )
513- flag .StringVar (& dbDSN , "dsn" , "file:ws_tokens.db?cache=shared" , "Database DSN" )
514- flag .StringVar (& serverAddr , "addr" , ":8080" , "Server address" )
515- flag .StringVar (& origins , "origins" , "" , "Allowed WS origins" )
516- flag .StringVar (& logFile , "log-file" , "" , "Path to log file (default: stdout)" )
517- flag .StringVar (& logLevel , "log-level" , "info" , "Log level (panic, fatal, error, warn, info, debug, trace)" )
518- flag .IntVar (& maxQueuedMessagesPerPlayer , "max-queued" , 100 , "Maximum queued messages per player" )
519- flag .IntVar (& rateLimit , "rate-limit" , 10 , "Number of messages allowed per rate-period per server token" )
520- flag .DurationVar (& ratePeriod , "rate-period" , time .Second , "Duration for rate limiting (e.g., 1s, 500ms)" )
521- flag .IntVar (& maxConnectionsPerPlayer , "max-conns" , 5 , "Maximum concurrent WebSocket connections per player" )
522- flag .DurationVar (& tokenRevalidationPeriod , "revalidate-period" , time .Minute , "Period for WS token revalidation (e.g., 30s, 1m)" )
523- flag .DurationVar (& offlineTTL , "offline-ttl" , 10 * time .Second , "Duration that messages will be stored offline (e.g., 30s, 1m)" )
524- flag .Parse ()
536+ // removePIDFile deletes the PID file at the specified path.
537+ // Any errors are ignored.
538+ func removePIDFile (path string ) {
539+ _ = os .Remove (path )
540+ }
525541
526- if origins != "" {
527- allowedOrigins = strings .Split (origins , "," )
542+ // pidFileExists checks if the PID file exists and reads its PID.
543+ // Returns the PID and true if the file exists and contains a valid integer, otherwise 0 and false.
544+ func pidFileExists (path string ) (int , bool ) {
545+ data , err := os .ReadFile (path )
546+ if err != nil {
547+ return 0 , false
528548 }
549+ var pid int
550+ if _ , err := fmt .Sscanf (string (data ), "%d" , & pid ); err != nil {
551+ return 0 , false
552+ }
553+ return pid , true
554+ }
555+
556+ // daemonizeSelf re-launches the current executable as a background daemon process.
557+ // It returns an error if the executable cannot be determined or if the child process fails to start.
558+ // If successful, the parent process will exit immediately using os.Exit(0) to allow the daemon to continue independently.
559+ func daemonizeSelf () error {
560+ if os .Getenv ("DAEMONIZED" ) == "1" {
561+ return nil
562+ }
563+
564+ exe , err := os .Executable ()
565+ if err != nil {
566+ return fmt .Errorf ("cannot get executable path: %w" , err )
567+ }
568+
569+ args := []string {}
570+ for _ , a := range os .Args [1 :] {
571+ if a != "-daemon" {
572+ args = append (args , a )
573+ }
574+ }
575+
576+ cmd := exec .Command (exe , args ... )
577+ cmd .Env = append (os .Environ (), "DAEMONIZED=1" )
578+ cmd .SysProcAttr = & syscall.SysProcAttr {Setsid : true }
529579
580+ if err := cmd .Start (); err != nil {
581+ return fmt .Errorf ("failed to daemonize: %w" , err )
582+ }
583+
584+ os .Exit (0 ) // safe here, because no defers in daemon parent context
585+ return nil // unreachable, but satisfies compiler
586+ }
587+
588+ // setupLogging configures logrus logging for the application.
589+ // It sets the output destination and log level based on global flags.
590+ // Returns an error if the log file cannot be opened or if the log level is invalid.
591+ func setupLogging () error {
530592 logrus .SetFormatter (& logrus.JSONFormatter {})
531593
532594 if logFile != "" {
533595 f , err := os .OpenFile (logFile , os .O_CREATE | os .O_WRONLY | os .O_APPEND , 0644 )
534596 if err != nil {
535- log . Fatalf ("failed to open log file %s: %v " , logFile , err )
597+ return fmt . Errorf ("failed to open log file %s: %w " , logFile , err )
536598 }
537599 logrus .SetOutput (f )
538600 } else {
@@ -541,17 +603,75 @@ func main() {
541603
542604 level , err := logrus .ParseLevel (strings .ToLower (logLevel ))
543605 if err != nil {
544- log . Fatalf ("invalid log level: %s" , logLevel )
606+ return fmt . Errorf ("invalid log level: %s" , logLevel )
545607 }
546608 logrus .SetLevel (level )
547609
610+ return nil
611+ }
612+
613+ // handlePIDFile ensures that the PID file is created and removed properly.
614+ // If the PID file already exists, it returns an error.
615+ // The PID file is automatically removed when the function that called this defers cleanup.
616+ // Returns an error if writing the PID file fails.
617+ func handlePIDFile () error {
618+ if pidFile == "" {
619+ return nil
620+ }
621+
622+ if pid , ok := pidFileExists (pidFile ); ok {
623+ return fmt .Errorf ("pid file already exists for PID: %d" , pid )
624+ }
625+
626+ if err := writePIDFile (pidFile ); err != nil {
627+ return fmt .Errorf ("failed to write pid file: %w" , err )
628+ }
629+
630+ // The caller should defer removePIDFile(pidFile) to ensure cleanup
631+ return nil
632+ }
633+
634+ // ///////////////////
635+ // MAIN
636+ // ///////////////////
637+ func main () {
638+ if err := run (); err != nil {
639+ fmt .Fprintln (os .Stderr , err )
640+ os .Exit (1 )
641+ }
642+ }
643+
644+ func run () error {
645+ parseFlags ()
646+
647+ // Handle PID file
648+ if err := handlePIDFile (); err != nil {
649+ return err
650+ }
651+ if pidFile != "" {
652+ defer removePIDFile (pidFile )
653+ }
654+
655+ // Setup logging
656+ if err := setupLogging (); err != nil {
657+ return fmt .Errorf ("failed to setup logging: %w" , err )
658+ }
659+
660+ // Initialize DB
548661 if err := initDB (); err != nil {
549- log .Fatal (err )
662+ return fmt .Errorf ("failed to init DB: %w" , err )
663+ }
664+
665+ // Daemonize if needed
666+ if daemonize {
667+ if err := daemonizeSelf (); err != nil {
668+ return fmt .Errorf ("failed to daemonize: %w" , err )
669+ }
550670 }
551671
552672 initMetrics ()
553673
554- // cleanup expired offline messages
674+ // Start offline message cleanup
555675 go func () {
556676 ticker := time .NewTicker (30 * time .Second )
557677 for range ticker .C {
@@ -585,7 +705,7 @@ func main() {
585705
586706 server := & http.Server {Addr : serverAddr , Handler : mux }
587707
588- // graceful shutdown
708+ // Graceful shutdown
589709 quit := make (chan os.Signal , 1 )
590710 signal .Notify (quit , syscall .SIGINT , syscall .SIGTERM )
591711 go func () {
@@ -598,5 +718,10 @@ func main() {
598718 }()
599719
600720 logrus .Infof ("Server listening on %s" , serverAddr )
601- log .Fatal (server .ListenAndServe ())
721+ err := server .ListenAndServe ()
722+ if err != nil && err != http .ErrServerClosed {
723+ return fmt .Errorf ("server error: %w" , err )
724+ }
725+
726+ return nil
602727}
0 commit comments