diff --git a/pkg/cloud/cloud.go b/pkg/cloud/cloud.go index 27dff6bc..368fd1df 100644 --- a/pkg/cloud/cloud.go +++ b/pkg/cloud/cloud.go @@ -321,7 +321,7 @@ func New(projectName string, opts LocalCloudOptions) (*LocalCloud, error) { return nil, err } - localWebsites := websites.NewLocalWebsitesService(localGateway.GetApiAddress, opts.LocalCloudMode == StartMode) + localWebsites := websites.NewLocalWebsitesService(localGateway.GetApiAddress, localGateway.GetWebsocketAddress, opts.LocalCloudMode == StartMode) return &LocalCloud{ servers: make(map[string]*server.NitricServer), diff --git a/pkg/cloud/gateway/gateway.go b/pkg/cloud/gateway/gateway.go index 5db46cf4..fd857609 100644 --- a/pkg/cloud/gateway/gateway.go +++ b/pkg/cloud/gateway/gateway.go @@ -46,7 +46,6 @@ import ( "github.com/nitrictech/cli/pkg/netx" "github.com/nitrictech/cli/pkg/project/localconfig" "github.com/nitrictech/cli/pkg/system" - "github.com/nitrictech/cli/pkg/view/tui" base_http "github.com/nitrictech/nitric/cloud/common/runtime/gateway" @@ -157,6 +156,19 @@ func (s *LocalGatewayService) GetApiAddress(apiName string) string { return "" } +func (s *LocalGatewayService) GetWebsocketAddress(socketName string) string { + s.lock.RLock() + defer s.lock.RUnlock() + + addresses := s.GetWebsocketAddresses() + + if address, ok := addresses[socketName]; ok { + return address + } + + return "" +} + func (s *LocalGatewayService) GetHttpWorkerAddresses() map[string]string { s.lock.RLock() defer s.lock.RUnlock() @@ -349,14 +361,14 @@ func (s *LocalGatewayService) handleWebsocketRequest(socketName string) func(ctx SocketName: socketName, }) if err != nil { - tui.Error.Println(err.Error()) + system.Logf("Websocket error: %s", err.Error()) return } }() err = s.websocketPlugin.RegisterConnection(socketName, connectionId, ws) if err != nil { - tui.Error.Println(err.Error()) + system.Logf("Websocket error: %s", err.Error()) return } @@ -372,7 +384,7 @@ func (s *LocalGatewayService) handleWebsocketRequest(socketName string) func(ctx if err != nil && websocket.IsCloseError(err, 1001, 1005) { break } else if err != nil { - log.Println("read:", err) + system.Logf("websocket read error: %v", err) break } @@ -390,7 +402,7 @@ func (s *LocalGatewayService) handleWebsocketRequest(socketName string) func(ctx }, }) if err != nil { - tui.Error.Println(err.Error()) + system.Logf("Websocket error: %s", err.Error()) return } } @@ -407,13 +419,13 @@ func (s *LocalGatewayService) handleWebsocketRequest(socketName string) func(ctx }, }) if err != nil { - tui.Error.Println(err.Error()) + system.Logf("Websocket error: %s", err.Error()) return } }) if err != nil { if _, ok := err.(websocket.HandshakeError); ok { - tui.Error.Println(err.Error()) + system.Logf("Websocket error: %s", err.Error()) } return diff --git a/pkg/cloud/websites/websites.go b/pkg/cloud/websites/websites.go index 3f583f0d..26afb3be 100644 --- a/pkg/cloud/websites/websites.go +++ b/pkg/cloud/websites/websites.go @@ -31,8 +31,10 @@ import ( "sync" "github.com/asaskevich/EventBus" + "github.com/gorilla/websocket" "github.com/nitrictech/cli/pkg/netx" + "github.com/nitrictech/cli/pkg/system" deploymentspb "github.com/nitrictech/nitric/core/pkg/proto/deployments/v1" ) @@ -55,10 +57,11 @@ type ( ) type LocalWebsiteService struct { - websiteRegLock sync.RWMutex - state State - getApiAddress GetApiAddress - isStartCmd bool + websiteRegLock sync.RWMutex + state State + getApiAddress GetApiAddress + getWebsocketAddress GetApiAddress + isStartCmd bool bus EventBus.Bus } @@ -74,6 +77,22 @@ func (l *LocalWebsiteService) SubscribeToState(fn func(State)) { _ = l.bus.Subscribe(localWebsitesTopic, fn) } +func proxyWebSocketMessages(src, dst *websocket.Conn, errChan chan error) { + for { + messageType, message, err := src.ReadMessage() + if err != nil { + errChan <- err + return + } + + err = dst.WriteMessage(messageType, message) + if err != nil { + errChan <- err + return + } + } +} + // register - Register a new website func (l *LocalWebsiteService) register(website Website, port int) { l.websiteRegLock.Lock() @@ -182,25 +201,72 @@ func (h staticSiteHandler) ServeHTTP(res http.ResponseWriter, req *http.Request) h.serveStatic(res, req) } -// createAPIPathHandler creates a handler for API proxy requests -func (l *LocalWebsiteService) createAPIPathHandler() http.HandlerFunc { - return func(res http.ResponseWriter, req *http.Request) { - apiName := req.PathValue("name") +// websocketPathHandler creates a handler for WebSocket proxy requests +func (l *LocalWebsiteService) websocketPathHandler(w http.ResponseWriter, r *http.Request) { + // Get the WebSocket API name from the request path + apiName := r.PathValue("name") - apiAddress := l.getApiAddress(apiName) - if apiAddress == "" { - http.Error(res, fmt.Sprintf("api %s not found", apiName), http.StatusNotFound) - return - } + // Get the address of the WebSocket API + apiAddress := l.getWebsocketAddress(apiName) + if apiAddress == "" { + http.Error(w, fmt.Sprintf("WebSocket API %s not found", apiName), http.StatusNotFound) + return + } - targetPath := strings.TrimPrefix(req.URL.Path, fmt.Sprintf("/api/%s", apiName)) - targetUrl, _ := url.Parse(apiAddress) + // Dial the backend WebSocket server + targetURL := fmt.Sprintf("ws://%s%s", apiAddress, r.URL.Path) + if r.URL.RawQuery != "" { + targetURL = fmt.Sprintf("%s?%s", targetURL, r.URL.RawQuery) + } + + targetConn, _, err := websocket.DefaultDialer.Dial(targetURL, nil) + if err != nil { + http.Error(w, fmt.Sprintf("Failed to connect to backend WebSocket server: %v", err), http.StatusInternalServerError) + return + } + defer targetConn.Close() - proxy := httputil.NewSingleHostReverseProxy(targetUrl) - req.URL.Path = targetPath + // Upgrade the HTTP connection to a WebSocket connection + upgrader := websocket.Upgrader{} - proxy.ServeHTTP(res, req) + clientConn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + http.Error(w, fmt.Sprintf("Failed to upgrade to WebSocket: %v", err), http.StatusInternalServerError) + return + } + + defer clientConn.Close() + + // Proxy messages between the client and the backend WebSocket server + errChan := make(chan error, 2) + go proxyWebSocketMessages(clientConn, targetConn, errChan) + go proxyWebSocketMessages(targetConn, clientConn, errChan) + + // Wait for an error to occur + err = <-errChan + if err != nil && !errors.Is(err, websocket.ErrCloseSent) { + // Because the error is already proxied through by the connection we can just log the error here + system.Logf("received error on websocket %s: %v", apiName, err) + } +} + +// apiPathHandler creates a handler for API proxy requests +func (l *LocalWebsiteService) apiPathHandler(res http.ResponseWriter, req *http.Request) { + apiName := req.PathValue("name") + + apiAddress := l.getApiAddress(apiName) + if apiAddress == "" { + http.Error(res, fmt.Sprintf("api %s not found", apiName), http.StatusNotFound) + return } + + targetPath := strings.TrimPrefix(req.URL.Path, fmt.Sprintf("/api/%s", apiName)) + targetUrl, _ := url.Parse(apiAddress) + + proxy := httputil.NewSingleHostReverseProxy(targetUrl) + req.URL.Path = targetPath + + proxy.ServeHTTP(res, req) } // createServer creates and configures an HTTP server with the given mux @@ -250,7 +316,10 @@ func (l *LocalWebsiteService) Start(websites []Website) error { mux := http.NewServeMux() // Register the API proxy handler for this website - mux.HandleFunc("/api/{name}/", l.createAPIPathHandler()) + mux.HandleFunc("/api/{name}/", l.apiPathHandler) + + // Register the WebSocket proxy handler for this website + mux.HandleFunc("/ws/{name}", l.websocketPathHandler) // Create the SPA handler for this website spa := staticSiteHandler{ @@ -287,7 +356,10 @@ func (l *LocalWebsiteService) Start(websites []Website) error { mux := http.NewServeMux() // Register the API proxy handler - mux.HandleFunc("/api/{name}/", l.createAPIPathHandler()) + mux.HandleFunc("/api/{name}/", l.apiPathHandler) + + // Register the WebSocket proxy handler for this website + mux.HandleFunc("/ws/{name}", l.websocketPathHandler) // Register the SPA handler for each website for i := range websites { @@ -325,11 +397,12 @@ func (l *LocalWebsiteService) Start(websites []Website) error { return nil } -func NewLocalWebsitesService(getApiAddress GetApiAddress, isStartCmd bool) *LocalWebsiteService { +func NewLocalWebsitesService(getApiAddress GetApiAddress, getWebsocketAddress GetApiAddress, isStartCmd bool) *LocalWebsiteService { return &LocalWebsiteService{ - state: State{}, - bus: EventBus.New(), - getApiAddress: getApiAddress, - isStartCmd: isStartCmd, + state: State{}, + bus: EventBus.New(), + getApiAddress: getApiAddress, + getWebsocketAddress: getWebsocketAddress, + isStartCmd: isStartCmd, } } diff --git a/pkg/dashboard/frontend/src/lib/utils/generate-architecture-data.ts b/pkg/dashboard/frontend/src/lib/utils/generate-architecture-data.ts index 571cdfe6..4bcc970d 100644 --- a/pkg/dashboard/frontend/src/lib/utils/generate-architecture-data.ts +++ b/pkg/dashboard/frontend/src/lib/utils/generate-architecture-data.ts @@ -609,6 +609,23 @@ export function generateArchitectureData(data: WebSocketResponse): { label: `Rewrites to /api/${api.name}`, }) }) + + data.websockets.forEach((websocket) => { + edges.push({ + id: `e-${websocket.name}-websites`, + source: websitesNode.id, + target: `websocket-${websocket.name}`, + animated: true, + markerEnd: { + type: MarkerType.ArrowClosed, + }, + markerStart: { + type: MarkerType.ArrowClosed, + orient: 'auto-start-reverse', + }, + label: `Rewrites to /ws/${websocket.name}`, + }) + }) } data.services.forEach((service) => {