Add "last access" to access tokens
This commit is contained in:
		
							parent
							
								
									000bf27c87
								
							
						
					
					
						commit
						e596834096
					
				
					 15 changed files with 276 additions and 145 deletions
				
			
		|  | @ -77,6 +77,7 @@ type Config struct { | |||
| 	AuthStartupQueries                   string | ||||
| 	AuthDefault                          user.Permission | ||||
| 	AuthBcryptCost                       int | ||||
| 	AuthStatsQueueWriterInterval         time.Duration | ||||
| 	AttachmentCacheDir                   string | ||||
| 	AttachmentTotalSizeLimit             int64 | ||||
| 	AttachmentFileSizeLimit              int64 | ||||
|  | @ -145,6 +146,7 @@ func NewConfig() *Config { | |||
| 		AuthStartupQueries:                   "", | ||||
| 		AuthDefault:                          user.NewPermission(true, true), | ||||
| 		AuthBcryptCost:                       user.DefaultUserPasswordBcryptCost, | ||||
| 		AuthStatsQueueWriterInterval:         user.DefaultUserStatsQueueWriterInterval, | ||||
| 		AttachmentCacheDir:                   "", | ||||
| 		AttachmentTotalSizeLimit:             DefaultAttachmentTotalSizeLimit, | ||||
| 		AttachmentFileSizeLimit:              DefaultAttachmentFileSizeLimit, | ||||
|  |  | |||
|  | @ -171,7 +171,7 @@ func New(conf *Config) (*Server, error) { | |||
| 	} | ||||
| 	var userManager *user.Manager | ||||
| 	if conf.AuthFile != "" { | ||||
| 		userManager, err = user.NewManager(conf.AuthFile, conf.AuthStartupQueries, conf.AuthDefault, conf.AuthBcryptCost, user.DefaultUserStatsQueueWriterInterval) | ||||
| 		userManager, err = user.NewManager(conf.AuthFile, conf.AuthStartupQueries, conf.AuthDefault, conf.AuthBcryptCost, conf.AuthStatsQueueWriterInterval) | ||||
| 		if err != nil { | ||||
| 			return nil, err | ||||
| 		} | ||||
|  | @ -598,7 +598,7 @@ func (s *Server) handlePublishWithoutResponse(r *http.Request, v *visitor) (*mes | |||
| 	} | ||||
| 	u := v.User() | ||||
| 	if s.userManager != nil && u != nil && u.Tier != nil { | ||||
| 		s.userManager.EnqueueStats(u.ID, v.Stats()) | ||||
| 		go s.userManager.EnqueueStats(u.ID, v.Stats()) | ||||
| 	} | ||||
| 	s.mu.Lock() | ||||
| 	s.messages++ | ||||
|  | @ -1620,7 +1620,7 @@ func (s *Server) authenticate(r *http.Request) (user *user.User, err error) { | |||
| 		return nil, errHTTPUnauthorized | ||||
| 	} | ||||
| 	if strings.HasPrefix(value, "Bearer") { | ||||
| 		return s.authenticateBearerAuth(value) | ||||
| 		return s.authenticateBearerAuth(r, value) | ||||
| 	} | ||||
| 	return s.authenticateBasicAuth(r, value) | ||||
| } | ||||
|  | @ -1634,9 +1634,18 @@ func (s *Server) authenticateBasicAuth(r *http.Request, value string) (user *use | |||
| 	return s.userManager.Authenticate(username, password) | ||||
| } | ||||
| 
 | ||||
| func (s *Server) authenticateBearerAuth(value string) (user *user.User, err error) { | ||||
| func (s *Server) authenticateBearerAuth(r *http.Request, value string) (*user.User, error) { | ||||
| 	token := strings.TrimSpace(strings.TrimPrefix(value, "Bearer")) | ||||
| 	return s.userManager.AuthenticateToken(token) | ||||
| 	u, err := s.userManager.AuthenticateToken(token) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	ip := extractIPAddress(r, s.config.BehindProxy) | ||||
| 	go s.userManager.EnqueueTokenUpdate(token, &user.TokenUpdate{ | ||||
| 		LastAccess: time.Now(), | ||||
| 		LastOrigin: ip, | ||||
| 	}) | ||||
| 	return u, nil | ||||
| } | ||||
| 
 | ||||
| func (s *Server) visitor(ip netip.Addr, user *user.User) *visitor { | ||||
|  |  | |||
|  | @ -6,6 +6,7 @@ import ( | |||
| 	"heckel.io/ntfy/user" | ||||
| 	"heckel.io/ntfy/util" | ||||
| 	"net/http" | ||||
| 	"net/netip" | ||||
| 	"strings" | ||||
| 	"time" | ||||
| ) | ||||
|  | @ -122,10 +123,16 @@ func (s *Server) handleAccountGet(w http.ResponseWriter, _ *http.Request, v *vis | |||
| 		if len(tokens) > 0 { | ||||
| 			response.Tokens = make([]*apiAccountTokenResponse, 0) | ||||
| 			for _, t := range tokens { | ||||
| 				var lastOrigin string | ||||
| 				if t.LastOrigin != netip.IPv4Unspecified() { | ||||
| 					lastOrigin = t.LastOrigin.String() | ||||
| 				} | ||||
| 				response.Tokens = append(response.Tokens, &apiAccountTokenResponse{ | ||||
| 					Token:   t.Value, | ||||
| 					Label:   t.Label, | ||||
| 					Expires: t.Expires.Unix(), | ||||
| 					Token:      t.Value, | ||||
| 					Label:      t.Label, | ||||
| 					LastAccess: t.LastAccess.Unix(), | ||||
| 					LastOrigin: lastOrigin, | ||||
| 					Expires:    t.Expires.Unix(), | ||||
| 				}) | ||||
| 			} | ||||
| 		} | ||||
|  | @ -192,14 +199,16 @@ func (s *Server) handleAccountTokenCreate(w http.ResponseWriter, r *http.Request | |||
| 	if req.Expires != nil { | ||||
| 		expires = time.Unix(*req.Expires, 0) | ||||
| 	} | ||||
| 	token, err := s.userManager.CreateToken(v.User().ID, label, expires) | ||||
| 	token, err := s.userManager.CreateToken(v.User().ID, label, expires, v.IP()) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	response := &apiAccountTokenResponse{ | ||||
| 		Token:   token.Value, | ||||
| 		Label:   token.Label, | ||||
| 		Expires: token.Expires.Unix(), | ||||
| 		Token:      token.Value, | ||||
| 		Label:      token.Label, | ||||
| 		LastAccess: token.LastAccess.Unix(), | ||||
| 		LastOrigin: token.LastOrigin.String(), | ||||
| 		Expires:    token.Expires.Unix(), | ||||
| 	} | ||||
| 	return s.writeJSON(w, response) | ||||
| } | ||||
|  | @ -228,9 +237,11 @@ func (s *Server) handleAccountTokenUpdate(w http.ResponseWriter, r *http.Request | |||
| 		return err | ||||
| 	} | ||||
| 	response := &apiAccountTokenResponse{ | ||||
| 		Token:   token.Value, | ||||
| 		Label:   token.Label, | ||||
| 		Expires: token.Expires.Unix(), | ||||
| 		Token:      token.Value, | ||||
| 		Label:      token.Label, | ||||
| 		LastAccess: token.LastAccess.Unix(), | ||||
| 		LastOrigin: token.LastOrigin.String(), | ||||
| 		Expires:    token.Expires.Unix(), | ||||
| 	} | ||||
| 	return s.writeJSON(w, response) | ||||
| } | ||||
|  |  | |||
|  | @ -7,6 +7,7 @@ import ( | |||
| 	"heckel.io/ntfy/user" | ||||
| 	"heckel.io/ntfy/util" | ||||
| 	"io" | ||||
| 	"net/netip" | ||||
| 	"strings" | ||||
| 	"testing" | ||||
| 	"time" | ||||
|  | @ -28,6 +29,10 @@ func TestAccount_Signup_Success(t *testing.T) { | |||
| 	token, _ := util.UnmarshalJSON[apiAccountTokenResponse](io.NopCloser(rr.Body)) | ||||
| 	require.NotEmpty(t, token.Token) | ||||
| 	require.True(t, time.Now().Add(71*time.Hour).Unix() < token.Expires) | ||||
| 	require.True(t, strings.HasPrefix(token.Token, "tk_")) | ||||
| 	require.Equal(t, "9.9.9.9", token.LastOrigin) | ||||
| 	require.True(t, token.LastAccess > time.Now().Unix()-1) | ||||
| 	require.True(t, token.LastAccess < time.Now().Unix()+1) | ||||
| 
 | ||||
| 	rr = request(t, s, "GET", "/v1/account", "", map[string]string{ | ||||
| 		"Authorization": util.BearerAuth(token.Token), | ||||
|  | @ -161,7 +166,7 @@ func TestAccount_ChangeSettings(t *testing.T) { | |||
| 
 | ||||
| 	require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser)) | ||||
| 	u, _ := s.userManager.User("phil") | ||||
| 	token, _ := s.userManager.CreateToken(u.ID, "", time.Unix(0, 0)) | ||||
| 	token, _ := s.userManager.CreateToken(u.ID, "", time.Unix(0, 0), netip.IPv4Unspecified()) | ||||
| 
 | ||||
| 	rr := request(t, s, "PATCH", "/v1/account/settings", `{"notification": {"sound": "juntos"},"ignored": true}`, map[string]string{ | ||||
| 		"Authorization": util.BasicAuth("phil", "phil"), | ||||
|  | @ -558,7 +563,7 @@ func TestAccount_Reservation_Add_Kills_Other_Subscribers(t *testing.T) { | |||
| 	// Subscribe anonymously | ||||
| 	anonCh, userCh := make(chan bool), make(chan bool) | ||||
| 	go func() { | ||||
| 		rr := request(t, s, "GET", "/mytopic/json", ``, nil) | ||||
| 		rr := request(t, s, "GET", "/mytopic/json", ``, nil) // This blocks until it's killed! | ||||
| 		require.Equal(t, 200, rr.Code) | ||||
| 		messages := toMessages(t, rr.Body.String()) | ||||
| 		require.Equal(t, 2, len(messages)) // This is the meat. We should NOT receive the second message! | ||||
|  | @ -570,7 +575,7 @@ func TestAccount_Reservation_Add_Kills_Other_Subscribers(t *testing.T) { | |||
| 
 | ||||
| 	// Subscribe with user | ||||
| 	go func() { | ||||
| 		rr := request(t, s, "GET", "/mytopic/json", ``, map[string]string{ | ||||
| 		rr := request(t, s, "GET", "/mytopic/json", ``, map[string]string{ // Blocks! | ||||
| 			"Authorization": util.BasicAuth("phil", "mypass"), | ||||
| 		}) | ||||
| 		require.Equal(t, 200, rr.Code) | ||||
|  | @ -584,10 +589,10 @@ func TestAccount_Reservation_Add_Kills_Other_Subscribers(t *testing.T) { | |||
| 	}() | ||||
| 
 | ||||
| 	// Publish message (before reservation) | ||||
| 	time.Sleep(time.Second) // Wait for subscribers | ||||
| 	time.Sleep(2 * time.Second) // Wait for subscribers | ||||
| 	rr = request(t, s, "POST", "/mytopic", "message before reservation", nil) | ||||
| 	require.Equal(t, 200, rr.Code) | ||||
| 	time.Sleep(time.Second) // Wait for subscribers to receive message | ||||
| 	time.Sleep(2 * time.Second) // Wait for subscribers to receive message | ||||
| 
 | ||||
| 	// Reserve a topic | ||||
| 	rr = request(t, s, "POST", "/v1/account/reservation", `{"topic": "mytopic", "everyone":"deny-all"}`, map[string]string{ | ||||
|  | @ -596,7 +601,11 @@ func TestAccount_Reservation_Add_Kills_Other_Subscribers(t *testing.T) { | |||
| 	require.Equal(t, 200, rr.Code) | ||||
| 
 | ||||
| 	// Everyone but phil should be killed | ||||
| 	<-anonCh | ||||
| 	select { | ||||
| 	case <-anonCh: | ||||
| 	case <-time.After(5 * time.Second): | ||||
| 		t.Fatal("Waiting for anonymous subscription to be killed failed") | ||||
| 	} | ||||
| 
 | ||||
| 	// Publish a message | ||||
| 	rr = request(t, s, "POST", "/mytopic", "message after reservation", map[string]string{ | ||||
|  | @ -606,62 +615,10 @@ func TestAccount_Reservation_Add_Kills_Other_Subscribers(t *testing.T) { | |||
| 
 | ||||
| 	// Kill user Go routine | ||||
| 	s.topics["mytopic"].CancelSubscribers("<invalid>") | ||||
| 	<-userCh | ||||
| } | ||||
| 
 | ||||
| func TestAccount_Tier_Create(t *testing.T) { | ||||
| 	conf := newTestConfigWithAuthFile(t) | ||||
| 	s := newTestServer(t, conf) | ||||
| 
 | ||||
| 	// Create tier and user | ||||
| 	require.Nil(t, s.userManager.CreateTier(&user.Tier{ | ||||
| 		Code:                     "pro", | ||||
| 		Name:                     "Pro", | ||||
| 		MessageLimit:             123, | ||||
| 		MessageExpiryDuration:    86400 * time.Second, | ||||
| 		EmailLimit:               32, | ||||
| 		ReservationLimit:         2, | ||||
| 		AttachmentFileSizeLimit:  1231231, | ||||
| 		AttachmentTotalSizeLimit: 123123, | ||||
| 		AttachmentExpiryDuration: 10800 * time.Second, | ||||
| 		AttachmentBandwidthLimit: 21474836480, | ||||
| 	})) | ||||
| 	require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser)) | ||||
| 	require.Nil(t, s.userManager.ChangeTier("phil", "pro")) | ||||
| 
 | ||||
| 	ti, err := s.userManager.Tier("pro") | ||||
| 	require.Nil(t, err) | ||||
| 
 | ||||
| 	u, err := s.userManager.User("phil") | ||||
| 	require.Nil(t, err) | ||||
| 
 | ||||
| 	// These are populated by different SQL queries | ||||
| 	require.Equal(t, ti, u.Tier) | ||||
| 
 | ||||
| 	// Fields | ||||
| 	require.True(t, strings.HasPrefix(ti.ID, "ti_")) | ||||
| 	require.Equal(t, "pro", ti.Code) | ||||
| 	require.Equal(t, "Pro", ti.Name) | ||||
| 	require.Equal(t, int64(123), ti.MessageLimit) | ||||
| 	require.Equal(t, 86400*time.Second, ti.MessageExpiryDuration) | ||||
| 	require.Equal(t, int64(32), ti.EmailLimit) | ||||
| 	require.Equal(t, int64(2), ti.ReservationLimit) | ||||
| 	require.Equal(t, int64(1231231), ti.AttachmentFileSizeLimit) | ||||
| 	require.Equal(t, int64(123123), ti.AttachmentTotalSizeLimit) | ||||
| 	require.Equal(t, 10800*time.Second, ti.AttachmentExpiryDuration) | ||||
| 	require.Equal(t, int64(21474836480), ti.AttachmentBandwidthLimit) | ||||
| } | ||||
| 
 | ||||
| func TestAccount_Tier_Create_With_ID(t *testing.T) { | ||||
| 	conf := newTestConfigWithAuthFile(t) | ||||
| 	s := newTestServer(t, conf) | ||||
| 
 | ||||
| 	require.Nil(t, s.userManager.CreateTier(&user.Tier{ | ||||
| 		ID:   "ti_123", | ||||
| 		Code: "pro", | ||||
| 	})) | ||||
| 
 | ||||
| 	ti, err := s.userManager.Tier("pro") | ||||
| 	require.Nil(t, err) | ||||
| 	require.Equal(t, "ti_123", ti.ID) | ||||
| 
 | ||||
| 	select { | ||||
| 	case <-userCh: | ||||
| 	case <-time.After(5 * time.Second): | ||||
| 		t.Fatal("Waiting for user subscription to be killed failed") | ||||
| 	} | ||||
| } | ||||
|  |  | |||
|  | @ -258,11 +258,6 @@ func TestPayments_Checkout_Success_And_Increase_Rate_Limits_Reset_Visitor(t *tes | |||
| 	c.StripeWebhookKey = "webhook key" | ||||
| 	c.VisitorRequestLimitBurst = 5 | ||||
| 	c.VisitorRequestLimitReplenish = time.Hour | ||||
| 	c.CacheStartupQueries = ` | ||||
| pragma journal_mode = WAL; | ||||
| pragma synchronous = normal; | ||||
| pragma temp_store = memory; | ||||
| ` | ||||
| 	c.CacheBatchSize = 500 | ||||
| 	c.CacheBatchTimeout = time.Second | ||||
| 	s := newTestServer(t, c) | ||||
|  | @ -324,6 +319,18 @@ pragma temp_store = memory; | |||
| 	}) | ||||
| 	require.Equal(t, 429, rr.Code) | ||||
| 
 | ||||
| 	// Verify some "before-stats" | ||||
| 	u, err = s.userManager.User("phil") | ||||
| 	require.Nil(t, err) | ||||
| 	require.Nil(t, u.Tier) | ||||
| 	require.Equal(t, "", u.Billing.StripeCustomerID) | ||||
| 	require.Equal(t, "", u.Billing.StripeSubscriptionID) | ||||
| 	require.Equal(t, stripe.SubscriptionStatus(""), u.Billing.StripeSubscriptionStatus) | ||||
| 	require.Equal(t, int64(0), u.Billing.StripeSubscriptionPaidUntil.Unix()) | ||||
| 	require.Equal(t, int64(0), u.Billing.StripeSubscriptionCancelAt.Unix()) | ||||
| 	require.Equal(t, int64(0), u.Stats.Messages) // Messages and emails are not persisted for no-tier users! | ||||
| 	require.Equal(t, int64(0), u.Stats.Emails) | ||||
| 
 | ||||
| 	// Simulate Stripe success return URL call (no user context) | ||||
| 	rr = request(t, s, "GET", "/v1/account/billing/subscription/success/SOMETOKEN", "", nil) | ||||
| 	require.Equal(t, 303, rr.Code) | ||||
|  | @ -337,6 +344,8 @@ pragma temp_store = memory; | |||
| 	require.Equal(t, stripe.SubscriptionStatusActive, u.Billing.StripeSubscriptionStatus) | ||||
| 	require.Equal(t, int64(123456789), u.Billing.StripeSubscriptionPaidUntil.Unix()) | ||||
| 	require.Equal(t, int64(0), u.Billing.StripeSubscriptionCancelAt.Unix()) | ||||
| 	require.Equal(t, int64(0), u.Stats.Messages) | ||||
| 	require.Equal(t, int64(0), u.Stats.Emails) | ||||
| 
 | ||||
| 	// Now for the fun part: Verify that new rate limits are immediately applied | ||||
| 	// This only tests the request limiter, which kicks in before the message limiter. | ||||
|  |  | |||
|  | @ -892,10 +892,8 @@ func TestServer_DailyMessageQuotaFromDatabase(t *testing.T) { | |||
| 	// if the visitor is unknown | ||||
| 
 | ||||
| 	c := newTestConfigWithAuthFile(t) | ||||
| 	c.AuthStatsQueueWriterInterval = 100 * time.Millisecond | ||||
| 	s := newTestServer(t, c) | ||||
| 	var err error | ||||
| 	s.userManager, err = user.NewManager(c.AuthFile, c.AuthStartupQueries, c.AuthDefault, c.AuthBcryptCost, 100*time.Millisecond) | ||||
| 	require.Nil(t, err) | ||||
| 
 | ||||
| 	// Create user, and update it with some message and email stats | ||||
| 	require.Nil(t, s.userManager.CreateTier(&user.Tier{ | ||||
|  |  | |||
|  | @ -247,9 +247,11 @@ type apiAccountTokenUpdateRequest struct { | |||
| } | ||||
| 
 | ||||
| type apiAccountTokenResponse struct { | ||||
| 	Token   string `json:"token"` | ||||
| 	Label   string `json:"label,omitempty"` | ||||
| 	Expires int64  `json:"expires,omitempty"` // Unix timestamp | ||||
| 	Token      string `json:"token"` | ||||
| 	Label      string `json:"label,omitempty"` | ||||
| 	LastAccess int64  `json:"last_access,omitempty"` | ||||
| 	LastOrigin string `json:"last_origin,omitempty"` | ||||
| 	Expires    int64  `json:"expires,omitempty"` // Unix timestamp | ||||
| } | ||||
| 
 | ||||
| type apiAccountTier struct { | ||||
|  |  | |||
|  | @ -131,7 +131,7 @@ func newVisitor(conf *Config, messageCache *messageCache, userManager *user.Mana | |||
| 		bandwidthLimiter:    nil, // Set in resetLimiters | ||||
| 		accountLimiter:      nil, // Set in resetLimiters, may be nil | ||||
| 	} | ||||
| 	v.resetLimitersNoLock(messages, emails) | ||||
| 	v.resetLimitersNoLock(messages, emails, false) | ||||
| 	return v | ||||
| } | ||||
| 
 | ||||
|  | @ -254,6 +254,13 @@ func (v *visitor) User() *user.User { | |||
| 	return v.user // May be nil | ||||
| } | ||||
| 
 | ||||
| // IP returns the visitor IP address | ||||
| func (v *visitor) IP() netip.Addr { | ||||
| 	v.mu.Lock() | ||||
| 	defer v.mu.Unlock() | ||||
| 	return v.ip | ||||
| } | ||||
| 
 | ||||
| // Authenticated returns true if a user successfully authenticated | ||||
| func (v *visitor) Authenticated() bool { | ||||
| 	v.mu.Lock() | ||||
|  | @ -268,7 +275,7 @@ func (v *visitor) SetUser(u *user.User) { | |||
| 	shouldResetLimiters := v.user.TierID() != u.TierID() // TierID works with nil receiver | ||||
| 	v.user = u | ||||
| 	if shouldResetLimiters { | ||||
| 		v.resetLimitersNoLock(0, 0) | ||||
| 		v.resetLimitersNoLock(0, 0, true) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
|  | @ -283,7 +290,7 @@ func (v *visitor) MaybeUserID() string { | |||
| 	return "" | ||||
| } | ||||
| 
 | ||||
| func (v *visitor) resetLimitersNoLock(messages, emails int64) { | ||||
| func (v *visitor) resetLimitersNoLock(messages, emails int64, enqueueUpdate bool) { | ||||
| 	log.Debug("%s Resetting limiters for visitor", v.stringNoLock()) | ||||
| 	limits := v.limitsNoLock() | ||||
| 	v.requestLimiter = rate.NewLimiter(limits.RequestLimitReplenish, limits.RequestLimitBurst) | ||||
|  | @ -295,6 +302,13 @@ func (v *visitor) resetLimitersNoLock(messages, emails int64) { | |||
| 	} else { | ||||
| 		v.accountLimiter = nil // Users cannot create accounts when logged in | ||||
| 	} | ||||
| 	/* | ||||
| 		if enqueueUpdate && v.user != nil { | ||||
| 			go v.userManager.EnqueueStats(v.user.ID, &user.Stats{ | ||||
| 				Messages: messages, | ||||
| 				Emails:   emails, | ||||
| 			}) | ||||
| 		}*/ | ||||
| } | ||||
| 
 | ||||
| func (v *visitor) Limits() *visitorLimits { | ||||
|  | @ -361,7 +375,7 @@ func (v *visitor) Info() (*visitorInfo, error) { | |||
| 	if u != nil { | ||||
| 		attachmentsBytesUsed, err = v.messageCache.AttachmentBytesUsedByUser(u.ID) | ||||
| 	} else { | ||||
| 		attachmentsBytesUsed, err = v.messageCache.AttachmentBytesUsedBySender(v.ip.String()) | ||||
| 		attachmentsBytesUsed, err = v.messageCache.AttachmentBytesUsedBySender(v.IP().String()) | ||||
| 	} | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
|  |  | |||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue