package auth import ( "net/http" ) // RequireSession redirects unauthenticated requests to /login. func (m *Manager) RequireSession(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if err := m.Validate(r); err != nil { if acceptsHTML(r) { http.Redirect(w, r, "/login?next="+r.URL.RequestURI(), http.StatusSeeOther) return } http.Error(w, "unauthorized", http.StatusUnauthorized) return } next.ServeHTTP(w, r) }) } func acceptsHTML(r *http.Request) bool { accept := r.Header.Get("Accept") if accept == "" { return true } for _, part := range splitComma(accept) { if part == "text/html" || part == "*/*" { return true } } return false } func splitComma(s string) []string { var out []string start := 0 for i := 0; i < len(s); i++ { if s[i] == ',' { out = append(out, trimSpace(s[start:i])) start = i + 1 } else if s[i] == ';' { out = append(out, trimSpace(s[start:i])) for i < len(s) && s[i] != ',' { i++ } start = i + 1 } } if start < len(s) { out = append(out, trimSpace(s[start:])) } return out } func trimSpace(s string) string { for len(s) > 0 && (s[0] == ' ' || s[0] == '\t') { s = s[1:] } for len(s) > 0 && (s[len(s)-1] == ' ' || s[len(s)-1] == '\t') { s = s[:len(s)-1] } return s }