sessions.go (5855B)
1 // Copyright 2012 The Gorilla Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package sessions 6 7 import ( 8 "context" 9 "encoding/gob" 10 "fmt" 11 "net/http" 12 "time" 13 ) 14 15 // Default flashes key. 16 const flashesKey = "_flash" 17 18 // Session -------------------------------------------------------------------- 19 20 // NewSession is called by session stores to create a new session instance. 21 func NewSession(store Store, name string) *Session { 22 return &Session{ 23 Values: make(map[interface{}]interface{}), 24 store: store, 25 name: name, 26 Options: new(Options), 27 } 28 } 29 30 // Session stores the values and optional configuration for a session. 31 type Session struct { 32 // The ID of the session, generated by stores. It should not be used for 33 // user data. 34 ID string 35 // Values contains the user-data for the session. 36 Values map[interface{}]interface{} 37 Options *Options 38 IsNew bool 39 store Store 40 name string 41 } 42 43 // Flashes returns a slice of flash messages from the session. 44 // 45 // A single variadic argument is accepted, and it is optional: it defines 46 // the flash key. If not defined "_flash" is used by default. 47 func (s *Session) Flashes(vars ...string) []interface{} { 48 var flashes []interface{} 49 key := flashesKey 50 if len(vars) > 0 { 51 key = vars[0] 52 } 53 if v, ok := s.Values[key]; ok { 54 // Drop the flashes and return it. 55 delete(s.Values, key) 56 flashes = v.([]interface{}) 57 } 58 return flashes 59 } 60 61 // AddFlash adds a flash message to the session. 62 // 63 // A single variadic argument is accepted, and it is optional: it defines 64 // the flash key. If not defined "_flash" is used by default. 65 func (s *Session) AddFlash(value interface{}, vars ...string) { 66 key := flashesKey 67 if len(vars) > 0 { 68 key = vars[0] 69 } 70 var flashes []interface{} 71 if v, ok := s.Values[key]; ok { 72 flashes = v.([]interface{}) 73 } 74 s.Values[key] = append(flashes, value) 75 } 76 77 // Save is a convenience method to save this session. It is the same as calling 78 // store.Save(request, response, session). You should call Save before writing to 79 // the response or returning from the handler. 80 func (s *Session) Save(r *http.Request, w http.ResponseWriter) error { 81 return s.store.Save(r, w, s) 82 } 83 84 // Name returns the name used to register the session. 85 func (s *Session) Name() string { 86 return s.name 87 } 88 89 // Store returns the session store used to register the session. 90 func (s *Session) Store() Store { 91 return s.store 92 } 93 94 // Registry ------------------------------------------------------------------- 95 96 // sessionInfo stores a session tracked by the registry. 97 type sessionInfo struct { 98 s *Session 99 e error 100 } 101 102 // contextKey is the type used to store the registry in the context. 103 type contextKey int 104 105 // registryKey is the key used to store the registry in the context. 106 const registryKey contextKey = 0 107 108 // GetRegistry returns a registry instance for the current request. 109 func GetRegistry(r *http.Request) *Registry { 110 var ctx = r.Context() 111 registry := ctx.Value(registryKey) 112 if registry != nil { 113 return registry.(*Registry) 114 } 115 newRegistry := &Registry{ 116 request: r, 117 sessions: make(map[string]sessionInfo), 118 } 119 *r = *r.WithContext(context.WithValue(ctx, registryKey, newRegistry)) 120 return newRegistry 121 } 122 123 // Registry stores sessions used during a request. 124 type Registry struct { 125 request *http.Request 126 sessions map[string]sessionInfo 127 } 128 129 // Get registers and returns a session for the given name and session store. 130 // 131 // It returns a new session if there are no sessions registered for the name. 132 func (s *Registry) Get(store Store, name string) (session *Session, err error) { 133 if !isCookieNameValid(name) { 134 return nil, fmt.Errorf("sessions: invalid character in cookie name: %s", name) 135 } 136 if info, ok := s.sessions[name]; ok { 137 session, err = info.s, info.e 138 } else { 139 session, err = store.New(s.request, name) 140 session.name = name 141 s.sessions[name] = sessionInfo{s: session, e: err} 142 } 143 session.store = store 144 return 145 } 146 147 // Save saves all sessions registered for the current request. 148 func (s *Registry) Save(w http.ResponseWriter) error { 149 var errMulti MultiError 150 for name, info := range s.sessions { 151 session := info.s 152 if session.store == nil { 153 errMulti = append(errMulti, fmt.Errorf( 154 "sessions: missing store for session %q", name)) 155 } else if err := session.store.Save(s.request, w, session); err != nil { 156 errMulti = append(errMulti, fmt.Errorf( 157 "sessions: error saving session %q -- %v", name, err)) 158 } 159 } 160 if errMulti != nil { 161 return errMulti 162 } 163 return nil 164 } 165 166 // Helpers -------------------------------------------------------------------- 167 168 func init() { 169 gob.Register([]interface{}{}) 170 } 171 172 // Save saves all sessions used during the current request. 173 func Save(r *http.Request, w http.ResponseWriter) error { 174 return GetRegistry(r).Save(w) 175 } 176 177 // NewCookie returns an http.Cookie with the options set. It also sets 178 // the Expires field calculated based on the MaxAge value, for Internet 179 // Explorer compatibility. 180 func NewCookie(name, value string, options *Options) *http.Cookie { 181 cookie := newCookieFromOptions(name, value, options) 182 if options.MaxAge > 0 { 183 d := time.Duration(options.MaxAge) * time.Second 184 cookie.Expires = time.Now().Add(d) 185 } else if options.MaxAge < 0 { 186 // Set it to the past to expire now. 187 cookie.Expires = time.Unix(1, 0) 188 } 189 return cookie 190 } 191 192 // Error ---------------------------------------------------------------------- 193 194 // MultiError stores multiple errors. 195 // 196 // Borrowed from the App Engine SDK. 197 type MultiError []error 198 199 func (m MultiError) Error() string { 200 s, n := "", 0 201 for _, e := range m { 202 if e != nil { 203 if n == 0 { 204 s = e.Error() 205 } 206 n++ 207 } 208 } 209 switch n { 210 case 0: 211 return "(0 errors)" 212 case 1: 213 return s 214 case 2: 215 return s + " (and 1 other error)" 216 } 217 return fmt.Sprintf("%s (and %d other errors)", s, n-1) 218 }