Payment checkout test, rate limit resetting on tier change; failing
parent
236254d907
commit
593e0748a8
|
@ -46,8 +46,8 @@ const (
|
||||||
DefaultVisitorRequestLimitReplenish = 5 * time.Second
|
DefaultVisitorRequestLimitReplenish = 5 * time.Second
|
||||||
DefaultVisitorEmailLimitBurst = 16
|
DefaultVisitorEmailLimitBurst = 16
|
||||||
DefaultVisitorEmailLimitReplenish = time.Hour
|
DefaultVisitorEmailLimitReplenish = time.Hour
|
||||||
DefaultVisitorAccountCreateLimitBurst = 3
|
DefaultVisitorAccountCreationLimitBurst = 3
|
||||||
DefaultVisitorAccountCreateLimitReplenish = 24 * time.Hour
|
DefaultVisitorAccountCreationLimitReplenish = 24 * time.Hour
|
||||||
DefaultVisitorAttachmentTotalSizeLimit = 100 * 1024 * 1024 // 100 MB
|
DefaultVisitorAttachmentTotalSizeLimit = 100 * 1024 * 1024 // 100 MB
|
||||||
DefaultVisitorAttachmentDailyBandwidthLimit = 500 * 1024 * 1024 // 500 MB
|
DefaultVisitorAttachmentDailyBandwidthLimit = 500 * 1024 * 1024 // 500 MB
|
||||||
)
|
)
|
||||||
|
@ -107,8 +107,8 @@ type Config struct {
|
||||||
VisitorRequestExemptIPAddrs []netip.Prefix
|
VisitorRequestExemptIPAddrs []netip.Prefix
|
||||||
VisitorEmailLimitBurst int
|
VisitorEmailLimitBurst int
|
||||||
VisitorEmailLimitReplenish time.Duration
|
VisitorEmailLimitReplenish time.Duration
|
||||||
VisitorAccountCreateLimitBurst int
|
VisitorAccountCreationLimitBurst int
|
||||||
VisitorAccountCreateLimitReplenish time.Duration
|
VisitorAccountCreationLimitReplenish time.Duration
|
||||||
VisitorStatsResetTime time.Time // Time of the day at which to reset visitor stats
|
VisitorStatsResetTime time.Time // Time of the day at which to reset visitor stats
|
||||||
BehindProxy bool
|
BehindProxy bool
|
||||||
StripeSecretKey string
|
StripeSecretKey string
|
||||||
|
@ -173,8 +173,8 @@ func NewConfig() *Config {
|
||||||
VisitorRequestExemptIPAddrs: make([]netip.Prefix, 0),
|
VisitorRequestExemptIPAddrs: make([]netip.Prefix, 0),
|
||||||
VisitorEmailLimitBurst: DefaultVisitorEmailLimitBurst,
|
VisitorEmailLimitBurst: DefaultVisitorEmailLimitBurst,
|
||||||
VisitorEmailLimitReplenish: DefaultVisitorEmailLimitReplenish,
|
VisitorEmailLimitReplenish: DefaultVisitorEmailLimitReplenish,
|
||||||
VisitorAccountCreateLimitBurst: DefaultVisitorAccountCreateLimitBurst,
|
VisitorAccountCreationLimitBurst: DefaultVisitorAccountCreationLimitBurst,
|
||||||
VisitorAccountCreateLimitReplenish: DefaultVisitorAccountCreateLimitReplenish,
|
VisitorAccountCreationLimitReplenish: DefaultVisitorAccountCreationLimitReplenish,
|
||||||
VisitorStatsResetTime: DefaultVisitorStatsResetTime,
|
VisitorStatsResetTime: DefaultVisitorStatsResetTime,
|
||||||
BehindProxy: false,
|
BehindProxy: false,
|
||||||
StripeSecretKey: "",
|
StripeSecretKey: "",
|
||||||
|
|
|
@ -40,6 +40,8 @@ TODO
|
||||||
|
|
||||||
- HIGH Rate limiting: dailyLimitToRate is wrong? + TESTS
|
- HIGH Rate limiting: dailyLimitToRate is wrong? + TESTS
|
||||||
- HIGH Rate limiting: Sensitive endpoints (account/login/change-password/...)
|
- HIGH Rate limiting: Sensitive endpoints (account/login/change-password/...)
|
||||||
|
- HIGH Rate limiting: Delete visitor when tier is changed to refresh rate limiters
|
||||||
|
- HIGH Rate limiting: When ResetStats() is run, reset messagesLimiter (and others)?
|
||||||
- MEDIUM: Races with v.user (see publishSyncEventAsync test)
|
- MEDIUM: Races with v.user (see publishSyncEventAsync test)
|
||||||
- MEDIUM: Reservation (UI): Show "This topic is reserved" error message when trying to reserve a reserved topic (Thorben)
|
- MEDIUM: Reservation (UI): Show "This topic is reserved" error message when trying to reserve a reserved topic (Thorben)
|
||||||
- MEDIUM: Reservation (UI): Ask for confirmation when removing reservation (deadcade)
|
- MEDIUM: Reservation (UI): Ask for confirmation when removing reservation (deadcade)
|
||||||
|
@ -50,8 +52,6 @@ TODO
|
||||||
|
|
||||||
Limits & rate limiting:
|
Limits & rate limiting:
|
||||||
users without tier: should the stats be persisted? are they meaningful? -> test that the visitor is based on the IP address!
|
users without tier: should the stats be persisted? are they meaningful? -> test that the visitor is based on the IP address!
|
||||||
when ResetStats() is run, reset messagesLimiter (and others)?
|
|
||||||
Delete visitor when tier is changed to refresh rate limiters
|
|
||||||
|
|
||||||
Make sure account endpoints make sense for admins
|
Make sure account endpoints make sense for admins
|
||||||
|
|
||||||
|
@ -1602,9 +1602,7 @@ func (s *Server) visitor(r *http.Request) (v *visitor, err error) {
|
||||||
} else {
|
} else {
|
||||||
v = s.visitorFromIP(ip)
|
v = s.visitorFromIP(ip)
|
||||||
}
|
}
|
||||||
v.mu.Lock()
|
v.SetUser(u) // Update visitor user with latest from database!
|
||||||
v.user = u
|
|
||||||
v.mu.Unlock()
|
|
||||||
return v, err // Always return visitor, even when error occurs!
|
return v, err // Always return visitor, even when error occurs!
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -31,7 +31,7 @@ func (s *Server) handleAccountCreate(w http.ResponseWriter, r *http.Request, v *
|
||||||
if existingUser, _ := s.userManager.User(newAccount.Username); existingUser != nil {
|
if existingUser, _ := s.userManager.User(newAccount.Username); existingUser != nil {
|
||||||
return errHTTPConflictUserExists
|
return errHTTPConflictUserExists
|
||||||
}
|
}
|
||||||
if v.accountLimiter != nil && !v.accountLimiter.Allow() {
|
if err := v.AccountCreationAllowed(); err != nil {
|
||||||
return errHTTPTooManyRequestsLimitAccountCreation
|
return errHTTPTooManyRequestsLimitAccountCreation
|
||||||
}
|
}
|
||||||
if err := s.userManager.AddUser(newAccount.Username, newAccount.Password, user.RoleUser); err != nil { // TODO this should return a User
|
if err := s.userManager.AddUser(newAccount.Username, newAccount.Password, user.RoleUser); err != nil { // TODO this should return a User
|
||||||
|
|
|
@ -6,6 +6,7 @@ import (
|
||||||
"heckel.io/ntfy/user"
|
"heckel.io/ntfy/user"
|
||||||
"heckel.io/ntfy/util"
|
"heckel.io/ntfy/util"
|
||||||
"io"
|
"io"
|
||||||
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
@ -91,6 +92,20 @@ func TestAccount_Signup_Disabled(t *testing.T) {
|
||||||
require.Equal(t, 40022, toHTTPError(t, rr.Body.String()).Code)
|
require.Equal(t, 40022, toHTTPError(t, rr.Body.String()).Code)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestAccount_Signup_Rate_Limit(t *testing.T) {
|
||||||
|
conf := newTestConfigWithAuthFile(t)
|
||||||
|
conf.EnableSignup = true
|
||||||
|
s := newTestServer(t, conf)
|
||||||
|
|
||||||
|
for i := 0; i < 3; i++ {
|
||||||
|
rr := request(t, s, "POST", "/v1/account", fmt.Sprintf(`{"username":"phil%d", "password":"mypass"}`, i), nil)
|
||||||
|
require.Equal(t, 200, rr.Code, "failed on iteration %d", i)
|
||||||
|
}
|
||||||
|
rr := request(t, s, "POST", "/v1/account", `{"username":"notallowed", "password":"mypass"}`, nil)
|
||||||
|
require.Equal(t, 429, rr.Code)
|
||||||
|
require.Equal(t, 42906, toHTTPError(t, rr.Body.String()).Code)
|
||||||
|
}
|
||||||
|
|
||||||
func TestAccount_Get_Anonymous(t *testing.T) {
|
func TestAccount_Get_Anonymous(t *testing.T) {
|
||||||
conf := newTestConfigWithAuthFile(t)
|
conf := newTestConfigWithAuthFile(t)
|
||||||
conf.VisitorRequestLimitReplenish = 86 * time.Second
|
conf.VisitorRequestLimitReplenish = 86 * time.Second
|
||||||
|
@ -567,3 +582,60 @@ func TestAccount_Reservation_Add_Kills_Other_Subscribers(t *testing.T) {
|
||||||
s.topics["mytopic"].CancelSubscribers("<invalid>")
|
s.topics["mytopic"].CancelSubscribers("<invalid>")
|
||||||
<-userCh
|
<-userCh
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestAccount_Tier_Create(t *testing.T) {
|
||||||
|
conf := newTestConfigWithAuthFile(t)
|
||||||
|
s := newTestServer(t, conf)
|
||||||
|
|
||||||
|
// Create tier and user
|
||||||
|
require.Nil(t, s.userManager.CreateTier(&user.Tier{
|
||||||
|
Code: "pro",
|
||||||
|
Name: "Pro",
|
||||||
|
MessagesLimit: 123,
|
||||||
|
MessagesExpiryDuration: 86400 * time.Second,
|
||||||
|
EmailsLimit: 32,
|
||||||
|
ReservationsLimit: 2,
|
||||||
|
AttachmentFileSizeLimit: 1231231,
|
||||||
|
AttachmentTotalSizeLimit: 123123,
|
||||||
|
AttachmentExpiryDuration: 10800 * time.Second,
|
||||||
|
AttachmentBandwidthLimit: 21474836480,
|
||||||
|
}))
|
||||||
|
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser))
|
||||||
|
require.Nil(t, s.userManager.ChangeTier("phil", "pro"))
|
||||||
|
|
||||||
|
ti, err := s.userManager.Tier("pro")
|
||||||
|
require.Nil(t, err)
|
||||||
|
|
||||||
|
u, err := s.userManager.User("phil")
|
||||||
|
require.Nil(t, err)
|
||||||
|
|
||||||
|
// These are populated by different SQL queries
|
||||||
|
require.Equal(t, ti, u.Tier)
|
||||||
|
|
||||||
|
// Fields
|
||||||
|
require.True(t, strings.HasPrefix(ti.ID, "ti_"))
|
||||||
|
require.Equal(t, "pro", ti.Code)
|
||||||
|
require.Equal(t, "Pro", ti.Name)
|
||||||
|
require.Equal(t, int64(123), ti.MessagesLimit)
|
||||||
|
require.Equal(t, 86400*time.Second, ti.MessagesExpiryDuration)
|
||||||
|
require.Equal(t, int64(32), ti.EmailsLimit)
|
||||||
|
require.Equal(t, int64(2), ti.ReservationsLimit)
|
||||||
|
require.Equal(t, int64(1231231), ti.AttachmentFileSizeLimit)
|
||||||
|
require.Equal(t, int64(123123), ti.AttachmentTotalSizeLimit)
|
||||||
|
require.Equal(t, 10800*time.Second, ti.AttachmentExpiryDuration)
|
||||||
|
require.Equal(t, int64(21474836480), ti.AttachmentBandwidthLimit)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAccount_Tier_Create_With_ID(t *testing.T) {
|
||||||
|
conf := newTestConfigWithAuthFile(t)
|
||||||
|
s := newTestServer(t, conf)
|
||||||
|
|
||||||
|
require.Nil(t, s.userManager.CreateTier(&user.Tier{
|
||||||
|
ID: "ti_123",
|
||||||
|
Code: "pro",
|
||||||
|
}))
|
||||||
|
|
||||||
|
ti, err := s.userManager.Tier("pro")
|
||||||
|
require.Nil(t, err)
|
||||||
|
require.Equal(t, "ti_123", ti.ID)
|
||||||
|
}
|
||||||
|
|
|
@ -133,6 +133,7 @@ func TestPayments_SubscriptionCreate_NotAStripeCustomer_Success(t *testing.T) {
|
||||||
|
|
||||||
// Create tier and user
|
// Create tier and user
|
||||||
require.Nil(t, s.userManager.CreateTier(&user.Tier{
|
require.Nil(t, s.userManager.CreateTier(&user.Tier{
|
||||||
|
ID: "ti_123",
|
||||||
Code: "pro",
|
Code: "pro",
|
||||||
StripePriceID: "price_123",
|
StripePriceID: "price_123",
|
||||||
}))
|
}))
|
||||||
|
@ -168,6 +169,7 @@ func TestPayments_SubscriptionCreate_StripeCustomer_Success(t *testing.T) {
|
||||||
|
|
||||||
// Create tier and user
|
// Create tier and user
|
||||||
require.Nil(t, s.userManager.CreateTier(&user.Tier{
|
require.Nil(t, s.userManager.CreateTier(&user.Tier{
|
||||||
|
ID: "ti_123",
|
||||||
Code: "pro",
|
Code: "pro",
|
||||||
StripePriceID: "price_123",
|
StripePriceID: "price_123",
|
||||||
}))
|
}))
|
||||||
|
@ -209,6 +211,7 @@ func TestPayments_AccountDelete_Cancels_Subscription(t *testing.T) {
|
||||||
|
|
||||||
// Create tier and user
|
// Create tier and user
|
||||||
require.Nil(t, s.userManager.CreateTier(&user.Tier{
|
require.Nil(t, s.userManager.CreateTier(&user.Tier{
|
||||||
|
ID: "ti_123",
|
||||||
Code: "pro",
|
Code: "pro",
|
||||||
StripePriceID: "price_123",
|
StripePriceID: "price_123",
|
||||||
}))
|
}))
|
||||||
|
@ -235,6 +238,106 @@ func TestPayments_AccountDelete_Cancels_Subscription(t *testing.T) {
|
||||||
require.Equal(t, 401, rr.Code)
|
require.Equal(t, 401, rr.Code)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestPayments_Checkout_Success_And_Increase_Ratelimits_Reset_Visitor(t *testing.T) {
|
||||||
|
// This tests a successful checkout flow (not a paying customer -> paying customer),
|
||||||
|
// and also tests that during the upgrade we are RESETTING THE RATE LIMITS of the existing user.
|
||||||
|
|
||||||
|
stripeMock := &testStripeAPI{}
|
||||||
|
defer stripeMock.AssertExpectations(t)
|
||||||
|
|
||||||
|
c := newTestConfigWithAuthFile(t)
|
||||||
|
c.StripeSecretKey = "secret key"
|
||||||
|
c.StripeWebhookKey = "webhook key"
|
||||||
|
c.VisitorRequestLimitBurst = 10
|
||||||
|
c.VisitorRequestLimitReplenish = time.Hour
|
||||||
|
s := newTestServer(t, c)
|
||||||
|
s.stripe = stripeMock
|
||||||
|
|
||||||
|
// Create a user with a Stripe subscription and 3 reservations
|
||||||
|
require.Nil(t, s.userManager.CreateTier(&user.Tier{
|
||||||
|
ID: "ti_123",
|
||||||
|
Code: "starter",
|
||||||
|
StripePriceID: "price_1234",
|
||||||
|
ReservationsLimit: 1,
|
||||||
|
MessagesLimit: 100,
|
||||||
|
MessagesExpiryDuration: time.Hour,
|
||||||
|
}))
|
||||||
|
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser)) // No tier
|
||||||
|
u, err := s.userManager.User("phil")
|
||||||
|
require.Nil(t, err)
|
||||||
|
|
||||||
|
// Define how the mock should react
|
||||||
|
stripeMock.
|
||||||
|
On("GetSession", "SOMETOKEN").
|
||||||
|
Return(&stripe.CheckoutSession{
|
||||||
|
ClientReferenceID: u.ID, // ntfy user ID
|
||||||
|
Customer: &stripe.Customer{
|
||||||
|
ID: "acct_5555",
|
||||||
|
},
|
||||||
|
Subscription: &stripe.Subscription{
|
||||||
|
ID: "sub_1234",
|
||||||
|
},
|
||||||
|
}, nil)
|
||||||
|
stripeMock.
|
||||||
|
On("GetSubscription", "sub_1234").
|
||||||
|
Return(&stripe.Subscription{
|
||||||
|
ID: "sub_1234",
|
||||||
|
Status: stripe.SubscriptionStatusActive,
|
||||||
|
CurrentPeriodEnd: 123456789,
|
||||||
|
CancelAt: 0,
|
||||||
|
Items: &stripe.SubscriptionItemList{
|
||||||
|
Data: []*stripe.SubscriptionItem{
|
||||||
|
{
|
||||||
|
Price: &stripe.Price{ID: "price_1234"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}, nil)
|
||||||
|
stripeMock.
|
||||||
|
On("UpdateCustomer", mock.Anything).
|
||||||
|
Return(&stripe.Customer{}, nil)
|
||||||
|
|
||||||
|
// Send messages until rate limit of free tier is hit
|
||||||
|
for i := 0; i < 10; i++ {
|
||||||
|
rr := request(t, s, "PUT", "/mytopic", "some message", map[string]string{
|
||||||
|
"Authorization": util.BasicAuth("phil", "phil"),
|
||||||
|
})
|
||||||
|
require.Equal(t, 200, rr.Code)
|
||||||
|
}
|
||||||
|
rr := request(t, s, "PUT", "/mytopic", "some message", map[string]string{
|
||||||
|
"Authorization": util.BasicAuth("phil", "phil"),
|
||||||
|
})
|
||||||
|
require.Equal(t, 429, rr.Code)
|
||||||
|
|
||||||
|
// Simulate Stripe success return URL call (no user context)
|
||||||
|
rr = request(t, s, "GET", "/v1/account/billing/subscription/success/SOMETOKEN", "", nil)
|
||||||
|
require.Equal(t, 303, rr.Code)
|
||||||
|
|
||||||
|
// Verify that database columns were updated
|
||||||
|
u, err = s.userManager.User("phil")
|
||||||
|
require.Nil(t, err)
|
||||||
|
require.Equal(t, "starter", u.Tier.Code) // Not "pro"
|
||||||
|
require.Equal(t, "acct_5555", u.Billing.StripeCustomerID)
|
||||||
|
require.Equal(t, "sub_1234", u.Billing.StripeSubscriptionID)
|
||||||
|
require.Equal(t, stripe.SubscriptionStatusActive, u.Billing.StripeSubscriptionStatus)
|
||||||
|
require.Equal(t, int64(123456789), u.Billing.StripeSubscriptionPaidUntil.Unix())
|
||||||
|
require.Equal(t, int64(0), u.Billing.StripeSubscriptionCancelAt.Unix())
|
||||||
|
|
||||||
|
// FIXME FIXME This test is broken, because the rate limit logic is unclear!
|
||||||
|
|
||||||
|
// Now for the fun part: Verify that new rate limits are immediately applied
|
||||||
|
for i := 0; i < 100; i++ {
|
||||||
|
rr := request(t, s, "PUT", "/mytopic", "some message", map[string]string{
|
||||||
|
"Authorization": util.BasicAuth("phil", "phil"),
|
||||||
|
})
|
||||||
|
require.Equal(t, 200, rr.Code, "failed on iteration %d", i)
|
||||||
|
}
|
||||||
|
rr = request(t, s, "PUT", "/mytopic", "some message", map[string]string{
|
||||||
|
"Authorization": util.BasicAuth("phil", "phil"),
|
||||||
|
})
|
||||||
|
require.Equal(t, 429, rr.Code)
|
||||||
|
}
|
||||||
|
|
||||||
func TestPayments_Webhook_Subscription_Updated_Downgrade_From_PastDue_To_Active(t *testing.T) {
|
func TestPayments_Webhook_Subscription_Updated_Downgrade_From_PastDue_To_Active(t *testing.T) {
|
||||||
// This tests incoming webhooks from Stripe to update a subscription:
|
// This tests incoming webhooks from Stripe to update a subscription:
|
||||||
// - All Stripe columns are updated in the user table
|
// - All Stripe columns are updated in the user table
|
||||||
|
@ -257,6 +360,7 @@ func TestPayments_Webhook_Subscription_Updated_Downgrade_From_PastDue_To_Active(
|
||||||
|
|
||||||
// Create a user with a Stripe subscription and 3 reservations
|
// Create a user with a Stripe subscription and 3 reservations
|
||||||
require.Nil(t, s.userManager.CreateTier(&user.Tier{
|
require.Nil(t, s.userManager.CreateTier(&user.Tier{
|
||||||
|
ID: "ti_1",
|
||||||
Code: "starter",
|
Code: "starter",
|
||||||
StripePriceID: "price_1234", // !
|
StripePriceID: "price_1234", // !
|
||||||
ReservationsLimit: 1, // !
|
ReservationsLimit: 1, // !
|
||||||
|
@ -268,6 +372,7 @@ func TestPayments_Webhook_Subscription_Updated_Downgrade_From_PastDue_To_Active(
|
||||||
AttachmentBandwidthLimit: 1000000,
|
AttachmentBandwidthLimit: 1000000,
|
||||||
}))
|
}))
|
||||||
require.Nil(t, s.userManager.CreateTier(&user.Tier{
|
require.Nil(t, s.userManager.CreateTier(&user.Tier{
|
||||||
|
ID: "ti_2",
|
||||||
Code: "pro",
|
Code: "pro",
|
||||||
StripePriceID: "price_1111", // !
|
StripePriceID: "price_1111", // !
|
||||||
ReservationsLimit: 3, // !
|
ReservationsLimit: 3, // !
|
||||||
|
|
|
@ -3,6 +3,7 @@ package server
|
||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"heckel.io/ntfy/log"
|
||||||
"heckel.io/ntfy/user"
|
"heckel.io/ntfy/user"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"sync"
|
"sync"
|
||||||
|
@ -41,7 +42,7 @@ type visitor struct {
|
||||||
emailsLimiter *rate.Limiter // Rate limiter for emails
|
emailsLimiter *rate.Limiter // Rate limiter for emails
|
||||||
subscriptionLimiter util.Limiter // Fixed limiter for active subscriptions (ongoing connections)
|
subscriptionLimiter util.Limiter // Fixed limiter for active subscriptions (ongoing connections)
|
||||||
bandwidthLimiter util.Limiter // Limiter for attachment bandwidth downloads
|
bandwidthLimiter util.Limiter // Limiter for attachment bandwidth downloads
|
||||||
accountLimiter *rate.Limiter // Rate limiter for account creation
|
accountLimiter *rate.Limiter // Rate limiter for account creation, may be nil
|
||||||
firebase time.Time // Next allowed Firebase message
|
firebase time.Time // Next allowed Firebase message
|
||||||
seen time.Time // Last seen time of this visitor (needed for removal of stale visitors)
|
seen time.Time // Last seen time of this visitor (needed for removal of stale visitors)
|
||||||
mu sync.Mutex
|
mu sync.Mutex
|
||||||
|
@ -85,26 +86,12 @@ const (
|
||||||
)
|
)
|
||||||
|
|
||||||
func newVisitor(conf *Config, messageCache *messageCache, userManager *user.Manager, ip netip.Addr, user *user.User) *visitor {
|
func newVisitor(conf *Config, messageCache *messageCache, userManager *user.Manager, ip netip.Addr, user *user.User) *visitor {
|
||||||
var messagesLimiter, attachmentBandwidthLimiter util.Limiter
|
|
||||||
var requestLimiter, emailsLimiter, accountLimiter *rate.Limiter
|
|
||||||
var messages, emails int64
|
var messages, emails int64
|
||||||
if user != nil {
|
if user != nil {
|
||||||
messages = user.Stats.Messages
|
messages = user.Stats.Messages
|
||||||
emails = user.Stats.Emails
|
emails = user.Stats.Emails
|
||||||
} else {
|
|
||||||
accountLimiter = rate.NewLimiter(rate.Every(conf.VisitorAccountCreateLimitReplenish), conf.VisitorAccountCreateLimitBurst)
|
|
||||||
}
|
}
|
||||||
if user != nil && user.Tier != nil {
|
v := &visitor{
|
||||||
requestLimiter = rate.NewLimiter(dailyLimitToRate(user.Tier.MessagesLimit), conf.VisitorRequestLimitBurst)
|
|
||||||
messagesLimiter = util.NewFixedLimiter(user.Tier.MessagesLimit)
|
|
||||||
emailsLimiter = rate.NewLimiter(dailyLimitToRate(user.Tier.EmailsLimit), conf.VisitorEmailLimitBurst)
|
|
||||||
attachmentBandwidthLimiter = util.NewBytesLimiter(int(user.Tier.AttachmentBandwidthLimit), 24*time.Hour)
|
|
||||||
} else {
|
|
||||||
requestLimiter = rate.NewLimiter(rate.Every(conf.VisitorRequestLimitReplenish), conf.VisitorRequestLimitBurst)
|
|
||||||
emailsLimiter = rate.NewLimiter(rate.Every(conf.VisitorEmailLimitReplenish), conf.VisitorEmailLimitBurst)
|
|
||||||
attachmentBandwidthLimiter = util.NewBytesLimiter(int(conf.VisitorAttachmentDailyBandwidthLimit), 24*time.Hour)
|
|
||||||
}
|
|
||||||
return &visitor{
|
|
||||||
config: conf,
|
config: conf,
|
||||||
messageCache: messageCache,
|
messageCache: messageCache,
|
||||||
userManager: userManager, // May be nil
|
userManager: userManager, // May be nil
|
||||||
|
@ -112,20 +99,26 @@ func newVisitor(conf *Config, messageCache *messageCache, userManager *user.Mana
|
||||||
user: user,
|
user: user,
|
||||||
messages: messages,
|
messages: messages,
|
||||||
emails: emails,
|
emails: emails,
|
||||||
requestLimiter: requestLimiter,
|
|
||||||
messagesLimiter: messagesLimiter, // May be nil
|
|
||||||
emailsLimiter: emailsLimiter,
|
|
||||||
subscriptionLimiter: util.NewFixedLimiter(int64(conf.VisitorSubscriptionLimit)),
|
|
||||||
bandwidthLimiter: attachmentBandwidthLimiter,
|
|
||||||
accountLimiter: accountLimiter, // May be nil
|
|
||||||
firebase: time.Unix(0, 0),
|
firebase: time.Unix(0, 0),
|
||||||
seen: time.Now(),
|
seen: time.Now(),
|
||||||
|
subscriptionLimiter: util.NewFixedLimiter(int64(conf.VisitorSubscriptionLimit)),
|
||||||
|
requestLimiter: nil, // Set in resetLimiters
|
||||||
|
messagesLimiter: nil, // Set in resetLimiters, may be nil
|
||||||
|
emailsLimiter: nil, // Set in resetLimiters
|
||||||
|
bandwidthLimiter: nil, // Set in resetLimiters
|
||||||
|
accountLimiter: nil, // Set in resetLimiters, may be nil
|
||||||
}
|
}
|
||||||
|
v.resetLimiters()
|
||||||
|
return v
|
||||||
}
|
}
|
||||||
|
|
||||||
func (v *visitor) String() string {
|
func (v *visitor) String() string {
|
||||||
v.mu.Lock()
|
v.mu.Lock()
|
||||||
defer v.mu.Unlock()
|
defer v.mu.Unlock()
|
||||||
|
return v.stringNoLock()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (v *visitor) stringNoLock() string {
|
||||||
if v.user != nil && v.user.Billing.StripeCustomerID != "" {
|
if v.user != nil && v.user.Billing.StripeCustomerID != "" {
|
||||||
return fmt.Sprintf("%s/%s/%s", v.ip.String(), v.user.ID, v.user.Billing.StripeCustomerID)
|
return fmt.Sprintf("%s/%s/%s", v.ip.String(), v.user.ID, v.user.Billing.StripeCustomerID)
|
||||||
} else if v.user != nil {
|
} else if v.user != nil {
|
||||||
|
@ -179,6 +172,13 @@ func (v *visitor) SubscriptionAllowed() error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (v *visitor) AccountCreationAllowed() error {
|
||||||
|
if v.accountLimiter != nil && !v.accountLimiter.Allow() {
|
||||||
|
return errVisitorLimitReached
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (v *visitor) RemoveSubscription() {
|
func (v *visitor) RemoveSubscription() {
|
||||||
v.mu.Lock()
|
v.mu.Lock()
|
||||||
defer v.mu.Unlock()
|
defer v.mu.Unlock()
|
||||||
|
@ -235,7 +235,35 @@ func (v *visitor) ResetStats() {
|
||||||
func (v *visitor) SetUser(u *user.User) {
|
func (v *visitor) SetUser(u *user.User) {
|
||||||
v.mu.Lock()
|
v.mu.Lock()
|
||||||
defer v.mu.Unlock()
|
defer v.mu.Unlock()
|
||||||
|
shouldResetLimiters := v.user.TierID() != u.TierID() // TierID works with nil receiver
|
||||||
v.user = u
|
v.user = u
|
||||||
|
if shouldResetLimiters {
|
||||||
|
v.resetLimiters()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (v *visitor) resetLimiters() {
|
||||||
|
log.Info("%s Resetting limiters for visitor", v.stringNoLock())
|
||||||
|
var messagesLimiter, bandwidthLimiter util.Limiter
|
||||||
|
var requestLimiter, emailsLimiter, accountLimiter *rate.Limiter
|
||||||
|
if v.user != nil && v.user.Tier != nil {
|
||||||
|
requestLimiter = rate.NewLimiter(dailyLimitToRate(v.user.Tier.MessagesLimit), v.config.VisitorRequestLimitBurst)
|
||||||
|
messagesLimiter = util.NewFixedLimiter(v.user.Tier.MessagesLimit)
|
||||||
|
emailsLimiter = rate.NewLimiter(dailyLimitToRate(v.user.Tier.EmailsLimit), v.config.VisitorEmailLimitBurst)
|
||||||
|
bandwidthLimiter = util.NewBytesLimiter(int(v.user.Tier.AttachmentBandwidthLimit), 24*time.Hour)
|
||||||
|
accountLimiter = nil // A logged-in user cannot create an account
|
||||||
|
} else {
|
||||||
|
requestLimiter = rate.NewLimiter(rate.Every(v.config.VisitorRequestLimitReplenish), v.config.VisitorRequestLimitBurst)
|
||||||
|
messagesLimiter = nil // Message limit is governed by the requestLimiter
|
||||||
|
emailsLimiter = rate.NewLimiter(rate.Every(v.config.VisitorEmailLimitReplenish), v.config.VisitorEmailLimitBurst)
|
||||||
|
bandwidthLimiter = util.NewBytesLimiter(int(v.config.VisitorAttachmentDailyBandwidthLimit), 24*time.Hour)
|
||||||
|
accountLimiter = rate.NewLimiter(rate.Every(v.config.VisitorAccountCreationLimitReplenish), v.config.VisitorAccountCreationLimitBurst)
|
||||||
|
}
|
||||||
|
v.requestLimiter = requestLimiter
|
||||||
|
v.messagesLimiter = messagesLimiter
|
||||||
|
v.emailsLimiter = emailsLimiter
|
||||||
|
v.bandwidthLimiter = bandwidthLimiter
|
||||||
|
v.accountLimiter = accountLimiter
|
||||||
}
|
}
|
||||||
|
|
||||||
// MaybeUserID returns the user ID of the visitor (if any). If this is an anonymous visitor,
|
// MaybeUserID returns the user ID of the visitor (if any). If this is an anonymous visitor,
|
||||||
|
|
|
@ -110,26 +110,26 @@ const (
|
||||||
`
|
`
|
||||||
|
|
||||||
selectUserByIDQuery = `
|
selectUserByIDQuery = `
|
||||||
SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_price_id
|
SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_price_id
|
||||||
FROM user u
|
FROM user u
|
||||||
LEFT JOIN tier t on t.id = u.tier_id
|
LEFT JOIN tier t on t.id = u.tier_id
|
||||||
WHERE u.id = ?
|
WHERE u.id = ?
|
||||||
`
|
`
|
||||||
selectUserByNameQuery = `
|
selectUserByNameQuery = `
|
||||||
SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_price_id
|
SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_price_id
|
||||||
FROM user u
|
FROM user u
|
||||||
LEFT JOIN tier t on t.id = u.tier_id
|
LEFT JOIN tier t on t.id = u.tier_id
|
||||||
WHERE user = ?
|
WHERE user = ?
|
||||||
`
|
`
|
||||||
selectUserByTokenQuery = `
|
selectUserByTokenQuery = `
|
||||||
SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_price_id
|
SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_price_id
|
||||||
FROM user u
|
FROM user u
|
||||||
JOIN user_token t on u.id = t.user_id
|
JOIN user_token t on u.id = t.user_id
|
||||||
LEFT JOIN tier t on t.id = u.tier_id
|
LEFT JOIN tier t on t.id = u.tier_id
|
||||||
WHERE t.token = ? AND t.expires >= ?
|
WHERE t.token = ? AND t.expires >= ?
|
||||||
`
|
`
|
||||||
selectUserByStripeCustomerIDQuery = `
|
selectUserByStripeCustomerIDQuery = `
|
||||||
SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_price_id
|
SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_price_id
|
||||||
FROM user u
|
FROM user u
|
||||||
LEFT JOIN tier t on t.id = u.tier_id
|
LEFT JOIN tier t on t.id = u.tier_id
|
||||||
WHERE u.stripe_customer_id = ?
|
WHERE u.stripe_customer_id = ?
|
||||||
|
@ -669,13 +669,13 @@ func (a *Manager) userByToken(token string) (*User, error) {
|
||||||
func (a *Manager) readUser(rows *sql.Rows) (*User, error) {
|
func (a *Manager) readUser(rows *sql.Rows) (*User, error) {
|
||||||
defer rows.Close()
|
defer rows.Close()
|
||||||
var id, username, hash, role, prefs, syncTopic string
|
var id, username, hash, role, prefs, syncTopic string
|
||||||
var stripeCustomerID, stripeSubscriptionID, stripeSubscriptionStatus, stripePriceID, tierCode, tierName sql.NullString
|
var stripeCustomerID, stripeSubscriptionID, stripeSubscriptionStatus, stripePriceID, tierID, tierCode, tierName sql.NullString
|
||||||
var messages, emails int64
|
var messages, emails int64
|
||||||
var messagesLimit, messagesExpiryDuration, emailsLimit, reservationsLimit, attachmentFileSizeLimit, attachmentTotalSizeLimit, attachmentExpiryDuration, attachmentBandwidthLimit, stripeSubscriptionPaidUntil, stripeSubscriptionCancelAt, deleted sql.NullInt64
|
var messagesLimit, messagesExpiryDuration, emailsLimit, reservationsLimit, attachmentFileSizeLimit, attachmentTotalSizeLimit, attachmentExpiryDuration, attachmentBandwidthLimit, stripeSubscriptionPaidUntil, stripeSubscriptionCancelAt, deleted sql.NullInt64
|
||||||
if !rows.Next() {
|
if !rows.Next() {
|
||||||
return nil, ErrUserNotFound
|
return nil, ErrUserNotFound
|
||||||
}
|
}
|
||||||
if err := rows.Scan(&id, &username, &hash, &role, &prefs, &syncTopic, &messages, &emails, &stripeCustomerID, &stripeSubscriptionID, &stripeSubscriptionStatus, &stripeSubscriptionPaidUntil, &stripeSubscriptionCancelAt, &deleted, &tierCode, &tierName, &messagesLimit, &messagesExpiryDuration, &emailsLimit, &reservationsLimit, &attachmentFileSizeLimit, &attachmentTotalSizeLimit, &attachmentExpiryDuration, &attachmentBandwidthLimit, &stripePriceID); err != nil {
|
if err := rows.Scan(&id, &username, &hash, &role, &prefs, &syncTopic, &messages, &emails, &stripeCustomerID, &stripeSubscriptionID, &stripeSubscriptionStatus, &stripeSubscriptionPaidUntil, &stripeSubscriptionCancelAt, &deleted, &tierID, &tierCode, &tierName, &messagesLimit, &messagesExpiryDuration, &emailsLimit, &reservationsLimit, &attachmentFileSizeLimit, &attachmentTotalSizeLimit, &attachmentExpiryDuration, &attachmentBandwidthLimit, &stripePriceID); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
} else if err := rows.Err(); err != nil {
|
} else if err := rows.Err(); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -706,6 +706,7 @@ func (a *Manager) readUser(rows *sql.Rows) (*User, error) {
|
||||||
if tierCode.Valid {
|
if tierCode.Valid {
|
||||||
// See readTier() when this is changed!
|
// See readTier() when this is changed!
|
||||||
user.Tier = &Tier{
|
user.Tier = &Tier{
|
||||||
|
ID: tierID.String,
|
||||||
Code: tierCode.String,
|
Code: tierCode.String,
|
||||||
Name: tierName.String,
|
Name: tierName.String,
|
||||||
MessagesLimit: messagesLimit.Int64,
|
MessagesLimit: messagesLimit.Int64,
|
||||||
|
@ -995,8 +996,10 @@ func (a *Manager) DefaultAccess() Permission {
|
||||||
|
|
||||||
// CreateTier creates a new tier in the database
|
// CreateTier creates a new tier in the database
|
||||||
func (a *Manager) CreateTier(tier *Tier) error {
|
func (a *Manager) CreateTier(tier *Tier) error {
|
||||||
tierID := util.RandomStringPrefix(tierIDPrefix, tierIDLength)
|
if tier.ID == "" {
|
||||||
if _, err := a.db.Exec(insertTierQuery, tierID, tier.Code, tier.Name, tier.MessagesLimit, int64(tier.MessagesExpiryDuration.Seconds()), tier.EmailsLimit, tier.ReservationsLimit, tier.AttachmentFileSizeLimit, tier.AttachmentTotalSizeLimit, int64(tier.AttachmentExpiryDuration.Seconds()), tier.AttachmentBandwidthLimit, tier.StripePriceID); err != nil {
|
tier.ID = util.RandomStringPrefix(tierIDPrefix, tierIDLength)
|
||||||
|
}
|
||||||
|
if _, err := a.db.Exec(insertTierQuery, tier.ID, tier.Code, tier.Name, tier.MessagesLimit, int64(tier.MessagesExpiryDuration.Seconds()), tier.EmailsLimit, tier.ReservationsLimit, tier.AttachmentFileSizeLimit, tier.AttachmentTotalSizeLimit, int64(tier.AttachmentExpiryDuration.Seconds()), tier.AttachmentBandwidthLimit, tier.StripePriceID); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
|
|
|
@ -23,6 +23,15 @@ type User struct {
|
||||||
Deleted bool
|
Deleted bool
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TierID returns the ID of the User.Tier, or an empty string if the user has no tier,
|
||||||
|
// or if the user itself is nil.
|
||||||
|
func (u *User) TierID() string {
|
||||||
|
if u == nil || u.Tier == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return u.Tier.ID
|
||||||
|
}
|
||||||
|
|
||||||
// Auther is an interface for authentication and authorization
|
// Auther is an interface for authentication and authorization
|
||||||
type Auther interface {
|
type Auther interface {
|
||||||
// Authenticate checks username and password and returns a user if correct. The method
|
// Authenticate checks username and password and returns a user if correct. The method
|
||||||
|
|
Loading…
Reference in New Issue