diff --git a/serve/proxy.go b/serve/proxy.go new file mode 100644 index 0000000..f478689 --- /dev/null +++ b/serve/proxy.go @@ -0,0 +1,51 @@ +package serve + +import ( + "context" + "net/http" + "net/netip" + "strings" + + "git.sunturtle.xyz/studio/shotgun/player" +) + +// WithPlayerID is a middleware that adds a player ID to the request context +// based on the X-Forwarded-For header. If there is no such header, or the +// originator addr otherwise cannot be parsed from it, the request fails with +// a 500 error. +func WithPlayerID(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ff := r.Header.Get("X-Forwarded-For") + addr, err := originator(ff) + if err != nil { + http.Error(w, "missing or invalid X-Forwarded-For header; check server configuration", http.StatusInternalServerError) + return + } + id := player.ID(addr.As16()) + ctx := ctxWith(r.Context(), id) + next.ServeHTTP(w, r.WithContext(ctx)) + }) +} + +// Player returns the player ID set by WithPlayerID in the request context. +func PlayerID(ctx context.Context) player.ID { + return ctxValue[player.ID](ctx) +} + +// originator parses the IP of the client that originated a request from the +// content of its X-Forwarded-For header. +func originator(ff string) (netip.Addr, error) { + ff, _, _ = strings.Cut(ff, ",") + return netip.ParseAddr(ff) +} + +type ctxKey[T any] struct{} + +func ctxValue[T any](ctx context.Context) T { + r, _ := ctx.Value(ctxKey[T]{}).(T) + return r +} + +func ctxWith[T any](ctx context.Context, v T) context.Context { + return context.WithValue(ctx, ctxKey[T]{}, v) +} diff --git a/serve/proxy_test.go b/serve/proxy_test.go new file mode 100644 index 0000000..29ca935 --- /dev/null +++ b/serve/proxy_test.go @@ -0,0 +1,12 @@ +package serve + +import "testing" + +func TestOriginator(t *testing.T) { + // We could do plenty of tests here, but the one I really care about is + // that we get an error on an empty string. + _, err := originator("") + if err == nil { + t.Error("originator should have returned an error on an empty string") + } +}