From 570fc6a2985807a10d117876050826d55dcabbb0 Mon Sep 17 00:00:00 2001 From: Branden J Brown Date: Thu, 1 Feb 2024 08:26:27 -0600 Subject: [PATCH] get player ids from sessions, not ips --- main.go | 14 +++++++++++- player/session.go | 9 ++++++++ serve/proxy.go | 53 -------------------------------------------- serve/proxy_test.go | 12 ---------- serve/session.go | 54 +++++++++++++++++++++++++++++++++++++++++++++ 5 files changed, 76 insertions(+), 66 deletions(-) delete mode 100644 serve/proxy.go delete mode 100644 serve/proxy_test.go create mode 100644 serve/session.go diff --git a/main.go b/main.go index e83989d..be2e77c 100644 --- a/main.go +++ b/main.go @@ -1,19 +1,31 @@ package main import ( + "context" "net/http" "github.com/go-chi/chi/v5" + "gitlab.com/zephyrtronium/sq" "git.sunturtle.xyz/studio/shotgun/lobby" "git.sunturtle.xyz/studio/shotgun/serve" + + _ "modernc.org/sqlite" ) func main() { s := Server{ l: lobby.New(), } + sessiondb, err := sq.Open("sqlite", ":memory:") + if err != nil { + panic(err) + } + sessions, err := sessiondb.Conn(context.Background()) + if err != nil { + panic(err) + } r := chi.NewRouter() - r.With(serve.WithPlayerID).Get("/queue", s.Queue) + r.With(serve.WithPlayerID(sessions)).Get("/queue", s.Queue) http.ListenAndServe(":8080", r) } diff --git a/player/session.go b/player/session.go index c707d43..145d96d 100644 --- a/player/session.go +++ b/player/session.go @@ -15,6 +15,15 @@ type Session struct { uuid.UUID } +// ParseSession parses a session ID. +func ParseSession(s string) (Session, error) { + id, err := uuid.Parse(s) + if err != nil { + return Session{}, fmt.Errorf("couldn't parse session ID: %w", err) + } + return Session{id}, nil +} + // InitSessions initializes an SQLite table relating player IDs to sessions. func InitSessions(ctx context.Context, db Execer) error { _, err := db.Exec(ctx, initSessions) diff --git a/serve/proxy.go b/serve/proxy.go deleted file mode 100644 index 6a89f27..0000000 --- a/serve/proxy.go +++ /dev/null @@ -1,53 +0,0 @@ -package serve - -import ( - "context" - "net/http" - "net/netip" - "strings" - - "github.com/google/uuid" - - "git.sunturtle.xyz/studio/shotgun/player" -) - -// WithPlayerID is a middleware that adds a player ID to the request context -// based on the X-Forwarded-For header. If there is no such header, or the -// originator addr otherwise cannot be parsed from it, the request fails with -// a 500 error. -func WithPlayerID(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - ff := r.Header.Get("X-Forwarded-For") - addr, err := originator(ff) - if err != nil { - http.Error(w, "missing or invalid X-Forwarded-For header; check server configuration", http.StatusInternalServerError) - return - } - id := player.ID{UUID: uuid.UUID(addr.As16())} - ctx := ctxWith(r.Context(), id) - next.ServeHTTP(w, r.WithContext(ctx)) - }) -} - -// Player returns the player ID set by WithPlayerID in the request context. -func PlayerID(ctx context.Context) player.ID { - return ctxValue[player.ID](ctx) -} - -// originator parses the IP of the client that originated a request from the -// content of its X-Forwarded-For header. -func originator(ff string) (netip.Addr, error) { - ff, _, _ = strings.Cut(ff, ",") - return netip.ParseAddr(ff) -} - -type ctxKey[T any] struct{} - -func ctxValue[T any](ctx context.Context) T { - r, _ := ctx.Value(ctxKey[T]{}).(T) - return r -} - -func ctxWith[T any](ctx context.Context, v T) context.Context { - return context.WithValue(ctx, ctxKey[T]{}, v) -} diff --git a/serve/proxy_test.go b/serve/proxy_test.go deleted file mode 100644 index 29ca935..0000000 --- a/serve/proxy_test.go +++ /dev/null @@ -1,12 +0,0 @@ -package serve - -import "testing" - -func TestOriginator(t *testing.T) { - // We could do plenty of tests here, but the one I really care about is - // that we get an error on an empty string. - _, err := originator("") - if err == nil { - t.Error("originator should have returned an error on an empty string") - } -} diff --git a/serve/session.go b/serve/session.go new file mode 100644 index 0000000..3fa807f --- /dev/null +++ b/serve/session.go @@ -0,0 +1,54 @@ +package serve + +import ( + "context" + "net/http" + "time" + + "git.sunturtle.xyz/studio/shotgun/player" +) + +type ctxKey[T any] struct{} + +func value[T any](ctx context.Context) T { + r, _ := ctx.Value(ctxKey[T]{}).(T) + return r +} + +func with[T any](ctx context.Context, v T) context.Context { + return context.WithValue(ctx, ctxKey[T]{}, v) +} + +const sessionCookie = "__Host-id-v1" + +// WithPlayerID is a middleware that adds a player ID to the request context +// based on the session cookie content. If there is no such cookie, or its +// value is invalid, the request fails with a 403 error. +func WithPlayerID(sessions player.RowQuerier) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + c, err := r.Cookie(sessionCookie) + if err != nil { + http.Error(w, "Forbidden", http.StatusForbidden) + return + } + id, err := player.ParseSession(c.Value) + if err != nil { + http.Error(w, "Forbidden", http.StatusForbidden) + return + } + p, err := player.FromSession(r.Context(), sessions, id, time.Now()) + if err != nil { + http.Error(w, "Forbidden", http.StatusForbidden) + return + } + ctx := with(r.Context(), p) + next.ServeHTTP(w, r.WithContext(ctx)) + }) + } +} + +// Player returns the player ID set by WithPlayerID in the request context. +func PlayerID(ctx context.Context) player.ID { + return value[player.ID](ctx) +}