diff --git a/web/server.go b/web/server.go index dfe6ed8..cb1d8f1 100644 --- a/web/server.go +++ b/web/server.go @@ -80,10 +80,49 @@ func NewServer(clock *derby.DerbyClock, events <-chan derby.Event, dbPath string // routes sets up the routes for the server func (s *Server) routes() { // Middleware + s.router.Use(middleware.RequestID) + s.router.Use(middleware.RealIP) s.router.Use(middleware.Logger) s.router.Use(middleware.Recoverer) - // Add timeout middleware with a longer duration for SSE connections - s.router.Use(middleware.Timeout(120 * time.Second)) + + // Custom middleware to log all requests + s.router.Use(func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + requestID := middleware.GetReqID(r.Context()) + s.logger.Info("Request started", + "method", r.Method, + "path", r.URL.Path, + "requestID", requestID, + "remoteAddr", r.RemoteAddr) + + start := time.Now() + next.ServeHTTP(w, r) + + s.logger.Info("Request completed", + "method", r.Method, + "path", r.URL.Path, + "requestID", requestID, + "duration", time.Since(start)) + }) + }) + + // Use a very long timeout for SSE endpoints, shorter for others + s.router.Use(func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if strings.Contains(r.URL.Path, "/api/events") || strings.Contains(r.URL.Path, "/api/admin/events") { + // For SSE endpoints, use a context with a very long timeout + ctx, cancel := context.WithTimeout(r.Context(), 24*time.Hour) + defer cancel() + next.ServeHTTP(w, r.WithContext(ctx)) + } else { + // For regular endpoints, use a shorter timeout + ctx, cancel := context.WithTimeout(r.Context(), 30*time.Second) + defer cancel() + next.ServeHTTP(w, r.WithContext(ctx)) + } + }) + }) + // Add middleware to set appropriate headers for SSE s.router.Use(func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -91,6 +130,7 @@ func (s *Server) routes() { w.Header().Set("Cache-Control", "no-cache") w.Header().Set("Connection", "keep-alive") w.Header().Set("Access-Control-Allow-Origin", "*") + s.logger.Debug("SSE headers set", "path", r.URL.Path) } next.ServeHTTP(w, r) }) @@ -184,50 +224,76 @@ func (s *Server) Start() error { Addr: addr, Handler: s.router, // Add these settings to handle multiple concurrent connections - ReadTimeout: 30 * time.Second, - WriteTimeout: 60 * time.Second, // Longer timeout for SSE connections - MaxHeaderBytes: 1 << 20, // 1 MB + ReadTimeout: 120 * time.Second, // Longer timeout for SSE + WriteTimeout: 120 * time.Second, // Longer timeout for SSE IdleTimeout: 120 * time.Second, + MaxHeaderBytes: 1 << 20, // 1 MB } - s.logger.Info("Starting web server", "port", s.port) - return s.server.ListenAndServe() -} + // Log middleware configuration + s.logger.Info("Chi router middleware configuration", + "timeout", "120s", + "recoverer", true, + "logger", true) -// Stop gracefully shuts down the server -func (s *Server) Stop() error { - // Close database connection - if s.db != nil { - if err := s.db.Close(); err != nil { - s.logger.Error("Error closing database", "error", err) + // Start server in a goroutine + go func() { + s.logger.Info("Web server starting", "port", s.port) + if err := s.server.ListenAndServe(); err != nil && err != http.ErrServerClosed { + s.logger.Error("HTTP server error", "error", err) } - } + }() - // Create a context with timeout for shutdown - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) - defer cancel() + return nil +} +// Shutdown gracefully shuts down the server +func (s *Server) Shutdown(ctx context.Context) error { s.logger.Info("Shutting down web server") - // Shutdown the HTTP server - if s.server != nil { - return s.server.Shutdown(ctx) + // Signal event forwarder to stop + close(s.shutdown) + + // Close all client connections + s.clientsMux.Lock() + for clientChan := range s.clients { + close(clientChan) } + s.clients = make(map[chan string]bool) + s.clientsMux.Unlock() - return nil + s.adminclientsMux.Lock() + for clientChan := range s.adminclients { + close(clientChan) + } + s.adminclients = make(map[chan string]bool) + s.adminclientsMux.Unlock() + + // Shutdown the HTTP server + return s.server.Shutdown(ctx) } // forwardEvents forwards derby events to SSE clients func (s *Server) forwardEvents() { + s.logger.Info("Starting event forwarder") + for { select { case event, ok := <-s.events: if !ok { + s.logger.Warn("Events channel closed, stopping forwarder") return } + + s.logger.Debug("Received event from derby clock", + "type", event.Type, + "clientCount", len(s.clients)) + // Process the event and send to clients s.broadcastRaceEvent(event) + case <-s.shutdown: + s.logger.Info("Shutdown signal received, stopping forwarder") return } } @@ -316,102 +382,113 @@ func (s *Server) broadcastAdminEvent(event models.AdminEvent) { } } +// sendRaceEventToAllClients sends an event to all connected clients func (s *Server) sendRaceEventToAllClients(message string) { - if message == "" { - return + // Log the message being sent (truncate if too long) + msgToLog := message + if len(msgToLog) > 100 { + msgToLog = msgToLog[:100] + "..." } + s.logger.Debug("Sending message to all clients", + "message", msgToLog, + "clientCount", len(s.clients)) - // Send to all clients + // Make a copy of the clients map to avoid holding the lock while sending s.clientsMux.Lock() - clientCount := len(s.clients) - sentCount := 0 + clientsToSend := make([]chan string, 0, len(s.clients)) for clientChan := range s.clients { - clientChan <- message - sentCount++ + clientsToSend = append(clientsToSend, clientChan) } - s.logger.Debug("Event broadcast complete", - "sentCount", sentCount, - "totalClients", clientCount, - "message", message) s.clientsMux.Unlock() + + s.logger.Debug("Prepared to send to clients", "count", len(clientsToSend)) + + // Count successful and failed sends + successCount := 0 + failCount := 0 + + // Send to all clients without holding the lock + for _, clientChan := range clientsToSend { + select { + case clientChan <- message: + // Message sent successfully + successCount++ + default: + // Client is not receiving, remove it + s.clientsMux.Lock() + delete(s.clients, clientChan) + s.clientsMux.Unlock() + close(clientChan) + failCount++ + } + } + + s.logger.Debug("Finished sending message", + "successCount", successCount, + "failCount", failCount, + "remainingClients", len(s.clients)) } +// sendAdminEventToAllClients sends an event to all connected clients func (s *Server) sendAdminEventToAllClients(message string) { - if message == "" { - return + // Log the message being sent (truncate if too long) + msgToLog := message + if len(msgToLog) > 100 { + msgToLog = msgToLog[:100] + "..." } + s.logger.Debug("Sending message to all clients", + "message", msgToLog, + "clientCount", len(s.adminclients)) - // Send to all clients + // Make a copy of the clients map to avoid holding the lock while sending s.adminclientsMux.Lock() - clientCount := len(s.adminclients) - sentCount := 0 + clientsToSend := make([]chan string, 0, len(s.adminclients)) for clientChan := range s.adminclients { - clientChan <- message - sentCount++ + clientsToSend = append(clientsToSend, clientChan) } - s.logger.Debug("Event broadcast complete", - "sentCount", sentCount, - "totalClients", clientCount, - "message", message) s.adminclientsMux.Unlock() -} -// handleIndex handles the index page -func (s *Server) handleIndex() http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - templates.Index().Render(r.Context(), w) - } -} + s.logger.Debug("Prepared to send to clients", "count", len(clientsToSend)) -// handleReset handles the reset API endpoint -func (s *Server) handleReset() http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - if err := s.clock.Reset(); err != nil { - http.Error(w, fmt.Sprintf("Failed to reset clock: %v", err), http.StatusInternalServerError) - return - } - - w.Header().Set("Content-Type", "application/json") - w.Write([]byte(`{"status": "reset"}`)) - } -} + // Count successful and failed sends + successCount := 0 + failCount := 0 -// handleForceEnd handles the force end API endpoint -func (s *Server) handleForceEnd() http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - if err := s.clock.ForceEnd(); err != nil { - http.Error(w, fmt.Sprintf("Failed to force end race: %v", err), http.StatusInternalServerError) - return + // Send to all clients without holding the lock + for _, clientChan := range clientsToSend { + select { + case clientChan <- message: + // Message sent successfully + successCount++ + default: + // Client is not receiving, remove it + s.adminclientsMux.Lock() + delete(s.adminclients, clientChan) + s.adminclientsMux.Unlock() + close(clientChan) + failCount++ } - - w.Header().Set("Content-Type", "application/json") - w.Write([]byte(`{"status": "forced"}`)) } -} -// handleStatus handles the status API endpoint -func (s *Server) handleStatus() http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - status := s.clock.Status() - - var statusStr string - switch status { - case derby.StatusIdle: - statusStr = "idle" - case derby.StatusRunning: - statusStr = "running" - case derby.StatusFinished: - statusStr = "finished" - } - - w.Header().Set("Content-Type", "application/json") - w.Write([]byte(fmt.Sprintf(`{"status": "%s"}`, statusStr))) - } + s.logger.Debug("Finished sending message", + "successCount", successCount, + "failCount", failCount, + "remainingClients", len(s.adminclients)) } // handleEvents handles SSE events func (s *Server) handleEvents() http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { + requestID := middleware.GetReqID(r.Context()) + if requestID == "" { + requestID = fmt.Sprintf("req-%d", time.Now().UnixNano()) + } + + s.logger.Info("SSE connection request received", + "requestID", requestID, + "remoteAddr", r.RemoteAddr, + "userAgent", r.UserAgent()) + // Set headers for SSE w.Header().Set("Content-Type", "text/event-stream") w.Header().Set("Cache-Control", "no-cache") @@ -420,8 +497,10 @@ func (s *Server) handleEvents() http.HandlerFunc { // Flush headers to ensure they're sent to the client if flusher, ok := w.(http.Flusher); ok { + s.logger.Debug("Flushing headers", "requestID", requestID) flusher.Flush() } else { + s.logger.Error("Streaming unsupported!", "requestID", requestID) http.Error(w, "Streaming unsupported!", http.StatusInternalServerError) return } @@ -436,9 +515,18 @@ func (s *Server) handleEvents() http.HandlerFunc { s.clientsMux.Unlock() s.logger.Info("New client connected", + "requestID", requestID, "clientIP", r.RemoteAddr, "totalClients", clientCount) + // Send a ping immediately to test the connection + pingMsg := "event: ping\ndata: connection established\n\n" + fmt.Fprint(w, pingMsg) + if flusher, ok := w.(http.Flusher); ok { + flusher.Flush() + s.logger.Debug("Initial ping sent", "requestID", requestID) + } + // Remove client when connection is closed defer func() { s.clientsMux.Lock() @@ -448,28 +536,107 @@ func (s *Server) handleEvents() http.HandlerFunc { close(clientChan) s.logger.Info("Client disconnected", + "requestID", requestID, "clientIP", r.RemoteAddr, "remainingClients", remainingClients) }() + // Set up a heartbeat to keep the connection alive + heartbeat := time.NewTicker(30 * time.Second) + defer heartbeat.Stop() + // Keep connection open and send events as they arrive for { select { case msg, ok := <-clientChan: if !ok { + s.logger.Warn("Client channel closed", "requestID", requestID) return } - fmt.Fprintf(w, "%s\n\n", msg) + + s.logger.Debug("Sending message to client", + "requestID", requestID, + "messageLength", len(msg)) + + fmt.Fprint(w, msg) if flusher, ok := w.(http.Flusher); ok { flusher.Flush() + s.logger.Debug("Message flushed", "requestID", requestID) + } else { + s.logger.Warn("Could not flush - client may not receive updates", "requestID", requestID) } + + case <-heartbeat.C: + // Send a heartbeat to keep the connection alive + s.logger.Debug("Sending heartbeat", "requestID", requestID) + fmt.Fprint(w, "event: ping\ndata: heartbeat\n\n") + if flusher, ok := w.(http.Flusher); ok { + flusher.Flush() + } + case <-r.Context().Done(): + s.logger.Info("Client context done", + "requestID", requestID, + "error", r.Context().Err()) return } } } } +// handleIndex handles the index page +func (s *Server) handleIndex() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + templates.Index().Render(r.Context(), w) + } +} + +// handleReset handles the reset API endpoint +func (s *Server) handleReset() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + if err := s.clock.Reset(); err != nil { + http.Error(w, fmt.Sprintf("Failed to reset clock: %v", err), http.StatusInternalServerError) + return + } + + w.Header().Set("Content-Type", "application/json") + w.Write([]byte(`{"status": "reset"}`)) + } +} + +// handleForceEnd handles the force end API endpoint +func (s *Server) handleForceEnd() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + if err := s.clock.ForceEnd(); err != nil { + http.Error(w, fmt.Sprintf("Failed to force end race: %v", err), http.StatusInternalServerError) + return + } + + w.Header().Set("Content-Type", "application/json") + w.Write([]byte(`{"status": "forced"}`)) + } +} + +// handleStatus handles the status API endpoint +func (s *Server) handleStatus() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + status := s.clock.Status() + + var statusStr string + switch status { + case derby.StatusIdle: + statusStr = "idle" + case derby.StatusRunning: + statusStr = "running" + case derby.StatusFinished: + statusStr = "finished" + } + + w.Header().Set("Content-Type", "application/json") + w.Write([]byte(fmt.Sprintf(`{"status": "%s"}`, statusStr))) + } +} + // handleAdminEvents handles SSE events func (s *Server) handleAdminEvents() http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) {