From 3ad2c294d2d06dc7c0ff1989d7233aa5d45f06ac Mon Sep 17 00:00:00 2001 From: Dustin Pianalto Date: Thu, 6 Mar 2025 15:25:37 -0900 Subject: [PATCH] fix shutdown bug --- examples/main.go | 149 ++++++++++++++++++++++++++++------------------- web/server.go | 142 +++++++++++++++++++++++++++++++------------- 2 files changed, 190 insertions(+), 101 deletions(-) diff --git a/examples/main.go b/examples/main.go index d851dae..6cb157a 100644 --- a/examples/main.go +++ b/examples/main.go @@ -2,6 +2,7 @@ package main import ( "bufio" + "context" "flag" "fmt" "os" @@ -49,12 +50,16 @@ func main() { close(eventBroadcaster) }() + // Create a context for graceful shutdown + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + // Start web interface if enabled if !*noWeb { wg.Add(1) go func() { defer wg.Done() - startWebInterface(clock, eventBroadcaster, *webPort, sigChan) + startWebInterface(clock, eventBroadcaster, *webPort, ctx) }() } @@ -63,26 +68,30 @@ func main() { wg.Add(1) go func() { defer wg.Done() - startTerminalInterface(clock, eventBroadcaster, sigChan) + startTerminalInterface(clock, eventBroadcaster, ctx) }() } // Wait for signal to exit <-sigChan fmt.Println("Shutting down...") - time.Sleep(500 * time.Millisecond) // Give a moment for any pending operations to complete + + // Cancel context to signal all components to shut down + cancel() + + // Give a moment for any pending operations to complete + time.Sleep(500 * time.Millisecond) // Wait for all interfaces to shut down wg.Wait() } // startWebInterface initializes and runs the web interface -func startWebInterface(clock *derby.DerbyClock, events <-chan derby.Event, webPort int, sigChan chan os.Signal) { +func startWebInterface(clock *derby.DerbyClock, events <-chan derby.Event, webPort int, ctx context.Context) { // Create and start the web server server, err := web.NewServer(clock, events, webPort) if err != nil { fmt.Printf("Error creating web server: %v\n", err) - sigChan <- syscall.SIGTERM return } @@ -91,18 +100,26 @@ func startWebInterface(clock *derby.DerbyClock, events <-chan derby.Event, webPo // Start the web server if err := server.Start(); err != nil { fmt.Printf("Web server error: %v\n", err) - sigChan <- syscall.SIGTERM + return + } + + // Wait for context cancellation + <-ctx.Done() + + // Gracefully shut down the server + fmt.Println("Shutting down web server...") + if err := server.Stop(); err != nil { + fmt.Printf("Error shutting down web server: %v\n", err) } } // startTerminalInterface initializes and runs the terminal interface -func startTerminalInterface(clock *derby.DerbyClock, events <-chan derby.Event, sigChan chan os.Signal) { +func startTerminalInterface(clock *derby.DerbyClock, events <-chan derby.Event, ctx context.Context) { fmt.Println("Terminal interface started") // Reset the clock to start fresh if err := clock.Reset(); err != nil { fmt.Printf("Error resetting clock: %v\n", err) - sigChan <- syscall.SIGTERM return } fmt.Println("Clock reset. Ready to start race.") @@ -117,25 +134,33 @@ func startTerminalInterface(clock *derby.DerbyClock, events <-chan derby.Event, // Process events from the clock go func() { raceResults := make([]*derby.Result, 0) - for event := range events { - switch event.Type { - case derby.EventRaceStart: - fmt.Println("\nšŸ Race started!") - - case derby.EventLaneFinish: - result := event.Result - fmt.Printf("šŸš— Lane %d finished in place %d with time %.4f seconds\n", - result.Lane, result.FinishPlace, result.Time) - raceResults = append(raceResults, result) - - case derby.EventRaceComplete: - fmt.Println("\nšŸ† Race complete! Final results:") - for _, result := range raceResults { - fmt.Printf("Place %d: Lane %d - %.4f seconds\n", - result.FinishPlace, result.Lane, result.Time) + for { + select { + case event, ok := <-events: + if !ok { + return + } + switch event.Type { + case derby.EventRaceStart: + fmt.Println("\nšŸ Race started!") + + case derby.EventLaneFinish: + result := event.Result + fmt.Printf("šŸš— Lane %d finished in place %d with time %.4f seconds\n", + result.Lane, result.FinishPlace, result.Time) + raceResults = append(raceResults, result) + + case derby.EventRaceComplete: + fmt.Println("\nšŸ† Race complete! Final results:") + for _, result := range raceResults { + fmt.Printf("Place %d: Lane %d - %.4f seconds\n", + result.FinishPlace, result.Lane, result.Time) + } + fmt.Println("\nEnter command (r/f/q/?):") + raceResults = nil } - fmt.Println("\nEnter command (r/f/q/?):") - raceResults = nil + case <-ctx.Done(): + return } } }() @@ -143,46 +168,50 @@ func startTerminalInterface(clock *derby.DerbyClock, events <-chan derby.Event, // Handle keyboard input reader := bufio.NewReader(os.Stdin) for { - fmt.Print("Enter command (r/f/q/?): ") - input, err := reader.ReadString('\n') - if err != nil { - fmt.Printf("Error reading input: %v\n", err) - continue - } + select { + case <-ctx.Done(): + return + default: + fmt.Print("Enter command (r/f/q/?): ") + input, err := reader.ReadString('\n') + if err != nil { + fmt.Printf("Error reading input: %v\n", err) + continue + } - // Trim whitespace and convert to lowercase - command := strings.TrimSpace(strings.ToLower(input)) + // Trim whitespace and convert to lowercase + command := strings.TrimSpace(strings.ToLower(input)) - switch command { - case "r": - fmt.Println("Resetting clock...") - if err := clock.Reset(); err != nil { - fmt.Printf("Error resetting clock: %v\n", err) - } else { - fmt.Println("Clock reset. Ready to start race.") - } + switch command { + case "r": + fmt.Println("Resetting clock...") + if err := clock.Reset(); err != nil { + fmt.Printf("Error resetting clock: %v\n", err) + } else { + fmt.Println("Clock reset. Ready to start race.") + } - case "f": - fmt.Println("Forcing race to end...") - if err := clock.ForceEnd(); err != nil { - fmt.Printf("Error forcing race end: %v\n", err) - } + case "f": + fmt.Println("Forcing race to end...") + if err := clock.ForceEnd(); err != nil { + fmt.Printf("Error forcing race end: %v\n", err) + } - case "q": - fmt.Println("Quitting...") - sigChan <- syscall.SIGTERM - return + case "q": + fmt.Println("Quitting...") + return - case "?": - fmt.Println("\nCommands:") - fmt.Println(" r - Reset the clock") - fmt.Println(" f - Force end the race") - fmt.Println(" q - Quit the program") - fmt.Println(" ? - Show this help message") + case "?": + fmt.Println("\nCommands:") + fmt.Println(" r - Reset the clock") + fmt.Println(" f - Force end the race") + fmt.Println(" q - Quit the program") + fmt.Println(" ? - Show this help message") - default: - if command != "" { - fmt.Println("Unknown command. Type ? for help.") + default: + if command != "" { + fmt.Println("Unknown command. Type ? for help.") + } } } } diff --git a/web/server.go b/web/server.go index 1250548..e4156d3 100644 --- a/web/server.go +++ b/web/server.go @@ -1,10 +1,12 @@ package web import ( + "context" "embed" "fmt" "io/fs" "net/http" + "sync" "time" "github.com/go-chi/chi/v5" @@ -19,22 +21,27 @@ var content embed.FS // Server represents the web server for the derby clock type Server struct { - router *chi.Mux - clock *derby.DerbyClock - events <-chan derby.Event - clients map[chan string]bool - port int + router *chi.Mux + clock *derby.DerbyClock + events <-chan derby.Event + clients map[chan string]bool + clientsMux sync.Mutex + port int + server *http.Server + shutdown chan struct{} } // NewServer creates a new web server func NewServer(clock *derby.DerbyClock, events <-chan derby.Event, port int) (*Server, error) { // Create server s := &Server{ - router: chi.NewRouter(), - clock: clock, - events: events, - clients: make(map[chan string]bool), - port: port, + router: chi.NewRouter(), + clock: clock, + events: events, + clients: make(map[chan string]bool), + clientsMux: sync.Mutex{}, + port: port, + shutdown: make(chan struct{}), } // Set up routes @@ -69,43 +76,90 @@ func (s *Server) routes() { // Start starts the web server func (s *Server) Start() error { - fmt.Printf("Starting web server on port %d...\n", s.port) - return http.ListenAndServe(fmt.Sprintf(":%d", s.port), s.router) + addr := fmt.Sprintf(":%d", s.port) + s.server = &http.Server{ + Addr: addr, + Handler: s.router, + } + + // Start server in a goroutine + go func() { + if err := s.server.ListenAndServe(); err != nil && err != http.ErrServerClosed { + fmt.Printf("HTTP server error: %v\n", err) + } + }() + + return nil +} + +// Stop gracefully shuts down the server +func (s *Server) Stop() error { + // Signal event forwarder to stop + close(s.shutdown) + + // Close all client connections + s.clientsMux.Lock() + for clientChan := range s.clients { + delete(s.clients, clientChan) + close(clientChan) + } + s.clientsMux.Unlock() + + // Create a context with timeout for shutdown + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + // Shutdown the HTTP server + if s.server != nil { + return s.server.Shutdown(ctx) + } + + return nil } // forwardEvents forwards derby events to SSE clients func (s *Server) forwardEvents() { - for event := range s.events { - // Store the event for new clients - // s.raceEvents <- event - - // Create the SSE message based on the event type - var message string - switch event.Type { - case derby.EventRaceStart: - message = "event: race-start\ndata: {\"status\": \"running\"}\n\n" - - case derby.EventLaneFinish: - result := event.Result - message = fmt.Sprintf("event: lane-finish\ndata: {\"lane\": %d, \"place\": %d, \"time\": %.4f}\n\n", - result.Lane, result.FinishPlace, result.Time) - - case derby.EventRaceComplete: - message = "event: race-complete\ndata: {\"status\": \"finished\"}\n\n" + for { + select { + case event, ok := <-s.events: + if !ok { + return + } + // Process the event and send to clients + s.broadcastEvent(event) + case <-s.shutdown: + return } + } +} - // Send to all connected clients - for clientChan := range s.clients { - select { - case clientChan <- message: - // Message sent successfully - default: - // Client is not receiving, remove it - delete(s.clients, clientChan) - close(clientChan) - } +// broadcastEvent sends an event to all connected clients +func (s *Server) broadcastEvent(event derby.Event) { + var message string + switch event.Type { + case derby.EventRaceStart: + message = "event: race-start\ndata: {\"status\": \"running\"}\n\n" + + case derby.EventLaneFinish: + result := event.Result + message = fmt.Sprintf("event: lane-finish\ndata: {\"lane\": %d, \"place\": %d, \"time\": %.4f}\n\n", + result.Lane, result.FinishPlace, result.Time) + + case derby.EventRaceComplete: + message = "event: race-complete\ndata: {\"status\": \"finished\"}\n\n" + } + + // Send to all clients + s.clientsMux.Lock() + for clientChan := range s.clients { + // Non-blocking send to avoid slow clients blocking others + select { + case clientChan <- message: + default: + // Client channel is full, could log this or take other action } } + s.clientsMux.Unlock() } // handleIndex handles the index page @@ -161,7 +215,7 @@ func (s *Server) handleStatus() http.HandlerFunc { } } -// handleEvents handles the SSE events endpoint +// handleEvents handles SSE events func (s *Server) handleEvents() http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { // Set headers for SSE @@ -172,11 +226,17 @@ func (s *Server) handleEvents() http.HandlerFunc { // Create a channel for this client clientChan := make(chan string, 10) + + // Add client to map with mutex protection + s.clientsMux.Lock() s.clients[clientChan] = true + s.clientsMux.Unlock() - // Clean up when the client disconnects + // Remove client when connection is closed defer func() { + s.clientsMux.Lock() delete(s.clients, clientChan) + s.clientsMux.Unlock() close(clientChan) }()