diff --git a/zooid/api.go b/zooid/api.go index 4a7e3a8..92cffd9 100644 --- a/zooid/api.go +++ b/zooid/api.go @@ -19,6 +19,7 @@ import ( type APIHandler struct { whitelist map[string]bool configDir string + mux http.Handler } // NewAPIHandler creates a new API handler with the given whitelist @@ -30,78 +31,48 @@ func NewAPIHandler(whitelist string, configDir string) *APIHandler { w[pubkey] = true } } - return &APIHandler{ + api := &APIHandler{ whitelist: w, configDir: configDir, } + api.mux = api.buildMux() + return api +} + +func (api *APIHandler) buildMux() http.Handler { + mux := http.NewServeMux() + mux.HandleFunc("POST /relay/{id}", api.auth(api.createRelay)) + mux.HandleFunc("PUT /relay/{id}", api.auth(api.updateRelay)) + mux.HandleFunc("PATCH /relay/{id}", api.auth(api.patchRelay)) + mux.HandleFunc("DELETE /relay/{id}", api.auth(api.deleteRelay)) + mux.HandleFunc("GET /relay/{id}/members", api.auth(api.listRelayMembers)) + return mux +} + +func (api *APIHandler) auth(next http.HandlerFunc) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + pubkey, err := validateNIP98Auth(r) + if err != nil { + writeError(w, http.StatusUnauthorized, err.Error()) + return + } + if !api.whitelist[pubkey.Hex()] { + writeError(w, http.StatusForbidden, "pubkey not in whitelist") + return + } + next(w, r) + } } // ServeHTTP implements the http.Handler interface func (api *APIHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") - - // Authenticate the request using NIP-98 - pubkey, err := validateNIP98Auth(r) - if err != nil { - writeError(w, http.StatusUnauthorized, err.Error()) - return - } - - // Check if pubkey is in whitelist - if !api.whitelist[pubkey.Hex()] { - writeError(w, http.StatusForbidden, "pubkey not in whitelist") - return - } - - // Route the request - path := strings.TrimPrefix(r.URL.Path, "/") - parts := strings.Split(path, "/") - - if len(parts) < 2 || parts[0] != "relay" { - writeError(w, http.StatusNotFound, "not found") - return - } - - id := parts[1] - if id == "" { - writeError(w, http.StatusBadRequest, "relay id is required") - return - } - - if len(parts) > 2 { - if len(parts) == 3 && parts[2] == "members" { - if r.Method != http.MethodGet { - writeError(w, http.StatusMethodNotAllowed, "method not allowed") - return - } - - api.listRelayMembers(w, id) - return - } - - // Keep trailing-slash compatibility for existing /relay/{id}/ calls. - if len(parts) != 3 || parts[2] != "" { - writeError(w, http.StatusNotFound, "not found") - return - } - } - - switch r.Method { - case http.MethodPost: - api.createRelay(w, r, id) - case http.MethodPut: - api.updateRelay(w, r, id) - case http.MethodPatch: - api.patchRelay(w, r, id) - case http.MethodDelete: - api.deleteRelay(w, r, id) - default: - writeError(w, http.StatusMethodNotAllowed, "method not allowed") - } + api.mux.ServeHTTP(w, r) } // listRelayMembers returns members for a relay as an array of pubkeys. -func (api *APIHandler) listRelayMembers(w http.ResponseWriter, id string) { +func (api *APIHandler) listRelayMembers(w http.ResponseWriter, r *http.Request) { + id := r.PathValue("id") members, err := api.resolveRelayMembers(id) if err != nil { if os.IsNotExist(err) { @@ -179,7 +150,8 @@ func scheme(r *http.Request) string { } // createRelay creates a new relay config file -func (api *APIHandler) createRelay(w http.ResponseWriter, r *http.Request, id string) { +func (api *APIHandler) createRelay(w http.ResponseWriter, r *http.Request) { + id := r.PathValue("id") configPath := api.configPath(id) if _, err := os.Stat(configPath); err == nil { @@ -207,7 +179,8 @@ func (api *APIHandler) createRelay(w http.ResponseWriter, r *http.Request, id st } // updateRelay updates an existing relay config file -func (api *APIHandler) updateRelay(w http.ResponseWriter, r *http.Request, id string) { +func (api *APIHandler) updateRelay(w http.ResponseWriter, r *http.Request) { + id := r.PathValue("id") configPath := api.configPath(id) if err := api.checkConfigExists(configPath); err != nil { @@ -239,7 +212,8 @@ func (api *APIHandler) updateRelay(w http.ResponseWriter, r *http.Request, id st } // patchRelay partially updates an existing relay config -func (api *APIHandler) patchRelay(w http.ResponseWriter, r *http.Request, id string) { +func (api *APIHandler) patchRelay(w http.ResponseWriter, r *http.Request) { + id := r.PathValue("id") configPath := api.configPath(id) if err := api.checkConfigExists(configPath); err != nil { @@ -382,7 +356,8 @@ func (api *APIHandler) validatePatchedConfig(config *Config) error { } // deleteRelay deletes a relay config file -func (api *APIHandler) deleteRelay(w http.ResponseWriter, r *http.Request, id string) { +func (api *APIHandler) deleteRelay(w http.ResponseWriter, r *http.Request) { + id := r.PathValue("id") configPath := api.configPath(id) if err := api.checkConfigExists(configPath); err != nil { diff --git a/zooid/api_test.go b/zooid/api_test.go index 1ab6b4c..2e616c6 100644 --- a/zooid/api_test.go +++ b/zooid/api_test.go @@ -809,8 +809,8 @@ func TestAPIHandler_InvalidPath(t *testing.T) { api.ServeHTTP(w, req) - if w.Code != http.StatusBadRequest { - t.Errorf("expected status %d, got %d", http.StatusBadRequest, w.Code) + if w.Code != http.StatusNotFound { + t.Errorf("expected status %d, got %d", http.StatusNotFound, w.Code) } }) }