Payment checkout test, rate limit resetting on tier change; failing
parent
236254d907
commit
593e0748a8
|
@ -46,8 +46,8 @@ const (
|
|||
DefaultVisitorRequestLimitReplenish = 5 * time.Second
|
||||
DefaultVisitorEmailLimitBurst = 16
|
||||
DefaultVisitorEmailLimitReplenish = time.Hour
|
||||
DefaultVisitorAccountCreateLimitBurst = 3
|
||||
DefaultVisitorAccountCreateLimitReplenish = 24 * time.Hour
|
||||
DefaultVisitorAccountCreationLimitBurst = 3
|
||||
DefaultVisitorAccountCreationLimitReplenish = 24 * time.Hour
|
||||
DefaultVisitorAttachmentTotalSizeLimit = 100 * 1024 * 1024 // 100 MB
|
||||
DefaultVisitorAttachmentDailyBandwidthLimit = 500 * 1024 * 1024 // 500 MB
|
||||
)
|
||||
|
@ -107,8 +107,8 @@ type Config struct {
|
|||
VisitorRequestExemptIPAddrs []netip.Prefix
|
||||
VisitorEmailLimitBurst int
|
||||
VisitorEmailLimitReplenish time.Duration
|
||||
VisitorAccountCreateLimitBurst int
|
||||
VisitorAccountCreateLimitReplenish time.Duration
|
||||
VisitorAccountCreationLimitBurst int
|
||||
VisitorAccountCreationLimitReplenish time.Duration
|
||||
VisitorStatsResetTime time.Time // Time of the day at which to reset visitor stats
|
||||
BehindProxy bool
|
||||
StripeSecretKey string
|
||||
|
@ -173,8 +173,8 @@ func NewConfig() *Config {
|
|||
VisitorRequestExemptIPAddrs: make([]netip.Prefix, 0),
|
||||
VisitorEmailLimitBurst: DefaultVisitorEmailLimitBurst,
|
||||
VisitorEmailLimitReplenish: DefaultVisitorEmailLimitReplenish,
|
||||
VisitorAccountCreateLimitBurst: DefaultVisitorAccountCreateLimitBurst,
|
||||
VisitorAccountCreateLimitReplenish: DefaultVisitorAccountCreateLimitReplenish,
|
||||
VisitorAccountCreationLimitBurst: DefaultVisitorAccountCreationLimitBurst,
|
||||
VisitorAccountCreationLimitReplenish: DefaultVisitorAccountCreationLimitReplenish,
|
||||
VisitorStatsResetTime: DefaultVisitorStatsResetTime,
|
||||
BehindProxy: false,
|
||||
StripeSecretKey: "",
|
||||
|
|
|
@ -40,6 +40,8 @@ TODO
|
|||
|
||||
- HIGH Rate limiting: dailyLimitToRate is wrong? + TESTS
|
||||
- HIGH Rate limiting: Sensitive endpoints (account/login/change-password/...)
|
||||
- HIGH Rate limiting: Delete visitor when tier is changed to refresh rate limiters
|
||||
- HIGH Rate limiting: When ResetStats() is run, reset messagesLimiter (and others)?
|
||||
- MEDIUM: Races with v.user (see publishSyncEventAsync test)
|
||||
- MEDIUM: Reservation (UI): Show "This topic is reserved" error message when trying to reserve a reserved topic (Thorben)
|
||||
- MEDIUM: Reservation (UI): Ask for confirmation when removing reservation (deadcade)
|
||||
|
@ -50,8 +52,6 @@ TODO
|
|||
|
||||
Limits & rate limiting:
|
||||
users without tier: should the stats be persisted? are they meaningful? -> test that the visitor is based on the IP address!
|
||||
when ResetStats() is run, reset messagesLimiter (and others)?
|
||||
Delete visitor when tier is changed to refresh rate limiters
|
||||
|
||||
Make sure account endpoints make sense for admins
|
||||
|
||||
|
@ -1602,9 +1602,7 @@ func (s *Server) visitor(r *http.Request) (v *visitor, err error) {
|
|||
} else {
|
||||
v = s.visitorFromIP(ip)
|
||||
}
|
||||
v.mu.Lock()
|
||||
v.user = u
|
||||
v.mu.Unlock()
|
||||
v.SetUser(u) // Update visitor user with latest from database!
|
||||
return v, err // Always return visitor, even when error occurs!
|
||||
}
|
||||
|
||||
|
|
|
@ -31,7 +31,7 @@ func (s *Server) handleAccountCreate(w http.ResponseWriter, r *http.Request, v *
|
|||
if existingUser, _ := s.userManager.User(newAccount.Username); existingUser != nil {
|
||||
return errHTTPConflictUserExists
|
||||
}
|
||||
if v.accountLimiter != nil && !v.accountLimiter.Allow() {
|
||||
if err := v.AccountCreationAllowed(); err != nil {
|
||||
return errHTTPTooManyRequestsLimitAccountCreation
|
||||
}
|
||||
if err := s.userManager.AddUser(newAccount.Username, newAccount.Password, user.RoleUser); err != nil { // TODO this should return a User
|
||||
|
|
|
@ -6,6 +6,7 @@ import (
|
|||
"heckel.io/ntfy/user"
|
||||
"heckel.io/ntfy/util"
|
||||
"io"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
@ -91,6 +92,20 @@ func TestAccount_Signup_Disabled(t *testing.T) {
|
|||
require.Equal(t, 40022, toHTTPError(t, rr.Body.String()).Code)
|
||||
}
|
||||
|
||||
func TestAccount_Signup_Rate_Limit(t *testing.T) {
|
||||
conf := newTestConfigWithAuthFile(t)
|
||||
conf.EnableSignup = true
|
||||
s := newTestServer(t, conf)
|
||||
|
||||
for i := 0; i < 3; i++ {
|
||||
rr := request(t, s, "POST", "/v1/account", fmt.Sprintf(`{"username":"phil%d", "password":"mypass"}`, i), nil)
|
||||
require.Equal(t, 200, rr.Code, "failed on iteration %d", i)
|
||||
}
|
||||
rr := request(t, s, "POST", "/v1/account", `{"username":"notallowed", "password":"mypass"}`, nil)
|
||||
require.Equal(t, 429, rr.Code)
|
||||
require.Equal(t, 42906, toHTTPError(t, rr.Body.String()).Code)
|
||||
}
|
||||
|
||||
func TestAccount_Get_Anonymous(t *testing.T) {
|
||||
conf := newTestConfigWithAuthFile(t)
|
||||
conf.VisitorRequestLimitReplenish = 86 * time.Second
|
||||
|
@ -567,3 +582,60 @@ func TestAccount_Reservation_Add_Kills_Other_Subscribers(t *testing.T) {
|
|||
s.topics["mytopic"].CancelSubscribers("<invalid>")
|
||||
<-userCh
|
||||
}
|
||||
|
||||
func TestAccount_Tier_Create(t *testing.T) {
|
||||
conf := newTestConfigWithAuthFile(t)
|
||||
s := newTestServer(t, conf)
|
||||
|
||||
// Create tier and user
|
||||
require.Nil(t, s.userManager.CreateTier(&user.Tier{
|
||||
Code: "pro",
|
||||
Name: "Pro",
|
||||
MessagesLimit: 123,
|
||||
MessagesExpiryDuration: 86400 * time.Second,
|
||||
EmailsLimit: 32,
|
||||
ReservationsLimit: 2,
|
||||
AttachmentFileSizeLimit: 1231231,
|
||||
AttachmentTotalSizeLimit: 123123,
|
||||
AttachmentExpiryDuration: 10800 * time.Second,
|
||||
AttachmentBandwidthLimit: 21474836480,
|
||||
}))
|
||||
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser))
|
||||
require.Nil(t, s.userManager.ChangeTier("phil", "pro"))
|
||||
|
||||
ti, err := s.userManager.Tier("pro")
|
||||
require.Nil(t, err)
|
||||
|
||||
u, err := s.userManager.User("phil")
|
||||
require.Nil(t, err)
|
||||
|
||||
// These are populated by different SQL queries
|
||||
require.Equal(t, ti, u.Tier)
|
||||
|
||||
// Fields
|
||||
require.True(t, strings.HasPrefix(ti.ID, "ti_"))
|
||||
require.Equal(t, "pro", ti.Code)
|
||||
require.Equal(t, "Pro", ti.Name)
|
||||
require.Equal(t, int64(123), ti.MessagesLimit)
|
||||
require.Equal(t, 86400*time.Second, ti.MessagesExpiryDuration)
|
||||
require.Equal(t, int64(32), ti.EmailsLimit)
|
||||
require.Equal(t, int64(2), ti.ReservationsLimit)
|
||||
require.Equal(t, int64(1231231), ti.AttachmentFileSizeLimit)
|
||||
require.Equal(t, int64(123123), ti.AttachmentTotalSizeLimit)
|
||||
require.Equal(t, 10800*time.Second, ti.AttachmentExpiryDuration)
|
||||
require.Equal(t, int64(21474836480), ti.AttachmentBandwidthLimit)
|
||||
}
|
||||
|
||||
func TestAccount_Tier_Create_With_ID(t *testing.T) {
|
||||
conf := newTestConfigWithAuthFile(t)
|
||||
s := newTestServer(t, conf)
|
||||
|
||||
require.Nil(t, s.userManager.CreateTier(&user.Tier{
|
||||
ID: "ti_123",
|
||||
Code: "pro",
|
||||
}))
|
||||
|
||||
ti, err := s.userManager.Tier("pro")
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, "ti_123", ti.ID)
|
||||
}
|
||||
|
|
|
@ -133,6 +133,7 @@ func TestPayments_SubscriptionCreate_NotAStripeCustomer_Success(t *testing.T) {
|
|||
|
||||
// Create tier and user
|
||||
require.Nil(t, s.userManager.CreateTier(&user.Tier{
|
||||
ID: "ti_123",
|
||||
Code: "pro",
|
||||
StripePriceID: "price_123",
|
||||
}))
|
||||
|
@ -168,6 +169,7 @@ func TestPayments_SubscriptionCreate_StripeCustomer_Success(t *testing.T) {
|
|||
|
||||
// Create tier and user
|
||||
require.Nil(t, s.userManager.CreateTier(&user.Tier{
|
||||
ID: "ti_123",
|
||||
Code: "pro",
|
||||
StripePriceID: "price_123",
|
||||
}))
|
||||
|
@ -209,6 +211,7 @@ func TestPayments_AccountDelete_Cancels_Subscription(t *testing.T) {
|
|||
|
||||
// Create tier and user
|
||||
require.Nil(t, s.userManager.CreateTier(&user.Tier{
|
||||
ID: "ti_123",
|
||||
Code: "pro",
|
||||
StripePriceID: "price_123",
|
||||
}))
|
||||
|
@ -235,6 +238,106 @@ func TestPayments_AccountDelete_Cancels_Subscription(t *testing.T) {
|
|||
require.Equal(t, 401, rr.Code)
|
||||
}
|
||||
|
||||
func TestPayments_Checkout_Success_And_Increase_Ratelimits_Reset_Visitor(t *testing.T) {
|
||||
// This tests a successful checkout flow (not a paying customer -> paying customer),
|
||||
// and also tests that during the upgrade we are RESETTING THE RATE LIMITS of the existing user.
|
||||
|
||||
stripeMock := &testStripeAPI{}
|
||||
defer stripeMock.AssertExpectations(t)
|
||||
|
||||
c := newTestConfigWithAuthFile(t)
|
||||
c.StripeSecretKey = "secret key"
|
||||
c.StripeWebhookKey = "webhook key"
|
||||
c.VisitorRequestLimitBurst = 10
|
||||
c.VisitorRequestLimitReplenish = time.Hour
|
||||
s := newTestServer(t, c)
|
||||
s.stripe = stripeMock
|
||||
|
||||
// Create a user with a Stripe subscription and 3 reservations
|
||||
require.Nil(t, s.userManager.CreateTier(&user.Tier{
|
||||
ID: "ti_123",
|
||||
Code: "starter",
|
||||
StripePriceID: "price_1234",
|
||||
ReservationsLimit: 1,
|
||||
MessagesLimit: 100,
|
||||
MessagesExpiryDuration: time.Hour,
|
||||
}))
|
||||
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser)) // No tier
|
||||
u, err := s.userManager.User("phil")
|
||||
require.Nil(t, err)
|
||||
|
||||
// Define how the mock should react
|
||||
stripeMock.
|
||||
On("GetSession", "SOMETOKEN").
|
||||
Return(&stripe.CheckoutSession{
|
||||
ClientReferenceID: u.ID, // ntfy user ID
|
||||
Customer: &stripe.Customer{
|
||||
ID: "acct_5555",
|
||||
},
|
||||
Subscription: &stripe.Subscription{
|
||||
ID: "sub_1234",
|
||||
},
|
||||
}, nil)
|
||||
stripeMock.
|
||||
On("GetSubscription", "sub_1234").
|
||||
Return(&stripe.Subscription{
|
||||
ID: "sub_1234",
|
||||
Status: stripe.SubscriptionStatusActive,
|
||||
CurrentPeriodEnd: 123456789,
|
||||
CancelAt: 0,
|
||||
Items: &stripe.SubscriptionItemList{
|
||||
Data: []*stripe.SubscriptionItem{
|
||||
{
|
||||
Price: &stripe.Price{ID: "price_1234"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}, nil)
|
||||
stripeMock.
|
||||
On("UpdateCustomer", mock.Anything).
|
||||
Return(&stripe.Customer{}, nil)
|
||||
|
||||
// Send messages until rate limit of free tier is hit
|
||||
for i := 0; i < 10; i++ {
|
||||
rr := request(t, s, "PUT", "/mytopic", "some message", map[string]string{
|
||||
"Authorization": util.BasicAuth("phil", "phil"),
|
||||
})
|
||||
require.Equal(t, 200, rr.Code)
|
||||
}
|
||||
rr := request(t, s, "PUT", "/mytopic", "some message", map[string]string{
|
||||
"Authorization": util.BasicAuth("phil", "phil"),
|
||||
})
|
||||
require.Equal(t, 429, rr.Code)
|
||||
|
||||
// Simulate Stripe success return URL call (no user context)
|
||||
rr = request(t, s, "GET", "/v1/account/billing/subscription/success/SOMETOKEN", "", nil)
|
||||
require.Equal(t, 303, rr.Code)
|
||||
|
||||
// Verify that database columns were updated
|
||||
u, err = s.userManager.User("phil")
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, "starter", u.Tier.Code) // Not "pro"
|
||||
require.Equal(t, "acct_5555", u.Billing.StripeCustomerID)
|
||||
require.Equal(t, "sub_1234", u.Billing.StripeSubscriptionID)
|
||||
require.Equal(t, stripe.SubscriptionStatusActive, u.Billing.StripeSubscriptionStatus)
|
||||
require.Equal(t, int64(123456789), u.Billing.StripeSubscriptionPaidUntil.Unix())
|
||||
require.Equal(t, int64(0), u.Billing.StripeSubscriptionCancelAt.Unix())
|
||||
|
||||
// FIXME FIXME This test is broken, because the rate limit logic is unclear!
|
||||
|
||||
// Now for the fun part: Verify that new rate limits are immediately applied
|
||||
for i := 0; i < 100; i++ {
|
||||
rr := request(t, s, "PUT", "/mytopic", "some message", map[string]string{
|
||||
"Authorization": util.BasicAuth("phil", "phil"),
|
||||
})
|
||||
require.Equal(t, 200, rr.Code, "failed on iteration %d", i)
|
||||
}
|
||||
rr = request(t, s, "PUT", "/mytopic", "some message", map[string]string{
|
||||
"Authorization": util.BasicAuth("phil", "phil"),
|
||||
})
|
||||
require.Equal(t, 429, rr.Code)
|
||||
}
|
||||
|
||||
func TestPayments_Webhook_Subscription_Updated_Downgrade_From_PastDue_To_Active(t *testing.T) {
|
||||
// This tests incoming webhooks from Stripe to update a subscription:
|
||||
// - All Stripe columns are updated in the user table
|
||||
|
@ -257,6 +360,7 @@ func TestPayments_Webhook_Subscription_Updated_Downgrade_From_PastDue_To_Active(
|
|||
|
||||
// Create a user with a Stripe subscription and 3 reservations
|
||||
require.Nil(t, s.userManager.CreateTier(&user.Tier{
|
||||
ID: "ti_1",
|
||||
Code: "starter",
|
||||
StripePriceID: "price_1234", // !
|
||||
ReservationsLimit: 1, // !
|
||||
|
@ -268,6 +372,7 @@ func TestPayments_Webhook_Subscription_Updated_Downgrade_From_PastDue_To_Active(
|
|||
AttachmentBandwidthLimit: 1000000,
|
||||
}))
|
||||
require.Nil(t, s.userManager.CreateTier(&user.Tier{
|
||||
ID: "ti_2",
|
||||
Code: "pro",
|
||||
StripePriceID: "price_1111", // !
|
||||
ReservationsLimit: 3, // !
|
||||
|
|
|
@ -3,6 +3,7 @@ package server
|
|||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"heckel.io/ntfy/log"
|
||||
"heckel.io/ntfy/user"
|
||||
"net/netip"
|
||||
"sync"
|
||||
|
@ -41,7 +42,7 @@ type visitor struct {
|
|||
emailsLimiter *rate.Limiter // Rate limiter for emails
|
||||
subscriptionLimiter util.Limiter // Fixed limiter for active subscriptions (ongoing connections)
|
||||
bandwidthLimiter util.Limiter // Limiter for attachment bandwidth downloads
|
||||
accountLimiter *rate.Limiter // Rate limiter for account creation
|
||||
accountLimiter *rate.Limiter // Rate limiter for account creation, may be nil
|
||||
firebase time.Time // Next allowed Firebase message
|
||||
seen time.Time // Last seen time of this visitor (needed for removal of stale visitors)
|
||||
mu sync.Mutex
|
||||
|
@ -85,26 +86,12 @@ const (
|
|||
)
|
||||
|
||||
func newVisitor(conf *Config, messageCache *messageCache, userManager *user.Manager, ip netip.Addr, user *user.User) *visitor {
|
||||
var messagesLimiter, attachmentBandwidthLimiter util.Limiter
|
||||
var requestLimiter, emailsLimiter, accountLimiter *rate.Limiter
|
||||
var messages, emails int64
|
||||
if user != nil {
|
||||
messages = user.Stats.Messages
|
||||
emails = user.Stats.Emails
|
||||
} else {
|
||||
accountLimiter = rate.NewLimiter(rate.Every(conf.VisitorAccountCreateLimitReplenish), conf.VisitorAccountCreateLimitBurst)
|
||||
}
|
||||
if user != nil && user.Tier != nil {
|
||||
requestLimiter = rate.NewLimiter(dailyLimitToRate(user.Tier.MessagesLimit), conf.VisitorRequestLimitBurst)
|
||||
messagesLimiter = util.NewFixedLimiter(user.Tier.MessagesLimit)
|
||||
emailsLimiter = rate.NewLimiter(dailyLimitToRate(user.Tier.EmailsLimit), conf.VisitorEmailLimitBurst)
|
||||
attachmentBandwidthLimiter = util.NewBytesLimiter(int(user.Tier.AttachmentBandwidthLimit), 24*time.Hour)
|
||||
} else {
|
||||
requestLimiter = rate.NewLimiter(rate.Every(conf.VisitorRequestLimitReplenish), conf.VisitorRequestLimitBurst)
|
||||
emailsLimiter = rate.NewLimiter(rate.Every(conf.VisitorEmailLimitReplenish), conf.VisitorEmailLimitBurst)
|
||||
attachmentBandwidthLimiter = util.NewBytesLimiter(int(conf.VisitorAttachmentDailyBandwidthLimit), 24*time.Hour)
|
||||
}
|
||||
return &visitor{
|
||||
v := &visitor{
|
||||
config: conf,
|
||||
messageCache: messageCache,
|
||||
userManager: userManager, // May be nil
|
||||
|
@ -112,20 +99,26 @@ func newVisitor(conf *Config, messageCache *messageCache, userManager *user.Mana
|
|||
user: user,
|
||||
messages: messages,
|
||||
emails: emails,
|
||||
requestLimiter: requestLimiter,
|
||||
messagesLimiter: messagesLimiter, // May be nil
|
||||
emailsLimiter: emailsLimiter,
|
||||
subscriptionLimiter: util.NewFixedLimiter(int64(conf.VisitorSubscriptionLimit)),
|
||||
bandwidthLimiter: attachmentBandwidthLimiter,
|
||||
accountLimiter: accountLimiter, // May be nil
|
||||
firebase: time.Unix(0, 0),
|
||||
seen: time.Now(),
|
||||
subscriptionLimiter: util.NewFixedLimiter(int64(conf.VisitorSubscriptionLimit)),
|
||||
requestLimiter: nil, // Set in resetLimiters
|
||||
messagesLimiter: nil, // Set in resetLimiters, may be nil
|
||||
emailsLimiter: nil, // Set in resetLimiters
|
||||
bandwidthLimiter: nil, // Set in resetLimiters
|
||||
accountLimiter: nil, // Set in resetLimiters, may be nil
|
||||
}
|
||||
v.resetLimiters()
|
||||
return v
|
||||
}
|
||||
|
||||
func (v *visitor) String() string {
|
||||
v.mu.Lock()
|
||||
defer v.mu.Unlock()
|
||||
return v.stringNoLock()
|
||||
}
|
||||
|
||||
func (v *visitor) stringNoLock() string {
|
||||
if v.user != nil && v.user.Billing.StripeCustomerID != "" {
|
||||
return fmt.Sprintf("%s/%s/%s", v.ip.String(), v.user.ID, v.user.Billing.StripeCustomerID)
|
||||
} else if v.user != nil {
|
||||
|
@ -179,6 +172,13 @@ func (v *visitor) SubscriptionAllowed() error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (v *visitor) AccountCreationAllowed() error {
|
||||
if v.accountLimiter != nil && !v.accountLimiter.Allow() {
|
||||
return errVisitorLimitReached
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (v *visitor) RemoveSubscription() {
|
||||
v.mu.Lock()
|
||||
defer v.mu.Unlock()
|
||||
|
@ -235,7 +235,35 @@ func (v *visitor) ResetStats() {
|
|||
func (v *visitor) SetUser(u *user.User) {
|
||||
v.mu.Lock()
|
||||
defer v.mu.Unlock()
|
||||
shouldResetLimiters := v.user.TierID() != u.TierID() // TierID works with nil receiver
|
||||
v.user = u
|
||||
if shouldResetLimiters {
|
||||
v.resetLimiters()
|
||||
}
|
||||
}
|
||||
|
||||
func (v *visitor) resetLimiters() {
|
||||
log.Info("%s Resetting limiters for visitor", v.stringNoLock())
|
||||
var messagesLimiter, bandwidthLimiter util.Limiter
|
||||
var requestLimiter, emailsLimiter, accountLimiter *rate.Limiter
|
||||
if v.user != nil && v.user.Tier != nil {
|
||||
requestLimiter = rate.NewLimiter(dailyLimitToRate(v.user.Tier.MessagesLimit), v.config.VisitorRequestLimitBurst)
|
||||
messagesLimiter = util.NewFixedLimiter(v.user.Tier.MessagesLimit)
|
||||
emailsLimiter = rate.NewLimiter(dailyLimitToRate(v.user.Tier.EmailsLimit), v.config.VisitorEmailLimitBurst)
|
||||
bandwidthLimiter = util.NewBytesLimiter(int(v.user.Tier.AttachmentBandwidthLimit), 24*time.Hour)
|
||||
accountLimiter = nil // A logged-in user cannot create an account
|
||||
} else {
|
||||
requestLimiter = rate.NewLimiter(rate.Every(v.config.VisitorRequestLimitReplenish), v.config.VisitorRequestLimitBurst)
|
||||
messagesLimiter = nil // Message limit is governed by the requestLimiter
|
||||
emailsLimiter = rate.NewLimiter(rate.Every(v.config.VisitorEmailLimitReplenish), v.config.VisitorEmailLimitBurst)
|
||||
bandwidthLimiter = util.NewBytesLimiter(int(v.config.VisitorAttachmentDailyBandwidthLimit), 24*time.Hour)
|
||||
accountLimiter = rate.NewLimiter(rate.Every(v.config.VisitorAccountCreationLimitReplenish), v.config.VisitorAccountCreationLimitBurst)
|
||||
}
|
||||
v.requestLimiter = requestLimiter
|
||||
v.messagesLimiter = messagesLimiter
|
||||
v.emailsLimiter = emailsLimiter
|
||||
v.bandwidthLimiter = bandwidthLimiter
|
||||
v.accountLimiter = accountLimiter
|
||||
}
|
||||
|
||||
// MaybeUserID returns the user ID of the visitor (if any). If this is an anonymous visitor,
|
||||
|
|
|
@ -110,26 +110,26 @@ const (
|
|||
`
|
||||
|
||||
selectUserByIDQuery = `
|
||||
SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_price_id
|
||||
SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_price_id
|
||||
FROM user u
|
||||
LEFT JOIN tier t on t.id = u.tier_id
|
||||
WHERE u.id = ?
|
||||
`
|
||||
selectUserByNameQuery = `
|
||||
SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_price_id
|
||||
SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_price_id
|
||||
FROM user u
|
||||
LEFT JOIN tier t on t.id = u.tier_id
|
||||
WHERE user = ?
|
||||
`
|
||||
selectUserByTokenQuery = `
|
||||
SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_price_id
|
||||
SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_price_id
|
||||
FROM user u
|
||||
JOIN user_token t on u.id = t.user_id
|
||||
LEFT JOIN tier t on t.id = u.tier_id
|
||||
WHERE t.token = ? AND t.expires >= ?
|
||||
`
|
||||
selectUserByStripeCustomerIDQuery = `
|
||||
SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_price_id
|
||||
SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_price_id
|
||||
FROM user u
|
||||
LEFT JOIN tier t on t.id = u.tier_id
|
||||
WHERE u.stripe_customer_id = ?
|
||||
|
@ -669,13 +669,13 @@ func (a *Manager) userByToken(token string) (*User, error) {
|
|||
func (a *Manager) readUser(rows *sql.Rows) (*User, error) {
|
||||
defer rows.Close()
|
||||
var id, username, hash, role, prefs, syncTopic string
|
||||
var stripeCustomerID, stripeSubscriptionID, stripeSubscriptionStatus, stripePriceID, tierCode, tierName sql.NullString
|
||||
var stripeCustomerID, stripeSubscriptionID, stripeSubscriptionStatus, stripePriceID, tierID, tierCode, tierName sql.NullString
|
||||
var messages, emails int64
|
||||
var messagesLimit, messagesExpiryDuration, emailsLimit, reservationsLimit, attachmentFileSizeLimit, attachmentTotalSizeLimit, attachmentExpiryDuration, attachmentBandwidthLimit, stripeSubscriptionPaidUntil, stripeSubscriptionCancelAt, deleted sql.NullInt64
|
||||
if !rows.Next() {
|
||||
return nil, ErrUserNotFound
|
||||
}
|
||||
if err := rows.Scan(&id, &username, &hash, &role, &prefs, &syncTopic, &messages, &emails, &stripeCustomerID, &stripeSubscriptionID, &stripeSubscriptionStatus, &stripeSubscriptionPaidUntil, &stripeSubscriptionCancelAt, &deleted, &tierCode, &tierName, &messagesLimit, &messagesExpiryDuration, &emailsLimit, &reservationsLimit, &attachmentFileSizeLimit, &attachmentTotalSizeLimit, &attachmentExpiryDuration, &attachmentBandwidthLimit, &stripePriceID); err != nil {
|
||||
if err := rows.Scan(&id, &username, &hash, &role, &prefs, &syncTopic, &messages, &emails, &stripeCustomerID, &stripeSubscriptionID, &stripeSubscriptionStatus, &stripeSubscriptionPaidUntil, &stripeSubscriptionCancelAt, &deleted, &tierID, &tierCode, &tierName, &messagesLimit, &messagesExpiryDuration, &emailsLimit, &reservationsLimit, &attachmentFileSizeLimit, &attachmentTotalSizeLimit, &attachmentExpiryDuration, &attachmentBandwidthLimit, &stripePriceID); err != nil {
|
||||
return nil, err
|
||||
} else if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
|
@ -706,6 +706,7 @@ func (a *Manager) readUser(rows *sql.Rows) (*User, error) {
|
|||
if tierCode.Valid {
|
||||
// See readTier() when this is changed!
|
||||
user.Tier = &Tier{
|
||||
ID: tierID.String,
|
||||
Code: tierCode.String,
|
||||
Name: tierName.String,
|
||||
MessagesLimit: messagesLimit.Int64,
|
||||
|
@ -995,8 +996,10 @@ func (a *Manager) DefaultAccess() Permission {
|
|||
|
||||
// CreateTier creates a new tier in the database
|
||||
func (a *Manager) CreateTier(tier *Tier) error {
|
||||
tierID := util.RandomStringPrefix(tierIDPrefix, tierIDLength)
|
||||
if _, err := a.db.Exec(insertTierQuery, tierID, tier.Code, tier.Name, tier.MessagesLimit, int64(tier.MessagesExpiryDuration.Seconds()), tier.EmailsLimit, tier.ReservationsLimit, tier.AttachmentFileSizeLimit, tier.AttachmentTotalSizeLimit, int64(tier.AttachmentExpiryDuration.Seconds()), tier.AttachmentBandwidthLimit, tier.StripePriceID); err != nil {
|
||||
if tier.ID == "" {
|
||||
tier.ID = util.RandomStringPrefix(tierIDPrefix, tierIDLength)
|
||||
}
|
||||
if _, err := a.db.Exec(insertTierQuery, tier.ID, tier.Code, tier.Name, tier.MessagesLimit, int64(tier.MessagesExpiryDuration.Seconds()), tier.EmailsLimit, tier.ReservationsLimit, tier.AttachmentFileSizeLimit, tier.AttachmentTotalSizeLimit, int64(tier.AttachmentExpiryDuration.Seconds()), tier.AttachmentBandwidthLimit, tier.StripePriceID); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
|
|
|
@ -23,6 +23,15 @@ type User struct {
|
|||
Deleted bool
|
||||
}
|
||||
|
||||
// TierID returns the ID of the User.Tier, or an empty string if the user has no tier,
|
||||
// or if the user itself is nil.
|
||||
func (u *User) TierID() string {
|
||||
if u == nil || u.Tier == nil {
|
||||
return ""
|
||||
}
|
||||
return u.Tier.ID
|
||||
}
|
||||
|
||||
// Auther is an interface for authentication and authorization
|
||||
type Auther interface {
|
||||
// Authenticate checks username and password and returns a user if correct. The method
|
||||
|
|
Loading…
Reference in New Issue